diff --git a/src/cachetools/__init__.py b/src/cachetools/__init__.py index ca67c588..e9a61f78 100644 --- a/src/cachetools/__init__.py +++ b/src/cachetools/__init__.py @@ -16,7 +16,7 @@ __all__ = ( 'cached', 'cachedmethod' ) -__version__ = '2.1.0' +__version__ = '3.1.0' if hasattr(functools.update_wrapper(lambda f: f(), lambda: 42), '__wrapped__'): _update_wrapper = functools.update_wrapper @@ -79,7 +79,7 @@ def cachedmethod(cache, key=keys.hashkey, lock=None): c = cache(self) if c is None: return method(self, *args, **kwargs) - k = key(self, *args, **kwargs) + k = key(*args, **kwargs) try: return c[k] except KeyError: @@ -95,7 +95,7 @@ def cachedmethod(cache, key=keys.hashkey, lock=None): c = cache(self) if c is None: return method(self, *args, **kwargs) - k = key(self, *args, **kwargs) + k = key(*args, **kwargs) try: with lock(self): return c[k] diff --git a/src/cachetools/abc.py b/src/cachetools/abc.py index b663d96e..3bc43cc4 100644 --- a/src/cachetools/abc.py +++ b/src/cachetools/abc.py @@ -1,10 +1,14 @@ from __future__ import absolute_import -import collections from abc import abstractmethod +try: + from collections.abc import MutableMapping +except ImportError: + from collections import MutableMapping -class DefaultMapping(collections.MutableMapping): + +class DefaultMapping(MutableMapping): __slots__ = () diff --git a/src/cachetools/cache.py b/src/cachetools/cache.py index a9a3e571..5cb80715 100644 --- a/src/cachetools/cache.py +++ b/src/cachetools/cache.py @@ -1,7 +1,5 @@ from __future__ import absolute_import -from warnings import warn - from .abc import DefaultMapping @@ -16,20 +14,12 @@ class _DefaultSize(object): return 1 -_deprecated = object() - - class Cache(DefaultMapping): """Mutable mapping to serve as a simple cache or cache base class.""" __size = _DefaultSize() - def __init__(self, maxsize, missing=_deprecated, getsizeof=None): - if missing is not _deprecated: - warn("Cache constructor parameter 'missing' is deprecated", - DeprecationWarning, 3) - if missing: - self.__missing = missing + def __init__(self, maxsize, getsizeof=None): if getsizeof: self.getsizeof = getsizeof if self.getsizeof is not Cache.getsizeof: @@ -77,12 +67,7 @@ class Cache(DefaultMapping): return key in self.__data def __missing__(self, key): - value = self.__missing(key) - try: - self.__setitem__(key, value) - except ValueError: - pass # value too large - return value + raise KeyError(key) def __iter__(self): return iter(self.__data) @@ -104,7 +89,3 @@ class Cache(DefaultMapping): def getsizeof(value): """Return the size of a cache element's value.""" return 1 - - @staticmethod - def __missing(key): - raise KeyError(key) diff --git a/src/cachetools/func.py b/src/cachetools/func.py index 5a2ce847..8ced5dda 100644 --- a/src/cachetools/func.py +++ b/src/cachetools/func.py @@ -5,11 +5,15 @@ from __future__ import absolute_import import collections import functools import random -import time + +try: + from time import monotonic as default_timer +except ImportError: + from time import time as default_timer try: from threading import RLock -except ImportError: +except ImportError: # pragma: no cover from dummy_threading import RLock from . import keys @@ -26,6 +30,24 @@ _CacheInfo = collections.namedtuple('CacheInfo', [ ]) +class _UnboundCache(dict): + + maxsize = None + + @property + def currsize(self): + return len(self) + + +class _UnboundTTLCache(TTLCache): + def __init__(self, ttl, timer): + TTLCache.__init__(self, float('inf'), ttl, timer) + + @property + def maxsize(self): + return None + + def _cache(cache, typed=False): def decorator(func): key = keys.typedkey if typed else keys.hashkey @@ -77,7 +99,10 @@ def lfu_cache(maxsize=128, typed=False): algorithm. """ - return _cache(LFUCache(maxsize), typed) + if maxsize is None: + return _cache(_UnboundCache(), typed) + else: + return _cache(LFUCache(maxsize), typed) def lru_cache(maxsize=128, typed=False): @@ -86,7 +111,10 @@ def lru_cache(maxsize=128, typed=False): algorithm. """ - return _cache(LRUCache(maxsize), typed) + if maxsize is None: + return _cache(_UnboundCache(), typed) + else: + return _cache(LRUCache(maxsize), typed) def rr_cache(maxsize=128, choice=random.choice, typed=False): @@ -95,12 +123,18 @@ def rr_cache(maxsize=128, choice=random.choice, typed=False): algorithm. """ - return _cache(RRCache(maxsize, choice), typed) + if maxsize is None: + return _cache(_UnboundCache(), typed) + else: + return _cache(RRCache(maxsize, choice), typed) -def ttl_cache(maxsize=128, ttl=600, timer=time.time, typed=False): +def ttl_cache(maxsize=128, ttl=600, timer=default_timer, typed=False): """Decorator to wrap a function with a memoizing callable that saves up to `maxsize` results based on a Least Recently Used (LRU) algorithm with a per-item time-to-live (TTL) value. """ - return _cache(TTLCache(maxsize, ttl, timer), typed) + if maxsize is None: + return _cache(_UnboundTTLCache(ttl, timer), typed) + else: + return _cache(TTLCache(maxsize, ttl, timer), typed) diff --git a/src/cachetools/lfu.py b/src/cachetools/lfu.py index 76a264a6..4857c4e9 100644 --- a/src/cachetools/lfu.py +++ b/src/cachetools/lfu.py @@ -2,14 +2,14 @@ from __future__ import absolute_import import collections -from .cache import Cache, _deprecated +from .cache import Cache class LFUCache(Cache): """Least Frequently Used (LFU) cache implementation.""" - def __init__(self, maxsize, missing=_deprecated, getsizeof=None): - Cache.__init__(self, maxsize, missing, getsizeof) + def __init__(self, maxsize, getsizeof=None): + Cache.__init__(self, maxsize, getsizeof) self.__counter = collections.Counter() def __getitem__(self, key, cache_getitem=Cache.__getitem__): diff --git a/src/cachetools/lru.py b/src/cachetools/lru.py index 991b0e23..44ec4f1c 100644 --- a/src/cachetools/lru.py +++ b/src/cachetools/lru.py @@ -2,14 +2,14 @@ from __future__ import absolute_import import collections -from .cache import Cache, _deprecated +from .cache import Cache class LRUCache(Cache): """Least Recently Used (LRU) cache implementation.""" - def __init__(self, maxsize, missing=_deprecated, getsizeof=None): - Cache.__init__(self, maxsize, missing, getsizeof) + def __init__(self, maxsize, getsizeof=None): + Cache.__init__(self, maxsize, getsizeof) self.__order = collections.OrderedDict() def __getitem__(self, key, cache_getitem=Cache.__getitem__): diff --git a/src/cachetools/rr.py b/src/cachetools/rr.py index 1aeed438..09ff7708 100644 --- a/src/cachetools/rr.py +++ b/src/cachetools/rr.py @@ -2,7 +2,7 @@ from __future__ import absolute_import import random -from .cache import Cache, _deprecated +from .cache import Cache # random.choice cannot be pickled in Python 2.7 @@ -13,9 +13,8 @@ def _choice(seq): class RRCache(Cache): """Random Replacement (RR) cache implementation.""" - def __init__(self, maxsize, choice=random.choice, missing=_deprecated, - getsizeof=None): - Cache.__init__(self, maxsize, missing, getsizeof) + def __init__(self, maxsize, choice=random.choice, getsizeof=None): + Cache.__init__(self, maxsize, getsizeof) # TODO: use None as default, assing to self.choice directly? if choice is random.choice: self.__choice = _choice diff --git a/src/cachetools/ttl.py b/src/cachetools/ttl.py index d4c3b37b..1edde3ab 100644 --- a/src/cachetools/ttl.py +++ b/src/cachetools/ttl.py @@ -1,9 +1,13 @@ from __future__ import absolute_import import collections -import time -from .cache import Cache, _deprecated +try: + from time import monotonic as default_timer +except ImportError: + from time import time as default_timer + +from .cache import Cache class _Link(object): @@ -57,9 +61,8 @@ class _Timer(object): class TTLCache(Cache): """LRU Cache implementation with per-item time-to-live (TTL) value.""" - def __init__(self, maxsize, ttl, timer=time.time, missing=_deprecated, - getsizeof=None): - Cache.__init__(self, maxsize, missing, getsizeof) + def __init__(self, maxsize, ttl, timer=default_timer, getsizeof=None): + Cache.__init__(self, maxsize, getsizeof) self.__root = root = _Link() root.prev = root.next = root self.__links = collections.OrderedDict() diff --git a/src/dns/__init__.py b/src/dns/__init__.py index c848e485..c1ce8e60 100644 --- a/src/dns/__init__.py +++ b/src/dns/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/_compat.py b/src/dns/_compat.py index 956f9a13..ca0931c2 100644 --- a/src/dns/_compat.py +++ b/src/dns/_compat.py @@ -2,7 +2,11 @@ import sys import decimal from decimal import Context -if sys.version_info > (3,): +PY3 = sys.version_info[0] == 3 +PY2 = sys.version_info[0] == 2 + + +if PY3: long = int xrange = range else: @@ -10,7 +14,7 @@ else: xrange = xrange # pylint: disable=xrange-builtin # unicode / binary types -if sys.version_info > (3,): +if PY3: text_type = str binary_type = bytes string_types = (str,) @@ -19,6 +23,10 @@ if sys.version_info > (3,): return x.decode() def maybe_encode(x): return x.encode() + def maybe_chr(x): + return x + def maybe_ord(x): + return x else: text_type = unicode # pylint: disable=unicode-builtin, undefined-variable binary_type = str @@ -30,6 +38,10 @@ else: return x def maybe_encode(x): return x + def maybe_chr(x): + return chr(x) + def maybe_ord(x): + return ord(x) def round_py2_compat(what): diff --git a/src/dns/dnssec.py b/src/dns/dnssec.py index fec12082..35da6b5a 100644 --- a/src/dns/dnssec.py +++ b/src/dns/dnssec.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,7 +22,6 @@ import struct import time import dns.exception -import dns.hash import dns.name import dns.node import dns.rdataset @@ -31,27 +32,40 @@ from ._compat import string_types class UnsupportedAlgorithm(dns.exception.DNSException): - """The DNSSEC algorithm is not supported.""" class ValidationFailure(dns.exception.DNSException): - """The DNSSEC signature is invalid.""" + +#: RSAMD5 RSAMD5 = 1 +#: DH DH = 2 +#: DSA DSA = 3 +#: ECC ECC = 4 +#: RSASHA1 RSASHA1 = 5 +#: DSANSEC3SHA1 DSANSEC3SHA1 = 6 +#: RSASHA1NSEC3SHA1 RSASHA1NSEC3SHA1 = 7 +#: RSASHA256 RSASHA256 = 8 +#: RSASHA512 RSASHA512 = 10 +#: ECDSAP256SHA256 ECDSAP256SHA256 = 13 +#: ECDSAP384SHA384 ECDSAP384SHA384 = 14 +#: INDIRECT INDIRECT = 252 +#: PRIVATEDNS PRIVATEDNS = 253 +#: PRIVATEOID PRIVATEOID = 254 _algorithm_by_text = { @@ -75,12 +89,14 @@ _algorithm_by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be true inverse. -_algorithm_by_value = dict((y, x) for x, y in _algorithm_by_text.items()) +_algorithm_by_value = {y: x for x, y in _algorithm_by_text.items()} def algorithm_from_text(text): - """Convert text into a DNSSEC algorithm value - @rtype: int""" + """Convert text into a DNSSEC algorithm value. + + Returns an ``int``. + """ value = _algorithm_by_text.get(text.upper()) if value is None: @@ -90,7 +106,9 @@ def algorithm_from_text(text): def algorithm_to_text(value): """Convert a DNSSEC algorithm value to text - @rtype: string""" + + Returns a ``str``. + """ text = _algorithm_by_value.get(value) if text is None: @@ -105,6 +123,14 @@ def _to_rdata(record, origin): def key_id(key, origin=None): + """Return the key id (a 16-bit number) for the specified key. + + Note the *origin* parameter of this function is historical and + is not needed. + + Returns an ``int`` between 0 and 65535. + """ + rdata = _to_rdata(key, origin) rdata = bytearray(rdata) if key.algorithm == RSAMD5: @@ -121,12 +147,28 @@ def key_id(key, origin=None): def make_ds(name, key, algorithm, origin=None): + """Create a DS record for a DNSSEC key. + + *name* is the owner name of the DS record. + + *key* is a ``dns.rdtypes.ANY.DNSKEY``. + + *algorithm* is a string describing which hash algorithm to use. The + currently supported hashes are "SHA1" and "SHA256". Case does not + matter for these strings. + + *origin* is a ``dns.name.Name`` and will be used as the origin + if *key* is a relative name. + + Returns a ``dns.rdtypes.ANY.DS``. + """ + if algorithm.upper() == 'SHA1': dsalg = 1 - hash = dns.hash.hashes['SHA1']() + hash = SHA1.new() elif algorithm.upper() == 'SHA256': dsalg = 2 - hash = dns.hash.hashes['SHA256']() + hash = SHA256.new() else: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) @@ -198,15 +240,15 @@ def _is_sha512(algorithm): def _make_hash(algorithm): if _is_md5(algorithm): - return dns.hash.hashes['MD5']() + return MD5.new() if _is_sha1(algorithm): - return dns.hash.hashes['SHA1']() + return SHA1.new() if _is_sha256(algorithm): - return dns.hash.hashes['SHA256']() + return SHA256.new() if _is_sha384(algorithm): - return dns.hash.hashes['SHA384']() + return SHA384.new() if _is_sha512(algorithm): - return dns.hash.hashes['SHA512']() + return SHA512.new() raise ValidationFailure('unknown hash for algorithm %u' % algorithm) @@ -232,31 +274,32 @@ def _make_algorithm_id(algorithm): def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): """Validate an RRset against a single signature rdata - The owner name of the rrsig is assumed to be the same as the owner name - of the rrset. + The owner name of *rrsig* is assumed to be the same as the owner name + of *rrset*. - @param rrset: The RRset to validate - @type rrset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param rrsig: The signature rdata - @type rrsig: dns.rrset.Rdata - @param keys: The key dictionary. - @type keys: a dictionary keyed by dns.name.Name with node or rdataset - values - @param origin: The origin to use for relative names - @type origin: dns.name.Name or None - @param now: The time to use when validating the signatures. The default - is the current time. - @type now: int + *rrset* is the RRset to validate. It can be a ``dns.rrset.RRset`` or + a ``(dns.name.Name, dns.rdataset.Rdataset)`` tuple. + + *rrsig* is a ``dns.rdata.Rdata``, the signature to validate. + + *keys* is the key dictionary, used to find the DNSKEY associated with + a given name. The dictionary is keyed by a ``dns.name.Name``, and has + ``dns.node.Node`` or ``dns.rdataset.Rdataset`` values. + + *origin* is a ``dns.name.Name``, the origin to use for relative names. + + *now* is an ``int``, the time to use when validating the signatures, + in seconds since the UNIX epoch. The default is the current time. """ if isinstance(origin, string_types): origin = dns.name.from_text(origin, dns.name.root) - for candidate_key in _find_candidate_keys(keys, rrsig): - if not candidate_key: - raise ValidationFailure('unknown key') + candidate_keys = _find_candidate_keys(keys, rrsig) + if candidate_keys is None: + raise ValidationFailure('unknown key') + for candidate_key in candidate_keys: # For convenience, allow the rrset to be specified as a (name, # rdataset) tuple as well as a proper rrset if isinstance(rrset, tuple): @@ -284,11 +327,13 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): keyptr = keyptr[2:] rsa_e = keyptr[0:bytes_] rsa_n = keyptr[bytes_:] - keylen = len(rsa_n) * 8 - pubkey = Crypto.PublicKey.RSA.construct( - (Crypto.Util.number.bytes_to_long(rsa_n), - Crypto.Util.number.bytes_to_long(rsa_e))) - sig = (Crypto.Util.number.bytes_to_long(rrsig.signature),) + try: + pubkey = CryptoRSA.construct( + (number.bytes_to_long(rsa_n), + number.bytes_to_long(rsa_e))) + except ValueError: + raise ValidationFailure('invalid public key') + sig = rrsig.signature elif _is_dsa(rrsig.algorithm): keyptr = candidate_key.key (t,) = struct.unpack('!B', keyptr[0:1]) @@ -301,36 +346,37 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): dsa_g = keyptr[0:octets] keyptr = keyptr[octets:] dsa_y = keyptr[0:octets] - pubkey = Crypto.PublicKey.DSA.construct( - (Crypto.Util.number.bytes_to_long(dsa_y), - Crypto.Util.number.bytes_to_long(dsa_g), - Crypto.Util.number.bytes_to_long(dsa_p), - Crypto.Util.number.bytes_to_long(dsa_q))) - (dsa_r, dsa_s) = struct.unpack('!20s20s', rrsig.signature[1:]) - sig = (Crypto.Util.number.bytes_to_long(dsa_r), - Crypto.Util.number.bytes_to_long(dsa_s)) + pubkey = CryptoDSA.construct( + (number.bytes_to_long(dsa_y), + number.bytes_to_long(dsa_g), + number.bytes_to_long(dsa_p), + number.bytes_to_long(dsa_q))) + sig = rrsig.signature[1:] elif _is_ecdsa(rrsig.algorithm): + # use ecdsa for NIST-384p -- not currently supported by pycryptodome + + keyptr = candidate_key.key + if rrsig.algorithm == ECDSAP256SHA256: curve = ecdsa.curves.NIST256p key_len = 32 elif rrsig.algorithm == ECDSAP384SHA384: curve = ecdsa.curves.NIST384p key_len = 48 - else: - # shouldn't happen - raise ValidationFailure('unknown ECDSA curve') - keyptr = candidate_key.key - x = Crypto.Util.number.bytes_to_long(keyptr[0:key_len]) - y = Crypto.Util.number.bytes_to_long(keyptr[key_len:key_len * 2]) - assert ecdsa.ecdsa.point_is_valid(curve.generator, x, y) + + x = number.bytes_to_long(keyptr[0:key_len]) + y = number.bytes_to_long(keyptr[key_len:key_len * 2]) + if not ecdsa.ecdsa.point_is_valid(curve.generator, x, y): + raise ValidationFailure('invalid ECDSA key') point = ecdsa.ellipticcurve.Point(curve.curve, x, y, curve.order) verifying_key = ecdsa.keys.VerifyingKey.from_public_point(point, curve) pubkey = ECKeyWrapper(verifying_key, key_len) r = rrsig.signature[:key_len] s = rrsig.signature[key_len:] - sig = ecdsa.ecdsa.Signature(Crypto.Util.number.bytes_to_long(r), - Crypto.Util.number.bytes_to_long(s)) + sig = ecdsa.ecdsa.Signature(number.bytes_to_long(r), + number.bytes_to_long(s)) + else: raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) @@ -352,44 +398,49 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): hash.update(rrlen) hash.update(rrdata) - digest = hash.digest() - - if _is_rsa(rrsig.algorithm): - # PKCS1 algorithm identifier goop - digest = _make_algorithm_id(rrsig.algorithm) + digest - padlen = keylen // 8 - len(digest) - 3 - digest = struct.pack('!%dB' % (2 + padlen + 1), - *([0, 1] + [0xFF] * padlen + [0])) + digest - elif _is_dsa(rrsig.algorithm) or _is_ecdsa(rrsig.algorithm): - pass - else: - # Raise here for code clarity; this won't actually ever happen - # since if the algorithm is really unknown we'd already have - # raised an exception above - raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) - - if pubkey.verify(digest, sig): + try: + if _is_rsa(rrsig.algorithm): + verifier = pkcs1_15.new(pubkey) + # will raise ValueError if verify fails: + verifier.verify(hash, sig) + elif _is_dsa(rrsig.algorithm): + verifier = DSS.new(pubkey, 'fips-186-3') + verifier.verify(hash, sig) + elif _is_ecdsa(rrsig.algorithm): + digest = hash.digest() + if not pubkey.verify(digest, sig): + raise ValueError + else: + # Raise here for code clarity; this won't actually ever happen + # since if the algorithm is really unknown we'd already have + # raised an exception above + raise ValidationFailure('unknown algorithm %u' % rrsig.algorithm) + # If we got here, we successfully verified so we can return without error return + except ValueError: + # this happens on an individual validation failure + continue + # nothing verified -- raise failure: raise ValidationFailure('verify failure') def _validate(rrset, rrsigset, keys, origin=None, now=None): - """Validate an RRset + """Validate an RRset. - @param rrset: The RRset to validate - @type rrset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param rrsigset: The signature RRset - @type rrsigset: dns.rrset.RRset or (dns.name.Name, dns.rdataset.Rdataset) - tuple - @param keys: The key dictionary. - @type keys: a dictionary keyed by dns.name.Name with node or rdataset - values - @param origin: The origin to use for relative names - @type origin: dns.name.Name or None - @param now: The time to use when validating the signatures. The default - is the current time. - @type now: int + *rrset* is the RRset to validate. It can be a ``dns.rrset.RRset`` or + a ``(dns.name.Name, dns.rdataset.Rdataset)`` tuple. + + *rrsigset* is the signature RRset to be validated. It can be a + ``dns.rrset.RRset`` or a ``(dns.name.Name, dns.rdataset.Rdataset)`` tuple. + + *keys* is the key dictionary, used to find the DNSKEY associated with + a given name. The dictionary is keyed by a ``dns.name.Name``, and has + ``dns.node.Node`` or ``dns.rdataset.Rdataset`` values. + + *origin* is a ``dns.name.Name``, the origin to use for relative names. + + *now* is an ``int``, the time to use when validating the signatures, + in seconds since the UNIX epoch. The default is the current time. """ if isinstance(origin, string_types): @@ -408,7 +459,7 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): rrsigrdataset = rrsigset rrname = rrname.choose_relativity(origin) - rrsigname = rrname.choose_relativity(origin) + rrsigname = rrsigname.choose_relativity(origin) if rrname != rrsigname: raise ValidationFailure("owner names do not match") @@ -422,36 +473,47 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): def _need_pycrypto(*args, **kwargs): - raise NotImplementedError("DNSSEC validation requires pycrypto") + raise NotImplementedError("DNSSEC validation requires pycryptodome/pycryptodomex") + try: - import Crypto.PublicKey.RSA - import Crypto.PublicKey.DSA - import Crypto.Util.number - validate = _validate - validate_rrsig = _validate_rrsig - _have_pycrypto = True + try: + # test we're using pycryptodome, not pycrypto (which misses SHA1 for example) + from Crypto.Hash import MD5, SHA1, SHA256, SHA384, SHA512 + from Crypto.PublicKey import RSA as CryptoRSA, DSA as CryptoDSA + from Crypto.Signature import pkcs1_15, DSS + from Crypto.Util import number + except ImportError: + from Cryptodome.Hash import MD5, SHA1, SHA256, SHA384, SHA512 + from Cryptodome.PublicKey import RSA as CryptoRSA, DSA as CryptoDSA + from Cryptodome.Signature import pkcs1_15, DSS + from Cryptodome.Util import number except ImportError: validate = _need_pycrypto validate_rrsig = _need_pycrypto _have_pycrypto = False - -try: - import ecdsa - import ecdsa.ecdsa - import ecdsa.ellipticcurve - import ecdsa.keys - _have_ecdsa = True - - class ECKeyWrapper(object): - - def __init__(self, key, key_len): - self.key = key - self.key_len = key_len - - def verify(self, digest, sig): - diglong = Crypto.Util.number.bytes_to_long(digest) - return self.key.pubkey.verifies(diglong, sig) - -except ImportError: _have_ecdsa = False +else: + validate = _validate + validate_rrsig = _validate_rrsig + _have_pycrypto = True + + try: + import ecdsa + import ecdsa.ecdsa + import ecdsa.ellipticcurve + import ecdsa.keys + except ImportError: + _have_ecdsa = False + else: + _have_ecdsa = True + + class ECKeyWrapper(object): + + def __init__(self, key, key_len): + self.key = key + self.key_len = key_len + + def verify(self, digest, sig): + diglong = number.bytes_to_long(digest) + return self.key.pubkey.verifies(diglong, sig) diff --git a/src/dns/dnssec.pyi b/src/dns/dnssec.pyi new file mode 100644 index 00000000..5699b3e1 --- /dev/null +++ b/src/dns/dnssec.pyi @@ -0,0 +1,19 @@ +from typing import Union, Dict, Tuple, Optional +from . import rdataset, rrset, exception, name, rdtypes, rdata, node +import dns.rdtypes.ANY.DS as DS +import dns.rdtypes.ANY.DNSKEY as DNSKEY + +_have_ecdsa : bool +_have_pycrypto : bool + +def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None: + ... + +def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None: + ... + +class ValidationFailure(exception.DNSException): + ... + +def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS: + ... diff --git a/src/dns/e164.py b/src/dns/e164.py index 99300730..758c47a7 100644 --- a/src/dns/e164.py +++ b/src/dns/e164.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,31 +15,32 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS E.164 helpers - -@var public_enum_domain: The DNS public ENUM domain, e164.arpa. -@type public_enum_domain: dns.name.Name object -""" - +"""DNS E.164 helpers.""" import dns.exception import dns.name import dns.resolver -from ._compat import string_types +from ._compat import string_types, maybe_decode +#: The public E.164 domain. public_enum_domain = dns.name.from_text('e164.arpa.') def from_e164(text, origin=public_enum_domain): """Convert an E.164 number in textual form into a Name object whose value is the ENUM domain name for that number. - @param text: an E.164 number in textual form. - @type text: str - @param origin: The domain in which the number should be constructed. - The default is e164.arpa. - @type origin: dns.name.Name object or None - @rtype: dns.name.Name object + + Non-digits in the text are ignored, i.e. "16505551212", + "+1.650.555.1212" and "1 (650) 555-1212" are all the same. + + *text*, a ``text``, is an E.164 number in textual form. + + *origin*, a ``dns.name.Name``, the domain in which the number + should be constructed. The default is ``e164.arpa.``. + + Returns a ``dns.name.Name``. """ + parts = [d for d in text if d.isdigit()] parts.reverse() return dns.name.from_text('.'.join(parts), origin=origin) @@ -45,14 +48,23 @@ def from_e164(text, origin=public_enum_domain): def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): """Convert an ENUM domain name into an E.164 number. - @param name: the ENUM domain name. - @type name: dns.name.Name object. - @param origin: A domain containing the ENUM domain name. The - name is relativized to this domain before being converted to text. - @type origin: dns.name.Name object or None - @param want_plus_prefix: if True, add a '+' to the beginning of the - returned number. - @rtype: str + + Note that dnspython does not have any information about preferred + number formats within national numbering plans, so all numbers are + emitted as a simple string of digits, prefixed by a '+' (unless + *want_plus_prefix* is ``False``). + + *name* is a ``dns.name.Name``, the ENUM domain name. + + *origin* is a ``dns.name.Name``, a domain containing the ENUM + domain name. The name is relativized to this domain before being + converted to text. If ``None``, no relativization is done. + + *want_plus_prefix* is a ``bool``. If True, add a '+' to the beginning of + the returned number. + + Returns a ``text``. + """ if origin is not None: name = name.relativize(origin) @@ -63,14 +75,22 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): text = b''.join(dlabels) if want_plus_prefix: text = b'+' + text - return text + return maybe_decode(text) def query(number, domains, resolver=None): """Look for NAPTR RRs for the specified number in the specified domains. e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) + + *number*, a ``text`` is the number to look for. + + *domains* is an iterable containing ``dns.name.Name`` values. + + *resolver*, a ``dns.resolver.Resolver``, is the resolver to use. If + ``None``, the default resolver is used. """ + if resolver is None: resolver = dns.resolver.get_default_resolver() e_nx = dns.resolver.NXDOMAIN() diff --git a/src/dns/e164.pyi b/src/dns/e164.pyi new file mode 100644 index 00000000..37a99fed --- /dev/null +++ b/src/dns/e164.pyi @@ -0,0 +1,10 @@ +from typing import Optional, Iterable +from . import name, resolver +def from_e164(text : str, origin=name.Name(".")) -> name.Name: + ... + +def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str: + ... + +def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer: + ... diff --git a/src/dns/edns.py b/src/dns/edns.py index 8ac676bc..5660f7bb 100644 --- a/src/dns/edns.py +++ b/src/dns/edns.py @@ -1,4 +1,6 @@ -# Copyright (C) 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,18 +17,42 @@ """EDNS Options""" -NSID = 3 +from __future__ import absolute_import +import math +import struct + +import dns.inet + +#: NSID +NSID = 3 +#: DAU +DAU = 5 +#: DHU +DHU = 6 +#: N3U +N3U = 7 +#: ECS (client-subnet) +ECS = 8 +#: EXPIRE +EXPIRE = 9 +#: COOKIE +COOKIE = 10 +#: KEEPALIVE +KEEPALIVE = 11 +#: PADDING +PADDING = 12 +#: CHAIN +CHAIN = 13 class Option(object): - """Base class for all EDNS option types. - """ + """Base class for all EDNS option types.""" def __init__(self, otype): """Initialize an option. - @param otype: The rdata type - @type otype: int + + *otype*, an ``int``, is the option type. """ self.otype = otype @@ -37,23 +63,26 @@ class Option(object): @classmethod def from_wire(cls, otype, wire, current, olen): - """Build an EDNS option object from wire format + """Build an EDNS option object from wire format. + + *otype*, an ``int``, is the option type. + + *wire*, a ``binary``, is the wire-format message. + + *current*, an ``int``, is the offset in *wire* of the beginning + of the rdata. + + *olen*, an ``int``, is the length of the wire-format option data + + Returns a ``dns.edns.Option``. + """ - @param otype: The option type - @type otype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param olen: The length of the wire-format option data - @type olen: int - @rtype: dns.edns.Option instance""" raise NotImplementedError def _cmp(self, other): """Compare an EDNS option with another option of the same type. - Return < 0 if self < other, 0 if self == other, - and > 0 if self > other. + + Returns < 0 if < *other*, 0 if == *other*, and > 0 if > *other*. """ raise NotImplementedError @@ -98,7 +127,7 @@ class Option(object): class GenericOption(Option): - """Generate Rdata Class + """Generic Option Class This class is used for EDNS option types for which we have no better implementation. @@ -111,6 +140,9 @@ class GenericOption(Option): def to_wire(self, file): file.write(self.data) + def to_text(self): + return "Generic %d" % self.otype + @classmethod def from_wire(cls, otype, wire, current, olen): return cls(otype, wire[current: current + olen]) @@ -122,11 +154,96 @@ class GenericOption(Option): return 1 return -1 + +class ECSOption(Option): + """EDNS Client Subnet (ECS, RFC7871)""" + + def __init__(self, address, srclen=None, scopelen=0): + """*address*, a ``text``, is the client address information. + + *srclen*, an ``int``, the source prefix length, which is the + leftmost number of bits of the address to be used for the + lookup. The default is 24 for IPv4 and 56 for IPv6. + + *scopelen*, an ``int``, the scope prefix length. This value + must be 0 in queries, and should be set in responses. + """ + + super(ECSOption, self).__init__(ECS) + af = dns.inet.af_for_address(address) + + if af == dns.inet.AF_INET6: + self.family = 2 + if srclen is None: + srclen = 56 + elif af == dns.inet.AF_INET: + self.family = 1 + if srclen is None: + srclen = 24 + else: + raise ValueError('Bad ip family') + + self.address = address + self.srclen = srclen + self.scopelen = scopelen + + addrdata = dns.inet.inet_pton(af, address) + nbytes = int(math.ceil(srclen/8.0)) + + # Truncate to srclen and pad to the end of the last octet needed + # See RFC section 6 + self.addrdata = addrdata[:nbytes] + nbits = srclen % 8 + if nbits != 0: + last = struct.pack('B', ord(self.addrdata[-1:]) & (0xff << nbits)) + self.addrdata = self.addrdata[:-1] + last + + def to_text(self): + return "ECS {}/{} scope/{}".format(self.address, self.srclen, + self.scopelen) + + def to_wire(self, file): + file.write(struct.pack('!H', self.family)) + file.write(struct.pack('!BB', self.srclen, self.scopelen)) + file.write(self.addrdata) + + @classmethod + def from_wire(cls, otype, wire, cur, olen): + family, src, scope = struct.unpack('!HBB', wire[cur:cur+4]) + cur += 4 + + addrlen = int(math.ceil(src/8.0)) + + if family == 1: + af = dns.inet.AF_INET + pad = 4 - addrlen + elif family == 2: + af = dns.inet.AF_INET6 + pad = 16 - addrlen + else: + raise ValueError('unsupported family') + + addr = dns.inet.inet_ntop(af, wire[cur:cur+addrlen] + b'\x00' * pad) + return cls(addr, src, scope) + + def _cmp(self, other): + if self.addrdata == other.addrdata: + return 0 + if self.addrdata > other.addrdata: + return 1 + return -1 + _type_to_class = { + ECS: ECSOption } - def get_option_class(otype): + """Return the class for the specified option type. + + The GenericOption class is used if a more specific class is not + known. + """ + cls = _type_to_class.get(otype) if cls is None: cls = GenericOption @@ -134,17 +251,19 @@ def get_option_class(otype): def option_from_wire(otype, wire, current, olen): - """Build an EDNS option object from wire format + """Build an EDNS option object from wire format. - @param otype: The option type - @type otype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param olen: The length of the wire-format option data - @type olen: int - @rtype: dns.edns.Option instance""" + *otype*, an ``int``, is the option type. + + *wire*, a ``binary``, is the wire-format message. + + *current*, an ``int``, is the offset in *wire* of the beginning + of the rdata. + + *olen*, an ``int``, is the length of the wire-format option data + + Returns an instance of a subclass of ``dns.edns.Option``. + """ cls = get_option_class(otype) return cls.from_wire(otype, wire, current, olen) diff --git a/src/dns/entropy.py b/src/dns/entropy.py index de7a70a5..00c6a4b3 100644 --- a/src/dns/entropy.py +++ b/src/dns/entropy.py @@ -1,4 +1,6 @@ -# Copyright (C) 2009, 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2009-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -25,6 +27,11 @@ except ImportError: class EntropyPool(object): + # This is an entropy pool for Python implementations that do not + # have a working SystemRandom. I'm not sure there are any, but + # leaving this code doesn't hurt anything as the library code + # is used if present. + def __init__(self, seed=None): self.pool_index = 0 self.digest = None diff --git a/src/dns/entropy.pyi b/src/dns/entropy.pyi new file mode 100644 index 00000000..818f805a --- /dev/null +++ b/src/dns/entropy.pyi @@ -0,0 +1,10 @@ +from typing import Optional +from random import SystemRandom + +system_random : Optional[SystemRandom] + +def random_16() -> int: + pass + +def between(first: int, last: int) -> int: + pass diff --git a/src/dns/exception.py b/src/dns/exception.py index 6c0b1f4b..71ff04f1 100644 --- a/src/dns/exception.py +++ b/src/dns/exception.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,32 +15,35 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""Common DNS Exceptions.""" +"""Common DNS Exceptions. +Dnspython modules may also define their own exceptions, which will +always be subclasses of ``DNSException``. +""" class DNSException(Exception): - """Abstract base class shared by all dnspython exceptions. It supports two basic modes of operation: - a) Old/compatible mode is used if __init__ was called with - empty **kwargs. - In compatible mode all *args are passed to standard Python Exception class - as before and all *args are printed by standard __str__ implementation. - Class variable msg (or doc string if msg is None) is returned from str() - if *args is empty. + a) Old/compatible mode is used if ``__init__`` was called with + empty *kwargs*. In compatible mode all *args* are passed + to the standard Python Exception class as before and all *args* are + printed by the standard ``__str__`` implementation. Class variable + ``msg`` (or doc string if ``msg`` is ``None``) is returned from ``str()`` + if *args* is empty. - b) New/parametrized mode is used if __init__ was called with - non-empty **kwargs. - In the new mode *args has to be empty and all kwargs has to exactly match - set in class variable self.supp_kwargs. All kwargs are stored inside - self.kwargs and used in new __str__ implementation to construct - formatted message based on self.fmt string. + b) New/parametrized mode is used if ``__init__`` was called with + non-empty *kwargs*. + In the new mode *args* must be empty and all kwargs must match + those set in class variable ``supp_kwargs``. All kwargs are stored inside + ``self.kwargs`` and used in a new ``__str__`` implementation to construct + a formatted message based on the ``fmt`` class variable, a ``string``. - In the simplest case it is enough to override supp_kwargs and fmt - class variables to get nice parametrized messages. + In the simplest case it is enough to override the ``supp_kwargs`` + and ``fmt`` class variables to get nice parametrized messages. """ + msg = None # non-parametrized message supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) fmt = None # message parametrized with results from _fmt_kwargs @@ -102,27 +107,22 @@ class DNSException(Exception): class FormError(DNSException): - """DNS message is malformed.""" class SyntaxError(DNSException): - """Text input is malformed.""" class UnexpectedEnd(SyntaxError): - """Text input ended unexpectedly.""" class TooBig(DNSException): - """The DNS message is too big.""" class Timeout(DNSException): - """The DNS operation timed out.""" - supp_kwargs = set(['timeout']) + supp_kwargs = {'timeout'} fmt = "The DNS operation timed out after {timeout} seconds" diff --git a/src/dns/exception.pyi b/src/dns/exception.pyi new file mode 100644 index 00000000..4b346cc4 --- /dev/null +++ b/src/dns/exception.pyi @@ -0,0 +1,9 @@ +from typing import Set, Optional, Dict + +class DNSException(Exception): + supp_kwargs : Set[str] + kwargs : Optional[Dict] + +class SyntaxError(DNSException): ... +class FormError(DNSException): ... +class Timeout(DNSException): ... diff --git a/src/dns/flags.py b/src/dns/flags.py index 388d6aaa..0119dec7 100644 --- a/src/dns/flags.py +++ b/src/dns/flags.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -17,16 +19,24 @@ # Standard DNS flags +#: Query Response QR = 0x8000 +#: Authoritative Answer AA = 0x0400 +#: Truncated Response TC = 0x0200 +#: Recursion Desired RD = 0x0100 +#: Recursion Available RA = 0x0080 +#: Authentic Data AD = 0x0020 +#: Checking Disabled CD = 0x0010 # EDNS flags +#: DNSSEC answer OK DO = 0x8000 _by_text = { @@ -48,9 +58,9 @@ _edns_by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mappings not to be true inverses. -_by_value = dict((y, x) for x, y in _by_text.items()) +_by_value = {y: x for x, y in _by_text.items()} -_edns_by_value = dict((y, x) for x, y in _edns_by_text.items()) +_edns_by_value = {y: x for x, y in _edns_by_text.items()} def _order_flags(table): @@ -83,7 +93,9 @@ def _to_text(flags, table, order): def from_text(text): """Convert a space-separated list of flag text values into a flags value. - @rtype: int""" + + Returns an ``int`` + """ return _from_text(text, _by_text) @@ -91,7 +103,9 @@ def from_text(text): def to_text(flags): """Convert a flags value into a space-separated list of flag text values. - @rtype: string""" + + Returns a ``text``. + """ return _to_text(flags, _by_value, _flags_order) @@ -99,7 +113,9 @@ def to_text(flags): def edns_from_text(text): """Convert a space-separated list of EDNS flag text values into a EDNS flags value. - @rtype: int""" + + Returns an ``int`` + """ return _from_text(text, _edns_by_text) @@ -107,6 +123,8 @@ def edns_from_text(text): def edns_to_text(flags): """Convert an EDNS flags value into a space-separated list of EDNS flag text values. - @rtype: string""" + + Returns a ``text``. + """ return _to_text(flags, _edns_by_value, _edns_flags_order) diff --git a/src/dns/grange.py b/src/dns/grange.py index 9ce9f67a..ffe8be7c 100644 --- a/src/dns/grange.py +++ b/src/dns/grange.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2012-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -17,18 +19,16 @@ import dns - def from_text(text): - """Convert the text form of a range in a GENERATE statement to an + """Convert the text form of a range in a ``$GENERATE`` statement to an integer. - @param text: the textual range - @type text: string - @return: The start, stop and step values. - @rtype: tuple - """ - # TODO, figure out the bounds on start, stop and step. + *text*, a ``str``, the textual range in ``$GENERATE`` form. + Returns a tuple of three ``int`` values ``(start, stop, step)``. + """ + + # TODO, figure out the bounds on start, stop and step. step = 1 cur = '' state = 0 diff --git a/src/dns/hash.py b/src/dns/hash.py index 966838a1..1713e628 100644 --- a/src/dns/hash.py +++ b/src/dns/hash.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -16,7 +18,11 @@ """Hashing backwards compatibility wrapper""" import hashlib +import warnings +warnings.warn( + "dns.hash module will be removed in future versions. Please use hashlib instead.", + DeprecationWarning) hashes = {} hashes['MD5'] = hashlib.md5 diff --git a/src/dns/inet.py b/src/dns/inet.py index 73490a9d..c8d7c1b4 100644 --- a/src/dns/inet.py +++ b/src/dns/inet.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,6 +22,7 @@ import socket import dns.ipv4 import dns.ipv6 +from ._compat import maybe_ord # We assume that AF_INET is always defined. @@ -38,13 +41,14 @@ except AttributeError: def inet_pton(family, text): """Convert the textual form of a network address into its binary form. - @param family: the address family - @type family: int - @param text: the textual address - @type text: string - @raises NotImplementedError: the address family specified is not + *family* is an ``int``, the address family. + + *text* is a ``text``, the textual address. + + Raises ``NotImplementedError`` if the address family specified is not implemented. - @rtype: string + + Returns a ``binary``. """ if family == AF_INET: @@ -58,14 +62,16 @@ def inet_pton(family, text): def inet_ntop(family, address): """Convert the binary form of a network address into its textual form. - @param family: the address family - @type family: int - @param address: the binary address - @type address: string - @raises NotImplementedError: the address family specified is not + *family* is an ``int``, the address family. + + *address* is a ``binary``, the network address in binary form. + + Raises ``NotImplementedError`` if the address family specified is not implemented. - @rtype: string + + Returns a ``text``. """ + if family == AF_INET: return dns.ipv4.inet_ntoa(address) elif family == AF_INET6: @@ -77,11 +83,14 @@ def inet_ntop(family, address): def af_for_address(text): """Determine the address family of a textual-form network address. - @param text: the textual address - @type text: string - @raises ValueError: the address family cannot be determined from the input. - @rtype: int + *text*, a ``text``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns an ``int``. """ + try: dns.ipv4.inet_aton(text) return AF_INET @@ -96,16 +105,20 @@ def af_for_address(text): def is_multicast(text): """Is the textual-form network address a multicast address? - @param text: the textual address - @raises ValueError: the address family cannot be determined from the input. - @rtype: bool + *text*, a ``text``, the textual address. + + Raises ``ValueError`` if the address family cannot be determined + from the input. + + Returns a ``bool``. """ + try: - first = ord(dns.ipv4.inet_aton(text)[0]) + first = maybe_ord(dns.ipv4.inet_aton(text)[0]) return first >= 224 and first <= 239 except Exception: try: - first = ord(dns.ipv6.inet_aton(text)[0]) + first = maybe_ord(dns.ipv6.inet_aton(text)[0]) return first == 255 except Exception: raise ValueError diff --git a/src/dns/inet.pyi b/src/dns/inet.pyi new file mode 100644 index 00000000..6d9dcc70 --- /dev/null +++ b/src/dns/inet.pyi @@ -0,0 +1,4 @@ +from typing import Union +from socket import AddressFamily + +AF_INET6 : Union[int, AddressFamily] diff --git a/src/dns/ipv4.py b/src/dns/ipv4.py index 3fef282b..8fc4f7dc 100644 --- a/src/dns/ipv4.py +++ b/src/dns/ipv4.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -21,26 +23,28 @@ import dns.exception from ._compat import binary_type def inet_ntoa(address): - """Convert an IPv4 address in network form to text form. + """Convert an IPv4 address in binary form to text form. - @param address: The IPv4 address - @type address: string - @returns: string + *address*, a ``binary``, the IPv4 address in binary form. + + Returns a ``text``. """ + if len(address) != 4: raise dns.exception.SyntaxError if not isinstance(address, bytearray): address = bytearray(address) - return (u'%u.%u.%u.%u' % (address[0], address[1], - address[2], address[3])).encode() + return ('%u.%u.%u.%u' % (address[0], address[1], + address[2], address[3])) def inet_aton(text): - """Convert an IPv4 address in text form to network form. + """Convert an IPv4 address in text form to binary form. - @param text: The IPv4 address - @type text: string - @returns: string + *text*, a ``text``, the IPv4 address in textual form. + + Returns a ``binary``. """ + if not isinstance(text, binary_type): text = text.encode() parts = text.split(b'.') diff --git a/src/dns/ipv6.py b/src/dns/ipv6.py index cbaee8ed..128e56c8 100644 --- a/src/dns/ipv6.py +++ b/src/dns/ipv6.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -22,15 +24,15 @@ import dns.exception import dns.ipv4 from ._compat import xrange, binary_type, maybe_decode -_leading_zero = re.compile(b'0+([0-9a-f]+)') +_leading_zero = re.compile(r'0+([0-9a-f]+)') def inet_ntoa(address): - """Convert a network format IPv6 address into text. + """Convert an IPv6 address in binary form to text form. - @param address: the binary address - @type address: string - @rtype: string - @raises ValueError: the address isn't 16 bytes long + *address*, a ``binary``, the IPv6 address in binary form. + + Raises ``ValueError`` if the address isn't 16 bytes long. + Returns a ``text``. """ if len(address) != 16: @@ -40,7 +42,7 @@ def inet_ntoa(address): i = 0 l = len(hex) while i < l: - chunk = hex[i : i + 4] + chunk = maybe_decode(hex[i : i + 4]) # strip leading zeros. we do this with an re instead of # with lstrip() because lstrip() didn't support chars until # python 2.2.2 @@ -57,7 +59,7 @@ def inet_ntoa(address): start = -1 last_was_zero = False for i in xrange(8): - if chunks[i] != b'0': + if chunks[i] != '0': if last_was_zero: end = i current_len = end - start @@ -77,31 +79,30 @@ def inet_ntoa(address): if best_len > 1: if best_start == 0 and \ (best_len == 6 or - best_len == 5 and chunks[5] == b'ffff'): + best_len == 5 and chunks[5] == 'ffff'): # We have an embedded IPv4 address if best_len == 6: - prefix = b'::' + prefix = '::' else: - prefix = b'::ffff:' + prefix = '::ffff:' hex = prefix + dns.ipv4.inet_ntoa(address[12:]) else: - hex = b':'.join(chunks[:best_start]) + b'::' + \ - b':'.join(chunks[best_start + best_len:]) + hex = ':'.join(chunks[:best_start]) + '::' + \ + ':'.join(chunks[best_start + best_len:]) else: - hex = b':'.join(chunks) - return maybe_decode(hex) + hex = ':'.join(chunks) + return hex -_v4_ending = re.compile(b'(.*):(\d+\.\d+\.\d+\.\d+)$') -_colon_colon_start = re.compile(b'::.*') -_colon_colon_end = re.compile(b'.*::$') +_v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') +_colon_colon_start = re.compile(br'::.*') +_colon_colon_end = re.compile(br'.*::$') def inet_aton(text): - """Convert a text format IPv6 address into network format. + """Convert an IPv6 address in text form to binary form. - @param text: the textual address - @type text: string - @rtype: string - @raises dns.exception.SyntaxError: the text was not properly formatted + *text*, a ``text``, the IPv6 address in textual form. + + Returns a ``binary``. """ # @@ -118,8 +119,9 @@ def inet_aton(text): m = _v4_ending.match(text) if not m is None: b = bytearray(dns.ipv4.inet_aton(m.group(2))) - text = (u"%s:%02x%02x:%02x%02x" % (m.group(1).decode(), b[0], b[1], - b[2], b[3])).encode() + text = (u"{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), + b[0], b[1], b[2], + b[3])).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' @@ -169,4 +171,11 @@ def inet_aton(text): _mapped_prefix = b'\x00' * 10 + b'\xff\xff' def is_mapped(address): + """Is the specified address a mapped IPv4 address? + + *address*, a ``binary`` is an IPv6 address in binary form. + + Returns a ``bool``. + """ + return address.startswith(_mapped_prefix) diff --git a/src/dns/message.py b/src/dns/message.py index a0df18e6..9d2b2f43 100644 --- a/src/dns/message.py +++ b/src/dns/message.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -40,114 +42,46 @@ from ._compat import long, xrange, string_types class ShortHeader(dns.exception.FormError): - """The DNS packet passed to from_wire() is too short.""" class TrailingJunk(dns.exception.FormError): - """The DNS packet passed to from_wire() has extra junk at the end of it.""" class UnknownHeaderField(dns.exception.DNSException): - """The header field name was not recognized when converting from text into a message.""" class BadEDNS(dns.exception.FormError): - - """OPT record occurred somewhere other than the start of + """An OPT record occurred somewhere other than the start of the additional data section.""" class BadTSIG(dns.exception.FormError): - """A TSIG record occurred somewhere other than the end of the additional data section.""" class UnknownTSIGKey(dns.exception.DNSException): - """A TSIG with an unknown key was received.""" +#: The question section number +QUESTION = 0 + +#: The answer section number +ANSWER = 1 + +#: The authority section number +AUTHORITY = 2 + +#: The additional section number +ADDITIONAL = 3 + class Message(object): - - """A DNS message. - - @ivar id: The query id; the default is a randomly chosen id. - @type id: int - @ivar flags: The DNS flags of the message. @see: RFC 1035 for an - explanation of these flags. - @type flags: int - @ivar question: The question section. - @type question: list of dns.rrset.RRset objects - @ivar answer: The answer section. - @type answer: list of dns.rrset.RRset objects - @ivar authority: The authority section. - @type authority: list of dns.rrset.RRset objects - @ivar additional: The additional data section. - @type additional: list of dns.rrset.RRset objects - @ivar edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @ivar ednsflags: The EDNS flags - @type ednsflags: long - @ivar payload: The EDNS payload size. The default is 0. - @type payload: int - @ivar options: The EDNS options - @type options: list of dns.edns.Option objects - @ivar request_payload: The associated request's EDNS payload size. - @type request_payload: int - @ivar keyring: The TSIG keyring to use. The default is None. - @type keyring: dict - @ivar keyname: The TSIG keyname to use. The default is None. - @type keyname: dns.name.Name object - @ivar keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm. Constants for TSIG algorithms are defined - in dns.tsig, and the currently implemented algorithms are - HMAC_MD5, HMAC_SHA1, HMAC_SHA224, HMAC_SHA256, HMAC_SHA384, and - HMAC_SHA512. - @type keyalgorithm: string - @ivar request_mac: The TSIG MAC of the request message associated with - this message; used when validating TSIG signatures. @see: RFC 2845 for - more information on TSIG fields. - @type request_mac: string - @ivar fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @ivar original_id: TSIG original id; defaults to the message's id - @type original_id: int - @ivar tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @ivar other_data: TSIG other data. - @type other_data: string - @ivar mac: The TSIG MAC for this message. - @type mac: string - @ivar xfr: Is the message being used to contain the results of a DNS - zone transfer? The default is False. - @type xfr: bool - @ivar origin: The origin of the zone in messages which are used for - zone transfers or for DNS dynamic updates. The default is None. - @type origin: dns.name.Name object - @ivar tsig_ctx: The TSIG signature context associated with this - message. The default is None. - @type tsig_ctx: hmac.HMAC object - @ivar had_tsig: Did the message decoded from wire format have a TSIG - signature? - @type had_tsig: bool - @ivar multi: Is this message part of a multi-message sequence? The - default is false. This variable is used when validating TSIG signatures - on messages which are part of a zone transfer. - @type multi: bool - @ivar first: Is this message standalone, or the first of a multi - message sequence? This variable is used when validating TSIG signatures - on messages which are part of a zone transfer. - @type first: bool - @ivar index: An index of rrsets in the message. The index key is - (section, name, rdclass, rdtype, covers, deleting). Indexing can be - disabled by setting the index to None. - @type index: dict - """ + """A DNS message.""" def __init__(self, id=None): if id is None: @@ -167,12 +101,12 @@ class Message(object): self.keyring = None self.keyname = None self.keyalgorithm = dns.tsig.default_algorithm - self.request_mac = '' - self.other_data = '' + self.request_mac = b'' + self.other_data = b'' self.tsig_error = 0 self.fudge = 300 self.original_id = self.id - self.mac = '' + self.mac = b'' self.xfr = False self.origin = None self.tsig_ctx = None @@ -190,10 +124,10 @@ class Message(object): def to_text(self, origin=None, relativize=True, **kw): """Convert the message to text. - The I{origin}, I{relativize}, and any other keyword - arguments are passed to the rrset to_wire() method. + The *origin*, *relativize*, and any other keyword + arguments are passed to the RRset ``to_wire()`` method. - @rtype: string + Returns a ``text``. """ s = StringIO() @@ -209,6 +143,8 @@ class Message(object): s.write(u'eflags %s\n' % dns.flags.edns_to_text(self.ednsflags)) s.write(u'payload %d\n' % self.payload) + for opt in self.options: + s.write(u'option %s\n' % opt.to_text()) is_update = dns.opcode.is_update(self.flags) if is_update: s.write(u';ZONE\n') @@ -245,7 +181,10 @@ class Message(object): def __eq__(self, other): """Two messages are equal if they have the same content in the header, question, answer, and authority sections. - @rtype: bool""" + + Returns a ``bool``. + """ + if not isinstance(other, Message): return False if self.id != other.id: @@ -273,13 +212,14 @@ class Message(object): return True def __ne__(self, other): - """Are two messages not equal? - @rtype: bool""" return not self.__eq__(other) def is_response(self, other): - """Is other a response to self? - @rtype: bool""" + """Is this message a response to *other*? + + Returns a ``bool``. + """ + if other.flags & dns.flags.QR == 0 or \ self.id != other.id or \ dns.opcode.from_flags(self.flags) != \ @@ -299,14 +239,48 @@ class Message(object): return True def section_number(self, section): + """Return the "section number" of the specified section for use + in indexing. The question section is 0, the answer section is 1, + the authority section is 2, and the additional section is 3. + + *section* is one of the section attributes of this message. + + Raises ``ValueError`` if the section isn't known. + + Returns an ``int``. + """ + if section is self.question: - return 0 + return QUESTION elif section is self.answer: - return 1 + return ANSWER elif section is self.authority: - return 2 + return AUTHORITY elif section is self.additional: - return 3 + return ADDITIONAL + else: + raise ValueError('unknown section') + + def section_from_number(self, number): + """Return the "section number" of the specified section for use + in indexing. The question section is 0, the answer section is 1, + the authority section is 2, and the additional section is 3. + + *section* is one of the section attributes of this message. + + Raises ``ValueError`` if the section isn't known. + + Returns an ``int``. + """ + + if number == QUESTION: + return self.question + elif number == ANSWER: + return self.answer + elif number == AUTHORITY: + return self.authority + elif number == ADDITIONAL: + return self.additional else: raise ValueError('unknown section') @@ -315,30 +289,45 @@ class Message(object): force_unique=False): """Find the RRset with the given attributes in the specified section. - @param section: the section of the message to look in, e.g. - self.answer. - @type section: list of dns.rrset.RRset objects - @param name: the name of the RRset - @type name: dns.name.Name object - @param rdclass: the class of the RRset - @type rdclass: int - @param rdtype: the type of the RRset - @type rdtype: int - @param covers: the covers value of the RRset - @type covers: int - @param deleting: the deleting value of the RRset - @type deleting: int - @param create: If True, create the RRset if it is not found. - The created RRset is appended to I{section}. - @type create: bool - @param force_unique: If True and create is also True, create a - new RRset regardless of whether a matching RRset exists already. - @type force_unique: bool - @raises KeyError: the RRset was not found and create was False - @rtype: dns.rrset.RRset object""" + *section*, an ``int`` section number, or one of the section + attributes of this message. This specifies the + the section of the message to search. For example:: - key = (self.section_number(section), - name, rdclass, rdtype, covers, deleting) + my_message.find_rrset(my_message.answer, name, rdclass, rdtype) + my_message.find_rrset(dns.message.ANSWER, name, rdclass, rdtype) + + *name*, a ``dns.name.Name``, the name of the RRset. + + *rdclass*, an ``int``, the class of the RRset. + + *rdtype*, an ``int``, the type of the RRset. + + *covers*, an ``int`` or ``None``, the covers value of the RRset. + The default is ``None``. + + *deleting*, an ``int`` or ``None``, the deleting value of the RRset. + The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + Raises ``KeyError`` if the RRset was not found and create was + ``False``. + + Returns a ``dns.rrset.RRset object``. + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: rrset = self.index.get(key) @@ -363,26 +352,35 @@ class Message(object): If the RRset is not found, None is returned. - @param section: the section of the message to look in, e.g. - self.answer. - @type section: list of dns.rrset.RRset objects - @param name: the name of the RRset - @type name: dns.name.Name object - @param rdclass: the class of the RRset - @type rdclass: int - @param rdtype: the type of the RRset - @type rdtype: int - @param covers: the covers value of the RRset - @type covers: int - @param deleting: the deleting value of the RRset - @type deleting: int - @param create: If True, create the RRset if it is not found. - The created RRset is appended to I{section}. - @type create: bool - @param force_unique: If True and create is also True, create a - new RRset regardless of whether a matching RRset exists already. - @type force_unique: bool - @rtype: dns.rrset.RRset object or None""" + *section*, an ``int`` section number, or one of the section + attributes of this message. This specifies the + the section of the message to search. For example:: + + my_message.get_rrset(my_message.answer, name, rdclass, rdtype) + my_message.get_rrset(dns.message.ANSWER, name, rdclass, rdtype) + + *name*, a ``dns.name.Name``, the name of the RRset. + + *rdclass*, an ``int``, the class of the RRset. + + *rdtype*, an ``int``, the type of the RRset. + + *covers*, an ``int`` or ``None``, the covers value of the RRset. + The default is ``None``. + + *deleting*, an ``int`` or ``None``, the deleting value of the RRset. + The default is ``None``. + + *create*, a ``bool``. If ``True``, create the RRset if it is not found. + The created RRset is appended to *section*. + + *force_unique*, a ``bool``. If ``True`` and *create* is also ``True``, + create a new RRset regardless of whether a matching RRset exists + already. The default is ``False``. This is useful when creating + DDNS Update messages, as order matters for them. + + Returns a ``dns.rrset.RRset object`` or ``None``. + """ try: rrset = self.find_rrset(section, name, rdclass, rdtype, covers, @@ -395,17 +393,19 @@ class Message(object): """Return a string containing the message in DNS compressed wire format. - Additional keyword arguments are passed to the rrset to_wire() + Additional keyword arguments are passed to the RRset ``to_wire()`` method. - @param origin: The origin to be appended to any relative names. - @type origin: dns.name.Name object - @param max_size: The maximum size of the wire format output; default - is 0, which means 'the message's request payload, if nonzero, or - 65536'. - @type max_size: int - @raises dns.exception.TooBig: max_size was exceeded - @rtype: string + *origin*, a ``dns.name.Name`` or ``None``, the origin to be appended + to any relative names. + + *max_size*, an ``int``, the maximum size of the wire format + output; default is 0, which means "the message's request + payload, if nonzero, or 65535". + + Raises ``dns.exception.TooBig`` if *max_size* was exceeded. + + Returns a ``binary``. """ if max_size == 0: @@ -438,30 +438,34 @@ class Message(object): return r.get_wire() def use_tsig(self, keyring, keyname=None, fudge=300, - original_id=None, tsig_error=0, other_data='', + original_id=None, tsig_error=0, other_data=b'', algorithm=dns.tsig.default_algorithm): """When sending, a TSIG signature using the specified keyring and keyname should be added. - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @type keyname: dns.name.Name or string - @param fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @param original_id: TSIG original id; defaults to the message's id - @type original_id: int - @param tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @param other_data: TSIG other data. - @type other_data: string - @param algorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm + See the documentation of the Message class for a complete + description of the keyring dictionary. + + *keyring*, a ``dict``, the TSIG keyring to use. If a + *keyring* is specified but a *keyname* is not, then the key + used will be the first key in the *keyring*. Note that the + order of keys in a dictionary is not defined, so applications + should supply a keyname when a keyring is used, unless they + know the keyring contains only one key. + + *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key + to use; defaults to ``None``. The key must be defined in the keyring. + + *fudge*, an ``int``, the TSIG time fudge. + + *original_id*, an ``int``, the TSIG original id. If ``None``, + the message's id is used. + + *tsig_error*, an ``int``, the TSIG error code. + + *other_data*, a ``binary``, the TSIG other data. + + *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use. """ self.keyring = keyring @@ -483,23 +487,26 @@ class Message(object): def use_edns(self, edns=0, ednsflags=0, payload=1280, request_payload=None, options=None): """Configure EDNS behavior. - @param edns: The EDNS level to use. Specifying None, False, or -1 - means 'do not use EDNS', and in this case the other parameters are - ignored. Specifying True is equivalent to specifying 0, i.e. 'use - EDNS0'. - @type edns: int or bool or None - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param request_payload: The EDNS payload size to use when sending - this message. If not specified, defaults to the value of payload. - @type request_payload: int or None - @param options: The EDNS options - @type options: None or list of dns.edns.Option objects - @see: RFC 2671 + + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when + sending this message. If not specified, defaults to the value of + *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. """ + if edns is None or edns is False: edns = -1 if edns is True: @@ -525,11 +532,13 @@ class Message(object): def want_dnssec(self, wanted=True): """Enable or disable 'DNSSEC desired' flag in requests. - @param wanted: Is DNSSEC desired? If True, EDNS is enabled if - required, and then the DO bit is set. If False, the DO bit is - cleared if EDNS is enabled. - @type wanted: bool + + *wanted*, a ``bool``. If ``True``, then DNSSEC data is + desired in the response, EDNS is enabled if required, and then + the DO bit is set. If ``False``, the DO bit is cleared if + EDNS is enabled. """ + if wanted: if self.edns < 0: self.use_edns() @@ -539,14 +548,15 @@ class Message(object): def rcode(self): """Return the rcode. - @rtype: int + + Returns an ``int``. """ return dns.rcode.from_flags(self.flags, self.ednsflags) def set_rcode(self, rcode): """Set the rcode. - @param rcode: the rcode - @type rcode: int + + *rcode*, an ``int``, is the rcode to set. """ (value, evalue) = dns.rcode.to_flags(rcode) self.flags &= 0xFFF0 @@ -558,14 +568,15 @@ class Message(object): def opcode(self): """Return the opcode. - @rtype: int + + Returns an ``int``. """ return dns.opcode.from_flags(self.flags) def set_opcode(self, opcode): """Set the opcode. - @param opcode: the opcode - @type opcode: int + + *opcode*, an ``int``, is the opcode to set. """ self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) @@ -575,23 +586,16 @@ class _WireReader(object): """Wire format reader. - @ivar wire: the wire-format message. - @type wire: string - @ivar message: The message object being built - @type message: dns.message.Message object - @ivar current: When building a message object from wire format, this + wire: a binary, is the wire-format message. + message: The message object being built + current: When building a message object from wire format, this variable contains the offset from the beginning of wire of the next octet to be read. - @type current: int - @ivar updating: Is the message a dynamic update? - @type updating: bool - @ivar one_rr_per_rrset: Put each RR into its own RRset? - @type one_rr_per_rrset: bool - @ivar ignore_trailing: Ignore trailing junk at end of request? - @type ignore_trailing: bool - @ivar zone_rdclass: The class of the zone in messages which are + updating: Is the message a dynamic update? + one_rr_per_rrset: Put each RR into its own RRset? + ignore_trailing: Ignore trailing junk at end of request? + zone_rdclass: The class of the zone in messages which are DNS dynamic updates. - @type zone_rdclass: int """ def __init__(self, wire, message, question_only=False, @@ -606,10 +610,9 @@ class _WireReader(object): self.ignore_trailing = ignore_trailing def _get_question(self, qcount): - """Read the next I{qcount} records from the wire data and add them to + """Read the next *qcount* records from the wire data and add them to the question section. - @param qcount: the number of questions in the message - @type qcount: int""" + """ if self.updating and qcount > 1: raise dns.exception.FormError @@ -632,10 +635,10 @@ class _WireReader(object): def _get_section(self, section, count): """Read the next I{count} records from the wire data and add them to the specified section. - @param section: the section of the message to which to add records - @type section: list of dns.rrset.RRset objects - @param count: the number of records to read - @type count: int""" + + section: the section of the message to which to add records + count: the number of records to read + """ if self.updating or self.one_rr_per_rrset: force_unique = True @@ -753,45 +756,58 @@ class _WireReader(object): self.message.tsig_ctx.update(self.wire) -def from_wire(wire, keyring=None, request_mac='', xfr=False, origin=None, +def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, first=True, question_only=False, one_rr_per_rrset=False, ignore_trailing=False): """Convert a DNS wire format message into a message object. - @param keyring: The keyring to use if the message is signed. - @type keyring: dict - @param request_mac: If the message is a response to a TSIG-signed request, - I{request_mac} should be set to the MAC of that request. - @type request_mac: string - @param xfr: Is this message part of a zone transfer? - @type xfr: bool - @param origin: If the message is part of a zone transfer, I{origin} - should be the origin name of the zone. - @type origin: dns.name.Name object - @param tsig_ctx: The ongoing TSIG context, used when validating zone - transfers. - @type tsig_ctx: hmac.HMAC object - @param multi: Is this message part of a multiple message sequence? - @type multi: bool - @param first: Is this message standalone, or the first of a multi - message sequence? - @type first: bool - @param question_only: Read only up to the end of the question section? - @type question_only: bool - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool - @param ignore_trailing: Ignore trailing junk at end of request? - @type ignore_trailing: bool - @raises ShortHeader: The message is less than 12 octets long. - @raises TrailingJunk: There were octets in the message past the end - of the proper DNS message. - @raises BadEDNS: An OPT record was in the wrong section, or occurred more - than once. - @raises BadTSIG: A TSIG record was not the last record of the additional - data section. - @rtype: dns.message.Message object""" + *keyring*, a ``dict``, the keyring to use if the message is signed. + + *request_mac*, a ``binary``. If the message is a response to a + TSIG-signed request, *request_mac* should be set to the MAC of + that request. + + *xfr*, a ``bool``, should be set to ``True`` if this message is part of + a zone transfer. + + *origin*, a ``dns.name.Name`` or ``None``. If the message is part + of a zone transfer, *origin* should be the origin name of the + zone. + + *tsig_ctx*, a ``hmac.HMAC`` objext, the ongoing TSIG context, used + when validating zone transfers. + + *multi*, a ``bool``, should be set to ``True`` if this message + part of a multiple message sequence. + + *first*, a ``bool``, should be set to ``True`` if this message is + stand-alone, or the first message in a multi-message sequence. + + *question_only*, a ``bool``. If ``True``, read only up to + the end of the question section. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its + own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the message. + + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets + long. + + Raises ``dns.messaage.TrailingJunk`` if there were octets in the message + past the end of the proper DNS message, and *ignore_trailing* is ``False``. + + Raises ``dns.message.BadEDNS`` if an OPT record was in the + wrong section, or occurred more than once. + + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last + record of the additional data section. + + Returns a ``dns.message.Message``. + """ m = Message(id=0) m.keyring = keyring @@ -813,18 +829,12 @@ class _TextReader(object): """Text format reader. - @ivar tok: the tokenizer - @type tok: dns.tokenizer.Tokenizer object - @ivar message: The message object being built - @type message: dns.message.Message object - @ivar updating: Is the message a dynamic update? - @type updating: bool - @ivar zone_rdclass: The class of the zone in messages which are + tok: the tokenizer. + message: The message object being built. + updating: Is the message a dynamic update? + zone_rdclass: The class of the zone in messages which are DNS dynamic updates. - @type zone_rdclass: int - @ivar last_name: The most recently read name when building a message object - from text format. - @type last_name: dns.name.Name object + last_name: The most recently read name when building a message object. """ def __init__(self, text, message): @@ -997,11 +1007,14 @@ class _TextReader(object): def from_text(text): """Convert the text format message into a message object. - @param text: The text format message. - @type text: string - @raises UnknownHeaderField: - @raises dns.exception.SyntaxError: - @rtype: dns.message.Message object""" + *text*, a ``text``, the text format message. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ # 'text' can also be a file, but we don't publish that fact # since it's an implementation detail. The official file @@ -1018,11 +1031,15 @@ def from_text(text): def from_file(f): """Read the next text format message from the specified file. - @param f: file or string. If I{f} is a string, it is treated - as the name of a file to open. - @raises UnknownHeaderField: - @raises dns.exception.SyntaxError: - @rtype: dns.message.Message object""" + *f*, a ``file`` or ``text``. If *f* is text, it is treated as the + pathname of a file to open. + + Raises ``dns.message.UnknownHeaderField`` if a header is unknown. + + Raises ``dns.exception.SyntaxError`` if the text is badly formed. + + Returns a ``dns.message.Message object`` + """ str_type = string_types opts = 'rU' @@ -1052,30 +1069,35 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, The query will have a randomly chosen query id, and its DNS flags will be set to dns.flags.RD. - @param qname: The query name. - @type qname: dns.name.Name object or string - @param rdtype: The desired rdata type. - @type rdtype: int - @param rdclass: The desired rdata class; the default is class IN. - @type rdclass: int - @param use_edns: The EDNS level to use; the default is None (no EDNS). + qname, a ``dns.name.Name`` or ``text``, the query name. + + *rdtype*, an ``int`` or ``text``, the desired rdata type. + + *rdclass*, an ``int`` or ``text``, the desired rdata class; the default + is class IN. + + *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the + default is None (no EDNS). See the description of dns.message.Message.use_edns() for the possible values for use_edns and their meanings. - @type use_edns: int or bool or None - @param want_dnssec: Should the query indicate that DNSSEC is desired? - @type want_dnssec: bool - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param request_payload: The EDNS payload size to use when sending - this message. If not specified, defaults to the value of payload. - @type request_payload: int or None - @param options: The EDNS options - @type options: None or list of dns.edns.Option objects - @see: RFC 2671 - @rtype: dns.message.Message object""" + + *want_dnssec*, a ``bool``. If ``True``, DNSSEC data is desired. + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + + *request_payload*, an ``int``, is the EDNS payload size to use when + sending this message. If not specified, defaults to the value of + *payload*. + + *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS + options. + + Returns a ``dns.message.Message`` + """ if isinstance(qname, string_types): qname = dns.name.from_text(qname) @@ -1124,16 +1146,17 @@ def make_response(query, recursion_available=False, our_payload=8192, question section, so the query's question RRsets should not be changed. - @param query: the query to respond to - @type query: dns.message.Message object - @param recursion_available: should RA be set in the response? - @type recursion_available: bool - @param our_payload: payload size to advertise in EDNS responses; default - is 8192. - @type our_payload: int - @param fudge: TSIG time fudge; default is 300 seconds. - @type fudge: int - @rtype: dns.message.Message object""" + *query*, a ``dns.message.Message``, the query to respond to. + + *recursion_available*, a ``bool``, should RA be set in the response? + + *our_payload*, an ``int``, the payload size to advertise in EDNS + responses. + + *fudge*, an ``int``, the TSIG time fudge. + + Returns a ``dns.message.Message`` object. + """ if query.flags & dns.flags.QR: raise dns.exception.FormError('specified query message is not a query') @@ -1146,7 +1169,7 @@ def make_response(query, recursion_available=False, our_payload=8192, if query.edns >= 0: response.use_edns(0, 0, our_payload, query.payload) if query.had_tsig: - response.use_tsig(query.keyring, query.keyname, fudge, None, 0, '', + response.use_tsig(query.keyring, query.keyname, fudge, None, 0, b'', query.keyalgorithm) response.request_mac = query.mac return response diff --git a/src/dns/message.pyi b/src/dns/message.pyi new file mode 100644 index 00000000..ed99b3c0 --- /dev/null +++ b/src/dns/message.pyi @@ -0,0 +1,55 @@ +from typing import Optional, Dict, List, Tuple, Union +from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass +import hmac + +class Message: + def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes: + ... + def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int, + covers=rdatatype.NONE, deleting : Optional[int]=None, create=False, + force_unique=False) -> rrset.RRset: + ... + def __init__(self, id : Optional[int] =None) -> None: + self.id : int + self.flags = 0 + self.question : List[rrset.RRset] = [] + self.answer : List[rrset.RRset] = [] + self.authority : List[rrset.RRset] = [] + self.additional : List[rrset.RRset] = [] + self.edns = -1 + self.ednsflags = 0 + self.payload = 0 + self.options : List[edns.Option] = [] + self.request_payload = 0 + self.keyring = None + self.keyname = None + self.keyalgorithm = tsig.default_algorithm + self.request_mac = b'' + self.other_data = b'' + self.tsig_error = 0 + self.fudge = 300 + self.original_id = self.id + self.mac = b'' + self.xfr = False + self.origin = None + self.tsig_ctx = None + self.had_tsig = False + self.multi = False + self.first = True + self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {} +def from_text(a : str) -> Message: + ... + +def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, + tsig_ctx : Optional[hmac.HMAC] = None, multi=False, first=True, + question_only=False, one_rr_per_rrset=False, + ignore_trailing=False) -> Message: + ... +def make_response(query : Message, recursion_available=False, our_payload=8192, + fudge=300) -> Message: + ... + +def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None, + want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None, + request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message: + ... diff --git a/src/dns/name.py b/src/dns/name.py index 97e216c8..0bcfd834 100644 --- a/src/dns/name.py +++ b/src/dns/name.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -14,11 +16,6 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. """DNS Names. - -@var root: The DNS root name. -@type root: dns.name.Name object -@var empty: The empty DNS name. -@type empty: dns.name.Name object """ from io import BytesIO @@ -38,79 +35,76 @@ import dns.wiredata from ._compat import long, binary_type, text_type, unichr, maybe_decode try: - maxint = sys.maxint + maxint = sys.maxint # pylint: disable=sys-max-int except AttributeError: maxint = (1 << (8 * struct.calcsize("P"))) // 2 - 1 + +# fullcompare() result values + +#: The compared names have no relationship to each other. NAMERELN_NONE = 0 +#: the first name is a superdomain of the second. NAMERELN_SUPERDOMAIN = 1 +#: The first name is a subdomain of the second. NAMERELN_SUBDOMAIN = 2 +#: The compared names are equal. NAMERELN_EQUAL = 3 +#: The compared names have a common ancestor. NAMERELN_COMMONANCESTOR = 4 class EmptyLabel(dns.exception.SyntaxError): - """A DNS label is empty.""" class BadEscape(dns.exception.SyntaxError): - """An escaped code in a text format of DNS name is invalid.""" class BadPointer(dns.exception.FormError): - """A DNS compression pointer points forward instead of backward.""" class BadLabelType(dns.exception.FormError): - """The label type in DNS name wire format is unknown.""" class NeedAbsoluteNameOrOrigin(dns.exception.DNSException): - """An attempt was made to convert a non-absolute name to wire when there was also a non-absolute (or missing) origin.""" class NameTooLong(dns.exception.FormError): - """A DNS name is > 255 octets long.""" class LabelTooLong(dns.exception.SyntaxError): - """A DNS label is > 63 octets long.""" class AbsoluteConcatenation(dns.exception.DNSException): - """An attempt was made to append anything other than the empty name to an absolute DNS name.""" class NoParent(dns.exception.DNSException): - """An attempt was made to get the parent of the root name or the empty name.""" class NoIDNA2008(dns.exception.DNSException): - """IDNA 2008 processing was requested but the idna module is not available.""" class IDNAException(dns.exception.DNSException): - """IDNA processing raised an exception.""" - supp_kwargs = set(['idna_exception']) + supp_kwargs = {'idna_exception'} fmt = "IDNA processing exception: {idna_exception}" -class IDNACodec(object): +class IDNACodec(object): """Abstract base class for IDNA encoder/decoders.""" def __init__(self): @@ -131,21 +125,24 @@ class IDNACodec(object): label = maybe_decode(label) return _escapify(label, True) -class IDNA2003Codec(IDNACodec): +class IDNA2003Codec(IDNACodec): """IDNA 2003 encoder/decoder.""" def __init__(self, strict_decode=False): """Initialize the IDNA 2003 encoder/decoder. - @param strict_decode: If True, then IDNA2003 checking is done when - decoding. This can cause failures if the name was encoded with - IDNA2008. The default is False. - @type strict_decode: bool + + *strict_decode* is a ``bool``. If `True`, then IDNA2003 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2008. The default is `False`. """ + super(IDNA2003Codec, self).__init__() self.strict_decode = strict_decode def encode(self, label): + """Encode *label*.""" + if label == '': return b'' try: @@ -154,6 +151,7 @@ class IDNA2003Codec(IDNACodec): raise LabelTooLong def decode(self, label): + """Decode *label*.""" if not self.strict_decode: return super(IDNA2003Codec, self).decode(label) if label == b'': @@ -163,34 +161,34 @@ class IDNA2003Codec(IDNACodec): except Exception as e: raise IDNAException(idna_exception=e) -class IDNA2008Codec(IDNACodec): - """IDNA 2008 encoder/decoder.""" +class IDNA2008Codec(IDNACodec): + """IDNA 2008 encoder/decoder. + + *uts_46* is a ``bool``. If True, apply Unicode IDNA + compatibility processing as described in Unicode Technical + Standard #46 (http://unicode.org/reports/tr46/). + If False, do not apply the mapping. The default is False. + + *transitional* is a ``bool``: If True, use the + "transitional" mode described in Unicode Technical Standard + #46. The default is False. + + *allow_pure_ascii* is a ``bool``. If True, then a label which + consists of only ASCII characters is allowed. This is less + strict than regular IDNA 2008, but is also necessary for mixed + names, e.g. a name with starting with "_sip._tcp." and ending + in an IDN suffix which would otherwise be disallowed. The + default is False. + + *strict_decode* is a ``bool``: If True, then IDNA2008 checking + is done when decoding. This can cause failures if the name + was encoded with IDNA2003. The default is False. + """ def __init__(self, uts_46=False, transitional=False, allow_pure_ascii=False, strict_decode=False): - """Initialize the IDNA 2008 encoder/decoder. - @param uts_46: If True, apply Unicode IDNA compatibility processing - as described in Unicode Technical Standard #46 - (U{http://unicode.org/reports/tr46/}). This parameter is only - meaningful if IDNA 2008 is in use. If False, do not apply - the mapping. The default is False - @type uts_46: bool - @param transitional: If True, use the "transitional" mode described - in Unicode Technical Standard #46. This parameter is only - meaningful if IDNA 2008 is in use. The default is False. - @type transitional: bool - @param allow_pure_ascii: If True, then a label which - consists of only ASCII characters is allowed. This is less strict - than regular IDNA 2008, but is also necessary for mixed names, - e.g. a name with starting with "_sip._tcp." and ending in an IDN - suffixm which would otherwise be disallowed. The default is False - @type allow_pure_ascii: bool - @param strict_decode: If True, then IDNA2008 checking is done when - decoding. This can cause failures if the name was encoded with - IDNA2003. The default is False. - @type strict_decode: bool - """ + """Initialize the IDNA 2008 encoder/decoder.""" super(IDNA2008Codec, self).__init__() self.uts_46 = uts_46 self.transitional = transitional @@ -277,9 +275,14 @@ def _escapify(label, unicode_mode=False): def _validate_labels(labels): """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. - @raises NameTooLong: the name as a whole is too long - @raises EmptyLabel: a label is empty (i.e. the root label) and appears - in a position other than the end of the label sequence""" + + Raises ``dns.name.NameTooLong`` if the name as a whole is too long. + + Raises ``dns.name.EmptyLabel`` if a label is empty (i.e. the root + label) and appears in a position other than the end of the label + sequence + + """ l = len(labels) total = 0 @@ -299,7 +302,12 @@ def _validate_labels(labels): raise EmptyLabel -def _ensure_bytes(label): +def _maybe_convert_to_binary(label): + """If label is ``text``, convert it to ``binary``. If it is already + ``binary`` just return it. + + """ + if isinstance(label, binary_type): return label if isinstance(label, text_type): @@ -311,24 +319,23 @@ class Name(object): """A DNS name. - The dns.name.Name class represents a DNS name as a tuple of labels. - Instances of the class are immutable. - - @ivar labels: The tuple of labels in the name. Each label is a string of - up to 63 octets.""" + The dns.name.Name class represents a DNS name as a tuple of + labels. Each label is a `binary` in DNS wire format. Instances + of the class are immutable. + """ __slots__ = ['labels'] def __init__(self, labels): - """Initialize a domain name from a list of labels. - @param labels: the labels - @type labels: any iterable whose values are strings + """*labels* is any iterable whose values are ``text`` or ``binary``. """ - labels = [_ensure_bytes(x) for x in labels] + + labels = [_maybe_convert_to_binary(x) for x in labels] super(Name, self).__setattr__('labels', tuple(labels)) _validate_labels(self.labels) def __setattr__(self, name, value): + # Names are immutable raise TypeError("object doesn't support attribute assignment") def __copy__(self): @@ -338,6 +345,7 @@ class Name(object): return Name(copy.deepcopy(self.labels, memo)) def __getstate__(self): + # Names can be pickled return {'labels': self.labels} def __setstate__(self, state): @@ -346,21 +354,24 @@ class Name(object): def is_absolute(self): """Is the most significant label of this name the root label? - @rtype: bool + + Returns a ``bool``. """ return len(self.labels) > 0 and self.labels[-1] == b'' def is_wild(self): """Is this name wild? (I.e. Is the least significant label '*'?) - @rtype: bool + + Returns a ``bool``. """ return len(self.labels) > 0 and self.labels[0] == b'*' def __hash__(self): """Return a case-insensitive hash of the name. - @rtype: int + + Returns an ``int``. """ h = long(0) @@ -370,20 +381,35 @@ class Name(object): return int(h % maxint) def fullcompare(self, other): - """Compare two names, returning a 3-tuple (relation, order, nlabels). + """Compare two names, returning a 3-tuple + ``(relation, order, nlabels)``. - I{relation} describes the relation ship between the names, - and is one of: dns.name.NAMERELN_NONE, - dns.name.NAMERELN_SUPERDOMAIN, dns.name.NAMERELN_SUBDOMAIN, - dns.name.NAMERELN_EQUAL, or dns.name.NAMERELN_COMMONANCESTOR + *relation* describes the relation ship between the names, + and is one of: ``dns.name.NAMERELN_NONE``, + ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``, + ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``. - I{order} is < 0 if self < other, > 0 if self > other, and == - 0 if self == other. A relative name is always less than an + *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == + 0 if *self* == *other*. A relative name is always less than an absolute name. If both names have the same relativity, then the DNSSEC order relation is used to order them. - I{nlabels} is the number of significant labels that the two names + *nlabels* is the number of significant labels that the two names have in common. + + Here are some examples. Names ending in "." are absolute names, + those not ending in "." are relative names. + + ============= ============= =========== ===== ======= + self other relation order nlabels + ============= ============= =========== ===== ======= + www.example. www.example. equal 0 3 + www.example. example. subdomain > 0 2 + example. www.example. superdomain < 0 2 + example1.com. example2.com. common anc. < 0 2 + example1 example2. none < 0 0 + example1. example2 none > 0 0 + ============= ============= =========== ===== ======= """ sabs = self.is_absolute() @@ -433,8 +459,10 @@ class Name(object): def is_subdomain(self, other): """Is self a subdomain of other? - The notion of subdomain includes equality. - @rtype: bool + Note that the notion of subdomain includes equality, e.g. + "dnpython.org" is a subdomain of itself. + + Returns a ``bool``. """ (nr, o, nl) = self.fullcompare(other) @@ -445,8 +473,10 @@ class Name(object): def is_superdomain(self, other): """Is self a superdomain of other? - The notion of subdomain includes equality. - @rtype: bool + Note that the notion of superdomain includes equality, e.g. + "dnpython.org" is a superdomain of itself. + + Returns a ``bool``. """ (nr, o, nl) = self.fullcompare(other) @@ -457,7 +487,6 @@ class Name(object): def canonicalize(self): """Return a name which is equal to the current name, but is in DNSSEC canonical form. - @rtype: dns.name.Name object """ return Name([x.lower() for x in self.labels]) @@ -505,10 +534,13 @@ class Name(object): return self.to_text(False) def to_text(self, omit_final_dot=False): - """Convert name to text format. - @param omit_final_dot: If True, don't emit the final dot (denoting the - root label) for absolute names. The default is False. - @rtype: string + """Convert name to DNS text format. + + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + + Returns a ``text``. """ if len(self.labels) == 0: @@ -527,16 +559,17 @@ class Name(object): IDN ACE labels are converted to Unicode. - @param omit_final_dot: If True, don't emit the final dot (denoting the - root label) for absolute names. The default is False. - @type omit_final_dot: bool - @param idna_codec: IDNA encoder/decoder. If None, the - IDNA_2003_Practical encoder/decoder is used. The IDNA_2003_Practical - decoder does not impose any policy, it just decodes punycode, so if - you don't want checking for compliance, you can use this decoder for - IDNA2008 as well. - @type idna_codec: dns.name.IDNA - @rtype: string + *omit_final_dot* is a ``bool``. If True, don't emit the final + dot (denoting the root label) for absolute names. The default + is False. + *idna_codec* specifies the IDNA encoder/decoder. If None, the + dns.name.IDNA_2003_Practical encoder/decoder is used. + The IDNA_2003_Practical decoder does + not impose any policy, it just decodes punycode, so if you + don't want checking for compliance, you can use this decoder + for IDNA2008 as well. + + Returns a ``text``. """ if len(self.labels) == 0: @@ -554,15 +587,18 @@ class Name(object): def to_digestable(self, origin=None): """Convert name to a format suitable for digesting in hashes. - The name is canonicalized and converted to uncompressed wire format. + The name is canonicalized and converted to uncompressed wire + format. All names in wire format are absolute. If the name + is a relative name, then an origin must be supplied. - @param origin: If the name is relative and origin is not None, then - origin will be appended to it. - @type origin: dns.name.Name object - @raises NeedAbsoluteNameOrOrigin: All names in wire format are - absolute. If self is a relative name, then an origin must be supplied; - if it is missing, then this exception is raised - @rtype: string + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then origin will be appended + to the name. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``binary``. """ if not self.is_absolute(): @@ -579,19 +615,21 @@ class Name(object): def to_wire(self, file=None, compress=None, origin=None): """Convert name to wire format, possibly compressing it. - @param file: the file where the name is emitted (typically - a BytesIO file). If None, a string containing the wire name - will be returned. - @type file: file or None - @param compress: The compression table. If None (the default) names - will not be compressed. - @type compress: dict - @param origin: If the name is relative and origin is not None, then - origin will be appended to it. - @type origin: dns.name.Name object - @raises NeedAbsoluteNameOrOrigin: All names in wire format are - absolute. If self is a relative name, then an origin must be supplied; - if it is missing, then this exception is raised + *file* is the file where the name is emitted (typically a + BytesIO file). If ``None`` (the default), a ``binary`` + containing the wire name will be returned. + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + Raises ``dns.name.NeedAbsoluteNameOrOrigin`` if the name is + relative and no origin was provided. + + Returns a ``binary`` or ``None``. """ if file is None: @@ -634,7 +672,8 @@ class Name(object): def __len__(self): """The length of the name (in labels). - @rtype: int + + Returns an ``int``. """ return len(self.labels) @@ -649,14 +688,14 @@ class Name(object): return self.relativize(other) def split(self, depth): - """Split a name into a prefix and suffix at depth. + """Split a name into a prefix and suffix names at the specified depth. - @param depth: the number of labels in the suffix - @type depth: int - @raises ValueError: the depth was not >= 0 and <= the length of the + *depth* is an ``int`` specifying the number of labels in the suffix + + Raises ``ValueError`` if *depth* was not >= 0 and <= the length of the name. - @returns: the tuple (prefix, suffix) - @rtype: tuple + + Returns the tuple ``(prefix, suffix)``. """ l = len(self.labels) @@ -671,9 +710,11 @@ class Name(object): def concatenate(self, other): """Return a new name which is the concatenation of self and other. - @rtype: dns.name.Name object - @raises AbsoluteConcatenation: self is absolute and other is - not the empty name + + Raises ``dns.name.AbsoluteConcatenation`` if the name is + absolute and *other* is not the empty name. + + Returns a ``dns.name.Name``. """ if self.is_absolute() and len(other) > 0: @@ -683,9 +724,14 @@ class Name(object): return Name(labels) def relativize(self, origin): - """If self is a subdomain of origin, return a new name which is self - relative to origin. Otherwise return self. - @rtype: dns.name.Name object + """If the name is a subdomain of *origin*, return a new name which is + the name relative to origin. Otherwise return the name. + + For example, relativizing ``www.dnspython.org.`` to origin + ``dnspython.org.`` returns the name ``www``. Relativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. """ if origin is not None and self.is_subdomain(origin): @@ -694,9 +740,14 @@ class Name(object): return self def derelativize(self, origin): - """If self is a relative name, return a new name which is the - concatenation of self and origin. Otherwise return self. - @rtype: dns.name.Name object + """If the name is a relative name, return a new name which is the + concatenation of the name and origin. Otherwise return the name. + + For example, derelativizing ``www`` to origin ``dnspython.org.`` + returns the name ``www.dnspython.org.``. Derelativizing ``example.`` + to origin ``dnspython.org.`` returns ``example.``. + + Returns a ``dns.name.Name``. """ if not self.is_absolute(): @@ -705,11 +756,14 @@ class Name(object): return self def choose_relativity(self, origin=None, relativize=True): - """Return a name with the relativity desired by the caller. If - origin is None, then self is returned. Otherwise, if - relativize is true the name is relativized, and if relativize is - false the name is derelativized. - @rtype: dns.name.Name object + """Return a name with the relativity desired by the caller. + + If *origin* is ``None``, then the name is returned. + Otherwise, if *relativize* is ``True`` the name is + relativized, and if *relativize* is ``False`` the name is + derelativized. + + Returns a ``dns.name.Name``. """ if origin: @@ -722,31 +776,41 @@ class Name(object): def parent(self): """Return the parent of the name. - @rtype: dns.name.Name object - @raises NoParent: the name is either the root name or the empty name, - and thus has no parent. + + For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. + + Raises ``dns.name.NoParent`` if the name is either the root name or the + empty name, and thus has no parent. + + Returns a ``dns.name.Name``. """ + if self == root or self == empty: raise NoParent return Name(self.labels[1:]) +#: The root name, '.' root = Name([b'']) -empty = Name([]) +#: The empty name. +empty = Name([]) def from_unicode(text, origin=root, idna_codec=None): """Convert unicode text into a Name object. - Labels are encoded in IDN ACE form. + Labels are encoded in IDN ACE form according to rules specified by + the IDNA codec. - @param text: The text to convert into a name. - @type text: Unicode string - @param origin: The origin to append to non-absolute names. - @type origin: dns.name.Name - @param idna_codec: IDNA encoder/decoder. If None, the default IDNA 2003 - encoder/decoder is used. - @type idna_codec: dns.name.IDNA - @rtype: dns.name.Name object + *text*, a ``text``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. """ if not isinstance(text, text_type): @@ -809,14 +873,16 @@ def from_unicode(text, origin=root, idna_codec=None): def from_text(text, origin=root, idna_codec=None): """Convert text into a Name object. - @param text: The text to convert into a name. - @type text: string - @param origin: The origin to append to non-absolute names. - @type origin: dns.name.Name - @param idna_codec: IDNA encoder/decoder. If None, the default IDNA 2003 - encoder/decoder is used. - @type idna_codec: dns.name.IDNA - @rtype: dns.name.Name object + *text*, a ``text``, is the text to convert into a name. + + *origin*, a ``dns.name.Name``, specifies the origin to + append to non-absolute names. The default is the root name. + + *idna_codec*, a ``dns.name.IDNACodec``, specifies the IDNA + encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder + is used. + + Returns a ``dns.name.Name``. """ if isinstance(text, text_type): @@ -878,17 +944,21 @@ def from_text(text, origin=root, idna_codec=None): def from_wire(message, current): """Convert possibly compressed wire format into a Name. - @param message: the entire DNS message - @type message: string - @param current: the offset of the beginning of the name from the start - of the message - @type current: int - @raises dns.name.BadPointer: a compression pointer did not point backwards - in the message - @raises dns.name.BadLabelType: an invalid label type was encountered. - @returns: a tuple consisting of the name that was read and the number - of bytes of the wire format message which were consumed reading it - @rtype: (dns.name.Name object, int) tuple + + *message* is a ``binary`` containing an entire DNS message in DNS + wire form. + + *current*, an ``int``, is the offset of the beginning of the name + from the start of the message + + Raises ``dns.name.BadPointer`` if a compression pointer did not + point backwards in the message. + + Raises ``dns.name.BadLabelType`` if an invalid label type was encountered. + + Returns a ``(dns.name.Name, int)`` tuple consisting of the name + that was read and the number of bytes of the wire format message + which were consumed reading it. """ if not isinstance(message, binary_type): diff --git a/src/dns/name.pyi b/src/dns/name.pyi new file mode 100644 index 00000000..5a8061b2 --- /dev/null +++ b/src/dns/name.pyi @@ -0,0 +1,35 @@ +from typing import Optional, Union, Tuple, Iterable, List + +class Name: + def is_subdomain(self, o : Name) -> bool: ... + def is_superdomain(self, o : Name) -> bool: ... + def __init__(self, labels : Iterable[Union[bytes,str]]) -> None: + self.labels : List[bytes] + def is_absolute(self) -> bool: ... + def is_wild(self) -> bool: ... + def fullcompare(self, other) -> Tuple[int,int,int]: ... + def canonicalize(self) -> Name: ... + def __lt__(self, other : Name): ... + def __le__(self, other : Name): ... + def __ge__(self, other : Name): ... + def __gt__(self, other : Name): ... + def to_text(self, omit_final_dot=False) -> str: ... + def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ... + def to_digestable(self, origin=None) -> bytes: ... + def to_wire(self, file=None, compress=None, origin=None) -> Optional[bytes]: ... + def __add__(self, other : Name): ... + def __sub__(self, other : Name): ... + def split(self, depth) -> List[Tuple[str,str]]: ... + def concatenate(self, other : Name) -> Name: ... + def relativize(self, origin): ... + def derelativize(self, origin): ... + def choose_relativity(self, origin : Optional[Name] = None, relativize=True): ... + def parent(self) -> Name: ... + +class IDNACodec: + pass + +def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name: + ... + +empty : Name diff --git a/src/dns/namedict.py b/src/dns/namedict.py index 58e40344..37a13104 100644 --- a/src/dns/namedict.py +++ b/src/dns/namedict.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # Copyright (C) 2016 Coresec Systems AB # # Permission to use, copy, modify, and distribute this software and its @@ -31,20 +33,20 @@ from ._compat import xrange class NameDict(collections.MutableMapping): - """A dictionary whose keys are dns.name.Name objects. - @ivar max_depth: the maximum depth of the keys that have ever been - added to the dictionary. - @type max_depth: int - @ivar max_depth_items: the number of items of maximum depth - @type max_depth_items: int + + In addition to being like a regular Python dictionary, this + dictionary can also get the deepest match for a given key. """ __slots__ = ["max_depth", "max_depth_items", "__store"] def __init__(self, *args, **kwargs): + super(NameDict, self).__init__() self.__store = dict() + #: the maximum depth of the keys that have ever been added self.max_depth = 0 + #: the number of items of maximum depth self.max_depth_items = 0 self.update(dict(*args, **kwargs)) @@ -83,14 +85,16 @@ class NameDict(collections.MutableMapping): return key in self.__store def get_deepest_match(self, name): - """Find the deepest match to I{name} in the dictionary. + """Find the deepest match to *fname* in the dictionary. The deepest match is the longest name in the dictionary which is - a superdomain of I{name}. + a superdomain of *name*. Note that *superdomain* includes matching + *name* itself. - @param name: the name - @type name: dns.name.Name object - @rtype: (key, value) tuple + *name*, a ``dns.name.Name``, the name to find. + + Returns a ``(key, value)`` where *key* is the deepest + ``dns.name.Name``, and *value* is the value associated with *key*. """ depth = len(name) diff --git a/src/dns/node.py b/src/dns/node.py index 7c25060e..8a7f19f5 100644 --- a/src/dns/node.py +++ b/src/dns/node.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -24,19 +26,12 @@ import dns.renderer class Node(object): - """A DNS node. - - A node is a set of rdatasets - - @ivar rdatasets: the node's rdatasets - @type rdatasets: list of dns.rdataset.Rdataset objects""" + """A Node is a set of rdatasets.""" __slots__ = ['rdatasets'] def __init__(self): - """Initialize a DNS node. - """ - + #: the set of rdatsets, represented as a list. self.rdatasets = [] def to_text(self, name, **kw): @@ -44,9 +39,10 @@ class Node(object): Each rdataset at the node is printed. Any keyword arguments to this method are passed on to the rdataset's to_text() method. - @param name: the owner name of the rdatasets - @type name: dns.name.Name object - @rtype: string + + *name*, a ``dns.name.Name`` or ``text``, the owner name of the rdatasets. + + Returns a ``text``. """ s = StringIO() @@ -60,10 +56,6 @@ class Node(object): return '' def __eq__(self, other): - """Two nodes are equal if they have the same rdatasets. - - @rtype: bool - """ # # This is inefficient. Good thing we don't need to do it much. # @@ -89,11 +81,11 @@ class Node(object): """Find an rdataset matching the specified properties in the current node. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. Usually this value is + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. Usually this value is dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or dns.rdatatype.RRSIG, then the covers value will be the rdata type the SIG/RRSIG covers. The library treats the SIG and RRSIG @@ -101,12 +93,13 @@ class Node(object): types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much easier to work with than if RRSIGs covering different rdata types were aggregated into a single RRSIG rdataset. - @type covers: int - @param create: If True, create the rdataset if it is not found. - @type create: bool - @raises KeyError: An rdataset of the desired type and class does - not exist and I{create} is not True. - @rtype: dns.rdataset.Rdataset object + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Raises ``KeyError`` if an rdataset of the desired type and class does + not exist and *create* is not ``True``. + + Returns a ``dns.rdataset.Rdataset``. """ for rds in self.rdatasets: @@ -124,17 +117,24 @@ class Node(object): current node. None is returned if an rdataset of the specified type and - class does not exist and I{create} is not True. + class does not exist and *create* is not ``True``. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. - @type covers: int - @param create: If True, create the rdataset if it is not found. - @type create: bool - @rtype: dns.rdataset.Rdataset object or None + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. Usually this value is + dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or + dns.rdatatype.RRSIG, then the covers value will be the rdata + type the SIG/RRSIG covers. The library treats the SIG and RRSIG + types as if they were a family of + types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much + easier to work with than if RRSIGs covering different rdata + types were aggregated into a single RRSIG rdataset. + + *create*, a ``bool``. If True, create the rdataset if it is not found. + + Returns a ``dns.rdataset.Rdataset`` or ``None``. """ try: @@ -149,12 +149,11 @@ class Node(object): If a matching rdataset does not exist, it is not an error. - @param rdclass: The class of the rdataset - @type rdclass: int - @param rdtype: The type of the rdataset - @type rdtype: int - @param covers: The covered type. - @type covers: int + *rdclass*, an ``int``, the class of the rdataset. + + *rdtype*, an ``int``, the type of the rdataset. + + *covers*, an ``int``, the covered type. """ rds = self.get_rdataset(rdclass, rdtype, covers) @@ -164,11 +163,16 @@ class Node(object): def replace_rdataset(self, replacement): """Replace an rdataset. - It is not an error if there is no rdataset matching I{replacement}. + It is not an error if there is no rdataset matching *replacement*. - Ownership of the I{replacement} object is transferred to the node; - in other words, this method does not store a copy of I{replacement} - at the node, it stores I{replacement} itself. + Ownership of the *replacement* object is transferred to the node; + in other words, this method does not store a copy of *replacement* + at the node, it stores *replacement* itself. + + *replacement*, a ``dns.rdataset.Rdataset``. + + Raises ``ValueError`` if *replacement* is not a + ``dns.rdataset.Rdataset``. """ if not isinstance(replacement, dns.rdataset.Rdataset): diff --git a/src/dns/node.pyi b/src/dns/node.pyi new file mode 100644 index 00000000..0997edf9 --- /dev/null +++ b/src/dns/node.pyi @@ -0,0 +1,17 @@ +from typing import List, Optional, Union +from . import rdataset, rdatatype, name +class Node: + def __init__(self): + self.rdatasets : List[rdataset.Rdataset] + def to_text(self, name : Union[str,name.Name], **kw) -> str: + ... + def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, + create=False) -> rdataset.Rdataset: + ... + def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, + create=False) -> Optional[rdataset.Rdataset]: + ... + def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE): + ... + def replace_rdataset(self, replacement : rdataset.Rdataset) -> None: + ... diff --git a/src/dns/opcode.py b/src/dns/opcode.py index 70d704fb..c0735ba4 100644 --- a/src/dns/opcode.py +++ b/src/dns/opcode.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -17,10 +19,15 @@ import dns.exception +#: Query QUERY = 0 +#: Inverse Query (historical) IQUERY = 1 +#: Server Status (unspecified and unimplemented anywhere) STATUS = 2 +#: Notify NOTIFY = 4 +#: Dynamic Update UPDATE = 5 _by_text = { @@ -35,21 +42,21 @@ _by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be true inverse. -_by_value = dict((y, x) for x, y in _by_text.items()) +_by_value = {y: x for x, y in _by_text.items()} class UnknownOpcode(dns.exception.DNSException): - """An DNS opcode is unknown.""" def from_text(text): """Convert text into an opcode. - @param text: the textual opcode - @type text: string - @raises UnknownOpcode: the opcode is unknown - @rtype: int + *text*, a ``text``, the textual opcode + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns an ``int``. """ if text.isdigit(): @@ -65,8 +72,9 @@ def from_text(text): def from_flags(flags): """Extract an opcode from DNS message flags. - @param flags: int - @rtype: int + *flags*, an ``int``, the DNS flags. + + Returns an ``int``. """ return (flags & 0x7800) >> 11 @@ -75,7 +83,10 @@ def from_flags(flags): def to_flags(value): """Convert an opcode to a value suitable for ORing into DNS message flags. - @rtype: int + + *value*, an ``int``, the DNS opcode value. + + Returns an ``int``. """ return (value << 11) & 0x7800 @@ -84,10 +95,11 @@ def to_flags(value): def to_text(value): """Convert an opcode to text. - @param value: the opcdoe - @type value: int - @raises UnknownOpcode: the opcode is unknown - @rtype: string + *value*, an ``int`` the opcode value, + + Raises ``dns.opcode.UnknownOpcode`` if the opcode is unknown. + + Returns a ``text``. """ text = _by_value.get(value) @@ -97,11 +109,11 @@ def to_text(value): def is_update(flags): - """True if the opcode in flags is UPDATE. + """Is the opcode in flags UPDATE? - @param flags: DNS flags - @type flags: int - @rtype: bool + *flags*, an ``int``, the DNS message flags. + + Returns a ``bool``. """ return from_flags(flags) == UPDATE diff --git a/src/dns/py.typed b/src/dns/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/dns/query.py b/src/dns/query.py index bfecd43e..c0c517cc 100644 --- a/src/dns/query.py +++ b/src/dns/query.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -28,11 +30,12 @@ import dns.exception import dns.inet import dns.name import dns.message +import dns.rcode import dns.rdataclass import dns.rdatatype -from ._compat import long, string_types +from ._compat import long, string_types, PY3 -if sys.version_info > (3,): +if PY3: select_error = OSError else: select_error = select.error @@ -42,34 +45,36 @@ else: socket_factory = socket.socket class UnexpectedSource(dns.exception.DNSException): - """A DNS query response came from an unexpected address or port.""" class BadResponse(dns.exception.FormError): - """A DNS query response does not respond to the question asked.""" +class TransferError(dns.exception.DNSException): + """A zone transfer response got a non-zero rcode.""" + + def __init__(self, rcode): + message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + super(TransferError, self).__init__(message) + self.rcode = rcode + + def _compute_expiration(timeout): if timeout is None: return None else: return time.time() + timeout +# This module can use either poll() or select() as the "polling backend". +# +# A backend function takes an fd, bools for readability, writablity, and +# error detection, and a timeout. def _poll_for(fd, readable, writable, error, timeout): - """Poll polling backend. - @param fd: File descriptor - @type fd: int - @param readable: Whether to wait for readability - @type readable: bool - @param writable: Whether to wait for writability - @type writable: bool - @param timeout: Deadline timeout (expiration time, in seconds) - @type timeout: float - @return True on success, False on timeout - """ + """Poll polling backend.""" + event_mask = 0 if readable: event_mask |= select.POLLIN @@ -90,17 +95,8 @@ def _poll_for(fd, readable, writable, error, timeout): def _select_for(fd, readable, writable, error, timeout): - """Select polling backend. - @param fd: File descriptor - @type fd: int - @param readable: Whether to wait for readability - @type readable: bool - @param writable: Whether to wait for writability - @type writable: bool - @param timeout: Deadline timeout (expiration time, in seconds) - @type timeout: float - @return True on success, False on timeout - """ + """Select polling backend.""" + rset, wset, xset = [], [], [] if readable: @@ -119,6 +115,10 @@ def _select_for(fd, readable, writable, error, timeout): def _wait_for(fd, readable, writable, error, expiration): + # Use the selected polling backend to wait for any of the specified + # events. An "expiration" absolute time is converted into a relative + # timeout. + done = False while not done: if expiration is None: @@ -137,9 +137,8 @@ def _wait_for(fd, readable, writable, error, expiration): def _set_polling_backend(fn): - """ - Internal API. Do not use. - """ + # Internal API. Do not use. + global _polling_backend _polling_backend = fn @@ -165,8 +164,11 @@ def _addresses_equal(af, a1, a2): # Convert the first value of the tuple, which is a textual format # address into binary form, so that we are not confused by different # textual representations of the same address - n1 = dns.inet.inet_pton(af, a1[0]) - n2 = dns.inet.inet_pton(af, a2[0]) + try: + n1 = dns.inet.inet_pton(af, a1[0]) + n2 = dns.inet.inet_pton(af, a2[0]) + except dns.exception.SyntaxError: + return False return n1 == n2 and a1[1:] == a2[1:] @@ -193,68 +195,140 @@ def _destination_and_source(af, where, port, source, source_port): return (af, destination, source) +def send_udp(sock, what, destination, expiration=None): + """Send a DNS message to the specified UDP socket. + + *sock*, a ``socket``. + + *what*, a ``binary`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + _wait_for_writable(sock, expiration) + sent_time = time.time() + n = sock.sendto(what, destination) + return (n, sent_time) + + +def receive_udp(sock, destination, expiration=None, + ignore_unexpected=False, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False): + """Read a DNS message from a UDP socket. + + *sock*, a ``socket``. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where the associated query was sent. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``binary``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``dns.message.Message`` object. + """ + + wire = b'' + while 1: + _wait_for_readable(sock, expiration) + (wire, from_address) = sock.recvfrom(65535) + if _addresses_equal(sock.family, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + break + if not ignore_unexpected: + raise UnexpectedSource('got a response from ' + '%s instead of %s' % (from_address, + destination)) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + return (r, received_time) + def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False): + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False): """Return the response obtained after sending a query via UDP. - @param q: the query - @type q: dns.message.Message - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param timeout: The number of seconds to wait before the query times out. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @rtype: dns.message.Message object - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``text`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*. If the inference attempt fails, AF_INET is used. This + parameter is historical; you need never set it. + + *source*, a ``text`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is 0. - @type source_port: int - @param ignore_unexpected: If True, ignore responses from unexpected - sources. The default is False. - @type ignore_unexpected: bool - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Returns a ``dns.message.Message``. """ wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) s = socket_factory(af, socket.SOCK_DGRAM, 0) - begin_time = None + received_time = None + sent_time = None try: expiration = _compute_expiration(timeout) s.setblocking(0) if source is not None: s.bind(source) - _wait_for_writable(s, expiration) - begin_time = time.time() - s.sendto(wire, destination) - while 1: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) - if _addresses_equal(af, from_address, destination) or \ - (dns.inet.is_multicast(where) and - from_address[1:] == destination[1:]): - break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) + (_, sent_time) = send_udp(s, wire, destination, expiration) + (r, received_time) = receive_udp(s, destination, expiration, + ignore_unexpected, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing) finally: - if begin_time is None: + if sent_time is None or received_time is None: response_time = 0 else: - response_time = time.time() - begin_time + response_time = received_time - sent_time s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) r.time = response_time if not q.is_response(r): raise BadResponse @@ -290,6 +364,67 @@ def _net_write(sock, data, expiration): current += sock.send(data[current:]) +def send_tcp(sock, what, expiration=None): + """Send a DNS message to the specified TCP socket. + + *sock*, a ``socket``. + + *what*, a ``binary`` or ``dns.message.Message``, the message to send. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + l = len(what) + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = struct.pack("!H", l) + what + _wait_for_writable(sock, expiration) + sent_time = time.time() + _net_write(sock, tcpmsg, expiration) + return (len(tcpmsg), sent_time) + +def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False): + """Read a DNS message from a TCP socket. + + *sock*, a ``socket``. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``binary``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``dns.message.Message`` object. + """ + + ldata = _net_read(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = _net_read(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + return (r, received_time) + def _connect(s, address): try: s.connect(address) @@ -305,30 +440,37 @@ def _connect(s, address): def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, - one_rr_per_rrset=False): + one_rr_per_rrset=False, ignore_trailing=False): """Return the response obtained after sending a query via TCP. - @param q: the query - @type q: dns.message.Message object - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param timeout: The number of seconds to wait before the query times out. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @rtype: dns.message.Message object - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``text`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*. If the inference attempt fails, AF_INET is used. This + parameter is historical; you need never set it. + + *source*, a ``text`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is 0. - @type source_port: int - @param one_rr_per_rrset: Put each RR into its own RRset - @type one_rr_per_rrset: bool + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Returns a ``dns.message.Message``. """ wire = q.to_wire() @@ -336,6 +478,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, source, source_port) s = socket_factory(af, socket.SOCK_STREAM, 0) begin_time = None + received_time = None try: expiration = _compute_expiration(timeout) s.setblocking(0) @@ -343,25 +486,15 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, if source is not None: s.bind(source) _connect(s, destination) - - l = len(wire) - - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - tcpmsg = struct.pack("!H", l) + wire - _net_write(s, tcpmsg, expiration) - ldata = _net_read(s, 2, expiration) - (l,) = struct.unpack("!H", ldata) - wire = _net_read(s, l, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing) finally: - if begin_time is None: + if begin_time is None or received_time is None: response_time = 0 else: - response_time = time.time() - begin_time + response_time = received_time - begin_time s.close() - r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, - one_rr_per_rrset=one_rr_per_rrset) r.time = response_time if not q.is_response(r): raise BadResponse @@ -374,51 +507,59 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, use_udp=False, keyalgorithm=dns.tsig.default_algorithm): """Return a generator for the responses to a zone transfer. - @param where: where to send the message - @type where: string containing an IPv4 or IPv6 address - @param zone: The name of the zone to transfer - @type zone: dns.name.Name object or string - @param rdtype: The type of zone transfer. The default is - dns.rdatatype.AXFR. - @type rdtype: int or string - @param rdclass: The class of the zone transfer. The default is - dns.rdataclass.IN. - @type rdclass: int or string - @param timeout: The number of seconds to wait for each response message. - If None, the default, wait forever. - @type timeout: float - @param port: The port to which to send the message. The default is 53. - @type port: int - @param keyring: The TSIG keyring to use - @type keyring: dict - @param keyname: The name of the TSIG key to use - @type keyname: dns.name.Name object or string - @param relativize: If True, all names in the zone will be relativized to - the zone origin. It is essential that the relativize setting matches - the one specified to dns.zone.from_xfr(). - @type relativize: bool - @param af: the address family to use. The default is None, which - causes the address family to use to be inferred from the form of where. - If the inference attempt fails, AF_INET is used. - @type af: int - @param lifetime: The total number of seconds to spend doing the transfer. - If None, the default, then there is no limit on the time the transfer may - take. - @type lifetime: float - @rtype: generator of dns.message.Message objects. - @param source: source address. The default is the wildcard address. - @type source: string - @param source_port: The port from which to send the message. + *where*. If the inference attempt fails, AF_INET is used. This + parameter is historical; you need never set it. + + *zone*, a ``dns.name.Name`` or ``text``, the name of the zone to transfer. + + *rdtype*, an ``int`` or ``text``, the type of zone transfer. The + default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be + used to do an incremental transfer instead. + + *rdclass*, an ``int`` or ``text``, the class of the zone transfer. + The default is ``dns.rdataclass.IN``. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *keyname*, a ``dns.name.Name`` or ``text``, the name of the TSIG + key to use. + + *relativize*, a ``bool``. If ``True``, all names in the zone will be + relativized to the zone origin. It is essential that the + relativize setting matches the one specified to + ``dns.zone.from_xfr()`` if using this generator to make a zone. + + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*. If the inference attempt fails, AF_INET is used. This + parameter is historical; you need never set it. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``text`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is 0. - @type source_port: int - @param serial: The SOA serial number to use as the base for an IXFR diff - sequence (only meaningful if rdtype == dns.rdatatype.IXFR). - @type serial: int - @param use_udp: Use UDP (only meaningful for IXFR) - @type use_udp: bool - @param keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm - @type keyalgorithm: string + + *serial*, an ``int``, the SOA serial number to use as the base for + an IXFR diff sequence (only meaningful if *rdtype* is + ``dns.rdatatype.IXFR``). + + *use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR). + + *keyalgorithm*, a ``dns.name.Name`` or ``text``, the TSIG algorithm to use. + + Raises on errors, and so does the generator. + + Returns a generator of ``dns.message.Message`` objects. """ if isinstance(zone, string_types): @@ -481,6 +622,9 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, xfr=True, origin=origin, tsig_ctx=tsig_ctx, multi=True, first=first, one_rr_per_rrset=is_ixfr) + rcode = r.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) tsig_ctx = r.tsig_ctx first = False answer_index = 0 diff --git a/src/dns/query.pyi b/src/dns/query.pyi new file mode 100644 index 00000000..fe5ef826 --- /dev/null +++ b/src/dns/query.pyi @@ -0,0 +1,15 @@ +from typing import Optional, Union, Dict, Generator, Any +from . import message, tsig, rdatatype, rdataclass, name, message +def tcp(q : message.Message, where : str, timeout : float = None, port=53, af : Optional[int] = None, source : Optional[str] = None, source_port : int = 0, + one_rr_per_rrset=False) -> message.Message: + pass + +def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR, rdclass=rdataclass.IN, + timeout : Optional[float] =None, port=53, keyring : Optional[Dict[name.Name, bytes]] =None, keyname : Union[str,name.Name]=None, relativize=True, + af : Optional[int] =None, lifetime : Optional[float]=None, source : Optional[str] =None, source_port=0, serial=0, + use_udp=False, keyalgorithm=tsig.default_algorithm) -> Generator[Any,Any,message.Message]: + pass + +def udp(q : message.Message, where : str, timeout : Optional[float] = None, port=53, af : Optional[int] = None, source : Optional[str] = None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False) -> message.Message: + ... diff --git a/src/dns/rcode.py b/src/dns/rcode.py index 314815f7..5191e1b1 100644 --- a/src/dns/rcode.py +++ b/src/dns/rcode.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -18,18 +20,29 @@ import dns.exception from ._compat import long - +#: No error NOERROR = 0 +#: Form error FORMERR = 1 +#: Server failure SERVFAIL = 2 +#: Name does not exist ("Name Error" in RFC 1025 terminology). NXDOMAIN = 3 +#: Not implemented NOTIMP = 4 +#: Refused REFUSED = 5 +#: Name exists. YXDOMAIN = 6 +#: RRset exists. YXRRSET = 7 +#: RRset does not exist. NXRRSET = 8 +#: Not authoritative. NOTAUTH = 9 +#: Name not in zone. NOTZONE = 10 +#: Bad EDNS version. BADVERS = 16 _by_text = { @@ -51,21 +64,21 @@ _by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be a true inverse. -_by_value = dict((y, x) for x, y in _by_text.items()) +_by_value = {y: x for x, y in _by_text.items()} class UnknownRcode(dns.exception.DNSException): - """A DNS rcode is unknown.""" def from_text(text): """Convert text into an rcode. - @param text: the textual rcode - @type text: string - @raises UnknownRcode: the rcode is unknown - @rtype: int + *text*, a ``text``, the textual rcode or an integer in textual form. + + Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. + + Returns an ``int``. """ if text.isdigit(): @@ -81,12 +94,13 @@ def from_text(text): def from_flags(flags, ednsflags): """Return the rcode value encoded by flags and ednsflags. - @param flags: the DNS flags - @type flags: int - @param ednsflags: the EDNS flags - @type ednsflags: int - @raises ValueError: rcode is < 0 or > 4095 - @rtype: int + *flags*, an ``int``, the DNS flags field. + + *ednsflags*, an ``int``, the EDNS flags field. + + Raises ``ValueError`` if rcode is < 0 or > 4095 + + Returns an ``int``. """ value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) @@ -98,10 +112,11 @@ def from_flags(flags, ednsflags): def to_flags(value): """Return a (flags, ednsflags) tuple which encodes the rcode. - @param value: the rcode - @type value: int - @raises ValueError: rcode is < 0 or > 4095 - @rtype: (int, int) tuple + *value*, an ``int``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns an ``(int, int)`` tuple. """ if value < 0 or value > 4095: @@ -114,11 +129,15 @@ def to_flags(value): def to_text(value): """Convert rcode into text. - @param value: the rcode - @type value: int - @rtype: string + *value*, and ``int``, the rcode. + + Raises ``ValueError`` if rcode is < 0 or > 4095. + + Returns a ``text``. """ + if value < 0 or value > 4095: + raise ValueError('rcode must be >= 0 and <= 4095') text = _by_value.get(value) if text is None: text = str(value) diff --git a/src/dns/rdata.py b/src/dns/rdata.py index 9e9344d5..ea1971dc 100644 --- a/src/dns/rdata.py +++ b/src/dns/rdata.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,17 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS rdata. - -@var _rdata_modules: A dictionary mapping a (rdclass, rdtype) tuple to -the module which implements that type. -@type _rdata_modules: dict -@var _module_prefix: The prefix to use when forming modules names. The -default is 'dns.rdtypes'. Changing this value will break the library. -@type _module_prefix: string -@var _hex_chunk: At most this many octets that will be represented in each -chunk of hexstring that _hexify() produces before whitespace occurs. -@type _hex_chunk: int""" +"""DNS rdata.""" from io import BytesIO import base64 @@ -37,17 +29,17 @@ import dns.tokenizer import dns.wiredata from ._compat import xrange, string_types, text_type +try: + import threading as _threading +except ImportError: + import dummy_threading as _threading + _hex_chunksize = 32 def _hexify(data, chunksize=_hex_chunksize): """Convert a binary string into its hex encoding, broken up into chunks - of I{chunksize} characters separated by a space. - - @param data: the binary string - @type data: string - @param chunksize: the chunk size. Default is L{dns.rdata._hex_chunksize} - @rtype: string + of chunksize characters separated by a space. """ line = binascii.hexlify(data) @@ -60,13 +52,7 @@ _base64_chunksize = 32 def _base64ify(data, chunksize=_base64_chunksize): """Convert a binary string into its base64 encoding, broken up into chunks - of I{chunksize} characters separated by a space. - - @param data: the binary string - @type data: string - @param chunksize: the chunk size. Default is - L{dns.rdata._base64_chunksize} - @rtype: string + of chunksize characters separated by a space. """ line = base64.b64encode(data) @@ -77,13 +63,7 @@ def _base64ify(data, chunksize=_base64_chunksize): __escaped = bytearray(b'"\\') def _escapify(qstring): - """Escape the characters in a quoted string which need it. - - @param qstring: the string - @type qstring: string - @returns: the escaped string - @rtype: string - """ + """Escape the characters in a quoted string which need it.""" if isinstance(qstring, text_type): qstring = qstring.encode() @@ -104,10 +84,6 @@ def _escapify(qstring): def _truncate_bitmap(what): """Determine the index of greatest byte that isn't all zeros, and return the bitmap that contains all the bytes less than that index. - - @param what: a string of octets representing a bitmap. - @type what: string - @rtype: string """ for i in xrange(len(what) - 1, -1, -1): @@ -117,30 +93,30 @@ def _truncate_bitmap(what): class Rdata(object): - - """Base class for all DNS rdata types. - """ + """Base class for all DNS rdata types.""" __slots__ = ['rdclass', 'rdtype'] def __init__(self, rdclass, rdtype): """Initialize an rdata. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + *rdtype*, an ``int`` is the rdatatype of the Rdata. """ self.rdclass = rdclass self.rdtype = rdtype def covers(self): - """DNS SIG/RRSIG rdatas apply to a specific type; this type is + """Return the type a Rdata covers. + + DNS SIG/RRSIG rdatas apply to a specific type; this type is returned by the covers() function. If the rdata type is not SIG or RRSIG, dns.rdatatype.NONE is returned. This is useful when creating rdatasets, allowing the rdataset to contain only RRSIGs of a particular type, e.g. RRSIG(NS). - @rtype: int + + Returns an ``int``. """ return dns.rdatatype.NONE @@ -149,37 +125,52 @@ class Rdata(object): """Return a 32-bit type value, the least significant 16 bits of which are the ordinary DNS type, and the upper 16 bits of which are the "covered" type, if any. - @rtype: int + + Returns an ``int``. """ return self.covers() << 16 | self.rdtype def to_text(self, origin=None, relativize=True, **kw): """Convert an rdata to text format. - @rtype: string + + Returns a ``text``. """ + raise NotImplementedError def to_wire(self, file, compress=None, origin=None): """Convert an rdata to wire format. - @rtype: string + + Returns a ``binary``. """ raise NotImplementedError def to_digestable(self, origin=None): """Convert rdata to a format suitable for digesting in hashes. This - is also the DNSSEC canonical form.""" + is also the DNSSEC canonical form. + + Returns a ``binary``. + """ + f = BytesIO() self.to_wire(f, None, origin) return f.getvalue() def validate(self): """Check that the current contents of the rdata's fields are - valid. If you change an rdata by assigning to its fields, + valid. + + If you change an rdata by assigning to its fields, it is a good idea to call validate() when you are done making changes. + + Raises various exceptions if there are problems. + + Returns ``None``. """ + dns.rdata.from_text(self.rdclass, self.rdtype, self.to_text()) def __repr__(self): @@ -197,17 +188,20 @@ class Rdata(object): def _cmp(self, other): """Compare an rdata with another rdata of the same rdtype and - rdclass. Return < 0 if self < other in the DNSSEC ordering, - 0 if self == other, and > 0 if self > other. + rdclass. + + Return < 0 if self < other in the DNSSEC ordering, 0 if self + == other, and > 0 if self > other. + """ our = self.to_digestable(dns.name.root) their = other.to_digestable(dns.name.root) if our == their: return 0 - if our > their: + elif our > their: return 1 - - return -1 + else: + return -1 def __eq__(self, other): if not isinstance(other, Rdata): @@ -253,42 +247,10 @@ class Rdata(object): @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): - """Build an rdata object from text format. - - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param tok: The tokenizer - @type tok: dns.tokenizer.Tokenizer - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @param relativize: should names be relativized? - @type relativize: bool - @rtype: dns.rdata.Rdata instance - """ - raise NotImplementedError @classmethod def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - """Build an rdata object from wire format - - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param rdlen: The length of the wire-format rdata - @type rdlen: int - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @rtype: dns.rdata.Rdata instance - """ - raise NotImplementedError def choose_relativity(self, origin=None, relativize=True): @@ -296,12 +258,9 @@ class Rdata(object): relativization. """ - pass - - class GenericRdata(Rdata): - """Generate Rdata Class + """Generic Rdata Class This class is used for rdata types for which we have no better implementation. It implements the DNS "unknown RRs" scheme. @@ -319,7 +278,7 @@ class GenericRdata(Rdata): @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): token = tok.get() - if not token.is_identifier() or token.value != '\#': + if not token.is_identifier() or token.value != r'\#': raise dns.exception.SyntaxError( r'generic rdata does not start with \#') length = tok.get_int() @@ -345,16 +304,17 @@ class GenericRdata(Rdata): _rdata_modules = {} _module_prefix = 'dns.rdtypes' - +_import_lock = _threading.Lock() def get_rdata_class(rdclass, rdtype): def import_module(name): - mod = __import__(name) - components = name.split('.') - for comp in components[1:]: - mod = getattr(mod, comp) - return mod + with _import_lock: + mod = __import__(name) + components = name.split('.') + for comp in components[1:]: + mod = getattr(mod, comp) + return mod mod = _rdata_modules.get((rdclass, rdtype)) rdclass_text = dns.rdataclass.to_text(rdclass) @@ -392,20 +352,23 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True): Once a class is chosen, its from_text() class method is called with the parameters to this function. - If I{tok} is a string, then a tokenizer is created and the string + If *tok* is a ``text``, then a tokenizer is created and the string is used as its input. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param tok: The tokenizer or input text - @type tok: dns.tokenizer.Tokenizer or string - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @param relativize: Should names be relativized? - @type relativize: bool - @rtype: dns.rdata.Rdata instance""" + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *tok*, a ``dns.tokenizer.Tokenizer`` or a ``text``. + + *origin*, a ``dns.name.Name`` (or ``None``), the + origin to use for relative names. + + *relativize*, a ``bool``. If true, name will be relativized to + the specified origin. + + Returns an instance of the chosen Rdata subclass. + """ if isinstance(tok, string_types): tok = dns.tokenizer.Tokenizer(tok) @@ -439,20 +402,55 @@ def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): Once a class is chosen, its from_wire() class method is called with the parameters to this function. - @param rdclass: The rdata class - @type rdclass: int - @param rdtype: The rdata type - @type rdtype: int - @param wire: The wire-format message - @type wire: string - @param current: The offset in wire of the beginning of the rdata. - @type current: int - @param rdlen: The length of the wire-format rdata - @type rdlen: int - @param origin: The origin to use for relative names - @type origin: dns.name.Name - @rtype: dns.rdata.Rdata instance""" + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *wire*, a ``binary``, the wire-format message. + + *current*, an ``int``, the offset in wire of the beginning of + the rdata. + + *rdlen*, an ``int``, the length of the wire-format rdata + + *origin*, a ``dns.name.Name`` (or ``None``). If not ``None``, + then names will be relativized to this origin. + + Returns an instance of the chosen Rdata subclass. + """ wire = dns.wiredata.maybe_wrap(wire) cls = get_rdata_class(rdclass, rdtype) return cls.from_wire(rdclass, rdtype, wire, current, rdlen, origin) + + +class RdatatypeExists(dns.exception.DNSException): + """DNS rdatatype already exists.""" + supp_kwargs = {'rdclass', 'rdtype'} + fmt = "The rdata type with class {rdclass} and rdtype {rdtype} " + \ + "already exists." + + +def register_type(implementation, rdtype, rdtype_text, is_singleton=False, + rdclass=dns.rdataclass.IN): + """Dynamically register a module to handle an rdatatype. + + *implementation*, a module implementing the type in the usual dnspython + way. + + *rdtype*, an ``int``, the rdatatype to register. + + *rdtype_text*, a ``text``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + + *rdclass*, the rdataclass of the type, or ``dns.rdataclass.ANY`` if + it applies to all classes. + """ + + existing_cls = get_rdata_class(rdclass, rdtype) + if existing_cls != GenericRdata: + raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + _rdata_modules[(rdclass, rdtype)] = implementation + dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/src/dns/rdata.pyi b/src/dns/rdata.pyi new file mode 100644 index 00000000..8663955c --- /dev/null +++ b/src/dns/rdata.pyi @@ -0,0 +1,17 @@ +from typing import Dict, Tuple, Any, Optional +from .name import Name +class Rdata: + def __init__(self): + self.address : str + def to_wire(self, file, compress : Optional[Dict[Name,int]], origin : Optional[Name]) -> bytes: + ... + @classmethod + def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True): + ... +_rdata_modules : Dict[Tuple[Any,Rdata],Any] + +def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None, relativize : bool = True): + ... + +def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None): + ... diff --git a/src/dns/rdataclass.py b/src/dns/rdataclass.py index 17a4810d..b88aa85b 100644 --- a/src/dns/rdataclass.py +++ b/src/dns/rdataclass.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,15 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Rdata Classes. - -@var _by_text: The rdata class textual name to value mapping -@type _by_text: dict -@var _by_value: The rdata class value to textual name mapping -@type _by_value: dict -@var _metaclasses: If an rdataclass is a metaclass, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _metaclasses: dict""" +"""DNS Rdata Classes.""" import re @@ -47,7 +41,7 @@ _by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be true inverse. -_by_value = dict((y, x) for x, y in _by_text.items()) +_by_value = {y: x for x, y in _by_text.items()} # Now that we've built the inverse map, we can add class aliases to # the _by_text mapping. @@ -67,17 +61,22 @@ _unknown_class_pattern = re.compile('CLASS([0-9]+)$', re.I) class UnknownRdataclass(dns.exception.DNSException): - """A DNS class is unknown.""" def from_text(text): """Convert text into a DNS rdata class value. - @param text: the text - @type text: string - @rtype: int - @raises dns.rdataclass.UnknownRdataclass: the class is unknown - @raises ValueError: the rdata class value is not >= 0 and <= 65535 + + The input text can be a defined DNS RR class mnemonic or + instance of the DNS generic class syntax. + + For example, "IN" and "CLASS1" will both result in a value of 1. + + Raises ``dns.rdatatype.UnknownRdataclass`` if the class is unknown. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns an ``int``. """ value = _by_text.get(text.upper()) @@ -92,11 +91,14 @@ def from_text(text): def to_text(value): - """Convert a DNS rdata class to text. - @param value: the rdata class value - @type value: int - @rtype: string - @raises ValueError: the rdata class value is not >= 0 and <= 65535 + """Convert a DNS rdata type value to text. + + If the value has a known mnemonic, it will be used, otherwise the + DNS generic class syntax will be used. + + Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. + + Returns a ``str``. """ if value < 0 or value > 65535: @@ -108,10 +110,12 @@ def to_text(value): def is_metaclass(rdclass): - """True if the class is a metaclass. - @param rdclass: the rdata class - @type rdclass: int - @rtype: bool""" + """True if the specified class is a metaclass. + + The currently defined metaclasses are ANY and NONE. + + *rdclass* is an ``int``. + """ if rdclass in _metaclasses: return True diff --git a/src/dns/rdataset.py b/src/dns/rdataset.py index db266f2f..f1afe241 100644 --- a/src/dns/rdataset.py +++ b/src/dns/rdataset.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -31,50 +33,37 @@ SimpleSet = dns.set.Set class DifferingCovers(dns.exception.DNSException): - """An attempt was made to add a DNS SIG/RRSIG whose covered type is not the same as that of the other rdatas in the rdataset.""" class IncompatibleTypes(dns.exception.DNSException): - """An attempt was made to add DNS RR data of an incompatible type.""" class Rdataset(dns.set.Set): - """A DNS rdataset. - - @ivar rdclass: The class of the rdataset - @type rdclass: int - @ivar rdtype: The type of the rdataset - @type rdtype: int - @ivar covers: The covered type. Usually this value is - dns.rdatatype.NONE, but if the rdtype is dns.rdatatype.SIG or - dns.rdatatype.RRSIG, then the covers value will be the rdata - type the SIG/RRSIG covers. The library treats the SIG and RRSIG - types as if they were a family of - types, e.g. RRSIG(A), RRSIG(NS), RRSIG(SOA). This makes RRSIGs much - easier to work with than if RRSIGs covering different rdata - types were aggregated into a single RRSIG rdataset. - @type covers: int - @ivar ttl: The DNS TTL (Time To Live) value - @type ttl: int - """ + """A DNS rdataset.""" __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] - def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0): """Create a new rdataset of the specified class and type. - @see: the description of the class instance variables for the - meaning of I{rdclass} and I{rdtype}""" + *rdclass*, an ``int``, the rdataclass. + + *rdtype*, an ``int``, the rdatatype. + + *covers*, an ``int``, the covered rdatatype. + + *ttl*, an ``int``, the TTL. + """ super(Rdataset, self).__init__() self.rdclass = rdclass self.rdtype = rdtype self.covers = covers - self.ttl = 0 + self.ttl = ttl def _clone(self): obj = super(Rdataset, self)._clone() @@ -85,11 +74,14 @@ class Rdataset(dns.set.Set): return obj def update_ttl(self, ttl): - """Set the TTL of the rdataset to be the lesser of the set's current + """Perform TTL minimization. + + Set the TTL of the rdataset to be the lesser of the set's current TTL or the specified TTL. If the set contains no rdatas, set the TTL to the specified TTL. - @param ttl: The TTL - @type ttl: int""" + + *ttl*, an ``int``. + """ if len(self) == 0: self.ttl = ttl @@ -99,13 +91,19 @@ class Rdataset(dns.set.Set): def add(self, rd, ttl=None): """Add the specified rdata to the rdataset. - If the optional I{ttl} parameter is supplied, then - self.update_ttl(ttl) will be called prior to adding the rdata. + If the optional *ttl* parameter is supplied, then + ``self.update_ttl(ttl)`` will be called prior to adding the rdata. - @param rd: The rdata - @type rd: dns.rdata.Rdata object - @param ttl: The TTL - @type ttl: int""" + *rd*, a ``dns.rdata.Rdata``, the rdata + + *ttl*, an ``int``, the TTL. + + Raises ``dns.rdataset.IncompatibleTypes`` if the type and class + do not match the type and class of the rdataset. + + Raises ``dns.rdataset.DifferingCovers`` if the type is a signature + type and the covered type does not match that of the rdataset. + """ # # If we're adding a signature, do some special handling to @@ -139,8 +137,9 @@ class Rdataset(dns.set.Set): def update(self, other): """Add all rdatas in other to self. - @param other: The rdataset from which to update - @type other: dns.rdataset.Rdataset object""" + *other*, a ``dns.rdataset.Rdataset``, the rdataset from which + to update. + """ self.update_ttl(other.ttl) super(Rdataset, self).update(other) @@ -157,10 +156,6 @@ class Rdataset(dns.set.Set): return self.to_text() def __eq__(self, other): - """Two rdatasets are equal if they have the same class, type, and - covers, and contain the same rdata. - @rtype: bool""" - if not isinstance(other, Rdataset): return False if self.rdclass != other.rdclass or \ @@ -176,20 +171,23 @@ class Rdataset(dns.set.Set): override_rdclass=None, **kw): """Convert the rdataset into DNS master file format. - @see: L{dns.name.Name.choose_relativity} for more information - on how I{origin} and I{relativize} determine the way names + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names are emitted. Any additional keyword arguments are passed on to the rdata - to_text() method. + ``to_text()`` method. + + *name*, a ``dns.name.Name``. If name is not ``None``, emit RRs with + *name* as the owner name. + + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. + + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + """ - @param name: If name is not None, emit a RRs with I{name} as - the owner name. - @type name: dns.name.Name object - @param origin: The origin for relative names, or None. - @type origin: dns.name.Name object - @param relativize: True if names should names be relativized - @type relativize: bool""" if name is not None: name = name.choose_relativity(origin, relativize) ntext = str(name) @@ -208,9 +206,9 @@ class Rdataset(dns.set.Set): # some dynamic updates, so we don't need to print out the TTL # (which is meaningless anyway). # - s.write(u'%s%s%s %s\n' % (ntext, pad, - dns.rdataclass.to_text(rdclass), - dns.rdatatype.to_text(self.rdtype))) + s.write(u'{}{}{} {}\n'.format(ntext, pad, + dns.rdataclass.to_text(rdclass), + dns.rdatatype.to_text(self.rdtype))) else: for rd in self: s.write(u'%s%s%d %s %s %s\n' % @@ -227,16 +225,26 @@ class Rdataset(dns.set.Set): override_rdclass=None, want_shuffle=True): """Convert the rdataset to wire format. - @param name: The owner name of the RRset that will be emitted - @type name: dns.name.Name object - @param file: The file to which the wire format data will be appended - @type file: file - @param compress: The compression table to use; the default is None. - @type compress: dict - @param origin: The origin to be appended to any relative names when - they are emitted. The default is None. - @returns: the number of records emitted - @rtype: int + *name*, a ``dns.name.Name`` is the owner name to use. + + *file* is the file where the name is emitted (typically a + BytesIO file). + + *compress*, a ``dict``, is the compression table to use. If + ``None`` (the default), names will not be compressed. + + *origin* is a ``dns.name.Name`` or ``None``. If the name is + relative and origin is not ``None``, then *origin* will be appended + to it. + + *override_rdclass*, an ``int``, is used as the class instead of the + class of the rdataset. This is useful when rendering rdatasets + associated with dynamic updates. + + *want_shuffle*, a ``bool``. If ``True``, then the order of the + Rdatas within the Rdataset will be shuffled before rendering. + + Returns an ``int``, the number of records emitted. """ if override_rdclass is not None: @@ -272,8 +280,9 @@ class Rdataset(dns.set.Set): return len(self) def match(self, rdclass, rdtype, covers): - """Returns True if this rdataset matches the specified class, type, - and covers""" + """Returns ``True`` if this rdataset matches the specified class, + type, and covers. + """ if self.rdclass == rdclass and \ self.rdtype == rdtype and \ self.covers == covers: @@ -285,7 +294,7 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas): """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ if isinstance(rdclass, string_types): @@ -304,7 +313,7 @@ def from_text(rdclass, rdtype, ttl, *text_rdatas): """Create an rdataset with the specified class, type, and TTL, and with the specified rdatas in text format. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ return from_text_list(rdclass, rdtype, ttl, text_rdatas) @@ -314,7 +323,7 @@ def from_rdata_list(ttl, rdatas): """Create an rdataset with the specified TTL, and with the specified list of rdata objects. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ if len(rdatas) == 0: @@ -332,7 +341,7 @@ def from_rdata(ttl, *rdatas): """Create an rdataset with the specified TTL, and with the specified rdata objects. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset`` object. """ return from_rdata_list(ttl, rdatas) diff --git a/src/dns/rdataset.pyi b/src/dns/rdataset.pyi new file mode 100644 index 00000000..3efff88a --- /dev/null +++ b/src/dns/rdataset.pyi @@ -0,0 +1,58 @@ +from typing import Optional, Dict, List, Union +from io import BytesIO +from . import exception, name, set, rdatatype, rdata, rdataset + +class DifferingCovers(exception.DNSException): + """An attempt was made to add a DNS SIG/RRSIG whose covered type + is not the same as that of the other rdatas in the rdataset.""" + + +class IncompatibleTypes(exception.DNSException): + """An attempt was made to add DNS RR data of an incompatible type.""" + + +class Rdataset(set.Set): + def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0): + self.rdclass : int = rdclass + self.rdtype : int = rdtype + self.covers : int = covers + self.ttl : int = ttl + + def update_ttl(self, ttl : int) -> None: + ... + + def add(self, rd : rdata.Rdata, ttl : Optional[int] =None): + ... + + def union_update(self, other : Rdataset): + ... + + def intersection_update(self, other : Rdataset): + ... + + def update(self, other : Rdataset): + ... + + def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True, + override_rdclass : Optional[int] =None, **kw) -> bytes: + ... + + def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None, + override_rdclass : Optional[int] = None, want_shuffle=True) -> int: + ... + + def match(self, rdclass : int, rdtype : int, covers : int) -> bool: + ... + + +def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str) -> rdataset.Rdataset: + ... + +def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset: + ... + +def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: + ... + +def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: + ... diff --git a/src/dns/rdatatype.py b/src/dns/rdatatype.py index 15284f64..b247bc9c 100644 --- a/src/dns/rdatatype.py +++ b/src/dns/rdatatype.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,18 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Rdata Types. - -@var _by_text: The rdata type textual name to value mapping -@type _by_text: dict -@var _by_value: The rdata type value to textual name mapping -@type _by_value: dict -@var _metatypes: If an rdatatype is a metatype, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _metatypes: dict -@var _singletons: If an rdatatype is a singleton, there will be a mapping -whose key is the rdatatype value and whose value is True in this dictionary. -@type _singletons: dict""" +"""DNS Rdata Types.""" import re @@ -82,6 +73,7 @@ TLSA = 52 HIP = 55 CDS = 59 CDNSKEY = 60 +OPENPGPKEY = 61 CSYNC = 62 SPF = 99 UNSPEC = 103 @@ -153,6 +145,7 @@ _by_text = { 'HIP': HIP, 'CDS': CDS, 'CDNSKEY': CDNSKEY, + 'OPENPGPKEY': OPENPGPKEY, 'CSYNC': CSYNC, 'SPF': SPF, 'UNSPEC': UNSPEC, @@ -176,8 +169,7 @@ _by_text = { # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be true inverse. -_by_value = dict((y, x) for x, y in _by_text.items()) - +_by_value = {y: x for x, y in _by_text.items()} _metatypes = { OPT: True @@ -188,24 +180,30 @@ _singletons = { NXT: True, DNAME: True, NSEC: True, - # CNAME is technically a singleton, but we allow multiple CNAMEs. + CNAME: True, } _unknown_type_pattern = re.compile('TYPE([0-9]+)$', re.I) class UnknownRdatatype(dns.exception.DNSException): - """DNS resource record type is unknown.""" def from_text(text): """Convert text into a DNS rdata type value. - @param text: the text - @type text: string - @raises dns.rdatatype.UnknownRdatatype: the type is unknown - @raises ValueError: the rdata type value is not >= 0 and <= 65535 - @rtype: int""" + + The input text can be a defined DNS RR type mnemonic or + instance of the DNS generic type syntax. + + For example, "NS" and "TYPE2" will both result in a value of 2. + + Raises ``dns.rdatatype.UnknownRdatatype`` if the type is unknown. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns an ``int``. + """ value = _by_text.get(text.upper()) if value is None: @@ -219,11 +217,15 @@ def from_text(text): def to_text(value): - """Convert a DNS rdata type to text. - @param value: the rdata type value - @type value: int - @raises ValueError: the rdata type value is not >= 0 and <= 65535 - @rtype: string""" + """Convert a DNS rdata type value to text. + + If the value has a known mnemonic, it will be used, otherwise the + DNS generic type syntax will be used. + + Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. + + Returns a ``str``. + """ if value < 0 or value > 65535: raise ValueError("type must be between >= 0 and <= 65535") @@ -234,10 +236,15 @@ def to_text(value): def is_metatype(rdtype): - """True if the type is a metatype. - @param rdtype: the type - @type rdtype: int - @rtype: bool""" + """True if the specified type is a metatype. + + *rdtype* is an ``int``. + + The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, + MAILB, ANY, and OPT. + + Returns a ``bool``. + """ if rdtype >= TKEY and rdtype <= ANY or rdtype in _metatypes: return True @@ -245,11 +252,36 @@ def is_metatype(rdtype): def is_singleton(rdtype): - """True if the type is a singleton. - @param rdtype: the type - @type rdtype: int - @rtype: bool""" + """Is the specified type a singleton type? + + Singleton types can only have a single rdata in an rdataset, or a single + RR in an RRset. + + The currently defined singleton types are CNAME, DNAME, NSEC, NXT, and + SOA. + + *rdtype* is an ``int``. + + Returns a ``bool``. + """ if rdtype in _singletons: return True return False + + +def register_type(rdtype, rdtype_text, is_singleton=False): # pylint: disable=redefined-outer-name + """Dynamically register an rdatatype. + + *rdtype*, an ``int``, the rdatatype to register. + + *rdtype_text*, a ``text``, the textual form of the rdatatype. + + *is_singleton*, a ``bool``, indicating if the type is a singleton (i.e. + RRsets of the type can have only one member.) + """ + + _by_text[rdtype_text] = rdtype + _by_value[rdtype] = rdtype_text + if is_singleton: + _singletons[rdtype] = True diff --git a/src/dns/rdtypes/ANY/AFSDB.py b/src/dns/rdtypes/ANY/AFSDB.py index f3d51540..c6a700cf 100644 --- a/src/dns/rdtypes/ANY/AFSDB.py +++ b/src/dns/rdtypes/ANY/AFSDB.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/AVC.py b/src/dns/rdtypes/ANY/AVC.py index 137c9de9..7f340b39 100644 --- a/src/dns/rdtypes/ANY/AVC.py +++ b/src/dns/rdtypes/ANY/AVC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2016 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CAA.py b/src/dns/rdtypes/ANY/CAA.py index f2e41ad0..0acf201a 100644 --- a/src/dns/rdtypes/ANY/CAA.py +++ b/src/dns/rdtypes/ANY/CAA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CDNSKEY.py b/src/dns/rdtypes/ANY/CDNSKEY.py index 83f3d51f..653ae1be 100644 --- a/src/dns/rdtypes/ANY/CDNSKEY.py +++ b/src/dns/rdtypes/ANY/CDNSKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CDS.py b/src/dns/rdtypes/ANY/CDS.py index e1abfc36..a63041dd 100644 --- a/src/dns/rdtypes/ANY/CDS.py +++ b/src/dns/rdtypes/ANY/CDS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CERT.py b/src/dns/rdtypes/ANY/CERT.py index 1c35c23d..eea27b52 100644 --- a/src/dns/rdtypes/ANY/CERT.py +++ b/src/dns/rdtypes/ANY/CERT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CNAME.py b/src/dns/rdtypes/ANY/CNAME.py index 65cf570c..11d42aa7 100644 --- a/src/dns/rdtypes/ANY/CNAME.py +++ b/src/dns/rdtypes/ANY/CNAME.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/CSYNC.py b/src/dns/rdtypes/ANY/CSYNC.py index bf95cb27..06292fb2 100644 --- a/src/dns/rdtypes/ANY/CSYNC.py +++ b/src/dns/rdtypes/ANY/CSYNC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011, 2016 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/DLV.py b/src/dns/rdtypes/ANY/DLV.py index cd1244c1..16352125 100644 --- a/src/dns/rdtypes/ANY/DLV.py +++ b/src/dns/rdtypes/ANY/DLV.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/DNAME.py b/src/dns/rdtypes/ANY/DNAME.py index dac97214..2499283c 100644 --- a/src/dns/rdtypes/ANY/DNAME.py +++ b/src/dns/rdtypes/ANY/DNAME.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/DNSKEY.py b/src/dns/rdtypes/ANY/DNSKEY.py index e915e98b..e36f7bc5 100644 --- a/src/dns/rdtypes/ANY/DNSKEY.py +++ b/src/dns/rdtypes/ANY/DNSKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/DS.py b/src/dns/rdtypes/ANY/DS.py index 577c8d84..7d457b22 100644 --- a/src/dns/rdtypes/ANY/DS.py +++ b/src/dns/rdtypes/ANY/DS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/GPOS.py b/src/dns/rdtypes/ANY/GPOS.py index a359a771..422822f0 100644 --- a/src/dns/rdtypes/ANY/GPOS.py +++ b/src/dns/rdtypes/ANY/GPOS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -80,7 +82,7 @@ class GPOS(dns.rdata.Rdata): self.altitude = altitude def to_text(self, origin=None, relativize=True, **kw): - return '%s %s %s' % (self.latitude.decode(), + return '{} {} {}'.format(self.latitude.decode(), self.longitude.decode(), self.altitude.decode()) diff --git a/src/dns/rdtypes/ANY/HINFO.py b/src/dns/rdtypes/ANY/HINFO.py index e5a1bea3..e4e0b34a 100644 --- a/src/dns/rdtypes/ANY/HINFO.py +++ b/src/dns/rdtypes/ANY/HINFO.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -45,8 +47,8 @@ class HINFO(dns.rdata.Rdata): self.os = os def to_text(self, origin=None, relativize=True, **kw): - return '"%s" "%s"' % (dns.rdata._escapify(self.cpu), - dns.rdata._escapify(self.os)) + return '"{}" "{}"'.format(dns.rdata._escapify(self.cpu), + dns.rdata._escapify(self.os)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): diff --git a/src/dns/rdtypes/ANY/HIP.py b/src/dns/rdtypes/ANY/HIP.py index fbe955c3..7c876b2d 100644 --- a/src/dns/rdtypes/ANY/HIP.py +++ b/src/dns/rdtypes/ANY/HIP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2010, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/ISDN.py b/src/dns/rdtypes/ANY/ISDN.py index da2ae3af..f5f5f8b9 100644 --- a/src/dns/rdtypes/ANY/ISDN.py +++ b/src/dns/rdtypes/ANY/ISDN.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -46,7 +48,7 @@ class ISDN(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): if self.subaddress: - return '"%s" "%s"' % (dns.rdata._escapify(self.address), + return '"{}" "{}"'.format(dns.rdata._escapify(self.address), dns.rdata._escapify(self.subaddress)) else: return '"%s"' % dns.rdata._escapify(self.address) diff --git a/src/dns/rdtypes/ANY/LOC.py b/src/dns/rdtypes/ANY/LOC.py index b433da94..da9bb03a 100644 --- a/src/dns/rdtypes/ANY/LOC.py +++ b/src/dns/rdtypes/ANY/LOC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -156,7 +158,7 @@ class LOC(dns.rdata.Rdata): if self.size != _default_size or \ self.horizontal_precision != _default_hprec or \ self.vertical_precision != _default_vprec: - text += " %0.2fm %0.2fm %0.2fm" % ( + text += " {:0.2f}m {:0.2f}m {:0.2f}m".format( self.size / 100.0, self.horizontal_precision / 100.0, self.vertical_precision / 100.0 ) diff --git a/src/dns/rdtypes/ANY/MX.py b/src/dns/rdtypes/ANY/MX.py index 3a6735dc..0a06494f 100644 --- a/src/dns/rdtypes/ANY/MX.py +++ b/src/dns/rdtypes/ANY/MX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/NS.py b/src/dns/rdtypes/ANY/NS.py index ae56d819..f9fcf637 100644 --- a/src/dns/rdtypes/ANY/NS.py +++ b/src/dns/rdtypes/ANY/NS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/NSEC.py b/src/dns/rdtypes/ANY/NSEC.py index dfe96859..4e3da729 100644 --- a/src/dns/rdtypes/ANY/NSEC.py +++ b/src/dns/rdtypes/ANY/NSEC.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -50,7 +52,7 @@ class NSEC(dns.rdata.Rdata): bits.append(dns.rdatatype.to_text(window * 256 + i * 8 + j)) text += (' ' + ' '.join(bits)) - return '%s%s' % (next, text) + return '{}{}'.format(next, text) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): diff --git a/src/dns/rdtypes/ANY/NSEC3.py b/src/dns/rdtypes/ANY/NSEC3.py index 9a15687b..1c281c4a 100644 --- a/src/dns/rdtypes/ANY/NSEC3.py +++ b/src/dns/rdtypes/ANY/NSEC3.py @@ -1,4 +1,6 @@ -# Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2004-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -21,18 +23,21 @@ import struct import dns.exception import dns.rdata import dns.rdatatype -from dns._compat import xrange, text_type +from dns._compat import xrange, text_type, PY3 -try: - b32_hex_to_normal = string.maketrans('0123456789ABCDEFGHIJKLMNOPQRSTUV', - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') - b32_normal_to_hex = string.maketrans('ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', - '0123456789ABCDEFGHIJKLMNOPQRSTUV') -except AttributeError: +# pylint: disable=deprecated-string-function +if PY3: b32_hex_to_normal = bytes.maketrans(b'0123456789ABCDEFGHIJKLMNOPQRSTUV', b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') b32_normal_to_hex = bytes.maketrans(b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', b'0123456789ABCDEFGHIJKLMNOPQRSTUV') +else: + b32_hex_to_normal = string.maketrans('0123456789ABCDEFGHIJKLMNOPQRSTUV', + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567') + b32_normal_to_hex = string.maketrans('ABCDEFGHIJKLMNOPQRSTUVWXYZ234567', + '0123456789ABCDEFGHIJKLMNOPQRSTUV') +# pylint: enable=deprecated-string-function + # hash algorithm constants SHA1 = 1 @@ -130,7 +135,7 @@ class NSEC3(dns.rdata.Rdata): new_window = nrdtype // 256 if new_window != window: if octets != 0: - windows.append((window, ''.join(bitmap[0:octets]))) + windows.append((window, bitmap[0:octets])) bitmap = bytearray(b'\0' * 32) window = new_window offset = nrdtype % 256 diff --git a/src/dns/rdtypes/ANY/NSEC3PARAM.py b/src/dns/rdtypes/ANY/NSEC3PARAM.py index 36bf7409..87c36e56 100644 --- a/src/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/src/dns/rdtypes/ANY/NSEC3PARAM.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/OPENPGPKEY.py b/src/dns/rdtypes/ANY/OPENPGPKEY.py new file mode 100644 index 00000000..a066cf98 --- /dev/null +++ b/src/dns/rdtypes/ANY/OPENPGPKEY.py @@ -0,0 +1,60 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2016 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import base64 + +import dns.exception +import dns.rdata +import dns.tokenizer + +class OPENPGPKEY(dns.rdata.Rdata): + + """OPENPGPKEY record + + @ivar key: the key + @type key: bytes + @see: RFC 7929 + """ + + def __init__(self, rdclass, rdtype, key): + super(OPENPGPKEY, self).__init__(rdclass, rdtype) + self.key = key + + def to_text(self, origin=None, relativize=True, **kw): + return dns.rdata._base64ify(self.key) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + chunks = [] + while 1: + t = tok.get().unescape() + if t.is_eol_or_eof(): + break + if not t.is_identifier(): + raise dns.exception.SyntaxError + chunks.append(t.value.encode()) + b64 = b''.join(chunks) + key = base64.b64decode(b64) + return cls(rdclass, rdtype, key) + + def to_wire(self, file, compress=None, origin=None): + file.write(self.key) + + @classmethod + def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + key = wire[current: current + rdlen].unwrap() + return cls(rdclass, rdtype, key) diff --git a/src/dns/rdtypes/ANY/PTR.py b/src/dns/rdtypes/ANY/PTR.py index 250187a6..20cd5076 100644 --- a/src/dns/rdtypes/ANY/PTR.py +++ b/src/dns/rdtypes/ANY/PTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/RP.py b/src/dns/rdtypes/ANY/RP.py index e9071c76..8f07be90 100644 --- a/src/dns/rdtypes/ANY/RP.py +++ b/src/dns/rdtypes/ANY/RP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -39,7 +41,7 @@ class RP(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): mbox = self.mbox.choose_relativity(origin, relativize) txt = self.txt.choose_relativity(origin, relativize) - return "%s %s" % (str(mbox), str(txt)) + return "{} {}".format(str(mbox), str(txt)) @classmethod def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): diff --git a/src/dns/rdtypes/ANY/RRSIG.py b/src/dns/rdtypes/ANY/RRSIG.py index 953dfb9a..d3756ece 100644 --- a/src/dns/rdtypes/ANY/RRSIG.py +++ b/src/dns/rdtypes/ANY/RRSIG.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/RT.py b/src/dns/rdtypes/ANY/RT.py index 88b75486..d0feb79e 100644 --- a/src/dns/rdtypes/ANY/RT.py +++ b/src/dns/rdtypes/ANY/RT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/SOA.py b/src/dns/rdtypes/ANY/SOA.py index cc0098e8..aec81cad 100644 --- a/src/dns/rdtypes/ANY/SOA.py +++ b/src/dns/rdtypes/ANY/SOA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/SPF.py b/src/dns/rdtypes/ANY/SPF.py index f3e0904e..41dee623 100644 --- a/src/dns/rdtypes/ANY/SPF.py +++ b/src/dns/rdtypes/ANY/SPF.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/SSHFP.py b/src/dns/rdtypes/ANY/SSHFP.py index 7e846b34..c18311e9 100644 --- a/src/dns/rdtypes/ANY/SSHFP.py +++ b/src/dns/rdtypes/ANY/SSHFP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/TLSA.py b/src/dns/rdtypes/ANY/TLSA.py index 790a93b9..a135c2b3 100644 --- a/src/dns/rdtypes/ANY/TLSA.py +++ b/src/dns/rdtypes/ANY/TLSA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2005-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/TXT.py b/src/dns/rdtypes/ANY/TXT.py index 6c7fa450..c5ae919c 100644 --- a/src/dns/rdtypes/ANY/TXT.py +++ b/src/dns/rdtypes/ANY/TXT.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/URI.py b/src/dns/rdtypes/ANY/URI.py index b5595b51..f5b65ed6 100644 --- a/src/dns/rdtypes/ANY/URI.py +++ b/src/dns/rdtypes/ANY/URI.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # Copyright (C) 2015 Red Hat, Inc. # diff --git a/src/dns/rdtypes/ANY/X25.py b/src/dns/rdtypes/ANY/X25.py index 8732ccf0..e530a2c2 100644 --- a/src/dns/rdtypes/ANY/X25.py +++ b/src/dns/rdtypes/ANY/X25.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/ANY/__init__.py b/src/dns/rdtypes/ANY/__init__.py index ea9c3e2e..ca41ef80 100644 --- a/src/dns/rdtypes/ANY/__init__.py +++ b/src/dns/rdtypes/ANY/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -17,10 +19,13 @@ __all__ = [ 'AFSDB', + 'AVC', + 'CAA', 'CDNSKEY', 'CDS', 'CERT', 'CNAME', + 'CSYNC', 'DLV', 'DNAME', 'DNSKEY', @@ -37,7 +42,7 @@ __all__ = [ 'NSEC', 'NSEC3', 'NSEC3PARAM', - 'TLSA', + 'OPENPGPKEY', 'PTR', 'RP', 'RRSIG', @@ -45,6 +50,8 @@ __all__ = [ 'SOA', 'SPF', 'SSHFP', + 'TLSA', 'TXT', + 'URI', 'X25', ] diff --git a/src/dns/rdtypes/CH/A.py b/src/dns/rdtypes/CH/A.py new file mode 100644 index 00000000..e65d192d --- /dev/null +++ b/src/dns/rdtypes/CH/A.py @@ -0,0 +1,70 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.rdtypes.mxbase +import struct + +class A(dns.rdtypes.mxbase.MXBase): + + """A record for Chaosnet + @ivar domain: the domain of the address + @type domain: dns.name.Name object + @ivar address: the 16-bit address + @type address: int""" + + __slots__ = ['domain', 'address'] + + def __init__(self, rdclass, rdtype, address, domain): + super(A, self).__init__(rdclass, rdtype, address, domain) + self.domain = domain + self.address = address + + def to_text(self, origin=None, relativize=True, **kw): + domain = self.domain.choose_relativity(origin, relativize) + return '%s %o' % (domain, self.address) + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + domain = tok.get_name() + address = tok.get_uint16(base=8) + domain = domain.choose_relativity(origin, relativize) + tok.get_eol() + return cls(rdclass, rdtype, address, domain) + + def to_wire(self, file, compress=None, origin=None): + self.domain.to_wire(file, compress, origin) + pref = struct.pack("!H", self.address) + file.write(pref) + + def to_digestable(self, origin=None): + return self.domain.to_digestable(origin) + \ + struct.pack("!H", self.address) + + @classmethod + def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + (domain, cused) = dns.name.from_wire(wire[: current + rdlen-2], + current) + current += cused + (address,) = struct.unpack('!H', wire[current: current + 2]) + if cused+2 != rdlen: + raise dns.exception.FormError + if origin is not None: + domain = domain.relativize(origin) + return cls(rdclass, rdtype, address, domain) + + def choose_relativity(self, origin=None, relativize=True): + self.domain = self.domain.choose_relativity(origin, relativize) diff --git a/src/dns/rdtypes/CH/__init__.py b/src/dns/rdtypes/CH/__init__.py new file mode 100644 index 00000000..7184a733 --- /dev/null +++ b/src/dns/rdtypes/CH/__init__.py @@ -0,0 +1,22 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Class CH rdata type classes.""" + +__all__ = [ + 'A', +] diff --git a/src/dns/rdtypes/IN/A.py b/src/dns/rdtypes/IN/A.py index 3775548f..89989824 100644 --- a/src/dns/rdtypes/IN/A.py +++ b/src/dns/rdtypes/IN/A.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -48,5 +50,5 @@ class A(dns.rdata.Rdata): @classmethod def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): - address = dns.ipv4.inet_ntoa(wire[current: current + rdlen]).decode() + address = dns.ipv4.inet_ntoa(wire[current: current + rdlen]) return cls(rdclass, rdtype, address) diff --git a/src/dns/rdtypes/IN/AAAA.py b/src/dns/rdtypes/IN/AAAA.py index 4352404d..a77c5bf2 100644 --- a/src/dns/rdtypes/IN/AAAA.py +++ b/src/dns/rdtypes/IN/AAAA.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/APL.py b/src/dns/rdtypes/IN/APL.py index 57ef6c0a..48faf88a 100644 --- a/src/dns/rdtypes/IN/APL.py +++ b/src/dns/rdtypes/IN/APL.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,14 +15,15 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -import struct import binascii +import codecs +import struct import dns.exception import dns.inet import dns.rdata import dns.tokenizer -from dns._compat import xrange +from dns._compat import xrange, maybe_chr class APLItem(object): @@ -63,7 +66,7 @@ class APLItem(object): # last = 0 for i in xrange(len(address) - 1, -1, -1): - if address[i] != chr(0): + if address[i] != maybe_chr(0): last = i + 1 break address = address[0: last] @@ -121,6 +124,7 @@ class APL(dns.rdata.Rdata): @classmethod def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + items = [] while 1: if rdlen == 0: @@ -142,18 +146,18 @@ class APL(dns.rdata.Rdata): l = len(address) if header[0] == 1: if l < 4: - address += '\x00' * (4 - l) + address += b'\x00' * (4 - l) address = dns.inet.inet_ntop(dns.inet.AF_INET, address) elif header[0] == 2: if l < 16: - address += '\x00' * (16 - l) + address += b'\x00' * (16 - l) address = dns.inet.inet_ntop(dns.inet.AF_INET6, address) else: # # This isn't really right according to the RFC, but it # seems better than throwing an exception # - address = address.encode('hex_codec') + address = codecs.encode(address, 'hex_codec') current += afdlen rdlen -= afdlen item = APLItem(header[0], negation, address, header[1]) diff --git a/src/dns/rdtypes/IN/DHCID.py b/src/dns/rdtypes/IN/DHCID.py index 5b8626a5..cec64590 100644 --- a/src/dns/rdtypes/IN/DHCID.py +++ b/src/dns/rdtypes/IN/DHCID.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/IPSECKEY.py b/src/dns/rdtypes/IN/IPSECKEY.py index c673e839..8f49ba13 100644 --- a/src/dns/rdtypes/IN/IPSECKEY.py +++ b/src/dns/rdtypes/IN/IPSECKEY.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/KX.py b/src/dns/rdtypes/IN/KX.py index adbfe34b..1318a582 100644 --- a/src/dns/rdtypes/IN/KX.py +++ b/src/dns/rdtypes/IN/KX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/NAPTR.py b/src/dns/rdtypes/IN/NAPTR.py index 5ae2feb1..32fa4745 100644 --- a/src/dns/rdtypes/IN/NAPTR.py +++ b/src/dns/rdtypes/IN/NAPTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/NSAP.py b/src/dns/rdtypes/IN/NSAP.py index 05d0745e..336befc7 100644 --- a/src/dns/rdtypes/IN/NSAP.py +++ b/src/dns/rdtypes/IN/NSAP.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/NSAP_PTR.py b/src/dns/rdtypes/IN/NSAP_PTR.py index 56967df0..a5b66c80 100644 --- a/src/dns/rdtypes/IN/NSAP_PTR.py +++ b/src/dns/rdtypes/IN/NSAP_PTR.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/PX.py b/src/dns/rdtypes/IN/PX.py index e1ef102b..2dbaee6c 100644 --- a/src/dns/rdtypes/IN/PX.py +++ b/src/dns/rdtypes/IN/PX.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/SRV.py b/src/dns/rdtypes/IN/SRV.py index f4396d61..b2c1bc9f 100644 --- a/src/dns/rdtypes/IN/SRV.py +++ b/src/dns/rdtypes/IN/SRV.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/WKS.py b/src/dns/rdtypes/IN/WKS.py index 1d4012c3..96f98ada 100644 --- a/src/dns/rdtypes/IN/WKS.py +++ b/src/dns/rdtypes/IN/WKS.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/IN/__init__.py b/src/dns/rdtypes/IN/__init__.py index 24cf1ece..d7e69c9f 100644 --- a/src/dns/rdtypes/IN/__init__.py +++ b/src/dns/rdtypes/IN/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -20,6 +22,7 @@ __all__ = [ 'AAAA', 'APL', 'DHCID', + 'IPSECKEY', 'KX', 'NAPTR', 'NSAP', diff --git a/src/dns/rdtypes/__init__.py b/src/dns/rdtypes/__init__.py index 826efbb6..1ac137f1 100644 --- a/src/dns/rdtypes/__init__.py +++ b/src/dns/rdtypes/__init__.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -18,6 +20,7 @@ __all__ = [ 'ANY', 'IN', + 'CH', 'euibase', 'mxbase', 'nsbase', diff --git a/src/dns/rdtypes/dnskeybase.py b/src/dns/rdtypes/dnskeybase.py index 85c4b23f..3e7e87ef 100644 --- a/src/dns/rdtypes/dnskeybase.py +++ b/src/dns/rdtypes/dnskeybase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2004-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -38,7 +40,7 @@ _flag_by_text = { # We construct the inverse mapping programmatically to ensure that we # cannot make any mistakes (e.g. omissions, cut-and-paste errors) that # would cause the mapping not to be true inverse. -_flag_by_value = dict((y, x) for x, y in _flag_by_text.items()) +_flag_by_value = {y: x for x, y in _flag_by_text.items()} def flags_to_text_set(flags): diff --git a/src/dns/rdtypes/dnskeybase.pyi b/src/dns/rdtypes/dnskeybase.pyi new file mode 100644 index 00000000..e102a698 --- /dev/null +++ b/src/dns/rdtypes/dnskeybase.pyi @@ -0,0 +1,37 @@ +from typing import Set, Any + +SEP : int +REVOKE : int +ZONE : int + +def flags_to_text_set(flags : int) -> Set[str]: + ... + +def flags_from_text_set(texts_set) -> int: + ... + +from .. import rdata + +class DNSKEYBase(rdata.Rdata): + def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): + self.flags : int + self.protocol : int + self.key : str + self.algorithm : int + + def to_text(self, origin : Any = None, relativize=True, **kw : Any): + ... + + @classmethod + def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True): + ... + + def to_wire(self, file, compress=None, origin=None): + ... + + @classmethod + def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + ... + + def flags_to_text_set(self) -> Set[str]: + ... diff --git a/src/dns/rdtypes/dsbase.py b/src/dns/rdtypes/dsbase.py index 1ee28e4a..26ae9d5c 100644 --- a/src/dns/rdtypes/dsbase.py +++ b/src/dns/rdtypes/dsbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2010, 2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/mxbase.py b/src/dns/rdtypes/mxbase.py index 5ac8cef9..9a3fa623 100644 --- a/src/dns/rdtypes/mxbase.py +++ b/src/dns/rdtypes/mxbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/nsbase.py b/src/dns/rdtypes/nsbase.py index 79333a14..97a22326 100644 --- a/src/dns/rdtypes/nsbase.py +++ b/src/dns/rdtypes/nsbase.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/rdtypes/txtbase.py b/src/dns/rdtypes/txtbase.py index 352b027b..645a57ec 100644 --- a/src/dns/rdtypes/txtbase.py +++ b/src/dns/rdtypes/txtbase.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,30 +22,35 @@ import struct import dns.exception import dns.rdata import dns.tokenizer -from dns._compat import binary_type +from dns._compat import binary_type, string_types class TXTBase(dns.rdata.Rdata): """Base class for rdata that is like a TXT record - @ivar strings: the text strings - @type strings: list of string + @ivar strings: the strings + @type strings: list of binary @see: RFC 1035""" __slots__ = ['strings'] def __init__(self, rdclass, rdtype, strings): super(TXTBase, self).__init__(rdclass, rdtype) - if isinstance(strings, str): + if isinstance(strings, binary_type) or \ + isinstance(strings, string_types): strings = [strings] - self.strings = strings[:] + self.strings = [] + for string in strings: + if isinstance(string, string_types): + string = string.encode() + self.strings.append(string) def to_text(self, origin=None, relativize=True, **kw): txt = '' prefix = '' for s in self.strings: - txt += '%s"%s"' % (prefix, dns.rdata._escapify(s)) + txt += '{}"{}"'.format(prefix, dns.rdata._escapify(s)) prefix = ' ' return txt diff --git a/src/dns/rdtypes/txtbase.pyi b/src/dns/rdtypes/txtbase.pyi new file mode 100644 index 00000000..af447d50 --- /dev/null +++ b/src/dns/rdtypes/txtbase.pyi @@ -0,0 +1,6 @@ +from .. import rdata + +class TXTBase(rdata.Rdata): + ... +class TXT(TXTBase): + ... diff --git a/src/dns/renderer.py b/src/dns/renderer.py index 670fb28f..d7ef8c7f 100644 --- a/src/dns/renderer.py +++ b/src/dns/renderer.py @@ -1,4 +1,6 @@ -# Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -32,7 +34,6 @@ ADDITIONAL = 3 class Renderer(object): - """Helper class for building DNS wire-format messages. Most applications can use the higher-level L{dns.message.Message} @@ -54,41 +55,27 @@ class Renderer(object): r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac) wire = r.get_wire() - @ivar output: where rendering is written - @type output: BytesIO object - @ivar id: the message id - @type id: int - @ivar flags: the message flags - @type flags: int - @ivar max_size: the maximum size of the message - @type max_size: int - @ivar origin: the origin to use when rendering relative names - @type origin: dns.name.Name object - @ivar compress: the compression table - @type compress: dict - @ivar section: the section currently being rendered - @type section: int (dns.renderer.QUESTION, dns.renderer.ANSWER, - dns.renderer.AUTHORITY, or dns.renderer.ADDITIONAL) - @ivar counts: list of the number of RRs in each section - @type counts: int list of length 4 - @ivar mac: the MAC of the rendered message (if TSIG was used) - @type mac: string + output, a BytesIO, where rendering is written + + id: the message id + + flags: the message flags + + max_size: the maximum size of the message + + origin: the origin to use when rendering relative names + + compress: the compression table + + section: an int, the section currently being rendered + + counts: list of the number of RRs in each section + + mac: the MAC of the rendered message (if TSIG was used) """ def __init__(self, id=None, flags=0, max_size=65535, origin=None): - """Initialize a new renderer. - - @param id: the message id - @type id: int - @param flags: the DNS message flags - @type flags: int - @param max_size: the maximum message size; the default is 65535. - If rendering results in a message greater than I{max_size}, - then L{dns.exception.TooBig} will be raised. - @type max_size: int - @param origin: the origin to use when rendering relative names - @type origin: dns.name.Name or None. - """ + """Initialize a new renderer.""" self.output = BytesIO() if id is None: @@ -105,12 +92,9 @@ class Renderer(object): self.mac = '' def _rollback(self, where): - """Truncate the output buffer at offset I{where}, and remove any + """Truncate the output buffer at offset *where*, and remove any compression table entries that pointed beyond the truncation point. - - @param where: the offset - @type where: int """ self.output.seek(where) @@ -128,9 +112,7 @@ class Renderer(object): Sections must be rendered order: QUESTION, ANSWER, AUTHORITY, ADDITIONAL. Sections may be empty. - @param section: the section - @type section: int - @raises dns.exception.FormError: an attempt was made to set + Raises dns.exception.FormError if an attempt was made to set a section value less than the current section. """ @@ -140,15 +122,7 @@ class Renderer(object): self.section = section def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN): - """Add a question to the message. - - @param qname: the question name - @type qname: dns.name.Name - @param rdtype: the question rdata type - @type rdtype: int - @param rdclass: the question rdata class - @type rdclass: int - """ + """Add a question to the message.""" self._set_section(QUESTION) before = self.output.tell() @@ -165,11 +139,6 @@ class Renderer(object): Any keyword arguments are passed on to the rdataset's to_wire() routine. - - @param section: the section - @type section: int - @param rrset: the rrset - @type rrset: dns.rrset.RRset object """ self._set_section(section) @@ -187,13 +156,6 @@ class Renderer(object): Any keyword arguments are passed on to the rdataset's to_wire() routine. - - @param section: the section - @type section: int - @param name: the owner name - @type name: dns.name.Name object - @param rdataset: the rdataset - @type rdataset: dns.rdataset.Rdataset object """ self._set_section(section) @@ -207,19 +169,7 @@ class Renderer(object): self.counts[section] += n def add_edns(self, edns, ednsflags, payload, options=None): - """Add an EDNS OPT record to the message. - - @param edns: The EDNS level to use. - @type edns: int - @param ednsflags: EDNS flag values. - @type ednsflags: int - @param payload: The EDNS sender's payload field, which is the maximum - size of UDP datagram the sender can handle. - @type payload: int - @param options: The EDNS options list - @type options: list of dns.edns.Option instances - @see: RFC 2671 - """ + """Add an EDNS OPT record to the message.""" # make sure the EDNS version in ednsflags agrees with edns ednsflags &= long(0xFF00FFFF) @@ -255,29 +205,8 @@ class Renderer(object): def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data, request_mac, algorithm=dns.tsig.default_algorithm): - """Add a TSIG signature to the message. + """Add a TSIG signature to the message.""" - @param keyname: the TSIG key name - @type keyname: dns.name.Name object - @param secret: the secret to use - @type secret: string - @param fudge: TSIG time fudge - @type fudge: int - @param id: the message id to encode in the tsig signature - @type id: int - @param tsig_error: TSIG error code; default is 0. - @type tsig_error: int - @param other_data: TSIG other data. - @type other_data: string - @param request_mac: This message is a response to the request which - had the specified MAC. - @type request_mac: string - @param algorithm: the TSIG algorithm to use - @type algorithm: dns.name.Name object - """ - - self._set_section(ADDITIONAL) - before = self.output.tell() s = self.output.getvalue() (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, keyname, @@ -289,16 +218,52 @@ class Renderer(object): other_data, request_mac, algorithm=algorithm) + self._write_tsig(tsig_rdata, keyname) + + def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error, + other_data, request_mac, + algorithm=dns.tsig.default_algorithm): + """Add a TSIG signature to the message. Unlike add_tsig(), this can be + used for a series of consecutive DNS envelopes, e.g. for a zone + transfer over TCP [RFC2845, 4.4]. + + For the first message in the sequence, give ctx=None. For each + subsequent message, give the ctx that was returned from the + add_multi_tsig() call for the previous message.""" + + s = self.output.getvalue() + (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, + keyname, + secret, + int(time.time()), + fudge, + id, + tsig_error, + other_data, + request_mac, + ctx=ctx, + first=ctx is None, + multi=True, + algorithm=algorithm) + self._write_tsig(tsig_rdata, keyname) + return ctx + + def _write_tsig(self, tsig_rdata, keyname): + self._set_section(ADDITIONAL) + before = self.output.tell() + keyname.to_wire(self.output, self.compress, self.origin) self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)) rdata_start = self.output.tell() self.output.write(tsig_rdata) + after = self.output.tell() assert after - rdata_start < 65536 if after >= self.max_size: self._rollback(before) raise dns.exception.TooBig + self.output.seek(rdata_start - 2) self.output.write(struct.pack('!H', after - rdata_start)) self.counts[ADDITIONAL] += 1 @@ -321,9 +286,6 @@ class Renderer(object): self.output.seek(0, 2) def get_wire(self): - """Return the wire format message. - - @rtype: string - """ + """Return the wire format message.""" return self.output.getvalue() diff --git a/src/dns/resolver.py b/src/dns/resolver.py index abc431d7..806e5b2b 100644 --- a/src/dns/resolver.py +++ b/src/dns/resolver.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,10 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS stub resolver. - -@var default_resolver: The default resolver object -@type default_resolver: dns.resolver.Resolver object""" +"""DNS stub resolver.""" import socket import sys @@ -49,9 +48,8 @@ if sys.platform == 'win32': import _winreg # pylint: disable=import-error class NXDOMAIN(dns.exception.DNSException): - """The DNS query name does not exist.""" - supp_kwargs = set(['qnames', 'responses']) + supp_kwargs = {'qnames', 'responses'} fmt = None # we have our own __str__ implementation def _check_kwargs(self, qnames, responses=None): @@ -73,9 +71,9 @@ class NXDOMAIN(dns.exception.DNSException): if len(qnames) > 1: msg = 'None of DNS query names exist' else: - msg = self.__doc__[:-1] + msg = 'The DNS query name does not exist' qnames = ', '.join(map(str, qnames)) - return "%s: %s" % (msg, qnames) + return "{}: {}".format(msg, qnames) def canonical_name(self): if not 'qnames' in self.kwargs: @@ -107,9 +105,30 @@ class NXDOMAIN(dns.exception.DNSException): responses0[qname1] = responses1[qname1] return NXDOMAIN(qnames=qnames0, responses=responses0) + def qnames(self): + """All of the names that were tried. + + Returns a list of ``dns.name.Name``. + """ + return self.kwargs['qnames'] + + def responses(self): + """A map from queried names to their NXDOMAIN responses. + + Returns a dict mapping a ``dns.name.Name`` to a + ``dns.message.Message``. + """ + return self.kwargs['responses'] + + def response(self, qname): + """The response for query *qname*. + + Returns a ``dns.message.Message``. + """ + return self.kwargs['responses'][qname] + class YXDOMAIN(dns.exception.DNSException): - """The DNS query name is too long after DNAME substitution.""" # The definition of the Timeout exception has moved from here to the @@ -120,11 +139,10 @@ Timeout = dns.exception.Timeout class NoAnswer(dns.exception.DNSException): - """The DNS response does not contain an answer to the question.""" fmt = 'The DNS response does not contain an answer ' + \ 'to the question: {query}' - supp_kwargs = set(['response']) + supp_kwargs = {'response'} def _fmt_kwargs(self, **kwargs): return super(NoAnswer, self)._fmt_kwargs( @@ -132,73 +150,53 @@ class NoAnswer(dns.exception.DNSException): class NoNameservers(dns.exception.DNSException): - """All nameservers failed to answer the query. errors: list of servers and respective errors The type of errors is - [(server ip address, any object convertible to string)]. + [(server IP address, any object convertible to string)]. Non-empty errors list will add explanatory message () """ msg = "All nameservers failed to answer the query." fmt = "%s {query}: {errors}" % msg[:-1] - supp_kwargs = set(['request', 'errors']) + supp_kwargs = {'request', 'errors'} def _fmt_kwargs(self, **kwargs): srv_msgs = [] for err in kwargs['errors']: - srv_msgs.append('Server %s %s port %s answered %s' % (err[0], + srv_msgs.append('Server {} {} port {} answered {}'.format(err[0], 'TCP' if err[1] else 'UDP', err[2], err[3])) return super(NoNameservers, self)._fmt_kwargs( query=kwargs['request'].question, errors='; '.join(srv_msgs)) class NotAbsolute(dns.exception.DNSException): - """An absolute domain name is required but a relative name was provided.""" class NoRootSOA(dns.exception.DNSException): - """There is no SOA RR at the DNS root name. This should never happen!""" class NoMetaqueries(dns.exception.DNSException): - """DNS metaqueries are not allowed.""" class Answer(object): - - """DNS stub resolver answer + """DNS stub resolver answer. Instances of this class bundle up the result of a successful DNS resolution. For convenience, the answer object implements much of the sequence - protocol, forwarding to its rrset. E.g. "for a in answer" is - equivalent to "for a in answer.rrset", "answer[i]" is equivalent - to "answer.rrset[i]", and "answer[i:j]" is equivalent to - "answer.rrset[i:j]". + protocol, forwarding to its ``rrset`` attribute. E.g. + ``for a in answer`` is equivalent to ``for a in answer.rrset``. + ``answer[i]`` is equivalent to ``answer.rrset[i]``, and + ``answer[i:j]`` is equivalent to ``answer.rrset[i:j]``. Note that CNAMEs or DNAMEs in the response may mean that answer - node's name might not be the query name. - - @ivar qname: The query name - @type qname: dns.name.Name object - @ivar rdtype: The query type - @type rdtype: int - @ivar rdclass: The query class - @type rdclass: int - @ivar response: The response message - @type response: dns.message.Message object - @ivar rrset: The answer - @type rrset: dns.rrset.RRset object - @ivar expiration: The time when the answer expires - @type expiration: float (seconds since the epoch) - @ivar canonical_name: The canonical name of the query name - @type canonical_name: dns.name.Name object + RRset's name might not be the query name. """ def __init__(self, qname, rdtype, rdclass, response, @@ -278,32 +276,22 @@ class Answer(object): return self.rrset and iter(self.rrset) or iter(tuple()) def __getitem__(self, i): + if self.rrset is None: + raise IndexError return self.rrset[i] def __delitem__(self, i): + if self.rrset is None: + raise IndexError del self.rrset[i] class Cache(object): - - """Simple DNS answer cache. - - @ivar data: A dictionary of cached data - @type data: dict - @ivar cleaning_interval: The number of seconds between cleanings. The - default is 300 (5 minutes). - @type cleaning_interval: float - @ivar next_cleaning: The time the cache should next be cleaned (in seconds - since the epoch.) - @type next_cleaning: float - """ + """Simple thread-safe DNS answer cache.""" def __init__(self, cleaning_interval=300.0): - """Initialize a DNS cache. - - @param cleaning_interval: the number of seconds between periodic - cleanings. The default is 300.0 - @type cleaning_interval: float. + """*cleaning_interval*, a ``float`` is the number of seconds between + periodic cleanings. """ self.data = {} @@ -326,12 +314,14 @@ class Cache(object): self.next_cleaning = now + self.cleaning_interval def get(self, key): - """Get the answer associated with I{key}. Returns None if - no answer is cached for the key. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @rtype: dns.resolver.Answer object or None + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. """ try: @@ -346,11 +336,11 @@ class Cache(object): def put(self, key, value): """Associate key and value in the cache. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @param value: The answer being cached - @type value: dns.resolver.Answer object + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. """ try: @@ -363,11 +353,11 @@ class Cache(object): def flush(self, key=None): """Flush the cache. - If I{key} is specified, only that item is flushed. Otherwise + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - @param key: the key to flush - @type key: (dns.name.Name, int, int) tuple or None + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. """ try: @@ -383,9 +373,7 @@ class Cache(object): class LRUCacheNode(object): - - """LRUCache node. - """ + """LRUCache node.""" def __init__(self, key, value): self.key = key @@ -411,30 +399,20 @@ class LRUCacheNode(object): class LRUCache(object): - - """Bounded least-recently-used DNS answer cache. + """Thread-safe, bounded, least-recently-used DNS answer cache. This cache is better than the simple cache (above) if you're running a web crawler or other process that does a lot of resolutions. The LRUCache has a maximum number of nodes, and when it is full, the least-recently used node is removed to make space for a new one. - - @ivar data: A dictionary of cached data - @type data: dict - @ivar sentinel: sentinel node for circular doubly linked list of nodes - @type sentinel: LRUCacheNode object - @ivar max_size: The maximum number of nodes - @type max_size: int """ def __init__(self, max_size=100000): - """Initialize a DNS cache. - - @param max_size: The maximum number of nodes to cache; the default is - 100,000. Must be greater than 1. - @type max_size: int + """*max_size*, an ``int``, is the maximum number of nodes to cache; + it must be greater than 0. """ + self.data = {} self.set_max_size(max_size) self.sentinel = LRUCacheNode(None, None) @@ -446,13 +424,16 @@ class LRUCache(object): self.max_size = max_size def get(self, key): - """Get the answer associated with I{key}. Returns None if - no answer is cached for the key. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @rtype: dns.resolver.Answer object or None + """Get the answer associated with *key*. + + Returns None if no answer is cached for the key. + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + Returns a ``dns.resolver.Answer`` or ``None``. """ + try: self.lock.acquire() node = self.data.get(key) @@ -471,12 +452,13 @@ class LRUCache(object): def put(self, key, value): """Associate key and value in the cache. - @param key: the key - @type key: (dns.name.Name, int, int) tuple whose values are the - query name, rdtype, and rdclass. - @param value: The answer being cached - @type value: dns.resolver.Answer object + + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. + + *value*, a ``dns.resolver.Answer``, the answer. """ + try: self.lock.acquire() node = self.data.get(key) @@ -496,12 +478,13 @@ class LRUCache(object): def flush(self, key=None): """Flush the cache. - If I{key} is specified, only that item is flushed. Otherwise + If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - @param key: the key to flush - @type key: (dns.name.Name, int, int) tuple or None + *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + query name, rdtype, and rdclass respectively. """ + try: self.lock.acquire() if key is not None: @@ -522,62 +505,19 @@ class LRUCache(object): class Resolver(object): - - """DNS stub resolver - - @ivar domain: The domain of this host - @type domain: dns.name.Name object - @ivar nameservers: A list of nameservers to query. Each nameserver is - a string which contains the IP address of a nameserver. - @type nameservers: list of strings - @ivar search: The search list. If the query name is a relative name, - the resolver will construct an absolute query name by appending the search - names one by one to the query name. - @type search: list of dns.name.Name objects - @ivar port: The port to which to send queries. The default is 53. - @type port: int - @ivar timeout: The number of seconds to wait for a response from a - server, before timing out. - @type timeout: float - @ivar lifetime: The total number of seconds to spend trying to get an - answer to the question. If the lifetime expires, a Timeout exception - will occur. - @type lifetime: float - @ivar keyring: The TSIG keyring to use. The default is None. - @type keyring: dict - @ivar keyname: The TSIG keyname to use. The default is None. - @type keyname: dns.name.Name object - @ivar keyalgorithm: The TSIG key algorithm to use. The default is - dns.tsig.default_algorithm. - @type keyalgorithm: string - @ivar edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @ivar ednsflags: The EDNS flags - @type ednsflags: int - @ivar payload: The EDNS payload size. The default is 0. - @type payload: int - @ivar flags: The message flags to use. The default is None (i.e. not - overwritten) - @type flags: int - @ivar cache: The cache to use. The default is None. - @type cache: dns.resolver.Cache object - @ivar retry_servfail: should we retry a nameserver if it says SERVFAIL? - The default is 'false'. - @type retry_servfail: bool - """ + """DNS stub resolver.""" def __init__(self, filename='/etc/resolv.conf', configure=True): - """Initialize a resolver instance. + """*filename*, a ``text`` or file object, specifying a file + in standard /etc/resolv.conf format. This parameter is meaningful + only when *configure* is true and the platform is POSIX. - @param filename: The filename of a configuration file in - standard /etc/resolv.conf format. This parameter is meaningful - only when I{configure} is true and the platform is POSIX. - @type filename: string or file object - @param configure: If True (the default), the resolver instance - is configured in the normal fashion for the operating system - the resolver is running on. (I.e. a /etc/resolv.conf file on - POSIX systems and from the registry on Windows systems.) - @type configure: bool""" + *configure*, a ``bool``. If True (the default), the resolver + instance is configured in the normal fashion for the operating + system the resolver is running on. (I.e. by reading a + /etc/resolv.conf file on POSIX systems and from the registry + on Windows systems.) + """ self.domain = None self.nameservers = None @@ -606,6 +546,7 @@ class Resolver(object): def reset(self): """Reset all resolver configuration to the defaults.""" + self.domain = \ dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: @@ -628,9 +569,10 @@ class Resolver(object): self.rotate = False def read_resolv_conf(self, f): - """Process f as a file in the /etc/resolv.conf format. If f is - a string, it is used as the name of the file to open; otherwise it + """Process *f* as a file in the /etc/resolv.conf format. If f is + a ``text``, it is used as the name of the file to open; otherwise it is treated as the file itself.""" + if isinstance(f, string_types): try: f = open(f, 'r') @@ -684,7 +626,6 @@ class Resolver(object): return split_char def _config_win32_nameservers(self, nameservers): - """Configure a NameServer registry entry.""" # we call str() on nameservers to convert it from unicode to ascii nameservers = str(nameservers) split_char = self._determine_split_char(nameservers) @@ -694,12 +635,10 @@ class Resolver(object): self.nameservers.append(ns) def _config_win32_domain(self, domain): - """Configure a Domain registry entry.""" # we call str() on domain to convert it from unicode to ascii self.domain = dns.name.from_text(str(domain)) def _config_win32_search(self, search): - """Configure a Search registry entry.""" # we call str() on search to convert it from unicode to ascii search = str(search) split_char = self._determine_split_char(search) @@ -708,14 +647,14 @@ class Resolver(object): if s not in self.search: self.search.append(dns.name.from_text(s)) - def _config_win32_fromkey(self, key): - """Extract DNS info from a registry key.""" + def _config_win32_fromkey(self, key, always_try_domain): try: servers, rtype = _winreg.QueryValueEx(key, 'NameServer') except WindowsError: # pylint: disable=undefined-variable servers = None if servers: self._config_win32_nameservers(servers) + if servers or always_try_domain: try: dom, rtype = _winreg.QueryValueEx(key, 'Domain') if dom: @@ -744,6 +683,7 @@ class Resolver(object): def read_registry(self): """Extract resolver configuration from the Windows registry.""" + lm = _winreg.ConnectRegistry(None, _winreg.HKEY_LOCAL_MACHINE) want_scan = False try: @@ -759,7 +699,7 @@ class Resolver(object): r'SYSTEM\CurrentControlSet' r'\Services\VxD\MSTCP') try: - self._config_win32_fromkey(tcp_params) + self._config_win32_fromkey(tcp_params, True) finally: tcp_params.Close() if want_scan: @@ -777,7 +717,7 @@ class Resolver(object): if not self._win32_is_nic_enabled(lm, guid, key): continue try: - self._config_win32_fromkey(key) + self._config_win32_fromkey(key, False) finally: key.Close() except EnvironmentError: @@ -842,7 +782,8 @@ class Resolver(object): except WindowsError: # pylint: disable=undefined-variable return False - def _compute_timeout(self, start): + def _compute_timeout(self, start, lifetime=None): + lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start if duration < 0: @@ -854,44 +795,54 @@ class Resolver(object): # happen, e.g. under vmware with older linux kernels. # Pretend it didn't happen. now = start - if duration >= self.lifetime: + if duration >= lifetime: raise Timeout(timeout=duration) - return min(self.lifetime - duration, self.timeout) + return min(lifetime - duration, self.timeout) def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0): + tcp=False, source=None, raise_on_no_answer=True, source_port=0, + lifetime=None): """Query nameservers to find the answer to the question. - The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects + The *qname*, *rdtype*, and *rdclass* parameters may be objects of the appropriate type, or strings that can be converted into objects - of the appropriate type. E.g. For I{rdtype} the integer 2 and the - the string 'NS' both mean to query for records with DNS rdata type NS. + of the appropriate type. - @param qname: the query name - @type qname: dns.name.Name object or string - @param rdtype: the query type - @type rdtype: int or string - @param rdclass: the query class - @type rdclass: int or string - @param tcp: use TCP to make the query (default is False). - @type tcp: bool - @param source: bind to this IP address (defaults to machine default - IP). - @type source: IP address in dotted quad notation - @param raise_on_no_answer: raise NoAnswer if there's no answer - (defaults is True). - @type raise_on_no_answer: bool - @param source_port: The port from which to send the message. - The default is 0. - @type source_port: int - @rtype: dns.resolver.Answer instance - @raises Timeout: no answers could be found in the specified lifetime - @raises NXDOMAIN: the query name does not exist - @raises YXDOMAIN: the query name is too long after DNAME substitution - @raises NoAnswer: the response did not contain an answer and - raise_on_no_answer is True. - @raises NoNameservers: no non-broken nameservers are available to - answer the question.""" + *qname*, a ``dns.name.Name`` or ``text``, the query name. + + *rdtype*, an ``int`` or ``text``, the query type. + + *rdclass*, an ``int`` or ``text``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *source*, a ``text`` or ``None``. If not ``None``, bind to this IP + address when making queries. + + *raise_on_no_answer*, a ``bool``. If ``True``, raise + ``dns.resolver.NoAnswer`` if there's no answer to the question. + + *source_port*, an ``int``, the port from which to send the message. + + *lifetime*, a ``float``, how long query should run before timing out. + + Raises ``dns.exception.Timeout`` if no answers could be found + in the specified lifetime. + + Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. + + Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after + DNAME substitution. + + Raises ``dns.resolver.NoAnswer`` if *raise_on_no_answer* is + ``True`` and the query name exists but has no RRset of the + desired type and class. + + Raises ``dns.resolver.NoNameservers`` if no non-broken + nameservers are available to answer the question. + + Returns a ``dns.resolver.Answer`` instance. + """ if isinstance(qname, string_types): qname = dns.name.from_text(qname, None) @@ -946,7 +897,7 @@ class Resolver(object): if len(nameservers) == 0: raise NoNameservers(request=request, errors=errors) for nameserver in nameservers[:]: - timeout = self._compute_timeout(start) + timeout = self._compute_timeout(start, lifetime) port = self.nameserver_ports.get(nameserver, self.port) try: tcp_attempt = tcp @@ -963,7 +914,7 @@ class Resolver(object): if response.flags & dns.flags.TC: # Response truncated; retry with TCP. tcp_attempt = True - timeout = self._compute_timeout(start) + timeout = self._compute_timeout(start, lifetime) response = \ dns.query.tcp(request, nameserver, timeout, port, @@ -1038,7 +989,7 @@ class Resolver(object): # But we still have servers to try. Sleep a bit # so we don't pound them! # - timeout = self._compute_timeout(start) + timeout = self._compute_timeout(start, lifetime) sleep_time = min(timeout, backoff) backoff *= 2 time.sleep(sleep_time) @@ -1059,17 +1010,22 @@ class Resolver(object): algorithm=dns.tsig.default_algorithm): """Add a TSIG signature to the query. - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @param algorithm: The TSIG key algorithm to use. The default - is dns.tsig.default_algorithm. - @type algorithm: string""" + See the documentation of the Message class for a complete + description of the keyring dictionary. + + *keyring*, a ``dict``, the TSIG keyring to use. If a + *keyring* is specified but a *keyname* is not, then the key + used will be the first key in the *keyring*. Note that the + order of keys in a dictionary is not defined, so applications + should supply a keyname when a keyring is used, unless they + know the keyring contains only one key. + + *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key + to use; defaults to ``None``. The key must be defined in the keyring. + + *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use. + """ + self.keyring = keyring if keyname is None: self.keyname = list(self.keyring.keys())[0] @@ -1078,14 +1034,19 @@ class Resolver(object): self.keyalgorithm = algorithm def use_edns(self, edns, ednsflags, payload): - """Configure Edns. + """Configure EDNS behavior. - @param edns: The EDNS level to use. The default is -1, no Edns. - @type edns: int - @param ednsflags: The EDNS flags - @type ednsflags: int - @param payload: The EDNS payload size. The default is 0. - @type payload: int""" + *edns*, an ``int``, is the EDNS level to use. Specifying + ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case + the other parameters are ignored. Specifying ``True`` is + equivalent to specifying 0, i.e. "use EDNS0". + + *ednsflags*, an ``int``, the EDNS flag values. + + *payload*, an ``int``, is the EDNS sender's payload field, which is the + maximum size of UDP datagram the sender can handle. I.e. how big + a response to this message can be. + """ if edns is None: edns = -1 @@ -1094,12 +1055,15 @@ class Resolver(object): self.payload = payload def set_flags(self, flags): - """Overrides the default flags with your own + """Overrides the default flags with your own. + + *flags*, an ``int``, the message flags to use. + """ - @param flags: The flags to overwrite the default with - @type flags: int""" self.flags = flags + +#: The default resolver. default_resolver = None @@ -1113,37 +1077,49 @@ def get_default_resolver(): def reset_default_resolver(): """Re-initialize default resolver. - resolv.conf will be re-read immediatelly. + Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX + systems) will be re-read immediately. """ + global default_resolver default_resolver = Resolver() def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False, source=None, raise_on_no_answer=True, - source_port=0): + source_port=0, lifetime=None): """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver object to make the query. - @see: L{dns.resolver.Resolver.query} for more information on the - parameters.""" + + See ``dns.resolver.Resolver.query`` for more information on the + parameters. + """ + return get_default_resolver().query(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer, source_port) + raise_on_no_answer, source_port, + lifetime) def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): """Find the name of the zone which contains the specified name. - @param name: the query name - @type name: absolute dns.name.Name object or string - @param rdclass: The query class - @type rdclass: int - @param tcp: use TCP to make the query (default is False). - @type tcp: bool - @param resolver: the resolver to use - @type resolver: dns.resolver.Resolver object or None - @rtype: dns.name.Name""" + *name*, an absolute ``dns.name.Name`` or ``text``, the query name. + + *rdclass*, an ``int``, the query class. + + *tcp*, a ``bool``. If ``True``, use TCP to make the query. + + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. + If ``None``, the default resolver is used. + + Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS + root. (This is only likely to happen if you're using non-default + root servers in your network and they are misconfigured.) + + Returns a ``dns.name.Name``. + """ if isinstance(name, string_types): name = dns.name.from_text(name, dns.name.root) @@ -1240,7 +1216,7 @@ def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, v4addrs.append(rdata.address) except dns.resolver.NXDOMAIN: raise socket.gaierror(socket.EAI_NONAME) - except: + except Exception: raise socket.gaierror(socket.EAI_SYSTEM) port = None try: @@ -1379,9 +1355,9 @@ def override_system_resolver(resolver=None): The resolver to use may be specified; if it's not, the default resolver will be used. - @param resolver: the resolver to use - @type resolver: dns.resolver.Resolver object or None + resolver, a ``dns.resolver.Resolver`` or ``None``, the resolver to use. """ + if resolver is None: resolver = get_default_resolver() global _resolver @@ -1395,8 +1371,8 @@ def override_system_resolver(resolver=None): def restore_system_resolver(): - """Undo the effects of override_system_resolver(). - """ + """Undo the effects of prior override_system_resolver().""" + global _resolver _resolver = None socket.getaddrinfo = _original_getaddrinfo diff --git a/src/dns/resolver.pyi b/src/dns/resolver.pyi new file mode 100644 index 00000000..e839ec21 --- /dev/null +++ b/src/dns/resolver.pyi @@ -0,0 +1,31 @@ +from typing import Union, Optional, List +from . import exception, rdataclass, name, rdatatype + +import socket +_gethostbyname = socket.gethostbyname +class NXDOMAIN(exception.DNSException): + ... +def query(qname : str, rdtype : Union[int,str] = 0, rdclass : Union[int,str] = 0, + tcp=False, source=None, raise_on_no_answer=True, + source_port=0): + ... +class LRUCache: + def __init__(self, max_size=1000): + ... + def get(self, key): + ... + def put(self, key, val): + ... +class Answer: + def __init__(self, qname, rdtype, rdclass, response, + raise_on_no_answer=True): + ... +def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, resolver : Optional[Resolver] = None): + ... + +class Resolver: + def __init__(self, configure): + self.nameservers : List[str] + def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A, rdclass : Union[int,str] = rdataclass.IN, + tcp : bool = False, source : Optional[str] = None, raise_on_no_answer=True, source_port : int = 0): + ... diff --git a/src/dns/reversename.py b/src/dns/reversename.py index 9ea9395a..8f095fa9 100644 --- a/src/dns/reversename.py +++ b/src/dns/reversename.py @@ -1,4 +1,6 @@ -# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2006-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,21 +15,16 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""DNS Reverse Map Names. - -@var ipv4_reverse_domain: The DNS IPv4 reverse-map domain, in-addr.arpa. -@type ipv4_reverse_domain: dns.name.Name object -@var ipv6_reverse_domain: The DNS IPv6 reverse-map domain, ip6.arpa. -@type ipv6_reverse_domain: dns.name.Name object -""" +"""DNS Reverse Map Names.""" import binascii -import sys import dns.name import dns.ipv6 import dns.ipv4 +from dns._compat import PY3 + ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') @@ -35,15 +32,19 @@ ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') def from_address(text): """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. - @param text: an IPv4 or IPv6 address in textual form (e.g. '127.0.0.1', - '::1') - @type text: str - @rtype: dns.name.Name object + + *text*, a ``text``, is an IPv4 or IPv6 address in textual form + (e.g. '127.0.0.1', '::1') + + Raises ``dns.exception.SyntaxError`` if the address is badly formed. + + Returns a ``dns.name.Name``. """ + try: v6 = dns.ipv6.inet_aton(text) if dns.ipv6.is_mapped(v6): - if sys.version_info >= (3,): + if PY3: parts = ['%d' % byte for byte in v6[12:]] else: parts = ['%d' % ord(byte) for byte in v6[12:]] @@ -61,10 +62,16 @@ def from_address(text): def to_address(name): """Convert a reverse map domain name into textual address form. - @param name: an IPv4 or IPv6 address in reverse-map form. - @type name: dns.name.Name object - @rtype: str + + *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name + form. + + Raises ``dns.exception.SyntaxError`` if the name does not have a + reverse-map form. + + Returns a ``text``. """ + if name.is_subdomain(ipv4_reverse_domain): name = name.relativize(ipv4_reverse_domain) labels = list(name.labels) diff --git a/src/dns/reversename.pyi b/src/dns/reversename.pyi new file mode 100644 index 00000000..97f072ea --- /dev/null +++ b/src/dns/reversename.pyi @@ -0,0 +1,6 @@ +from . import name +def from_address(text : str) -> name.Name: + ... + +def to_address(name : name.Name) -> str: + ... diff --git a/src/dns/rrset.py b/src/dns/rrset.py index d0f8f937..a53ec324 100644 --- a/src/dns/rrset.py +++ b/src/dns/rrset.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -67,10 +69,6 @@ class RRset(dns.rdataset.Rdataset): return self.to_text() def __eq__(self, other): - """Two RRsets are equal if they have the same name and the same - rdataset - - @rtype: bool""" if not isinstance(other, RRset): return False if self.name != other.name: @@ -78,8 +76,9 @@ class RRset(dns.rdataset.Rdataset): return super(RRset, self).__eq__(other) def match(self, name, rdclass, rdtype, covers, deleting=None): - """Returns True if this rrset matches the specified class, type, - covers, and deletion state.""" + """Returns ``True`` if this rrset matches the specified class, type, + covers, and deletion state. + """ if not super(RRset, self).match(rdclass, rdtype, covers): return False @@ -90,23 +89,31 @@ class RRset(dns.rdataset.Rdataset): def to_text(self, origin=None, relativize=True, **kw): """Convert the RRset into DNS master file format. - @see: L{dns.name.Name.choose_relativity} for more information - on how I{origin} and I{relativize} determine the way names + See ``dns.name.Name.choose_relativity`` for more information + on how *origin* and *relativize* determine the way names are emitted. Any additional keyword arguments are passed on to the rdata - to_text() method. + ``to_text()`` method. - @param origin: The origin for relative names, or None. - @type origin: dns.name.Name object - @param relativize: True if names should names be relativized - @type relativize: bool""" + *origin*, a ``dns.name.Name`` or ``None``, the origin for relative + names. + + *relativize*, a ``bool``. If ``True``, names will be relativized + to *origin*. + """ return super(RRset, self).to_text(self.name, origin, relativize, self.deleting, **kw) def to_wire(self, file, compress=None, origin=None, **kw): - """Convert the RRset to wire format.""" + """Convert the RRset to wire format. + + All keyword arguments are passed to ``dns.rdataset.to_wire()``; see + that function for details. + + Returns an ``int``, the number of records emitted. + """ return super(RRset, self).to_wire(self.name, file, compress, origin, self.deleting, **kw) @@ -114,7 +121,7 @@ class RRset(dns.rdataset.Rdataset): def to_rdataset(self): """Convert an RRset into an Rdataset. - @rtype: dns.rdataset.Rdataset object + Returns a ``dns.rdataset.Rdataset``. """ return dns.rdataset.from_rdata_list(self.ttl, list(self)) @@ -124,7 +131,7 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ if isinstance(name, string_types): @@ -145,7 +152,7 @@ def from_text(name, ttl, rdclass, rdtype, *text_rdatas): """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ return from_text_list(name, ttl, rdclass, rdtype, text_rdatas) @@ -155,7 +162,7 @@ def from_rdata_list(name, ttl, rdatas, idna_codec=None): """Create an RRset with the specified name and TTL, and with the specified list of rdata objects. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ if isinstance(name, string_types): @@ -176,7 +183,7 @@ def from_rdata(name, ttl, *rdatas): """Create an RRset with the specified name and TTL, and with the specified rdata objects. - @rtype: dns.rrset.RRset object + Returns a ``dns.rrset.RRset`` object. """ return from_rdata_list(name, ttl, rdatas) diff --git a/src/dns/rrset.pyi b/src/dns/rrset.pyi new file mode 100644 index 00000000..0a81a2a0 --- /dev/null +++ b/src/dns/rrset.pyi @@ -0,0 +1,10 @@ +from typing import List, Optional +from . import rdataset, rdatatype + +class RRset(rdataset.Rdataset): + def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE, + deleting : Optional[int] =None) -> None: + self.name = name + self.deleting = deleting +def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str): + ... diff --git a/src/dns/set.py b/src/dns/set.py index ef7fd295..81329bf4 100644 --- a/src/dns/set.py +++ b/src/dns/set.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -13,27 +15,22 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""A simple Set class.""" - - class Set(object): """A simple set class. - Sets are not in Python until 2.3, and rdata are not immutable so - we cannot use sets.Set anyway. This class implements subset of - the 2.3 Set interface using a list as the container. - - @ivar items: A list of the items which are in the set - @type items: list""" + This class was originally used to deal with sets being missing in + ancient versions of python, but dnspython will continue to use it + as these sets are based on lists and are thus indexable, and this + ability is widely used in dnspython applications. + """ __slots__ = ['items'] def __init__(self, items=None): """Initialize the set. - @param items: the initial set of items - @type items: any iterable or None + *items*, an iterable or ``None``, the initial set of items. """ self.items = [] @@ -45,16 +42,22 @@ class Set(object): return "dns.simpleset.Set(%s)" % repr(self.items) def add(self, item): - """Add an item to the set.""" + """Add an item to the set. + """ + if item not in self.items: self.items.append(item) def remove(self, item): - """Remove an item from the set.""" + """Remove an item from the set. + """ + self.items.remove(item) def discard(self, item): - """Remove an item from the set if present.""" + """Remove an item from the set if present. + """ + try: self.items.remove(item) except ValueError: @@ -79,19 +82,22 @@ class Set(object): return obj def __copy__(self): - """Make a (shallow) copy of the set.""" + """Make a (shallow) copy of the set. + """ + return self._clone() def copy(self): - """Make a (shallow) copy of the set.""" + """Make a (shallow) copy of the set. + """ + return self._clone() def union_update(self, other): """Update the set, adding any elements from other which are not already in the set. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: @@ -102,9 +108,8 @@ class Set(object): def intersection_update(self, other): """Update the set, removing any elements from other which are not in both sets. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: @@ -118,9 +123,8 @@ class Set(object): def difference_update(self, other): """Update the set, removing any elements from other which are in the set. - @param other: the collection of items with which to update the set - @type other: Set object """ + if not isinstance(other, Set): raise ValueError('other must be a Set instance') if self is other: @@ -130,11 +134,9 @@ class Set(object): self.discard(item) def union(self, other): - """Return a new set which is the union of I{self} and I{other}. + """Return a new set which is the union of ``self`` and ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -142,11 +144,10 @@ class Set(object): return obj def intersection(self, other): - """Return a new set which is the intersection of I{self} and I{other}. + """Return a new set which is the intersection of ``self`` and + ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -154,12 +155,10 @@ class Set(object): return obj def difference(self, other): - """Return a new set which I{self} - I{other}, i.e. the items - in I{self} which are not also in I{other}. + """Return a new set which ``self`` - ``other``, i.e. the items + in ``self`` which are not also in ``other``. - @param other: the other set - @type other: Set object - @rtype: the same type as I{self} + Returns the same Set type as this set. """ obj = self._clone() @@ -197,8 +196,11 @@ class Set(object): def update(self, other): """Update the set, adding any elements from other which are not already in the set. - @param other: the collection of items with which to update the set - @type other: any iterable type""" + + *other*, the collection of items with which to update the set, which + may be any iterable type. + """ + for item in other: self.add(item) @@ -233,9 +235,9 @@ class Set(object): del self.items[i] def issubset(self, other): - """Is I{self} a subset of I{other}? + """Is this set a subset of *other*? - @rtype: bool + Returns a ``bool``. """ if not isinstance(other, Set): @@ -246,9 +248,9 @@ class Set(object): return True def issuperset(self, other): - """Is I{self} a superset of I{other}? + """Is this set a superset of *other*? - @rtype: bool + Returns a ``bool``. """ if not isinstance(other, Set): diff --git a/src/dns/tokenizer.py b/src/dns/tokenizer.py index 04b98254..880b71ce 100644 --- a/src/dns/tokenizer.py +++ b/src/dns/tokenizer.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -44,32 +46,20 @@ DELIMITER = 6 class UngetBufferFull(dns.exception.DNSException): - """An attempt was made to unget a token when the unget buffer was full.""" class Token(object): - """A DNS master file format token. - @ivar ttype: The token type - @type ttype: int - @ivar value: The token value - @type value: string - @ivar has_escape: Does the token value contain escapes? - @type has_escape: bool + ttype: The token type + value: The token value + has_escape: Does the token value contain escapes? """ def __init__(self, ttype, value='', has_escape=False): - """Initialize a token instance. + """Initialize a token instance.""" - @param ttype: The token type - @type ttype: int - @param value: The token value - @type value: string - @param has_escape: Does the token value contain escapes? - @type has_escape: bool - """ self.ttype = ttype self.value = value self.has_escape = has_escape @@ -160,46 +150,43 @@ class Token(object): class Tokenizer(object): - """A DNS master file format tokenizer. - A token is a (type, value) tuple, where I{type} is an int, and - I{value} is a string. The valid types are EOF, EOL, WHITESPACE, - IDENTIFIER, QUOTED_STRING, COMMENT, and DELIMITER. + A token object is basically a (type, value) tuple. The valid + types are EOF, EOL, WHITESPACE, IDENTIFIER, QUOTED_STRING, + COMMENT, and DELIMITER. - @ivar file: The file to tokenize - @type file: file - @ivar ungotten_char: The most recently ungotten character, or None. - @type ungotten_char: string - @ivar ungotten_token: The most recently ungotten token, or None. - @type ungotten_token: (int, string) token tuple - @ivar multiline: The current multiline level. This value is increased + file: The file to tokenize + + ungotten_char: The most recently ungotten character, or None. + + ungotten_token: The most recently ungotten token, or None. + + multiline: The current multiline level. This value is increased by one every time a '(' delimiter is read, and decreased by one every time a ')' delimiter is read. - @type multiline: int - @ivar quoting: This variable is true if the tokenizer is currently + + quoting: This variable is true if the tokenizer is currently reading a quoted string. - @type quoting: bool - @ivar eof: This variable is true if the tokenizer has encountered EOF. - @type eof: bool - @ivar delimiters: The current delimiter dictionary. - @type delimiters: dict - @ivar line_number: The current line number - @type line_number: int - @ivar filename: A filename that will be returned by the L{where} method. - @type filename: string + + eof: This variable is true if the tokenizer has encountered EOF. + + delimiters: The current delimiter dictionary. + + line_number: The current line number + + filename: A filename that will be returned by the where() method. """ def __init__(self, f=sys.stdin, filename=None): """Initialize a tokenizer instance. - @param f: The file to tokenize. The default is sys.stdin. + f: The file to tokenize. The default is sys.stdin. This parameter may also be a string, in which case the tokenizer will take its input from the contents of the string. - @type f: file or string - @param filename: the name of the filename that the L{where} method + + filename: the name of the filename that the where() method will return. - @type filename: string """ if isinstance(f, text_type): @@ -228,7 +215,6 @@ class Tokenizer(object): def _get_char(self): """Read a character from input. - @rtype: string """ if self.ungotten_char is None: @@ -248,7 +234,7 @@ class Tokenizer(object): def where(self): """Return the current location in the input. - @rtype: (string, int) tuple. The first item is the filename of + Returns a (string, int) tuple. The first item is the filename of the input, the second is the current line number. """ @@ -261,9 +247,8 @@ class Tokenizer(object): an error to try to unget a character when the unget buffer is not empty. - @param c: the character to unget - @type c: string - @raises UngetBufferFull: there is already an ungotten char + c: the character to unget + raises UngetBufferFull: there is already an ungotten char """ if self.ungotten_char is not None: @@ -278,7 +263,7 @@ class Tokenizer(object): If the tokenizer is in multiline mode, then newlines are whitespace. - @rtype: int + Returns the number of characters skipped. """ skipped = 0 @@ -293,15 +278,17 @@ class Tokenizer(object): def get(self, want_leading=False, want_comment=False): """Get the next token. - @param want_leading: If True, return a WHITESPACE token if the + want_leading: If True, return a WHITESPACE token if the first character read is whitespace. The default is False. - @type want_leading: bool - @param want_comment: If True, return a COMMENT token if the + + want_comment: If True, return a COMMENT token if the first token read is a comment. The default is False. - @type want_comment: bool - @rtype: Token object - @raises dns.exception.UnexpectedEnd: input ended prematurely - @raises dns.exception.SyntaxError: input was badly formed + + Raises dns.exception.UnexpectedEnd: input ended prematurely + + Raises dns.exception.SyntaxError: input was badly formed + + Returns a Token. """ if self.ungotten_token is not None: @@ -420,9 +407,9 @@ class Tokenizer(object): an error to try to unget a token when the unget buffer is not empty. - @param token: the token to unget - @type token: Token object - @raises UngetBufferFull: there is already an ungotten token + token: the token to unget + + Raises UngetBufferFull: there is already an ungotten token """ if self.ungotten_token is not None: @@ -431,7 +418,8 @@ class Tokenizer(object): def next(self): """Return the next item in an iteration. - @rtype: (int, string) + + Returns a Token. """ token = self.get() @@ -446,11 +434,12 @@ class Tokenizer(object): # Helpers - def get_int(self): + def get_int(self, base=10): """Read the next token and interpret it as an integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not an integer. + + Returns an int. """ token = self.get().unescape() @@ -458,14 +447,15 @@ class Tokenizer(object): raise dns.exception.SyntaxError('expecting an identifier') if not token.value.isdigit(): raise dns.exception.SyntaxError('expecting an integer') - return int(token.value) + return int(token.value, base) def get_uint8(self): """Read the next token and interpret it as an 8-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not an 8-bit unsigned integer. + + Returns an int. """ value = self.get_int() @@ -474,26 +464,32 @@ class Tokenizer(object): '%d is not an unsigned 8-bit integer' % value) return value - def get_uint16(self): + def get_uint16(self, base=10): """Read the next token and interpret it as a 16-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not a 16-bit unsigned integer. + + Returns an int. """ - value = self.get_int() + value = self.get_int(base=base) if value < 0 or value > 65535: - raise dns.exception.SyntaxError( - '%d is not an unsigned 16-bit integer' % value) + if base == 8: + raise dns.exception.SyntaxError( + '%o is not an octal unsigned 16-bit integer' % value) + else: + raise dns.exception.SyntaxError( + '%d is not an unsigned 16-bit integer' % value) return value def get_uint32(self): """Read the next token and interpret it as a 32-bit unsigned integer. - @raises dns.exception.SyntaxError: - @rtype: int + Raises dns.exception.SyntaxError if not a 32-bit unsigned integer. + + Returns an int. """ token = self.get().unescape() @@ -510,8 +506,9 @@ class Tokenizer(object): def get_string(self, origin=None): """Read the next token and interpret it as a string. - @raises dns.exception.SyntaxError: - @rtype: string + Raises dns.exception.SyntaxError if not a string. + + Returns a string. """ token = self.get().unescape() @@ -520,10 +517,11 @@ class Tokenizer(object): return token.value def get_identifier(self, origin=None): - """Read the next token and raise an exception if it is not an identifier. + """Read the next token, which should be an identifier. - @raises dns.exception.SyntaxError: - @rtype: string + Raises dns.exception.SyntaxError if not an identifier. + + Returns a string. """ token = self.get().unescape() @@ -534,8 +532,10 @@ class Tokenizer(object): def get_name(self, origin=None): """Read the next token and interpret it as a DNS name. - @raises dns.exception.SyntaxError: - @rtype: dns.name.Name object""" + Raises dns.exception.SyntaxError if not a name. + + Returns a dns.name.Name. + """ token = self.get() if not token.is_identifier(): @@ -546,8 +546,7 @@ class Tokenizer(object): """Read the next token and raise an exception if it isn't EOL or EOF. - @raises dns.exception.SyntaxError: - @rtype: string + Returns a string. """ token = self.get() @@ -558,6 +557,14 @@ class Tokenizer(object): return token.value def get_ttl(self): + """Read the next token and interpret it as a DNS TTL. + + Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an + identifier or badly formed. + + Returns an int. + """ + token = self.get().unescape() if not token.is_identifier(): raise dns.exception.SyntaxError('expecting an identifier') diff --git a/src/dns/tsig.py b/src/dns/tsig.py index c57d879f..3daa3878 100644 --- a/src/dns/tsig.py +++ b/src/dns/tsig.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2001-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -15,11 +17,11 @@ """DNS TSIG support.""" +import hashlib import hmac import struct import dns.exception -import dns.hash import dns.rdataclass import dns.name from ._compat import long, string_types, text_type @@ -68,12 +70,12 @@ HMAC_SHA384 = dns.name.from_text("hmac-sha384") HMAC_SHA512 = dns.name.from_text("hmac-sha512") _hashes = { - HMAC_SHA224: 'SHA224', - HMAC_SHA256: 'SHA256', - HMAC_SHA384: 'SHA384', - HMAC_SHA512: 'SHA512', - HMAC_SHA1: 'SHA1', - HMAC_MD5: 'MD5', + HMAC_SHA224: hashlib.sha224, + HMAC_SHA256: hashlib.sha256, + HMAC_SHA384: hashlib.sha384, + HMAC_SHA512: hashlib.sha512, + HMAC_SHA1: hashlib.sha1, + HMAC_MD5: hashlib.md5, } default_algorithm = HMAC_MD5 @@ -211,7 +213,7 @@ def get_algorithm(algorithm): algorithm = dns.name.from_text(algorithm) try: - return (algorithm.to_digestable(), dns.hash.hashes[_hashes[algorithm]]) + return (algorithm.to_digestable(), _hashes[algorithm]) except KeyError: raise NotImplementedError("TSIG algorithm " + str(algorithm) + " is not supported") diff --git a/src/dns/tsigkeyring.py b/src/dns/tsigkeyring.py index 01f87027..5e5fe1cb 100644 --- a/src/dns/tsigkeyring.py +++ b/src/dns/tsigkeyring.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its diff --git a/src/dns/tsigkeyring.pyi b/src/dns/tsigkeyring.pyi new file mode 100644 index 00000000..b5d51e15 --- /dev/null +++ b/src/dns/tsigkeyring.pyi @@ -0,0 +1,7 @@ +from typing import Dict +from . import name + +def from_text(textring : Dict[str,str]) -> Dict[name.Name,bytes]: + ... +def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]: + ... diff --git a/src/dns/ttl.py b/src/dns/ttl.py index a27d8251..4be16bee 100644 --- a/src/dns/ttl.py +++ b/src/dns/ttl.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -20,7 +22,6 @@ from ._compat import long class BadTTL(dns.exception.SyntaxError): - """DNS TTL value is not well-formed.""" @@ -29,10 +30,11 @@ def from_text(text): The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported. - @param text: the textual TTL - @type text: string - @raises dns.ttl.BadTTL: the TTL is not well-formed - @rtype: int + *text*, a ``text``, the textual TTL. + + Raises ``dns.ttl.BadTTL`` if the TTL is not well-formed. + + Returns an ``int``. """ if text.isdigit(): diff --git a/src/dns/update.py b/src/dns/update.py index 59728d98..96a00d5d 100644 --- a/src/dns/update.py +++ b/src/dns/update.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -32,26 +34,25 @@ class Update(dns.message.Message): keyname=None, keyalgorithm=dns.tsig.default_algorithm): """Initialize a new DNS Update object. - @param zone: The zone which is being updated. - @type zone: A dns.name.Name or string - @param rdclass: The class of the zone; defaults to dns.rdataclass.IN. - @type rdclass: An int designating the class, or a string whose value - is the name of a class. - @param keyring: The TSIG keyring to use; defaults to None. - @type keyring: dict - @param keyname: The name of the TSIG key to use; defaults to None. - The key must be defined in the keyring. If a keyring is specified - but a keyname is not, then the key used will be the first key in the - keyring. Note that the order of keys in a dictionary is not defined, - so applications should supply a keyname when a keyring is used, unless - they know the keyring contains only one key. - @type keyname: dns.name.Name or string - @param keyalgorithm: The TSIG algorithm to use; defaults to - dns.tsig.default_algorithm. Constants for TSIG algorithms are defined - in dns.tsig, and the currently implemented algorithms are - HMAC_MD5, HMAC_SHA1, HMAC_SHA224, HMAC_SHA256, HMAC_SHA384, and - HMAC_SHA512. - @type keyalgorithm: string + See the documentation of the Message class for a complete + description of the keyring dictionary. + + *zone*, a ``dns.name.Name`` or ``text``, the zone which is being + updated. + + *rdclass*, an ``int`` or ``text``, the class of the zone. + + *keyring*, a ``dict``, the TSIG keyring to use. If a + *keyring* is specified but a *keyname* is not, then the key + used will be the first key in the *keyring*. Note that the + order of keys in a dictionary is not defined, so applications + should supply a keyname when a keyring is used, unless they + know the keyring contains only one key. + + *keyname*, a ``dns.name.Name`` or ``None``, the name of the TSIG key + to use; defaults to ``None``. The key must be defined in the keyring. + + *keyalgorithm*, a ``dns.name.Name``, the TSIG algorithm to use. """ super(Update, self).__init__() self.flags |= dns.opcode.to_flags(dns.opcode.UPDATE) @@ -77,8 +78,10 @@ class Update(dns.message.Message): rrset.add(rd, ttl) def _add(self, replace, section, name, *args): - """Add records. The first argument is the replace mode. If - false, RRs are added to an existing RRset; if true, the RRset + """Add records. + + *replace* is the replacement mode. If ``False``, + RRs are added to an existing RRset; if ``True``, the RRset is replaced with the specified contents. The second argument is the section to add to. The third argument is always a name. The other arguments can be: @@ -87,7 +90,8 @@ class Update(dns.message.Message): - ttl, rdata... - - ttl, rdtype, string...""" + - ttl, rdtype, string... + """ if isinstance(name, string_types): name = dns.name.from_text(name, None) @@ -117,27 +121,34 @@ class Update(dns.message.Message): self._add_rr(name, ttl, rd, section=section) def add(self, name, *args): - """Add records. The first argument is always a name. The other + """Add records. + + The first argument is always a name. The other arguments can be: - rdataset... - ttl, rdata... - - ttl, rdtype, string...""" + - ttl, rdtype, string... + """ + self._add(False, self.authority, name, *args) def delete(self, name, *args): - """Delete records. The first argument is always a name. The other + """Delete records. + + The first argument is always a name. The other arguments can be: - - I{nothing} + - *empty* - rdataset... - rdata... - - rdtype, [string...]""" + - rdtype, [string...] + """ if isinstance(name, string_types): name = dns.name.from_text(name, None) @@ -171,7 +182,9 @@ class Update(dns.message.Message): self._add_rr(name, 0, rd, dns.rdataclass.NONE) def replace(self, name, *args): - """Replace records. The first argument is always a name. The other + """Replace records. + + The first argument is always a name. The other arguments can be: - rdataset... @@ -181,21 +194,25 @@ class Update(dns.message.Message): - ttl, rdtype, string... Note that if you want to replace the entire node, you should do - a delete of the name followed by one or more calls to add.""" + a delete of the name followed by one or more calls to add. + """ self._add(True, self.authority, name, *args) def present(self, name, *args): """Require that an owner name (and optionally an rdata type, or specific rdataset) exists as a prerequisite to the - execution of the update. The first argument is always a name. + execution of the update. + + The first argument is always a name. The other arguments can be: - rdataset... - rdata... - - rdtype, string...""" + - rdtype, string... + """ if isinstance(name, string_types): name = dns.name.from_text(name, None) @@ -243,7 +260,20 @@ class Update(dns.message.Message): def to_wire(self, origin=None, max_size=65535): """Return a string containing the update in DNS compressed wire format. - @rtype: string""" + + *origin*, a ``dns.name.Name`` or ``None``, the origin to be + appended to any relative names. If *origin* is ``None``, then + the origin of the ``dns.update.Update`` message object is used + (i.e. the *zone* parameter passed when the Update object was + created). + + *max_size*, an ``int``, the maximum size of the wire format + output; default is 0, which means "the message's request + payload, if nonzero, or 65535". + + Returns a ``binary``. + """ + if origin is None: origin = self.origin return super(Update, self).to_wire(origin, max_size) diff --git a/src/dns/update.pyi b/src/dns/update.pyi new file mode 100644 index 00000000..eeac0591 --- /dev/null +++ b/src/dns/update.pyi @@ -0,0 +1,21 @@ +from typing import Optional,Dict,Union,Any + +from . import message, tsig, rdataclass, name + +class Update(message.Message): + def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None, + keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None: + self.id : int + def add(self, name : Union[str,name.Name], *args : Any): + ... + def delete(self, name, *args : Any): + ... + def replace(self, name : Union[str,name.Name], *args : Any): + ... + def present(self, name : Union[str,name.Name], *args : Any): + ... + def absent(self, name : Union[str,name.Name], rdtype=None): + """Require that an owner name (and optionally an rdata type) does + not exist as a prerequisite to the execution of the update.""" + def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes: + ... diff --git a/src/dns/version.py b/src/dns/version.py index 9e8dbb1b..f116904b 100644 --- a/src/dns/version.py +++ b/src/dns/version.py @@ -1,4 +1,6 @@ -# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,13 +17,19 @@ """dnspython release version information.""" +#: MAJOR MAJOR = 1 -MINOR = 15 +#: MINOR +MINOR = 16 +#: MICRO MICRO = 0 +#: RELEASELEVEL RELEASELEVEL = 0x0f +#: SERIAL SERIAL = 0 if RELEASELEVEL == 0x0f: + #: version version = '%d.%d.%d' % (MAJOR, MINOR, MICRO) elif RELEASELEVEL == 0x00: version = '%d.%d.%dx%d' % \ @@ -30,5 +38,6 @@ else: version = '%d.%d.%d%x%d' % \ (MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL) +#: hexversion hexversion = MAJOR << 24 | MINOR << 16 | MICRO << 8 | RELEASELEVEL << 4 | \ SERIAL diff --git a/src/dns/wiredata.py b/src/dns/wiredata.py index ccef5954..ea3c1e67 100644 --- a/src/dns/wiredata.py +++ b/src/dns/wiredata.py @@ -1,4 +1,6 @@ -# Copyright (C) 2011 Nominum, Inc. +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2011,2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -15,10 +17,8 @@ """DNS Wire Data Helper""" -import sys - import dns.exception -from ._compat import binary_type, string_types +from ._compat import binary_type, string_types, PY2 # Figure out what constant python passes for an unspecified slice bound. # It's supposed to be sys.maxint, yet on 64-bit windows sys.maxint is 2^31 - 1 @@ -32,7 +32,7 @@ class _SliceUnspecifiedBound(binary_type): def __getitem__(self, key): return key.stop - if sys.version_info < (3,): + if PY2: def __getslice__(self, i, j): # pylint: disable=getslice-method return self.__getitem__(slice(i, j)) @@ -40,7 +40,7 @@ _unspecified_bound = _SliceUnspecifiedBound()[1:] class WireData(binary_type): - # WireData is a string with stricter slicing + # WireData is a binary type with stricter slicing def __getitem__(self, key): try: @@ -51,7 +51,7 @@ class WireData(binary_type): start = key.start stop = key.stop - if sys.version_info < (3,): + if PY2: if stop == _unspecified_bound: # handle the case where the right bound is unspecified stop = len(self) @@ -76,7 +76,7 @@ class WireData(binary_type): except IndexError: raise dns.exception.FormError - if sys.version_info < (3,): + if PY2: def __getslice__(self, i, j): # pylint: disable=getslice-method return self.__getitem__(slice(i, j)) diff --git a/src/dns/zone.py b/src/dns/zone.py index 468618f6..1e2fe781 100644 --- a/src/dns/zone.py +++ b/src/dns/zone.py @@ -1,3 +1,5 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its @@ -28,14 +30,12 @@ import dns.node import dns.rdataclass import dns.rdatatype import dns.rdata +import dns.rdtypes.ANY.SOA import dns.rrset import dns.tokenizer import dns.ttl import dns.grange -from ._compat import string_types, text_type - - -_py3 = sys.version_info > (3,) +from ._compat import string_types, text_type, PY3 class BadZone(dns.exception.DNSException): @@ -157,25 +157,25 @@ class Zone(object): return self.nodes.__iter__() def iterkeys(self): - if _py3: - return self.nodes.keys() + if PY3: + return self.nodes.keys() # pylint: disable=dict-keys-not-iterating else: return self.nodes.iterkeys() # pylint: disable=dict-iter-method def keys(self): - return self.nodes.keys() + return self.nodes.keys() # pylint: disable=dict-keys-not-iterating def itervalues(self): - if _py3: - return self.nodes.values() + if PY3: + return self.nodes.values() # pylint: disable=dict-values-not-iterating else: return self.nodes.itervalues() # pylint: disable=dict-iter-method def values(self): - return self.nodes.values() + return self.nodes.values() # pylint: disable=dict-values-not-iterating def items(self): - return self.nodes.items() + return self.nodes.items() # pylint: disable=dict-items-not-iterating iteritems = items @@ -261,7 +261,7 @@ class Zone(object): exist? @type create: bool @raises KeyError: the node or rdata could not be found - @rtype: dns.rrset.RRset object + @rtype: dns.rdataset.Rdataset object """ name = self._validate_name(name) @@ -296,7 +296,7 @@ class Zone(object): @param create: should the node and rdataset be created if they do not exist? @type create: bool - @rtype: dns.rrset.RRset object + @rtype: dns.rdataset.Rdataset object or None """ try: @@ -451,7 +451,7 @@ class Zone(object): rdtype = dns.rdatatype.from_text(rdtype) if isinstance(covers, string_types): covers = dns.rdatatype.from_text(covers) - for (name, node) in self.iteritems(): + for (name, node) in self.iteritems(): # pylint: disable=dict-iter-method for rds in node: if rdtype == dns.rdatatype.ANY or \ (rds.rdtype == rdtype and rds.covers == covers): @@ -474,7 +474,7 @@ class Zone(object): rdtype = dns.rdatatype.from_text(rdtype) if isinstance(covers, string_types): covers = dns.rdatatype.from_text(covers) - for (name, node) in self.iteritems(): + for (name, node) in self.iteritems(): # pylint: disable=dict-iter-method for rds in node: if rdtype == dns.rdatatype.ANY or \ (rds.rdtype == rdtype and rds.covers == covers): @@ -525,7 +525,7 @@ class Zone(object): names = list(self.keys()) names.sort() else: - names = self.iterkeys() + names = self.iterkeys() # pylint: disable=dict-iter-method for n in names: l = self[n].to_text(n, origin=self.origin, relativize=relativize) @@ -589,8 +589,14 @@ class _MasterReader(object): @ivar tok: The tokenizer @type tok: dns.tokenizer.Tokenizer object - @ivar ttl: The default TTL - @type ttl: int + @ivar last_ttl: The last seen explicit TTL for an RR + @type last_ttl: int + @ivar last_ttl_known: Has last TTL been detected + @type last_ttl_known: bool + @ivar default_ttl: The default TTL from a $TTL directive or SOA RR + @type default_ttl: int + @ivar default_ttl_known: Has default TTL been detected + @type default_ttl_known: bool @ivar last_name: The last name read @type last_name: dns.name.Name object @ivar current_origin: The current origin @@ -600,8 +606,8 @@ class _MasterReader(object): @ivar zone: the zone @type zone: dns.zone.Zone object @ivar saved_state: saved reader state (used when processing $INCLUDE) - @type saved_state: list of (tokenizer, current_origin, last_name, file) - tuples. + @type saved_state: list of (tokenizer, current_origin, last_name, file, + last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples. @ivar current_file: the file object of the $INCLUDed file being parsed (None if no $INCLUDE is active). @ivar allow_include: is $INCLUDE allowed? @@ -618,7 +624,10 @@ class _MasterReader(object): self.tok = tok self.current_origin = origin self.relativize = relativize - self.ttl = 0 + self.last_ttl = 0 + self.last_ttl_known = False + self.default_ttl = 0 + self.default_ttl_known = False self.last_name = self.current_origin self.zone = zone_factory(origin, rdclass, relativize=relativize) self.saved_state = [] @@ -659,11 +668,18 @@ class _MasterReader(object): # TTL try: ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError except dns.ttl.BadTTL: - ttl = self.ttl + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + else: + ttl = self.last_ttl # Class try: rdclass = dns.rdataclass.from_text(token.value) @@ -701,7 +717,14 @@ class _MasterReader(object): # helpful filename:line info. (ty, va) = sys.exc_info()[:2] raise dns.exception.SyntaxError( - "caught exception %s: %s" % (str(ty), str(va))) + "caught exception {}: {}".format(str(ty), str(va))) + + if not self.default_ttl_known and isinstance(rd, dns.rdtypes.ANY.SOA.SOA): + # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default + # TTL from the SOA minttl if no $TTL statement is present before the + # SOA is parsed. + self.default_ttl = rd.minimum + self.default_ttl_known = True rd.choose_relativity(self.zone.origin, self.relativize) covers = rd.covers() @@ -778,11 +801,18 @@ class _MasterReader(object): # TTL try: ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError except dns.ttl.BadTTL: - ttl = self.ttl + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + else: + ttl = self.last_ttl # Class try: rdclass = dns.rdataclass.from_text(token.value) @@ -884,7 +914,10 @@ class _MasterReader(object): self.current_origin, self.last_name, self.current_file, - self.ttl) = self.saved_state.pop(-1) + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known) = self.saved_state.pop(-1) continue break elif token.is_eol(): @@ -898,7 +931,8 @@ class _MasterReader(object): token = self.tok.get() if not token.is_identifier(): raise dns.exception.SyntaxError("bad $TTL") - self.ttl = dns.ttl.from_text(token.value) + self.default_ttl = dns.ttl.from_text(token.value) + self.default_ttl_known = True self.tok.get_eol() elif c == u'$ORIGIN': self.current_origin = self.tok.get_name() @@ -923,7 +957,10 @@ class _MasterReader(object): self.current_origin, self.last_name, self.current_file, - self.ttl)) + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known)) self.current_file = open(filename, 'r') self.tok = dns.tokenizer.Tokenizer(self.current_file, filename) @@ -1024,7 +1061,10 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, """ str_type = string_types - opts = 'rU' + if PY3: + opts = 'r' + else: + opts = 'rU' if isinstance(f, str_type): if filename is None: diff --git a/src/dns/zone.pyi b/src/dns/zone.pyi new file mode 100644 index 00000000..911d7a01 --- /dev/null +++ b/src/dns/zone.pyi @@ -0,0 +1,55 @@ +from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict +from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype + +class BadZone(exception.DNSException): ... +class NoSOA(BadZone): ... +class NoNS(BadZone): ... +class UnknownOrigin(BadZone): ... + +class Zone: + def __getitem__(self, key : str) -> node.Node: + ... + def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None: + self.nodes : Dict[str,node.Node] + self.origin = origin + def values(self): + return self.nodes.values() + def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]: + ... + def __iter__(self) -> Iterator[str]: + ... + def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]: + ... + def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset: + ... + def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, + create=False) -> rdataset.Rdataset: + ... + def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]: + ... + def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]: + ... + def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None: + ... + def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None: + ... + def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY, + covers : Union[str,int] =rdatatype.NONE): + ... + def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None): + ... + def to_text(self, sorted=True, relativize=True, nl : Optional[bytes] = None) -> bytes: + ... + +def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True): + ... + +def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN, + relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None, + allow_include=False, check_origin=True) -> zone.Zone: + ... + +def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN, + relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None, + allow_include=True, check_origin=True) -> zone.Zone: + ... diff --git a/src/google/auth/_default.py b/src/google/auth/_default.py index 1f75be05..c93b4896 100644 --- a/src/google/auth/_default.py +++ b/src/google/auth/_default.py @@ -41,7 +41,7 @@ _HELP_MESSAGE = """\ Could not automatically determine credentials. Please set {env} or \ explicitly create credentials and re-run the application. For more \ information, please see \ -https://developers.google.com/accounts/docs/application-default-credentials. +https://cloud.google.com/docs/authentication/getting-started """.format(env=environment_vars.CREDENTIALS).strip() # Warning when using Cloud SDK user credentials @@ -51,7 +51,7 @@ Cloud SDK. We recommend that most server applications use service accounts \ instead. If your application continues to use end user credentials from Cloud \ SDK, you might receive a "quota exceeded" or "API not enabled" error. For \ more information about service accounts, see \ -https://cloud.google.com/docs/authentication/.""" +https://cloud.google.com/docs/authentication/""" def _warn_about_problematic_credentials(credentials): diff --git a/src/google/auth/_oauth2client.py b/src/google/auth/_oauth2client.py index 71fd7bf4..afe7dc45 100644 --- a/src/google/auth/_oauth2client.py +++ b/src/google/auth/_oauth2client.py @@ -25,6 +25,7 @@ import six from google.auth import _helpers import google.auth.app_engine +import google.auth.compute_engine import google.oauth2.credentials import google.oauth2.service_account @@ -37,7 +38,7 @@ except ImportError as caught_exc: ImportError('oauth2client is not installed.'), caught_exc) try: - import oauth2client.contrib.appengine + import oauth2client.contrib.appengine # pytype: disable=import-error _HAS_APPENGINE = True except ImportError: _HAS_APPENGINE = False diff --git a/src/google/auth/app_engine.py b/src/google/auth/app_engine.py index f47dae12..91ba8427 100644 --- a/src/google/auth/app_engine.py +++ b/src/google/auth/app_engine.py @@ -28,10 +28,12 @@ from google.auth import _helpers from google.auth import credentials from google.auth import crypt +# pytype: disable=import-error try: from google.appengine.api import app_identity except ImportError: app_identity = None +# pytype: enable=import-error class Signer(crypt.Signer): diff --git a/src/google/auth/impersonated_credentials.py b/src/google/auth/impersonated_credentials.py new file mode 100644 index 00000000..32dfe830 --- /dev/null +++ b/src/google/auth/impersonated_credentials.py @@ -0,0 +1,231 @@ +# Copyright 2018 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Cloud Impersonated credentials. + +This module provides authentication for applications where local credentials +impersonates a remote service account using `IAM Credentials API`_. + +This class can be used to impersonate a service account as long as the original +Credential object has the "Service Account Token Creator" role on the target +service account. + + .. _IAM Credentials API: + https://cloud.google.com/iam/credentials/reference/rest/ +""" + +import copy +from datetime import datetime +import json + +import six +from six.moves import http_client + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions + +_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds + +_IAM_SCOPE = ['https://www.googleapis.com/auth/iam'] + +_IAM_ENDPOINT = ('https://iamcredentials.googleapis.com/v1/projects/-' + + '/serviceAccounts/{}:generateAccessToken') + +_REFRESH_ERROR = 'Unable to acquire impersonated credentials' + + +def _make_iam_token_request(request, principal, headers, body): + """Makes a request to the Google Cloud IAM service for an access token. + Args: + request (Request): The Request object to use. + principal (str): The principal to request an access token for. + headers (Mapping[str, str]): Map of headers to transmit. + body (Mapping[str, str]): JSON Payload body for the iamcredentials + API call. + + Raises: + TransportError: Raised if there is an underlying HTTP connection + Error + DefaultCredentialsError: Raised if the impersonated credentials + are not available. Common reasons are + `iamcredentials.googleapis.com` is not enabled or the + `Service Account Token Creator` is not assigned + """ + iam_endpoint = _IAM_ENDPOINT.format(principal) + + body = json.dumps(body) + + response = request( + url=iam_endpoint, + method='POST', + headers=headers, + body=body) + + response_body = response.data.decode('utf-8') + + if response.status != http_client.OK: + exceptions.RefreshError(_REFRESH_ERROR, response_body) + + try: + token_response = json.loads(response.data.decode('utf-8')) + token = token_response['accessToken'] + expiry = datetime.strptime( + token_response['expireTime'], '%Y-%m-%dT%H:%M:%SZ') + + return token, expiry + + except (KeyError, ValueError) as caught_exc: + new_exc = exceptions.RefreshError( + '{}: No access token or invalid expiration in response.'.format( + _REFRESH_ERROR), + response_body) + six.raise_from(new_exc, caught_exc) + + +class Credentials(credentials.Credentials): + """This module defines impersonated credentials which are essentially + impersonated identities. + + Impersonated Credentials allows credentials issued to a user or + service account to impersonate another. The target service account must + grant the originating credential principal the + `Service Account Token Creator`_ IAM role: + + For more information about Token Creator IAM role and + IAMCredentials API, see + `Creating Short-Lived Service Account Credentials`_. + + .. _Service Account Token Creator: + https://cloud.google.com/iam/docs/service-accounts#the_service_account_token_creator_role + + .. _Creating Short-Lived Service Account Credentials: + https://cloud.google.com/iam/docs/creating-short-lived-service-account-credentials + + Usage: + + First grant source_credentials the `Service Account Token Creator` + role on the target account to impersonate. In this example, the + service account represented by svc_account.json has the + token creator role on + `impersonated-account@_project_.iam.gserviceaccount.com`. + + Enable the IAMCredentials API on the source project: + `gcloud services enable iamcredentials.googleapis.com`. + + Initialize a source credential which does not have access to + list bucket:: + + from google.oauth2 import service_acccount + + target_scopes = [ + 'https://www.googleapis.com/auth/devstorage.read_only'] + + source_credentials = ( + service_account.Credentials.from_service_account_file( + '/path/to/svc_account.json', + scopes=target_scopes)) + + Now use the source credentials to acquire credentials to impersonate + another service account:: + + from google.auth import impersonated_credentials + + target_credentials = impersonated_credentials.Credentials( + source_credentials=source_credentials, + target_principal='impersonated-account@_project_.iam.gserviceaccount.com', + target_scopes = target_scopes, + lifetime=500) + + Resource access is granted:: + + client = storage.Client(credentials=target_credentials) + buckets = client.list_buckets(project='your_project') + for bucket in buckets: + print bucket.name + """ + + def __init__(self, source_credentials, target_principal, + target_scopes, delegates=None, + lifetime=_DEFAULT_TOKEN_LIFETIME_SECS): + """ + Args: + source_credentials (google.auth.Credentials): The source credential + used as to acquire the impersonated credentials. + target_principal (str): The service account to impersonate. + target_scopes (Sequence[str]): Scopes to request during the + authorization grant. + delegates (Sequence[str]): The chained list of delegates required + to grant the final access_token. If set, the sequence of + identities must have "Service Account Token Creator" capability + granted to the prceeding identity. For example, if set to + [serviceAccountB, serviceAccountC], the source_credential + must have the Token Creator role on serviceAccountB. + serviceAccountB must have the Token Creator on serviceAccountC. + Finally, C must have Token Creator on target_principal. + If left unset, source_credential must have that role on + target_principal. + lifetime (int): Number of seconds the delegated credential should + be valid for (upto 3600). + """ + + super(Credentials, self).__init__() + + self._source_credentials = copy.copy(source_credentials) + self._source_credentials._scopes = _IAM_SCOPE + self._target_principal = target_principal + self._target_scopes = target_scopes + self._delegates = delegates + self._lifetime = lifetime + self.token = None + self.expiry = _helpers.utcnow() + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + self._update_token(request) + + @property + def expired(self): + return _helpers.utcnow() >= self.expiry + + def _update_token(self, request): + """Updates credentials with a new access_token representing + the impersonated account. + + Args: + request (google.auth.transport.requests.Request): Request object + to use for refreshing credentials. + """ + + # Refresh our source credentials. + self._source_credentials.refresh(request) + + body = { + "delegates": self._delegates, + "scope": self._target_scopes, + "lifetime": str(self._lifetime) + "s" + } + + headers = { + 'Content-Type': 'application/json', + } + + # Apply the source credentials authentication info. + self._source_credentials.apply(headers) + + self.token, self.expiry = _make_iam_token_request( + request=request, + principal=self._target_principal, + headers=headers, + body=body) diff --git a/src/google/auth/jwt.py b/src/google/auth/jwt.py index ef23db23..3805f371 100644 --- a/src/google/auth/jwt.py +++ b/src/google/auth/jwt.py @@ -738,7 +738,7 @@ class OnDemandCredentials( parts = urllib.parse.urlsplit(url) # Strip query string and fragment audience = urllib.parse.urlunsplit( - (parts.scheme, parts.netloc, parts.path, None, None)) + (parts.scheme, parts.netloc, parts.path, "", "")) token = self._get_jwt_for_audience(audience) self.apply(headers, token=token) diff --git a/src/google/oauth2/credentials.py b/src/google/oauth2/credentials.py index 8e2a7f80..4cb909cb 100644 --- a/src/google/oauth2/credentials.py +++ b/src/google/oauth2/credentials.py @@ -43,7 +43,7 @@ from google.oauth2 import _client # The Google OAuth 2.0 token endpoint. Used for authorized user credentials. -_GOOGLE_OAUTH2_TOKEN_ENDPOINT = 'https://accounts.google.com/o/oauth2/token' +_GOOGLE_OAUTH2_TOKEN_ENDPOINT = 'https://oauth2.googleapis.com/token' class Credentials(credentials.ReadOnlyScoped, credentials.Credentials): diff --git a/src/googleapiclient/__init__.py b/src/googleapiclient/__init__.py index 5b4ca889..ee986d29 100644 --- a/src/googleapiclient/__init__.py +++ b/src/googleapiclient/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.7.3" +__version__ = "1.7.8" # Set default logging handler to avoid "No handler found" warnings. import logging diff --git a/src/googleapiclient/channel.py b/src/googleapiclient/channel.py index 0fdb080f..3caee13a 100644 --- a/src/googleapiclient/channel.py +++ b/src/googleapiclient/channel.py @@ -1,3 +1,17 @@ +# Copyright 2014 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Channel notifications support. Classes and functions to support channel subscriptions and notifications @@ -53,7 +67,7 @@ always be upper case. Example of unsubscribing. - service.channels().stop(channel.body()) + service.channels().stop(channel.body()).execute() """ from __future__ import absolute_import diff --git a/src/googleapiclient/discovery.py b/src/googleapiclient/discovery.py index 7762d84f..7d895bbe 100644 --- a/src/googleapiclient/discovery.py +++ b/src/googleapiclient/discovery.py @@ -126,14 +126,16 @@ class _BytesGenerator(BytesGenerator): _write_lines = BytesGenerator.write def fix_method_name(name): - """Fix method names to avoid reserved word conflicts. + """Fix method names to avoid '$' characters and reserved word conflicts. Args: name: string, method name. Returns: - The name with an '_' appended if the name is a reserved word. + The name with '_' appended if the name is a reserved word and '$' + replaced with '_'. """ + name = name.replace('$', '_') if keyword.iskeyword(name) or name in RESERVED_WORDS: return name + '_' else: @@ -219,7 +221,7 @@ def build(serviceName, try: content = _retrieve_discovery_doc( - requested_url, discovery_http, cache_discovery, cache) + requested_url, discovery_http, cache_discovery, cache, developerKey) return build_from_document(content, base=discovery_url, http=http, developerKey=developerKey, model=model, requestBuilder=requestBuilder, credentials=credentials) @@ -233,7 +235,8 @@ def build(serviceName, "name: %s version: %s" % (serviceName, version)) -def _retrieve_discovery_doc(url, http, cache_discovery, cache=None): +def _retrieve_discovery_doc(url, http, cache_discovery, cache=None, + developerKey=None): """Retrieves the discovery_doc from cache or the internet. Args: @@ -264,6 +267,8 @@ def _retrieve_discovery_doc(url, http, cache_discovery, cache=None): # document to avoid exceeding the quota on discovery requests. if 'REMOTE_ADDR' in os.environ: actual_url = _add_query_parameter(url, 'userIp', os.environ['REMOTE_ADDR']) + if developerKey: + actual_url = _add_query_parameter(url, 'key', developerKey) logger.info('URL being requested: GET %s', actual_url) resp, content = http.request(actual_url) @@ -360,7 +365,9 @@ def build_from_document( # The credentials need to be scoped. credentials = _auth.with_scopes(credentials, scopes) - # Create an authorized http instance + # If credentials are provided, create an authorized http instance; + # otherwise, skip authentication. + if credentials: http = _auth.authorized_http(credentials) # If the service doesn't require scopes then there is no need for diff --git a/src/googleapiclient/http.py b/src/googleapiclient/http.py index a7f14b73..4949d0cf 100644 --- a/src/googleapiclient/http.py +++ b/src/googleapiclient/http.py @@ -73,6 +73,8 @@ DEFAULT_CHUNK_SIZE = 100*1024*1024 MAX_URI_LENGTH = 2048 +MAX_BATCH_LIMIT = 1000 + _TOO_MANY_REQUESTS = 429 DEFAULT_HTTP_TIMEOUT_SEC = 60 @@ -173,6 +175,8 @@ def _retry_request(http, num_retries, req_type, sleep, rand, uri, method, *args, 'WSAETIMEDOUT', 'ETIMEDOUT', 'EPIPE', 'ECONNABORTED'}: raise exception = socket_error + except httplib2.ServerNotFoundError as server_not_found_error: + exception = server_not_found_error if exception: if retry_num == num_retries: @@ -645,6 +649,14 @@ class MediaIoBaseDownload(object): self._sleep = time.sleep self._rand = random.random + self._headers = {} + for k, v in six.iteritems(request.headers): + # allow users to supply custom headers by setting them on the request + # but strip out the ones that are set by default on requests generated by + # API methods like Drive's files().get(fileId=...) + if not k.lower() in ('accept', 'accept-encoding', 'user-agent'): + self._headers[k] = v + @util.positional(1) def next_chunk(self, num_retries=0): """Get the next chunk of the download. @@ -664,10 +676,9 @@ class MediaIoBaseDownload(object): googleapiclient.errors.HttpError if the response was not a 2xx. httplib2.HttpLib2Error if a transport error has occured. """ - headers = { - 'range': 'bytes=%d-%d' % ( + headers = self._headers.copy() + headers['range'] = 'bytes=%d-%d' % ( self._progress, self._progress + self._chunksize) - } http = self._request.http resp, content = _retry_request( @@ -1169,7 +1180,10 @@ class BatchHttpRequest(object): if self._base_id is None: self._base_id = uuid.uuid4() - return '<%s+%s>' % (self._base_id, quote(id_)) + # NB: we intentionally leave whitespace between base/id and '+', so RFC2822 + # line folding works properly on Python 3; see + # https://github.com/google/google-api-python-client/issues/164 + return '<%s + %s>' % (self._base_id, quote(id_)) def _header_to_id(self, header): """Convert a Content-ID header value to an id. @@ -1190,7 +1204,7 @@ class BatchHttpRequest(object): raise BatchError("Invalid value for Content-ID: %s" % header) if '+' not in header: raise BatchError("Invalid value for Content-ID: %s" % header) - base, id_ = header[1:-1].rsplit('+', 1) + base, id_ = header[1:-1].split(' + ', 1) return unquote(id_) @@ -1300,8 +1314,8 @@ class BatchHttpRequest(object): request id, and the second is the deserialized response object. The third is an googleapiclient.errors.HttpError exception object if an HTTP error occurred while processing the request, or None if no errors occurred. - request_id: string, A unique id for the request. The id will be passed to - the callback with the response. + request_id: string, A unique id for the request. The id will be passed + to the callback with the response. Returns: None @@ -1310,6 +1324,10 @@ class BatchHttpRequest(object): BatchError if a media request is added to a batch. KeyError is the request_id is not unique. """ + + if len(self._order) >= MAX_BATCH_LIMIT: + raise BatchError("Exceeded the maximum calls(%d) in a single batch request." + % MAX_BATCH_LIMIT) if request_id is None: request_id = self._new_id() if request.resumable is not None: diff --git a/src/googleapiclient/sample_tools.py b/src/googleapiclient/sample_tools.py index 21fede3e..5cb7a06e 100644 --- a/src/googleapiclient/sample_tools.py +++ b/src/googleapiclient/sample_tools.py @@ -28,14 +28,6 @@ import os from googleapiclient import discovery from googleapiclient.http import build_http -try: - from oauth2client import client - from oauth2client import file - from oauth2client import tools -except ImportError: - raise ImportError('googleapiclient.sample_tools requires oauth2client. Please install oauth2client and try again.') - - def init(argv, name, version, doc, filename, scope=None, parents=[], discovery_filename=None): """A common initialization routine for samples. @@ -60,6 +52,13 @@ def init(argv, name, version, doc, filename, scope=None, parents=[], discovery_f A tuple of (service, flags), where service is the service object and flags is the parsed command-line flags. """ + try: + from oauth2client import client + from oauth2client import file + from oauth2client import tools + except ImportError: + raise ImportError('googleapiclient.sample_tools requires oauth2client. Please install oauth2client and try again.') + if scope is None: scope = 'https://www.googleapis.com/auth/' + name diff --git a/src/httplib2/__init__.py b/src/httplib2/__init__.py index 18b013d9..fee091d7 100644 --- a/src/httplib2/__init__.py +++ b/src/httplib2/__init__.py @@ -1,55 +1,53 @@ +"""Small, fast HTTP client library for Python. + +Features persistent connections, cache, and Google App Engine Standard +Environment support. +""" + from __future__ import print_function -""" -httplib2 - -A caching http interface that supports ETags and gzip -to conserve bandwidth. - -Requires Python 2.3 or later - -Changelog: -2007-08-18, Rick: Modified so it's able to use a socks proxy if needed. - -""" __author__ = "Joe Gregorio (joe@bitworking.org)" __copyright__ = "Copyright 2006, Joe Gregorio" -__contributors__ = ["Thomas Broyer (t.broyer@ltgt.net)", - "James Antill", - "Xavier Verges Farrero", - "Jonathan Feinberg", - "Blair Zajac", - "Sam Ruby", - "Louis Nyffenegger", - "Alex Yu"] +__contributors__ = [ + "Thomas Broyer (t.broyer@ltgt.net)", + "James Antill", + "Xavier Verges Farrero", + "Jonathan Feinberg", + "Blair Zajac", + "Sam Ruby", + "Louis Nyffenegger", + "Alex Yu", +] __license__ = "MIT" -__version__ = '0.11.3' +__version__ = '0.12.1' -import re -import sys -import email -import email.Utils -import email.Message -import email.FeedParser -import StringIO -import gzip -import zlib -import httplib -import urlparse -import urllib import base64 -import os -import copy import calendar -import time -import random +import copy +import email +import email.FeedParser +import email.Message +import email.Utils import errno +import gzip +import httplib +import os +import random +import re +import StringIO +import sys +import time +import urllib +import urlparse +import zlib + try: from hashlib import sha1 as _sha, md5 as _md5 except ImportError: # prior to Python 2.5, these were separate modules import sha import md5 + _sha = sha.new _md5 = md5.new import hmac @@ -73,12 +71,13 @@ try: except ImportError: pass if ssl is not None: - ssl_SSLError = getattr(ssl, 'SSLError', None) - ssl_CertificateError = getattr(ssl, 'CertificateError', None) + ssl_SSLError = getattr(ssl, "SSLError", None) + ssl_CertificateError = getattr(ssl, "CertificateError", None) -def _ssl_wrap_socket(sock, key_file, cert_file, disable_validation, - ca_certs, ssl_version, hostname): +def _ssl_wrap_socket( + sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname +): if disable_validation: cert_reqs = ssl.CERT_NONE else: @@ -86,53 +85,69 @@ def _ssl_wrap_socket(sock, key_file, cert_file, disable_validation, if ssl_version is None: ssl_version = ssl.PROTOCOL_SSLv23 - if hasattr(ssl, 'SSLContext'): # Python 2.7.9 + if hasattr(ssl, "SSLContext"): # Python 2.7.9 context = ssl.SSLContext(ssl_version) context.verify_mode = cert_reqs - context.check_hostname = (cert_reqs != ssl.CERT_NONE) + context.check_hostname = cert_reqs != ssl.CERT_NONE if cert_file: context.load_cert_chain(cert_file, key_file) if ca_certs: context.load_verify_locations(ca_certs) return context.wrap_socket(sock, server_hostname=hostname) else: - return ssl.wrap_socket(sock, keyfile=key_file, certfile=cert_file, - cert_reqs=cert_reqs, ca_certs=ca_certs, - ssl_version=ssl_version) + return ssl.wrap_socket( + sock, + keyfile=key_file, + certfile=cert_file, + cert_reqs=cert_reqs, + ca_certs=ca_certs, + ssl_version=ssl_version, + ) -def _ssl_wrap_socket_unsupported(sock, key_file, cert_file, disable_validation, - ca_certs, ssl_version, hostname): +def _ssl_wrap_socket_unsupported( + sock, key_file, cert_file, disable_validation, ca_certs, ssl_version, hostname +): if not disable_validation: raise CertificateValidationUnsupported( - "SSL certificate validation is not supported without " - "the ssl module installed. To avoid this error, install " - "the ssl module, or explicity disable validation.") + "SSL certificate validation is not supported without " + "the ssl module installed. To avoid this error, install " + "the ssl module, or explicity disable validation." + ) ssl_sock = socket.ssl(sock, key_file, cert_file) return httplib.FakeSocket(sock, ssl_sock) + if ssl is None: _ssl_wrap_socket = _ssl_wrap_socket_unsupported - -if sys.version_info >= (2,3): +if sys.version_info >= (2, 3): from iri2uri import iri2uri else: + def iri2uri(uri): return uri -def has_timeout(timeout): # python 2.6 - if hasattr(socket, '_GLOBAL_DEFAULT_TIMEOUT'): - return (timeout is not None and timeout is not socket._GLOBAL_DEFAULT_TIMEOUT) - return (timeout is not None) + +def has_timeout(timeout): # python 2.6 + if hasattr(socket, "_GLOBAL_DEFAULT_TIMEOUT"): + return timeout is not None and timeout is not socket._GLOBAL_DEFAULT_TIMEOUT + return timeout is not None + __all__ = [ - 'Http', 'Response', 'ProxyInfo', 'HttpLib2Error', 'RedirectMissingLocation', - 'RedirectLimit', 'FailedToDecompressContent', - 'UnimplementedDigestAuthOptionError', - 'UnimplementedHmacDigestAuthOptionError', - 'debuglevel', 'ProxiesUnavailableError'] - + "Http", + "Response", + "ProxyInfo", + "HttpLib2Error", + "RedirectMissingLocation", + "RedirectLimit", + "FailedToDecompressContent", + "UnimplementedDigestAuthOptionError", + "UnimplementedHmacDigestAuthOptionError", + "debuglevel", + "ProxiesUnavailableError", +] # The httplib debug level, set to a non-zero value to get debug output debuglevel = 0 @@ -141,7 +156,8 @@ debuglevel = 0 RETRIES = 2 # Python 2.3 support -if sys.version_info < (2,4): +if sys.version_info < (2, 4): + def sorted(seq): seq.sort() return seq @@ -154,11 +170,15 @@ def HTTPResponse__getheaders(self): raise httplib.ResponseNotReady() return self.msg.items() -if not hasattr(httplib.HTTPResponse, 'getheaders'): + +if not hasattr(httplib.HTTPResponse, "getheaders"): httplib.HTTPResponse.getheaders = HTTPResponse__getheaders + # All exceptions raised here derive from HttpLib2Error -class HttpLib2Error(Exception): pass +class HttpLib2Error(Exception): + pass + # Some exceptions can be caught and optionally # be turned back into responses. @@ -168,26 +188,65 @@ class HttpLib2ErrorWithResponse(HttpLib2Error): self.content = content HttpLib2Error.__init__(self, desc) -class RedirectMissingLocation(HttpLib2ErrorWithResponse): pass -class RedirectLimit(HttpLib2ErrorWithResponse): pass -class FailedToDecompressContent(HttpLib2ErrorWithResponse): pass -class UnimplementedDigestAuthOptionError(HttpLib2ErrorWithResponse): pass -class UnimplementedHmacDigestAuthOptionError(HttpLib2ErrorWithResponse): pass -class MalformedHeader(HttpLib2Error): pass -class RelativeURIError(HttpLib2Error): pass -class ServerNotFoundError(HttpLib2Error): pass -class ProxiesUnavailableError(HttpLib2Error): pass -class CertificateValidationUnsupported(HttpLib2Error): pass -class SSLHandshakeError(HttpLib2Error): pass -class NotSupportedOnThisPlatform(HttpLib2Error): pass +class RedirectMissingLocation(HttpLib2ErrorWithResponse): + pass + + +class RedirectLimit(HttpLib2ErrorWithResponse): + pass + + +class FailedToDecompressContent(HttpLib2ErrorWithResponse): + pass + + +class UnimplementedDigestAuthOptionError(HttpLib2ErrorWithResponse): + pass + + +class UnimplementedHmacDigestAuthOptionError(HttpLib2ErrorWithResponse): + pass + + +class MalformedHeader(HttpLib2Error): + pass + + +class RelativeURIError(HttpLib2Error): + pass + + +class ServerNotFoundError(HttpLib2Error): + pass + + +class ProxiesUnavailableError(HttpLib2Error): + pass + + +class CertificateValidationUnsupported(HttpLib2Error): + pass + + +class SSLHandshakeError(HttpLib2Error): + pass + + +class NotSupportedOnThisPlatform(HttpLib2Error): + pass + + class CertificateHostnameMismatch(SSLHandshakeError): def __init__(self, desc, host, cert): HttpLib2Error.__init__(self, desc) self.host = host self.cert = cert -class NotRunningAppEngineEnvironment(HttpLib2Error): pass + +class NotRunningAppEngineEnvironment(HttpLib2Error): + pass + # Open Items: # ----------- @@ -204,32 +263,34 @@ class NotRunningAppEngineEnvironment(HttpLib2Error): pass # Does not handle Cache-Control: max-stale # Does not use Age: headers when calculating cache freshness. - # The number of redirections to follow before giving up. # Note that only GET redirects are automatically followed. # Will also honor 301 requests by saving that info and never # requesting that URI again. DEFAULT_MAX_REDIRECTS = 5 -try: - # Users can optionally provide a module that tells us where the CA_CERTS - # are located. - import ca_certs_locater - CA_CERTS = ca_certs_locater.get() -except ImportError: - # Default CA certificates file bundled with httplib2. - CA_CERTS = os.path.join( - os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt") +from httplib2 import certs +CA_CERTS = certs.where() # Which headers are hop-by-hop headers by default -HOP_BY_HOP = ['connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade'] +HOP_BY_HOP = [ + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", +] def _get_end2end_headers(response): hopbyhop = list(HOP_BY_HOP) - hopbyhop.extend([x.strip() for x in response.get('connection', '').split(',')]) + hopbyhop.extend([x.strip() for x in response.get("connection", "").split(",")]) return [header for header in response.keys() if header not in hopbyhop] + URI = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") @@ -259,53 +320,62 @@ def urlnorm(uri): # Cache filename construction (original borrowed from Venus http://intertwingly.net/code/venus/) -re_url_scheme = re.compile(r'^\w+://') -re_slash = re.compile(r'[?/:|]+') +re_url_scheme = re.compile(r"^\w+://") +re_unsafe = re.compile(r"[^\w\-_.()=!]+") def safename(filename): """Return a filename suitable for the cache. - Strips dangerous and common characters to create a filename we can use to store the cache in. """ - - try: - if re_url_scheme.match(filename): - if isinstance(filename,str): - filename = filename.decode('utf-8') - filename = filename.encode('idna') - else: - filename = filename.encode('idna') - except UnicodeError: - pass - if isinstance(filename,unicode): - filename=filename.encode('utf-8') - filemd5 = _md5(filename).hexdigest() + if isinstance(filename, str): + filename_bytes = filename + filename = filename.decode("utf-8") + else: + filename_bytes = filename.encode("utf-8") + filemd5 = _md5(filename_bytes).hexdigest() filename = re_url_scheme.sub("", filename) - filename = re_slash.sub(",", filename) + filename = re_unsafe.sub("", filename) + + # limit length of filename (vital for Windows) + # https://github.com/httplib2/httplib2/pull/74 + # C:\Users\ \AppData\Local\Temp\ , + # 9 chars + max 104 chars + 20 chars + x + 1 + 32 = max 259 chars + # Thus max safe filename x = 93 chars. Let it be 90 to make a round sum: + filename = filename[:90] - # limit length of filename - if len(filename)>200: - filename=filename[:200] return ",".join((filename, filemd5)) -NORMALIZE_SPACE = re.compile(r'(?:\r\n)?[ \t]+') + +NORMALIZE_SPACE = re.compile(r"(?:\r\n)?[ \t]+") def _normalize_headers(headers): - return dict([ (key.lower(), NORMALIZE_SPACE.sub(value, ' ').strip()) for (key, value) in headers.iteritems()]) + return dict( + [ + (key.lower(), NORMALIZE_SPACE.sub(value, " ").strip()) + for (key, value) in headers.iteritems() + ] + ) def _parse_cache_control(headers): retval = {} - if 'cache-control' in headers: - parts = headers['cache-control'].split(',') - parts_with_args = [tuple([x.strip().lower() for x in part.split("=", 1)]) for part in parts if -1 != part.find("=")] - parts_wo_args = [(name.strip().lower(), 1) for name in parts if -1 == name.find("=")] + if "cache-control" in headers: + parts = headers["cache-control"].split(",") + parts_with_args = [ + tuple([x.strip().lower() for x in part.split("=", 1)]) + for part in parts + if -1 != part.find("=") + ] + parts_wo_args = [ + (name.strip().lower(), 1) for name in parts if -1 == name.find("=") + ] retval = dict(parts_with_args + parts_wo_args) return retval + # Whether to use a strict mode to parse WWW-Authenticate headers # Might lead to bad results in case of ill-formed header value, # so disabled by default, falling back to relaxed parsing. @@ -317,10 +387,16 @@ USE_WWW_AUTH_STRICT_PARSING = 0 # "(?:[^\0-\x08\x0A-\x1f\x7f-\xff\\\"]|\\[\0-\x7f])*?" matches a "quoted-string" as defined by HTTP, when LWS have already been replaced by a single space # Actually, as an auth-param value can be either a token or a quoted-string, they are combined in a single pattern which matches both: # \"?((?<=\")(?:[^\0-\x1f\x7f-\xff\\\"]|\\[\0-\x7f])*?(?=\")|(?@,;:\\\"/[\]?={} \t]+(?!\"))\"? -WWW_AUTH_STRICT = re.compile(r"^(?:\s*(?:,\s*)?([^\0-\x1f\x7f-\xff()<>@,;:\\\"/[\]?={} \t]+)\s*=\s*\"?((?<=\")(?:[^\0-\x08\x0A-\x1f\x7f-\xff\\\"]|\\[\0-\x7f])*?(?=\")|(?@,;:\\\"/[\]?={} \t]+(?!\"))\"?)(.*)$") -WWW_AUTH_RELAXED = re.compile(r"^(?:\s*(?:,\s*)?([^ \t\r\n=]+)\s*=\s*\"?((?<=\")(?:[^\\\"]|\\.)*?(?=\")|(?@,;:\\\"/[\]?={} \t]+)\s*=\s*\"?((?<=\")(?:[^\0-\x08\x0A-\x1f\x7f-\xff\\\"]|\\[\0-\x7f])*?(?=\")|(?@,;:\\\"/[\]?={} \t]+(?!\"))\"?)(.*)$" +) +WWW_AUTH_RELAXED = re.compile( + r"^(?:\s*(?:,\s*)?([^ \t\r\n=]+)\s*=\s*\"?((?<=\")(?:[^\\\"]|\\.)*?(?=\")|(? 0: + if service == "xapi" and request_uri.find("calendar") > 0: service = "cl" # No point in guessing Base or Spreadsheet - #elif request_uri.find("spreadsheets") > 0: + # elif request_uri.find("spreadsheets") > 0: # service = "wise" - auth = dict(Email=credentials[0], Passwd=credentials[1], service=service, source=headers['user-agent']) - resp, content = self.http.request("https://www.google.com/accounts/ClientLogin", method="POST", body=urlencode(auth), headers={'Content-Type': 'application/x-www-form-urlencoded'}) - lines = content.split('\n') + auth = dict( + Email=credentials[0], + Passwd=credentials[1], + service=service, + source=headers["user-agent"], + ) + resp, content = self.http.request( + "https://www.google.com/accounts/ClientLogin", + method="POST", + body=urlencode(auth), + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + lines = content.split("\n") d = dict([tuple(line.split("=", 1)) for line in lines if line]) if resp.status == 403: self.Auth = "" else: - self.Auth = d['Auth'] + self.Auth = d["Auth"] def request(self, method, request_uri, headers, content): """Modify the request headers to add the appropriate Authorization header.""" - headers['authorization'] = 'GoogleLogin Auth=' + self.Auth + headers["authorization"] = "GoogleLogin Auth=" + self.Auth AUTH_SCHEME_CLASSES = { @@ -720,7 +916,7 @@ AUTH_SCHEME_CLASSES = { "wsse": WsseAuthentication, "digest": DigestAuthentication, "hmacdigest": HmacDigestAuthentication, - "googlelogin": GoogleLoginAuthentication + "googlelogin": GoogleLoginAuthentication, } AUTH_SCHEME_ORDER = ["hmacdigest", "googlelogin", "digest", "wsse", "basic"] @@ -731,7 +927,10 @@ class FileCache(object): Not really safe to use if multiple threads or processes are going to be running on the same cache. """ - def __init__(self, cache, safe=safename): # use safe=lambda x: md5.new(x).hexdigest() for the old behavior + + def __init__( + self, cache, safe=safename + ): # use safe=lambda x: md5.new(x).hexdigest() for the old behavior self.cache = cache self.safe = safe if not os.path.exists(cache): @@ -779,6 +978,7 @@ class Credentials(object): class KeyCerts(Credentials): """Identical to Credentials except that name/password are mapped to key/cert.""" + pass @@ -788,32 +988,35 @@ class AllHosts(object): class ProxyInfo(object): """Collect information required to use a proxy.""" + bypass_hosts = () - def __init__(self, proxy_type, proxy_host, proxy_port, - proxy_rdns=True, proxy_user=None, proxy_pass=None, proxy_headers=None): - """ - Args: + def __init__( + self, + proxy_type, + proxy_host, + proxy_port, + proxy_rdns=True, + proxy_user=None, + proxy_pass=None, + proxy_headers=None, + ): + """Args: + proxy_type: The type of proxy server. This must be set to one of - socks.PROXY_TYPE_XXX constants. For example: - - p = ProxyInfo(proxy_type=socks.PROXY_TYPE_HTTP, - proxy_host='localhost', proxy_port=8000) - + socks.PROXY_TYPE_XXX constants. For example: p = + ProxyInfo(proxy_type=socks.PROXY_TYPE_HTTP, proxy_host='localhost', + proxy_port=8000) proxy_host: The hostname or IP address of the proxy server. - proxy_port: The port that the proxy server is running on. - proxy_rdns: If True (default), DNS queries will not be performed locally, and instead, handed to the proxy to resolve. This is useful - if the network does not allow resolution of non-local names. In + if the network does not allow resolution of non-local names. In httplib2 0.9 and earlier, this defaulted to False. - proxy_user: The username used to authenticate with the proxy server. - proxy_pass: The password used to authenticate with the proxy server. - - proxy_headers: Additional or modified headers for the proxy connect request. + proxy_headers: Additional or modified headers for the proxy connect + request. """ self.proxy_type = proxy_type self.proxy_host = proxy_host @@ -824,8 +1027,15 @@ class ProxyInfo(object): self.proxy_headers = proxy_headers def astuple(self): - return (self.proxy_type, self.proxy_host, self.proxy_port, - self.proxy_rdns, self.proxy_user, self.proxy_pass, self.proxy_headers) + return ( + self.proxy_type, + self.proxy_host, + self.proxy_port, + self.proxy_rdns, + self.proxy_user, + self.proxy_pass, + self.proxy_headers, + ) def isgood(self): return (self.proxy_host != None) and (self.proxy_port != None) @@ -838,54 +1048,54 @@ class ProxyInfo(object): if self.bypass_hosts is AllHosts: return True - hostname = '.' + hostname.lstrip('.') + hostname = "." + hostname.lstrip(".") for skip_name in self.bypass_hosts: # *.suffix - if skip_name.startswith('.') and hostname.endswith(skip_name): + if skip_name.startswith(".") and hostname.endswith(skip_name): return True # exact match - if hostname == '.' + skip_name: + if hostname == "." + skip_name: return True return False def __repr__(self): return ( - '').format(p=self) + "" + ).format(p=self) -def proxy_info_from_environment(method='http'): +def proxy_info_from_environment(method="http"): + """Read proxy info from the environment variables. """ - Read proxy info from the environment variables. - """ - if method not in ['http', 'https']: + if method not in ["http", "https"]: return - env_var = method + '_proxy' + env_var = method + "_proxy" url = os.environ.get(env_var, os.environ.get(env_var.upper())) if not url: return return proxy_info_from_url(url, method, None) -def proxy_info_from_url(url, method='http', noproxy=None): - """ - Construct a ProxyInfo from a URL (such as http_proxy env var) +def proxy_info_from_url(url, method="http", noproxy=None): + """Construct a ProxyInfo from a URL (such as http_proxy env var) """ url = urlparse.urlparse(url) username = None password = None port = None - if '@' in url[1]: - ident, host_port = url[1].split('@', 1) - if ':' in ident: - username, password = ident.split(':', 1) + if "@" in url[1]: + ident, host_port = url[1].split("@", 1) + if ":" in ident: + username, password = ident.split(":", 1) else: password = ident else: host_port = url[1] - if ':' in host_port: - host, port = host_port.split(':', 1) + if ":" in host_port: + host, port = host_port.split(":", 1) else: host = host_port @@ -896,23 +1106,23 @@ def proxy_info_from_url(url, method='http', noproxy=None): proxy_type = 3 # socks.PROXY_TYPE_HTTP pi = ProxyInfo( - proxy_type = proxy_type, - proxy_host = host, - proxy_port = port, - proxy_user = username or None, - proxy_pass = password or None, - proxy_headers = None, + proxy_type=proxy_type, + proxy_host=host, + proxy_port=port, + proxy_user=username or None, + proxy_pass=password or None, + proxy_headers=None, ) bypass_hosts = [] # If not given an explicit noproxy value, respect values in env vars. if noproxy is None: - noproxy = os.environ.get('no_proxy', os.environ.get('NO_PROXY', '')) + noproxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY", "")) # Special case: A single '*' character means all hosts should be bypassed. - if noproxy == '*': + if noproxy == "*": bypass_hosts = AllHosts elif noproxy.strip(): - bypass_hosts = noproxy.split(',') + bypass_hosts = noproxy.split(",") bypass_hosts = filter(bool, bypass_hosts) # To exclude empty string. pi.bypass_hosts = bypass_hosts @@ -920,8 +1130,7 @@ def proxy_info_from_url(url, method='http', noproxy=None): class HTTPConnectionWithTimeout(httplib.HTTPConnection): - """ - HTTPConnection subclass that supports timeouts + """HTTPConnection subclass that supports timeouts All timeouts are in seconds. If None is passed for timeout then Python's default timeout for sockets will be used. See for example @@ -939,11 +1148,14 @@ class HTTPConnectionWithTimeout(httplib.HTTPConnection): # Mostly verbatim from httplib.py. if self.proxy_info and socks is None: raise ProxiesUnavailableError( - 'Proxy support missing but proxy use was requested!') + "Proxy support missing but proxy use was requested!" + ) msg = "getaddrinfo returns an empty list" if self.proxy_info and self.proxy_info.isgood(): use_proxy = True - proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers = self.proxy_info.astuple() + proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers = ( + self.proxy_info.astuple() + ) host = proxy_host port = proxy_port @@ -958,7 +1170,15 @@ class HTTPConnectionWithTimeout(httplib.HTTPConnection): try: if use_proxy: self.sock = socks.socksocket(af, socktype, proto) - self.sock.setproxy(proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers) + self.sock.setproxy( + proxy_type, + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) else: self.sock = socket.socket(af, socktype, proto) self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -969,7 +1189,19 @@ class HTTPConnectionWithTimeout(httplib.HTTPConnection): if self.debuglevel > 0: print("connect: (%s, %s) ************" % (self.host, self.port)) if use_proxy: - print("proxy: %s ************" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers))) + print( + "proxy: %s ************" + % str( + ( + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) + ) + ) if use_proxy: self.sock.connect((self.host, self.port) + sa[2:]) else: @@ -978,39 +1210,59 @@ class HTTPConnectionWithTimeout(httplib.HTTPConnection): if self.debuglevel > 0: print("connect fail: (%s, %s)" % (self.host, self.port)) if use_proxy: - print("proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers))) + print( + "proxy: %s" + % str( + ( + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) + ) + ) if self.sock: self.sock.close() self.sock = None continue break if not self.sock: - raise socket.error, msg + raise socket.error(msg) class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): - """ - This class allows communication via SSL. + """This class allows communication via SSL. All timeouts are in seconds. If None is passed for timeout then Python's default timeout for sockets will be used. See for example the docs of socket.setdefaulttimeout(): http://docs.python.org/library/socket.html#socket.setdefaulttimeout """ - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, - ca_certs=None, disable_ssl_certificate_validation=False, - ssl_version=None): - httplib.HTTPSConnection.__init__(self, host, port=port, - key_file=key_file, - cert_file=cert_file, strict=strict) + + def __init__( + self, + host, + port=None, + key_file=None, + cert_file=None, + strict=None, + timeout=None, + proxy_info=None, + ca_certs=None, + disable_ssl_certificate_validation=False, + ssl_version=None, + ): + httplib.HTTPSConnection.__init__( + self, host, port=port, key_file=key_file, cert_file=cert_file, strict=strict + ) self.timeout = timeout self.proxy_info = proxy_info if ca_certs is None: ca_certs = CA_CERTS self.ca_certs = ca_certs - self.disable_ssl_certificate_validation = \ - disable_ssl_certificate_validation + self.disable_ssl_certificate_validation = disable_ssl_certificate_validation self.ssl_version = ssl_version # The following two methods were adapted from https_wrapper.py, released @@ -1041,12 +1293,10 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): Returns: list: A list of valid host globs. """ - if 'subjectAltName' in cert: - return [x[1] for x in cert['subjectAltName'] - if x[0].lower() == 'dns'] + if "subjectAltName" in cert: + return [x[1] for x in cert["subjectAltName"] if x[0].lower() == "dns"] else: - return [x[0][1] for x in cert['subject'] - if x[0][0].lower() == 'commonname'] + return [x[0][1] for x in cert["subject"] if x[0][0].lower() == "commonname"] def _ValidateCertificateHostname(self, cert, hostname): """Validates that a given hostname is valid for an SSL certificate. @@ -1059,8 +1309,8 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): """ hosts = self._GetValidHostsForCert(cert) for host in hosts: - host_re = host.replace('.', '\.').replace('*', '[^.]*') - if re.search('^%s$' % (host_re,), hostname, re.I): + host_re = host.replace(".", "\.").replace("*", "[^.]*") + if re.search("^%s$" % (host_re,), hostname, re.I): return True return False @@ -1070,7 +1320,9 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): msg = "getaddrinfo returns an empty list" if self.proxy_info and self.proxy_info.isgood(): use_proxy = True - proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers = self.proxy_info.astuple() + proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers = ( + self.proxy_info.astuple() + ) host = proxy_host port = proxy_port @@ -1086,7 +1338,15 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): if use_proxy: sock = socks.socksocket(family, socktype, proto) - sock.setproxy(proxy_type, proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers) + sock.setproxy( + proxy_type, + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) else: sock = socket.socket(family, socktype, proto) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -1098,22 +1358,46 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): sock.connect((self.host, self.port) + sockaddr[:2]) else: sock.connect(sockaddr) - self.sock =_ssl_wrap_socket( - sock, self.key_file, self.cert_file, - self.disable_ssl_certificate_validation, self.ca_certs, - self.ssl_version, self.host) + self.sock = _ssl_wrap_socket( + sock, + self.key_file, + self.cert_file, + self.disable_ssl_certificate_validation, + self.ca_certs, + self.ssl_version, + self.host, + ) if self.debuglevel > 0: print("connect: (%s, %s)" % (self.host, self.port)) if use_proxy: - print("proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers))) + print( + "proxy: %s" + % str( + ( + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) + ) + ) if not self.disable_ssl_certificate_validation: cert = self.sock.getpeercert() - hostname = self.host.split(':', 0)[0] + hostname = self.host.split(":", 0)[0] if not self._ValidateCertificateHostname(cert, hostname): raise CertificateHostnameMismatch( - 'Server presented certificate that does not match ' - 'host %s: %s' % (hostname, cert), hostname, cert) - except (ssl_SSLError, ssl_CertificateError, CertificateHostnameMismatch) as e: + "Server presented certificate that does not match " + "host %s: %s" % (hostname, cert), + hostname, + cert, + ) + except ( + ssl_SSLError, + ssl_CertificateError, + CertificateHostnameMismatch, + ) as e: if sock: sock.close() if self.sock: @@ -1123,7 +1407,7 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): # to get at more detailed error information, in particular # whether the error is due to certificate validation or # something else (such as SSL protocol mismatch). - if getattr(e, 'errno', None) == ssl.SSL_ERROR_SSL: + if getattr(e, "errno", None) == ssl.SSL_ERROR_SSL: raise SSLHandshakeError(e) else: raise @@ -1133,31 +1417,56 @@ class HTTPSConnectionWithTimeout(httplib.HTTPSConnection): if self.debuglevel > 0: print("connect fail: (%s, %s)" % (self.host, self.port)) if use_proxy: - print("proxy: %s" % str((proxy_host, proxy_port, proxy_rdns, proxy_user, proxy_pass, proxy_headers))) + print( + "proxy: %s" + % str( + ( + proxy_host, + proxy_port, + proxy_rdns, + proxy_user, + proxy_pass, + proxy_headers, + ) + ) + ) if self.sock: self.sock.close() self.sock = None continue break if not self.sock: - raise socket.error, msg + raise socket.error(msg) + SCHEME_TO_CONNECTION = { - 'http': HTTPConnectionWithTimeout, - 'https': HTTPSConnectionWithTimeout + "http": HTTPConnectionWithTimeout, + "https": HTTPSConnectionWithTimeout, } def _new_fixed_fetch(validate_certificate): - def fixed_fetch(url, payload=None, method="GET", headers={}, - allow_truncated=False, follow_redirects=True, - deadline=None): - if deadline is None: - deadline = socket.getdefaulttimeout() - return fetch(url, payload=payload, method=method, headers=headers, - allow_truncated=allow_truncated, - follow_redirects=follow_redirects, deadline=deadline, - validate_certificate=validate_certificate) + + def fixed_fetch( + url, + payload=None, + method="GET", + headers={}, + allow_truncated=False, + follow_redirects=True, + deadline=None, + ): + return fetch( + url, + payload=payload, + method=method, + headers=headers, + allow_truncated=allow_truncated, + follow_redirects=follow_redirects, + deadline=deadline, + validate_certificate=validate_certificate, + ) + return fixed_fetch @@ -1168,12 +1477,23 @@ class AppEngineHttpConnection(httplib.HTTPConnection): disable_ssl_certificate_validation, and ssl_version are all dropped on the ground. """ - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, ca_certs=None, - disable_ssl_certificate_validation=False, - ssl_version=None): - httplib.HTTPConnection.__init__(self, host, port=port, - strict=strict, timeout=timeout) + + def __init__( + self, + host, + port=None, + key_file=None, + cert_file=None, + strict=None, + timeout=None, + proxy_info=None, + ca_certs=None, + disable_ssl_certificate_validation=False, + ssl_version=None, + ): + httplib.HTTPConnection.__init__( + self, host, port=port, strict=strict, timeout=timeout + ) class AppEngineHttpsConnection(httplib.HTTPSConnection): @@ -1182,38 +1502,58 @@ class AppEngineHttpsConnection(httplib.HTTPSConnection): The parameters proxy_info, ca_certs, disable_ssl_certificate_validation, and ssl_version are all dropped on the ground. """ - def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None, ca_certs=None, - disable_ssl_certificate_validation=False, - ssl_version=None): - httplib.HTTPSConnection.__init__(self, host, port=port, - key_file=key_file, - cert_file=cert_file, strict=strict, - timeout=timeout) - self._fetch = _new_fixed_fetch( - not disable_ssl_certificate_validation) -# Use a different connection object for Google App Engine + def __init__( + self, + host, + port=None, + key_file=None, + cert_file=None, + strict=None, + timeout=None, + proxy_info=None, + ca_certs=None, + disable_ssl_certificate_validation=False, + ssl_version=None, + ): + httplib.HTTPSConnection.__init__( + self, + host, + port=port, + key_file=key_file, + cert_file=cert_file, + strict=strict, + timeout=timeout, + ) + self._fetch = _new_fixed_fetch(not disable_ssl_certificate_validation) + + +# Use a different connection object for Google App Engine Standard Environment. +def is_gae_instance(): + server_software = os.environ.get('SERVER_SOFTWARE', '') + if (server_software.startswith('Google App Engine/') or + server_software.startswith('Development/') or + server_software.startswith('testutil/')): + return True + return False + + try: - server_software = os.environ.get('SERVER_SOFTWARE') - if not server_software: - raise NotRunningAppEngineEnvironment() - elif not (server_software.startswith('Google App Engine/') or - server_software.startswith('Development/')): + if not is_gae_instance(): raise NotRunningAppEngineEnvironment() from google.appengine.api import apiproxy_stub_map - if apiproxy_stub_map.apiproxy.GetStub('urlfetch') is None: - raise ImportError # Bail out; we're not actually running on App Engine. + if apiproxy_stub_map.apiproxy.GetStub("urlfetch") is None: + raise ImportError + from google.appengine.api.urlfetch import fetch - from google.appengine.api.urlfetch import InvalidURLError # Update the connection classes to use the Googel App Engine specific ones. SCHEME_TO_CONNECTION = { - 'http': AppEngineHttpConnection, - 'https': AppEngineHttpsConnection + "http": AppEngineHttpConnection, + "https": AppEngineHttpsConnection, } -except (ImportError, AttributeError, NotRunningAppEngineEnvironment): +except (ImportError, NotRunningAppEngineEnvironment): pass @@ -1231,10 +1571,16 @@ class Http(object): and more. """ - def __init__(self, cache=None, timeout=None, - proxy_info=proxy_info_from_environment, - ca_certs=None, disable_ssl_certificate_validation=False, - ssl_version=None): + + def __init__( + self, + cache=None, + timeout=None, + proxy_info=proxy_info_from_environment, + ca_certs=None, + disable_ssl_certificate_validation=False, + ssl_version=None, + ): """If 'cache' is a string then it is used as a directory name for a disk cache. Otherwise it must be an object that supports the same interface as FileCache. @@ -1262,8 +1608,7 @@ class Http(object): """ self.proxy_info = proxy_info self.ca_certs = ca_certs - self.disable_ssl_certificate_validation = \ - disable_ssl_certificate_validation + self.disable_ssl_certificate_validation = disable_ssl_certificate_validation self.ssl_version = ssl_version # Map domain name to an httplib connection @@ -1308,10 +1653,10 @@ class Http(object): state_dict = copy.copy(self.__dict__) # In case request is augmented by some foreign object such as # credentials which handle auth - if 'request' in state_dict: - del state_dict['request'] - if 'connections' in state_dict: - del state_dict['connections'] + if "request" in state_dict: + del state_dict["request"] + if "connections" in state_dict: + del state_dict["connections"] return state_dict def __setstate__(self, state): @@ -1322,11 +1667,13 @@ class Http(object): """A generator that creates Authorization objects that can be applied to requests. """ - challenges = _parse_www_authenticate(response, 'www-authenticate') + challenges = _parse_www_authenticate(response, "www-authenticate") for cred in self.credentials.iter(host): for scheme in AUTH_SCHEME_ORDER: if scheme in challenges: - yield AUTH_SCHEME_CLASSES[scheme](cred, host, request_uri, headers, response, content, self) + yield AUTH_SCHEME_CLASSES[scheme]( + cred, host, request_uri, headers, response, content, self + ) def add_credentials(self, name, password, domain=""): """Add a name and password that will be used @@ -1350,7 +1697,7 @@ class Http(object): while i < RETRIES: i += 1 try: - if hasattr(conn, 'sock') and conn.sock is None: + if hasattr(conn, "sock") and conn.sock is None: conn.connect() conn.request(method, request_uri, body, headers) except socket.timeout: @@ -1363,8 +1710,8 @@ class Http(object): raise except socket.error as e: err = 0 - if hasattr(e, 'args'): - err = getattr(e, 'args')[0] + if hasattr(e, "args"): + err = getattr(e, "args")[0] else: err = e.errno if err == errno.ECONNREFUSED: # Connection refused @@ -1374,15 +1721,15 @@ class Http(object): except httplib.HTTPException: # Just because the server closed the connection doesn't apparently mean # that the server didn't send a response. - if hasattr(conn, 'sock') and conn.sock is None: - if i < RETRIES-1: + if hasattr(conn, "sock") and conn.sock is None: + if i < RETRIES - 1: conn.close() conn.connect() continue else: conn.close() raise - if i < RETRIES-1: + if i < RETRIES - 1: conn.close() conn.connect() continue @@ -1402,7 +1749,7 @@ class Http(object): conn.close() raise except (socket.error, httplib.HTTPException): - if i < RETRIES-1: + if i < RETRIES - 1: conn.close() conn.connect() continue @@ -1421,77 +1768,121 @@ class Http(object): break return (response, content) - - def _request(self, conn, host, absolute_uri, request_uri, method, body, headers, redirections, cachekey): + def _request( + self, + conn, + host, + absolute_uri, + request_uri, + method, + body, + headers, + redirections, + cachekey, + ): """Do the actual request using the connection object and also follow one level of redirects if necessary""" - auths = [(auth.depth(request_uri), auth) for auth in self.authorizations if auth.inscope(host, request_uri)] + auths = [ + (auth.depth(request_uri), auth) + for auth in self.authorizations + if auth.inscope(host, request_uri) + ] auth = auths and sorted(auths)[0][1] or None if auth: auth.request(method, request_uri, headers, body) - (response, content) = self._conn_request(conn, request_uri, method, body, headers) + (response, content) = self._conn_request( + conn, request_uri, method, body, headers + ) if auth: if auth.response(response, body): auth.request(method, request_uri, headers, body) - (response, content) = self._conn_request(conn, request_uri, method, body, headers ) + (response, content) = self._conn_request( + conn, request_uri, method, body, headers + ) response._stale_digest = 1 if response.status == 401: - for authorization in self._auth_from_challenge(host, request_uri, headers, response, content): + for authorization in self._auth_from_challenge( + host, request_uri, headers, response, content + ): authorization.request(method, request_uri, headers, body) - (response, content) = self._conn_request(conn, request_uri, method, body, headers, ) + (response, content) = self._conn_request( + conn, request_uri, method, body, headers + ) if response.status != 401: self.authorizations.append(authorization) authorization.response(response, body) break - if (self.follow_all_redirects or (method in ["GET", "HEAD"]) or response.status == 303): + if ( + self.follow_all_redirects + or (method in ["GET", "HEAD"]) + or response.status == 303 + ): if self.follow_redirects and response.status in [300, 301, 302, 303, 307]: # Pick out the location header and basically start from the beginning # remembering first to strip the ETag header and decrement our 'depth' if redirections: - if 'location' not in response and response.status != 300: - raise RedirectMissingLocation( _("Redirected but the response is missing a Location: header."), response, content) + if "location" not in response and response.status != 300: + raise RedirectMissingLocation( + _( + "Redirected but the response is missing a Location: header." + ), + response, + content, + ) # Fix-up relative redirects (which violate an RFC 2616 MUST) - if 'location' in response: - location = response['location'] + if "location" in response: + location = response["location"] (scheme, authority, path, query, fragment) = parse_uri(location) if authority == None: - response['location'] = urlparse.urljoin(absolute_uri, location) + response["location"] = urlparse.urljoin( + absolute_uri, location + ) if response.status == 301 and method in ["GET", "HEAD"]: - response['-x-permanent-redirect-url'] = response['location'] - if 'content-location' not in response: - response['content-location'] = absolute_uri + response["-x-permanent-redirect-url"] = response["location"] + if "content-location" not in response: + response["content-location"] = absolute_uri _updateCache(headers, response, content, self.cache, cachekey) - if 'if-none-match' in headers: - del headers['if-none-match'] - if 'if-modified-since' in headers: - del headers['if-modified-since'] - if 'authorization' in headers and not self.forward_authorization_headers: - del headers['authorization'] - if 'location' in response: - location = response['location'] + if "if-none-match" in headers: + del headers["if-none-match"] + if "if-modified-since" in headers: + del headers["if-modified-since"] + if ( + "authorization" in headers + and not self.forward_authorization_headers + ): + del headers["authorization"] + if "location" in response: + location = response["location"] old_response = copy.deepcopy(response) - if 'content-location' not in old_response: - old_response['content-location'] = absolute_uri + if "content-location" not in old_response: + old_response["content-location"] = absolute_uri redirect_method = method if response.status in [302, 303]: redirect_method = "GET" body = None (response, content) = self.request( - location, method=redirect_method, - body=body, headers=headers, - redirections=redirections - 1) + location, + method=redirect_method, + body=body, + headers=headers, + redirections=redirections - 1, + ) response.previous = old_response else: - raise RedirectLimit("Redirected more times than rediection_limit allows.", response, content) + raise RedirectLimit( + "Redirected more times than rediection_limit allows.", + response, + content, + ) elif response.status in [200, 203] and method in ["GET", "HEAD"]: # Don't cache 206's since we aren't going to handle byte range requests - if 'content-location' not in response: - response['content-location'] = absolute_uri + if "content-location" not in response: + response["content-location"] = absolute_uri _updateCache(headers, response, content, self.cache, cachekey) return (response, content) @@ -1499,12 +1890,19 @@ class Http(object): def _normalize_headers(self, headers): return _normalize_headers(headers) -# Need to catch and rebrand some exceptions -# Then need to optionally turn all exceptions into status codes -# including all socket.* and httplib.* exceptions. + # Need to catch and rebrand some exceptions + # Then need to optionally turn all exceptions into status codes + # including all socket.* and httplib.* exceptions. - - def request(self, uri, method="GET", body=None, headers=None, redirections=DEFAULT_MAX_REDIRECTS, connection_type=None): + def request( + self, + uri, + method="GET", + body=None, + headers=None, + redirections=DEFAULT_MAX_REDIRECTS, + connection_type=None, + ): """ Performs a single HTTP request. The 'uri' is the URI of the HTTP resource and can begin with either @@ -1526,63 +1924,63 @@ class Http(object): being and instance of the 'Response' class, the second being a string that contains the response entity body. """ + conn_key = '' + try: if headers is None: headers = {} else: headers = self._normalize_headers(headers) - if 'user-agent' not in headers: - headers['user-agent'] = "Python-httplib2/%s (gzip)" % __version__ + if "user-agent" not in headers: + headers["user-agent"] = "Python-httplib2/%s (gzip)" % __version__ uri = iri2uri(uri) (scheme, authority, request_uri, defrag_uri) = urlnorm(uri) - domain_port = authority.split(":")[0:2] - if len(domain_port) == 2 and domain_port[1] == '443' and scheme == 'http': - scheme = 'https' - authority = domain_port[0] proxy_info = self._get_proxy_info(scheme, authority) - conn_key = scheme+":"+authority - if conn_key in self.connections: - conn = self.connections[conn_key] - else: + conn_key = scheme + ":" + authority + conn = self.connections.get(conn_key) + if conn is None: if not connection_type: connection_type = SCHEME_TO_CONNECTION[scheme] certs = list(self.certificates.iter(authority)) - if scheme == 'https': + if scheme == "https": if certs: conn = self.connections[conn_key] = connection_type( - authority, key_file=certs[0][0], - cert_file=certs[0][1], timeout=self.timeout, - proxy_info=proxy_info, - ca_certs=self.ca_certs, - disable_ssl_certificate_validation= - self.disable_ssl_certificate_validation, - ssl_version=self.ssl_version) + authority, + key_file=certs[0][0], + cert_file=certs[0][1], + timeout=self.timeout, + proxy_info=proxy_info, + ca_certs=self.ca_certs, + disable_ssl_certificate_validation=self.disable_ssl_certificate_validation, + ssl_version=self.ssl_version, + ) else: conn = self.connections[conn_key] = connection_type( - authority, timeout=self.timeout, - proxy_info=proxy_info, - ca_certs=self.ca_certs, - disable_ssl_certificate_validation= - self.disable_ssl_certificate_validation, - ssl_version=self.ssl_version) + authority, + timeout=self.timeout, + proxy_info=proxy_info, + ca_certs=self.ca_certs, + disable_ssl_certificate_validation=self.disable_ssl_certificate_validation, + ssl_version=self.ssl_version, + ) else: conn = self.connections[conn_key] = connection_type( - authority, timeout=self.timeout, - proxy_info=proxy_info) + authority, timeout=self.timeout, proxy_info=proxy_info + ) conn.set_debuglevel(debuglevel) - if 'range' not in headers and 'accept-encoding' not in headers: - headers['accept-encoding'] = 'gzip, deflate' + if "range" not in headers and "accept-encoding" not in headers: + headers["accept-encoding"] = "gzip, deflate" info = email.Message.Message() cached_value = None if self.cache: - cachekey = defrag_uri.encode('utf-8') + cachekey = defrag_uri.encode("utf-8") cached_value = self.cache.get(cachekey) if cached_value: # info = email.message_from_string(cached_value) @@ -1591,7 +1989,7 @@ class Http(object): # to fix the non-existent bug not fixed in this # bug report: http://mail.python.org/pipermail/python-bugs-list/2005-September/030289.html try: - info, content = cached_value.split('\r\n\r\n', 1) + info, content = cached_value.split("\r\n\r\n", 1) feedparser = email.FeedParser.FeedParser() feedparser.feed(info) info = feedparser.close() @@ -1603,9 +2001,15 @@ class Http(object): else: cachekey = None - if method in self.optimistic_concurrency_methods and self.cache and 'etag' in info and not self.ignore_etag and 'if-match' not in headers: + if ( + method in self.optimistic_concurrency_methods + and self.cache + and "etag" in info + and not self.ignore_etag + and "if-match" not in headers + ): # http://www.w3.org/1999/04/Editing/ - headers['if-match'] = info['etag'] + headers["if-match"] = info["etag"] if method not in ["GET", "HEAD"] and self.cache and cachekey: # RFC 2616 Section 13.10 @@ -1613,24 +2017,36 @@ class Http(object): # Check the vary header in the cache to see if this request # matches what varies in the cache. - if method in ['GET', 'HEAD'] and 'vary' in info: - vary = info['vary'] - vary_headers = vary.lower().replace(' ', '').split(',') + if method in ["GET", "HEAD"] and "vary" in info: + vary = info["vary"] + vary_headers = vary.lower().replace(" ", "").split(",") for header in vary_headers: - key = '-varied-%s' % header + key = "-varied-%s" % header value = info[key] if headers.get(header, None) != value: cached_value = None break - if cached_value and method in ["GET", "HEAD"] and self.cache and 'range' not in headers: - if '-x-permanent-redirect-url' in info: + if ( + cached_value + and method in ["GET", "HEAD"] + and self.cache + and "range" not in headers + ): + if "-x-permanent-redirect-url" in info: # Should cached permanent redirects be counted in our redirection count? For now, yes. if redirections <= 0: - raise RedirectLimit("Redirected more times than rediection_limit allows.", {}, "") + raise RedirectLimit( + "Redirected more times than rediection_limit allows.", + {}, + "", + ) (response, new_content) = self.request( - info['-x-permanent-redirect-url'], method='GET', - headers=headers, redirections=redirections - 1) + info["-x-permanent-redirect-url"], + method="GET", + headers=headers, + redirections=redirections - 1, + ) response.previous = Response(info) response.previous.fromcache = True else: @@ -1646,7 +2062,7 @@ class Http(object): if entry_disposition == "FRESH": if not cached_value: - info['status'] = '504' + info["status"] = "504" content = "" response = Response(info) if cached_value: @@ -1654,14 +2070,28 @@ class Http(object): return (response, content) if entry_disposition == "STALE": - if 'etag' in info and not self.ignore_etag and not 'if-none-match' in headers: - headers['if-none-match'] = info['etag'] - if 'last-modified' in info and not 'last-modified' in headers: - headers['if-modified-since'] = info['last-modified'] + if ( + "etag" in info + and not self.ignore_etag + and not "if-none-match" in headers + ): + headers["if-none-match"] = info["etag"] + if "last-modified" in info and not "last-modified" in headers: + headers["if-modified-since"] = info["last-modified"] elif entry_disposition == "TRANSPARENT": pass - (response, new_content) = self._request(conn, authority, uri, request_uri, method, body, headers, redirections, cachekey) + (response, new_content) = self._request( + conn, + authority, + uri, + request_uri, + method, + body, + headers, + redirections, + cachekey, + ) if response.status == 304 and method == "GET": # Rewrite the cache entry with the new end-to-end headers @@ -1674,7 +2104,9 @@ class Http(object): merged_response = Response(info) if hasattr(response, "_stale_digest"): merged_response._stale_digest = response._stale_digest - _updateCache(headers, merged_response, content, self.cache, cachekey) + _updateCache( + headers, merged_response, content, self.cache, cachekey + ) response = merged_response response.status = 200 response.fromcache = True @@ -1686,39 +2118,58 @@ class Http(object): content = new_content else: cc = _parse_cache_control(headers) - if 'only-if-cached' in cc: - info['status'] = '504' + if "only-if-cached" in cc: + info["status"] = "504" response = Response(info) content = "" else: - (response, content) = self._request(conn, authority, uri, request_uri, method, body, headers, redirections, cachekey) + (response, content) = self._request( + conn, + authority, + uri, + request_uri, + method, + body, + headers, + redirections, + cachekey, + ) except Exception as e: + is_timeout = isinstance(e, socket.timeout) + if is_timeout: + conn = self.connections.pop(conn_key, None) + if conn: + conn.close() + if self.force_exception_to_status_code: if isinstance(e, HttpLib2ErrorWithResponse): response = e.response content = e.content response.status = 500 response.reason = str(e) - elif isinstance(e, socket.timeout): + elif is_timeout: content = "Request Timeout" - response = Response({ - "content-type": "text/plain", - "status": "408", - "content-length": len(content) - }) + response = Response( + { + "content-type": "text/plain", + "status": "408", + "content-length": len(content), + } + ) response.reason = "Request Timeout" else: content = str(e) - response = Response({ - "content-type": "text/plain", - "status": "400", - "content-length": len(content) - }) + response = Response( + { + "content-type": "text/plain", + "status": "400", + "content-length": len(content), + } + ) response.reason = "Bad Request" else: raise - return (response, content) def _get_proxy_info(self, scheme, authority): @@ -1730,8 +2181,7 @@ class Http(object): if callable(proxy_info): proxy_info = proxy_info(scheme) - if (hasattr(proxy_info, 'applies_to') - and not proxy_info.applies_to(hostname)): + if hasattr(proxy_info, "applies_to") and not proxy_info.applies_to(hostname): proxy_info = None return proxy_info @@ -1741,13 +2191,14 @@ class Response(dict): """Is this response from our local cache""" fromcache = False + """HTTP protocol version used by server. - """HTTP protocol version used by server. 10 for HTTP/1.0, 11 for HTTP/1.1. """ + 10 for HTTP/1.0, 11 for HTTP/1.1. + """ version = 11 "Status code returned by server. " status = 200 - """Reason phrase returned by server.""" reason = "Ok" @@ -1760,21 +2211,21 @@ class Response(dict): for key, value in info.getheaders(): self[key.lower()] = value self.status = info.status - self['status'] = str(self.status) + self["status"] = str(self.status) self.reason = info.reason self.version = info.version elif isinstance(info, email.Message.Message): for key, value in info.items(): self[key.lower()] = value - self.status = int(self['status']) + self.status = int(self["status"]) else: for key, value in info.iteritems(): self[key.lower()] = value - self.status = int(self.get('status', self.status)) - self.reason = self.get('reason', self.reason) + self.status = int(self.get("status", self.status)) + self.reason = self.get("reason", self.reason) def __getattr__(self, name): - if name == 'dict': + if name == "dict": return self else: raise AttributeError(name) diff --git a/src/httplib2/certs.py b/src/httplib2/certs.py new file mode 100644 index 00000000..59d1ffc7 --- /dev/null +++ b/src/httplib2/certs.py @@ -0,0 +1,42 @@ +"""Utilities for certificate management.""" + +import os + +certifi_available = False +certifi_where = None +try: + from certifi import where as certifi_where + certifi_available = True +except ImportError: + pass + +custom_ca_locater_available = False +custom_ca_locater_where = None +try: + from ca_certs_locater import get as custom_ca_locater_where + custom_ca_locater_available = True +except ImportError: + pass + + +BUILTIN_CA_CERTS = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cacerts.txt" +) + + +def where(): + env = os.environ.get("HTTPLIB2_CA_CERTS") + if env is not None: + if os.path.isfile(env): + return env + else: + raise RuntimeError("Environment variable HTTPLIB2_CA_CERTS not a valid file") + if custom_ca_locater_available: + return custom_ca_locater_where() + if certifi_available: + return certifi_where() + return BUILTIN_CA_CERTS + + +if __name__ == "__main__": + print(where()) diff --git a/src/httplib2/iri2uri.py b/src/httplib2/iri2uri.py index d88c91fd..0a978a78 100644 --- a/src/httplib2/iri2uri.py +++ b/src/httplib2/iri2uri.py @@ -1,20 +1,13 @@ -""" -iri2uri +"""Converts an IRI to a URI.""" -Converts an IRI to a URI. - -""" __author__ = "Joe Gregorio (joe@bitworking.org)" __copyright__ = "Copyright 2006, Joe Gregorio" __contributors__ = [] __version__ = "1.0.0" __license__ = "MIT" -__history__ = """ -""" import urlparse - # Convert an IRI to a URI following the rules in RFC 3987 # # The characters we need to enocde and escape are defined in the spec: @@ -50,6 +43,7 @@ escape_range = [ (0x100000, 0x10FFFD), ] + def encode(c): retval = c i = ord(c) @@ -57,7 +51,7 @@ def encode(c): if i < low: break if i >= low and i <= high: - retval = "".join(["%%%2X" % ord(o) for o in c.encode('utf-8')]) + retval = "".join(["%%%2X" % ord(o) for o in c.encode("utf-8")]) break return retval @@ -66,9 +60,9 @@ def iri2uri(uri): """Convert an IRI to a URI. Note that IRIs must be passed in a unicode strings. That is, do not utf-8 encode the IRI before passing it into the function.""" - if isinstance(uri ,unicode): + if isinstance(uri, unicode): (scheme, authority, path, query, fragment) = urlparse.urlsplit(uri) - authority = authority.encode('idna') + authority = authority.encode("idna") # For each character in 'ucschar' or 'iprivate' # 1. encode as utf-8 # 2. then %-encode each octet of that utf-8 @@ -76,11 +70,11 @@ def iri2uri(uri): uri = "".join([encode(c) for c in uri]) return uri + if __name__ == "__main__": import unittest class Test(unittest.TestCase): - def test_uris(self): """Test that URIs are invariant under the transformation.""" invariant = [ @@ -91,20 +85,39 @@ if __name__ == "__main__": u"news:comp.infosystems.www.servers.unix", u"tel:+1-816-555-1212", u"telnet://192.0.2.16:80/", - u"urn:oasis:names:specification:docbook:dtd:xml:4.1.2" ] + u"urn:oasis:names:specification:docbook:dtd:xml:4.1.2", + ] for uri in invariant: self.assertEqual(uri, iri2uri(uri)) def test_iri(self): - """ Test that the right type of escaping is done for each part of the URI.""" - self.assertEqual("http://xn--o3h.com/%E2%98%84", iri2uri(u"http://\N{COMET}.com/\N{COMET}")) - self.assertEqual("http://bitworking.org/?fred=%E2%98%84", iri2uri(u"http://bitworking.org/?fred=\N{COMET}")) - self.assertEqual("http://bitworking.org/#%E2%98%84", iri2uri(u"http://bitworking.org/#\N{COMET}")) + """Test that the right type of escaping is done for each part of the URI.""" + self.assertEqual( + "http://xn--o3h.com/%E2%98%84", + iri2uri(u"http://\N{COMET}.com/\N{COMET}"), + ) + self.assertEqual( + "http://bitworking.org/?fred=%E2%98%84", + iri2uri(u"http://bitworking.org/?fred=\N{COMET}"), + ) + self.assertEqual( + "http://bitworking.org/#%E2%98%84", + iri2uri(u"http://bitworking.org/#\N{COMET}"), + ) self.assertEqual("#%E2%98%84", iri2uri(u"#\N{COMET}")) - self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}")) - self.assertEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}"))) - self.assertNotEqual("/fred?bar=%E2%98%9A#%E2%98%84", iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}".encode('utf-8'))) + self.assertEqual( + "/fred?bar=%E2%98%9A#%E2%98%84", + iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}"), + ) + self.assertEqual( + "/fred?bar=%E2%98%9A#%E2%98%84", + iri2uri(iri2uri(u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}")), + ) + self.assertNotEqual( + "/fred?bar=%E2%98%9A#%E2%98%84", + iri2uri( + u"/fred?bar=\N{BLACK LEFT POINTING INDEX}#\N{COMET}".encode("utf-8") + ), + ) unittest.main() - - diff --git a/src/httplib2/socks.py b/src/httplib2/socks.py index dbbe5114..5cef7760 100644 --- a/src/httplib2/socks.py +++ b/src/httplib2/socks.py @@ -1,4 +1,5 @@ """SocksiPy - Python SOCKS module. + Version 1.00 Copyright 2006 Dan-Haim. All rights reserved. @@ -24,20 +25,14 @@ OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMANGE. - This module provides a standard socket-like interface for Python for tunneling connections through SOCKS proxies. -""" - -""" - -Minor modifications made by Christopher Gilbert (http://motomastyle.com/) -for use in PyLoris (http://pyloris.sourceforge.net/) +Minor modifications made by Christopher Gilbert (http://motomastyle.com/) for +use in PyLoris (http://pyloris.sourceforge.net/). Minor modifications made by Mario Vilas (http://breakingcode.wordpress.com/) -mainly to merge bug fixes found in Sourceforge - +mainly to merge bug fixes found in Sourceforge. """ import base64 @@ -45,8 +40,8 @@ import socket import struct import sys -if getattr(socket, 'socket', None) is None: - raise ImportError('socket.socket missing, proxy support unusable') +if getattr(socket, "socket", None) is None: + raise ImportError("socket.socket missing, proxy support unusable") PROXY_TYPE_SOCKS4 = 1 PROXY_TYPE_SOCKS5 = 2 @@ -56,21 +51,42 @@ PROXY_TYPE_HTTP_NO_TUNNEL = 4 _defaultproxy = None _orgsocket = socket.socket -class ProxyError(Exception): pass -class GeneralProxyError(ProxyError): pass -class Socks5AuthError(ProxyError): pass -class Socks5Error(ProxyError): pass -class Socks4Error(ProxyError): pass -class HTTPError(ProxyError): pass -_generalerrors = ("success", +class ProxyError(Exception): + pass + + +class GeneralProxyError(ProxyError): + pass + + +class Socks5AuthError(ProxyError): + pass + + +class Socks5Error(ProxyError): + pass + + +class Socks4Error(ProxyError): + pass + + +class HTTPError(ProxyError): + pass + + +_generalerrors = ( + "success", "invalid data", "not connected", "not available", "bad proxy type", - "bad input") + "bad input", +) -_socks5errors = ("succeeded", +_socks5errors = ( + "succeeded", "general SOCKS server failure", "connection not allowed by ruleset", "Network unreachable", @@ -79,21 +95,30 @@ _socks5errors = ("succeeded", "TTL expired", "Command not supported", "Address type not supported", - "Unknown error") + "Unknown error", +) -_socks5autherrors = ("succeeded", +_socks5autherrors = ( + "succeeded", "authentication is required", "all offered authentication methods were rejected", "unknown username or invalid password", - "unknown error") + "unknown error", +) -_socks4errors = ("request granted", +_socks4errors = ( + "request granted", "request rejected or failed", "request rejected because SOCKS server cannot connect to identd on the client", - "request rejected because the client program and identd report different user-ids", - "unknown error") + "request rejected because the client program and identd report different " + "user-ids", + "unknown error", +) -def setdefaultproxy(proxytype=None, addr=None, port=None, rdns=True, username=None, password=None): + +def setdefaultproxy( + proxytype=None, addr=None, port=None, rdns=True, username=None, password=None +): """setdefaultproxy(proxytype, addr[, port[, rdns[, username[, password]]]]) Sets a default proxy which all further socksocket objects will use, unless explicitly changed. @@ -101,11 +126,14 @@ def setdefaultproxy(proxytype=None, addr=None, port=None, rdns=True, username=No global _defaultproxy _defaultproxy = (proxytype, addr, port, rdns, username, password) + def wrapmodule(module): """wrapmodule(module) + Attempts to replace a module's socket library with a SOCKS socket. Must set a default proxy using setdefaultproxy(...) first. - This will only work on modules that import socket directly into the namespace; + This will only work on modules that import socket directly into the + namespace; most of the Python Standard Library falls into this category. """ if _defaultproxy != None: @@ -113,6 +141,7 @@ def wrapmodule(module): else: raise GeneralProxyError((4, "no proxy specified")) + class socksocket(socket.socket): """socksocket([family[, type[, proto]]]) -> socket object Open a SOCKS enabled socket. The parameters are the same as @@ -120,7 +149,9 @@ class socksocket(socket.socket): you must specify family=AF_INET, type=SOCK_STREAM and proto=0. """ - def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, _sock=None): + def __init__( + self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, _sock=None + ): _orgsocket.__init__(self, family, type, proto, _sock) if _defaultproxy != None: self.__proxy = _defaultproxy @@ -137,8 +168,9 @@ class socksocket(socket.socket): """ data = self.recv(count) while len(data) < count: - d = self.recv(count-len(data)) - if not d: raise GeneralProxyError((0, "connection closed unexpectedly")) + d = self.recv(count - len(data)) + if not d: + raise GeneralProxyError((0, "connection closed unexpectedly")) data = data + d return data @@ -167,7 +199,7 @@ class socksocket(socket.socket): hdrs.remove(endpt) host = host.split(" ")[1] endpt = endpt.split(" ") - if (self.__proxy[4] != None and self.__proxy[5] != None): + if self.__proxy[4] != None and self.__proxy[5] != None: hdrs.insert(0, self.__getauthheader()) hdrs.insert(0, "Host: %s" % host) hdrs.insert(0, "%s http://%s%s %s" % (endpt[0], host, endpt[1], endpt[2])) @@ -177,8 +209,18 @@ class socksocket(socket.socket): auth = self.__proxy[4] + ":" + self.__proxy[5] return "Proxy-Authorization: Basic " + base64.b64encode(auth) - def setproxy(self, proxytype=None, addr=None, port=None, rdns=True, username=None, password=None, headers=None): + def setproxy( + self, + proxytype=None, + addr=None, + port=None, + rdns=True, + username=None, + password=None, + headers=None, + ): """setproxy(proxytype, addr[, port[, rdns[, username[, password]]]]) + Sets the proxy to be used. proxytype - The type of the proxy to be used. Three types are supported: PROXY_TYPE_SOCKS4 (including socks4a), @@ -193,7 +235,8 @@ class socksocket(socket.socket): The default is no authentication. password - Password to authenticate with to the server. Only relevant when username is also provided. - headers - Additional or modified headers for the proxy connect request. + headers - Additional or modified headers for the proxy connect + request. """ self.__proxy = (proxytype, addr, port, rdns, username, password, headers) @@ -202,15 +245,15 @@ class socksocket(socket.socket): Negotiates a connection through a SOCKS5 server. """ # First we'll send the authentication packages we support. - if (self.__proxy[4]!=None) and (self.__proxy[5]!=None): + if (self.__proxy[4] != None) and (self.__proxy[5] != None): # The username/password details were supplied to the # setproxy method so we support the USERNAME/PASSWORD # authentication (in addition to the standard none). - self.sendall(struct.pack('BBBB', 0x05, 0x02, 0x00, 0x02)) + self.sendall(struct.pack("BBBB", 0x05, 0x02, 0x00, 0x02)) else: # No username/password were entered, therefore we # only support connections with no authentication. - self.sendall(struct.pack('BBB', 0x05, 0x01, 0x00)) + self.sendall(struct.pack("BBB", 0x05, 0x01, 0x00)) # We'll receive the server's response to determine which # method was selected chosenauth = self.__recvall(2) @@ -224,7 +267,13 @@ class socksocket(socket.socket): elif chosenauth[1:2] == chr(0x02).encode(): # Okay, we need to perform a basic username/password # authentication. - self.sendall(chr(0x01).encode() + chr(len(self.__proxy[4])) + self.__proxy[4] + chr(len(self.__proxy[5])) + self.__proxy[5]) + self.sendall( + chr(0x01).encode() + + chr(len(self.__proxy[4])) + + self.__proxy[4] + + chr(len(self.__proxy[5])) + + self.__proxy[5] + ) authstat = self.__recvall(2) if authstat[0:1] != chr(0x01).encode(): # Bad response @@ -243,7 +292,7 @@ class socksocket(socket.socket): else: raise GeneralProxyError((1, _generalerrors[1])) # Now we can request the actual connection - req = struct.pack('BBB', 0x05, 0x01, 0x00) + req = struct.pack("BBB", 0x05, 0x01, 0x00) # If the given destination address is an IP address, we'll # use the IPv4 address request even if remote resolving was specified. try: @@ -254,7 +303,12 @@ class socksocket(socket.socket): if self.__proxy[3]: # Resolve remotely ipaddr = None - req = req + chr(0x03).encode() + chr(len(destaddr)).encode() + destaddr.encode() + req = ( + req + + chr(0x03).encode() + + chr(len(destaddr)).encode() + + destaddr.encode() + ) else: # Resolve locally ipaddr = socket.inet_aton(socket.gethostbyname(destaddr)) @@ -269,7 +323,7 @@ class socksocket(socket.socket): elif resp[1:2] != chr(0x00).encode(): # Connection failed self.close() - if ord(resp[1:2])<=8: + if ord(resp[1:2]) <= 8: raise Socks5Error((ord(resp[1:2]), _socks5errors[ord(resp[1:2])])) else: raise Socks5Error((9, _socks5errors[9])) @@ -281,7 +335,7 @@ class socksocket(socket.socket): boundaddr = self.__recvall(ord(resp[4:5])) else: self.close() - raise GeneralProxyError((1,_generalerrors[1])) + raise GeneralProxyError((1, _generalerrors[1])) boundport = struct.unpack(">H", self.__recvall(2))[0] self.__proxysockname = (boundaddr, boundport) if ipaddr != None: @@ -308,7 +362,7 @@ class socksocket(socket.socket): """ return self.__proxypeername - def __negotiatesocks4(self,destaddr,destport): + def __negotiatesocks4(self, destaddr, destport): """__negotiatesocks4(self,destaddr,destport) Negotiates a connection through a SOCKS4 server. """ @@ -340,7 +394,7 @@ class socksocket(socket.socket): if resp[0:1] != chr(0x00).encode(): # Bad data self.close() - raise GeneralProxyError((1,_generalerrors[1])) + raise GeneralProxyError((1, _generalerrors[1])) if resp[1:2] != chr(0x5A).encode(): # Server returned an error self.close() @@ -350,7 +404,10 @@ class socksocket(socket.socket): else: raise Socks4Error((94, _socks4errors[4])) # Get the bound address/port - self.__proxysockname = (socket.inet_ntoa(resp[4:]), struct.unpack(">H", resp[2:4])[0]) + self.__proxysockname = ( + socket.inet_ntoa(resp[4:]), + struct.unpack(">H", resp[2:4])[0], + ) if rmtrslv != None: self.__proxypeername = (socket.inet_ntoa(ipaddr), destport) else: @@ -365,18 +422,18 @@ class socksocket(socket.socket): addr = socket.gethostbyname(destaddr) else: addr = destaddr - headers = ["CONNECT ", addr, ":", str(destport), " HTTP/1.1\r\n"] + headers = ["CONNECT ", addr, ":", str(destport), " HTTP/1.1\r\n"] wrote_host_header = False wrote_auth_header = False if self.__proxy[6] != None: for key, val in self.__proxy[6].iteritems(): headers += [key, ": ", val, "\r\n"] - wrote_host_header = (key.lower() == "host") - wrote_auth_header = (key.lower() == "proxy-authorization") + wrote_host_header = key.lower() == "host" + wrote_auth_header = key.lower() == "proxy-authorization" if not wrote_host_header: headers += ["Host: ", destaddr, "\r\n"] if not wrote_auth_header: - if (self.__proxy[4] != None and self.__proxy[5] != None): + if self.__proxy[4] != None and self.__proxy[5] != None: headers += [self.__getauthheader(), "\r\n"] headers.append("\r\n") self.sendall("".join(headers).encode()) @@ -409,7 +466,12 @@ class socksocket(socket.socket): To select the proxy server use setproxy(). """ # Do a minimal input check first - if (not type(destpair) in (list,tuple)) or (len(destpair) < 2) or (not isinstance(destpair[0], basestring)) or (type(destpair[1]) != int): + if ( + (not type(destpair) in (list, tuple)) + or (len(destpair) < 2) + or (not isinstance(destpair[0], basestring)) + or (type(destpair[1]) != int) + ): raise GeneralProxyError((5, _generalerrors[5])) if self.__proxy[0] == PROXY_TYPE_SOCKS5: if self.__proxy[2] != None: @@ -423,23 +485,23 @@ class socksocket(socket.socket): portnum = self.__proxy[2] else: portnum = 1080 - _orgsocket.connect(self,(self.__proxy[1], portnum)) + _orgsocket.connect(self, (self.__proxy[1], portnum)) self.__negotiatesocks4(destpair[0], destpair[1]) elif self.__proxy[0] == PROXY_TYPE_HTTP: if self.__proxy[2] != None: portnum = self.__proxy[2] else: portnum = 8080 - _orgsocket.connect(self,(self.__proxy[1], portnum)) + _orgsocket.connect(self, (self.__proxy[1], portnum)) self.__negotiatehttp(destpair[0], destpair[1]) elif self.__proxy[0] == PROXY_TYPE_HTTP_NO_TUNNEL: if self.__proxy[2] != None: portnum = self.__proxy[2] else: portnum = 8080 - _orgsocket.connect(self,(self.__proxy[1],portnum)) + _orgsocket.connect(self, (self.__proxy[1], portnum)) if destpair[1] == 443: - self.__negotiatehttp(destpair[0],destpair[1]) + self.__negotiatehttp(destpair[0], destpair[1]) else: self.__httptunnel = False elif self.__proxy[0] == None: diff --git a/src/httplib2/test/functional/test_proxies.py b/src/httplib2/test/functional/test_proxies.py index e11369da..939140d4 100644 --- a/src/httplib2/test/functional/test_proxies.py +++ b/src/httplib2/test/functional/test_proxies.py @@ -27,35 +27,35 @@ LogLevel Info class FunctionalProxyHttpTest(unittest.TestCase): def setUp(self): if not socks: - raise nose.SkipTest('socks module unavailable') + raise nose.SkipTest("socks module unavailable") if not subprocess: - raise nose.SkipTest('subprocess module unavailable') + raise nose.SkipTest("subprocess module unavailable") # start a short-lived miniserver so we can get a likely port # for the proxy - self.httpd, self.proxyport = miniserver.start_server( - miniserver.ThisDirHandler) + self.httpd, self.proxyport = miniserver.start_server(miniserver.ThisDirHandler) self.httpd.shutdown() - self.httpd, self.port = miniserver.start_server( - miniserver.ThisDirHandler) + self.httpd, self.port = miniserver.start_server(miniserver.ThisDirHandler) self.pidfile = tempfile.mktemp() self.logfile = tempfile.mktemp() fd, self.conffile = tempfile.mkstemp() - f = os.fdopen(fd, 'w') - our_cfg = tinyproxy_cfg % {'user': os.getlogin(), - 'pidfile': self.pidfile, - 'port': self.proxyport, - 'logfile': self.logfile} + f = os.fdopen(fd, "w") + our_cfg = tinyproxy_cfg % { + "user": os.getlogin(), + "pidfile": self.pidfile, + "port": self.proxyport, + "logfile": self.logfile, + } f.write(our_cfg) f.close() try: # TODO use subprocess.check_call when 2.4 is dropped - ret = subprocess.call(['tinyproxy', '-c', self.conffile]) + ret = subprocess.call(["tinyproxy", "-c", self.conffile]) self.assertEqual(0, ret) except OSError as e: if e.errno == errno.ENOENT: - raise nose.SkipTest('tinyproxy not available') + raise nose.SkipTest("tinyproxy not available") raise def tearDown(self): @@ -65,25 +65,23 @@ class FunctionalProxyHttpTest(unittest.TestCase): os.kill(pid, signal.SIGTERM) except OSError as e: if e.errno == errno.ESRCH: - print('\n\n\nTinyProxy Failed to start, log follows:') + print("\n\n\nTinyProxy Failed to start, log follows:") print(open(self.logfile).read()) - print('end tinyproxy log\n\n\n') + print("end tinyproxy log\n\n\n") raise - map(os.unlink, (self.pidfile, - self.logfile, - self.conffile)) + map(os.unlink, (self.pidfile, self.logfile, self.conffile)) def testSimpleProxy(self): - proxy_info = httplib2.ProxyInfo(socks.PROXY_TYPE_HTTP, - 'localhost', self.proxyport) + proxy_info = httplib2.ProxyInfo( + socks.PROXY_TYPE_HTTP, "localhost", self.proxyport + ) client = httplib2.Http(proxy_info=proxy_info) - src = 'miniserver.py' - response, body = client.request('http://localhost:%d/%s' % - (self.port, src)) + src = "miniserver.py" + response, body = client.request("http://localhost:%d/%s" % (self.port, src)) self.assertEqual(response.status, 200) self.assertEqual(body, open(os.path.join(miniserver.HERE, src)).read()) lf = open(self.logfile).read() - expect = ('Established connection to host "127.0.0.1" ' - 'using file descriptor') - self.assertTrue(expect in lf, - 'tinyproxy did not proxy a request for miniserver') + expect = 'Established connection to host "127.0.0.1" ' "using file descriptor" + self.assertTrue( + expect in lf, "tinyproxy did not proxy a request for miniserver" + ) diff --git a/src/httplib2/test/miniserver.py b/src/httplib2/test/miniserver.py index f72eccac..47c3ee5b 100644 --- a/src/httplib2/test/miniserver.py +++ b/src/httplib2/test/miniserver.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class ThisDirHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): def translate_path(self, path): - path = path.split('?', 1)[0].split('#', 1)[0] - return os.path.join(HERE, *filter(None, path.split('/'))) + path = path.split("?", 1)[0].split("#", 1)[0] + return os.path.join(HERE, *filter(None, path.split("/"))) def log_message(self, s, *args): # output via logging so nose can catch it @@ -38,12 +38,13 @@ class ShutdownServer(SocketServer.TCPServer): SocketServer.TCPServer.server_bind(self) if self.__use_tls: import ssl - self.socket = ssl.wrap_socket(self.socket, - os.path.join(os.path.dirname(__file__), 'server.key'), - os.path.join(os.path.dirname(__file__), 'server.pem'), - True - ) + self.socket = ssl.wrap_socket( + self.socket, + os.path.join(os.path.dirname(__file__), "server.key"), + os.path.join(os.path.dirname(__file__), "server.pem"), + True, + ) def serve_forever(self, poll_interval=0.1): """Handle one request at a time until shutdown. diff --git a/src/httplib2/test/smoke_test.py b/src/httplib2/test/smoke_test.py index 9f1e6f01..25e9cf2e 100644 --- a/src/httplib2/test/smoke_test.py +++ b/src/httplib2/test/smoke_test.py @@ -8,16 +8,14 @@ from httplib2.test import miniserver class HttpSmokeTest(unittest.TestCase): def setUp(self): - self.httpd, self.port = miniserver.start_server( - miniserver.ThisDirHandler) + self.httpd, self.port = miniserver.start_server(miniserver.ThisDirHandler) def tearDown(self): self.httpd.shutdown() def testGetFile(self): client = httplib2.Http() - src = 'miniserver.py' - response, body = client.request('http://localhost:%d/%s' % - (self.port, src)) + src = "miniserver.py" + response, body = client.request("http://localhost:%d/%s" % (self.port, src)) self.assertEqual(response.status, 200) self.assertEqual(body, open(os.path.join(miniserver.HERE, src)).read()) diff --git a/src/httplib2/test/test_no_socket.py b/src/httplib2/test/test_no_socket.py index 66ba0563..d251cbc3 100644 --- a/src/httplib2/test/test_no_socket.py +++ b/src/httplib2/test/test_no_socket.py @@ -8,6 +8,7 @@ import unittest import httplib2 + class MissingSocketTest(unittest.TestCase): def setUp(self): self._oldsocks = httplib2.socks @@ -17,8 +18,8 @@ class MissingSocketTest(unittest.TestCase): httplib2.socks = self._oldsocks def testProxyDisabled(self): - proxy_info = httplib2.ProxyInfo('blah', - 'localhost', 0) + proxy_info = httplib2.ProxyInfo("blah", "localhost", 0) client = httplib2.Http(proxy_info=proxy_info) - self.assertRaises(httplib2.ProxiesUnavailableError, - client.request, 'http://localhost:-1/') + self.assertRaises( + httplib2.ProxiesUnavailableError, client.request, "http://localhost:-1/" + ) diff --git a/src/httplib2/test/test_ssl_context.py b/src/httplib2/test/test_ssl_context.py index 5cf9efb0..43504dc7 100644 --- a/src/httplib2/test/test_ssl_context.py +++ b/src/httplib2/test/test_ssl_context.py @@ -10,15 +10,14 @@ import unittest import httplib2 from httplib2.test import miniserver - logger = logging.getLogger(__name__) class KeepAliveHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """Request handler that keeps the HTTP connection open, so that the test can inspect the resulting SSL connection object + """ - Request handler that keeps the HTTP connection open, so that the test can - inspect the resulting SSL connection object - """ + def do_GET(self): self.send_response(200) self.send_header("Content-Length", "0") @@ -40,7 +39,7 @@ class HttpsContextTest(unittest.TestCase): else: return - self.ca_certs_path = os.path.join(os.path.dirname(__file__), 'server.pem') + self.ca_certs_path = os.path.join(os.path.dirname(__file__), "server.pem") self.httpd, self.port = miniserver.start_server(KeepAliveHandler, True) def tearDown(self): @@ -50,16 +49,16 @@ class HttpsContextTest(unittest.TestCase): client = httplib2.Http(ca_certs=self.ca_certs_path) # Establish connection to local server - client.request('https://localhost:%d/' % (self.port)) + client.request("https://localhost:%d/" % (self.port)) # Verify that connection uses a TLS context with the correct hostname - conn = client.connections['https:localhost:%d' % self.port] + conn = client.connections["https:localhost:%d" % self.port] self.assertIsInstance(conn.sock, ssl.SSLSocket) - self.assertTrue(hasattr(conn.sock, 'context')) + self.assertTrue(hasattr(conn.sock, "context")) self.assertIsInstance(conn.sock.context, ssl.SSLContext) self.assertTrue(conn.sock.context.check_hostname) - self.assertEqual(conn.sock.server_hostname, 'localhost') + self.assertEqual(conn.sock.server_hostname, "localhost") self.assertEqual(conn.sock.context.verify_mode, ssl.CERT_REQUIRED) self.assertEqual(conn.sock.context.protocol, ssl.PROTOCOL_SSLv23) @@ -72,15 +71,15 @@ class HttpsContextTest(unittest.TestCase): # which was also added to original patch. # url host is intentionally different, we provoke ssl hostname mismatch error - url = 'https://127.0.0.1:%d/' % (self.port,) + url = "https://127.0.0.1:%d/" % (self.port,) http = httplib2.Http(ca_certs=self.ca_certs_path, proxy_info=None) def once(): try: http.request(url) - assert False, 'expected certificate hostname mismatch error' + assert False, "expected certificate hostname mismatch error" except Exception as e: - print('%s errno=%s' % (repr(e), getattr(e, 'errno', None))) + print("%s errno=%s" % (repr(e), getattr(e, "errno", None))) once() once() diff --git a/src/oauth2client/__init__.py b/src/oauth2client/__init__.py index 31dd701f..92bc191d 100644 --- a/src/oauth2client/__init__.py +++ b/src/oauth2client/__init__.py @@ -14,10 +14,11 @@ """Client library for using OAuth2, especially with Google APIs.""" -__version__ = '4.1.2' +__version__ = '4.1.3' GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/v2/auth' -GOOGLE_DEVICE_URI = 'https://accounts.google.com/o/oauth2/device/code' -GOOGLE_REVOKE_URI = 'https://accounts.google.com/o/oauth2/revoke' -GOOGLE_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token' -GOOGLE_TOKEN_INFO_URI = 'https://www.googleapis.com/oauth2/v3/tokeninfo' +GOOGLE_DEVICE_URI = 'https://oauth2.googleapis.com/device/code' +GOOGLE_REVOKE_URI = 'https://oauth2.googleapis.com/revoke' +GOOGLE_TOKEN_URI = 'https://oauth2.googleapis.com/token' +GOOGLE_TOKEN_INFO_URI = 'https://oauth2.googleapis.com/tokeninfo' + diff --git a/src/oauth2client/contrib/django_util/views.py b/src/oauth2client/contrib/django_util/views.py index 009b544c..1835208a 100644 --- a/src/oauth2client/contrib/django_util/views.py +++ b/src/oauth2client/contrib/django_util/views.py @@ -28,6 +28,7 @@ from django import shortcuts from django.conf import settings from django.core import urlresolvers from django.shortcuts import redirect +from django.utils import html import jsonpickle from six.moves.urllib import parse @@ -109,6 +110,7 @@ def oauth2_callback(request): if 'error' in request.GET: reason = request.GET.get( 'error_description', request.GET.get('error', '')) + reason = html.escape(reason) return http.HttpResponseBadRequest( 'Authorization failed {0}'.format(reason)) diff --git a/src/oauth2client/oauth2client/__init__.py b/src/oauth2client/oauth2client/__init__.py new file mode 100644 index 00000000..92bc191d --- /dev/null +++ b/src/oauth2client/oauth2client/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client library for using OAuth2, especially with Google APIs.""" + +__version__ = '4.1.3' + +GOOGLE_AUTH_URI = 'https://accounts.google.com/o/oauth2/v2/auth' +GOOGLE_DEVICE_URI = 'https://oauth2.googleapis.com/device/code' +GOOGLE_REVOKE_URI = 'https://oauth2.googleapis.com/revoke' +GOOGLE_TOKEN_URI = 'https://oauth2.googleapis.com/token' +GOOGLE_TOKEN_INFO_URI = 'https://oauth2.googleapis.com/tokeninfo' + diff --git a/src/oauth2client/oauth2client/_helpers.py b/src/oauth2client/oauth2client/_helpers.py new file mode 100644 index 00000000..e9123971 --- /dev/null +++ b/src/oauth2client/oauth2client/_helpers.py @@ -0,0 +1,341 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for commonly used utilities.""" + +import base64 +import functools +import inspect +import json +import logging +import os +import warnings + +import six +from six.moves import urllib + + +logger = logging.getLogger(__name__) + +POSITIONAL_WARNING = 'WARNING' +POSITIONAL_EXCEPTION = 'EXCEPTION' +POSITIONAL_IGNORE = 'IGNORE' +POSITIONAL_SET = frozenset([POSITIONAL_WARNING, POSITIONAL_EXCEPTION, + POSITIONAL_IGNORE]) + +positional_parameters_enforcement = POSITIONAL_WARNING + +_SYM_LINK_MESSAGE = 'File: {0}: Is a symbolic link.' +_IS_DIR_MESSAGE = '{0}: Is a directory' +_MISSING_FILE_MESSAGE = 'Cannot access {0}: No such file or directory' + + +def positional(max_positional_args): + """A decorator to declare that only the first N arguments my be positional. + + This decorator makes it easy to support Python 3 style keyword-only + parameters. For example, in Python 3 it is possible to write:: + + def fn(pos1, *, kwonly1=None, kwonly1=None): + ... + + All named parameters after ``*`` must be a keyword:: + + fn(10, 'kw1', 'kw2') # Raises exception. + fn(10, kwonly1='kw1') # Ok. + + Example + ^^^^^^^ + + To define a function like above, do:: + + @positional(1) + def fn(pos1, kwonly1=None, kwonly2=None): + ... + + If no default value is provided to a keyword argument, it becomes a + required keyword argument:: + + @positional(0) + def fn(required_kw): + ... + + This must be called with the keyword parameter:: + + fn() # Raises exception. + fn(10) # Raises exception. + fn(required_kw=10) # Ok. + + When defining instance or class methods always remember to account for + ``self`` and ``cls``:: + + class MyClass(object): + + @positional(2) + def my_method(self, pos1, kwonly1=None): + ... + + @classmethod + @positional(2) + def my_method(cls, pos1, kwonly1=None): + ... + + The positional decorator behavior is controlled by + ``_helpers.positional_parameters_enforcement``, which may be set to + ``POSITIONAL_EXCEPTION``, ``POSITIONAL_WARNING`` or + ``POSITIONAL_IGNORE`` to raise an exception, log a warning, or do + nothing, respectively, if a declaration is violated. + + Args: + max_positional_arguments: Maximum number of positional arguments. All + parameters after the this index must be + keyword only. + + Returns: + A decorator that prevents using arguments after max_positional_args + from being used as positional parameters. + + Raises: + TypeError: if a key-word only argument is provided as a positional + parameter, but only if + _helpers.positional_parameters_enforcement is set to + POSITIONAL_EXCEPTION. + """ + + def positional_decorator(wrapped): + @functools.wraps(wrapped) + def positional_wrapper(*args, **kwargs): + if len(args) > max_positional_args: + plural_s = '' + if max_positional_args != 1: + plural_s = 's' + message = ('{function}() takes at most {args_max} positional ' + 'argument{plural} ({args_given} given)'.format( + function=wrapped.__name__, + args_max=max_positional_args, + args_given=len(args), + plural=plural_s)) + if positional_parameters_enforcement == POSITIONAL_EXCEPTION: + raise TypeError(message) + elif positional_parameters_enforcement == POSITIONAL_WARNING: + logger.warning(message) + return wrapped(*args, **kwargs) + return positional_wrapper + + if isinstance(max_positional_args, six.integer_types): + return positional_decorator + else: + args, _, _, defaults = inspect.getargspec(max_positional_args) + return positional(len(args) - len(defaults))(max_positional_args) + + +def scopes_to_string(scopes): + """Converts scope value to a string. + + If scopes is a string then it is simply passed through. If scopes is an + iterable then a string is returned that is all the individual scopes + concatenated with spaces. + + Args: + scopes: string or iterable of strings, the scopes. + + Returns: + The scopes formatted as a single string. + """ + if isinstance(scopes, six.string_types): + return scopes + else: + return ' '.join(scopes) + + +def string_to_scopes(scopes): + """Converts stringifed scope value to a list. + + If scopes is a list then it is simply passed through. If scopes is an + string then a list of each individual scope is returned. + + Args: + scopes: a string or iterable of strings, the scopes. + + Returns: + The scopes in a list. + """ + if not scopes: + return [] + elif isinstance(scopes, six.string_types): + return scopes.split(' ') + else: + return scopes + + +def parse_unique_urlencoded(content): + """Parses unique key-value parameters from urlencoded content. + + Args: + content: string, URL-encoded key-value pairs. + + Returns: + dict, The key-value pairs from ``content``. + + Raises: + ValueError: if one of the keys is repeated. + """ + urlencoded_params = urllib.parse.parse_qs(content) + params = {} + for key, value in six.iteritems(urlencoded_params): + if len(value) != 1: + msg = ('URL-encoded content contains a repeated value:' + '%s -> %s' % (key, ', '.join(value))) + raise ValueError(msg) + params[key] = value[0] + return params + + +def update_query_params(uri, params): + """Updates a URI with new query parameters. + + If a given key from ``params`` is repeated in the ``uri``, then + the URI will be considered invalid and an error will occur. + + If the URI is valid, then each value from ``params`` will + replace the corresponding value in the query parameters (if + it exists). + + Args: + uri: string, A valid URI, with potential existing query parameters. + params: dict, A dictionary of query parameters. + + Returns: + The same URI but with the new query parameters added. + """ + parts = urllib.parse.urlparse(uri) + query_params = parse_unique_urlencoded(parts.query) + query_params.update(params) + new_query = urllib.parse.urlencode(query_params) + new_parts = parts._replace(query=new_query) + return urllib.parse.urlunparse(new_parts) + + +def _add_query_parameter(url, name, value): + """Adds a query parameter to a url. + + Replaces the current value if it already exists in the URL. + + Args: + url: string, url to add the query parameter to. + name: string, query parameter name. + value: string, query parameter value. + + Returns: + Updated query parameter. Does not update the url if value is None. + """ + if value is None: + return url + else: + return update_query_params(url, {name: value}) + + +def validate_file(filename): + if os.path.islink(filename): + raise IOError(_SYM_LINK_MESSAGE.format(filename)) + elif os.path.isdir(filename): + raise IOError(_IS_DIR_MESSAGE.format(filename)) + elif not os.path.isfile(filename): + warnings.warn(_MISSING_FILE_MESSAGE.format(filename)) + + +def _parse_pem_key(raw_key_input): + """Identify and extract PEM keys. + + Determines whether the given key is in the format of PEM key, and extracts + the relevant part of the key if it is. + + Args: + raw_key_input: The contents of a private key file (either PEM or + PKCS12). + + Returns: + string, The actual key if the contents are from a PEM file, or + else None. + """ + offset = raw_key_input.find(b'-----BEGIN ') + if offset != -1: + return raw_key_input[offset:] + + +def _json_encode(data): + return json.dumps(data, separators=(',', ':')) + + +def _to_bytes(value, encoding='ascii'): + """Converts a string value to bytes, if necessary. + + Unfortunately, ``six.b`` is insufficient for this task since in + Python2 it does not modify ``unicode`` objects. + + Args: + value: The string/bytes value to be converted. + encoding: The encoding to use to convert unicode to bytes. Defaults + to "ascii", which will not allow any characters from ordinals + larger than 127. Other useful values are "latin-1", which + which will only allows byte ordinals (up to 255) and "utf-8", + which will encode any unicode that needs to be. + + Returns: + The original value converted to bytes (if unicode) or as passed in + if it started out as bytes. + + Raises: + ValueError if the value could not be converted to bytes. + """ + result = (value.encode(encoding) + if isinstance(value, six.text_type) else value) + if isinstance(result, six.binary_type): + return result + else: + raise ValueError('{0!r} could not be converted to bytes'.format(value)) + + +def _from_bytes(value): + """Converts bytes to a string value, if necessary. + + Args: + value: The string/bytes value to be converted. + + Returns: + The original value converted to unicode (if bytes) or as passed in + if it started out as unicode. + + Raises: + ValueError if the value could not be converted to unicode. + """ + result = (value.decode('utf-8') + if isinstance(value, six.binary_type) else value) + if isinstance(result, six.text_type): + return result + else: + raise ValueError( + '{0!r} could not be converted to unicode'.format(value)) + + +def _urlsafe_b64encode(raw_bytes): + raw_bytes = _to_bytes(raw_bytes, encoding='utf-8') + return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=') + + +def _urlsafe_b64decode(b64string): + # Guard against unicode strings, which base64 can't handle. + b64string = _to_bytes(b64string) + padded = b64string + b'=' * (4 - len(b64string) % 4) + return base64.urlsafe_b64decode(padded) diff --git a/src/oauth2client/oauth2client/_openssl_crypt.py b/src/oauth2client/oauth2client/_openssl_crypt.py new file mode 100644 index 00000000..77fac743 --- /dev/null +++ b/src/oauth2client/oauth2client/_openssl_crypt.py @@ -0,0 +1,136 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenSSL Crypto-related routines for oauth2client.""" + +from OpenSSL import crypto + +from oauth2client import _helpers + + +class OpenSSLVerifier(object): + """Verifies the signature on a message.""" + + def __init__(self, pubkey): + """Constructor. + + Args: + pubkey: OpenSSL.crypto.PKey, The public key to verify with. + """ + self._pubkey = pubkey + + def verify(self, message, signature): + """Verifies a message against a signature. + + Args: + message: string or bytes, The message to verify. If string, will be + encoded to bytes as utf-8. + signature: string or bytes, The signature on the message. If string, + will be encoded to bytes as utf-8. + + Returns: + True if message was signed by the private key associated with the + public key that this object was constructed with. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + signature = _helpers._to_bytes(signature, encoding='utf-8') + try: + crypto.verify(self._pubkey, signature, message, 'sha256') + return True + except crypto.Error: + return False + + @staticmethod + def from_string(key_pem, is_x509_cert): + """Construct a Verified instance from a string. + + Args: + key_pem: string, public key in PEM format. + is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it + is expected to be an RSA key in PEM format. + + Returns: + Verifier instance. + + Raises: + OpenSSL.crypto.Error: if the key_pem can't be parsed. + """ + key_pem = _helpers._to_bytes(key_pem) + if is_x509_cert: + pubkey = crypto.load_certificate(crypto.FILETYPE_PEM, key_pem) + else: + pubkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key_pem) + return OpenSSLVerifier(pubkey) + + +class OpenSSLSigner(object): + """Signs messages with a private key.""" + + def __init__(self, pkey): + """Constructor. + + Args: + pkey: OpenSSL.crypto.PKey (or equiv), The private key to sign with. + """ + self._key = pkey + + def sign(self, message): + """Signs a message. + + Args: + message: bytes, Message to be signed. + + Returns: + string, The signature of the message for the given key. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + return crypto.sign(self._key, message, 'sha256') + + @staticmethod + def from_string(key, password=b'notasecret'): + """Construct a Signer instance from a string. + + Args: + key: string, private key in PKCS12 or PEM format. + password: string, password for the private key file. + + Returns: + Signer instance. + + Raises: + OpenSSL.crypto.Error if the key can't be parsed. + """ + key = _helpers._to_bytes(key) + parsed_pem_key = _helpers._parse_pem_key(key) + if parsed_pem_key: + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) + else: + password = _helpers._to_bytes(password, encoding='utf-8') + pkey = crypto.load_pkcs12(key, password).get_privatekey() + return OpenSSLSigner(pkey) + + +def pkcs12_key_as_pem(private_key_bytes, private_key_password): + """Convert the contents of a PKCS#12 key to PEM using pyOpenSSL. + + Args: + private_key_bytes: Bytes. PKCS#12 key in DER format. + private_key_password: String. Password for PKCS#12 key. + + Returns: + String. PEM contents of ``private_key_bytes``. + """ + private_key_password = _helpers._to_bytes(private_key_password) + pkcs12 = crypto.load_pkcs12(private_key_bytes, private_key_password) + return crypto.dump_privatekey(crypto.FILETYPE_PEM, + pkcs12.get_privatekey()) diff --git a/src/oauth2client/oauth2client/_pkce.py b/src/oauth2client/oauth2client/_pkce.py new file mode 100644 index 00000000..e4952d8c --- /dev/null +++ b/src/oauth2client/oauth2client/_pkce.py @@ -0,0 +1,67 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for implementing Proof Key for Code Exchange (PKCE) by OAuth +Public Clients + +See RFC7636. +""" + +import base64 +import hashlib +import os + + +def code_verifier(n_bytes=64): + """ + Generates a 'code_verifier' as described in section 4.1 of RFC 7636. + + This is a 'high-entropy cryptographic random string' that will be + impractical for an attacker to guess. + + Args: + n_bytes: integer between 31 and 96, inclusive. default: 64 + number of bytes of entropy to include in verifier. + + Returns: + Bytestring, representing urlsafe base64-encoded random data. + """ + verifier = base64.urlsafe_b64encode(os.urandom(n_bytes)).rstrip(b'=') + # https://tools.ietf.org/html/rfc7636#section-4.1 + # minimum length of 43 characters and a maximum length of 128 characters. + if len(verifier) < 43: + raise ValueError("Verifier too short. n_bytes must be > 30.") + elif len(verifier) > 128: + raise ValueError("Verifier too long. n_bytes must be < 97.") + else: + return verifier + + +def code_challenge(verifier): + """ + Creates a 'code_challenge' as described in section 4.2 of RFC 7636 + by taking the sha256 hash of the verifier and then urlsafe + base64-encoding it. + + Args: + verifier: bytestring, representing a code_verifier as generated by + code_verifier(). + + Returns: + Bytestring, representing a urlsafe base64-encoded sha256 hash digest, + without '=' padding. + """ + digest = hashlib.sha256(verifier).digest() + return base64.urlsafe_b64encode(digest).rstrip(b'=') diff --git a/src/oauth2client/oauth2client/_pure_python_crypt.py b/src/oauth2client/oauth2client/_pure_python_crypt.py new file mode 100644 index 00000000..2c5d43aa --- /dev/null +++ b/src/oauth2client/oauth2client/_pure_python_crypt.py @@ -0,0 +1,184 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pure Python crypto-related routines for oauth2client. + +Uses the ``rsa``, ``pyasn1`` and ``pyasn1_modules`` packages +to parse PEM files storing PKCS#1 or PKCS#8 keys as well as +certificates. +""" + +from pyasn1.codec.der import decoder +from pyasn1_modules import pem +from pyasn1_modules.rfc2459 import Certificate +from pyasn1_modules.rfc5208 import PrivateKeyInfo +import rsa +import six + +from oauth2client import _helpers + + +_PKCS12_ERROR = r"""\ +PKCS12 format is not supported by the RSA library. +Either install PyOpenSSL, or please convert .p12 format +to .pem format: + $ cat key.p12 | \ + > openssl pkcs12 -nodes -nocerts -passin pass:notasecret | \ + > openssl rsa > key.pem +""" + +_POW2 = (128, 64, 32, 16, 8, 4, 2, 1) +_PKCS1_MARKER = ('-----BEGIN RSA PRIVATE KEY-----', + '-----END RSA PRIVATE KEY-----') +_PKCS8_MARKER = ('-----BEGIN PRIVATE KEY-----', + '-----END PRIVATE KEY-----') +_PKCS8_SPEC = PrivateKeyInfo() + + +def _bit_list_to_bytes(bit_list): + """Converts an iterable of 1's and 0's to bytes. + + Combines the list 8 at a time, treating each group of 8 bits + as a single byte. + """ + num_bits = len(bit_list) + byte_vals = bytearray() + for start in six.moves.xrange(0, num_bits, 8): + curr_bits = bit_list[start:start + 8] + char_val = sum(val * digit + for val, digit in zip(_POW2, curr_bits)) + byte_vals.append(char_val) + return bytes(byte_vals) + + +class RsaVerifier(object): + """Verifies the signature on a message. + + Args: + pubkey: rsa.key.PublicKey (or equiv), The public key to verify with. + """ + + def __init__(self, pubkey): + self._pubkey = pubkey + + def verify(self, message, signature): + """Verifies a message against a signature. + + Args: + message: string or bytes, The message to verify. If string, will be + encoded to bytes as utf-8. + signature: string or bytes, The signature on the message. If + string, will be encoded to bytes as utf-8. + + Returns: + True if message was signed by the private key associated with the + public key that this object was constructed with. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + try: + return rsa.pkcs1.verify(message, signature, self._pubkey) + except (ValueError, rsa.pkcs1.VerificationError): + return False + + @classmethod + def from_string(cls, key_pem, is_x509_cert): + """Construct an RsaVerifier instance from a string. + + Args: + key_pem: string, public key in PEM format. + is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it + is expected to be an RSA key in PEM format. + + Returns: + RsaVerifier instance. + + Raises: + ValueError: if the key_pem can't be parsed. In either case, error + will begin with 'No PEM start marker'. If + ``is_x509_cert`` is True, will fail to find the + "-----BEGIN CERTIFICATE-----" error, otherwise fails + to find "-----BEGIN RSA PUBLIC KEY-----". + """ + key_pem = _helpers._to_bytes(key_pem) + if is_x509_cert: + der = rsa.pem.load_pem(key_pem, 'CERTIFICATE') + asn1_cert, remaining = decoder.decode(der, asn1Spec=Certificate()) + if remaining != b'': + raise ValueError('Unused bytes', remaining) + + cert_info = asn1_cert['tbsCertificate']['subjectPublicKeyInfo'] + key_bytes = _bit_list_to_bytes(cert_info['subjectPublicKey']) + pubkey = rsa.PublicKey.load_pkcs1(key_bytes, 'DER') + else: + pubkey = rsa.PublicKey.load_pkcs1(key_pem, 'PEM') + return cls(pubkey) + + +class RsaSigner(object): + """Signs messages with a private key. + + Args: + pkey: rsa.key.PrivateKey (or equiv), The private key to sign with. + """ + + def __init__(self, pkey): + self._key = pkey + + def sign(self, message): + """Signs a message. + + Args: + message: bytes, Message to be signed. + + Returns: + string, The signature of the message for the given key. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + return rsa.pkcs1.sign(message, self._key, 'SHA-256') + + @classmethod + def from_string(cls, key, password='notasecret'): + """Construct an RsaSigner instance from a string. + + Args: + key: string, private key in PEM format. + password: string, password for private key file. Unused for PEM + files. + + Returns: + RsaSigner instance. + + Raises: + ValueError if the key cannot be parsed as PKCS#1 or PKCS#8 in + PEM format. + """ + key = _helpers._from_bytes(key) # pem expects str in Py3 + marker_id, key_bytes = pem.readPemBlocksFromFile( + six.StringIO(key), _PKCS1_MARKER, _PKCS8_MARKER) + + if marker_id == 0: + pkey = rsa.key.PrivateKey.load_pkcs1(key_bytes, + format='DER') + elif marker_id == 1: + key_info, remaining = decoder.decode( + key_bytes, asn1Spec=_PKCS8_SPEC) + if remaining != b'': + raise ValueError('Unused bytes', remaining) + pkey_info = key_info.getComponentByName('privateKey') + pkey = rsa.key.PrivateKey.load_pkcs1(pkey_info.asOctets(), + format='DER') + else: + raise ValueError('No key could be detected.') + + return cls(pkey) diff --git a/src/oauth2client/oauth2client/_pycrypto_crypt.py b/src/oauth2client/oauth2client/_pycrypto_crypt.py new file mode 100644 index 00000000..fd2ce0cd --- /dev/null +++ b/src/oauth2client/oauth2client/_pycrypto_crypt.py @@ -0,0 +1,124 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""pyCrypto Crypto-related routines for oauth2client.""" + +from Crypto.Hash import SHA256 +from Crypto.PublicKey import RSA +from Crypto.Signature import PKCS1_v1_5 +from Crypto.Util.asn1 import DerSequence + +from oauth2client import _helpers + + +class PyCryptoVerifier(object): + """Verifies the signature on a message.""" + + def __init__(self, pubkey): + """Constructor. + + Args: + pubkey: OpenSSL.crypto.PKey (or equiv), The public key to verify + with. + """ + self._pubkey = pubkey + + def verify(self, message, signature): + """Verifies a message against a signature. + + Args: + message: string or bytes, The message to verify. If string, will be + encoded to bytes as utf-8. + signature: string or bytes, The signature on the message. + + Returns: + True if message was signed by the private key associated with the + public key that this object was constructed with. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + return PKCS1_v1_5.new(self._pubkey).verify( + SHA256.new(message), signature) + + @staticmethod + def from_string(key_pem, is_x509_cert): + """Construct a Verified instance from a string. + + Args: + key_pem: string, public key in PEM format. + is_x509_cert: bool, True if key_pem is an X509 cert, otherwise it + is expected to be an RSA key in PEM format. + + Returns: + Verifier instance. + """ + if is_x509_cert: + key_pem = _helpers._to_bytes(key_pem) + pemLines = key_pem.replace(b' ', b'').split() + certDer = _helpers._urlsafe_b64decode(b''.join(pemLines[1:-1])) + certSeq = DerSequence() + certSeq.decode(certDer) + tbsSeq = DerSequence() + tbsSeq.decode(certSeq[0]) + pubkey = RSA.importKey(tbsSeq[6]) + else: + pubkey = RSA.importKey(key_pem) + return PyCryptoVerifier(pubkey) + + +class PyCryptoSigner(object): + """Signs messages with a private key.""" + + def __init__(self, pkey): + """Constructor. + + Args: + pkey, OpenSSL.crypto.PKey (or equiv), The private key to sign with. + """ + self._key = pkey + + def sign(self, message): + """Signs a message. + + Args: + message: string, Message to be signed. + + Returns: + string, The signature of the message for the given key. + """ + message = _helpers._to_bytes(message, encoding='utf-8') + return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) + + @staticmethod + def from_string(key, password='notasecret'): + """Construct a Signer instance from a string. + + Args: + key: string, private key in PEM format. + password: string, password for private key file. Unused for PEM + files. + + Returns: + Signer instance. + + Raises: + NotImplementedError if the key isn't in PEM format. + """ + parsed_pem_key = _helpers._parse_pem_key(_helpers._to_bytes(key)) + if parsed_pem_key: + pkey = RSA.importKey(parsed_pem_key) + else: + raise NotImplementedError( + 'No key in PEM format was detected. This implementation ' + 'can only use the PyCrypto library for keys in PEM ' + 'format.') + return PyCryptoSigner(pkey) diff --git a/src/oauth2client/oauth2client/client.py b/src/oauth2client/oauth2client/client.py new file mode 100644 index 00000000..7618960e --- /dev/null +++ b/src/oauth2client/oauth2client/client.py @@ -0,0 +1,2170 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An OAuth 2.0 client. + +Tools for interacting with OAuth 2.0 protected resources. +""" + +import collections +import copy +import datetime +import json +import logging +import os +import shutil +import socket +import sys +import tempfile + +import six +from six.moves import http_client +from six.moves import urllib + +import oauth2client +from oauth2client import _helpers +from oauth2client import _pkce +from oauth2client import clientsecrets +from oauth2client import transport + + +HAS_OPENSSL = False +HAS_CRYPTO = False +try: + from oauth2client import crypt + HAS_CRYPTO = True + HAS_OPENSSL = crypt.OpenSSLVerifier is not None +except ImportError: # pragma: NO COVER + pass + + +logger = logging.getLogger(__name__) + +# Expiry is stored in RFC3339 UTC format +EXPIRY_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + +# Which certs to use to validate id_tokens received. +ID_TOKEN_VERIFICATION_CERTS = 'https://www.googleapis.com/oauth2/v1/certs' +# This symbol previously had a typo in the name; we keep the old name +# around for now, but will remove it in the future. +ID_TOKEN_VERIFICATON_CERTS = ID_TOKEN_VERIFICATION_CERTS + +# Constant to use for the out of band OAuth 2.0 flow. +OOB_CALLBACK_URN = 'urn:ietf:wg:oauth:2.0:oob' + +# The value representing user credentials. +AUTHORIZED_USER = 'authorized_user' + +# The value representing service account credentials. +SERVICE_ACCOUNT = 'service_account' + +# The environment variable pointing the file with local +# Application Default Credentials. +GOOGLE_APPLICATION_CREDENTIALS = 'GOOGLE_APPLICATION_CREDENTIALS' +# The ~/.config subdirectory containing gcloud credentials. Intended +# to be swapped out in tests. +_CLOUDSDK_CONFIG_DIRECTORY = 'gcloud' +# The environment variable name which can replace ~/.config if set. +_CLOUDSDK_CONFIG_ENV_VAR = 'CLOUDSDK_CONFIG' + +# The error message we show users when we can't find the Application +# Default Credentials. +ADC_HELP_MSG = ( + 'The Application Default Credentials are not available. They are ' + 'available if running in Google Compute Engine. Otherwise, the ' + 'environment variable ' + + GOOGLE_APPLICATION_CREDENTIALS + + ' must be defined pointing to a file defining the credentials. See ' + 'https://developers.google.com/accounts/docs/' + 'application-default-credentials for more information.') + +_WELL_KNOWN_CREDENTIALS_FILE = 'application_default_credentials.json' + +# The access token along with the seconds in which it expires. +AccessTokenInfo = collections.namedtuple( + 'AccessTokenInfo', ['access_token', 'expires_in']) + +DEFAULT_ENV_NAME = 'UNKNOWN' + +# If set to True _get_environment avoid GCE check (_detect_gce_environment) +NO_GCE_CHECK = os.getenv('NO_GCE_CHECK', 'False') + +# Timeout in seconds to wait for the GCE metadata server when detecting the +# GCE environment. +try: + GCE_METADATA_TIMEOUT = int(os.getenv('GCE_METADATA_TIMEOUT', 3)) +except ValueError: # pragma: NO COVER + GCE_METADATA_TIMEOUT = 3 + +_SERVER_SOFTWARE = 'SERVER_SOFTWARE' +_GCE_METADATA_URI = 'http://' + os.getenv('GCE_METADATA_IP', '169.254.169.254') +_METADATA_FLAVOR_HEADER = 'metadata-flavor' # lowercase header +_DESIRED_METADATA_FLAVOR = 'Google' +_GCE_HEADERS = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} + +# Expose utcnow() at module level to allow for +# easier testing (by replacing with a stub). +_UTCNOW = datetime.datetime.utcnow + +# NOTE: These names were previously defined in this module but have been +# moved into `oauth2client.transport`, +clean_headers = transport.clean_headers +MemoryCache = transport.MemoryCache +REFRESH_STATUS_CODES = transport.REFRESH_STATUS_CODES + + +class SETTINGS(object): + """Settings namespace for globally defined values.""" + env_name = None + + +class Error(Exception): + """Base error for this module.""" + + +class FlowExchangeError(Error): + """Error trying to exchange an authorization grant for an access token.""" + + +class AccessTokenRefreshError(Error): + """Error trying to refresh an expired access token.""" + + +class HttpAccessTokenRefreshError(AccessTokenRefreshError): + """Error (with HTTP status) trying to refresh an expired access token.""" + def __init__(self, *args, **kwargs): + super(HttpAccessTokenRefreshError, self).__init__(*args) + self.status = kwargs.get('status') + + +class TokenRevokeError(Error): + """Error trying to revoke a token.""" + + +class UnknownClientSecretsFlowError(Error): + """The client secrets file called for an unknown type of OAuth 2.0 flow.""" + + +class AccessTokenCredentialsError(Error): + """Having only the access_token means no refresh is possible.""" + + +class VerifyJwtTokenError(Error): + """Could not retrieve certificates for validation.""" + + +class NonAsciiHeaderError(Error): + """Header names and values must be ASCII strings.""" + + +class ApplicationDefaultCredentialsError(Error): + """Error retrieving the Application Default Credentials.""" + + +class OAuth2DeviceCodeError(Error): + """Error trying to retrieve a device code.""" + + +class CryptoUnavailableError(Error, NotImplementedError): + """Raised when a crypto library is required, but none is available.""" + + +def _parse_expiry(expiry): + if expiry and isinstance(expiry, datetime.datetime): + return expiry.strftime(EXPIRY_FORMAT) + else: + return None + + +class Credentials(object): + """Base class for all Credentials objects. + + Subclasses must define an authorize() method that applies the credentials + to an HTTP transport. + + Subclasses must also specify a classmethod named 'from_json' that takes a + JSON string as input and returns an instantiated Credentials object. + """ + + NON_SERIALIZED_MEMBERS = frozenset(['store']) + + def authorize(self, http): + """Take an httplib2.Http instance (or equivalent) and authorizes it. + + Authorizes it for the set of credentials, usually by replacing + http.request() with a method that adds in the appropriate headers and + then delegates to the original Http.request() method. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + raise NotImplementedError + + def refresh(self, http): + """Forces a refresh of the access_token. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + raise NotImplementedError + + def revoke(self, http): + """Revokes a refresh_token and makes the credentials void. + + Args: + http: httplib2.Http, an http object to be used to make the revoke + request. + """ + raise NotImplementedError + + def apply(self, headers): + """Add the authorization to the headers. + + Args: + headers: dict, the headers to add the Authorization header to. + """ + raise NotImplementedError + + def _to_json(self, strip, to_serialize=None): + """Utility function that creates JSON repr. of a Credentials object. + + Args: + strip: array, An array of names of members to exclude from the + JSON. + to_serialize: dict, (Optional) The properties for this object + that will be serialized. This allows callers to + modify before serializing. + + Returns: + string, a JSON representation of this instance, suitable to pass to + from_json(). + """ + curr_type = self.__class__ + if to_serialize is None: + to_serialize = copy.copy(self.__dict__) + else: + # Assumes it is a str->str dictionary, so we don't deep copy. + to_serialize = copy.copy(to_serialize) + for member in strip: + if member in to_serialize: + del to_serialize[member] + to_serialize['token_expiry'] = _parse_expiry( + to_serialize.get('token_expiry')) + # Add in information we will need later to reconstitute this instance. + to_serialize['_class'] = curr_type.__name__ + to_serialize['_module'] = curr_type.__module__ + for key, val in to_serialize.items(): + if isinstance(val, bytes): + to_serialize[key] = val.decode('utf-8') + if isinstance(val, set): + to_serialize[key] = list(val) + return json.dumps(to_serialize) + + def to_json(self): + """Creating a JSON representation of an instance of Credentials. + + Returns: + string, a JSON representation of this instance, suitable to pass to + from_json(). + """ + return self._to_json(self.NON_SERIALIZED_MEMBERS) + + @classmethod + def new_from_json(cls, json_data): + """Utility class method to instantiate a Credentials subclass from JSON. + + Expects the JSON string to have been produced by to_json(). + + Args: + json_data: string or bytes, JSON from to_json(). + + Returns: + An instance of the subclass of Credentials that was serialized with + to_json(). + """ + json_data_as_unicode = _helpers._from_bytes(json_data) + data = json.loads(json_data_as_unicode) + # Find and call the right classmethod from_json() to restore + # the object. + module_name = data['_module'] + try: + module_obj = __import__(module_name) + except ImportError: + # In case there's an object from the old package structure, + # update it + module_name = module_name.replace('.googleapiclient', '') + module_obj = __import__(module_name) + + module_obj = __import__(module_name, + fromlist=module_name.split('.')[:-1]) + kls = getattr(module_obj, data['_class']) + return kls.from_json(json_data_as_unicode) + + @classmethod + def from_json(cls, unused_data): + """Instantiate a Credentials object from a JSON description of it. + + The JSON should have been produced by calling .to_json() on the object. + + Args: + unused_data: dict, A deserialized JSON object. + + Returns: + An instance of a Credentials subclass. + """ + return Credentials() + + +class Flow(object): + """Base class for all Flow objects.""" + pass + + +class Storage(object): + """Base class for all Storage objects. + + Store and retrieve a single credential. This class supports locking + such that multiple processes and threads can operate on a single + store. + """ + def __init__(self, lock=None): + """Create a Storage instance. + + Args: + lock: An optional threading.Lock-like object. Must implement at + least acquire() and release(). Does not need to be + re-entrant. + """ + self._lock = lock + + def acquire_lock(self): + """Acquires any lock necessary to access this Storage. + + This lock is not reentrant. + """ + if self._lock is not None: + self._lock.acquire() + + def release_lock(self): + """Release the Storage lock. + + Trying to release a lock that isn't held will result in a + RuntimeError in the case of a threading.Lock or multiprocessing.Lock. + """ + if self._lock is not None: + self._lock.release() + + def locked_get(self): + """Retrieve credential. + + The Storage lock must be held when this is called. + + Returns: + oauth2client.client.Credentials + """ + raise NotImplementedError + + def locked_put(self, credentials): + """Write a credential. + + The Storage lock must be held when this is called. + + Args: + credentials: Credentials, the credentials to store. + """ + raise NotImplementedError + + def locked_delete(self): + """Delete a credential. + + The Storage lock must be held when this is called. + """ + raise NotImplementedError + + def get(self): + """Retrieve credential. + + The Storage lock must *not* be held when this is called. + + Returns: + oauth2client.client.Credentials + """ + self.acquire_lock() + try: + return self.locked_get() + finally: + self.release_lock() + + def put(self, credentials): + """Write a credential. + + The Storage lock must be held when this is called. + + Args: + credentials: Credentials, the credentials to store. + """ + self.acquire_lock() + try: + self.locked_put(credentials) + finally: + self.release_lock() + + def delete(self): + """Delete credential. + + Frees any resources associated with storing the credential. + The Storage lock must *not* be held when this is called. + + Returns: + None + """ + self.acquire_lock() + try: + return self.locked_delete() + finally: + self.release_lock() + + +class OAuth2Credentials(Credentials): + """Credentials object for OAuth 2.0. + + Credentials can be applied to an httplib2.Http object using the authorize() + method, which then adds the OAuth 2.0 access token to each request. + + OAuth2Credentials objects may be safely pickled and unpickled. + """ + + @_helpers.positional(8) + def __init__(self, access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, revoke_uri=None, + id_token=None, token_response=None, scopes=None, + token_info_uri=None, id_token_jwt=None): + """Create an instance of OAuth2Credentials. + + This constructor is not usually called by the user, instead + OAuth2Credentials objects are instantiated by the OAuth2WebServerFlow. + + Args: + access_token: string, access token. + client_id: string, client identifier. + client_secret: string, client secret. + refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to None; a + token can't be revoked if this is None. + id_token: object, The identity of the resource owner. + token_response: dict, the decoded response to the token request. + None if a token hasn't been requested yet. Stored + because some providers (e.g. wordpress.com) include + extra fields that clients may want. + scopes: list, authorized scopes for these credentials. + token_info_uri: string, the URI for the token info endpoint. + Defaults to None; scopes can not be refreshed if + this is None. + id_token_jwt: string, the encoded and signed identity JWT. The + decoded version of this is stored in id_token. + + Notes: + store: callable, A callable that when passed a Credential + will store the credential back to where it came from. + This is needed to store the latest access_token if it + has expired and been refreshed. + """ + self.access_token = access_token + self.client_id = client_id + self.client_secret = client_secret + self.refresh_token = refresh_token + self.store = None + self.token_expiry = token_expiry + self.token_uri = token_uri + self.user_agent = user_agent + self.revoke_uri = revoke_uri + self.id_token = id_token + self.id_token_jwt = id_token_jwt + self.token_response = token_response + self.scopes = set(_helpers.string_to_scopes(scopes or [])) + self.token_info_uri = token_info_uri + + # True if the credentials have been revoked or expired and can't be + # refreshed. + self.invalid = False + + def authorize(self, http): + """Authorize an httplib2.Http instance with these credentials. + + The modified http.request method will add authentication headers to + each request and will refresh access_tokens when a 401 is received on a + request. In addition the http.request method has a credentials + property, http.request.credentials, which is the Credentials object + that authorized it. + + Args: + http: An instance of ``httplib2.Http`` or something that acts + like it. + + Returns: + A modified instance of http that was passed in. + + Example:: + + h = httplib2.Http() + h = credentials.authorize(h) + + You can't create a new OAuth subclass of httplib2.Authentication + because it never gets passed the absolute URI, which is needed for + signing. So instead we have to overload 'request' with a closure + that adds in the Authorization header and then calls the original + version of 'request()'. + """ + transport.wrap_http_for_auth(self, http) + return http + + def refresh(self, http): + """Forces a refresh of the access_token. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + """ + self._refresh(http) + + def revoke(self, http): + """Revokes a refresh_token and makes the credentials void. + + Args: + http: httplib2.Http, an http object to be used to make the revoke + request. + """ + self._revoke(http) + + def apply(self, headers): + """Add the authorization to the headers. + + Args: + headers: dict, the headers to add the Authorization header to. + """ + headers['Authorization'] = 'Bearer ' + self.access_token + + def has_scopes(self, scopes): + """Verify that the credentials are authorized for the given scopes. + + Returns True if the credentials authorized scopes contain all of the + scopes given. + + Args: + scopes: list or string, the scopes to check. + + Notes: + There are cases where the credentials are unaware of which scopes + are authorized. Notably, credentials obtained and stored before + this code was added will not have scopes, AccessTokenCredentials do + not have scopes. In both cases, you can use refresh_scopes() to + obtain the canonical set of scopes. + """ + scopes = _helpers.string_to_scopes(scopes) + return set(scopes).issubset(self.scopes) + + def retrieve_scopes(self, http): + """Retrieves the canonical list of scopes for this access token. + + Gets the scopes from the OAuth2 provider. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + + Returns: + A set of strings containing the canonical list of scopes. + """ + self._retrieve_scopes(http) + return self.scopes + + @classmethod + def from_json(cls, json_data): + """Instantiate a Credentials object from a JSON description of it. + + The JSON should have been produced by calling .to_json() on the object. + + Args: + json_data: string or bytes, JSON to deserialize. + + Returns: + An instance of a Credentials subclass. + """ + data = json.loads(_helpers._from_bytes(json_data)) + if (data.get('token_expiry') and + not isinstance(data['token_expiry'], datetime.datetime)): + try: + data['token_expiry'] = datetime.datetime.strptime( + data['token_expiry'], EXPIRY_FORMAT) + except ValueError: + data['token_expiry'] = None + retval = cls( + data['access_token'], + data['client_id'], + data['client_secret'], + data['refresh_token'], + data['token_expiry'], + data['token_uri'], + data['user_agent'], + revoke_uri=data.get('revoke_uri', None), + id_token=data.get('id_token', None), + id_token_jwt=data.get('id_token_jwt', None), + token_response=data.get('token_response', None), + scopes=data.get('scopes', None), + token_info_uri=data.get('token_info_uri', None)) + retval.invalid = data['invalid'] + return retval + + @property + def access_token_expired(self): + """True if the credential is expired or invalid. + + If the token_expiry isn't set, we assume the token doesn't expire. + """ + if self.invalid: + return True + + if not self.token_expiry: + return False + + now = _UTCNOW() + if now >= self.token_expiry: + logger.info('access_token is expired. Now: %s, token_expiry: %s', + now, self.token_expiry) + return True + return False + + def get_access_token(self, http=None): + """Return the access token and its expiration information. + + If the token does not exist, get one. + If the token expired, refresh it. + """ + if not self.access_token or self.access_token_expired: + if not http: + http = transport.get_http_object() + self.refresh(http) + return AccessTokenInfo(access_token=self.access_token, + expires_in=self._expires_in()) + + def set_store(self, store): + """Set the Storage for the credential. + + Args: + store: Storage, an implementation of Storage object. + This is needed to store the latest access_token if it + has expired and been refreshed. This implementation uses + locking to check for updates before updating the + access_token. + """ + self.store = store + + def _expires_in(self): + """Return the number of seconds until this token expires. + + If token_expiry is in the past, this method will return 0, meaning the + token has already expired. + + If token_expiry is None, this method will return None. Note that + returning 0 in such a case would not be fair: the token may still be + valid; we just don't know anything about it. + """ + if self.token_expiry: + now = _UTCNOW() + if self.token_expiry > now: + time_delta = self.token_expiry - now + # TODO(orestica): return time_delta.total_seconds() + # once dropping support for Python 2.6 + return time_delta.days * 86400 + time_delta.seconds + else: + return 0 + + def _updateFromCredential(self, other): + """Update this Credential from another instance.""" + self.__dict__.update(other.__getstate__()) + + def __getstate__(self): + """Trim the state down to something that can be pickled.""" + d = copy.copy(self.__dict__) + del d['store'] + return d + + def __setstate__(self, state): + """Reconstitute the state of the object from being pickled.""" + self.__dict__.update(state) + self.store = None + + def _generate_refresh_request_body(self): + """Generate the body that will be used in the refresh request.""" + body = urllib.parse.urlencode({ + 'grant_type': 'refresh_token', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token, + }) + return body + + def _generate_refresh_request_headers(self): + """Generate the headers that will be used in the refresh request.""" + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } + + if self.user_agent is not None: + headers['user-agent'] = self.user_agent + + return headers + + def _refresh(self, http): + """Refreshes the access_token. + + This method first checks by reading the Storage object if available. + If a refresh is still needed, it holds the Storage lock until the + refresh is completed. + + Args: + http: an object to be used to make HTTP requests. + + Raises: + HttpAccessTokenRefreshError: When the refresh fails. + """ + if not self.store: + self._do_refresh_request(http) + else: + self.store.acquire_lock() + try: + new_cred = self.store.locked_get() + + if (new_cred and not new_cred.invalid and + new_cred.access_token != self.access_token and + not new_cred.access_token_expired): + logger.info('Updated access_token read from Storage') + self._updateFromCredential(new_cred) + else: + self._do_refresh_request(http) + finally: + self.store.release_lock() + + def _do_refresh_request(self, http): + """Refresh the access_token using the refresh_token. + + Args: + http: an object to be used to make HTTP requests. + + Raises: + HttpAccessTokenRefreshError: When the refresh fails. + """ + body = self._generate_refresh_request_body() + headers = self._generate_refresh_request_headers() + + logger.info('Refreshing access_token') + resp, content = transport.request( + http, self.token_uri, method='POST', + body=body, headers=headers) + content = _helpers._from_bytes(content) + if resp.status == http_client.OK: + d = json.loads(content) + self.token_response = d + self.access_token = d['access_token'] + self.refresh_token = d.get('refresh_token', self.refresh_token) + if 'expires_in' in d: + delta = datetime.timedelta(seconds=int(d['expires_in'])) + self.token_expiry = delta + _UTCNOW() + else: + self.token_expiry = None + if 'id_token' in d: + self.id_token = _extract_id_token(d['id_token']) + self.id_token_jwt = d['id_token'] + else: + self.id_token = None + self.id_token_jwt = None + # On temporary refresh errors, the user does not actually have to + # re-authorize, so we unflag here. + self.invalid = False + if self.store: + self.store.locked_put(self) + else: + # An {'error':...} response body means the token is expired or + # revoked, so we flag the credentials as such. + logger.info('Failed to retrieve access token: %s', content) + error_msg = 'Invalid response {0}.'.format(resp.status) + try: + d = json.loads(content) + if 'error' in d: + error_msg = d['error'] + if 'error_description' in d: + error_msg += ': ' + d['error_description'] + self.invalid = True + if self.store is not None: + self.store.locked_put(self) + except (TypeError, ValueError): + pass + raise HttpAccessTokenRefreshError(error_msg, status=resp.status) + + def _revoke(self, http): + """Revokes this credential and deletes the stored copy (if it exists). + + Args: + http: an object to be used to make HTTP requests. + """ + self._do_revoke(http, self.refresh_token or self.access_token) + + def _do_revoke(self, http, token): + """Revokes this credential and deletes the stored copy (if it exists). + + Args: + http: an object to be used to make HTTP requests. + token: A string used as the token to be revoked. Can be either an + access_token or refresh_token. + + Raises: + TokenRevokeError: If the revoke request does not return with a + 200 OK. + """ + logger.info('Revoking token') + query_params = {'token': token} + token_revoke_uri = _helpers.update_query_params( + self.revoke_uri, query_params) + resp, content = transport.request(http, token_revoke_uri) + if resp.status == http_client.METHOD_NOT_ALLOWED: + body = urllib.parse.urlencode(query_params) + resp, content = transport.request(http, token_revoke_uri, + method='POST', body=body) + if resp.status == http_client.OK: + self.invalid = True + else: + error_msg = 'Invalid response {0}.'.format(resp.status) + try: + d = json.loads(_helpers._from_bytes(content)) + if 'error' in d: + error_msg = d['error'] + except (TypeError, ValueError): + pass + raise TokenRevokeError(error_msg) + + if self.store: + self.store.delete() + + def _retrieve_scopes(self, http): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http: an object to be used to make HTTP requests. + """ + self._do_retrieve_scopes(http, self.access_token) + + def _do_retrieve_scopes(self, http, token): + """Retrieves the list of authorized scopes from the OAuth2 provider. + + Args: + http: an object to be used to make HTTP requests. + token: A string used as the token to identify the credentials to + the provider. + + Raises: + Error: When refresh fails, indicating the the access token is + invalid. + """ + logger.info('Refreshing scopes') + query_params = {'access_token': token, 'fields': 'scope'} + token_info_uri = _helpers.update_query_params( + self.token_info_uri, query_params) + resp, content = transport.request(http, token_info_uri) + content = _helpers._from_bytes(content) + if resp.status == http_client.OK: + d = json.loads(content) + self.scopes = set(_helpers.string_to_scopes(d.get('scope', ''))) + else: + error_msg = 'Invalid response {0}.'.format(resp.status) + try: + d = json.loads(content) + if 'error_description' in d: + error_msg = d['error_description'] + except (TypeError, ValueError): + pass + raise Error(error_msg) + + +class AccessTokenCredentials(OAuth2Credentials): + """Credentials object for OAuth 2.0. + + Credentials can be applied to an httplib2.Http object using the + authorize() method, which then signs each request from that object + with the OAuth 2.0 access token. This set of credentials is for the + use case where you have acquired an OAuth 2.0 access_token from + another place such as a JavaScript client or another web + application, and wish to use it from Python. Because only the + access_token is present it can not be refreshed and will in time + expire. + + AccessTokenCredentials objects may be safely pickled and unpickled. + + Usage:: + + credentials = AccessTokenCredentials('', + 'my-user-agent/1.0') + http = httplib2.Http() + http = credentials.authorize(http) + + Raises: + AccessTokenCredentialsExpired: raised when the access_token expires or + is revoked. + """ + + def __init__(self, access_token, user_agent, revoke_uri=None): + """Create an instance of OAuth2Credentials + + This is one of the few types if Credentials that you should contrust, + Credentials objects are usually instantiated by a Flow. + + Args: + access_token: string, access token. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to None; a + token can't be revoked if this is None. + """ + super(AccessTokenCredentials, self).__init__( + access_token, + None, + None, + None, + None, + None, + user_agent, + revoke_uri=revoke_uri) + + @classmethod + def from_json(cls, json_data): + data = json.loads(_helpers._from_bytes(json_data)) + retval = AccessTokenCredentials( + data['access_token'], + data['user_agent']) + return retval + + def _refresh(self, http): + """Refreshes the access token. + + Args: + http: unused HTTP object. + + Raises: + AccessTokenCredentialsError: always + """ + raise AccessTokenCredentialsError( + 'The access_token is expired or invalid and can\'t be refreshed.') + + def _revoke(self, http): + """Revokes the access_token and deletes the store if available. + + Args: + http: an object to be used to make HTTP requests. + """ + self._do_revoke(http, self.access_token) + + +def _detect_gce_environment(): + """Determine if the current environment is Compute Engine. + + Returns: + Boolean indicating whether or not the current environment is Google + Compute Engine. + """ + # NOTE: The explicit ``timeout`` is a workaround. The underlying + # issue is that resolving an unknown host on some networks will take + # 20-30 seconds; making this timeout short fixes the issue, but + # could lead to false negatives in the event that we are on GCE, but + # the metadata resolution was particularly slow. The latter case is + # "unlikely". + http = transport.get_http_object(timeout=GCE_METADATA_TIMEOUT) + try: + response, _ = transport.request( + http, _GCE_METADATA_URI, headers=_GCE_HEADERS) + return ( + response.status == http_client.OK and + response.get(_METADATA_FLAVOR_HEADER) == _DESIRED_METADATA_FLAVOR) + except socket.error: # socket.timeout or socket.error(64, 'Host is down') + logger.info('Timeout attempting to reach GCE metadata service.') + return False + + +def _in_gae_environment(): + """Detects if the code is running in the App Engine environment. + + Returns: + True if running in the GAE environment, False otherwise. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name in ('GAE_PRODUCTION', 'GAE_LOCAL') + + try: + import google.appengine # noqa: unused import + except ImportError: + pass + else: + server_software = os.environ.get(_SERVER_SOFTWARE, '') + if server_software.startswith('Google App Engine/'): + SETTINGS.env_name = 'GAE_PRODUCTION' + return True + elif server_software.startswith('Development/'): + SETTINGS.env_name = 'GAE_LOCAL' + return True + + return False + + +def _in_gce_environment(): + """Detect if the code is running in the Compute Engine environment. + + Returns: + True if running in the GCE environment, False otherwise. + """ + if SETTINGS.env_name is not None: + return SETTINGS.env_name == 'GCE_PRODUCTION' + + if NO_GCE_CHECK != 'True' and _detect_gce_environment(): + SETTINGS.env_name = 'GCE_PRODUCTION' + return True + return False + + +class GoogleCredentials(OAuth2Credentials): + """Application Default Credentials for use in calling Google APIs. + + The Application Default Credentials are being constructed as a function of + the environment where the code is being run. + More details can be found on this page: + https://developers.google.com/accounts/docs/application-default-credentials + + Here is an example of how to use the Application Default Credentials for a + service that requires authentication:: + + from googleapiclient.discovery import build + from oauth2client.client import GoogleCredentials + + credentials = GoogleCredentials.get_application_default() + service = build('compute', 'v1', credentials=credentials) + + PROJECT = 'bamboo-machine-422' + ZONE = 'us-central1-a' + request = service.instances().list(project=PROJECT, zone=ZONE) + response = request.execute() + + print(response) + """ + + NON_SERIALIZED_MEMBERS = ( + frozenset(['_private_key']) | + OAuth2Credentials.NON_SERIALIZED_MEMBERS) + """Members that aren't serialized when object is converted to JSON.""" + + def __init__(self, access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI): + """Create an instance of GoogleCredentials. + + This constructor is not usually called by the user, instead + GoogleCredentials objects are instantiated by + GoogleCredentials.from_stream() or + GoogleCredentials.get_application_default(). + + Args: + access_token: string, access token. + client_id: string, client identifier. + client_secret: string, client secret. + refresh_token: string, refresh token. + token_expiry: datetime, when the access_token expires. + token_uri: string, URI of token endpoint. + user_agent: string, The HTTP User-Agent to provide for this + application. + revoke_uri: string, URI for revoke endpoint. Defaults to + oauth2client.GOOGLE_REVOKE_URI; a token can't be + revoked if this is None. + """ + super(GoogleCredentials, self).__init__( + access_token, client_id, client_secret, refresh_token, + token_expiry, token_uri, user_agent, revoke_uri=revoke_uri) + + def create_scoped_required(self): + """Whether this Credentials object is scopeless. + + create_scoped(scopes) method needs to be called in order to create + a Credentials object for API calls. + """ + return False + + def create_scoped(self, scopes): + """Create a Credentials object for the given scopes. + + The Credentials type is preserved. + """ + return self + + @classmethod + def from_json(cls, json_data): + # TODO(issue 388): eliminate the circularity that is the reason for + # this non-top-level import. + from oauth2client import service_account + data = json.loads(_helpers._from_bytes(json_data)) + + # We handle service_account.ServiceAccountCredentials since it is a + # possible return type of GoogleCredentials.get_application_default() + if (data['_module'] == 'oauth2client.service_account' and + data['_class'] == 'ServiceAccountCredentials'): + return service_account.ServiceAccountCredentials.from_json(data) + elif (data['_module'] == 'oauth2client.service_account' and + data['_class'] == '_JWTAccessCredentials'): + return service_account._JWTAccessCredentials.from_json(data) + + token_expiry = _parse_expiry(data.get('token_expiry')) + google_credentials = cls( + data['access_token'], + data['client_id'], + data['client_secret'], + data['refresh_token'], + token_expiry, + data['token_uri'], + data['user_agent'], + revoke_uri=data.get('revoke_uri', None)) + google_credentials.invalid = data['invalid'] + return google_credentials + + @property + def serialization_data(self): + """Get the fields and values identifying the current credentials.""" + return { + 'type': 'authorized_user', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token + } + + @staticmethod + def _implicit_credentials_from_gae(): + """Attempts to get implicit credentials in Google App Engine env. + + If the current environment is not detected as App Engine, returns None, + indicating no Google App Engine credentials can be detected from the + current environment. + + Returns: + None, if not in GAE, else an appengine.AppAssertionCredentials + object. + """ + if not _in_gae_environment(): + return None + + return _get_application_default_credential_GAE() + + @staticmethod + def _implicit_credentials_from_gce(): + """Attempts to get implicit credentials in Google Compute Engine env. + + If the current environment is not detected as Compute Engine, returns + None, indicating no Google Compute Engine credentials can be detected + from the current environment. + + Returns: + None, if not in GCE, else a gce.AppAssertionCredentials object. + """ + if not _in_gce_environment(): + return None + + return _get_application_default_credential_GCE() + + @staticmethod + def _implicit_credentials_from_files(): + """Attempts to get implicit credentials from local credential files. + + First checks if the environment variable GOOGLE_APPLICATION_CREDENTIALS + is set with a filename and then falls back to a configuration file (the + "well known" file) associated with the 'gcloud' command line tool. + + Returns: + Credentials object associated with the + GOOGLE_APPLICATION_CREDENTIALS file or the "well known" file if + either exist. If neither file is define, returns None, indicating + no credentials from a file can detected from the current + environment. + """ + credentials_filename = _get_environment_variable_file() + if not credentials_filename: + credentials_filename = _get_well_known_file() + if os.path.isfile(credentials_filename): + extra_help = (' (produced automatically when running' + ' "gcloud auth login" command)') + else: + credentials_filename = None + else: + extra_help = (' (pointed to by ' + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable)') + + if not credentials_filename: + return + + # If we can read the credentials from a file, we don't need to know + # what environment we are in. + SETTINGS.env_name = DEFAULT_ENV_NAME + + try: + return _get_application_default_credential_from_file( + credentials_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + _raise_exception_for_reading_json(credentials_filename, + extra_help, error) + + @classmethod + def _get_implicit_credentials(cls): + """Gets credentials implicitly from the environment. + + Checks environment in order of precedence: + - Environment variable GOOGLE_APPLICATION_CREDENTIALS pointing to + a file with stored credentials information. + - Stored "well known" file associated with `gcloud` command line tool. + - Google App Engine (production and testing) + - Google Compute Engine production environment. + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + # Environ checks (in order). + environ_checkers = [ + cls._implicit_credentials_from_files, + cls._implicit_credentials_from_gae, + cls._implicit_credentials_from_gce, + ] + + for checker in environ_checkers: + credentials = checker() + if credentials is not None: + return credentials + + # If no credentials, fail. + raise ApplicationDefaultCredentialsError(ADC_HELP_MSG) + + @staticmethod + def get_application_default(): + """Get the Application Default Credentials for the current environment. + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + return GoogleCredentials._get_implicit_credentials() + + @staticmethod + def from_stream(credential_filename): + """Create a Credentials object by reading information from a file. + + It returns an object of type GoogleCredentials. + + Args: + credential_filename: the path to the file from where the + credentials are to be read + + Raises: + ApplicationDefaultCredentialsError: raised when the credentials + fail to be retrieved. + """ + if credential_filename and os.path.isfile(credential_filename): + try: + return _get_application_default_credential_from_file( + credential_filename) + except (ApplicationDefaultCredentialsError, ValueError) as error: + extra_help = (' (provided as parameter to the ' + 'from_stream() method)') + _raise_exception_for_reading_json(credential_filename, + extra_help, + error) + else: + raise ApplicationDefaultCredentialsError( + 'The parameter passed to the from_stream() ' + 'method should point to a file.') + + +def _save_private_file(filename, json_contents): + """Saves a file with read-write permissions on for the owner. + + Args: + filename: String. Absolute path to file. + json_contents: JSON serializable object to be saved. + """ + temp_filename = tempfile.mktemp() + file_desc = os.open(temp_filename, os.O_WRONLY | os.O_CREAT, 0o600) + with os.fdopen(file_desc, 'w') as file_handle: + json.dump(json_contents, file_handle, sort_keys=True, + indent=2, separators=(',', ': ')) + shutil.move(temp_filename, filename) + + +def save_to_well_known_file(credentials, well_known_file=None): + """Save the provided GoogleCredentials to the well known file. + + Args: + credentials: the credentials to be saved to the well known file; + it should be an instance of GoogleCredentials + well_known_file: the name of the file where the credentials are to be + saved; this parameter is supposed to be used for + testing only + """ + # TODO(orestica): move this method to tools.py + # once the argparse import gets fixed (it is not present in Python 2.6) + + if well_known_file is None: + well_known_file = _get_well_known_file() + + config_dir = os.path.dirname(well_known_file) + if not os.path.isdir(config_dir): + raise OSError( + 'Config directory does not exist: {0}'.format(config_dir)) + + credentials_data = credentials.serialization_data + _save_private_file(well_known_file, credentials_data) + + +def _get_environment_variable_file(): + application_default_credential_filename = ( + os.environ.get(GOOGLE_APPLICATION_CREDENTIALS, None)) + + if application_default_credential_filename: + if os.path.isfile(application_default_credential_filename): + return application_default_credential_filename + else: + raise ApplicationDefaultCredentialsError( + 'File ' + application_default_credential_filename + + ' (pointed by ' + + GOOGLE_APPLICATION_CREDENTIALS + + ' environment variable) does not exist!') + + +def _get_well_known_file(): + """Get the well known file produced by command 'gcloud auth login'.""" + # TODO(orestica): Revisit this method once gcloud provides a better way + # of pinpointing the exact location of the file. + default_config_dir = os.getenv(_CLOUDSDK_CONFIG_ENV_VAR) + if default_config_dir is None: + if os.name == 'nt': + try: + default_config_dir = os.path.join(os.environ['APPDATA'], + _CLOUDSDK_CONFIG_DIRECTORY) + except KeyError: + # This should never happen unless someone is really + # messing with things. + drive = os.environ.get('SystemDrive', 'C:') + default_config_dir = os.path.join(drive, '\\', + _CLOUDSDK_CONFIG_DIRECTORY) + else: + default_config_dir = os.path.join(os.path.expanduser('~'), + '.config', + _CLOUDSDK_CONFIG_DIRECTORY) + + return os.path.join(default_config_dir, _WELL_KNOWN_CREDENTIALS_FILE) + + +def _get_application_default_credential_from_file(filename): + """Build the Application Default Credentials from file.""" + # read the credentials from the file + with open(filename) as file_obj: + client_credentials = json.load(file_obj) + + credentials_type = client_credentials.get('type') + if credentials_type == AUTHORIZED_USER: + required_fields = set(['client_id', 'client_secret', 'refresh_token']) + elif credentials_type == SERVICE_ACCOUNT: + required_fields = set(['client_id', 'client_email', 'private_key_id', + 'private_key']) + else: + raise ApplicationDefaultCredentialsError( + "'type' field should be defined (and have one of the '" + + AUTHORIZED_USER + "' or '" + SERVICE_ACCOUNT + "' values)") + + missing_fields = required_fields.difference(client_credentials.keys()) + + if missing_fields: + _raise_exception_for_missing_fields(missing_fields) + + if client_credentials['type'] == AUTHORIZED_USER: + return GoogleCredentials( + access_token=None, + client_id=client_credentials['client_id'], + client_secret=client_credentials['client_secret'], + refresh_token=client_credentials['refresh_token'], + token_expiry=None, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + user_agent='Python client library') + else: # client_credentials['type'] == SERVICE_ACCOUNT + from oauth2client import service_account + return service_account._JWTAccessCredentials.from_json_keyfile_dict( + client_credentials) + + +def _raise_exception_for_missing_fields(missing_fields): + raise ApplicationDefaultCredentialsError( + 'The following field(s) must be defined: ' + ', '.join(missing_fields)) + + +def _raise_exception_for_reading_json(credential_file, + extra_help, + error): + raise ApplicationDefaultCredentialsError( + 'An error was encountered while reading json file: ' + + credential_file + extra_help + ': ' + str(error)) + + +def _get_application_default_credential_GAE(): + from oauth2client.contrib.appengine import AppAssertionCredentials + + return AppAssertionCredentials([]) + + +def _get_application_default_credential_GCE(): + from oauth2client.contrib.gce import AppAssertionCredentials + + return AppAssertionCredentials() + + +class AssertionCredentials(GoogleCredentials): + """Abstract Credentials object used for OAuth 2.0 assertion grants. + + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. It must + be subclassed to generate the appropriate assertion string. + + AssertionCredentials objects may be safely pickled and unpickled. + """ + + @_helpers.positional(2) + def __init__(self, assertion_type, user_agent=None, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + **unused_kwargs): + """Constructor for AssertionFlowCredentials. + + Args: + assertion_type: string, assertion type that will be declared to the + auth server + user_agent: string, The HTTP User-Agent to provide for this + application. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. + """ + super(AssertionCredentials, self).__init__( + None, + None, + None, + None, + None, + token_uri, + user_agent, + revoke_uri=revoke_uri) + self.assertion_type = assertion_type + + def _generate_refresh_request_body(self): + assertion = self._generate_assertion() + + body = urllib.parse.urlencode({ + 'assertion': assertion, + 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', + }) + + return body + + def _generate_assertion(self): + """Generate assertion string to be used in the access token request.""" + raise NotImplementedError + + def _revoke(self, http): + """Revokes the access_token and deletes the store if available. + + Args: + http: an object to be used to make HTTP requests. + """ + self._do_revoke(http, self.access_token) + + def sign_blob(self, blob): + """Cryptographically sign a blob (of bytes). + + Args: + blob: bytes, Message to be signed. + + Returns: + tuple, A pair of the private key ID used to sign the blob and + the signed contents. + """ + raise NotImplementedError('This method is abstract.') + + +def _require_crypto_or_die(): + """Ensure we have a crypto library, or throw CryptoUnavailableError. + + The oauth2client.crypt module requires either PyCrypto or PyOpenSSL + to be available in order to function, but these are optional + dependencies. + """ + if not HAS_CRYPTO: + raise CryptoUnavailableError('No crypto library available') + + +@_helpers.positional(2) +def verify_id_token(id_token, audience, http=None, + cert_uri=ID_TOKEN_VERIFICATION_CERTS): + """Verifies a signed JWT id_token. + + This function requires PyOpenSSL and because of that it does not work on + App Engine. + + Args: + id_token: string, A Signed JWT. + audience: string, The audience 'aud' that the token should be for. + http: httplib2.Http, instance to use to make the HTTP request. Callers + should supply an instance that has caching enabled. + cert_uri: string, URI of the certificates in JSON format to + verify the JWT against. + + Returns: + The deserialized JSON in the JWT. + + Raises: + oauth2client.crypt.AppIdentityError: if the JWT fails to verify. + CryptoUnavailableError: if no crypto library is available. + """ + _require_crypto_or_die() + if http is None: + http = transport.get_cached_http() + + resp, content = transport.request(http, cert_uri) + if resp.status == http_client.OK: + certs = json.loads(_helpers._from_bytes(content)) + return crypt.verify_signed_jwt_with_certs(id_token, certs, audience) + else: + raise VerifyJwtTokenError('Status code: {0}'.format(resp.status)) + + +def _extract_id_token(id_token): + """Extract the JSON payload from a JWT. + + Does the extraction w/o checking the signature. + + Args: + id_token: string or bytestring, OAuth 2.0 id_token. + + Returns: + object, The deserialized JSON payload. + """ + if type(id_token) == bytes: + segments = id_token.split(b'.') + else: + segments = id_token.split(u'.') + + if len(segments) != 3: + raise VerifyJwtTokenError( + 'Wrong number of segments in token: {0}'.format(id_token)) + + return json.loads( + _helpers._from_bytes(_helpers._urlsafe_b64decode(segments[1]))) + + +def _parse_exchange_token_response(content): + """Parses response of an exchange token request. + + Most providers return JSON but some (e.g. Facebook) return a + url-encoded string. + + Args: + content: The body of a response + + Returns: + Content as a dictionary object. Note that the dict could be empty, + i.e. {}. That basically indicates a failure. + """ + resp = {} + content = _helpers._from_bytes(content) + try: + resp = json.loads(content) + except Exception: + # different JSON libs raise different exceptions, + # so we just do a catch-all here + resp = _helpers.parse_unique_urlencoded(content) + + # some providers respond with 'expires', others with 'expires_in' + if resp and 'expires' in resp: + resp['expires_in'] = resp.pop('expires') + + return resp + + +@_helpers.positional(4) +def credentials_from_code(client_id, client_secret, scope, code, + redirect_uri='postmessage', http=None, + user_agent=None, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + auth_uri=oauth2client.GOOGLE_AUTH_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + device_uri=oauth2client.GOOGLE_DEVICE_URI, + token_info_uri=oauth2client.GOOGLE_TOKEN_INFO_URI, + pkce=False, + code_verifier=None): + """Exchanges an authorization code for an OAuth2Credentials object. + + Args: + client_id: string, client identifier. + client_secret: string, client secret. + scope: string or iterable of strings, scope(s) to request. + code: string, An authorization code, most likely passed down from + the client + redirect_uri: string, this is generally set to 'postmessage' to match + the redirect_uri that the client specified + http: httplib2.Http, optional http instance to use to do the fetch + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any OAuth + 2.0 provider can be used. + pkce: boolean, default: False, Generate and include a "Proof Key + for Code Exchange" (PKCE) with your authorization and token + requests. This adds security for installed applications that + cannot protect a client_secret. See RFC 7636 for details. + code_verifier: bytestring or None, default: None, parameter passed + as part of the code exchange when pkce=True. If + None, a code_verifier will automatically be + generated as part of step1_get_authorize_url(). See + RFC 7636 for details. + + Returns: + An OAuth2Credentials object. + + Raises: + FlowExchangeError if the authorization code cannot be exchanged for an + access token + """ + flow = OAuth2WebServerFlow(client_id, client_secret, scope, + redirect_uri=redirect_uri, + user_agent=user_agent, + auth_uri=auth_uri, + token_uri=token_uri, + revoke_uri=revoke_uri, + device_uri=device_uri, + token_info_uri=token_info_uri, + pkce=pkce, + code_verifier=code_verifier) + + credentials = flow.step2_exchange(code, http=http) + return credentials + + +@_helpers.positional(3) +def credentials_from_clientsecrets_and_code(filename, scope, code, + message=None, + redirect_uri='postmessage', + http=None, + cache=None, + device_uri=None): + """Returns OAuth2Credentials from a clientsecrets file and an auth code. + + Will create the right kind of Flow based on the contents of the + clientsecrets file or will raise InvalidClientSecretsError for unknown + types of Flows. + + Args: + filename: string, File name of clientsecrets. + scope: string or iterable of strings, scope(s) to request. + code: string, An authorization code, most likely passed down from + the client + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. If message is + provided then sys.exit will be called in the case of an error. + If message in not provided then + clientsecrets.InvalidClientSecretsError will be raised. + redirect_uri: string, this is generally set to 'postmessage' to match + the redirect_uri that the client specified + http: httplib2.Http, optional http instance to use to do the fetch + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. + device_uri: string, OAuth 2.0 device authorization endpoint + pkce: boolean, default: False, Generate and include a "Proof Key + for Code Exchange" (PKCE) with your authorization and token + requests. This adds security for installed applications that + cannot protect a client_secret. See RFC 7636 for details. + code_verifier: bytestring or None, default: None, parameter passed + as part of the code exchange when pkce=True. If + None, a code_verifier will automatically be + generated as part of step1_get_authorize_url(). See + RFC 7636 for details. + + Returns: + An OAuth2Credentials object. + + Raises: + FlowExchangeError: if the authorization code cannot be exchanged for an + access token + UnknownClientSecretsFlowError: if the file describes an unknown kind + of Flow. + clientsecrets.InvalidClientSecretsError: if the clientsecrets file is + invalid. + """ + flow = flow_from_clientsecrets(filename, scope, message=message, + cache=cache, redirect_uri=redirect_uri, + device_uri=device_uri) + credentials = flow.step2_exchange(code, http=http) + return credentials + + +class DeviceFlowInfo(collections.namedtuple('DeviceFlowInfo', ( + 'device_code', 'user_code', 'interval', 'verification_url', + 'user_code_expiry'))): + """Intermediate information the OAuth2 for devices flow.""" + + @classmethod + def FromResponse(cls, response): + """Create a DeviceFlowInfo from a server response. + + The response should be a dict containing entries as described here: + + http://tools.ietf.org/html/draft-ietf-oauth-v2-05#section-3.7.1 + """ + # device_code, user_code, and verification_url are required. + kwargs = { + 'device_code': response['device_code'], + 'user_code': response['user_code'], + } + # The response may list the verification address as either + # verification_url or verification_uri, so we check for both. + verification_url = response.get( + 'verification_url', response.get('verification_uri')) + if verification_url is None: + raise OAuth2DeviceCodeError( + 'No verification_url provided in server response') + kwargs['verification_url'] = verification_url + # expires_in and interval are optional. + kwargs.update({ + 'interval': response.get('interval'), + 'user_code_expiry': None, + }) + if 'expires_in' in response: + kwargs['user_code_expiry'] = ( + _UTCNOW() + + datetime.timedelta(seconds=int(response['expires_in']))) + return cls(**kwargs) + + +def _oauth2_web_server_flow_params(kwargs): + """Configures redirect URI parameters for OAuth2WebServerFlow.""" + params = { + 'access_type': 'offline', + 'response_type': 'code', + } + + params.update(kwargs) + + # Check for the presence of the deprecated approval_prompt param and + # warn appropriately. + approval_prompt = params.get('approval_prompt') + if approval_prompt is not None: + logger.warning( + 'The approval_prompt parameter for OAuth2WebServerFlow is ' + 'deprecated. Please use the prompt parameter instead.') + + if approval_prompt == 'force': + logger.warning( + 'approval_prompt="force" has been adjusted to ' + 'prompt="consent"') + params['prompt'] = 'consent' + del params['approval_prompt'] + + return params + + +class OAuth2WebServerFlow(Flow): + """Does the Web Server Flow for OAuth 2.0. + + OAuth2WebServerFlow objects may be safely pickled and unpickled. + """ + + @_helpers.positional(4) + def __init__(self, client_id, + client_secret=None, + scope=None, + redirect_uri=None, + user_agent=None, + auth_uri=oauth2client.GOOGLE_AUTH_URI, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + login_hint=None, + device_uri=oauth2client.GOOGLE_DEVICE_URI, + token_info_uri=oauth2client.GOOGLE_TOKEN_INFO_URI, + authorization_header=None, + pkce=False, + code_verifier=None, + **kwargs): + """Constructor for OAuth2WebServerFlow. + + The kwargs argument is used to set extra query parameters on the + auth_uri. For example, the access_type and prompt + query parameters can be set via kwargs. + + Args: + client_id: string, client identifier. + client_secret: string client secret. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' + for a non-web-based application, or a URI that + handles the callback from the authorization server. + user_agent: string, HTTP User-Agent to provide for this + application. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + token_uri: string, URI for token endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + login_hint: string, Either an email address or domain. Passing this + hint will either pre-fill the email box on the sign-in + form or select the proper multi-login session, thereby + simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any + OAuth 2.0 provider can be used. + authorization_header: string, For use with OAuth 2.0 providers that + require a client to authenticate using a + header value instead of passing client_secret + in the POST body. + pkce: boolean, default: False, Generate and include a "Proof Key + for Code Exchange" (PKCE) with your authorization and token + requests. This adds security for installed applications that + cannot protect a client_secret. See RFC 7636 for details. + code_verifier: bytestring or None, default: None, parameter passed + as part of the code exchange when pkce=True. If + None, a code_verifier will automatically be + generated as part of step1_get_authorize_url(). See + RFC 7636 for details. + **kwargs: dict, The keyword arguments are all optional and required + parameters for the OAuth calls. + """ + # scope is a required argument, but to preserve backwards-compatibility + # we don't want to rearrange the positional arguments + if scope is None: + raise TypeError("The value of scope must not be None") + self.client_id = client_id + self.client_secret = client_secret + self.scope = _helpers.scopes_to_string(scope) + self.redirect_uri = redirect_uri + self.login_hint = login_hint + self.user_agent = user_agent + self.auth_uri = auth_uri + self.token_uri = token_uri + self.revoke_uri = revoke_uri + self.device_uri = device_uri + self.token_info_uri = token_info_uri + self.authorization_header = authorization_header + self._pkce = pkce + self.code_verifier = code_verifier + self.params = _oauth2_web_server_flow_params(kwargs) + + @_helpers.positional(1) + def step1_get_authorize_url(self, redirect_uri=None, state=None): + """Returns a URI to redirect to the provider. + + Args: + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' + for a non-web-based application, or a URI that + handles the callback from the authorization server. + This parameter is deprecated, please move to passing + the redirect_uri in via the constructor. + state: string, Opaque state string which is passed through the + OAuth2 flow and returned to the client as a query parameter + in the callback. + + Returns: + A URI as a string to redirect the user to begin the authorization + flow. + """ + if redirect_uri is not None: + logger.warning(( + 'The redirect_uri parameter for ' + 'OAuth2WebServerFlow.step1_get_authorize_url is deprecated. ' + 'Please move to passing the redirect_uri in via the ' + 'constructor.')) + self.redirect_uri = redirect_uri + + if self.redirect_uri is None: + raise ValueError('The value of redirect_uri must not be None.') + + query_params = { + 'client_id': self.client_id, + 'redirect_uri': self.redirect_uri, + 'scope': self.scope, + } + if state is not None: + query_params['state'] = state + if self.login_hint is not None: + query_params['login_hint'] = self.login_hint + if self._pkce: + if not self.code_verifier: + self.code_verifier = _pkce.code_verifier() + challenge = _pkce.code_challenge(self.code_verifier) + query_params['code_challenge'] = challenge + query_params['code_challenge_method'] = 'S256' + + query_params.update(self.params) + return _helpers.update_query_params(self.auth_uri, query_params) + + @_helpers.positional(1) + def step1_get_device_and_user_codes(self, http=None): + """Returns a user code and the verification URL where to enter it + + Returns: + A user code as a string for the user to authorize the application + An URL as a string where the user has to enter the code + """ + if self.device_uri is None: + raise ValueError('The value of device_uri must not be None.') + + body = urllib.parse.urlencode({ + 'client_id': self.client_id, + 'scope': self.scope, + }) + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } + + if self.user_agent is not None: + headers['user-agent'] = self.user_agent + + if http is None: + http = transport.get_http_object() + + resp, content = transport.request( + http, self.device_uri, method='POST', body=body, headers=headers) + content = _helpers._from_bytes(content) + if resp.status == http_client.OK: + try: + flow_info = json.loads(content) + except ValueError as exc: + raise OAuth2DeviceCodeError( + 'Could not parse server response as JSON: "{0}", ' + 'error: "{1}"'.format(content, exc)) + return DeviceFlowInfo.FromResponse(flow_info) + else: + error_msg = 'Invalid response {0}.'.format(resp.status) + try: + error_dict = json.loads(content) + if 'error' in error_dict: + error_msg += ' Error: {0}'.format(error_dict['error']) + except ValueError: + # Couldn't decode a JSON response, stick with the + # default message. + pass + raise OAuth2DeviceCodeError(error_msg) + + @_helpers.positional(2) + def step2_exchange(self, code=None, http=None, device_flow_info=None): + """Exchanges a code for OAuth2Credentials. + + Args: + code: string, a dict-like object, or None. For a non-device + flow, this is either the response code as a string, or a + dictionary of query parameters to the redirect_uri. For a + device flow, this should be None. + http: httplib2.Http, optional http instance to use when fetching + credentials. + device_flow_info: DeviceFlowInfo, return value from step1 in the + case of a device flow. + + Returns: + An OAuth2Credentials object that can be used to authorize requests. + + Raises: + FlowExchangeError: if a problem occurred exchanging the code for a + refresh_token. + ValueError: if code and device_flow_info are both provided or both + missing. + """ + if code is None and device_flow_info is None: + raise ValueError('No code or device_flow_info provided.') + if code is not None and device_flow_info is not None: + raise ValueError('Cannot provide both code and device_flow_info.') + + if code is None: + code = device_flow_info.device_code + elif not isinstance(code, (six.string_types, six.binary_type)): + if 'code' not in code: + raise FlowExchangeError(code.get( + 'error', 'No code was supplied in the query parameters.')) + code = code['code'] + + post_data = { + 'client_id': self.client_id, + 'code': code, + 'scope': self.scope, + } + if self.client_secret is not None: + post_data['client_secret'] = self.client_secret + if self._pkce: + post_data['code_verifier'] = self.code_verifier + if device_flow_info is not None: + post_data['grant_type'] = 'http://oauth.net/grant_type/device/1.0' + else: + post_data['grant_type'] = 'authorization_code' + post_data['redirect_uri'] = self.redirect_uri + body = urllib.parse.urlencode(post_data) + headers = { + 'content-type': 'application/x-www-form-urlencoded', + } + if self.authorization_header is not None: + headers['Authorization'] = self.authorization_header + if self.user_agent is not None: + headers['user-agent'] = self.user_agent + + if http is None: + http = transport.get_http_object() + + resp, content = transport.request( + http, self.token_uri, method='POST', body=body, headers=headers) + d = _parse_exchange_token_response(content) + if resp.status == http_client.OK and 'access_token' in d: + access_token = d['access_token'] + refresh_token = d.get('refresh_token', None) + if not refresh_token: + logger.info( + 'Received token response with no refresh_token. Consider ' + "reauthenticating with prompt='consent'.") + token_expiry = None + if 'expires_in' in d: + delta = datetime.timedelta(seconds=int(d['expires_in'])) + token_expiry = delta + _UTCNOW() + + extracted_id_token = None + id_token_jwt = None + if 'id_token' in d: + extracted_id_token = _extract_id_token(d['id_token']) + id_token_jwt = d['id_token'] + + logger.info('Successfully retrieved access token') + return OAuth2Credentials( + access_token, self.client_id, self.client_secret, + refresh_token, token_expiry, self.token_uri, self.user_agent, + revoke_uri=self.revoke_uri, id_token=extracted_id_token, + id_token_jwt=id_token_jwt, token_response=d, scopes=self.scope, + token_info_uri=self.token_info_uri) + else: + logger.info('Failed to retrieve access token: %s', content) + if 'error' in d: + # you never know what those providers got to say + error_msg = (str(d['error']) + + str(d.get('error_description', ''))) + else: + error_msg = 'Invalid response: {0}.'.format(str(resp.status)) + raise FlowExchangeError(error_msg) + + +@_helpers.positional(2) +def flow_from_clientsecrets(filename, scope, redirect_uri=None, + message=None, cache=None, login_hint=None, + device_uri=None, pkce=None, code_verifier=None, + prompt=None): + """Create a Flow from a clientsecrets file. + + Will create the right kind of Flow based on the contents of the + clientsecrets file or will raise InvalidClientSecretsError for unknown + types of Flows. + + Args: + filename: string, File name of client secrets. + scope: string or iterable of strings, scope(s) to request. + redirect_uri: string, Either the string 'urn:ietf:wg:oauth:2.0:oob' for + a non-web-based application, or a URI that handles the + callback from the authorization server. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. If message is + provided then sys.exit will be called in the case of an error. + If message in not provided then + clientsecrets.InvalidClientSecretsError will be raised. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. + login_hint: string, Either an email address or domain. Passing this + hint will either pre-fill the email box on the sign-in form + or select the proper multi-login session, thereby + simplifying the login flow. + device_uri: string, URI for device authorization endpoint. For + convenience defaults to Google's endpoints but any + OAuth 2.0 provider can be used. + + Returns: + A Flow object. + + Raises: + UnknownClientSecretsFlowError: if the file describes an unknown kind of + Flow. + clientsecrets.InvalidClientSecretsError: if the clientsecrets file is + invalid. + """ + try: + client_type, client_info = clientsecrets.loadfile(filename, + cache=cache) + if client_type in (clientsecrets.TYPE_WEB, + clientsecrets.TYPE_INSTALLED): + constructor_kwargs = { + 'redirect_uri': redirect_uri, + 'auth_uri': client_info['auth_uri'], + 'token_uri': client_info['token_uri'], + 'login_hint': login_hint, + } + revoke_uri = client_info.get('revoke_uri') + optional = ( + 'revoke_uri', + 'device_uri', + 'pkce', + 'code_verifier', + 'prompt' + ) + for param in optional: + if locals()[param] is not None: + constructor_kwargs[param] = locals()[param] + + return OAuth2WebServerFlow( + client_info['client_id'], client_info['client_secret'], + scope, **constructor_kwargs) + + except clientsecrets.InvalidClientSecretsError as e: + if message is not None: + if e.args: + message = ('The client secrets were invalid: ' + '\n{0}\n{1}'.format(e, message)) + sys.exit(message) + else: + raise + else: + raise UnknownClientSecretsFlowError( + 'This OAuth 2.0 flow is unsupported: {0!r}'.format(client_type)) diff --git a/src/oauth2client/oauth2client/clientsecrets.py b/src/oauth2client/oauth2client/clientsecrets.py new file mode 100644 index 00000000..1598142e --- /dev/null +++ b/src/oauth2client/oauth2client/clientsecrets.py @@ -0,0 +1,173 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for reading OAuth 2.0 client secret files. + +A client_secrets.json file contains all the information needed to interact with +an OAuth 2.0 protected service. +""" + +import json + +import six + + +# Properties that make a client_secrets.json file valid. +TYPE_WEB = 'web' +TYPE_INSTALLED = 'installed' + +VALID_CLIENT = { + TYPE_WEB: { + 'required': [ + 'client_id', + 'client_secret', + 'redirect_uris', + 'auth_uri', + 'token_uri', + ], + 'string': [ + 'client_id', + 'client_secret', + ], + }, + TYPE_INSTALLED: { + 'required': [ + 'client_id', + 'client_secret', + 'redirect_uris', + 'auth_uri', + 'token_uri', + ], + 'string': [ + 'client_id', + 'client_secret', + ], + }, +} + + +class Error(Exception): + """Base error for this module.""" + + +class InvalidClientSecretsError(Error): + """Format of ClientSecrets file is invalid.""" + + +def _validate_clientsecrets(clientsecrets_dict): + """Validate parsed client secrets from a file. + + Args: + clientsecrets_dict: dict, a dictionary holding the client secrets. + + Returns: + tuple, a string of the client type and the information parsed + from the file. + """ + _INVALID_FILE_FORMAT_MSG = ( + 'Invalid file format. See ' + 'https://developers.google.com/api-client-library/' + 'python/guide/aaa_client_secrets') + + if clientsecrets_dict is None: + raise InvalidClientSecretsError(_INVALID_FILE_FORMAT_MSG) + try: + (client_type, client_info), = clientsecrets_dict.items() + except (ValueError, AttributeError): + raise InvalidClientSecretsError( + _INVALID_FILE_FORMAT_MSG + ' ' + 'Expected a JSON object with a single property for a "web" or ' + '"installed" application') + + if client_type not in VALID_CLIENT: + raise InvalidClientSecretsError( + 'Unknown client type: {0}.'.format(client_type)) + + for prop_name in VALID_CLIENT[client_type]['required']: + if prop_name not in client_info: + raise InvalidClientSecretsError( + 'Missing property "{0}" in a client type of "{1}".'.format( + prop_name, client_type)) + for prop_name in VALID_CLIENT[client_type]['string']: + if client_info[prop_name].startswith('[['): + raise InvalidClientSecretsError( + 'Property "{0}" is not configured.'.format(prop_name)) + return client_type, client_info + + +def load(fp): + obj = json.load(fp) + return _validate_clientsecrets(obj) + + +def loads(s): + obj = json.loads(s) + return _validate_clientsecrets(obj) + + +def _loadfile(filename): + try: + with open(filename, 'r') as fp: + obj = json.load(fp) + except IOError as exc: + raise InvalidClientSecretsError('Error opening file', exc.filename, + exc.strerror, exc.errno) + return _validate_clientsecrets(obj) + + +def loadfile(filename, cache=None): + """Loading of client_secrets JSON file, optionally backed by a cache. + + Typical cache storage would be App Engine memcache service, + but you can pass in any other cache client that implements + these methods: + + * ``get(key, namespace=ns)`` + * ``set(key, value, namespace=ns)`` + + Usage:: + + # without caching + client_type, client_info = loadfile('secrets.json') + # using App Engine memcache service + from google.appengine.api import memcache + client_type, client_info = loadfile('secrets.json', cache=memcache) + + Args: + filename: string, Path to a client_secrets.json file on a filesystem. + cache: An optional cache service client that implements get() and set() + methods. If not specified, the file is always being loaded from + a filesystem. + + Raises: + InvalidClientSecretsError: In case of a validation error or some + I/O failure. Can happen only on cache miss. + + Returns: + (client_type, client_info) tuple, as _loadfile() normally would. + JSON contents is validated only during first load. Cache hits are not + validated. + """ + _SECRET_NAMESPACE = 'oauth2client:secrets#ns' + + if not cache: + return _loadfile(filename) + + obj = cache.get(filename, namespace=_SECRET_NAMESPACE) + if obj is None: + client_type, client_info = _loadfile(filename) + obj = {client_type: client_info} + cache.set(filename, obj, namespace=_SECRET_NAMESPACE) + + return next(six.iteritems(obj)) diff --git a/src/oauth2client/oauth2client/contrib/__init__.py b/src/oauth2client/oauth2client/contrib/__init__.py new file mode 100644 index 00000000..ecfd06c9 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/__init__.py @@ -0,0 +1,6 @@ +"""Contributed modules. + +Contrib contains modules that are not considered part of the core oauth2client +library but provide additional functionality. These modules are intended to +make it easier to use oauth2client. +""" diff --git a/src/oauth2client/oauth2client/contrib/_appengine_ndb.py b/src/oauth2client/oauth2client/contrib/_appengine_ndb.py new file mode 100644 index 00000000..c863e8f4 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/_appengine_ndb.py @@ -0,0 +1,163 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google App Engine utilities helper. + +Classes that directly require App Engine's ndb library. Provided +as a separate module in case of failure to import ndb while +other App Engine libraries are present. +""" + +import logging + +from google.appengine.ext import ndb + +from oauth2client import client + + +NDB_KEY = ndb.Key +"""Key constant used by :mod:`oauth2client.contrib.appengine`.""" + +NDB_MODEL = ndb.Model +"""Model constant used by :mod:`oauth2client.contrib.appengine`.""" + +_LOGGER = logging.getLogger(__name__) + + +class SiteXsrfSecretKeyNDB(ndb.Model): + """NDB Model for storage for the sites XSRF secret key. + + Since this model uses the same kind as SiteXsrfSecretKey, it can be + used interchangeably. This simply provides an NDB model for interacting + with the same data the DB model interacts with. + + There should only be one instance stored of this model, the one used + for the site. + """ + secret = ndb.StringProperty() + + @classmethod + def _get_kind(cls): + """Return the kind name for this class.""" + return 'SiteXsrfSecretKey' + + +class FlowNDBProperty(ndb.PickleProperty): + """App Engine NDB datastore Property for Flow. + + Serves the same purpose as the DB FlowProperty, but for NDB models. + Since PickleProperty inherits from BlobProperty, the underlying + representation of the data in the datastore will be the same as in the + DB case. + + Utility property that allows easy storage and retrieval of an + oauth2client.Flow + """ + + def _validate(self, value): + """Validates a value as a proper Flow object. + + Args: + value: A value to be set on the property. + + Raises: + TypeError if the value is not an instance of Flow. + """ + _LOGGER.info('validate: Got type %s', type(value)) + if value is not None and not isinstance(value, client.Flow): + raise TypeError( + 'Property {0} must be convertible to a flow ' + 'instance; received: {1}.'.format(self._name, value)) + + +class CredentialsNDBProperty(ndb.BlobProperty): + """App Engine NDB datastore Property for Credentials. + + Serves the same purpose as the DB CredentialsProperty, but for NDB + models. Since CredentialsProperty stores data as a blob and this + inherits from BlobProperty, the data in the datastore will be the same + as in the DB case. + + Utility property that allows easy storage and retrieval of Credentials + and subclasses. + """ + + def _validate(self, value): + """Validates a value as a proper credentials object. + + Args: + value: A value to be set on the property. + + Raises: + TypeError if the value is not an instance of Credentials. + """ + _LOGGER.info('validate: Got type %s', type(value)) + if value is not None and not isinstance(value, client.Credentials): + raise TypeError( + 'Property {0} must be convertible to a credentials ' + 'instance; received: {1}.'.format(self._name, value)) + + def _to_base_type(self, value): + """Converts our validated value to a JSON serialized string. + + Args: + value: A value to be set in the datastore. + + Returns: + A JSON serialized version of the credential, else '' if value + is None. + """ + if value is None: + return '' + else: + return value.to_json() + + def _from_base_type(self, value): + """Converts our stored JSON string back to the desired type. + + Args: + value: A value from the datastore to be converted to the + desired type. + + Returns: + A deserialized Credentials (or subclass) object, else None if + the value can't be parsed. + """ + if not value: + return None + try: + # Uses the from_json method of the implied class of value + credentials = client.Credentials.new_from_json(value) + except ValueError: + credentials = None + return credentials + + +class CredentialsNDBModel(ndb.Model): + """NDB Model for storage of OAuth 2.0 Credentials + + Since this model uses the same kind as CredentialsModel and has a + property which can serialize and deserialize Credentials correctly, it + can be used interchangeably with a CredentialsModel to access, insert + and delete the same entities. This simply provides an NDB model for + interacting with the same data the DB model interacts with. + + Storage of the model is keyed by the user.user_id(). + """ + credentials = CredentialsNDBProperty() + + @classmethod + def _get_kind(cls): + """Return the kind name for this class.""" + return 'CredentialsModel' diff --git a/src/oauth2client/oauth2client/contrib/_metadata.py b/src/oauth2client/oauth2client/contrib/_metadata.py new file mode 100644 index 00000000..564cd398 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/_metadata.py @@ -0,0 +1,118 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Provides helper methods for talking to the Compute Engine metadata server. + +See https://cloud.google.com/compute/docs/metadata +""" + +import datetime +import json +import os + +from six.moves import http_client +from six.moves.urllib import parse as urlparse + +from oauth2client import _helpers +from oauth2client import client +from oauth2client import transport + + +METADATA_ROOT = 'http://{}/computeMetadata/v1/'.format( + os.getenv('GCE_METADATA_ROOT', 'metadata.google.internal')) +METADATA_HEADERS = {'Metadata-Flavor': 'Google'} + + +def get(http, path, root=METADATA_ROOT, recursive=None): + """Fetch a resource from the metadata server. + + Args: + http: an object to be used to make HTTP requests. + path: A string indicating the resource to retrieve. For example, + 'instance/service-accounts/default' + root: A string indicating the full path to the metadata server root. + recursive: A boolean indicating whether to do a recursive query of + metadata. See + https://cloud.google.com/compute/docs/metadata#aggcontents + + Returns: + A dictionary if the metadata server returns JSON, otherwise a string. + + Raises: + http_client.HTTPException if an error corrured while + retrieving metadata. + """ + url = urlparse.urljoin(root, path) + url = _helpers._add_query_parameter(url, 'recursive', recursive) + + response, content = transport.request( + http, url, headers=METADATA_HEADERS) + + if response.status == http_client.OK: + decoded = _helpers._from_bytes(content) + if response['content-type'] == 'application/json': + return json.loads(decoded) + else: + return decoded + else: + raise http_client.HTTPException( + 'Failed to retrieve {0} from the Google Compute Engine' + 'metadata service. Response:\n{1}'.format(url, response)) + + +def get_service_account_info(http, service_account='default'): + """Get information about a service account from the metadata server. + + Args: + http: an object to be used to make HTTP requests. + service_account: An email specifying the service account for which to + look up information. Default will be information for the "default" + service account of the current compute engine instance. + + Returns: + A dictionary with information about the specified service account, + for example: + + { + 'email': '...', + 'scopes': ['scope', ...], + 'aliases': ['default', '...'] + } + """ + return get( + http, + 'instance/service-accounts/{0}/'.format(service_account), + recursive=True) + + +def get_token(http, service_account='default'): + """Fetch an oauth token for the + + Args: + http: an object to be used to make HTTP requests. + service_account: An email specifying the service account this token + should represent. Default will be a token for the "default" service + account of the current compute engine instance. + + Returns: + A tuple of (access token, token expiration), where access token is the + access token as a string and token expiration is a datetime object + that indicates when the access token will expire. + """ + token_json = get( + http, + 'instance/service-accounts/{0}/token'.format(service_account)) + token_expiry = client._UTCNOW() + datetime.timedelta( + seconds=token_json['expires_in']) + return token_json['access_token'], token_expiry diff --git a/src/oauth2client/oauth2client/contrib/appengine.py b/src/oauth2client/oauth2client/contrib/appengine.py new file mode 100644 index 00000000..c1326eeb --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/appengine.py @@ -0,0 +1,910 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Google App Engine + +Utilities for making it easier to use OAuth 2.0 on Google App Engine. +""" + +import cgi +import json +import logging +import os +import pickle +import threading + +from google.appengine.api import app_identity +from google.appengine.api import memcache +from google.appengine.api import users +from google.appengine.ext import db +from google.appengine.ext.webapp.util import login_required +import webapp2 as webapp + +import oauth2client +from oauth2client import _helpers +from oauth2client import client +from oauth2client import clientsecrets +from oauth2client import transport +from oauth2client.contrib import xsrfutil + +# This is a temporary fix for a Google internal issue. +try: + from oauth2client.contrib import _appengine_ndb +except ImportError: # pragma: NO COVER + _appengine_ndb = None + + +logger = logging.getLogger(__name__) + +OAUTH2CLIENT_NAMESPACE = 'oauth2client#ns' + +XSRF_MEMCACHE_ID = 'xsrf_secret_key' + +if _appengine_ndb is None: # pragma: NO COVER + CredentialsNDBModel = None + CredentialsNDBProperty = None + FlowNDBProperty = None + _NDB_KEY = None + _NDB_MODEL = None + SiteXsrfSecretKeyNDB = None +else: + CredentialsNDBModel = _appengine_ndb.CredentialsNDBModel + CredentialsNDBProperty = _appengine_ndb.CredentialsNDBProperty + FlowNDBProperty = _appengine_ndb.FlowNDBProperty + _NDB_KEY = _appengine_ndb.NDB_KEY + _NDB_MODEL = _appengine_ndb.NDB_MODEL + SiteXsrfSecretKeyNDB = _appengine_ndb.SiteXsrfSecretKeyNDB + + +def _safe_html(s): + """Escape text to make it safe to display. + + Args: + s: string, The text to escape. + + Returns: + The escaped text as a string. + """ + return cgi.escape(s, quote=1).replace("'", ''') + + +class SiteXsrfSecretKey(db.Model): + """Storage for the sites XSRF secret key. + + There will only be one instance stored of this model, the one used for the + site. + """ + secret = db.StringProperty() + + +def _generate_new_xsrf_secret_key(): + """Returns a random XSRF secret key.""" + return os.urandom(16).encode("hex") + + +def xsrf_secret_key(): + """Return the secret key for use for XSRF protection. + + If the Site entity does not have a secret key, this method will also create + one and persist it. + + Returns: + The secret key. + """ + secret = memcache.get(XSRF_MEMCACHE_ID, namespace=OAUTH2CLIENT_NAMESPACE) + if not secret: + # Load the one and only instance of SiteXsrfSecretKey. + model = SiteXsrfSecretKey.get_or_insert(key_name='site') + if not model.secret: + model.secret = _generate_new_xsrf_secret_key() + model.put() + secret = model.secret + memcache.add(XSRF_MEMCACHE_ID, secret, + namespace=OAUTH2CLIENT_NAMESPACE) + + return str(secret) + + +class AppAssertionCredentials(client.AssertionCredentials): + """Credentials object for App Engine Assertion Grants + + This object will allow an App Engine application to identify itself to + Google and other OAuth 2.0 servers that can verify assertions. It can be + used for the purpose of accessing data stored under an account assigned to + the App Engine application itself. + + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + """ + + @_helpers.positional(2) + def __init__(self, scope, **kwargs): + """Constructor for AppAssertionCredentials + + Args: + scope: string or iterable of strings, scope(s) of the credentials + being requested. + **kwargs: optional keyword args, including: + service_account_id: service account id of the application. If None + or unspecified, the default service account for + the app is used. + """ + self.scope = _helpers.scopes_to_string(scope) + self._kwargs = kwargs + self.service_account_id = kwargs.get('service_account_id', None) + self._service_account_email = None + + # Assertion type is no longer used, but still in the + # parent class signature. + super(AppAssertionCredentials, self).__init__(None) + + @classmethod + def from_json(cls, json_data): + data = json.loads(json_data) + return AppAssertionCredentials(data['scope']) + + def _refresh(self, http): + """Refreshes the access token. + + Since the underlying App Engine app_identity implementation does its + own caching we can skip all the storage hoops and just to a refresh + using the API. + + Args: + http: unused HTTP object + + Raises: + AccessTokenRefreshError: When the refresh fails. + """ + try: + scopes = self.scope.split() + (token, _) = app_identity.get_access_token( + scopes, service_account_id=self.service_account_id) + except app_identity.Error as e: + raise client.AccessTokenRefreshError(str(e)) + self.access_token = token + + @property + def serialization_data(self): + raise NotImplementedError('Cannot serialize credentials ' + 'for Google App Engine.') + + def create_scoped_required(self): + return not self.scope + + def create_scoped(self, scopes): + return AppAssertionCredentials(scopes, **self._kwargs) + + def sign_blob(self, blob): + """Cryptographically sign a blob (of bytes). + + Implements abstract method + :meth:`oauth2client.client.AssertionCredentials.sign_blob`. + + Args: + blob: bytes, Message to be signed. + + Returns: + tuple, A pair of the private key ID used to sign the blob and + the signed contents. + """ + return app_identity.sign_blob(blob) + + @property + def service_account_email(self): + """Get the email for the current service account. + + Returns: + string, The email associated with the Google App Engine + service account. + """ + if self._service_account_email is None: + self._service_account_email = ( + app_identity.get_service_account_name()) + return self._service_account_email + + +class FlowProperty(db.Property): + """App Engine datastore Property for Flow. + + Utility property that allows easy storage and retrieval of an + oauth2client.Flow + """ + + # Tell what the user type is. + data_type = client.Flow + + # For writing to datastore. + def get_value_for_datastore(self, model_instance): + flow = super(FlowProperty, self).get_value_for_datastore( + model_instance) + return db.Blob(pickle.dumps(flow)) + + # For reading from datastore. + def make_value_from_datastore(self, value): + if value is None: + return None + return pickle.loads(value) + + def validate(self, value): + if value is not None and not isinstance(value, client.Flow): + raise db.BadValueError( + 'Property {0} must be convertible ' + 'to a FlowThreeLegged instance ({1})'.format(self.name, value)) + return super(FlowProperty, self).validate(value) + + def empty(self, value): + return not value + + +class CredentialsProperty(db.Property): + """App Engine datastore Property for Credentials. + + Utility property that allows easy storage and retrieval of + oauth2client.Credentials + """ + + # Tell what the user type is. + data_type = client.Credentials + + # For writing to datastore. + def get_value_for_datastore(self, model_instance): + logger.info("get: Got type " + str(type(model_instance))) + cred = super(CredentialsProperty, self).get_value_for_datastore( + model_instance) + if cred is None: + cred = '' + else: + cred = cred.to_json() + return db.Blob(cred) + + # For reading from datastore. + def make_value_from_datastore(self, value): + logger.info("make: Got type " + str(type(value))) + if value is None: + return None + if len(value) == 0: + return None + try: + credentials = client.Credentials.new_from_json(value) + except ValueError: + credentials = None + return credentials + + def validate(self, value): + value = super(CredentialsProperty, self).validate(value) + logger.info("validate: Got type " + str(type(value))) + if value is not None and not isinstance(value, client.Credentials): + raise db.BadValueError( + 'Property {0} must be convertible ' + 'to a Credentials instance ({1})'.format(self.name, value)) + return value + + +class StorageByKeyName(client.Storage): + """Store and retrieve a credential to and from the App Engine datastore. + + This Storage helper presumes the Credentials have been stored as a + CredentialsProperty or CredentialsNDBProperty on a datastore model class, + and that entities are stored by key_name. + """ + + @_helpers.positional(4) + def __init__(self, model, key_name, property_name, cache=None, user=None): + """Constructor for Storage. + + Args: + model: db.Model or ndb.Model, model class + key_name: string, key name for the entity that has the credentials + property_name: string, name of the property that is a + CredentialsProperty or CredentialsNDBProperty. + cache: memcache, a write-through cache to put in front of the + datastore. If the model you are using is an NDB model, using + a cache will be redundant since the model uses an instance + cache and memcache for you. + user: users.User object, optional. Can be used to grab user ID as a + key_name if no key name is specified. + """ + super(StorageByKeyName, self).__init__() + + if key_name is None: + if user is None: + raise ValueError('StorageByKeyName called with no ' + 'key name or user.') + key_name = user.user_id() + + self._model = model + self._key_name = key_name + self._property_name = property_name + self._cache = cache + + def _is_ndb(self): + """Determine whether the model of the instance is an NDB model. + + Returns: + Boolean indicating whether or not the model is an NDB or DB model. + """ + # issubclass will fail if one of the arguments is not a class, only + # need worry about new-style classes since ndb and db models are + # new-style + if isinstance(self._model, type): + if _NDB_MODEL is not None and issubclass(self._model, _NDB_MODEL): + return True + elif issubclass(self._model, db.Model): + return False + + raise TypeError( + 'Model class not an NDB or DB model: {0}.'.format(self._model)) + + def _get_entity(self): + """Retrieve entity from datastore. + + Uses a different model method for db or ndb models. + + Returns: + Instance of the model corresponding to the current storage object + and stored using the key name of the storage object. + """ + if self._is_ndb(): + return self._model.get_by_id(self._key_name) + else: + return self._model.get_by_key_name(self._key_name) + + def _delete_entity(self): + """Delete entity from datastore. + + Attempts to delete using the key_name stored on the object, whether or + not the given key is in the datastore. + """ + if self._is_ndb(): + _NDB_KEY(self._model, self._key_name).delete() + else: + entity_key = db.Key.from_path(self._model.kind(), self._key_name) + db.delete(entity_key) + + @db.non_transactional(allow_existing=True) + def locked_get(self): + """Retrieve Credential from datastore. + + Returns: + oauth2client.Credentials + """ + credentials = None + if self._cache: + json = self._cache.get(self._key_name) + if json: + credentials = client.Credentials.new_from_json(json) + if credentials is None: + entity = self._get_entity() + if entity is not None: + credentials = getattr(entity, self._property_name) + if self._cache: + self._cache.set(self._key_name, credentials.to_json()) + + if credentials and hasattr(credentials, 'set_store'): + credentials.set_store(self) + return credentials + + @db.non_transactional(allow_existing=True) + def locked_put(self, credentials): + """Write a Credentials to the datastore. + + Args: + credentials: Credentials, the credentials to store. + """ + entity = self._model.get_or_insert(self._key_name) + setattr(entity, self._property_name, credentials) + entity.put() + if self._cache: + self._cache.set(self._key_name, credentials.to_json()) + + @db.non_transactional(allow_existing=True) + def locked_delete(self): + """Delete Credential from datastore.""" + + if self._cache: + self._cache.delete(self._key_name) + + self._delete_entity() + + +class CredentialsModel(db.Model): + """Storage for OAuth 2.0 Credentials + + Storage of the model is keyed by the user.user_id(). + """ + credentials = CredentialsProperty() + + +def _build_state_value(request_handler, user): + """Composes the value for the 'state' parameter. + + Packs the current request URI and an XSRF token into an opaque string that + can be passed to the authentication server via the 'state' parameter. + + Args: + request_handler: webapp.RequestHandler, The request. + user: google.appengine.api.users.User, The current user. + + Returns: + The state value as a string. + """ + uri = request_handler.request.url + token = xsrfutil.generate_token(xsrf_secret_key(), user.user_id(), + action_id=str(uri)) + return uri + ':' + token + + +def _parse_state_value(state, user): + """Parse the value of the 'state' parameter. + + Parses the value and validates the XSRF token in the state parameter. + + Args: + state: string, The value of the state parameter. + user: google.appengine.api.users.User, The current user. + + Returns: + The redirect URI, or None if XSRF token is not valid. + """ + uri, token = state.rsplit(':', 1) + if xsrfutil.validate_token(xsrf_secret_key(), token, user.user_id(), + action_id=uri): + return uri + else: + return None + + +class OAuth2Decorator(object): + """Utility for making OAuth 2.0 easier. + + Instantiate and then use with oauth_required or oauth_aware + as decorators on webapp.RequestHandler methods. + + :: + + decorator = OAuth2Decorator( + client_id='837...ent.com', + client_secret='Qh...wwI', + scope='https://www.googleapis.com/auth/plus') + + class MainHandler(webapp.RequestHandler): + @decorator.oauth_required + def get(self): + http = decorator.http() + # http is authorized with the user's Credentials and can be + # used in API calls + + """ + + def set_credentials(self, credentials): + self._tls.credentials = credentials + + def get_credentials(self): + """A thread local Credentials object. + + Returns: + A client.Credentials object, or None if credentials hasn't been set + in this thread yet, which may happen when calling has_credentials + inside oauth_aware. + """ + return getattr(self._tls, 'credentials', None) + + credentials = property(get_credentials, set_credentials) + + def set_flow(self, flow): + self._tls.flow = flow + + def get_flow(self): + """A thread local Flow object. + + Returns: + A credentials.Flow object, or None if the flow hasn't been set in + this thread yet, which happens in _create_flow() since Flows are + created lazily. + """ + return getattr(self._tls, 'flow', None) + + flow = property(get_flow, set_flow) + + @_helpers.positional(4) + def __init__(self, client_id, client_secret, scope, + auth_uri=oauth2client.GOOGLE_AUTH_URI, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + user_agent=None, + message=None, + callback_path='/oauth2callback', + token_response_param=None, + _storage_class=StorageByKeyName, + _credentials_class=CredentialsModel, + _credentials_property_name='credentials', + **kwargs): + """Constructor for OAuth2Decorator + + Args: + client_id: string, client identifier. + client_secret: string client secret. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + auth_uri: string, URI for authorization endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 provider + can be used. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + user_agent: string, User agent of your application, default to + None. + message: Message to display if there are problems with the + OAuth 2.0 configuration. The message may contain HTML and + will be presented on the web interface for any method that + uses the decorator. + callback_path: string, The absolute path to use as the callback + URI. Note that this must match up with the URI given + when registering the application in the APIs + Console. + token_response_param: string. If provided, the full JSON response + to the access token request will be encoded + and included in this query parameter in the + callback URI. This is useful with providers + (e.g. wordpress.com) that include extra + fields that the client may want. + _storage_class: "Protected" keyword argument not typically provided + to this constructor. A storage class to aid in + storing a Credentials object for a user in the + datastore. Defaults to StorageByKeyName. + _credentials_class: "Protected" keyword argument not typically + provided to this constructor. A db or ndb Model + class to hold credentials. Defaults to + CredentialsModel. + _credentials_property_name: "Protected" keyword argument not + typically provided to this constructor. + A string indicating the name of the + field on the _credentials_class where a + Credentials object will be stored. + Defaults to 'credentials'. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. + """ + self._tls = threading.local() + self.flow = None + self.credentials = None + self._client_id = client_id + self._client_secret = client_secret + self._scope = _helpers.scopes_to_string(scope) + self._auth_uri = auth_uri + self._token_uri = token_uri + self._revoke_uri = revoke_uri + self._user_agent = user_agent + self._kwargs = kwargs + self._message = message + self._in_error = False + self._callback_path = callback_path + self._token_response_param = token_response_param + self._storage_class = _storage_class + self._credentials_class = _credentials_class + self._credentials_property_name = _credentials_property_name + + def _display_error_message(self, request_handler): + request_handler.response.out.write('') + request_handler.response.out.write(_safe_html(self._message)) + request_handler.response.out.write('') + + def oauth_required(self, method): + """Decorator that starts the OAuth 2.0 dance. + + Starts the OAuth dance for the logged in user if they haven't already + granted access for this application. + + Args: + method: callable, to be decorated method of a webapp.RequestHandler + instance. + """ + + def check_oauth(request_handler, *args, **kwargs): + if self._in_error: + self._display_error_message(request_handler) + return + + user = users.get_current_user() + # Don't use @login_decorator as this could be used in a + # POST request. + if not user: + request_handler.redirect(users.create_login_url( + request_handler.request.uri)) + return + + self._create_flow(request_handler) + + # Store the request URI in 'state' so we can use it later + self.flow.params['state'] = _build_state_value( + request_handler, user) + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() + + if not self.has_credentials(): + return request_handler.redirect(self.authorize_url()) + try: + resp = method(request_handler, *args, **kwargs) + except client.AccessTokenRefreshError: + return request_handler.redirect(self.authorize_url()) + finally: + self.credentials = None + return resp + + return check_oauth + + def _create_flow(self, request_handler): + """Create the Flow object. + + The Flow is calculated lazily since we don't know where this app is + running until it receives a request, at which point redirect_uri can be + calculated and then the Flow object can be constructed. + + Args: + request_handler: webapp.RequestHandler, the request handler. + """ + if self.flow is None: + redirect_uri = request_handler.request.relative_url( + self._callback_path) # Usually /oauth2callback + self.flow = client.OAuth2WebServerFlow( + self._client_id, self._client_secret, self._scope, + redirect_uri=redirect_uri, user_agent=self._user_agent, + auth_uri=self._auth_uri, token_uri=self._token_uri, + revoke_uri=self._revoke_uri, **self._kwargs) + + def oauth_aware(self, method): + """Decorator that sets up for OAuth 2.0 dance, but doesn't do it. + + Does all the setup for the OAuth dance, but doesn't initiate it. + This decorator is useful if you want to create a page that knows + whether or not the user has granted access to this application. + From within a method decorated with @oauth_aware the has_credentials() + and authorize_url() methods can be called. + + Args: + method: callable, to be decorated method of a webapp.RequestHandler + instance. + """ + + def setup_oauth(request_handler, *args, **kwargs): + if self._in_error: + self._display_error_message(request_handler) + return + + user = users.get_current_user() + # Don't use @login_decorator as this could be used in a + # POST request. + if not user: + request_handler.redirect(users.create_login_url( + request_handler.request.uri)) + return + + self._create_flow(request_handler) + + self.flow.params['state'] = _build_state_value(request_handler, + user) + self.credentials = self._storage_class( + self._credentials_class, None, + self._credentials_property_name, user=user).get() + try: + resp = method(request_handler, *args, **kwargs) + finally: + self.credentials = None + return resp + return setup_oauth + + def has_credentials(self): + """True if for the logged in user there are valid access Credentials. + + Must only be called from with a webapp.RequestHandler subclassed method + that had been decorated with either @oauth_required or @oauth_aware. + """ + return self.credentials is not None and not self.credentials.invalid + + def authorize_url(self): + """Returns the URL to start the OAuth dance. + + Must only be called from with a webapp.RequestHandler subclassed method + that had been decorated with either @oauth_required or @oauth_aware. + """ + url = self.flow.step1_get_authorize_url() + return str(url) + + def http(self, *args, **kwargs): + """Returns an authorized http instance. + + Must only be called from within an @oauth_required decorated method, or + from within an @oauth_aware decorated method where has_credentials() + returns True. + + Args: + *args: Positional arguments passed to httplib2.Http constructor. + **kwargs: Positional arguments passed to httplib2.Http constructor. + """ + return self.credentials.authorize( + transport.get_http_object(*args, **kwargs)) + + @property + def callback_path(self): + """The absolute path where the callback will occur. + + Note this is the absolute path, not the absolute URI, that will be + calculated by the decorator at runtime. See callback_handler() for how + this should be used. + + Returns: + The callback path as a string. + """ + return self._callback_path + + def callback_handler(self): + """RequestHandler for the OAuth 2.0 redirect callback. + + Usage:: + + app = webapp.WSGIApplication([ + ('/index', MyIndexHandler), + ..., + (decorator.callback_path, decorator.callback_handler()) + ]) + + Returns: + A webapp.RequestHandler that handles the redirect back from the + server during the OAuth 2.0 dance. + """ + decorator = self + + class OAuth2Handler(webapp.RequestHandler): + """Handler for the redirect_uri of the OAuth 2.0 dance.""" + + @login_required + def get(self): + error = self.request.get('error') + if error: + errormsg = self.request.get('error_description', error) + self.response.out.write( + 'The authorization request failed: {0}'.format( + _safe_html(errormsg))) + else: + user = users.get_current_user() + decorator._create_flow(self) + credentials = decorator.flow.step2_exchange( + self.request.params) + decorator._storage_class( + decorator._credentials_class, None, + decorator._credentials_property_name, + user=user).put(credentials) + redirect_uri = _parse_state_value( + str(self.request.get('state')), user) + if redirect_uri is None: + self.response.out.write( + 'The authorization request failed') + return + + if (decorator._token_response_param and + credentials.token_response): + resp_json = json.dumps(credentials.token_response) + redirect_uri = _helpers._add_query_parameter( + redirect_uri, decorator._token_response_param, + resp_json) + + self.redirect(redirect_uri) + + return OAuth2Handler + + def callback_application(self): + """WSGI application for handling the OAuth 2.0 redirect callback. + + If you need finer grained control use `callback_handler` which returns + just the webapp.RequestHandler. + + Returns: + A webapp.WSGIApplication that handles the redirect back from the + server during the OAuth 2.0 dance. + """ + return webapp.WSGIApplication([ + (self.callback_path, self.callback_handler()) + ]) + + +class OAuth2DecoratorFromClientSecrets(OAuth2Decorator): + """An OAuth2Decorator that builds from a clientsecrets file. + + Uses a clientsecrets file as the source for all the information when + constructing an OAuth2Decorator. + + :: + + decorator = OAuth2DecoratorFromClientSecrets( + os.path.join(os.path.dirname(__file__), 'client_secrets.json') + scope='https://www.googleapis.com/auth/plus') + + class MainHandler(webapp.RequestHandler): + @decorator.oauth_required + def get(self): + http = decorator.http() + # http is authorized with the user's Credentials and can be + # used in API calls + + """ + + @_helpers.positional(3) + def __init__(self, filename, scope, message=None, cache=None, **kwargs): + """Constructor + + Args: + filename: string, File name of client secrets. + scope: string or iterable of strings, scope(s) of the credentials + being requested. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. The message may + contain HTML and will be presented on the web interface + for any method that uses the decorator. + cache: An optional cache service client that implements get() and + set() + methods. See clientsecrets.loadfile() for details. + **kwargs: dict, Keyword arguments are passed along as kwargs to + the OAuth2WebServerFlow constructor. + """ + client_type, client_info = clientsecrets.loadfile(filename, + cache=cache) + if client_type not in (clientsecrets.TYPE_WEB, + clientsecrets.TYPE_INSTALLED): + raise clientsecrets.InvalidClientSecretsError( + "OAuth2Decorator doesn't support this OAuth 2.0 flow.") + + constructor_kwargs = dict(kwargs) + constructor_kwargs.update({ + 'auth_uri': client_info['auth_uri'], + 'token_uri': client_info['token_uri'], + 'message': message, + }) + revoke_uri = client_info.get('revoke_uri') + if revoke_uri is not None: + constructor_kwargs['revoke_uri'] = revoke_uri + super(OAuth2DecoratorFromClientSecrets, self).__init__( + client_info['client_id'], client_info['client_secret'], + scope, **constructor_kwargs) + if message is not None: + self._message = message + else: + self._message = 'Please configure your application for OAuth 2.0.' + + +@_helpers.positional(2) +def oauth2decorator_from_clientsecrets(filename, scope, + message=None, cache=None): + """Creates an OAuth2Decorator populated from a clientsecrets file. + + Args: + filename: string, File name of client secrets. + scope: string or list of strings, scope(s) of the credentials being + requested. + message: string, A friendly string to display to the user if the + clientsecrets file is missing or invalid. The message may + contain HTML and will be presented on the web interface for + any method that uses the decorator. + cache: An optional cache service client that implements get() and set() + methods. See clientsecrets.loadfile() for details. + + Returns: An OAuth2Decorator + """ + return OAuth2DecoratorFromClientSecrets(filename, scope, + message=message, cache=cache) diff --git a/src/oauth2client/oauth2client/contrib/devshell.py b/src/oauth2client/oauth2client/contrib/devshell.py new file mode 100644 index 00000000..691765f0 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/devshell.py @@ -0,0 +1,152 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 utitilies for Google Developer Shell environment.""" + +import datetime +import json +import os +import socket + +from oauth2client import _helpers +from oauth2client import client + +DEVSHELL_ENV = 'DEVSHELL_CLIENT_PORT' + + +class Error(Exception): + """Errors for this module.""" + pass + + +class CommunicationError(Error): + """Errors for communication with the Developer Shell server.""" + + +class NoDevshellServer(Error): + """Error when no Developer Shell server can be contacted.""" + + +# The request for credential information to the Developer Shell client socket +# is always an empty PBLite-formatted JSON object, so just define it as a +# constant. +CREDENTIAL_INFO_REQUEST_JSON = '[]' + + +class CredentialInfoResponse(object): + """Credential information response from Developer Shell server. + + The credential information response from Developer Shell socket is a + PBLite-formatted JSON array with fields encoded by their index in the + array: + + * Index 0 - user email + * Index 1 - default project ID. None if the project context is not known. + * Index 2 - OAuth2 access token. None if there is no valid auth context. + * Index 3 - Seconds until the access token expires. None if not present. + """ + + def __init__(self, json_string): + """Initialize the response data from JSON PBLite array.""" + pbl = json.loads(json_string) + if not isinstance(pbl, list): + raise ValueError('Not a list: ' + str(pbl)) + pbl_len = len(pbl) + self.user_email = pbl[0] if pbl_len > 0 else None + self.project_id = pbl[1] if pbl_len > 1 else None + self.access_token = pbl[2] if pbl_len > 2 else None + self.expires_in = pbl[3] if pbl_len > 3 else None + + +def _SendRecv(): + """Communicate with the Developer Shell server socket.""" + + port = int(os.getenv(DEVSHELL_ENV, 0)) + if port == 0: + raise NoDevshellServer() + + sock = socket.socket() + sock.connect(('localhost', port)) + + data = CREDENTIAL_INFO_REQUEST_JSON + msg = '{0}\n{1}'.format(len(data), data) + sock.sendall(_helpers._to_bytes(msg, encoding='utf-8')) + + header = sock.recv(6).decode() + if '\n' not in header: + raise CommunicationError('saw no newline in the first 6 bytes') + len_str, json_str = header.split('\n', 1) + to_read = int(len_str) - len(json_str) + if to_read > 0: + json_str += sock.recv(to_read, socket.MSG_WAITALL).decode() + + return CredentialInfoResponse(json_str) + + +class DevshellCredentials(client.GoogleCredentials): + """Credentials object for Google Developer Shell environment. + + This object will allow a Google Developer Shell session to identify its + user to Google and other OAuth 2.0 servers that can verify assertions. It + can be used for the purpose of accessing data stored under the user + account. + + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + """ + + def __init__(self, user_agent=None): + super(DevshellCredentials, self).__init__( + None, # access_token, initialized below + None, # client_id + None, # client_secret + None, # refresh_token + None, # token_expiry + None, # token_uri + user_agent) + self._refresh(None) + + def _refresh(self, http): + """Refreshes the access token. + + Args: + http: unused HTTP object + """ + self.devshell_response = _SendRecv() + self.access_token = self.devshell_response.access_token + expires_in = self.devshell_response.expires_in + if expires_in is not None: + delta = datetime.timedelta(seconds=expires_in) + self.token_expiry = client._UTCNOW() + delta + else: + self.token_expiry = None + + @property + def user_email(self): + return self.devshell_response.user_email + + @property + def project_id(self): + return self.devshell_response.project_id + + @classmethod + def from_json(cls, json_data): + raise NotImplementedError( + 'Cannot load Developer Shell credentials from JSON.') + + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize Developer Shell credentials.') diff --git a/src/oauth2client/oauth2client/contrib/dictionary_storage.py b/src/oauth2client/oauth2client/contrib/dictionary_storage.py new file mode 100644 index 00000000..6ee333fa --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/dictionary_storage.py @@ -0,0 +1,65 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dictionary storage for OAuth2 Credentials.""" + +from oauth2client import client + + +class DictionaryStorage(client.Storage): + """Store and retrieve credentials to and from a dictionary-like object. + + Args: + dictionary: A dictionary or dictionary-like object. + key: A string or other hashable. The credentials will be stored in + ``dictionary[key]``. + lock: An optional threading.Lock-like object. The lock will be + acquired before anything is written or read from the + dictionary. + """ + + def __init__(self, dictionary, key, lock=None): + """Construct a DictionaryStorage instance.""" + super(DictionaryStorage, self).__init__(lock=lock) + self._dictionary = dictionary + self._key = key + + def locked_get(self): + """Retrieve the credentials from the dictionary, if they exist. + + Returns: A :class:`oauth2client.client.OAuth2Credentials` instance. + """ + serialized = self._dictionary.get(self._key) + + if serialized is None: + return None + + credentials = client.OAuth2Credentials.from_json(serialized) + credentials.set_store(self) + + return credentials + + def locked_put(self, credentials): + """Save the credentials to the dictionary. + + Args: + credentials: A :class:`oauth2client.client.OAuth2Credentials` + instance. + """ + serialized = credentials.to_json() + self._dictionary[self._key] = serialized + + def locked_delete(self): + """Remove the credentials from the dictionary, if they exist.""" + self._dictionary.pop(self._key, None) diff --git a/src/oauth2client/oauth2client/contrib/django_util/__init__.py b/src/oauth2client/oauth2client/contrib/django_util/__init__.py new file mode 100644 index 00000000..644a8f9f --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/__init__.py @@ -0,0 +1,489 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for the Django web framework. + +Provides Django views and helpers the make using the OAuth2 web server +flow easier. It includes an ``oauth_required`` decorator to automatically +ensure that user credentials are available, and an ``oauth_enabled`` decorator +to check if the user has authorized, and helper shortcuts to create the +authorization URL otherwise. + +There are two basic use cases supported. The first is using Google OAuth as the +primary form of authentication, which is the simpler approach recommended +for applications without their own user system. + +The second use case is adding Google OAuth credentials to an +existing Django model containing a Django user field. Most of the +configuration is the same, except for `GOOGLE_OAUTH_MODEL_STORAGE` in +settings.py. See "Adding Credentials To An Existing Django User System" for +usage differences. + +Only Django versions 1.8+ are supported. + +Configuration +=============== + +To configure, you'll need a set of OAuth2 web application credentials from +`Google Developer's Console `. + +Add the helper to your INSTALLED_APPS: + +.. code-block:: python + :caption: settings.py + :name: installed_apps + + INSTALLED_APPS = ( + # other apps + "django.contrib.sessions.middleware" + "oauth2client.contrib.django_util" + ) + +This helper also requires the Django Session Middleware, so +``django.contrib.sessions.middleware`` should be in INSTALLED_APPS as well. +MIDDLEWARE or MIDDLEWARE_CLASSES (in Django versions <1.10) should also +contain the string 'django.contrib.sessions.middleware.SessionMiddleware'. + + +Add the client secrets created earlier to the settings. You can either +specify the path to the credentials file in JSON format + +.. code-block:: python + :caption: settings.py + :name: secrets_file + + GOOGLE_OAUTH2_CLIENT_SECRETS_JSON=/path/to/client-secret.json + +Or, directly configure the client Id and client secret. + + +.. code-block:: python + :caption: settings.py + :name: secrets_config + + GOOGLE_OAUTH2_CLIENT_ID=client-id-field + GOOGLE_OAUTH2_CLIENT_SECRET=client-secret-field + +By default, the default scopes for the required decorator only contains the +``email`` scopes. You can change that default in the settings. + +.. code-block:: python + :caption: settings.py + :name: scopes + + GOOGLE_OAUTH2_SCOPES = ('email', 'https://www.googleapis.com/auth/calendar',) + +By default, the decorators will add an `oauth` object to the Django request +object, and include all of its state and helpers inside that object. If the +`oauth` name conflicts with another usage, it can be changed + +.. code-block:: python + :caption: settings.py + :name: request_prefix + + # changes request.oauth to request.google_oauth + GOOGLE_OAUTH2_REQUEST_ATTRIBUTE = 'google_oauth' + +Add the oauth2 routes to your application's urls.py urlpatterns. + +.. code-block:: python + :caption: urls.py + :name: urls + + from oauth2client.contrib.django_util.site import urls as oauth2_urls + + urlpatterns += [url(r'^oauth2/', include(oauth2_urls))] + +To require OAuth2 credentials for a view, use the `oauth2_required` decorator. +This creates a credentials object with an id_token, and allows you to create +an `http` object to build service clients with. These are all attached to the +request.oauth + +.. code-block:: python + :caption: views.py + :name: views_required + + from oauth2client.contrib.django_util.decorators import oauth_required + + @oauth_required + def requires_default_scopes(request): + email = request.oauth.credentials.id_token['email'] + service = build(serviceName='calendar', version='v3', + http=request.oauth.http, + developerKey=API_KEY) + events = service.events().list(calendarId='primary').execute()['items'] + return HttpResponse("email: {0} , calendar: {1}".format( + email,str(events))) + return HttpResponse( + "email: {0} , calendar: {1}".format(email, str(events))) + +To make OAuth2 optional and provide an authorization link in your own views. + +.. code-block:: python + :caption: views.py + :name: views_enabled2 + + from oauth2client.contrib.django_util.decorators import oauth_enabled + + @oauth_enabled + def optional_oauth2(request): + if request.oauth.has_credentials(): + # this could be passed into a view + # request.oauth.http is also initialized + return HttpResponse("User email: {0}".format( + request.oauth.credentials.id_token['email'])) + else: + return HttpResponse( + 'Here is an OAuth Authorize link: Authorize' + ''.format(request.oauth.get_authorize_redirect())) + +If a view needs a scope not included in the default scopes specified in +the settings, you can use [incremental auth](https://developers.google.com/identity/sign-in/web/incremental-auth) +and specify additional scopes in the decorator arguments. + +.. code-block:: python + :caption: views.py + :name: views_required_additional_scopes + + @oauth_enabled(scopes=['https://www.googleapis.com/auth/drive']) + def drive_required(request): + if request.oauth.has_credentials(): + service = build(serviceName='drive', version='v2', + http=request.oauth.http, + developerKey=API_KEY) + events = service.files().list().execute()['items'] + return HttpResponse(str(events)) + else: + return HttpResponse( + 'Here is an OAuth Authorize link: Authorize' + ''.format(request.oauth.get_authorize_redirect())) + + +To provide a callback on authorization being completed, use the +oauth2_authorized signal: + +.. code-block:: python + :caption: views.py + :name: signals + + from oauth2client.contrib.django_util.signals import oauth2_authorized + + def test_callback(sender, request, credentials, **kwargs): + print("Authorization Signal Received {0}".format( + credentials.id_token['email'])) + + oauth2_authorized.connect(test_callback) + +Adding Credentials To An Existing Django User System +===================================================== + +As an alternative to storing the credentials in the session, the helper +can be configured to store the fields on a Django model. This might be useful +if you need to use the credentials outside the context of a user request. It +also prevents the need for a logged in user to repeat the OAuth flow when +starting a new session. + +To use, change ``settings.py`` + +.. code-block:: python + :caption: settings.py + :name: storage_model_config + + GOOGLE_OAUTH2_STORAGE_MODEL = { + 'model': 'path.to.model.MyModel', + 'user_property': 'user_id', + 'credentials_property': 'credential' + } + +Where ``path.to.model`` class is the fully qualified name of a +``django.db.model`` class containing a ``django.contrib.auth.models.User`` +field with the name specified by `user_property` and a +:class:`oauth2client.contrib.django_util.models.CredentialsField` with the name +specified by `credentials_property`. For the sample configuration given, +our model would look like + +.. code-block:: python + :caption: models.py + :name: storage_model_model + + from django.contrib.auth.models import User + from oauth2client.contrib.django_util.models import CredentialsField + + class MyModel(models.Model): + # ... other fields here ... + user = models.OneToOneField(User) + credential = CredentialsField() +""" + +import importlib + +import django.conf +from django.core import exceptions +from django.core import urlresolvers +from six.moves.urllib import parse + +from oauth2client import clientsecrets +from oauth2client import transport +from oauth2client.contrib import dictionary_storage +from oauth2client.contrib.django_util import storage + +GOOGLE_OAUTH2_DEFAULT_SCOPES = ('email',) +GOOGLE_OAUTH2_REQUEST_ATTRIBUTE = 'oauth' + + +def _load_client_secrets(filename): + """Loads client secrets from the given filename. + + Args: + filename: The name of the file containing the JSON secret key. + + Returns: + A 2-tuple, the first item containing the client id, and the second + item containing a client secret. + """ + client_type, client_info = clientsecrets.loadfile(filename) + + if client_type != clientsecrets.TYPE_WEB: + raise ValueError( + 'The flow specified in {} is not supported, only the WEB flow ' + 'type is supported.'.format(client_type)) + return client_info['client_id'], client_info['client_secret'] + + +def _get_oauth2_client_id_and_secret(settings_instance): + """Initializes client id and client secret based on the settings. + + Args: + settings_instance: An instance of ``django.conf.settings``. + + Returns: + A 2-tuple, the first item is the client id and the second + item is the client secret. + """ + secret_json = getattr(settings_instance, + 'GOOGLE_OAUTH2_CLIENT_SECRETS_JSON', None) + if secret_json is not None: + return _load_client_secrets(secret_json) + else: + client_id = getattr(settings_instance, "GOOGLE_OAUTH2_CLIENT_ID", + None) + client_secret = getattr(settings_instance, + "GOOGLE_OAUTH2_CLIENT_SECRET", None) + if client_id is not None and client_secret is not None: + return client_id, client_secret + else: + raise exceptions.ImproperlyConfigured( + "Must specify either GOOGLE_OAUTH2_CLIENT_SECRETS_JSON, or " + "both GOOGLE_OAUTH2_CLIENT_ID and " + "GOOGLE_OAUTH2_CLIENT_SECRET in settings.py") + + +def _get_storage_model(): + """This configures whether the credentials will be stored in the session + or the Django ORM based on the settings. By default, the credentials + will be stored in the session, unless `GOOGLE_OAUTH2_STORAGE_MODEL` + is found in the settings. Usually, the ORM storage is used to integrate + credentials into an existing Django user system. + + Returns: + A tuple containing three strings, or None. If + ``GOOGLE_OAUTH2_STORAGE_MODEL`` is configured, the tuple + will contain the fully qualifed path of the `django.db.model`, + the name of the ``django.contrib.auth.models.User`` field on the + model, and the name of the + :class:`oauth2client.contrib.django_util.models.CredentialsField` + field on the model. If Django ORM storage is not configured, + this function returns None. + """ + storage_model_settings = getattr(django.conf.settings, + 'GOOGLE_OAUTH2_STORAGE_MODEL', None) + if storage_model_settings is not None: + return (storage_model_settings['model'], + storage_model_settings['user_property'], + storage_model_settings['credentials_property']) + else: + return None, None, None + + +class OAuth2Settings(object): + """Initializes Django OAuth2 Helper Settings + + This class loads the OAuth2 Settings from the Django settings, and then + provides those settings as attributes to the rest of the views and + decorators in the module. + + Attributes: + scopes: A list of OAuth2 scopes that the decorators and views will use + as defaults. + request_prefix: The name of the attribute that the decorators use to + attach the UserOAuth2 object to the Django request object. + client_id: The OAuth2 Client ID. + client_secret: The OAuth2 Client Secret. + """ + + def __init__(self, settings_instance): + self.scopes = getattr(settings_instance, 'GOOGLE_OAUTH2_SCOPES', + GOOGLE_OAUTH2_DEFAULT_SCOPES) + self.request_prefix = getattr(settings_instance, + 'GOOGLE_OAUTH2_REQUEST_ATTRIBUTE', + GOOGLE_OAUTH2_REQUEST_ATTRIBUTE) + info = _get_oauth2_client_id_and_secret(settings_instance) + self.client_id, self.client_secret = info + + # Django 1.10 deprecated MIDDLEWARE_CLASSES in favor of MIDDLEWARE + middleware_settings = getattr(settings_instance, 'MIDDLEWARE', None) + if middleware_settings is None: + middleware_settings = getattr( + settings_instance, 'MIDDLEWARE_CLASSES', None) + if middleware_settings is None: + raise exceptions.ImproperlyConfigured( + 'Django settings has neither MIDDLEWARE nor MIDDLEWARE_CLASSES' + 'configured') + + if ('django.contrib.sessions.middleware.SessionMiddleware' not in + middleware_settings): + raise exceptions.ImproperlyConfigured( + 'The Google OAuth2 Helper requires session middleware to ' + 'be installed. Edit your MIDDLEWARE_CLASSES or MIDDLEWARE ' + 'setting to include \'django.contrib.sessions.middleware.' + 'SessionMiddleware\'.') + (self.storage_model, self.storage_model_user_property, + self.storage_model_credentials_property) = _get_storage_model() + + +oauth2_settings = OAuth2Settings(django.conf.settings) + +_CREDENTIALS_KEY = 'google_oauth2_credentials' + + +def get_storage(request): + """ Gets a Credentials storage object provided by the Django OAuth2 Helper + object. + + Args: + request: Reference to the current request object. + + Returns: + An :class:`oauth2.client.Storage` object. + """ + storage_model = oauth2_settings.storage_model + user_property = oauth2_settings.storage_model_user_property + credentials_property = oauth2_settings.storage_model_credentials_property + + if storage_model: + module_name, class_name = storage_model.rsplit('.', 1) + module = importlib.import_module(module_name) + storage_model_class = getattr(module, class_name) + return storage.DjangoORMStorage(storage_model_class, + user_property, + request.user, + credentials_property) + else: + # use session + return dictionary_storage.DictionaryStorage( + request.session, key=_CREDENTIALS_KEY) + + +def _redirect_with_params(url_name, *args, **kwargs): + """Helper method to create a redirect response with URL params. + + This builds a redirect string that converts kwargs into a + query string. + + Args: + url_name: The name of the url to redirect to. + kwargs: the query string param and their values to build. + + Returns: + A properly formatted redirect string. + """ + url = urlresolvers.reverse(url_name, args=args) + params = parse.urlencode(kwargs, True) + return "{0}?{1}".format(url, params) + + +def _credentials_from_request(request): + """Gets the authorized credentials for this flow, if they exist.""" + # ORM storage requires a logged in user + if (oauth2_settings.storage_model is None or + request.user.is_authenticated()): + return get_storage(request).get() + else: + return None + + +class UserOAuth2(object): + """Class to create oauth2 objects on Django request objects containing + credentials and helper methods. + """ + + def __init__(self, request, scopes=None, return_url=None): + """Initialize the Oauth2 Object. + + Args: + request: Django request object. + scopes: Scopes desired for this OAuth2 flow. + return_url: The url to return to after the OAuth flow is complete, + defaults to the request's current URL path. + """ + self.request = request + self.return_url = return_url or request.get_full_path() + if scopes: + self._scopes = set(oauth2_settings.scopes) | set(scopes) + else: + self._scopes = set(oauth2_settings.scopes) + + def get_authorize_redirect(self): + """Creates a URl to start the OAuth2 authorization flow.""" + get_params = { + 'return_url': self.return_url, + 'scopes': self._get_scopes() + } + + return _redirect_with_params('google_oauth:authorize', **get_params) + + def has_credentials(self): + """Returns True if there are valid credentials for the current user + and required scopes.""" + credentials = _credentials_from_request(self.request) + return (credentials and not credentials.invalid and + credentials.has_scopes(self._get_scopes())) + + def _get_scopes(self): + """Returns the scopes associated with this object, kept up to + date for incremental auth.""" + if _credentials_from_request(self.request): + return (self._scopes | + _credentials_from_request(self.request).scopes) + else: + return self._scopes + + @property + def scopes(self): + """Returns the scopes associated with this OAuth2 object.""" + # make sure previously requested custom scopes are maintained + # in future authorizations + return self._get_scopes() + + @property + def credentials(self): + """Gets the authorized credentials for this flow, if they exist.""" + return _credentials_from_request(self.request) + + @property + def http(self): + """Helper: create HTTP client authorized with OAuth2 credentials.""" + if self.has_credentials(): + return self.credentials.authorize(transport.get_http_object()) + return None diff --git a/src/oauth2client/oauth2client/contrib/django_util/apps.py b/src/oauth2client/oauth2client/contrib/django_util/apps.py new file mode 100644 index 00000000..86676b91 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/apps.py @@ -0,0 +1,32 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Application Config For Django OAuth2 Helper. + +Django 1.7+ provides an +[applications](https://docs.djangoproject.com/en/1.8/ref/applications/) +API so that Django projects can introspect on installed applications using a +stable API. This module exists to follow that convention. +""" + +import sys + +# Django 1.7+ only supports Python 2.7+ +if sys.hexversion >= 0x02070000: # pragma: NO COVER + from django.apps import AppConfig + + class GoogleOAuth2HelperConfig(AppConfig): + """ App Config for Django Helper""" + name = 'oauth2client.django_util' + verbose_name = "Google OAuth2 Django Helper" diff --git a/src/oauth2client/oauth2client/contrib/django_util/decorators.py b/src/oauth2client/oauth2client/contrib/django_util/decorators.py new file mode 100644 index 00000000..e62e1710 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/decorators.py @@ -0,0 +1,145 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Decorators for Django OAuth2 Flow. + +Contains two decorators, ``oauth_required`` and ``oauth_enabled``. + +``oauth_required`` will ensure that a user has an oauth object containing +credentials associated with the request, and if not, redirect to the +authorization flow. + +``oauth_enabled`` will attach the oauth2 object containing credentials if it +exists. If it doesn't, the view will still render, but helper methods will be +attached to start the oauth2 flow. +""" + +from django import shortcuts +import django.conf +from six import wraps +from six.moves.urllib import parse + +from oauth2client.contrib import django_util + + +def oauth_required(decorated_function=None, scopes=None, **decorator_kwargs): + """ Decorator to require OAuth2 credentials for a view. + + + .. code-block:: python + :caption: views.py + :name: views_required_2 + + + from oauth2client.django_util.decorators import oauth_required + + @oauth_required + def requires_default_scopes(request): + email = request.credentials.id_token['email'] + service = build(serviceName='calendar', version='v3', + http=request.oauth.http, + developerKey=API_KEY) + events = service.events().list( + calendarId='primary').execute()['items'] + return HttpResponse( + "email: {0}, calendar: {1}".format(email, str(events))) + + Args: + decorated_function: View function to decorate, must have the Django + request object as the first argument. + scopes: Scopes to require, will default. + decorator_kwargs: Can include ``return_url`` to specify the URL to + return to after OAuth2 authorization is complete. + + Returns: + An OAuth2 Authorize view if credentials are not found or if the + credentials are missing the required scopes. Otherwise, + the decorated view. + """ + def curry_wrapper(wrapped_function): + @wraps(wrapped_function) + def required_wrapper(request, *args, **kwargs): + if not (django_util.oauth2_settings.storage_model is None or + request.user.is_authenticated()): + redirect_str = '{0}?next={1}'.format( + django.conf.settings.LOGIN_URL, + parse.quote(request.path)) + return shortcuts.redirect(redirect_str) + + return_url = decorator_kwargs.pop('return_url', + request.get_full_path()) + user_oauth = django_util.UserOAuth2(request, scopes, return_url) + if not user_oauth.has_credentials(): + return shortcuts.redirect(user_oauth.get_authorize_redirect()) + setattr(request, django_util.oauth2_settings.request_prefix, + user_oauth) + return wrapped_function(request, *args, **kwargs) + + return required_wrapper + + if decorated_function: + return curry_wrapper(decorated_function) + else: + return curry_wrapper + + +def oauth_enabled(decorated_function=None, scopes=None, **decorator_kwargs): + """ Decorator to enable OAuth Credentials if authorized, and setup + the oauth object on the request object to provide helper functions + to start the flow otherwise. + + .. code-block:: python + :caption: views.py + :name: views_enabled3 + + from oauth2client.django_util.decorators import oauth_enabled + + @oauth_enabled + def optional_oauth2(request): + if request.oauth.has_credentials(): + # this could be passed into a view + # request.oauth.http is also initialized + return HttpResponse("User email: {0}".format( + request.oauth.credentials.id_token['email']) + else: + return HttpResponse('Here is an OAuth Authorize link: + Authorize'.format( + request.oauth.get_authorize_redirect())) + + + Args: + decorated_function: View function to decorate. + scopes: Scopes to require, will default. + decorator_kwargs: Can include ``return_url`` to specify the URL to + return to after OAuth2 authorization is complete. + + Returns: + The decorated view function. + """ + def curry_wrapper(wrapped_function): + @wraps(wrapped_function) + def enabled_wrapper(request, *args, **kwargs): + return_url = decorator_kwargs.pop('return_url', + request.get_full_path()) + user_oauth = django_util.UserOAuth2(request, scopes, return_url) + setattr(request, django_util.oauth2_settings.request_prefix, + user_oauth) + return wrapped_function(request, *args, **kwargs) + + return enabled_wrapper + + if decorated_function: + return curry_wrapper(decorated_function) + else: + return curry_wrapper diff --git a/src/oauth2client/oauth2client/contrib/django_util/models.py b/src/oauth2client/oauth2client/contrib/django_util/models.py new file mode 100644 index 00000000..37cc6970 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/models.py @@ -0,0 +1,82 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains classes used for the Django ORM storage.""" + +import base64 +import pickle + +from django.db import models +from django.utils import encoding +import jsonpickle + +import oauth2client + + +class CredentialsField(models.Field): + """Django ORM field for storing OAuth2 Credentials.""" + + def __init__(self, *args, **kwargs): + if 'null' not in kwargs: + kwargs['null'] = True + super(CredentialsField, self).__init__(*args, **kwargs) + + def get_internal_type(self): + return 'BinaryField' + + def from_db_value(self, value, expression, connection, context): + """Overrides ``models.Field`` method. This converts the value + returned from the database to an instance of this class. + """ + return self.to_python(value) + + def to_python(self, value): + """Overrides ``models.Field`` method. This is used to convert + bytes (from serialization etc) to an instance of this class""" + if value is None: + return None + elif isinstance(value, oauth2client.client.Credentials): + return value + else: + try: + return jsonpickle.decode( + base64.b64decode(encoding.smart_bytes(value)).decode()) + except ValueError: + return pickle.loads( + base64.b64decode(encoding.smart_bytes(value))) + + def get_prep_value(self, value): + """Overrides ``models.Field`` method. This is used to convert + the value from an instances of this class to bytes that can be + inserted into the database. + """ + if value is None: + return None + else: + return encoding.smart_text( + base64.b64encode(jsonpickle.encode(value).encode())) + + def value_to_string(self, obj): + """Convert the field value from the provided model to a string. + + Used during model serialization. + + Args: + obj: db.Model, model object + + Returns: + string, the serialized field value + """ + value = self._get_val_from_obj(obj) + return self.get_prep_value(value) diff --git a/src/oauth2client/oauth2client/contrib/django_util/signals.py b/src/oauth2client/oauth2client/contrib/django_util/signals.py new file mode 100644 index 00000000..e9356b4d --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/signals.py @@ -0,0 +1,28 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Signals for Google OAuth2 Helper. + +This module contains signals for Google OAuth2 Helper. Currently it only +contains one, which fires when an OAuth2 authorization flow has completed. +""" + +import django.dispatch + +"""Signal that fires when OAuth2 Flow has completed. +It passes the Django request object and the OAuth2 credentials object to the + receiver. +""" +oauth2_authorized = django.dispatch.Signal( + providing_args=["request", "credentials"]) diff --git a/src/oauth2client/oauth2client/contrib/django_util/site.py b/src/oauth2client/oauth2client/contrib/django_util/site.py new file mode 100644 index 00000000..631f79be --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/site.py @@ -0,0 +1,26 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains Django URL patterns used for OAuth2 flow.""" + +from django.conf import urls + +from oauth2client.contrib.django_util import views + +urlpatterns = [ + urls.url(r'oauth2callback/', views.oauth2_callback, name="callback"), + urls.url(r'oauth2authorize/', views.oauth2_authorize, name="authorize") +] + +urls = (urlpatterns, "google_oauth", "google_oauth") diff --git a/src/oauth2client/oauth2client/contrib/django_util/storage.py b/src/oauth2client/oauth2client/contrib/django_util/storage.py new file mode 100644 index 00000000..5682919b --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/storage.py @@ -0,0 +1,81 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains a storage module that stores credentials using the Django ORM.""" + +from oauth2client import client + + +class DjangoORMStorage(client.Storage): + """Store and retrieve a single credential to and from the Django datastore. + + This Storage helper presumes the Credentials + have been stored as a CredentialsField + on a db model class. + """ + + def __init__(self, model_class, key_name, key_value, property_name): + """Constructor for Storage. + + Args: + model: string, fully qualified name of db.Model model class. + key_name: string, key name for the entity that has the credentials + key_value: string, key value for the entity that has the + credentials. + property_name: string, name of the property that is an + CredentialsProperty. + """ + super(DjangoORMStorage, self).__init__() + self.model_class = model_class + self.key_name = key_name + self.key_value = key_value + self.property_name = property_name + + def locked_get(self): + """Retrieve stored credential from the Django ORM. + + Returns: + oauth2client.Credentials retrieved from the Django ORM, associated + with the ``model``, ``key_value``->``key_name`` pair used to query + for the model, and ``property_name`` identifying the + ``CredentialsProperty`` field, all of which are defined in the + constructor for this Storage object. + + """ + query = {self.key_name: self.key_value} + entities = self.model_class.objects.filter(**query) + if len(entities) > 0: + credential = getattr(entities[0], self.property_name) + if getattr(credential, 'set_store', None) is not None: + credential.set_store(self) + return credential + else: + return None + + def locked_put(self, credentials): + """Write a Credentials to the Django datastore. + + Args: + credentials: Credentials, the credentials to store. + """ + entity, _ = self.model_class.objects.get_or_create( + **{self.key_name: self.key_value}) + + setattr(entity, self.property_name, credentials) + entity.save() + + def locked_delete(self): + """Delete Credentials from the datastore.""" + query = {self.key_name: self.key_value} + self.model_class.objects.filter(**query).delete() diff --git a/src/oauth2client/oauth2client/contrib/django_util/views.py b/src/oauth2client/oauth2client/contrib/django_util/views.py new file mode 100644 index 00000000..1835208a --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/django_util/views.py @@ -0,0 +1,193 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the views used by the OAuth2 flows. + +Their are two views used by the OAuth2 flow, the authorize and the callback +view. The authorize view kicks off the three-legged OAuth flow, and the +callback view validates the flow and if successful stores the credentials +in the configured storage.""" + +import hashlib +import json +import os + +from django import http +from django import shortcuts +from django.conf import settings +from django.core import urlresolvers +from django.shortcuts import redirect +from django.utils import html +import jsonpickle +from six.moves.urllib import parse + +from oauth2client import client +from oauth2client.contrib import django_util +from oauth2client.contrib.django_util import get_storage +from oauth2client.contrib.django_util import signals + +_CSRF_KEY = 'google_oauth2_csrf_token' +_FLOW_KEY = 'google_oauth2_flow_{0}' + + +def _make_flow(request, scopes, return_url=None): + """Creates a Web Server Flow + + Args: + request: A Django request object. + scopes: the request oauth2 scopes. + return_url: The URL to return to after the flow is complete. Defaults + to the path of the current request. + + Returns: + An OAuth2 flow object that has been stored in the session. + """ + # Generate a CSRF token to prevent malicious requests. + csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest() + + request.session[_CSRF_KEY] = csrf_token + + state = json.dumps({ + 'csrf_token': csrf_token, + 'return_url': return_url, + }) + + flow = client.OAuth2WebServerFlow( + client_id=django_util.oauth2_settings.client_id, + client_secret=django_util.oauth2_settings.client_secret, + scope=scopes, + state=state, + redirect_uri=request.build_absolute_uri( + urlresolvers.reverse("google_oauth:callback"))) + + flow_key = _FLOW_KEY.format(csrf_token) + request.session[flow_key] = jsonpickle.encode(flow) + return flow + + +def _get_flow_for_token(csrf_token, request): + """ Looks up the flow in session to recover information about requested + scopes. + + Args: + csrf_token: The token passed in the callback request that should + match the one previously generated and stored in the request on the + initial authorization view. + + Returns: + The OAuth2 Flow object associated with this flow based on the + CSRF token. + """ + flow_pickle = request.session.get(_FLOW_KEY.format(csrf_token), None) + return None if flow_pickle is None else jsonpickle.decode(flow_pickle) + + +def oauth2_callback(request): + """ View that handles the user's return from OAuth2 provider. + + This view verifies the CSRF state and OAuth authorization code, and on + success stores the credentials obtained in the storage provider, + and redirects to the return_url specified in the authorize view and + stored in the session. + + Args: + request: Django request. + + Returns: + A redirect response back to the return_url. + """ + if 'error' in request.GET: + reason = request.GET.get( + 'error_description', request.GET.get('error', '')) + reason = html.escape(reason) + return http.HttpResponseBadRequest( + 'Authorization failed {0}'.format(reason)) + + try: + encoded_state = request.GET['state'] + code = request.GET['code'] + except KeyError: + return http.HttpResponseBadRequest( + 'Request missing state or authorization code') + + try: + server_csrf = request.session[_CSRF_KEY] + except KeyError: + return http.HttpResponseBadRequest( + 'No existing session for this flow.') + + try: + state = json.loads(encoded_state) + client_csrf = state['csrf_token'] + return_url = state['return_url'] + except (ValueError, KeyError): + return http.HttpResponseBadRequest('Invalid state parameter.') + + if client_csrf != server_csrf: + return http.HttpResponseBadRequest('Invalid CSRF token.') + + flow = _get_flow_for_token(client_csrf, request) + + if not flow: + return http.HttpResponseBadRequest('Missing Oauth2 flow.') + + try: + credentials = flow.step2_exchange(code) + except client.FlowExchangeError as exchange_error: + return http.HttpResponseBadRequest( + 'An error has occurred: {0}'.format(exchange_error)) + + get_storage(request).put(credentials) + + signals.oauth2_authorized.send(sender=signals.oauth2_authorized, + request=request, credentials=credentials) + + return shortcuts.redirect(return_url) + + +def oauth2_authorize(request): + """ View to start the OAuth2 Authorization flow. + + This view starts the OAuth2 authorization flow. If scopes is passed in + as a GET URL parameter, it will authorize those scopes, otherwise the + default scopes specified in settings. The return_url can also be + specified as a GET parameter, otherwise the referer header will be + checked, and if that isn't found it will return to the root path. + + Args: + request: The Django request object. + + Returns: + A redirect to Google OAuth2 Authorization. + """ + return_url = request.GET.get('return_url', None) + if not return_url: + return_url = request.META.get('HTTP_REFERER', '/') + + scopes = request.GET.getlist('scopes', django_util.oauth2_settings.scopes) + # Model storage (but not session storage) requires a logged in user + if django_util.oauth2_settings.storage_model: + if not request.user.is_authenticated(): + return redirect('{0}?next={1}'.format( + settings.LOGIN_URL, parse.quote(request.get_full_path()))) + # This checks for the case where we ended up here because of a logged + # out user but we had credentials for it in the first place + else: + user_oauth = django_util.UserOAuth2(request, scopes, return_url) + if user_oauth.has_credentials(): + return redirect(return_url) + + flow = _make_flow(request=request, scopes=scopes, return_url=return_url) + auth_url = flow.step1_get_authorize_url() + return shortcuts.redirect(auth_url) diff --git a/src/oauth2client/oauth2client/contrib/flask_util.py b/src/oauth2client/oauth2client/contrib/flask_util.py new file mode 100644 index 00000000..fabd613b --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/flask_util.py @@ -0,0 +1,557 @@ +# Copyright 2015 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for the Flask web framework + +Provides a Flask extension that makes using OAuth2 web server flow easier. +The extension includes views that handle the entire auth flow and a +``@required`` decorator to automatically ensure that user credentials are +available. + + +Configuration +============= + +To configure, you'll need a set of OAuth2 web application credentials from the +`Google Developer's Console `__. + +.. code-block:: python + + from oauth2client.contrib.flask_util import UserOAuth2 + + app = Flask(__name__) + + app.config['SECRET_KEY'] = 'your-secret-key' + + app.config['GOOGLE_OAUTH2_CLIENT_SECRETS_FILE'] = 'client_secrets.json' + + # or, specify the client id and secret separately + app.config['GOOGLE_OAUTH2_CLIENT_ID'] = 'your-client-id' + app.config['GOOGLE_OAUTH2_CLIENT_SECRET'] = 'your-client-secret' + + oauth2 = UserOAuth2(app) + + +Usage +===== + +Once configured, you can use the :meth:`UserOAuth2.required` decorator to +ensure that credentials are available within a view. + +.. code-block:: python + :emphasize-lines: 3,7,10 + + # Note that app.route should be the outermost decorator. + @app.route('/needs_credentials') + @oauth2.required + def example(): + # http is authorized with the user's credentials and can be used + # to make http calls. + http = oauth2.http() + + # Or, you can access the credentials directly + credentials = oauth2.credentials + +If you want credentials to be optional for a view, you can leave the decorator +off and use :meth:`UserOAuth2.has_credentials` to check. + +.. code-block:: python + :emphasize-lines: 3 + + @app.route('/optional') + def optional(): + if oauth2.has_credentials(): + return 'Credentials found!' + else: + return 'No credentials!' + + +When credentials are available, you can use :attr:`UserOAuth2.email` and +:attr:`UserOAuth2.user_id` to access information from the `ID Token +`__, if +available. + +.. code-block:: python + :emphasize-lines: 4 + + @app.route('/info') + @oauth2.required + def info(): + return "Hello, {} ({})".format(oauth2.email, oauth2.user_id) + + +URLs & Trigging Authorization +============================= + +The extension will add two new routes to your application: + + * ``"oauth2.authorize"`` -> ``/oauth2authorize`` + * ``"oauth2.callback"`` -> ``/oauth2callback`` + +When configuring your OAuth2 credentials on the Google Developer's Console, be +sure to add ``http[s]://[your-app-url]/oauth2callback`` as an authorized +callback url. + +Typically you don't not need to use these routes directly, just be sure to +decorate any views that require credentials with ``@oauth2.required``. If +needed, you can trigger authorization at any time by redirecting the user +to the URL returned by :meth:`UserOAuth2.authorize_url`. + +.. code-block:: python + :emphasize-lines: 3 + + @app.route('/login') + def login(): + return oauth2.authorize_url("/") + + +Incremental Auth +================ + +This extension also supports `Incremental Auth `__. To enable it, +configure the extension with ``include_granted_scopes``. + +.. code-block:: python + + oauth2 = UserOAuth2(app, include_granted_scopes=True) + +Then specify any additional scopes needed on the decorator, for example: + +.. code-block:: python + :emphasize-lines: 2,7 + + @app.route('/drive') + @oauth2.required(scopes=["https://www.googleapis.com/auth/drive"]) + def requires_drive(): + ... + + @app.route('/calendar') + @oauth2.required(scopes=["https://www.googleapis.com/auth/calendar"]) + def requires_calendar(): + ... + +The decorator will ensure that the the user has authorized all specified scopes +before allowing them to access the view, and will also ensure that credentials +do not lose any previously authorized scopes. + + +Storage +======= + +By default, the extension uses a Flask session-based storage solution. This +means that credentials are only available for the duration of a session. It +also means that with Flask's default configuration, the credentials will be +visible in the session cookie. It's highly recommended to use database-backed +session and to use https whenever handling user credentials. + +If you need the credentials to be available longer than a user session or +available outside of a request context, you will need to implement your own +:class:`oauth2client.Storage`. +""" + +from functools import wraps +import hashlib +import json +import os +import pickle + +try: + from flask import Blueprint + from flask import _app_ctx_stack + from flask import current_app + from flask import redirect + from flask import request + from flask import session + from flask import url_for + import markupsafe +except ImportError: # pragma: NO COVER + raise ImportError('The flask utilities require flask 0.9 or newer.') + +import six.moves.http_client as httplib + +from oauth2client import client +from oauth2client import clientsecrets +from oauth2client import transport +from oauth2client.contrib import dictionary_storage + + +_DEFAULT_SCOPES = ('email',) +_CREDENTIALS_KEY = 'google_oauth2_credentials' +_FLOW_KEY = 'google_oauth2_flow_{0}' +_CSRF_KEY = 'google_oauth2_csrf_token' + + +def _get_flow_for_token(csrf_token): + """Retrieves the flow instance associated with a given CSRF token from + the Flask session.""" + flow_pickle = session.pop( + _FLOW_KEY.format(csrf_token), None) + + if flow_pickle is None: + return None + else: + return pickle.loads(flow_pickle) + + +class UserOAuth2(object): + """Flask extension for making OAuth 2.0 easier. + + Configuration values: + + * ``GOOGLE_OAUTH2_CLIENT_SECRETS_FILE`` path to a client secrets json + file, obtained from the credentials screen in the Google Developers + console. + * ``GOOGLE_OAUTH2_CLIENT_ID`` the oauth2 credentials' client ID. This + is only needed if ``GOOGLE_OAUTH2_CLIENT_SECRETS_FILE`` is not + specified. + * ``GOOGLE_OAUTH2_CLIENT_SECRET`` the oauth2 credentials' client + secret. This is only needed if ``GOOGLE_OAUTH2_CLIENT_SECRETS_FILE`` + is not specified. + + If app is specified, all arguments will be passed along to init_app. + + If no app is specified, then you should call init_app in your application + factory to finish initialization. + """ + + def __init__(self, app=None, *args, **kwargs): + self.app = app + if app is not None: + self.init_app(app, *args, **kwargs) + + def init_app(self, app, scopes=None, client_secrets_file=None, + client_id=None, client_secret=None, authorize_callback=None, + storage=None, **kwargs): + """Initialize this extension for the given app. + + Arguments: + app: A Flask application. + scopes: Optional list of scopes to authorize. + client_secrets_file: Path to a file containing client secrets. You + can also specify the GOOGLE_OAUTH2_CLIENT_SECRETS_FILE config + value. + client_id: If not specifying a client secrets file, specify the + OAuth2 client id. You can also specify the + GOOGLE_OAUTH2_CLIENT_ID config value. You must also provide a + client secret. + client_secret: The OAuth2 client secret. You can also specify the + GOOGLE_OAUTH2_CLIENT_SECRET config value. + authorize_callback: A function that is executed after successful + user authorization. + storage: A oauth2client.client.Storage subclass for storing the + credentials. By default, this is a Flask session based storage. + kwargs: Any additional args are passed along to the Flow + constructor. + """ + self.app = app + self.authorize_callback = authorize_callback + self.flow_kwargs = kwargs + + if storage is None: + storage = dictionary_storage.DictionaryStorage( + session, key=_CREDENTIALS_KEY) + self.storage = storage + + if scopes is None: + scopes = app.config.get('GOOGLE_OAUTH2_SCOPES', _DEFAULT_SCOPES) + self.scopes = scopes + + self._load_config(client_secrets_file, client_id, client_secret) + + app.register_blueprint(self._create_blueprint()) + + def _load_config(self, client_secrets_file, client_id, client_secret): + """Loads oauth2 configuration in order of priority. + + Priority: + 1. Config passed to the constructor or init_app. + 2. Config passed via the GOOGLE_OAUTH2_CLIENT_SECRETS_FILE app + config. + 3. Config passed via the GOOGLE_OAUTH2_CLIENT_ID and + GOOGLE_OAUTH2_CLIENT_SECRET app config. + + Raises: + ValueError if no config could be found. + """ + if client_id and client_secret: + self.client_id, self.client_secret = client_id, client_secret + return + + if client_secrets_file: + self._load_client_secrets(client_secrets_file) + return + + if 'GOOGLE_OAUTH2_CLIENT_SECRETS_FILE' in self.app.config: + self._load_client_secrets( + self.app.config['GOOGLE_OAUTH2_CLIENT_SECRETS_FILE']) + return + + try: + self.client_id, self.client_secret = ( + self.app.config['GOOGLE_OAUTH2_CLIENT_ID'], + self.app.config['GOOGLE_OAUTH2_CLIENT_SECRET']) + except KeyError: + raise ValueError( + 'OAuth2 configuration could not be found. Either specify the ' + 'client_secrets_file or client_id and client_secret or set ' + 'the app configuration variables ' + 'GOOGLE_OAUTH2_CLIENT_SECRETS_FILE or ' + 'GOOGLE_OAUTH2_CLIENT_ID and GOOGLE_OAUTH2_CLIENT_SECRET.') + + def _load_client_secrets(self, filename): + """Loads client secrets from the given filename.""" + client_type, client_info = clientsecrets.loadfile(filename) + if client_type != clientsecrets.TYPE_WEB: + raise ValueError( + 'The flow specified in {0} is not supported.'.format( + client_type)) + + self.client_id = client_info['client_id'] + self.client_secret = client_info['client_secret'] + + def _make_flow(self, return_url=None, **kwargs): + """Creates a Web Server Flow""" + # Generate a CSRF token to prevent malicious requests. + csrf_token = hashlib.sha256(os.urandom(1024)).hexdigest() + + session[_CSRF_KEY] = csrf_token + + state = json.dumps({ + 'csrf_token': csrf_token, + 'return_url': return_url + }) + + kw = self.flow_kwargs.copy() + kw.update(kwargs) + + extra_scopes = kw.pop('scopes', []) + scopes = set(self.scopes).union(set(extra_scopes)) + + flow = client.OAuth2WebServerFlow( + client_id=self.client_id, + client_secret=self.client_secret, + scope=scopes, + state=state, + redirect_uri=url_for('oauth2.callback', _external=True), + **kw) + + flow_key = _FLOW_KEY.format(csrf_token) + session[flow_key] = pickle.dumps(flow) + + return flow + + def _create_blueprint(self): + bp = Blueprint('oauth2', __name__) + bp.add_url_rule('/oauth2authorize', 'authorize', self.authorize_view) + bp.add_url_rule('/oauth2callback', 'callback', self.callback_view) + + return bp + + def authorize_view(self): + """Flask view that starts the authorization flow. + + Starts flow by redirecting the user to the OAuth2 provider. + """ + args = request.args.to_dict() + + # Scopes will be passed as mutliple args, and to_dict() will only + # return one. So, we use getlist() to get all of the scopes. + args['scopes'] = request.args.getlist('scopes') + + return_url = args.pop('return_url', None) + if return_url is None: + return_url = request.referrer or '/' + + flow = self._make_flow(return_url=return_url, **args) + auth_url = flow.step1_get_authorize_url() + + return redirect(auth_url) + + def callback_view(self): + """Flask view that handles the user's return from OAuth2 provider. + + On return, exchanges the authorization code for credentials and stores + the credentials. + """ + if 'error' in request.args: + reason = request.args.get( + 'error_description', request.args.get('error', '')) + reason = markupsafe.escape(reason) + return ('Authorization failed: {0}'.format(reason), + httplib.BAD_REQUEST) + + try: + encoded_state = request.args['state'] + server_csrf = session[_CSRF_KEY] + code = request.args['code'] + except KeyError: + return 'Invalid request', httplib.BAD_REQUEST + + try: + state = json.loads(encoded_state) + client_csrf = state['csrf_token'] + return_url = state['return_url'] + except (ValueError, KeyError): + return 'Invalid request state', httplib.BAD_REQUEST + + if client_csrf != server_csrf: + return 'Invalid request state', httplib.BAD_REQUEST + + flow = _get_flow_for_token(server_csrf) + + if flow is None: + return 'Invalid request state', httplib.BAD_REQUEST + + # Exchange the auth code for credentials. + try: + credentials = flow.step2_exchange(code) + except client.FlowExchangeError as exchange_error: + current_app.logger.exception(exchange_error) + content = 'An error occurred: {0}'.format(exchange_error) + return content, httplib.BAD_REQUEST + + # Save the credentials to the storage. + self.storage.put(credentials) + + if self.authorize_callback: + self.authorize_callback(credentials) + + return redirect(return_url) + + @property + def credentials(self): + """The credentials for the current user or None if unavailable.""" + ctx = _app_ctx_stack.top + + if not hasattr(ctx, _CREDENTIALS_KEY): + ctx.google_oauth2_credentials = self.storage.get() + + return ctx.google_oauth2_credentials + + def has_credentials(self): + """Returns True if there are valid credentials for the current user.""" + if not self.credentials: + return False + # Is the access token expired? If so, do we have an refresh token? + elif (self.credentials.access_token_expired and + not self.credentials.refresh_token): + return False + else: + return True + + @property + def email(self): + """Returns the user's email address or None if there are no credentials. + + The email address is provided by the current credentials' id_token. + This should not be used as unique identifier as the user can change + their email. If you need a unique identifier, use user_id. + """ + if not self.credentials: + return None + try: + return self.credentials.id_token['email'] + except KeyError: + current_app.logger.error( + 'Invalid id_token {0}'.format(self.credentials.id_token)) + + @property + def user_id(self): + """Returns the a unique identifier for the user + + Returns None if there are no credentials. + + The id is provided by the current credentials' id_token. + """ + if not self.credentials: + return None + try: + return self.credentials.id_token['sub'] + except KeyError: + current_app.logger.error( + 'Invalid id_token {0}'.format(self.credentials.id_token)) + + def authorize_url(self, return_url, **kwargs): + """Creates a URL that can be used to start the authorization flow. + + When the user is directed to the URL, the authorization flow will + begin. Once complete, the user will be redirected to the specified + return URL. + + Any kwargs are passed into the flow constructor. + """ + return url_for('oauth2.authorize', return_url=return_url, **kwargs) + + def required(self, decorated_function=None, scopes=None, + **decorator_kwargs): + """Decorator to require OAuth2 credentials for a view. + + If credentials are not available for the current user, then they will + be redirected to the authorization flow. Once complete, the user will + be redirected back to the original page. + """ + + def curry_wrapper(wrapped_function): + @wraps(wrapped_function) + def required_wrapper(*args, **kwargs): + return_url = decorator_kwargs.pop('return_url', request.url) + + requested_scopes = set(self.scopes) + if scopes is not None: + requested_scopes |= set(scopes) + if self.has_credentials(): + requested_scopes |= self.credentials.scopes + + requested_scopes = list(requested_scopes) + + # Does the user have credentials and does the credentials have + # all of the needed scopes? + if (self.has_credentials() and + self.credentials.has_scopes(requested_scopes)): + return wrapped_function(*args, **kwargs) + # Otherwise, redirect to authorization + else: + auth_url = self.authorize_url( + return_url, + scopes=requested_scopes, + **decorator_kwargs) + + return redirect(auth_url) + + return required_wrapper + + if decorated_function: + return curry_wrapper(decorated_function) + else: + return curry_wrapper + + def http(self, *args, **kwargs): + """Returns an authorized http instance. + + Can only be called if there are valid credentials for the user, such + as inside of a view that is decorated with @required. + + Args: + *args: Positional arguments passed to httplib2.Http constructor. + **kwargs: Positional arguments passed to httplib2.Http constructor. + + Raises: + ValueError if no credentials are available. + """ + if not self.credentials: + raise ValueError('No credentials available.') + return self.credentials.authorize( + transport.get_http_object(*args, **kwargs)) diff --git a/src/oauth2client/oauth2client/contrib/gce.py b/src/oauth2client/oauth2client/contrib/gce.py new file mode 100644 index 00000000..aaab15ff --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/gce.py @@ -0,0 +1,156 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Google Compute Engine + +Utilities for making it easier to use OAuth 2.0 on Google Compute Engine. +""" + +import logging +import warnings + +from six.moves import http_client + +from oauth2client import client +from oauth2client.contrib import _metadata + + +logger = logging.getLogger(__name__) + +_SCOPES_WARNING = """\ +You have requested explicit scopes to be used with a GCE service account. +Using this argument will have no effect on the actual scopes for tokens +requested. These scopes are set at VM instance creation time and +can't be overridden in the request. +""" + + +class AppAssertionCredentials(client.AssertionCredentials): + """Credentials object for Compute Engine Assertion Grants + + This object will allow a Compute Engine instance to identify itself to + Google and other OAuth 2.0 servers that can verify assertions. It can be + used for the purpose of accessing data stored under an account assigned to + the Compute Engine instance itself. + + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + + Note that :attr:`service_account_email` and :attr:`scopes` + will both return None until the credentials have been refreshed. + To check whether credentials have previously been refreshed use + :attr:`invalid`. + """ + + def __init__(self, email=None, *args, **kwargs): + """Constructor for AppAssertionCredentials + + Args: + email: an email that specifies the service account to use. + Only necessary if using custom service accounts + (see https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#createdefaultserviceaccount). + """ + if 'scopes' in kwargs: + warnings.warn(_SCOPES_WARNING) + kwargs['scopes'] = None + + # Assertion type is no longer used, but still in the + # parent class signature. + super(AppAssertionCredentials, self).__init__(None, *args, **kwargs) + + self.service_account_email = email + self.scopes = None + self.invalid = True + + @classmethod + def from_json(cls, json_data): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def to_json(self): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def retrieve_scopes(self, http): + """Retrieves the canonical list of scopes for this access token. + + Overrides client.Credentials.retrieve_scopes. Fetches scopes info + from the metadata server. + + Args: + http: httplib2.Http, an http object to be used to make the refresh + request. + + Returns: + A set of strings containing the canonical list of scopes. + """ + self._retrieve_info(http) + return self.scopes + + def _retrieve_info(self, http): + """Retrieves service account info for invalid credentials. + + Args: + http: an object to be used to make HTTP requests. + """ + if self.invalid: + info = _metadata.get_service_account_info( + http, + service_account=self.service_account_email or 'default') + self.invalid = False + self.service_account_email = info['email'] + self.scopes = info['scopes'] + + def _refresh(self, http): + """Refreshes the access token. + + Skip all the storage hoops and just refresh using the API. + + Args: + http: an object to be used to make HTTP requests. + + Raises: + HttpAccessTokenRefreshError: When the refresh fails. + """ + try: + self._retrieve_info(http) + self.access_token, self.token_expiry = _metadata.get_token( + http, service_account=self.service_account_email) + except http_client.HTTPException as err: + raise client.HttpAccessTokenRefreshError(str(err)) + + @property + def serialization_data(self): + raise NotImplementedError( + 'Cannot serialize credentials for GCE service accounts.') + + def create_scoped_required(self): + return False + + def sign_blob(self, blob): + """Cryptographically sign a blob (of bytes). + + This method is provided to support a common interface, but + the actual key used for a Google Compute Engine service account + is not available, so it can't be used to sign content. + + Args: + blob: bytes, Message to be signed. + + Raises: + NotImplementedError, always. + """ + raise NotImplementedError( + 'Compute Engine service accounts cannot sign blobs') diff --git a/src/oauth2client/oauth2client/contrib/keyring_storage.py b/src/oauth2client/oauth2client/contrib/keyring_storage.py new file mode 100644 index 00000000..4af94488 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/keyring_storage.py @@ -0,0 +1,95 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A keyring based Storage. + +A Storage for Credentials that uses the keyring module. +""" + +import threading + +import keyring + +from oauth2client import client + + +class Storage(client.Storage): + """Store and retrieve a single credential to and from the keyring. + + To use this module you must have the keyring module installed. See + . This is an optional module and is + not installed with oauth2client by default because it does not work on all + the platforms that oauth2client supports, such as Google App Engine. + + The keyring module is a + cross-platform library for access the keyring capabilities of the local + system. The user will be prompted for their keyring password when this + module is used, and the manner in which the user is prompted will vary per + platform. + + Usage:: + + from oauth2client import keyring_storage + + s = keyring_storage.Storage('name_of_application', 'user1') + credentials = s.get() + + """ + + def __init__(self, service_name, user_name): + """Constructor. + + Args: + service_name: string, The name of the service under which the + credentials are stored. + user_name: string, The name of the user to store credentials for. + """ + super(Storage, self).__init__(lock=threading.Lock()) + self._service_name = service_name + self._user_name = user_name + + def locked_get(self): + """Retrieve Credential from file. + + Returns: + oauth2client.client.Credentials + """ + credentials = None + content = keyring.get_password(self._service_name, self._user_name) + + if content is not None: + try: + credentials = client.Credentials.new_from_json(content) + credentials.set_store(self) + except ValueError: + pass + + return credentials + + def locked_put(self, credentials): + """Write Credentials to file. + + Args: + credentials: Credentials, the credentials to store. + """ + keyring.set_password(self._service_name, self._user_name, + credentials.to_json()) + + def locked_delete(self): + """Delete Credentials file. + + Args: + credentials: Credentials, the credentials to store. + """ + keyring.set_password(self._service_name, self._user_name, '') diff --git a/src/oauth2client/oauth2client/contrib/multiprocess_file_storage.py b/src/oauth2client/oauth2client/contrib/multiprocess_file_storage.py new file mode 100644 index 00000000..e9e8c8cd --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/multiprocess_file_storage.py @@ -0,0 +1,355 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiprocess file credential storage. + +This module provides file-based storage that supports multiple credentials and +cross-thread and process access. + +This module supersedes the functionality previously found in `multistore_file`. + +This module provides :class:`MultiprocessFileStorage` which: + * Is tied to a single credential via a user-specified key. This key can be + used to distinguish between multiple users, client ids, and/or scopes. + * Can be safely accessed and refreshed across threads and processes. + +Process & thread safety guarantees the following behavior: + * If one thread or process refreshes a credential, subsequent refreshes + from other processes will re-fetch the credentials from the file instead + of performing an http request. + * If two processes or threads attempt to refresh concurrently, only one + will be able to acquire the lock and refresh, with the deadlock caveat + below. + * The interprocess lock will not deadlock, instead, the if a process can + not acquire the interprocess lock within ``INTERPROCESS_LOCK_DEADLINE`` + it will allow refreshing the credential but will not write the updated + credential to disk, This logic happens during every lock cycle - if the + credentials are refreshed again it will retry locking and writing as + normal. + +Usage +===== + +Before using the storage, you need to decide how you want to key the +credentials. A few common strategies include: + + * If you're storing credentials for multiple users in a single file, use + a unique identifier for each user as the key. + * If you're storing credentials for multiple client IDs in a single file, + use the client ID as the key. + * If you're storing multiple credentials for one user, use the scopes as + the key. + * If you have a complicated setup, use a compound key. For example, you + can use a combination of the client ID and scopes as the key. + +Create an instance of :class:`MultiprocessFileStorage` for each credential you +want to store, for example:: + + filename = 'credentials' + key = '{}-{}'.format(client_id, user_id) + storage = MultiprocessFileStorage(filename, key) + +To store the credentials:: + + storage.put(credentials) + +If you're going to continue to use the credentials after storing them, be sure +to call :func:`set_store`:: + + credentials.set_store(storage) + +To retrieve the credentials:: + + storage.get(credentials) + +""" + +import base64 +import json +import logging +import os +import threading + +import fasteners +from six import iteritems + +from oauth2client import _helpers +from oauth2client import client + + +#: The maximum amount of time, in seconds, to wait when acquire the +#: interprocess lock before falling back to read-only mode. +INTERPROCESS_LOCK_DEADLINE = 1 + +logger = logging.getLogger(__name__) +_backends = {} +_backends_lock = threading.Lock() + + +def _create_file_if_needed(filename): + """Creates the an empty file if it does not already exist. + + Returns: + True if the file was created, False otherwise. + """ + if os.path.exists(filename): + return False + else: + # Equivalent to "touch". + open(filename, 'a+b').close() + logger.info('Credential file {0} created'.format(filename)) + return True + + +def _load_credentials_file(credentials_file): + """Load credentials from the given file handle. + + The file is expected to be in this format: + + { + "file_version": 2, + "credentials": { + "key": "base64 encoded json representation of credentials." + } + } + + This function will warn and return empty credentials instead of raising + exceptions. + + Args: + credentials_file: An open file handle. + + Returns: + A dictionary mapping user-defined keys to an instance of + :class:`oauth2client.client.Credentials`. + """ + try: + credentials_file.seek(0) + data = json.load(credentials_file) + except Exception: + logger.warning( + 'Credentials file could not be loaded, will ignore and ' + 'overwrite.') + return {} + + if data.get('file_version') != 2: + logger.warning( + 'Credentials file is not version 2, will ignore and ' + 'overwrite.') + return {} + + credentials = {} + + for key, encoded_credential in iteritems(data.get('credentials', {})): + try: + credential_json = base64.b64decode(encoded_credential) + credential = client.Credentials.new_from_json(credential_json) + credentials[key] = credential + except: + logger.warning( + 'Invalid credential {0} in file, ignoring.'.format(key)) + + return credentials + + +def _write_credentials_file(credentials_file, credentials): + """Writes credentials to a file. + + Refer to :func:`_load_credentials_file` for the format. + + Args: + credentials_file: An open file handle, must be read/write. + credentials: A dictionary mapping user-defined keys to an instance of + :class:`oauth2client.client.Credentials`. + """ + data = {'file_version': 2, 'credentials': {}} + + for key, credential in iteritems(credentials): + credential_json = credential.to_json() + encoded_credential = _helpers._from_bytes(base64.b64encode( + _helpers._to_bytes(credential_json))) + data['credentials'][key] = encoded_credential + + credentials_file.seek(0) + json.dump(data, credentials_file) + credentials_file.truncate() + + +class _MultiprocessStorageBackend(object): + """Thread-local backend for multiprocess storage. + + Each process has only one instance of this backend per file. All threads + share a single instance of this backend. This ensures that all threads + use the same thread lock and process lock when accessing the file. + """ + + def __init__(self, filename): + self._file = None + self._filename = filename + self._process_lock = fasteners.InterProcessLock( + '{0}.lock'.format(filename)) + self._thread_lock = threading.Lock() + self._read_only = False + self._credentials = {} + + def _load_credentials(self): + """(Re-)loads the credentials from the file.""" + if not self._file: + return + + loaded_credentials = _load_credentials_file(self._file) + self._credentials.update(loaded_credentials) + + logger.debug('Read credential file') + + def _write_credentials(self): + if self._read_only: + logger.debug('In read-only mode, not writing credentials.') + return + + _write_credentials_file(self._file, self._credentials) + logger.debug('Wrote credential file {0}.'.format(self._filename)) + + def acquire_lock(self): + self._thread_lock.acquire() + locked = self._process_lock.acquire(timeout=INTERPROCESS_LOCK_DEADLINE) + + if locked: + _create_file_if_needed(self._filename) + self._file = open(self._filename, 'r+') + self._read_only = False + + else: + logger.warn( + 'Failed to obtain interprocess lock for credentials. ' + 'If a credential is being refreshed, other processes may ' + 'not see the updated access token and refresh as well.') + if os.path.exists(self._filename): + self._file = open(self._filename, 'r') + else: + self._file = None + self._read_only = True + + self._load_credentials() + + def release_lock(self): + if self._file is not None: + self._file.close() + self._file = None + + if not self._read_only: + self._process_lock.release() + + self._thread_lock.release() + + def _refresh_predicate(self, credentials): + if credentials is None: + return True + elif credentials.invalid: + return True + elif credentials.access_token_expired: + return True + else: + return False + + def locked_get(self, key): + # Check if the credential is already in memory. + credentials = self._credentials.get(key, None) + + # Use the refresh predicate to determine if the entire store should be + # reloaded. This basically checks if the credentials are invalid + # or expired. This covers the situation where another process has + # refreshed the credentials and this process doesn't know about it yet. + # In that case, this process won't needlessly refresh the credentials. + if self._refresh_predicate(credentials): + self._load_credentials() + credentials = self._credentials.get(key, None) + + return credentials + + def locked_put(self, key, credentials): + self._load_credentials() + self._credentials[key] = credentials + self._write_credentials() + + def locked_delete(self, key): + self._load_credentials() + self._credentials.pop(key, None) + self._write_credentials() + + +def _get_backend(filename): + """A helper method to get or create a backend with thread locking. + + This ensures that only one backend is used per-file per-process, so that + thread and process locks are appropriately shared. + + Args: + filename: The full path to the credential storage file. + + Returns: + An instance of :class:`_MultiprocessStorageBackend`. + """ + filename = os.path.abspath(filename) + + with _backends_lock: + if filename not in _backends: + _backends[filename] = _MultiprocessStorageBackend(filename) + return _backends[filename] + + +class MultiprocessFileStorage(client.Storage): + """Multiprocess file credential storage. + + Args: + filename: The path to the file where credentials will be stored. + key: An arbitrary string used to uniquely identify this set of + credentials. For example, you may use the user's ID as the key or + a combination of the client ID and user ID. + """ + def __init__(self, filename, key): + self._key = key + self._backend = _get_backend(filename) + + def acquire_lock(self): + self._backend.acquire_lock() + + def release_lock(self): + self._backend.release_lock() + + def locked_get(self): + """Retrieves the current credentials from the store. + + Returns: + An instance of :class:`oauth2client.client.Credentials` or `None`. + """ + credential = self._backend.locked_get(self._key) + + if credential is not None: + credential.set_store(self) + + return credential + + def locked_put(self, credentials): + """Writes the given credentials to the store. + + Args: + credentials: an instance of + :class:`oauth2client.client.Credentials`. + """ + return self._backend.locked_put(self._key, credentials) + + def locked_delete(self): + """Deletes the current credentials from the store.""" + return self._backend.locked_delete(self._key) diff --git a/src/oauth2client/oauth2client/contrib/sqlalchemy.py b/src/oauth2client/oauth2client/contrib/sqlalchemy.py new file mode 100644 index 00000000..7d9fd4b2 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/sqlalchemy.py @@ -0,0 +1,173 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 utilities for SQLAlchemy. + +Utilities for using OAuth 2.0 in conjunction with a SQLAlchemy. + +Configuration +============= + +In order to use this storage, you'll need to create table +with :class:`oauth2client.contrib.sqlalchemy.CredentialsType` column. +It's recommended to either put this column on some sort of user info +table or put the column in a table with a belongs-to relationship to +a user info table. + +Here's an example of a simple table with a :class:`CredentialsType` +column that's related to a user table by the `user_id` key. + +.. code-block:: python + + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import relationship + + from oauth2client.contrib.sqlalchemy import CredentialsType + + + Base = declarative_base() + + + class Credentials(Base): + __tablename__ = 'credentials' + + user_id = Column(Integer, ForeignKey('user.id')) + credentials = Column(CredentialsType) + + + class User(Base): + id = Column(Integer, primary_key=True) + # bunch of other columns + credentials = relationship('Credentials') + + +Usage +===== + +With tables ready, you are now able to store credentials in database. +We will reuse tables defined above. + +.. code-block:: python + + from sqlalchemy.orm import Session + + from oauth2client.client import OAuth2Credentials + from oauth2client.contrib.sql_alchemy import Storage + + session = Session() + user = session.query(User).first() + storage = Storage( + session=session, + model_class=Credentials, + # This is the key column used to identify + # the row that stores the credentials. + key_name='user_id', + key_value=user.id, + property_name='credentials', + ) + + # Store + credentials = OAuth2Credentials(...) + storage.put(credentials) + + # Retrieve + credentials = storage.get() + + # Delete + storage.delete() + +""" + +from __future__ import absolute_import + +import sqlalchemy.types + +from oauth2client import client + + +class CredentialsType(sqlalchemy.types.PickleType): + """Type representing credentials. + + Alias for :class:`sqlalchemy.types.PickleType`. + """ + + +class Storage(client.Storage): + """Store and retrieve a single credential to and from SQLAlchemy. + This helper presumes the Credentials + have been stored as a Credentials column + on a db model class. + """ + + def __init__(self, session, model_class, key_name, + key_value, property_name): + """Constructor for Storage. + + Args: + session: An instance of :class:`sqlalchemy.orm.Session`. + model_class: SQLAlchemy declarative mapping. + key_name: string, key name for the entity that has the credentials + key_value: key value for the entity that has the credentials + property_name: A string indicating which property on the + ``model_class`` to store the credentials. + This property must be a + :class:`CredentialsType` column. + """ + super(Storage, self).__init__() + + self.session = session + self.model_class = model_class + self.key_name = key_name + self.key_value = key_value + self.property_name = property_name + + def locked_get(self): + """Retrieve stored credential. + + Returns: + A :class:`oauth2client.Credentials` instance or `None`. + """ + filters = {self.key_name: self.key_value} + query = self.session.query(self.model_class).filter_by(**filters) + entity = query.first() + + if entity: + credential = getattr(entity, self.property_name) + if credential and hasattr(credential, 'set_store'): + credential.set_store(self) + return credential + else: + return None + + def locked_put(self, credentials): + """Write a credentials to the SQLAlchemy datastore. + + Args: + credentials: :class:`oauth2client.Credentials` + """ + filters = {self.key_name: self.key_value} + query = self.session.query(self.model_class).filter_by(**filters) + entity = query.first() + + if not entity: + entity = self.model_class(**filters) + + setattr(entity, self.property_name, credentials) + self.session.add(entity) + + def locked_delete(self): + """Delete credentials from the SQLAlchemy datastore.""" + filters = {self.key_name: self.key_value} + self.session.query(self.model_class).filter_by(**filters).delete() diff --git a/src/oauth2client/oauth2client/contrib/xsrfutil.py b/src/oauth2client/oauth2client/contrib/xsrfutil.py new file mode 100644 index 00000000..7c3ec035 --- /dev/null +++ b/src/oauth2client/oauth2client/contrib/xsrfutil.py @@ -0,0 +1,101 @@ +# Copyright 2014 the Melange authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper methods for creating & verifying XSRF tokens.""" + +import base64 +import binascii +import hmac +import time + +from oauth2client import _helpers + + +# Delimiter character +DELIMITER = b':' + +# 1 hour in seconds +DEFAULT_TIMEOUT_SECS = 60 * 60 + + +@_helpers.positional(2) +def generate_token(key, user_id, action_id='', when=None): + """Generates a URL-safe token for the given user, action, time tuple. + + Args: + key: secret key to use. + user_id: the user ID of the authenticated user. + action_id: a string identifier of the action they requested + authorization for. + when: the time in seconds since the epoch at which the user was + authorized for this action. If not set the current time is used. + + Returns: + A string XSRF protection token. + """ + digester = hmac.new(_helpers._to_bytes(key, encoding='utf-8')) + digester.update(_helpers._to_bytes(str(user_id), encoding='utf-8')) + digester.update(DELIMITER) + digester.update(_helpers._to_bytes(action_id, encoding='utf-8')) + digester.update(DELIMITER) + when = _helpers._to_bytes(str(when or int(time.time())), encoding='utf-8') + digester.update(when) + digest = digester.digest() + + token = base64.urlsafe_b64encode(digest + DELIMITER + when) + return token + + +@_helpers.positional(3) +def validate_token(key, token, user_id, action_id="", current_time=None): + """Validates that the given token authorizes the user for the action. + + Tokens are invalid if the time of issue is too old or if the token + does not match what generateToken outputs (i.e. the token was forged). + + Args: + key: secret key to use. + token: a string of the token generated by generateToken. + user_id: the user ID of the authenticated user. + action_id: a string identifier of the action they requested + authorization for. + + Returns: + A boolean - True if the user is authorized for the action, False + otherwise. + """ + if not token: + return False + try: + decoded = base64.urlsafe_b64decode(token) + token_time = int(decoded.split(DELIMITER)[-1]) + except (TypeError, ValueError, binascii.Error): + return False + if current_time is None: + current_time = time.time() + # If the token is too old it's not valid. + if current_time - token_time > DEFAULT_TIMEOUT_SECS: + return False + + # The given token should match the generated one with the same time. + expected_token = generate_token(key, user_id, action_id=action_id, + when=token_time) + if len(token) != len(expected_token): + return False + + # Perform constant time comparison to avoid timing attacks + different = 0 + for x, y in zip(bytearray(token), bytearray(expected_token)): + different |= x ^ y + return not different diff --git a/src/oauth2client/oauth2client/crypt.py b/src/oauth2client/oauth2client/crypt.py new file mode 100644 index 00000000..13260982 --- /dev/null +++ b/src/oauth2client/oauth2client/crypt.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Crypto-related routines for oauth2client.""" + +import json +import logging +import time + +from oauth2client import _helpers +from oauth2client import _pure_python_crypt + + +RsaSigner = _pure_python_crypt.RsaSigner +RsaVerifier = _pure_python_crypt.RsaVerifier + +CLOCK_SKEW_SECS = 300 # 5 minutes in seconds +AUTH_TOKEN_LIFETIME_SECS = 300 # 5 minutes in seconds +MAX_TOKEN_LIFETIME_SECS = 86400 # 1 day in seconds + +logger = logging.getLogger(__name__) + + +class AppIdentityError(Exception): + """Error to indicate crypto failure.""" + + +def _bad_pkcs12_key_as_pem(*args, **kwargs): + raise NotImplementedError('pkcs12_key_as_pem requires OpenSSL.') + + +try: + from oauth2client import _openssl_crypt + OpenSSLSigner = _openssl_crypt.OpenSSLSigner + OpenSSLVerifier = _openssl_crypt.OpenSSLVerifier + pkcs12_key_as_pem = _openssl_crypt.pkcs12_key_as_pem +except ImportError: # pragma: NO COVER + OpenSSLVerifier = None + OpenSSLSigner = None + pkcs12_key_as_pem = _bad_pkcs12_key_as_pem + +try: + from oauth2client import _pycrypto_crypt + PyCryptoSigner = _pycrypto_crypt.PyCryptoSigner + PyCryptoVerifier = _pycrypto_crypt.PyCryptoVerifier +except ImportError: # pragma: NO COVER + PyCryptoVerifier = None + PyCryptoSigner = None + + +if OpenSSLSigner: + Signer = OpenSSLSigner + Verifier = OpenSSLVerifier +elif PyCryptoSigner: # pragma: NO COVER + Signer = PyCryptoSigner + Verifier = PyCryptoVerifier +else: # pragma: NO COVER + Signer = RsaSigner + Verifier = RsaVerifier + + +def make_signed_jwt(signer, payload, key_id=None): + """Make a signed JWT. + + See http://self-issued.info/docs/draft-jones-json-web-token.html. + + Args: + signer: crypt.Signer, Cryptographic signer. + payload: dict, Dictionary of data to convert to JSON and then sign. + key_id: string, (Optional) Key ID header. + + Returns: + string, The JWT for the payload. + """ + header = {'typ': 'JWT', 'alg': 'RS256'} + if key_id is not None: + header['kid'] = key_id + + segments = [ + _helpers._urlsafe_b64encode(_helpers._json_encode(header)), + _helpers._urlsafe_b64encode(_helpers._json_encode(payload)), + ] + signing_input = b'.'.join(segments) + + signature = signer.sign(signing_input) + segments.append(_helpers._urlsafe_b64encode(signature)) + + logger.debug(str(segments)) + + return b'.'.join(segments) + + +def _verify_signature(message, signature, certs): + """Verifies signed content using a list of certificates. + + Args: + message: string or bytes, The message to verify. + signature: string or bytes, The signature on the message. + certs: iterable, certificates in PEM format. + + Raises: + AppIdentityError: If none of the certificates can verify the message + against the signature. + """ + for pem in certs: + verifier = Verifier.from_string(pem, is_x509_cert=True) + if verifier.verify(message, signature): + return + + # If we have not returned, no certificate confirms the signature. + raise AppIdentityError('Invalid token signature') + + +def _check_audience(payload_dict, audience): + """Checks audience field from a JWT payload. + + Does nothing if the passed in ``audience`` is null. + + Args: + payload_dict: dict, A dictionary containing a JWT payload. + audience: string or NoneType, an audience to check for in + the JWT payload. + + Raises: + AppIdentityError: If there is no ``'aud'`` field in the payload + dictionary but there is an ``audience`` to check. + AppIdentityError: If the ``'aud'`` field in the payload dictionary + does not match the ``audience``. + """ + if audience is None: + return + + audience_in_payload = payload_dict.get('aud') + if audience_in_payload is None: + raise AppIdentityError( + 'No aud field in token: {0}'.format(payload_dict)) + if audience_in_payload != audience: + raise AppIdentityError('Wrong recipient, {0} != {1}: {2}'.format( + audience_in_payload, audience, payload_dict)) + + +def _verify_time_range(payload_dict): + """Verifies the issued at and expiration from a JWT payload. + + Makes sure the current time (in UTC) falls between the issued at and + expiration for the JWT (with some skew allowed for via + ``CLOCK_SKEW_SECS``). + + Args: + payload_dict: dict, A dictionary containing a JWT payload. + + Raises: + AppIdentityError: If there is no ``'iat'`` field in the payload + dictionary. + AppIdentityError: If there is no ``'exp'`` field in the payload + dictionary. + AppIdentityError: If the JWT expiration is too far in the future (i.e. + if the expiration would imply a token lifetime + longer than what is allowed.) + AppIdentityError: If the token appears to have been issued in the + future (up to clock skew). + AppIdentityError: If the token appears to have expired in the past + (up to clock skew). + """ + # Get the current time to use throughout. + now = int(time.time()) + + # Make sure issued at and expiration are in the payload. + issued_at = payload_dict.get('iat') + if issued_at is None: + raise AppIdentityError( + 'No iat field in token: {0}'.format(payload_dict)) + expiration = payload_dict.get('exp') + if expiration is None: + raise AppIdentityError( + 'No exp field in token: {0}'.format(payload_dict)) + + # Make sure the expiration gives an acceptable token lifetime. + if expiration >= now + MAX_TOKEN_LIFETIME_SECS: + raise AppIdentityError( + 'exp field too far in future: {0}'.format(payload_dict)) + + # Make sure (up to clock skew) that the token wasn't issued in the future. + earliest = issued_at - CLOCK_SKEW_SECS + if now < earliest: + raise AppIdentityError('Token used too early, {0} < {1}: {2}'.format( + now, earliest, payload_dict)) + # Make sure (up to clock skew) that the token isn't already expired. + latest = expiration + CLOCK_SKEW_SECS + if now > latest: + raise AppIdentityError('Token used too late, {0} > {1}: {2}'.format( + now, latest, payload_dict)) + + +def verify_signed_jwt_with_certs(jwt, certs, audience=None): + """Verify a JWT against public certs. + + See http://self-issued.info/docs/draft-jones-json-web-token.html. + + Args: + jwt: string, A JWT. + certs: dict, Dictionary where values of public keys in PEM format. + audience: string, The audience, 'aud', that this JWT should contain. If + None then the JWT's 'aud' parameter is not verified. + + Returns: + dict, The deserialized JSON payload in the JWT. + + Raises: + AppIdentityError: if any checks are failed. + """ + jwt = _helpers._to_bytes(jwt) + + if jwt.count(b'.') != 2: + raise AppIdentityError( + 'Wrong number of segments in token: {0}'.format(jwt)) + + header, payload, signature = jwt.split(b'.') + message_to_sign = header + b'.' + payload + signature = _helpers._urlsafe_b64decode(signature) + + # Parse token. + payload_bytes = _helpers._urlsafe_b64decode(payload) + try: + payload_dict = json.loads(_helpers._from_bytes(payload_bytes)) + except: + raise AppIdentityError('Can\'t parse token: {0}'.format(payload_bytes)) + + # Verify that the signature matches the message. + _verify_signature(message_to_sign, signature, certs.values()) + + # Verify the issued at and created times in the payload. + _verify_time_range(payload_dict) + + # Check audience. + _check_audience(payload_dict, audience) + + return payload_dict diff --git a/src/oauth2client/oauth2client/file.py b/src/oauth2client/oauth2client/file.py new file mode 100644 index 00000000..3551c80d --- /dev/null +++ b/src/oauth2client/oauth2client/file.py @@ -0,0 +1,95 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for OAuth. + +Utilities for making it easier to work with OAuth 2.0 +credentials. +""" + +import os +import threading + +from oauth2client import _helpers +from oauth2client import client + + +class Storage(client.Storage): + """Store and retrieve a single credential to and from a file.""" + + def __init__(self, filename): + super(Storage, self).__init__(lock=threading.Lock()) + self._filename = filename + + def locked_get(self): + """Retrieve Credential from file. + + Returns: + oauth2client.client.Credentials + + Raises: + IOError if the file is a symbolic link. + """ + credentials = None + _helpers.validate_file(self._filename) + try: + f = open(self._filename, 'rb') + content = f.read() + f.close() + except IOError: + return credentials + + try: + credentials = client.Credentials.new_from_json(content) + credentials.set_store(self) + except ValueError: + pass + + return credentials + + def _create_file_if_needed(self): + """Create an empty file if necessary. + + This method will not initialize the file. Instead it implements a + simple version of "touch" to ensure the file has been created. + """ + if not os.path.exists(self._filename): + old_umask = os.umask(0o177) + try: + open(self._filename, 'a+b').close() + finally: + os.umask(old_umask) + + def locked_put(self, credentials): + """Write Credentials to file. + + Args: + credentials: Credentials, the credentials to store. + + Raises: + IOError if the file is a symbolic link. + """ + self._create_file_if_needed() + _helpers.validate_file(self._filename) + f = open(self._filename, 'w') + f.write(credentials.to_json()) + f.close() + + def locked_delete(self): + """Delete Credentials file. + + Args: + credentials: Credentials, the credentials to store. + """ + os.unlink(self._filename) diff --git a/src/oauth2client/oauth2client/service_account.py b/src/oauth2client/oauth2client/service_account.py new file mode 100644 index 00000000..540bfaaa --- /dev/null +++ b/src/oauth2client/oauth2client/service_account.py @@ -0,0 +1,685 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""oauth2client Service account credentials class.""" + +import base64 +import copy +import datetime +import json +import time + +import oauth2client +from oauth2client import _helpers +from oauth2client import client +from oauth2client import crypt +from oauth2client import transport + + +_PASSWORD_DEFAULT = 'notasecret' +_PKCS12_KEY = '_private_key_pkcs12' +_PKCS12_ERROR = r""" +This library only implements PKCS#12 support via the pyOpenSSL library. +Either install pyOpenSSL, or please convert the .p12 file +to .pem format: + $ cat key.p12 | \ + > openssl pkcs12 -nodes -nocerts -passin pass:notasecret | \ + > openssl rsa > key.pem +""" + + +class ServiceAccountCredentials(client.AssertionCredentials): + """Service Account credential for OAuth 2.0 signed JWT grants. + + Supports + + * JSON keyfile (typically contains a PKCS8 key stored as + PEM text) + * ``.p12`` key (stores PKCS12 key and certificate) + + Makes an assertion to server using a signed JWT assertion in exchange + for an access token. + + This credential does not require a flow to instantiate because it + represents a two legged flow, and therefore has all of the required + information to generate and refresh its own access tokens. + + Args: + service_account_email: string, The email associated with the + service account. + signer: ``crypt.Signer``, A signer which can be used to sign content. + scopes: List or string, (Optional) Scopes to use when acquiring + an access token. + private_key_id: string, (Optional) Private key identifier. Typically + only used with a JSON keyfile. Can be sent in the + header of a JWT token assertion. + client_id: string, (Optional) Client ID for the project that owns the + service account. + user_agent: string, (Optional) User agent to use when sending + request. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + kwargs: dict, Extra key-value pairs (both strings) to send in the + payload body when making an assertion. + """ + + MAX_TOKEN_LIFETIME_SECS = 3600 + """Max lifetime of the token (one hour, in seconds).""" + + NON_SERIALIZED_MEMBERS = ( + frozenset(['_signer']) | + client.AssertionCredentials.NON_SERIALIZED_MEMBERS) + """Members that aren't serialized when object is converted to JSON.""" + + # Can be over-ridden by factory constructors. Used for + # serialization/deserialization purposes. + _private_key_pkcs8_pem = None + _private_key_pkcs12 = None + _private_key_password = None + + def __init__(self, + service_account_email, + signer, + scopes='', + private_key_id=None, + client_id=None, + user_agent=None, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + **kwargs): + + super(ServiceAccountCredentials, self).__init__( + None, user_agent=user_agent, token_uri=token_uri, + revoke_uri=revoke_uri) + + self._service_account_email = service_account_email + self._signer = signer + self._scopes = _helpers.scopes_to_string(scopes) + self._private_key_id = private_key_id + self.client_id = client_id + self._user_agent = user_agent + self._kwargs = kwargs + + def _to_json(self, strip, to_serialize=None): + """Utility function that creates JSON repr. of a credentials object. + + Over-ride is needed since PKCS#12 keys will not in general be JSON + serializable. + + Args: + strip: array, An array of names of members to exclude from the + JSON. + to_serialize: dict, (Optional) The properties for this object + that will be serialized. This allows callers to + modify before serializing. + + Returns: + string, a JSON representation of this instance, suitable to pass to + from_json(). + """ + if to_serialize is None: + to_serialize = copy.copy(self.__dict__) + pkcs12_val = to_serialize.get(_PKCS12_KEY) + if pkcs12_val is not None: + to_serialize[_PKCS12_KEY] = base64.b64encode(pkcs12_val) + return super(ServiceAccountCredentials, self)._to_json( + strip, to_serialize=to_serialize) + + @classmethod + def _from_parsed_json_keyfile(cls, keyfile_dict, scopes, + token_uri=None, revoke_uri=None): + """Helper for factory constructors from JSON keyfile. + + Args: + keyfile_dict: dict-like object, The parsed dictionary-like object + containing the contents of the JSON keyfile. + scopes: List or string, Scopes to use when acquiring an + access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile contents. + + Raises: + ValueError, if the credential type is not :data:`SERVICE_ACCOUNT`. + KeyError, if one of the expected keys is not present in + the keyfile. + """ + creds_type = keyfile_dict.get('type') + if creds_type != client.SERVICE_ACCOUNT: + raise ValueError('Unexpected credentials type', creds_type, + 'Expected', client.SERVICE_ACCOUNT) + + service_account_email = keyfile_dict['client_email'] + private_key_pkcs8_pem = keyfile_dict['private_key'] + private_key_id = keyfile_dict['private_key_id'] + client_id = keyfile_dict['client_id'] + if not token_uri: + token_uri = keyfile_dict.get('token_uri', + oauth2client.GOOGLE_TOKEN_URI) + if not revoke_uri: + revoke_uri = keyfile_dict.get('revoke_uri', + oauth2client.GOOGLE_REVOKE_URI) + + signer = crypt.Signer.from_string(private_key_pkcs8_pem) + credentials = cls(service_account_email, signer, scopes=scopes, + private_key_id=private_key_id, + client_id=client_id, token_uri=token_uri, + revoke_uri=revoke_uri) + credentials._private_key_pkcs8_pem = private_key_pkcs8_pem + return credentials + + @classmethod + def from_json_keyfile_name(cls, filename, scopes='', + token_uri=None, revoke_uri=None): + + """Factory constructor from JSON keyfile by name. + + Args: + filename: string, The location of the keyfile. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in the key file, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in the key file, defaults + to Google's endpoints. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + ValueError, if the credential type is not :data:`SERVICE_ACCOUNT`. + KeyError, if one of the expected keys is not present in + the keyfile. + """ + with open(filename, 'r') as file_obj: + client_credentials = json.load(file_obj) + return cls._from_parsed_json_keyfile(client_credentials, scopes, + token_uri=token_uri, + revoke_uri=revoke_uri) + + @classmethod + def from_json_keyfile_dict(cls, keyfile_dict, scopes='', + token_uri=None, revoke_uri=None): + """Factory constructor from parsed JSON keyfile. + + Args: + keyfile_dict: dict-like object, The parsed dictionary-like object + containing the contents of the JSON keyfile. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + token_uri: string, URI for OAuth 2.0 provider token endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + revoke_uri: string, URI for OAuth 2.0 provider revoke endpoint. + If unset and not present in keyfile_dict, defaults + to Google's endpoints. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + ValueError, if the credential type is not :data:`SERVICE_ACCOUNT`. + KeyError, if one of the expected keys is not present in + the keyfile. + """ + return cls._from_parsed_json_keyfile(keyfile_dict, scopes, + token_uri=token_uri, + revoke_uri=revoke_uri) + + @classmethod + def _from_p12_keyfile_contents(cls, service_account_email, + private_key_pkcs12, + private_key_password=None, scopes='', + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI): + """Factory constructor from JSON keyfile. + + Args: + service_account_email: string, The email associated with the + service account. + private_key_pkcs12: string, The contents of a PKCS#12 keyfile. + private_key_password: string, (Optional) Password for PKCS#12 + private key. Defaults to ``notasecret``. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + NotImplementedError if pyOpenSSL is not installed / not the + active crypto library. + """ + if private_key_password is None: + private_key_password = _PASSWORD_DEFAULT + if crypt.Signer is not crypt.OpenSSLSigner: + raise NotImplementedError(_PKCS12_ERROR) + signer = crypt.Signer.from_string(private_key_pkcs12, + private_key_password) + credentials = cls(service_account_email, signer, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) + credentials._private_key_pkcs12 = private_key_pkcs12 + credentials._private_key_password = private_key_password + return credentials + + @classmethod + def from_p12_keyfile(cls, service_account_email, filename, + private_key_password=None, scopes='', + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI): + + """Factory constructor from JSON keyfile. + + Args: + service_account_email: string, The email associated with the + service account. + filename: string, The location of the PKCS#12 keyfile. + private_key_password: string, (Optional) Password for PKCS#12 + private key. Defaults to ``notasecret``. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + NotImplementedError if pyOpenSSL is not installed / not the + active crypto library. + """ + with open(filename, 'rb') as file_obj: + private_key_pkcs12 = file_obj.read() + return cls._from_p12_keyfile_contents( + service_account_email, private_key_pkcs12, + private_key_password=private_key_password, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) + + @classmethod + def from_p12_keyfile_buffer(cls, service_account_email, file_buffer, + private_key_password=None, scopes='', + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI): + """Factory constructor from JSON keyfile. + + Args: + service_account_email: string, The email associated with the + service account. + file_buffer: stream, A buffer that implements ``read()`` + and contains the PKCS#12 key contents. + private_key_password: string, (Optional) Password for PKCS#12 + private key. Defaults to ``notasecret``. + scopes: List or string, (Optional) Scopes to use when acquiring an + access token. + token_uri: string, URI for token endpoint. For convenience defaults + to Google's endpoints but any OAuth 2.0 provider can be + used. + revoke_uri: string, URI for revoke endpoint. For convenience + defaults to Google's endpoints but any OAuth 2.0 + provider can be used. + + Returns: + ServiceAccountCredentials, a credentials object created from + the keyfile. + + Raises: + NotImplementedError if pyOpenSSL is not installed / not the + active crypto library. + """ + private_key_pkcs12 = file_buffer.read() + return cls._from_p12_keyfile_contents( + service_account_email, private_key_pkcs12, + private_key_password=private_key_password, scopes=scopes, + token_uri=token_uri, revoke_uri=revoke_uri) + + def _generate_assertion(self): + """Generate the assertion that will be used in the request.""" + now = int(time.time()) + payload = { + 'aud': self.token_uri, + 'scope': self._scopes, + 'iat': now, + 'exp': now + self.MAX_TOKEN_LIFETIME_SECS, + 'iss': self._service_account_email, + } + payload.update(self._kwargs) + return crypt.make_signed_jwt(self._signer, payload, + key_id=self._private_key_id) + + def sign_blob(self, blob): + """Cryptographically sign a blob (of bytes). + + Implements abstract method + :meth:`oauth2client.client.AssertionCredentials.sign_blob`. + + Args: + blob: bytes, Message to be signed. + + Returns: + tuple, A pair of the private key ID used to sign the blob and + the signed contents. + """ + return self._private_key_id, self._signer.sign(blob) + + @property + def service_account_email(self): + """Get the email for the current service account. + + Returns: + string, The email associated with the service account. + """ + return self._service_account_email + + @property + def serialization_data(self): + # NOTE: This is only useful for JSON keyfile. + return { + 'type': 'service_account', + 'client_email': self._service_account_email, + 'private_key_id': self._private_key_id, + 'private_key': self._private_key_pkcs8_pem, + 'client_id': self.client_id, + } + + @classmethod + def from_json(cls, json_data): + """Deserialize a JSON-serialized instance. + + Inverse to :meth:`to_json`. + + Args: + json_data: dict or string, Serialized JSON (as a string or an + already parsed dictionary) representing a credential. + + Returns: + ServiceAccountCredentials from the serialized data. + """ + if not isinstance(json_data, dict): + json_data = json.loads(_helpers._from_bytes(json_data)) + + private_key_pkcs8_pem = None + pkcs12_val = json_data.get(_PKCS12_KEY) + password = None + if pkcs12_val is None: + private_key_pkcs8_pem = json_data['_private_key_pkcs8_pem'] + signer = crypt.Signer.from_string(private_key_pkcs8_pem) + else: + # NOTE: This assumes that private_key_pkcs8_pem is not also + # in the serialized data. This would be very incorrect + # state. + pkcs12_val = base64.b64decode(pkcs12_val) + password = json_data['_private_key_password'] + signer = crypt.Signer.from_string(pkcs12_val, password) + + credentials = cls( + json_data['_service_account_email'], + signer, + scopes=json_data['_scopes'], + private_key_id=json_data['_private_key_id'], + client_id=json_data['client_id'], + user_agent=json_data['_user_agent'], + **json_data['_kwargs'] + ) + if private_key_pkcs8_pem is not None: + credentials._private_key_pkcs8_pem = private_key_pkcs8_pem + if pkcs12_val is not None: + credentials._private_key_pkcs12 = pkcs12_val + if password is not None: + credentials._private_key_password = password + credentials.invalid = json_data['invalid'] + credentials.access_token = json_data['access_token'] + credentials.token_uri = json_data['token_uri'] + credentials.revoke_uri = json_data['revoke_uri'] + token_expiry = json_data.get('token_expiry', None) + if token_expiry is not None: + credentials.token_expiry = datetime.datetime.strptime( + token_expiry, client.EXPIRY_FORMAT) + return credentials + + def create_scoped_required(self): + return not self._scopes + + def create_scoped(self, scopes): + result = self.__class__(self._service_account_email, + self._signer, + scopes=scopes, + private_key_id=self._private_key_id, + client_id=self.client_id, + user_agent=self._user_agent, + **self._kwargs) + result.token_uri = self.token_uri + result.revoke_uri = self.revoke_uri + result._private_key_pkcs8_pem = self._private_key_pkcs8_pem + result._private_key_pkcs12 = self._private_key_pkcs12 + result._private_key_password = self._private_key_password + return result + + def create_with_claims(self, claims): + """Create credentials that specify additional claims. + + Args: + claims: dict, key-value pairs for claims. + + Returns: + ServiceAccountCredentials, a copy of the current service account + credentials with updated claims to use when obtaining access + tokens. + """ + new_kwargs = dict(self._kwargs) + new_kwargs.update(claims) + result = self.__class__(self._service_account_email, + self._signer, + scopes=self._scopes, + private_key_id=self._private_key_id, + client_id=self.client_id, + user_agent=self._user_agent, + **new_kwargs) + result.token_uri = self.token_uri + result.revoke_uri = self.revoke_uri + result._private_key_pkcs8_pem = self._private_key_pkcs8_pem + result._private_key_pkcs12 = self._private_key_pkcs12 + result._private_key_password = self._private_key_password + return result + + def create_delegated(self, sub): + """Create credentials that act as domain-wide delegation of authority. + + Use the ``sub`` parameter as the subject to delegate on behalf of + that user. + + For example:: + + >>> account_sub = 'foo@email.com' + >>> delegate_creds = creds.create_delegated(account_sub) + + Args: + sub: string, An email address that this service account will + act on behalf of (via domain-wide delegation). + + Returns: + ServiceAccountCredentials, a copy of the current service account + updated to act on behalf of ``sub``. + """ + return self.create_with_claims({'sub': sub}) + + +def _datetime_to_secs(utc_time): + # TODO(issue 298): use time_delta.total_seconds() + # time_delta.total_seconds() not supported in Python 2.6 + epoch = datetime.datetime(1970, 1, 1) + time_delta = utc_time - epoch + return time_delta.days * 86400 + time_delta.seconds + + +class _JWTAccessCredentials(ServiceAccountCredentials): + """Self signed JWT credentials. + + Makes an assertion to server using a self signed JWT from service account + credentials. These credentials do NOT use OAuth 2.0 and instead + authenticate directly. + """ + _MAX_TOKEN_LIFETIME_SECS = 3600 + """Max lifetime of the token (one hour, in seconds).""" + + def __init__(self, + service_account_email, + signer, + scopes=None, + private_key_id=None, + client_id=None, + user_agent=None, + token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI, + additional_claims=None): + if additional_claims is None: + additional_claims = {} + super(_JWTAccessCredentials, self).__init__( + service_account_email, + signer, + private_key_id=private_key_id, + client_id=client_id, + user_agent=user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, + **additional_claims) + + def authorize(self, http): + """Authorize an httplib2.Http instance with a JWT assertion. + + Unless specified, the 'aud' of the assertion will be the base + uri of the request. + + Args: + http: An instance of ``httplib2.Http`` or something that acts + like it. + Returns: + A modified instance of http that was passed in. + Example:: + h = httplib2.Http() + h = credentials.authorize(h) + """ + transport.wrap_http_for_jwt_access(self, http) + return http + + def get_access_token(self, http=None, additional_claims=None): + """Create a signed jwt. + + Args: + http: unused + additional_claims: dict, additional claims to add to + the payload of the JWT. + Returns: + An AccessTokenInfo with the signed jwt + """ + if additional_claims is None: + if self.access_token is None or self.access_token_expired: + self.refresh(None) + return client.AccessTokenInfo( + access_token=self.access_token, expires_in=self._expires_in()) + else: + # Create a 1 time token + token, unused_expiry = self._create_token(additional_claims) + return client.AccessTokenInfo( + access_token=token, expires_in=self._MAX_TOKEN_LIFETIME_SECS) + + def revoke(self, http): + """Cannot revoke JWTAccessCredentials tokens.""" + pass + + def create_scoped_required(self): + # JWTAccessCredentials are unscoped by definition + return True + + def create_scoped(self, scopes, token_uri=oauth2client.GOOGLE_TOKEN_URI, + revoke_uri=oauth2client.GOOGLE_REVOKE_URI): + # Returns an OAuth2 credentials with the given scope + result = ServiceAccountCredentials(self._service_account_email, + self._signer, + scopes=scopes, + private_key_id=self._private_key_id, + client_id=self.client_id, + user_agent=self._user_agent, + token_uri=token_uri, + revoke_uri=revoke_uri, + **self._kwargs) + if self._private_key_pkcs8_pem is not None: + result._private_key_pkcs8_pem = self._private_key_pkcs8_pem + if self._private_key_pkcs12 is not None: + result._private_key_pkcs12 = self._private_key_pkcs12 + if self._private_key_password is not None: + result._private_key_password = self._private_key_password + return result + + def refresh(self, http): + """Refreshes the access_token. + + The HTTP object is unused since no request needs to be made to + get a new token, it can just be generated locally. + + Args: + http: unused HTTP object + """ + self._refresh(None) + + def _refresh(self, http): + """Refreshes the access_token. + + Args: + http: unused HTTP object + """ + self.access_token, self.token_expiry = self._create_token() + + def _create_token(self, additional_claims=None): + now = client._UTCNOW() + lifetime = datetime.timedelta(seconds=self._MAX_TOKEN_LIFETIME_SECS) + expiry = now + lifetime + payload = { + 'iat': _datetime_to_secs(now), + 'exp': _datetime_to_secs(expiry), + 'iss': self._service_account_email, + 'sub': self._service_account_email + } + payload.update(self._kwargs) + if additional_claims is not None: + payload.update(additional_claims) + jwt = crypt.make_signed_jwt(self._signer, payload, + key_id=self._private_key_id) + return jwt.decode('ascii'), expiry diff --git a/src/oauth2client/oauth2client/tools.py b/src/oauth2client/oauth2client/tools.py new file mode 100644 index 00000000..51669934 --- /dev/null +++ b/src/oauth2client/oauth2client/tools.py @@ -0,0 +1,256 @@ +# Copyright 2014 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Command-line tools for authenticating via OAuth 2.0 + +Do the OAuth 2.0 Web Server dance for a command line application. Stores the +generated credentials in a common file that is used by other example apps in +the same directory. +""" + +from __future__ import print_function + +import logging +import socket +import sys + +from six.moves import BaseHTTPServer +from six.moves import http_client +from six.moves import input +from six.moves import urllib + +from oauth2client import _helpers +from oauth2client import client + + +__all__ = ['argparser', 'run_flow', 'message_if_missing'] + +_CLIENT_SECRETS_MESSAGE = """WARNING: Please configure OAuth 2.0 + +To make this sample run you will need to populate the client_secrets.json file +found at: + + {file_path} + +with information from the APIs Console . + +""" + +_FAILED_START_MESSAGE = """ +Failed to start a local webserver listening on either port 8080 +or port 8090. Please check your firewall settings and locally +running programs that may be blocking or using those ports. + +Falling back to --noauth_local_webserver and continuing with +authorization. +""" + +_BROWSER_OPENED_MESSAGE = """ +Your browser has been opened to visit: + + {address} + +If your browser is on a different machine then exit and re-run this +application with the command-line parameter + + --noauth_local_webserver +""" + +_GO_TO_LINK_MESSAGE = """ +Go to the following link in your browser: + + {address} +""" + + +def _CreateArgumentParser(): + try: + import argparse + except ImportError: # pragma: NO COVER + return None + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument('--auth_host_name', default='localhost', + help='Hostname when running a local web server.') + parser.add_argument('--noauth_local_webserver', action='store_true', + default=False, help='Do not run a local web server.') + parser.add_argument('--auth_host_port', default=[8080, 8090], type=int, + nargs='*', help='Port web server should listen on.') + parser.add_argument( + '--logging_level', default='ERROR', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Set the logging level of detail.') + return parser + + +# argparser is an ArgumentParser that contains command-line options expected +# by tools.run(). Pass it in as part of the 'parents' argument to your own +# ArgumentParser. +argparser = _CreateArgumentParser() + + +class ClientRedirectServer(BaseHTTPServer.HTTPServer): + """A server to handle OAuth 2.0 redirects back to localhost. + + Waits for a single request and parses the query parameters + into query_params and then stops serving. + """ + query_params = {} + + +class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """A handler for OAuth 2.0 redirects back to localhost. + + Waits for a single request and parses the query parameters + into the servers query_params and then stops serving. + """ + + def do_GET(self): + """Handle a GET request. + + Parses the query parameters and prints a message + if the flow has completed. Note that we can't detect + if an error occurred. + """ + self.send_response(http_client.OK) + self.send_header('Content-type', 'text/html') + self.end_headers() + parts = urllib.parse.urlparse(self.path) + query = _helpers.parse_unique_urlencoded(parts.query) + self.server.query_params = query + self.wfile.write( + b'Authentication Status') + self.wfile.write( + b'

The authentication flow has completed.

') + self.wfile.write(b'') + + def log_message(self, format, *args): + """Do not log messages to stdout while running as cmd. line program.""" + + +@_helpers.positional(3) +def run_flow(flow, storage, flags=None, http=None): + """Core code for a command-line application. + + The ``run()`` function is called from your application and runs + through all the steps to obtain credentials. It takes a ``Flow`` + argument and attempts to open an authorization server page in the + user's default web browser. The server asks the user to grant your + application access to the user's data. If the user grants access, + the ``run()`` function returns new credentials. The new credentials + are also stored in the ``storage`` argument, which updates the file + associated with the ``Storage`` object. + + It presumes it is run from a command-line application and supports the + following flags: + + ``--auth_host_name`` (string, default: ``localhost``) + Host name to use when running a local web server to handle + redirects during OAuth authorization. + + ``--auth_host_port`` (integer, default: ``[8080, 8090]``) + Port to use when running a local web server to handle redirects + during OAuth authorization. Repeat this option to specify a list + of values. + + ``--[no]auth_local_webserver`` (boolean, default: ``True``) + Run a local web server to handle redirects during OAuth + authorization. + + The tools module defines an ``ArgumentParser`` the already contains the + flag definitions that ``run()`` requires. You can pass that + ``ArgumentParser`` to your ``ArgumentParser`` constructor:: + + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + parents=[tools.argparser]) + flags = parser.parse_args(argv) + + Args: + flow: Flow, an OAuth 2.0 Flow to step through. + storage: Storage, a ``Storage`` to store the credential in. + flags: ``argparse.Namespace``, (Optional) The command-line flags. This + is the object returned from calling ``parse_args()`` on + ``argparse.ArgumentParser`` as described above. Defaults + to ``argparser.parse_args()``. + http: An instance of ``httplib2.Http.request`` or something that + acts like it. + + Returns: + Credentials, the obtained credential. + """ + if flags is None: + flags = argparser.parse_args() + logging.getLogger().setLevel(getattr(logging, flags.logging_level)) + if not flags.noauth_local_webserver: + success = False + port_number = 0 + for port in flags.auth_host_port: + port_number = port + try: + httpd = ClientRedirectServer((flags.auth_host_name, port), + ClientRedirectHandler) + except socket.error: + pass + else: + success = True + break + flags.noauth_local_webserver = not success + if not success: + print(_FAILED_START_MESSAGE) + + if not flags.noauth_local_webserver: + oauth_callback = 'http://{host}:{port}/'.format( + host=flags.auth_host_name, port=port_number) + else: + oauth_callback = client.OOB_CALLBACK_URN + flow.redirect_uri = oauth_callback + authorize_url = flow.step1_get_authorize_url() + + if not flags.noauth_local_webserver: + import webbrowser + webbrowser.open(authorize_url, new=1, autoraise=True) + print(_BROWSER_OPENED_MESSAGE.format(address=authorize_url)) + else: + print(_GO_TO_LINK_MESSAGE.format(address=authorize_url)) + + code = None + if not flags.noauth_local_webserver: + httpd.handle_request() + if 'error' in httpd.query_params: + sys.exit('Authentication request was rejected.') + if 'code' in httpd.query_params: + code = httpd.query_params['code'] + else: + print('Failed to find "code" in the query parameters ' + 'of the redirect.') + sys.exit('Try running with --noauth_local_webserver.') + else: + code = input('Enter verification code: ').strip() + + try: + credential = flow.step2_exchange(code, http=http) + except client.FlowExchangeError as e: + sys.exit('Authentication has failed: {0}'.format(e)) + + storage.put(credential) + credential.set_store(storage) + print('Authentication successful.') + + return credential + + +def message_if_missing(filename): + """Helpful message to display if the CLIENT_SECRETS file is missing.""" + return _CLIENT_SECRETS_MESSAGE.format(file_path=filename) diff --git a/src/oauth2client/oauth2client/transport.py b/src/oauth2client/oauth2client/transport.py new file mode 100644 index 00000000..79a61f1c --- /dev/null +++ b/src/oauth2client/oauth2client/transport.py @@ -0,0 +1,285 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import httplib2 +import six +from six.moves import http_client + +from oauth2client import _helpers + + +_LOGGER = logging.getLogger(__name__) +# Properties present in file-like streams / buffers. +_STREAM_PROPERTIES = ('read', 'seek', 'tell') + +# Google Data client libraries may need to set this to [401, 403]. +REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) + + +class MemoryCache(object): + """httplib2 Cache implementation which only caches locally.""" + + def __init__(self): + self.cache = {} + + def get(self, key): + return self.cache.get(key) + + def set(self, key, value): + self.cache[key] = value + + def delete(self, key): + self.cache.pop(key, None) + + +def get_cached_http(): + """Return an HTTP object which caches results returned. + + This is intended to be used in methods like + oauth2client.client.verify_id_token(), which calls to the same URI + to retrieve certs. + + Returns: + httplib2.Http, an HTTP object with a MemoryCache + """ + return _CACHED_HTTP + + +def get_http_object(*args, **kwargs): + """Return a new HTTP object. + + Args: + *args: tuple, The positional arguments to be passed when + contructing a new HTTP object. + **kwargs: dict, The keyword arguments to be passed when + contructing a new HTTP object. + + Returns: + httplib2.Http, an HTTP object. + """ + return httplib2.Http(*args, **kwargs) + + +def _initialize_headers(headers): + """Creates a copy of the headers. + + Args: + headers: dict, request headers to copy. + + Returns: + dict, the copied headers or a new dictionary if the headers + were None. + """ + return {} if headers is None else dict(headers) + + +def _apply_user_agent(headers, user_agent): + """Adds a user-agent to the headers. + + Args: + headers: dict, request headers to add / modify user + agent within. + user_agent: str, the user agent to add. + + Returns: + dict, the original headers passed in, but modified if the + user agent is not None. + """ + if user_agent is not None: + if 'user-agent' in headers: + headers['user-agent'] = (user_agent + ' ' + headers['user-agent']) + else: + headers['user-agent'] = user_agent + + return headers + + +def clean_headers(headers): + """Forces header keys and values to be strings, i.e not unicode. + + The httplib module just concats the header keys and values in a way that + may make the message header a unicode string, which, if it then tries to + contatenate to a binary request body may result in a unicode decode error. + + Args: + headers: dict, A dictionary of headers. + + Returns: + The same dictionary but with all the keys converted to strings. + """ + clean = {} + try: + for k, v in six.iteritems(headers): + if not isinstance(k, six.binary_type): + k = str(k) + if not isinstance(v, six.binary_type): + v = str(v) + clean[_helpers._to_bytes(k)] = _helpers._to_bytes(v) + except UnicodeEncodeError: + from oauth2client.client import NonAsciiHeaderError + raise NonAsciiHeaderError(k, ': ', v) + return clean + + +def wrap_http_for_auth(credentials, http): + """Prepares an HTTP object's request method for auth. + + Wraps HTTP requests with logic to catch auth failures (typically + identified via a 401 status code). In the event of failure, tries + to refresh the token used and then retry the original request. + + Args: + credentials: Credentials, the credentials used to identify + the authenticated user. + http: httplib2.Http, an http object to be used to make + auth requests. + """ + orig_request_method = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if not credentials.access_token: + _LOGGER.info('Attempting refresh to obtain ' + 'initial access_token') + credentials._refresh(orig_request_method) + + # Clone and modify the request headers to add the appropriate + # Authorization header. + headers = _initialize_headers(headers) + credentials.apply(headers) + _apply_user_agent(headers, credentials.user_agent) + + body_stream_position = None + # Check if the body is a file-like stream. + if all(getattr(body, stream_prop, None) for stream_prop in + _STREAM_PROPERTIES): + body_stream_position = body.tell() + + resp, content = request(orig_request_method, uri, method, body, + clean_headers(headers), + redirections, connection_type) + + # A stored token may expire between the time it is retrieved and + # the time the request is made, so we may need to try twice. + max_refresh_attempts = 2 + for refresh_attempt in range(max_refresh_attempts): + if resp.status not in REFRESH_STATUS_CODES: + break + _LOGGER.info('Refreshing due to a %s (attempt %s/%s)', + resp.status, refresh_attempt + 1, + max_refresh_attempts) + credentials._refresh(orig_request_method) + credentials.apply(headers) + if body_stream_position is not None: + body.seek(body_stream_position) + + resp, content = request(orig_request_method, uri, method, body, + clean_headers(headers), + redirections, connection_type) + + return resp, content + + # Replace the request method with our own closure. + http.request = new_request + + # Set credentials as a property of the request method. + http.request.credentials = credentials + + +def wrap_http_for_jwt_access(credentials, http): + """Prepares an HTTP object's request method for JWT access. + + Wraps HTTP requests with logic to catch auth failures (typically + identified via a 401 status code). In the event of failure, tries + to refresh the token used and then retry the original request. + + Args: + credentials: _JWTAccessCredentials, the credentials used to identify + a service account that uses JWT access tokens. + http: httplib2.Http, an http object to be used to make + auth requests. + """ + orig_request_method = http.request + wrap_http_for_auth(credentials, http) + # The new value of ``http.request`` set by ``wrap_http_for_auth``. + authenticated_request_method = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if 'aud' in credentials._kwargs: + # Preemptively refresh token, this is not done for OAuth2 + if (credentials.access_token is None or + credentials.access_token_expired): + credentials.refresh(None) + return request(authenticated_request_method, uri, + method, body, headers, redirections, + connection_type) + else: + # If we don't have an 'aud' (audience) claim, + # create a 1-time token with the uri root as the audience + headers = _initialize_headers(headers) + _apply_user_agent(headers, credentials.user_agent) + uri_root = uri.split('?', 1)[0] + token, unused_expiry = credentials._create_token({'aud': uri_root}) + + headers['Authorization'] = 'Bearer ' + token + return request(orig_request_method, uri, method, body, + clean_headers(headers), + redirections, connection_type) + + # Replace the request method with our own closure. + http.request = new_request + + # Set credentials as a property of the request method. + http.request.credentials = credentials + + +def request(http, uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + """Make an HTTP request with an HTTP object and arguments. + + Args: + http: httplib2.Http, an http object to be used to make requests. + uri: string, The URI to be requested. + method: string, The HTTP method to use for the request. Defaults + to 'GET'. + body: string, The payload / body in HTTP request. By default + there is no payload. + headers: dict, Key-value pairs of request headers. By default + there are no headers. + redirections: int, The number of allowed 203 redirects for + the request. Defaults to 5. + connection_type: httplib.HTTPConnection, a subclass to be used for + establishing connection. If not set, the type + will be determined from the ``uri``. + + Returns: + tuple, a pair of a httplib2.Response with the status code and other + headers and the bytes of the content returned. + """ + # NOTE: Allowing http or http.request is temporary (See Issue 601). + http_callable = getattr(http, 'request', http) + return http_callable(uri, method=method, body=body, headers=headers, + redirections=redirections, + connection_type=connection_type) + + +_CACHED_HTTP = httplib2.Http(MemoryCache()) diff --git a/src/pyasn1/__init__.py b/src/pyasn1/__init__.py index 71bb22fe..68db4f1b 100644 --- a/src/pyasn1/__init__.py +++ b/src/pyasn1/__init__.py @@ -1,7 +1,7 @@ import sys # https://www.python.org/dev/peps/pep-0396/ -__version__ = '0.4.3' +__version__ = '0.4.5' if sys.version_info[:2] < (2, 4): raise RuntimeError('PyASN1 requires Python 2.4 or later') diff --git a/src/pyasn1/codec/ber/decoder.py b/src/pyasn1/codec/ber/decoder.py index a27b3e0e..591bbc4b 100644 --- a/src/pyasn1/codec/ber/decoder.py +++ b/src/pyasn1/codec/ber/decoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import debug @@ -18,6 +18,8 @@ from pyasn1.type import useful __all__ = ['decode'] +LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER) + noValue = base.noValue @@ -70,6 +72,10 @@ class ExplicitTagDecoder(AbstractSimpleDecoder): value, _ = decodeFun(head, asn1Spec, tagSet, length, **options) + if LOG: + LOG('explicit tag container carries %d octets of trailing payload ' + '(will be lost!): %s' % (len(_), debug.hexdump(_))) + return value, tail def indefLenValueDecoder(self, substrate, asn1Spec, @@ -120,7 +126,8 @@ class BooleanDecoder(IntegerDecoder): protoComponent = univ.Boolean(0) def _createComponent(self, asn1Spec, tagSet, value, **options): - return IntegerDecoder._createComponent(self, asn1Spec, tagSet, value and 1 or 0, **options) + return IntegerDecoder._createComponent( + self, asn1Spec, tagSet, value and 1 or 0, **options) class BitStringDecoder(AbstractSimpleDecoder): @@ -134,8 +141,8 @@ class BitStringDecoder(AbstractSimpleDecoder): head, tail = substrate[:length], substrate[length:] if substrateFun: - return substrateFun(self._createComponent(asn1Spec, tagSet, noValue, **options), - substrate, length) + return substrateFun(self._createComponent( + asn1Spec, tagSet, noValue, **options), substrate, length) if not head: raise error.PyAsn1Error('Empty BIT STRING substrate') @@ -148,12 +155,17 @@ class BitStringDecoder(AbstractSimpleDecoder): 'Trailing bits overflow %s' % trailingBits ) - value = self.protoComponent.fromOctetString(head[1:], internalFormat=True, padding=trailingBits) + value = self.protoComponent.fromOctetString( + head[1:], internalFormat=True, padding=trailingBits) return self._createComponent(asn1Spec, tagSet, value, **options), tail if not self.supportConstructedForm: - raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__) + raise error.PyAsn1Error('Constructed encoding form prohibited ' + 'at %s' % self.__class__.__name__) + + if LOG: + LOG('assembling constructed serialization') # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector @@ -234,6 +246,9 @@ class OctetStringDecoder(AbstractSimpleDecoder): if not self.supportConstructedForm: raise error.PyAsn1Error('Constructed encoding form prohibited at %s' % self.__class__.__name__) + if LOG: + LOG('assembling constructed serialization') + # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector @@ -267,7 +282,9 @@ class OctetStringDecoder(AbstractSimpleDecoder): allowEoo=True, **options) if component is eoo.endOfOctets: break + header += component + else: raise error.SubstrateUnderrunError( 'No EOO seen before substrate ends' @@ -374,59 +391,90 @@ class RealDecoder(AbstractSimpleDecoder): if fo & 0x80: # binary encoding if not head: raise error.PyAsn1Error("Incomplete floating-point value") + + if LOG: + LOG('decoding binary encoded REAL') + n = (fo & 0x03) + 1 + if n == 4: n = oct2int(head[0]) head = head[1:] + eo, head = head[:n], head[n:] + if not eo or not head: raise error.PyAsn1Error('Real exponent screwed') + e = oct2int(eo[0]) & 0x80 and -1 or 0 + while eo: # exponent e <<= 8 e |= oct2int(eo[0]) eo = eo[1:] + b = fo >> 4 & 0x03 # base bits + if b > 2: raise error.PyAsn1Error('Illegal Real base') + if b == 1: # encbase = 8 e *= 3 + elif b == 2: # encbase = 16 e *= 4 p = 0 + while head: # value p <<= 8 p |= oct2int(head[0]) head = head[1:] + if fo & 0x40: # sign bit p = -p + sf = fo >> 2 & 0x03 # scale bits p *= 2 ** sf value = (p, 2, e) + elif fo & 0x40: # infinite value + if LOG: + LOG('decoding infinite REAL') + value = fo & 0x01 and '-inf' or 'inf' + elif fo & 0xc0 == 0: # character encoding if not head: raise error.PyAsn1Error("Incomplete floating-point value") + + if LOG: + LOG('decoding character encoded REAL') + try: if fo & 0x3 == 0x1: # NR1 value = (int(head), 10, 0) + elif fo & 0x3 == 0x2: # NR2 value = float(head) + elif fo & 0x3 == 0x3: # NR3 value = float(head) + else: raise error.SubstrateUnderrunError( 'Unknown NR (tag %s)' % fo ) + except ValueError: raise error.SubstrateUnderrunError( 'Bad character Real syntax' ) + else: raise error.SubstrateUnderrunError( 'Unknown encoding (tag %s)' % fo ) + return self._createComponent(asn1Spec, tagSet, value, **options), tail @@ -447,10 +495,12 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): def _decodeComponents(self, substrate, tagSet=None, decodeFun=None, **options): components = [] componentTypes = set() + while substrate: component, substrate = decodeFun(substrate, **options) if component is eoo.endOfOctets: break + components.append(component) componentTypes.add(component.tagSet) @@ -460,6 +510,7 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): # * otherwise -> likely SEQUENCE OF/SET OF if len(componentTypes) > 1: protoComponent = self.protoRecordComponent + else: protoComponent = self.protoSequenceComponent @@ -469,6 +520,10 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): tagSet=tag.TagSet(protoComponent.tagSet.baseTag, *tagSet.superTags) ) + if LOG: + LOG('guessed %r container type (pass `asn1Spec` to guide the ' + 'decoder)' % asn1Object) + for idx, component in enumerate(components): asn1Object.setComponentByPosition( idx, component, @@ -490,8 +545,10 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): if substrateFun is not None: if asn1Spec is not None: asn1Object = asn1Spec.clone() + elif self.protoComponent is not None: asn1Object = self.protoComponent.clone(tagSet=tagSet) + else: asn1Object = self.protoRecordComponent, self.protoSequenceComponent @@ -501,8 +558,12 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): asn1Object, trailing = self._decodeComponents( head, tagSet=tagSet, decodeFun=decodeFun, **options ) + if trailing: - raise error.PyAsn1Error('Unused trailing %d octets encountered' % len(trailing)) + if LOG: + LOG('Unused trailing %d octets encountered: %s' % ( + len(trailing), debug.hexdump(trailing))) + return asn1Object, tail asn1Object = asn1Spec.clone() @@ -514,21 +575,31 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): isSetType = asn1Spec.typeId == univ.Set.typeId isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault + if LOG: + LOG('decoding %sdeterministic %s type %r chosen by type ID' % ( + not isDeterministic and 'non-' or '', isSetType and 'SET' or '', + asn1Spec)) + seenIndices = set() idx = 0 while head: if not namedTypes: componentType = None + elif isSetType: componentType = namedTypes.tagMapUnique + else: try: if isDeterministic: componentType = namedTypes[idx].asn1Object + elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted: componentType = namedTypes.getTagMapNearPosition(idx) + else: componentType = namedTypes[idx].asn1Object + except IndexError: raise error.PyAsn1Error( 'Excessive components decoded at %r' % (asn1Spec,) @@ -539,6 +610,7 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): if not isDeterministic and namedTypes: if isSetType: idx = namedTypes.getPositionByType(component.effectiveTagSet) + elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted: idx = namedTypes.getPositionNearType(component.effectiveTagSet, idx) @@ -551,14 +623,22 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): seenIndices.add(idx) idx += 1 + if LOG: + LOG('seen component indices %s' % seenIndices) + if namedTypes: if not namedTypes.requiredComponents.issubset(seenIndices): - raise error.PyAsn1Error('ASN.1 object %s has uninitialized components' % asn1Object.__class__.__name__) + raise error.PyAsn1Error( + 'ASN.1 object %s has uninitialized ' + 'components' % asn1Object.__class__.__name__) if namedTypes.hasOpenTypes: openTypes = options.get('openTypes', {}) + if LOG: + LOG('using open types map: %r' % openTypes) + if openTypes or options.get('decodeOpenTypes', False): for idx, namedType in enumerate(namedTypes.namedTypes): @@ -581,8 +661,15 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): openType = namedType.openType[governingValue] except KeyError: + if LOG: + LOG('failed to resolve open type by governing ' + 'value %r' % (governingValue,)) continue + if LOG: + LOG('resolved open type %r by governing ' + 'value %r' % (openType, governingValue)) + component, rest = decodeFun( asn1Object.getComponentByPosition(idx).asOctets(), asn1Spec=openType @@ -598,6 +685,9 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): componentType = asn1Spec.componentType + if LOG: + LOG('decoding type %r chosen by given `asn1Spec`' % componentType) + idx = 0 while head: @@ -607,6 +697,7 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): verifyConstraints=False, matchTags=False, matchConstraints=False ) + idx += 1 return asn1Object, tail @@ -621,8 +712,10 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): if substrateFun is not None: if asn1Spec is not None: asn1Object = asn1Spec.clone() + elif self.protoComponent is not None: asn1Object = self.protoComponent.clone(tagSet=tagSet) + else: asn1Object = self.protoRecordComponent, self.protoSequenceComponent @@ -642,21 +735,31 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): isSetType = asn1Object.typeId == univ.Set.typeId isDeterministic = not isSetType and not namedTypes.hasOptionalOrDefault + if LOG: + LOG('decoding %sdeterministic %s type %r chosen by type ID' % ( + not isDeterministic and 'non-' or '', isSetType and 'SET' or '', + asn1Spec)) + seenIndices = set() idx = 0 while substrate: if len(namedTypes) <= idx: asn1Spec = None + elif isSetType: asn1Spec = namedTypes.tagMapUnique + else: try: if isDeterministic: asn1Spec = namedTypes[idx].asn1Object + elif namedTypes[idx].isOptional or namedTypes[idx].isDefaulted: asn1Spec = namedTypes.getTagMapNearPosition(idx) + else: asn1Spec = namedTypes[idx].asn1Object + except IndexError: raise error.PyAsn1Error( 'Excessive components decoded at %r' % (asn1Object,) @@ -686,13 +789,19 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): 'No EOO seen before substrate ends' ) + if LOG: + LOG('seen component indices %s' % seenIndices) + if namedTypes: if not namedTypes.requiredComponents.issubset(seenIndices): raise error.PyAsn1Error('ASN.1 object %s has uninitialized components' % asn1Object.__class__.__name__) if namedTypes.hasOpenTypes: - openTypes = options.get('openTypes', None) + openTypes = options.get('openTypes', {}) + + if LOG: + LOG('using open types map: %r' % openTypes) if openTypes or options.get('decodeOpenTypes', False): @@ -716,8 +825,15 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): openType = namedType.openType[governingValue] except KeyError: + if LOG: + LOG('failed to resolve open type by governing ' + 'value %r' % (governingValue,)) continue + if LOG: + LOG('resolved open type %r by governing ' + 'value %r' % (openType, governingValue)) + component, rest = decodeFun( asn1Object.getComponentByPosition(idx).asOctets(), asn1Spec=openType, allowEoo=True @@ -734,6 +850,9 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): componentType = asn1Spec.componentType + if LOG: + LOG('decoding type %r chosen by given `asn1Spec`' % componentType) + idx = 0 while substrate: @@ -747,7 +866,9 @@ class UniversalConstructedTypeDecoder(AbstractConstructedDecoder): verifyConstraints=False, matchTags=False, matchConstraints=False ) + idx += 1 + else: raise error.SubstrateUnderrunError( 'No EOO seen before substrate ends' @@ -794,18 +915,25 @@ class ChoiceDecoder(AbstractConstructedDecoder): if asn1Spec is None: asn1Object = self.protoComponent.clone(tagSet=tagSet) + else: asn1Object = asn1Spec.clone() if substrateFun: return substrateFun(asn1Object, substrate, length) - if asn1Object.tagSet == tagSet: # explicitly tagged Choice + if asn1Object.tagSet == tagSet: + if LOG: + LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,)) + component, head = decodeFun( head, asn1Object.componentTagMap, **options ) else: + if LOG: + LOG('decoding %s as untagged CHOICE' % (tagSet,)) + component, head = decodeFun( head, asn1Object.componentTagMap, tagSet, length, state, **options @@ -813,6 +941,9 @@ class ChoiceDecoder(AbstractConstructedDecoder): effectiveTagSet = component.effectiveTagSet + if LOG: + LOG('decoded component %s, effective tag set %s' % (component, effectiveTagSet)) + asn1Object.setComponentByType( effectiveTagSet, component, verifyConstraints=False, @@ -834,18 +965,26 @@ class ChoiceDecoder(AbstractConstructedDecoder): if substrateFun: return substrateFun(asn1Object, substrate, length) - if asn1Object.tagSet == tagSet: # explicitly tagged Choice + if asn1Object.tagSet == tagSet: + if LOG: + LOG('decoding %s as explicitly tagged CHOICE' % (tagSet,)) + component, substrate = decodeFun( substrate, asn1Object.componentType.tagMapUnique, **options ) + # eat up EOO marker eooMarker, substrate = decodeFun( substrate, allowEoo=True, **options ) + if eooMarker is not eoo.endOfOctets: raise error.PyAsn1Error('No EOO seen before substrate ends') else: + if LOG: + LOG('decoding %s as untagged CHOICE' % (tagSet,)) + component, substrate = decodeFun( substrate, asn1Object.componentType.tagMapUnique, tagSet, length, state, **options @@ -853,6 +992,9 @@ class ChoiceDecoder(AbstractConstructedDecoder): effectiveTagSet = component.effectiveTagSet + if LOG: + LOG('decoded component %s, effective tag set %s' % (component, effectiveTagSet)) + asn1Object.setComponentByType( effectiveTagSet, component, verifyConstraints=False, @@ -877,6 +1019,9 @@ class AnyDecoder(AbstractSimpleDecoder): length += len(fullSubstrate) - len(substrate) substrate = fullSubstrate + if LOG: + LOG('decoding as untagged ANY, substrate %s' % debug.hexdump(substrate)) + if substrateFun: return substrateFun(self._createComponent(asn1Spec, tagSet, noValue, **options), substrate, length) @@ -892,12 +1037,19 @@ class AnyDecoder(AbstractSimpleDecoder): if asn1Spec is not None and tagSet == asn1Spec.tagSet: # tagged Any type -- consume header substrate header = null + + if LOG: + LOG('decoding as tagged ANY') + else: fullSubstrate = options['fullSubstrate'] # untagged Any, recover header substrate header = fullSubstrate[:-len(substrate)] + if LOG: + LOG('decoding as untagged ANY, header substrate %s' % debug.hexdump(header)) + # Any components do not inherit initial tag asn1Spec = self.protoComponent @@ -905,6 +1057,9 @@ class AnyDecoder(AbstractSimpleDecoder): asn1Object = self._createComponent(asn1Spec, tagSet, noValue, **options) return substrateFun(asn1Object, header + substrate, length + len(header)) + if LOG: + LOG('assembling constructed serialization') + # All inner fragments are of the same type, treat them as octet string substrateFun = self.substrateCollector @@ -914,13 +1069,17 @@ class AnyDecoder(AbstractSimpleDecoder): allowEoo=True, **options) if component is eoo.endOfOctets: break + header += component + else: raise error.SubstrateUnderrunError( 'No EOO seen before substrate ends' ) + if substrateFun: return header, substrate + else: return self._createComponent(asn1Spec, tagSet, header, **options), substrate @@ -1063,21 +1222,16 @@ class Decoder(object): decodeFun=None, substrateFun=None, **options): - if debug.logger & debug.flagDecoder: - logger = debug.logger - else: - logger = None - - if logger: - logger('decoder called at scope %s with state %d, working with up to %d octets of substrate: %s' % (debug.scope, state, len(substrate), debug.hexdump(substrate))) + if LOG: + LOG('decoder called at scope %s with state %d, working with up to %d octets of substrate: %s' % (debug.scope, state, len(substrate), debug.hexdump(substrate))) allowEoo = options.pop('allowEoo', False) # Look for end-of-octets sentinel if allowEoo and self.supportIndefLength: if substrate[:2] == self.__eooSentinel: - if logger: - logger('end-of-octets sentinel found') + if LOG: + LOG('end-of-octets sentinel found') return eoo.endOfOctets, substrate[2:] value = noValue @@ -1090,26 +1244,32 @@ class Decoder(object): fullSubstrate = substrate while state is not stStop: + if state is stDecodeTag: if not substrate: raise error.SubstrateUnderrunError( 'Short octet stream on tag decoding' ) + # Decode tag isShortTag = True firstOctet = substrate[0] substrate = substrate[1:] + try: lastTag = tagCache[firstOctet] + except KeyError: integerTag = oct2int(firstOctet) tagClass = integerTag & 0xC0 tagFormat = integerTag & 0x20 tagId = integerTag & 0x1F + if tagId == 0x1F: isShortTag = False lengthOctetIdx = 0 tagId = 0 + try: while True: integerTag = oct2int(substrate[lengthOctetIdx]) @@ -1118,42 +1278,55 @@ class Decoder(object): tagId |= (integerTag & 0x7F) if not integerTag & 0x80: break + substrate = substrate[lengthOctetIdx:] + except IndexError: raise error.SubstrateUnderrunError( 'Short octet stream on long tag decoding' ) + lastTag = tag.Tag( tagClass=tagClass, tagFormat=tagFormat, tagId=tagId ) + if isShortTag: # cache short tags tagCache[firstOctet] = lastTag + if tagSet is None: if isShortTag: try: tagSet = tagSetCache[firstOctet] + except KeyError: # base tag not recovered tagSet = tag.TagSet((), lastTag) tagSetCache[firstOctet] = tagSet else: tagSet = tag.TagSet((), lastTag) + else: tagSet = lastTag + tagSet + state = stDecodeLength - if logger: - logger('tag decoded into %s, decoding length' % tagSet) + + if LOG: + LOG('tag decoded into %s, decoding length' % tagSet) + if state is stDecodeLength: # Decode length if not substrate: raise error.SubstrateUnderrunError( 'Short octet stream on length decoding' ) + firstOctet = oct2int(substrate[0]) + if firstOctet < 128: size = 1 length = firstOctet + elif firstOctet > 128: size = firstOctet & 0x7F # encoded in size bytes @@ -1164,28 +1337,36 @@ class Decoder(object): raise error.SubstrateUnderrunError( '%s<%s at %s' % (size, len(encodedLength), tagSet) ) + length = 0 for lengthOctet in encodedLength: length <<= 8 length |= lengthOctet size += 1 + else: size = 1 length = -1 substrate = substrate[size:] + if length == -1: if not self.supportIndefLength: raise error.PyAsn1Error('Indefinite length encoding not supported by this codec') + else: if len(substrate) < length: raise error.SubstrateUnderrunError('%d-octet short' % (length - len(substrate))) + state = stGetValueDecoder - if logger: - logger('value length decoded into %d, payload substrate is: %s' % (length, debug.hexdump(length == -1 and substrate or substrate[:length]))) + + if LOG: + LOG('value length decoded into %d, payload substrate is: %s' % (length, debug.hexdump(length == -1 and substrate or substrate[:length]))) + if state is stGetValueDecoder: if asn1Spec is None: state = stGetValueDecoderByTag + else: state = stGetValueDecoderByAsn1Spec # @@ -1207,41 +1388,55 @@ class Decoder(object): if state is stGetValueDecoderByTag: try: concreteDecoder = tagMap[tagSet] + except KeyError: concreteDecoder = None + if concreteDecoder: state = stDecodeValue + else: try: concreteDecoder = tagMap[tagSet[:1]] + except KeyError: concreteDecoder = None + if concreteDecoder: state = stDecodeValue else: state = stTryAsExplicitTag - if logger: - logger('codec %s chosen by a built-in type, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as explicit tag')) + + if LOG: + LOG('codec %s chosen by a built-in type, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as explicit tag')) debug.scope.push(concreteDecoder is None and '?' or concreteDecoder.protoComponent.__class__.__name__) + if state is stGetValueDecoderByAsn1Spec: + if asn1Spec.__class__ is tagmap.TagMap: try: chosenSpec = asn1Spec[tagSet] + except KeyError: chosenSpec = None - if logger: - logger('candidate ASN.1 spec is a map of:') + + if LOG: + LOG('candidate ASN.1 spec is a map of:') + for firstOctet, v in asn1Spec.presentTypes.items(): - logger(' %s -> %s' % (firstOctet, v.__class__.__name__)) + LOG(' %s -> %s' % (firstOctet, v.__class__.__name__)) + if asn1Spec.skipTypes: - logger('but neither of: ') + LOG('but neither of: ') for firstOctet, v in asn1Spec.skipTypes.items(): - logger(' %s -> %s' % (firstOctet, v.__class__.__name__)) - logger('new candidate ASN.1 spec is %s, chosen by %s' % (chosenSpec is None and '' or chosenSpec.prettyPrintType(), tagSet)) + LOG(' %s -> %s' % (firstOctet, v.__class__.__name__)) + LOG('new candidate ASN.1 spec is %s, chosen by %s' % (chosenSpec is None and '' or chosenSpec.prettyPrintType(), tagSet)) + elif tagSet == asn1Spec.tagSet or tagSet in asn1Spec.tagMap: chosenSpec = asn1Spec - if logger: - logger('candidate ASN.1 spec is %s' % asn1Spec.__class__.__name__) + if LOG: + LOG('candidate ASN.1 spec is %s' % asn1Spec.__class__.__name__) + else: chosenSpec = None @@ -1249,29 +1444,38 @@ class Decoder(object): try: # ambiguous type or just faster codec lookup concreteDecoder = typeMap[chosenSpec.typeId] - if logger: - logger('value decoder chosen for an ambiguous type by type ID %s' % (chosenSpec.typeId,)) + + if LOG: + LOG('value decoder chosen for an ambiguous type by type ID %s' % (chosenSpec.typeId,)) + except KeyError: # use base type for codec lookup to recover untagged types baseTagSet = tag.TagSet(chosenSpec.tagSet.baseTag, chosenSpec.tagSet.baseTag) try: # base type or tagged subtype concreteDecoder = tagMap[baseTagSet] - if logger: - logger('value decoder chosen by base %s' % (baseTagSet,)) + + if LOG: + LOG('value decoder chosen by base %s' % (baseTagSet,)) + except KeyError: concreteDecoder = None + if concreteDecoder: asn1Spec = chosenSpec state = stDecodeValue + else: state = stTryAsExplicitTag + else: concreteDecoder = None state = stTryAsExplicitTag - if logger: - logger('codec %s chosen by ASN.1 spec, decoding %s' % (state is stDecodeValue and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as explicit tag')) + + if LOG: + LOG('codec %s chosen by ASN.1 spec, decoding %s' % (state is stDecodeValue and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as explicit tag')) debug.scope.push(chosenSpec is None and '?' or chosenSpec.__class__.__name__) + if state is stDecodeValue: if not options.get('recursiveFlag', True) and not substrateFun: # deprecate this substrateFun = lambda a, b, c: (a, b[:c]) @@ -1285,6 +1489,7 @@ class Decoder(object): self, substrateFun, **options ) + else: value, substrate = concreteDecoder.valueDecoder( substrate, asn1Spec, @@ -1293,33 +1498,42 @@ class Decoder(object): **options ) - if logger: - logger('codec %s yields type %s, value:\n%s\n...remaining substrate is: %s' % (concreteDecoder.__class__.__name__, value.__class__.__name__, isinstance(value, base.Asn1Item) and value.prettyPrint() or value, substrate and debug.hexdump(substrate) or '')) + if LOG: + LOG('codec %s yields type %s, value:\n%s\n...remaining substrate is: %s' % (concreteDecoder.__class__.__name__, value.__class__.__name__, isinstance(value, base.Asn1Item) and value.prettyPrint() or value, substrate and debug.hexdump(substrate) or '')) state = stStop break + if state is stTryAsExplicitTag: if tagSet and tagSet[0].tagFormat == tag.tagFormatConstructed and tagSet[0].tagClass != tag.tagClassUniversal: # Assume explicit tagging concreteDecoder = explicitTagDecoder state = stDecodeValue + else: concreteDecoder = None state = self.defaultErrorState - if logger: - logger('codec %s chosen, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as failure')) + + if LOG: + LOG('codec %s chosen, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state is stDecodeValue and 'value' or 'as failure')) + if state is stDumpRawValue: concreteDecoder = self.defaultRawDecoder - if logger: - logger('codec %s chosen, decoding value' % concreteDecoder.__class__.__name__) + + if LOG: + LOG('codec %s chosen, decoding value' % concreteDecoder.__class__.__name__) + state = stDecodeValue + if state is stErrorCondition: raise error.PyAsn1Error( '%s not in asn1Spec: %r' % (tagSet, asn1Spec) ) - if logger: + + if LOG: debug.scope.pop() - logger('decoder left scope %s, call completed' % debug.scope) + LOG('decoder left scope %s, call completed' % debug.scope) + return value, substrate diff --git a/src/pyasn1/codec/ber/encoder.py b/src/pyasn1/codec/ber/encoder.py index 0094b224..65b85141 100644 --- a/src/pyasn1/codec/ber/encoder.py +++ b/src/pyasn1/codec/ber/encoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import debug @@ -17,6 +17,8 @@ from pyasn1.type import useful __all__ = ['encode'] +LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_ENCODER) + class AbstractItemEncoder(object): supportIndefLenMode = True @@ -31,29 +33,39 @@ class AbstractItemEncoder(object): encodedTag = tagClass | tagFormat if isConstructed: encodedTag |= tag.tagFormatConstructed + if tagId < 31: return encodedTag | tagId, + else: substrate = tagId & 0x7f, + tagId >>= 7 + while tagId: substrate = (0x80 | (tagId & 0x7f),) + substrate tagId >>= 7 + return (encodedTag | 0x1F,) + substrate def encodeLength(self, length, defMode): if not defMode and self.supportIndefLenMode: return (0x80,) + if length < 0x80: return length, + else: substrate = () while length: substrate = (length & 0xff,) + substrate length >>= 8 + substrateLen = len(substrate) + if substrateLen > 126: raise error.PyAsn1Error('Length octets overflow (%d)' % substrateLen) + return (0x80 | substrateLen,) + substrate def encodeValue(self, value, asn1Spec, encodeFun, **options): @@ -85,16 +97,33 @@ class AbstractItemEncoder(object): value, asn1Spec, encodeFun, **options ) + if LOG: + LOG('encoded %svalue %s into %s' % ( + isConstructed and 'constructed ' or '', value, substrate + )) + if not substrate and isConstructed and options.get('ifNotEmpty', False): return substrate - # primitive form implies definite mode if not isConstructed: defModeOverride = True + if LOG: + LOG('overridden encoding mode into definitive for primitive type') + header = self.encodeTag(singleTag, isConstructed) + + if LOG: + LOG('encoded %stag %s into %s' % ( + isConstructed and 'constructed ' or '', + singleTag, debug.hexdump(ints2octs(header)))) + header += self.encodeLength(len(substrate), defModeOverride) + if LOG: + LOG('encoded %s octets (tag + payload) into %s' % ( + len(substrate), debug.hexdump(ints2octs(header)))) + if isOctets: substrate = ints2octs(header) + substrate @@ -131,6 +160,11 @@ class IntegerEncoder(AbstractItemEncoder): def encodeValue(self, value, asn1Spec, encodeFun, **options): if value == 0: + if LOG: + LOG('encoding %spayload for zero INTEGER' % ( + self.supportCompactZero and 'no ' or '' + )) + # de-facto way to encode zero if self.supportCompactZero: return (), False, False @@ -157,11 +191,15 @@ class BitStringEncoder(AbstractItemEncoder): substrate = alignedValue.asOctets() return int2oct(len(substrate) * 8 - valueLength) + substrate, False, True + if LOG: + LOG('encoding into up to %s-octet chunks' % maxChunkSize) + baseTag = value.tagSet.baseTag # strip off explicit tags if baseTag: tagSet = tag.TagSet(baseTag, baseTag) + else: tagSet = tag.TagSet() @@ -195,44 +233,47 @@ class OctetStringEncoder(AbstractItemEncoder): if not maxChunkSize or len(substrate) <= maxChunkSize: return substrate, False, True - else: + if LOG: + LOG('encoding into up to %s-octet chunks' % maxChunkSize) - # strip off explicit tags for inner chunks + # strip off explicit tags for inner chunks - if asn1Spec is None: - baseTag = value.tagSet.baseTag + if asn1Spec is None: + baseTag = value.tagSet.baseTag - # strip off explicit tags - if baseTag: - tagSet = tag.TagSet(baseTag, baseTag) - else: - tagSet = tag.TagSet() + # strip off explicit tags + if baseTag: + tagSet = tag.TagSet(baseTag, baseTag) - asn1Spec = value.clone(tagSet=tagSet) + else: + tagSet = tag.TagSet() - elif not isOctetsType(value): - baseTag = asn1Spec.tagSet.baseTag + asn1Spec = value.clone(tagSet=tagSet) - # strip off explicit tags - if baseTag: - tagSet = tag.TagSet(baseTag, baseTag) - else: - tagSet = tag.TagSet() + elif not isOctetsType(value): + baseTag = asn1Spec.tagSet.baseTag - asn1Spec = asn1Spec.clone(tagSet=tagSet) + # strip off explicit tags + if baseTag: + tagSet = tag.TagSet(baseTag, baseTag) - pos = 0 - substrate = null + else: + tagSet = tag.TagSet() - while True: - chunk = value[pos:pos + maxChunkSize] - if not chunk: - break + asn1Spec = asn1Spec.clone(tagSet=tagSet) - substrate += encodeFun(chunk, asn1Spec, **options) - pos += maxChunkSize + pos = 0 + substrate = null - return substrate, True, True + while True: + chunk = value[pos:pos + maxChunkSize] + if not chunk: + break + + substrate += encodeFun(chunk, asn1Spec, **options) + pos += maxChunkSize + + return substrate, True, True class NullEncoder(AbstractItemEncoder): @@ -268,8 +309,10 @@ class ObjectIdentifierEncoder(AbstractItemEncoder): oid = (second + 80,) + oid[2:] else: raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,)) + elif first == 2: oid = (second + 80,) + oid[2:] + else: raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,)) @@ -280,15 +323,19 @@ class ObjectIdentifierEncoder(AbstractItemEncoder): if 0 <= subOid <= 127: # Optimize for the common case octets += (subOid,) + elif subOid > 127: # Pack large Sub-Object IDs res = (subOid & 0x7f,) subOid >>= 7 + while subOid: res = (0x80 | (subOid & 0x7f),) + res subOid >>= 7 + # Add packed Sub-Object ID to resulted Object ID octets += res + else: raise error.PyAsn1Error('Negative OID arc %s at %s' % (subOid, value)) @@ -304,12 +351,16 @@ class RealEncoder(AbstractItemEncoder): ms, es = 1, 1 if m < 0: ms = -1 # mantissa sign + if e < 0: - es = -1 # exponenta sign + es = -1 # exponent sign + m *= ms + if encbase == 8: m *= 2 ** (abs(e) % 3 * es) e = abs(e) // 3 * es + elif encbase == 16: m *= 2 ** (abs(e) % 4 * es) e = abs(e) // 4 * es @@ -320,6 +371,7 @@ class RealEncoder(AbstractItemEncoder): e -= 1 continue break + return ms, int(m), encbase, e def _chooseEncBase(self, value): @@ -327,23 +379,32 @@ class RealEncoder(AbstractItemEncoder): encBase = [2, 8, 16] if value.binEncBase in encBase: return self._dropFloatingPoint(m, value.binEncBase, e) + elif self.binEncBase in encBase: return self._dropFloatingPoint(m, self.binEncBase, e) - # auto choosing base 2/8/16 + + # auto choosing base 2/8/16 mantissa = [m, m, m] - exponenta = [e, e, e] + exponent = [e, e, e] sign = 1 encbase = 2 e = float('inf') + for i in range(3): (sign, mantissa[i], encBase[i], - exponenta[i]) = self._dropFloatingPoint(mantissa[i], encBase[i], exponenta[i]) - if abs(exponenta[i]) < abs(e) or (abs(exponenta[i]) == abs(e) and mantissa[i] < m): - e = exponenta[i] + exponent[i]) = self._dropFloatingPoint(mantissa[i], encBase[i], exponent[i]) + + if abs(exponent[i]) < abs(e) or (abs(exponent[i]) == abs(e) and mantissa[i] < m): + e = exponent[i] m = int(mantissa[i]) encbase = encBase[i] + + if LOG: + LOG('automatically chosen REAL encoding base %s, sign %s, mantissa %s, ' + 'exponent %s' % (encbase, sign, m, e)) + return sign, m, encbase, e def encodeValue(self, value, asn1Spec, encodeFun, **options): @@ -352,69 +413,98 @@ class RealEncoder(AbstractItemEncoder): if value.isPlusInf: return (0x40,), False, False + if value.isMinusInf: return (0x41,), False, False + m, b, e = value + if not m: return null, False, True + if b == 10: + if LOG: + LOG('encoding REAL into character form') + return str2octs('\x03%dE%s%d' % (m, e == 0 and '+' or '', e)), False, True + elif b == 2: fo = 0x80 # binary encoding ms, m, encbase, e = self._chooseEncBase(value) + if ms < 0: # mantissa sign fo |= 0x40 # sign bit - # exponenta & mantissa normalization + + # exponent & mantissa normalization if encbase == 2: while m & 0x1 == 0: m >>= 1 e += 1 + elif encbase == 8: while m & 0x7 == 0: m >>= 3 e += 1 fo |= 0x10 + else: # encbase = 16 while m & 0xf == 0: m >>= 4 e += 1 fo |= 0x20 + sf = 0 # scale factor + while m & 0x1 == 0: m >>= 1 sf += 1 + if sf > 3: raise error.PyAsn1Error('Scale factor overflow') # bug if raised + fo |= sf << 2 eo = null if e == 0 or e == -1: eo = int2oct(e & 0xff) + else: while e not in (0, -1): eo = int2oct(e & 0xff) + eo e >>= 8 + if e == 0 and eo and oct2int(eo[0]) & 0x80: eo = int2oct(0) + eo + if e == -1 and eo and not (oct2int(eo[0]) & 0x80): eo = int2oct(0xff) + eo + n = len(eo) if n > 0xff: raise error.PyAsn1Error('Real exponent overflow') + if n == 1: pass + elif n == 2: fo |= 1 + elif n == 3: fo |= 2 + else: fo |= 3 eo = int2oct(n & 0xff) + eo + po = null + while m: po = int2oct(m & 0xff) + po m >>= 8 + substrate = int2oct(fo) + eo + po + return substrate, False, True + else: raise error.PyAsn1Error('Prohibited Real base %s' % b) @@ -439,10 +529,14 @@ class SequenceEncoder(AbstractItemEncoder): namedType = namedTypes[idx] if namedType.isOptional and not component.isValue: - continue + if LOG: + LOG('not encoding OPTIONAL component %r' % (namedType,)) + continue if namedType.isDefaulted and component == namedType.asn1Object: - continue + if LOG: + LOG('not encoding DEFAULT component %r' % (namedType,)) + continue if self.omitEmptyOptionals: options.update(ifNotEmpty=namedType.isOptional) @@ -455,6 +549,9 @@ class SequenceEncoder(AbstractItemEncoder): if wrapType.tagSet and not wrapType.isSameTypeWith(component): chunk = encodeFun(chunk, wrapType, **options) + if LOG: + LOG('wrapped open type with wrap type %r' % (wrapType,)) + substrate += chunk else: @@ -465,12 +562,17 @@ class SequenceEncoder(AbstractItemEncoder): component = value[namedType.name] except KeyError: - raise error.PyAsn1Error('Component name "%s" not found in %r' % (namedType.name, value)) + raise error.PyAsn1Error('Component name "%s" not found in %r' % ( + namedType.name, value)) if namedType.isOptional and namedType.name not in value: + if LOG: + LOG('not encoding OPTIONAL component %r' % (namedType,)) continue if namedType.isDefaulted and component == namedType.asn1Object: + if LOG: + LOG('not encoding DEFAULT component %r' % (namedType,)) continue if self.omitEmptyOptionals: @@ -484,6 +586,9 @@ class SequenceEncoder(AbstractItemEncoder): if wrapType.tagSet and not wrapType.isSameTypeWith(component): chunk = encodeFun(chunk, wrapType, **options) + if LOG: + LOG('wrapped open type with wrap type %r' % (wrapType,)) + substrate += chunk return substrate, True, True @@ -620,13 +725,8 @@ class Encoder(object): raise error.PyAsn1Error('Value %r is not ASN.1 type instance ' 'and "asn1Spec" not given' % (value,)) - if debug.logger & debug.flagEncoder: - logger = debug.logger - else: - logger = None - - if logger: - logger('encoder called in %sdef mode, chunk size %s for ' + if LOG: + LOG('encoder called in %sdef mode, chunk size %s for ' 'type %s, value:\n%s' % (not options.get('defMode', True) and 'in' or '', options.get('maxChunkSize', 0), asn1Spec is None and value.prettyPrintType() or asn1Spec.prettyPrintType(), value)) if self.fixedDefLengthMode is not None: @@ -639,8 +739,8 @@ class Encoder(object): try: concreteEncoder = self.__typeMap[typeId] - if logger: - logger('using value codec %s chosen by type ID %s' % (concreteEncoder.__class__.__name__, typeId)) + if LOG: + LOG('using value codec %s chosen by type ID %s' % (concreteEncoder.__class__.__name__, typeId)) except KeyError: if asn1Spec is None: @@ -657,13 +757,13 @@ class Encoder(object): except KeyError: raise error.PyAsn1Error('No encoder for %r (%s)' % (value, tagSet)) - if logger: - logger('using value codec %s chosen by tagSet %s' % (concreteEncoder.__class__.__name__, tagSet)) + if LOG: + LOG('using value codec %s chosen by tagSet %s' % (concreteEncoder.__class__.__name__, tagSet)) substrate = concreteEncoder.encode(value, asn1Spec, self, **options) - if logger: - logger('codec %s built %s octets of substrate: %s\nencoder completed' % (concreteEncoder, len(substrate), debug.hexdump(substrate))) + if LOG: + LOG('codec %s built %s octets of substrate: %s\nencoder completed' % (concreteEncoder, len(substrate), debug.hexdump(substrate))) return substrate diff --git a/src/pyasn1/codec/ber/eoo.py b/src/pyasn1/codec/ber/eoo.py index d4cd827a..b613b530 100644 --- a/src/pyasn1/codec/ber/eoo.py +++ b/src/pyasn1/codec/ber/eoo.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1.type import base diff --git a/src/pyasn1/codec/cer/decoder.py b/src/pyasn1/codec/cer/decoder.py index 66572ecb..5099e3c2 100644 --- a/src/pyasn1/codec/cer/decoder.py +++ b/src/pyasn1/codec/cer/decoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import error diff --git a/src/pyasn1/codec/cer/encoder.py b/src/pyasn1/codec/cer/encoder.py index 768d3c11..788567f2 100644 --- a/src/pyasn1/codec/cer/encoder.py +++ b/src/pyasn1/codec/cer/encoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import error @@ -41,7 +41,7 @@ class TimeEncoderMixIn(object): def encodeValue(self, value, asn1Spec, encodeFun, **options): # Encoding constraints: # - minutes are mandatory, seconds are optional - # - subseconds must NOT be zero + # - sub-seconds must NOT be zero # - no hanging fraction dot # - time in UTC (Z) # - only dot is allowed for fractions diff --git a/src/pyasn1/codec/der/decoder.py b/src/pyasn1/codec/der/decoder.py index f67d0250..261bab89 100644 --- a/src/pyasn1/codec/der/decoder.py +++ b/src/pyasn1/codec/der/decoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1.codec.cer import decoder diff --git a/src/pyasn1/codec/der/encoder.py b/src/pyasn1/codec/der/encoder.py index 756d9fe9..5e3c5717 100644 --- a/src/pyasn1/codec/der/encoder.py +++ b/src/pyasn1/codec/der/encoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import error diff --git a/src/pyasn1/codec/native/decoder.py b/src/pyasn1/codec/native/decoder.py index 78fcda68..10e20158 100644 --- a/src/pyasn1/codec/native/decoder.py +++ b/src/pyasn1/codec/native/decoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import debug @@ -14,6 +14,8 @@ from pyasn1.type import useful __all__ = ['decode'] +LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_DECODER) + class AbstractScalarDecoder(object): def __call__(self, pyObject, asn1Spec, decodeFun=None, **options): @@ -136,13 +138,10 @@ class Decoder(object): self.__typeMap = typeMap def __call__(self, pyObject, asn1Spec, **options): - if debug.logger & debug.flagDecoder: - logger = debug.logger - else: - logger = None - if logger: + + if LOG: debug.scope.push(type(pyObject).__name__) - logger('decoder called at scope %s, working with type %s' % (debug.scope, type(pyObject).__name__)) + LOG('decoder called at scope %s, working with type %s' % (debug.scope, type(pyObject).__name__)) if asn1Spec is None or not isinstance(asn1Spec, base.Asn1Item): raise error.PyAsn1Error('asn1Spec is not valid (should be an instance of an ASN.1 Item, not %s)' % asn1Spec.__class__.__name__) @@ -159,13 +158,13 @@ class Decoder(object): except KeyError: raise error.PyAsn1Error('Unknown ASN.1 tag %s' % asn1Spec.tagSet) - if logger: - logger('calling decoder %s on Python type %s <%s>' % (type(valueDecoder).__name__, type(pyObject).__name__, repr(pyObject))) + if LOG: + LOG('calling decoder %s on Python type %s <%s>' % (type(valueDecoder).__name__, type(pyObject).__name__, repr(pyObject))) value = valueDecoder(pyObject, asn1Spec, self, **options) - if logger: - logger('decoder %s produced ASN.1 type %s <%s>' % (type(valueDecoder).__name__, type(value).__name__, repr(value))) + if LOG: + LOG('decoder %s produced ASN.1 type %s <%s>' % (type(valueDecoder).__name__, type(value).__name__, repr(value))) debug.scope.pop() return value diff --git a/src/pyasn1/codec/native/encoder.py b/src/pyasn1/codec/native/encoder.py index 87e50f2b..50caa531 100644 --- a/src/pyasn1/codec/native/encoder.py +++ b/src/pyasn1/codec/native/encoder.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # try: @@ -20,6 +20,8 @@ from pyasn1.type import useful __all__ = ['encode'] +LOG = debug.registerLoggee(__name__, flags=debug.DEBUG_ENCODER) + class AbstractItemEncoder(object): def encode(self, value, encodeFun, **options): @@ -132,14 +134,40 @@ tagMap = { useful.UTCTime.tagSet: OctetStringEncoder() } -# Type-to-codec map for ambiguous ASN.1 types + +# Put in ambiguous & non-ambiguous types for faster codec lookup typeMap = { + univ.Boolean.typeId: BooleanEncoder(), + univ.Integer.typeId: IntegerEncoder(), + univ.BitString.typeId: BitStringEncoder(), + univ.OctetString.typeId: OctetStringEncoder(), + univ.Null.typeId: NullEncoder(), + univ.ObjectIdentifier.typeId: ObjectIdentifierEncoder(), + univ.Enumerated.typeId: IntegerEncoder(), + univ.Real.typeId: RealEncoder(), + # Sequence & Set have same tags as SequenceOf & SetOf univ.Set.typeId: SetEncoder(), univ.SetOf.typeId: SequenceOfEncoder(), univ.Sequence.typeId: SequenceEncoder(), univ.SequenceOf.typeId: SequenceOfEncoder(), univ.Choice.typeId: ChoiceEncoder(), - univ.Any.typeId: AnyEncoder() + univ.Any.typeId: AnyEncoder(), + # character string types + char.UTF8String.typeId: OctetStringEncoder(), + char.NumericString.typeId: OctetStringEncoder(), + char.PrintableString.typeId: OctetStringEncoder(), + char.TeletexString.typeId: OctetStringEncoder(), + char.VideotexString.typeId: OctetStringEncoder(), + char.IA5String.typeId: OctetStringEncoder(), + char.GraphicString.typeId: OctetStringEncoder(), + char.VisibleString.typeId: OctetStringEncoder(), + char.GeneralString.typeId: OctetStringEncoder(), + char.UniversalString.typeId: OctetStringEncoder(), + char.BMPString.typeId: OctetStringEncoder(), + # useful types + useful.ObjectDescriptor.typeId: OctetStringEncoder(), + useful.GeneralizedTime.typeId: OctetStringEncoder(), + useful.UTCTime.typeId: OctetStringEncoder() } @@ -154,14 +182,9 @@ class Encoder(object): if not isinstance(value, base.Asn1Item): raise error.PyAsn1Error('value is not valid (should be an instance of an ASN.1 Item)') - if debug.logger & debug.flagEncoder: - logger = debug.logger - else: - logger = None - - if logger: + if LOG: debug.scope.push(type(value).__name__) - logger('encoder called for type %s <%s>' % (type(value).__name__, value.prettyPrint())) + LOG('encoder called for type %s <%s>' % (type(value).__name__, value.prettyPrint())) tagSet = value.tagSet @@ -178,13 +201,13 @@ class Encoder(object): except KeyError: raise error.PyAsn1Error('No encoder for %s' % (value,)) - if logger: - logger('using value codec %s chosen by %s' % (concreteEncoder.__class__.__name__, tagSet)) + if LOG: + LOG('using value codec %s chosen by %s' % (concreteEncoder.__class__.__name__, tagSet)) pyObject = concreteEncoder.encode(value, self, **options) - if logger: - logger('encoder %s produced: %s' % (type(concreteEncoder).__name__, repr(pyObject))) + if LOG: + LOG('encoder %s produced: %s' % (type(concreteEncoder).__name__, repr(pyObject))) debug.scope.pop() return pyObject diff --git a/src/pyasn1/compat/binary.py b/src/pyasn1/compat/binary.py index c38a6508..addbdc9c 100644 --- a/src/pyasn1/compat/binary.py +++ b/src/pyasn1/compat/binary.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from sys import version_info diff --git a/src/pyasn1/compat/calling.py b/src/pyasn1/compat/calling.py index c60b50d8..778a3d15 100644 --- a/src/pyasn1/compat/calling.py +++ b/src/pyasn1/compat/calling.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from sys import version_info diff --git a/src/pyasn1/compat/dateandtime.py b/src/pyasn1/compat/dateandtime.py index 27526ade..5e471bf7 100644 --- a/src/pyasn1/compat/dateandtime.py +++ b/src/pyasn1/compat/dateandtime.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import time diff --git a/src/pyasn1/compat/integer.py b/src/pyasn1/compat/integer.py index bb3d099e..4b31791d 100644 --- a/src/pyasn1/compat/integer.py +++ b/src/pyasn1/compat/integer.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import sys diff --git a/src/pyasn1/compat/octets.py b/src/pyasn1/compat/octets.py index a06db5dd..99d23bb3 100644 --- a/src/pyasn1/compat/octets.py +++ b/src/pyasn1/compat/octets.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from sys import version_info diff --git a/src/pyasn1/compat/string.py b/src/pyasn1/compat/string.py index 4d8a045a..b9bc8c38 100644 --- a/src/pyasn1/compat/string.py +++ b/src/pyasn1/compat/string.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from sys import version_info diff --git a/src/pyasn1/debug.py b/src/pyasn1/debug.py index ab72fa84..8707aa88 100644 --- a/src/pyasn1/debug.py +++ b/src/pyasn1/debug.py @@ -1,10 +1,11 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import logging +import sys from pyasn1 import __version__ from pyasn1 import error @@ -12,18 +13,20 @@ from pyasn1.compat.octets import octs2ints __all__ = ['Debug', 'setLogger', 'hexdump'] -flagNone = 0x0000 -flagEncoder = 0x0001 -flagDecoder = 0x0002 -flagAll = 0xffff +DEBUG_NONE = 0x0000 +DEBUG_ENCODER = 0x0001 +DEBUG_DECODER = 0x0002 +DEBUG_ALL = 0xffff -flagMap = { - 'none': flagNone, - 'encoder': flagEncoder, - 'decoder': flagDecoder, - 'all': flagAll +FLAG_MAP = { + 'none': DEBUG_NONE, + 'encoder': DEBUG_ENCODER, + 'decoder': DEBUG_DECODER, + 'all': DEBUG_ALL } +LOGGEE_MAP = {} + class Printer(object): # noinspection PyShadowingNames @@ -66,7 +69,7 @@ class Debug(object): defaultPrinter = Printer() def __init__(self, *flags, **options): - self._flags = flagNone + self._flags = DEBUG_NONE if 'loggerName' in options: # route our logs to parent logger @@ -89,9 +92,9 @@ class Debug(object): flag = flag[1:] try: if inverse: - self._flags &= ~flagMap[flag] + self._flags &= ~FLAG_MAP[flag] else: - self._flags |= flagMap[flag] + self._flags |= FLAG_MAP[flag] except KeyError: raise error.PyAsn1Error('bad debug flag %s' % flag) @@ -109,17 +112,26 @@ class Debug(object): def __rand__(self, flag): return flag & self._flags - -logger = 0 +_LOG = DEBUG_NONE def setLogger(userLogger): - global logger + global _LOG if userLogger: - logger = userLogger + _LOG = userLogger else: - logger = 0 + _LOG = DEBUG_NONE + + # Update registered logging clients + for module, (name, flags) in LOGGEE_MAP.items(): + setattr(module, name, _LOG & flags and _LOG or DEBUG_NONE) + + +def registerLoggee(module, name='LOG', flags=DEBUG_NONE): + LOGGEE_MAP[sys.modules[module]] = name, flags + setLogger(_LOG) + return _LOG def hexdump(octets): diff --git a/src/pyasn1/error.py b/src/pyasn1/error.py index c05e65c6..7f606bbd 100644 --- a/src/pyasn1/error.py +++ b/src/pyasn1/error.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # @@ -25,5 +25,5 @@ class SubstrateUnderrunError(PyAsn1Error): """Create pyasn1 exception object The `SubstrateUnderrunError` exception indicates insufficient serialised - data on input of a deserialisation routine. + data on input of a de-serialization routine. """ diff --git a/src/pyasn1/type/base.py b/src/pyasn1/type/base.py index adaab228..7995118b 100644 --- a/src/pyasn1/type/base.py +++ b/src/pyasn1/type/base.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import sys @@ -343,10 +343,10 @@ class AbstractSimpleAsn1Item(Asn1ItemBase): value = self._value - initilaizers = self.readOnly.copy() - initilaizers.update(kwargs) + initializers = self.readOnly.copy() + initializers.update(kwargs) - return self.__class__(value, **initilaizers) + return self.__class__(value, **initializers) def subtype(self, value=noValue, **kwargs): """Create a specialization of |ASN.1| schema or value object. @@ -540,10 +540,10 @@ class AbstractConstructedAsn1Item(Asn1ItemBase): """ cloneValueFlag = kwargs.pop('cloneValueFlag', False) - initilaizers = self.readOnly.copy() - initilaizers.update(kwargs) + initializers = self.readOnly.copy() + initializers.update(kwargs) - clone = self.__class__(**initilaizers) + clone = self.__class__(**initializers) if cloneValueFlag: self._cloneComponentValues(clone, cloneValueFlag) diff --git a/src/pyasn1/type/char.py b/src/pyasn1/type/char.py index 493badb9..617b98dd 100644 --- a/src/pyasn1/type/char.py +++ b/src/pyasn1/type/char.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import sys diff --git a/src/pyasn1/type/constraint.py b/src/pyasn1/type/constraint.py index a7043310..9d8883dd 100644 --- a/src/pyasn1/type/constraint.py +++ b/src/pyasn1/type/constraint.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Original concept and code by Mike C. Fletcher. @@ -352,7 +352,7 @@ class InnerTypeConstraint(AbstractConstraint): if idx not in self.__multipleTypeConstraint: raise error.ValueConstraintError(value) constraint, status = self.__multipleTypeConstraint[idx] - if status == 'ABSENT': # XXX presense is not checked! + if status == 'ABSENT': # XXX presence is not checked! raise error.ValueConstraintError(value) constraint(value) diff --git a/src/pyasn1/type/error.py b/src/pyasn1/type/error.py index b2056bd6..80fcf3bd 100644 --- a/src/pyasn1/type/error.py +++ b/src/pyasn1/type/error.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1.error import PyAsn1Error diff --git a/src/pyasn1/type/namedtype.py b/src/pyasn1/type/namedtype.py index f162d194..71f5f11c 100644 --- a/src/pyasn1/type/namedtype.py +++ b/src/pyasn1/type/namedtype.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import sys @@ -265,18 +265,18 @@ class NamedTypes(object): return nameToPosMap def __computeAmbiguousTypes(self): - ambigiousTypes = {} - partialAmbigiousTypes = () + ambiguousTypes = {} + partialAmbiguousTypes = () for idx, namedType in reversed(tuple(enumerate(self.__namedTypes))): if namedType.isOptional or namedType.isDefaulted: - partialAmbigiousTypes = (namedType,) + partialAmbigiousTypes + partialAmbiguousTypes = (namedType,) + partialAmbiguousTypes else: - partialAmbigiousTypes = (namedType,) - if len(partialAmbigiousTypes) == len(self.__namedTypes): - ambigiousTypes[idx] = self + partialAmbiguousTypes = (namedType,) + if len(partialAmbiguousTypes) == len(self.__namedTypes): + ambiguousTypes[idx] = self else: - ambigiousTypes[idx] = NamedTypes(*partialAmbigiousTypes, **dict(terminal=True)) - return ambigiousTypes + ambiguousTypes[idx] = NamedTypes(*partialAmbiguousTypes, **dict(terminal=True)) + return ambiguousTypes def getTypeByPosition(self, idx): """Return ASN.1 type object by its position in fields set. diff --git a/src/pyasn1/type/namedval.py b/src/pyasn1/type/namedval.py index 59257e48..2233aaf7 100644 --- a/src/pyasn1/type/namedval.py +++ b/src/pyasn1/type/namedval.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # ASN.1 named integers diff --git a/src/pyasn1/type/opentype.py b/src/pyasn1/type/opentype.py index d14ab340..d37a533b 100644 --- a/src/pyasn1/type/opentype.py +++ b/src/pyasn1/type/opentype.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # diff --git a/src/pyasn1/type/tag.py b/src/pyasn1/type/tag.py index 95c226f6..b46f491c 100644 --- a/src/pyasn1/type/tag.py +++ b/src/pyasn1/type/tag.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import error diff --git a/src/pyasn1/type/tagmap.py b/src/pyasn1/type/tagmap.py index a9d237f2..e53a14d5 100644 --- a/src/pyasn1/type/tagmap.py +++ b/src/pyasn1/type/tagmap.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # from pyasn1 import error diff --git a/src/pyasn1/type/univ.py b/src/pyasn1/type/univ.py index a19f6baa..7fab69f2 100644 --- a/src/pyasn1/type/univ.py +++ b/src/pyasn1/type/univ.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import math @@ -2347,7 +2347,9 @@ class SequenceAndSetBase(base.AbstractConstructedAsn1Item): if value is noValue: if componentTypeLen: - value = componentType.getTypeByPosition(idx).clone() + value = componentType.getTypeByPosition(idx) + if isinstance(value, base.AbstractConstructedAsn1Item): + value = value.clone(cloneValueFlag=componentType[idx].isDefaulted) elif currentValue is noValue: raise error.PyAsn1Error('Component type not defined') @@ -2457,7 +2459,7 @@ class SequenceAndSetBase(base.AbstractConstructedAsn1Item): scope += 1 representation = self.__class__.__name__ + ':\n' for idx, componentValue in enumerate(self._componentValues): - if componentValue is not noValue: + if componentValue is not noValue and componentValue.isValue: representation += ' ' * scope if self.componentType: representation += self.componentType.getNameByPosition(idx) diff --git a/src/pyasn1/type/useful.py b/src/pyasn1/type/useful.py index 146916d4..7536b95c 100644 --- a/src/pyasn1/type/useful.py +++ b/src/pyasn1/type/useful.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1 software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import datetime diff --git a/src/pyasn1_modules/__init__.py b/src/pyasn1_modules/__init__.py index a3aedb62..def0b77e 100644 --- a/src/pyasn1_modules/__init__.py +++ b/src/pyasn1_modules/__init__.py @@ -1,2 +1,2 @@ # http://www.python.org/dev/peps/pep-0396/ -__version__ = '0.2.2' +__version__ = '0.2.4' diff --git a/src/pyasn1_modules/pem.py b/src/pyasn1_modules/pem.py index e72b97fd..a6090bdd 100644 --- a/src/pyasn1_modules/pem.py +++ b/src/pyasn1_modules/pem.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # import base64 diff --git a/src/pyasn1_modules/rfc1155.py b/src/pyasn1_modules/rfc1155.py index efe39bc3..611e97eb 100644 --- a/src/pyasn1_modules/rfc1155.py +++ b/src/pyasn1_modules/rfc1155.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv1 message syntax diff --git a/src/pyasn1_modules/rfc1157.py b/src/pyasn1_modules/rfc1157.py index c616dfcf..b80d926a 100644 --- a/src/pyasn1_modules/rfc1157.py +++ b/src/pyasn1_modules/rfc1157.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv1 message syntax diff --git a/src/pyasn1_modules/rfc1901.py b/src/pyasn1_modules/rfc1901.py index 16c83327..04533da0 100644 --- a/src/pyasn1_modules/rfc1901.py +++ b/src/pyasn1_modules/rfc1901.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv2c message syntax diff --git a/src/pyasn1_modules/rfc1902.py b/src/pyasn1_modules/rfc1902.py index b4373f5e..d1a16489 100644 --- a/src/pyasn1_modules/rfc1902.py +++ b/src/pyasn1_modules/rfc1902.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv2c message syntax diff --git a/src/pyasn1_modules/rfc1905.py b/src/pyasn1_modules/rfc1905.py index e35f37df..567e818a 100644 --- a/src/pyasn1_modules/rfc1905.py +++ b/src/pyasn1_modules/rfc1905.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv2c PDU syntax diff --git a/src/pyasn1_modules/rfc2251.py b/src/pyasn1_modules/rfc2251.py index 88ee9a87..84c3d87c 100644 --- a/src/pyasn1_modules/rfc2251.py +++ b/src/pyasn1_modules/rfc2251.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # LDAP message syntax diff --git a/src/pyasn1_modules/rfc2314.py b/src/pyasn1_modules/rfc2314.py index 5a6d9273..a4532176 100644 --- a/src/pyasn1_modules/rfc2314.py +++ b/src/pyasn1_modules/rfc2314.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS#10 syntax diff --git a/src/pyasn1_modules/rfc2315.py b/src/pyasn1_modules/rfc2315.py index c7e53b9b..932c9849 100644 --- a/src/pyasn1_modules/rfc2315.py +++ b/src/pyasn1_modules/rfc2315.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS#7 message syntax diff --git a/src/pyasn1_modules/rfc2437.py b/src/pyasn1_modules/rfc2437.py index 0866f570..1139eb4b 100644 --- a/src/pyasn1_modules/rfc2437.py +++ b/src/pyasn1_modules/rfc2437.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS#1 syntax diff --git a/src/pyasn1_modules/rfc2459.py b/src/pyasn1_modules/rfc2459.py index 3d00adf7..071e5dab 100644 --- a/src/pyasn1_modules/rfc2459.py +++ b/src/pyasn1_modules/rfc2459.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # X.509 message syntax diff --git a/src/pyasn1_modules/rfc2511.py b/src/pyasn1_modules/rfc2511.py index 00ef4419..6b3c37ce 100644 --- a/src/pyasn1_modules/rfc2511.py +++ b/src/pyasn1_modules/rfc2511.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # X.509 certificate Request Message Format (CRMF) syntax diff --git a/src/pyasn1_modules/rfc2560.py b/src/pyasn1_modules/rfc2560.py index f6e0df07..c37e25b6 100644 --- a/src/pyasn1_modules/rfc2560.py +++ b/src/pyasn1_modules/rfc2560.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # OCSP request/response syntax diff --git a/src/pyasn1_modules/rfc2986.py b/src/pyasn1_modules/rfc2986.py index 47562c0b..014f2cb9 100644 --- a/src/pyasn1_modules/rfc2986.py +++ b/src/pyasn1_modules/rfc2986.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Joel Johnson with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS #10: Certification Request Syntax Specification diff --git a/src/pyasn1_modules/rfc3280.py b/src/pyasn1_modules/rfc3280.py index 58dba38b..6c45b8fe 100644 --- a/src/pyasn1_modules/rfc3280.py +++ b/src/pyasn1_modules/rfc3280.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Internet X.509 Public Key Infrastructure Certificate and Certificate diff --git a/src/pyasn1_modules/rfc3281.py b/src/pyasn1_modules/rfc3281.py index 9378a45e..39ce8242 100644 --- a/src/pyasn1_modules/rfc3281.py +++ b/src/pyasn1_modules/rfc3281.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # An Internet Attribute Certificate Profile for Authorization diff --git a/src/pyasn1_modules/rfc3412.py b/src/pyasn1_modules/rfc3412.py index 8644c627..59f84959 100644 --- a/src/pyasn1_modules/rfc3412.py +++ b/src/pyasn1_modules/rfc3412.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv3 message syntax diff --git a/src/pyasn1_modules/rfc3414.py b/src/pyasn1_modules/rfc3414.py index 28183796..b9087cb5 100644 --- a/src/pyasn1_modules/rfc3414.py +++ b/src/pyasn1_modules/rfc3414.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # SNMPv3 message syntax diff --git a/src/pyasn1_modules/rfc3447.py b/src/pyasn1_modules/rfc3447.py index ff5c6b52..a5499feb 100644 --- a/src/pyasn1_modules/rfc3447.py +++ b/src/pyasn1_modules/rfc3447.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS#1 syntax diff --git a/src/pyasn1_modules/rfc3852.py b/src/pyasn1_modules/rfc3852.py index 04b215e3..7c8f6c6f 100644 --- a/src/pyasn1_modules/rfc3852.py +++ b/src/pyasn1_modules/rfc3852.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Cryptographic Message Syntax (CMS) diff --git a/src/pyasn1_modules/rfc4210.py b/src/pyasn1_modules/rfc4210.py index 39b468f5..c43fc608 100644 --- a/src/pyasn1_modules/rfc4210.py +++ b/src/pyasn1_modules/rfc4210.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Certificate Management Protocol structures as per RFC4210 diff --git a/src/pyasn1_modules/rfc4211.py b/src/pyasn1_modules/rfc4211.py index 01c10cd5..cc792611 100644 --- a/src/pyasn1_modules/rfc4211.py +++ b/src/pyasn1_modules/rfc4211.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Internet X.509 Public Key Infrastructure Certificate Request diff --git a/src/pyasn1_modules/rfc5083.py b/src/pyasn1_modules/rfc5083.py new file mode 100644 index 00000000..5240aaa5 --- /dev/null +++ b/src/pyasn1_modules/rfc5083.py @@ -0,0 +1,46 @@ +# This file is being contributed to of pyasn1-modules software. +# +# Created by Russ Housley without assistance from the asn1ate tool. +# Copyright (c) 2018, Vigil Security, LLC +# License: http://snmplabs.com/pyasn1/license.html +# +# Authenticated-Enveloped-Data for the Cryptographic Message Syntax (CMS) +# +# ASN.1 source from: +# https://www.rfc-editor.org/rfc/rfc5083.txt + +from pyasn1.type import namedtype, tag, univ +from pyasn1_modules import rfc5652 + + +MAX = float('inf') + + +def _buildOid(*components): + output = [] + for x in tuple(components): + if isinstance(x, univ.ObjectIdentifier): + output.extend(list(x)) + else: + output.append(int(x)) + return univ.ObjectIdentifier(output) + + +id_ct_authEnvelopedData = _buildOid(1, 2, 840, 113549, 1, 9, 16, 1, 23) + + +class AuthEnvelopedData(univ.Sequence): + pass + +AuthEnvelopedData.componentType = namedtype.NamedTypes( + namedtype.NamedType('version', rfc5652.CMSVersion()), + namedtype.OptionalNamedType('originatorInfo', rfc5652.OriginatorInfo().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatConstructed, 0))), + namedtype.NamedType('recipientInfos', rfc5652.RecipientInfos()), + namedtype.NamedType('authEncryptedContentInfo', rfc5652.EncryptedContentInfo()), + namedtype.OptionalNamedType('authAttrs', rfc5652.AuthAttributes().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 1))), + namedtype.NamedType('mac', rfc5652.MessageAuthenticationCode()), + namedtype.OptionalNamedType('unauthAttrs', rfc5652.UnauthAttributes().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2))) +) diff --git a/src/pyasn1_modules/rfc5084.py b/src/pyasn1_modules/rfc5084.py new file mode 100644 index 00000000..99ce60d6 --- /dev/null +++ b/src/pyasn1_modules/rfc5084.py @@ -0,0 +1,87 @@ +# This file is being contributed to pyasn1-modules software. +# +# Created by Russ Housley with assistance from the asn1ate tool, with manual +# changes to AES_CCM_ICVlen.subtypeSpec and added comments +# +# Copyright (c) 2018-2019, Vigil Security, LLC +# License: http://snmplabs.com/pyasn1/license.html +# +# AES-CCM and AES-GCM Algorithms fo use with the Authenticated-Enveloped-Data +# protecting content type for the Cryptographic Message Syntax (CMS) +# +# ASN.1 source from: +# https://www.rfc-editor.org/rfc/rfc5084.txt + + +from pyasn1.type import univ, char, namedtype, namedval, tag, constraint, useful + + +def _OID(*components): + output = [] + for x in tuple(components): + if isinstance(x, univ.ObjectIdentifier): + output.extend(list(x)) + else: + output.append(int(x)) + + return univ.ObjectIdentifier(output) + + +class AES_CCM_ICVlen(univ.Integer): + pass + + +class AES_GCM_ICVlen(univ.Integer): + pass + + +AES_CCM_ICVlen.subtypeSpec = constraint.SingleValueConstraint(4, 6, 8, 10, 12, 14, 16) + + +AES_GCM_ICVlen.subtypeSpec = constraint.ValueRangeConstraint(12, 16) + + +class CCMParameters(univ.Sequence): + pass + + +CCMParameters.componentType = namedtype.NamedTypes( + namedtype.NamedType('aes-nonce', univ.OctetString().subtype(subtypeSpec=constraint.ValueSizeConstraint(7, 13))), + # The aes-nonce parameter contains 15-L octets, where L is the size of the length field. L=8 is RECOMMENDED. + # Within the scope of any content-authenticated-encryption key, the nonce value MUST be unique. + namedtype.DefaultedNamedType('aes-ICVlen', AES_CCM_ICVlen().subtype(value=12)) +) + + +class GCMParameters(univ.Sequence): + pass + + +GCMParameters.componentType = namedtype.NamedTypes( + namedtype.NamedType('aes-nonce', univ.OctetString()), + # The aes-nonce may have any number of bits between 8 and 2^64, but it MUST be a multiple of 8 bits. + # Within the scope of any content-authenticated-encryption key, the nonce value MUST be unique. + # A nonce value of 12 octets can be processed more efficiently, so that length is RECOMMENDED. + namedtype.DefaultedNamedType('aes-ICVlen', AES_GCM_ICVlen().subtype(value=12)) +) + + +aes = _OID(2, 16, 840, 1, 101, 3, 4, 1) + + +id_aes128_CCM = _OID(aes, 7) + + +id_aes128_GCM = _OID(aes, 6) + + +id_aes192_CCM = _OID(aes, 27) + + +id_aes192_GCM = _OID(aes, 26) + + +id_aes256_CCM = _OID(aes, 47) + + +id_aes256_GCM = _OID(aes, 46) diff --git a/src/pyasn1_modules/rfc5208.py b/src/pyasn1_modules/rfc5208.py index 85bb5302..14082a89 100644 --- a/src/pyasn1_modules/rfc5208.py +++ b/src/pyasn1_modules/rfc5208.py @@ -1,7 +1,7 @@ # # This file is part of pyasn1-modules software. # -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # PKCS#8 syntax diff --git a/src/pyasn1_modules/rfc5280.py b/src/pyasn1_modules/rfc5280.py index 1a01352c..80bded50 100644 --- a/src/pyasn1_modules/rfc5280.py +++ b/src/pyasn1_modules/rfc5280.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Internet X.509 Public Key Infrastructure Certificate and Certificate @@ -283,7 +283,7 @@ class CertificateSerialNumber(univ.Integer): class AlgorithmIdentifier(univ.Sequence): componentType = namedtype.NamedTypes( namedtype.NamedType('algorithm', univ.ObjectIdentifier()), - namedtype.OptionalNamedType('parameters', univ.Any()) + namedtype.OptionalNamedType('parameters', univ.Any(), openType=opentype.OpenType) ) diff --git a/src/pyasn1_modules/rfc5652.py b/src/pyasn1_modules/rfc5652.py index 309d1d61..094ce746 100644 --- a/src/pyasn1_modules/rfc5652.py +++ b/src/pyasn1_modules/rfc5652.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Cryptographic Message Syntax (CMS) diff --git a/src/pyasn1_modules/rfc6402.py b/src/pyasn1_modules/rfc6402.py index 3814a3d2..7c9f8629 100644 --- a/src/pyasn1_modules/rfc6402.py +++ b/src/pyasn1_modules/rfc6402.py @@ -3,7 +3,7 @@ # This file is part of pyasn1-modules software. # # Created by Stanisław Pitucha with asn1ate tool. -# Copyright (c) 2005-2018, Ilya Etingof +# Copyright (c) 2005-2019, Ilya Etingof # License: http://snmplabs.com/pyasn1/license.html # # Certificate Management over CMS (CMC) Updates diff --git a/src/pyasn1_modules/rfc8103.py b/src/pyasn1_modules/rfc8103.py new file mode 100644 index 00000000..5e2d787f --- /dev/null +++ b/src/pyasn1_modules/rfc8103.py @@ -0,0 +1,38 @@ +# This file is being contributed to pyasn1-modules software. +# +# Created by Russ Housley with assistance from the asn1ate tool. +# Auto-generated by asn1ate v.0.6.0 from rfc8103.asn. +# +# Copyright (c) 2019, Vigil Security, LLC +# License: http://snmplabs.com/pyasn1/license.html +# +# ChaCha20Poly1305 algorithm fo use with the Authenticated-Enveloped-Data +# protecting content type for the Cryptographic Message Syntax (CMS) +# +# ASN.1 source from: +# https://www.rfc-editor.org/rfc/rfc8103.txt + +from pyasn1.type import univ, char, namedtype, namedval, tag, constraint, useful + + +def _OID(*components): + output = [] + for x in tuple(components): + if isinstance(x, univ.ObjectIdentifier): + output.extend(list(x)) + else: + output.append(int(x)) + + return univ.ObjectIdentifier(output) + + +class AEADChaCha20Poly1305Nonce(univ.OctetString): + pass + + +AEADChaCha20Poly1305Nonce.subtypeSpec = constraint.ValueSizeConstraint(12, 12) + + +id_alg_AEADChaCha20Poly1305 = _OID(1, 2, 840, 113549, 1, 9, 16, 3, 18) + + diff --git a/src/pyasn1_modules/rfc8226.py b/src/pyasn1_modules/rfc8226.py new file mode 100644 index 00000000..cd9bfd15 --- /dev/null +++ b/src/pyasn1_modules/rfc8226.py @@ -0,0 +1,123 @@ +# This file is being contributed to pyasn1-modules software. +# +# Created by Russ Housley with assistance from the asn1ate tool, with manual +# changes to implement appropriate constraints and added comments +# +# Copyright (c) 2019, Vigil Security, LLC +# License: http://snmplabs.com/pyasn1/license.html +# +# JWT Claim Constraints and TN Authorization List for certificate extensions. +# +# ASN.1 source from: +# https://www.rfc-editor.org/rfc/rfc8226.txt (with errata corrected) + +from pyasn1.type import univ, char, namedtype, namedval, tag, constraint, useful + + +MAX = float('inf') + + +def _OID(*components): + output = [] + for x in tuple(components): + if isinstance(x, univ.ObjectIdentifier): + output.extend(list(x)) + else: + output.append(int(x)) + + return univ.ObjectIdentifier(output) + + +class JWTClaimName(char.IA5String): + pass + + +class JWTClaimNames(univ.SequenceOf): + pass + + +JWTClaimNames.componentType = JWTClaimName() +JWTClaimNames.subtypeSpec=constraint.ValueSizeConstraint(1, MAX) + + +class JWTClaimPermittedValues(univ.Sequence): + pass + + +JWTClaimPermittedValues.componentType = namedtype.NamedTypes( + namedtype.NamedType('claim', JWTClaimName()), + namedtype.NamedType('permitted', univ.SequenceOf(componentType=char.UTF8String()).subtype(subtypeSpec=constraint.ValueSizeConstraint(1, MAX))) +) + + +class JWTClaimPermittedValuesList(univ.SequenceOf): + pass + + +JWTClaimPermittedValuesList.componentType = JWTClaimPermittedValues() +JWTClaimPermittedValuesList.subtypeSpec=constraint.ValueSizeConstraint(1, MAX) + + +class JWTClaimConstraints(univ.Sequence): + pass + + +JWTClaimConstraints.componentType = namedtype.NamedTypes( + namedtype.OptionalNamedType('mustInclude', JWTClaimNames().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 0))), + namedtype.OptionalNamedType('permittedValues', JWTClaimPermittedValuesList().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 1))) +) + + +JWTClaimConstraints.sizeSpec = univ.Sequence.sizeSpec + constraint.ValueSizeConstraint(1, 2) + + +id_pe_JWTClaimConstraints = _OID(1, 3, 6, 1, 5, 5, 7, 1, 27) + + +class ServiceProviderCode(char.IA5String): + pass + + +class TelephoneNumber(char.IA5String): + pass + + +TelephoneNumber.subtypeSpec = constraint.ConstraintsIntersection( + constraint.ValueSizeConstraint(1, 15), + constraint.PermittedAlphabetConstraint('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '#', '*') +) + + +class TelephoneNumberRange(univ.Sequence): + pass + + +TelephoneNumberRange.componentType = namedtype.NamedTypes( + namedtype.NamedType('start', TelephoneNumber()), + namedtype.NamedType('count', univ.Integer().subtype(subtypeSpec=constraint.ValueRangeConstraint(2, MAX))) +) + + +class TNEntry(univ.Choice): + pass + + +TNEntry.componentType = namedtype.NamedTypes( + namedtype.NamedType('spc', ServiceProviderCode().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 0))), + namedtype.NamedType('range', TelephoneNumberRange().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatConstructed, 1))), + namedtype.NamedType('one', TelephoneNumber().subtype(explicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2))) +) + + +class TNAuthorizationList(univ.SequenceOf): + pass + + +TNAuthorizationList.componentType = TNEntry() +TNAuthorizationList.subtypeSpec=constraint.ValueSizeConstraint(1, MAX) + + +id_pe_TNAuthList = _OID(1, 3, 6, 1, 5, 5, 7, 1, 26) + + +id_ad_stirTNList = _OID(1, 3, 6, 1, 5, 5, 7, 48, 14) diff --git a/src/rsa/__init__.py b/src/rsa/__init__.py index c572c06b..9b05c6c8 100644 --- a/src/rsa/__init__.py +++ b/src/rsa/__init__.py @@ -18,19 +18,18 @@ Module for calculating large primes, and RSA encryption, decryption, signing and verification. Includes generating public and private keys. -WARNING: this implementation does not use random padding, compression of the -cleartext input to prevent repetitions, or other common security improvements. -Use with care. +WARNING: this implementation does not use compression of the cleartext input to +prevent repetitions, or other common security improvements. Use with care. """ from rsa.key import newkeys, PrivateKey, PublicKey from rsa.pkcs1 import encrypt, decrypt, sign, verify, DecryptionError, \ - VerificationError + VerificationError, find_signature_hash, sign_hash, compute_hash __author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly" -__date__ = "2016-03-29" -__version__ = '3.4.2' +__date__ = "2018-09-16" +__version__ = '4.0' # Do doctest if we're run directly if __name__ == "__main__": @@ -39,4 +38,5 @@ if __name__ == "__main__": doctest.testmod() __all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify", 'PublicKey', - 'PrivateKey', 'DecryptionError', 'VerificationError'] + 'PrivateKey', 'DecryptionError', 'VerificationError', + 'compute_hash', 'sign_hash'] diff --git a/src/rsa/_compat.py b/src/rsa/_compat.py index 93393d9f..71197a55 100644 --- a/src/rsa/_compat.py +++ b/src/rsa/_compat.py @@ -18,18 +18,17 @@ from __future__ import absolute_import +import itertools import sys from struct import pack -try: - MAX_INT = sys.maxsize -except AttributeError: - MAX_INT = sys.maxint - +MAX_INT = sys.maxsize MAX_INT64 = (1 << 63) - 1 MAX_INT32 = (1 << 31) - 1 MAX_INT16 = (1 << 15) - 1 +PY2 = sys.version_info[0] == 2 + # Determine the word size of the processor. if MAX_INT == MAX_INT64: # 64-bit processor. @@ -41,32 +40,26 @@ else: # Else we just assume 64-bit processor keeping up with modern times. MACHINE_WORD_SIZE = 64 -try: - # < Python3 - unicode_type = unicode -except NameError: - # Python3. - unicode_type = str - -# Fake byte literals. -if str is unicode_type: - def byte_literal(s): - return s.encode('latin1') -else: - def byte_literal(s): - return s - -# ``long`` is no more. Do type detection using this instead. -try: +if PY2: integer_types = (int, long) -except NameError: - integer_types = (int,) + range = xrange + zip = itertools.izip +else: + integer_types = (int, ) + range = range + zip = zip -b = byte_literal -# To avoid calling b() multiple times in tight loops. -ZERO_BYTE = b('\x00') -EMPTY_BYTE = b('') +def write_to_stdout(data): + """Writes bytes to stdout + + :type data: bytes + """ + if PY2: + sys.stdout.write(data) + else: + # On Py3 we must use the buffer interface to write bytes. + sys.stdout.buffer.write(data) def is_bytes(obj): @@ -109,6 +102,27 @@ def byte(num): return pack("B", num) +def xor_bytes(b1, b2): + """ + Returns the bitwise XOR result between two bytes objects, b1 ^ b2. + + Bitwise XOR operation is commutative, so order of parameters doesn't + generate different results. If parameters have different length, extra + length of the largest one is ignored. + + :param b1: + First bytes object. + :param b2: + Second bytes object. + :returns: + Bytes object, result of XOR operation. + """ + if PY2: + return ''.join(byte(ord(x) ^ ord(y)) for x, y in zip(b1, b2)) + + return bytes(x ^ y for x, y in zip(b1, b2)) + + def get_word_alignment(num, force_arch=64, _machine_word_size=MACHINE_WORD_SIZE): """ diff --git a/src/rsa/cli.py b/src/rsa/cli.py index 3a218782..6450af42 100644 --- a/src/rsa/cli.py +++ b/src/rsa/cli.py @@ -26,7 +26,6 @@ import sys from optparse import OptionParser import rsa -import rsa.bigfile import rsa.pkcs1 HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys()) @@ -84,7 +83,7 @@ def keygen(): outfile.write(data) else: print('Writing private key to stdout', file=sys.stderr) - sys.stdout.write(data) + rsa._compat.write_to_stdout(data) class CryptoOperation(object): @@ -113,7 +112,7 @@ class CryptoOperation(object): self.output_help = self.output_help % self.__class__.__dict__ @abc.abstractmethod - def perform_operation(self, indata, key, cli_args=None): + def perform_operation(self, indata, key, cli_args): """Performs the program's operation. Implement in a subclass. @@ -190,7 +189,7 @@ class CryptoOperation(object): outfile.write(outdata) else: print('Writing output to stdout', file=sys.stderr) - sys.stdout.write(outdata) + rsa._compat.write_to_stdout(outdata) class EncryptOperation(CryptoOperation): @@ -198,8 +197,7 @@ class EncryptOperation(CryptoOperation): keyname = 'public' description = ('Encrypts a file. The file must be shorter than the key ' - 'length in order to be encrypted. For larger files, use the ' - 'pyrsa-encrypt-bigfile command.') + 'length in order to be encrypted.') operation = 'encrypt' operation_past = 'encrypted' operation_progressive = 'encrypting' @@ -215,8 +213,7 @@ class DecryptOperation(CryptoOperation): keyname = 'private' description = ('Decrypts a file. The original file must be shorter than ' - 'the key length in order to have been encrypted. For larger ' - 'files, use the pyrsa-decrypt-bigfile command.') + 'the key length in order to have been encrypted.') operation = 'decrypt' operation_past = 'decrypted' operation_progressive = 'decrypting' @@ -285,99 +282,7 @@ class VerifyOperation(CryptoOperation): print('Verification OK', file=sys.stderr) -class BigfileOperation(CryptoOperation): - """CryptoOperation that doesn't read the entire file into memory.""" - - def __init__(self): - CryptoOperation.__init__(self) - - self.file_objects = [] - - def __del__(self): - """Closes any open file handles.""" - - for fobj in self.file_objects: - fobj.close() - - def __call__(self): - """Runs the program.""" - - (cli, cli_args) = self.parse_cli() - - key = self.read_key(cli_args[0], cli.keyform) - - # Get the file handles - infile = self.get_infile(cli.input) - outfile = self.get_outfile(cli.output) - - # Call the operation - print(self.operation_progressive.title(), file=sys.stderr) - self.perform_operation(infile, outfile, key, cli_args) - - def get_infile(self, inname): - """Returns the input file object""" - - if inname: - print('Reading input from %s' % inname, file=sys.stderr) - fobj = open(inname, 'rb') - self.file_objects.append(fobj) - else: - print('Reading input from stdin', file=sys.stderr) - fobj = sys.stdin - - return fobj - - def get_outfile(self, outname): - """Returns the output file object""" - - if outname: - print('Will write output to %s' % outname, file=sys.stderr) - fobj = open(outname, 'wb') - self.file_objects.append(fobj) - else: - print('Will write output to stdout', file=sys.stderr) - fobj = sys.stdout - - return fobj - - -class EncryptBigfileOperation(BigfileOperation): - """Encrypts a file to VARBLOCK format.""" - - keyname = 'public' - description = ('Encrypts a file to an encrypted VARBLOCK file. The file ' - 'can be larger than the key length, but the output file is only ' - 'compatible with Python-RSA.') - operation = 'encrypt' - operation_past = 'encrypted' - operation_progressive = 'encrypting' - - def perform_operation(self, infile, outfile, pub_key, cli_args=None): - """Encrypts files to VARBLOCK.""" - - return rsa.bigfile.encrypt_bigfile(infile, outfile, pub_key) - - -class DecryptBigfileOperation(BigfileOperation): - """Decrypts a file in VARBLOCK format.""" - - keyname = 'private' - description = ('Decrypts an encrypted VARBLOCK file that was encrypted ' - 'with pyrsa-encrypt-bigfile') - operation = 'decrypt' - operation_past = 'decrypted' - operation_progressive = 'decrypting' - key_class = rsa.PrivateKey - - def perform_operation(self, infile, outfile, priv_key, cli_args=None): - """Decrypts a VARBLOCK file.""" - - return rsa.bigfile.decrypt_bigfile(infile, outfile, priv_key) - - encrypt = EncryptOperation() decrypt = DecryptOperation() sign = SignOperation() verify = VerifyOperation() -encrypt_bigfile = EncryptBigfileOperation() -decrypt_bigfile = DecryptBigfileOperation() diff --git a/src/rsa/common.py b/src/rsa/common.py index e0743340..f7aa2d14 100644 --- a/src/rsa/common.py +++ b/src/rsa/common.py @@ -14,17 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rsa._compat import zip + """Common functionality shared by several modules.""" +class NotRelativePrimeError(ValueError): + def __init__(self, a, b, d, msg=None): + super(NotRelativePrimeError, self).__init__( + msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) + self.a = a + self.b = b + self.d = d + + def bit_size(num): """ Number of bits needed to represent a integer excluding any prefix 0 bits. - As per definition from https://wiki.python.org/moin/BitManipulation and - to match the behavior of the Python 3 API. - Usage:: >>> bit_size(1023) @@ -41,41 +49,11 @@ def bit_size(num): :returns: Returns the number of bits in the integer. """ - if num == 0: - return 0 - if num < 0: - num = -num - # Make sure this is an int and not a float. - num & 1 - - hex_num = "%x" % num - return ((len(hex_num) - 1) * 4) + { - '0': 0, '1': 1, '2': 2, '3': 2, - '4': 3, '5': 3, '6': 3, '7': 3, - '8': 4, '9': 4, 'a': 4, 'b': 4, - 'c': 4, 'd': 4, 'e': 4, 'f': 4, - }[hex_num[0]] - - -def _bit_size(number): - """ - Returns the number of bits required to hold a specific long number. - """ - if number < 0: - raise ValueError('Only nonnegative numbers possible: %s' % number) - - if number == 0: - return 0 - - # This works, even with very large numbers. When using math.log(number, 2), - # you'll get rounding errors and it'll fail. - bits = 0 - while number: - bits += 1 - number >>= 1 - - return bits + try: + return num.bit_length() + except AttributeError: + raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) def byte_size(number): @@ -98,11 +76,33 @@ def byte_size(number): :returns: The number of bytes required to hold a specific long number. """ - quanta, mod = divmod(bit_size(number), 8) - if mod or number == 0: + if number == 0: + return 1 + return ceil_div(bit_size(number), 8) + + +def ceil_div(num, div): + """ + Returns the ceiling function of a division between `num` and `div`. + + Usage:: + + >>> ceil_div(100, 7) + 15 + >>> ceil_div(100, 10) + 10 + >>> ceil_div(1, 4) + 1 + + :param num: Division's numerator, a number + :param div: Division's divisor, a number + + :return: Rounded up result of the division between the parameters. + """ + quanta, mod = divmod(num, div) + if mod: quanta += 1 return quanta - # return int(math.ceil(bit_size(number) / 8.0)) def extended_gcd(a, b): @@ -131,7 +131,7 @@ def extended_gcd(a, b): def inverse(x, n): - """Returns x^-1 (mod n) + """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n) >>> inverse(7, 4) 3 @@ -142,7 +142,7 @@ def inverse(x, n): (divider, inv, _) = extended_gcd(x, n) if divider != 1: - raise ValueError("x (%d) and n (%d) are not relatively prime" % (x, n)) + raise NotRelativePrimeError(x, n, divider) return inv diff --git a/src/rsa/key.py b/src/rsa/key.py index 64600a27..1004412b 100644 --- a/src/rsa/key.py +++ b/src/rsa/key.py @@ -34,14 +34,16 @@ of pyasn1. """ import logging -from rsa._compat import b +import warnings +from rsa._compat import range import rsa.prime import rsa.pem import rsa.common import rsa.randnum import rsa.core + log = logging.getLogger(__name__) DEFAULT_EXPONENT = 65537 @@ -55,15 +57,56 @@ class AbstractKey(object): self.n = n self.e = e + @classmethod + def _load_pkcs1_pem(cls, keyfile): + """Loads a key in PKCS#1 PEM format, implement in a subclass. + + :param keyfile: contents of a PEM-encoded file that contains + the public key. + :type keyfile: bytes + + :return: the loaded key + :rtype: AbstractKey + """ + + @classmethod + def _load_pkcs1_der(cls, keyfile): + """Loads a key in PKCS#1 PEM format, implement in a subclass. + + :param keyfile: contents of a DER-encoded file that contains + the public key. + :type keyfile: bytes + + :return: the loaded key + :rtype: AbstractKey + """ + + def _save_pkcs1_pem(self): + """Saves the key in PKCS#1 PEM format, implement in a subclass. + + :returns: the PEM-encoded key. + :rtype: bytes + """ + + def _save_pkcs1_der(self): + """Saves the key in PKCS#1 DER format, implement in a subclass. + + :returns: the DER-encoded key. + :rtype: bytes + """ + @classmethod def load_pkcs1(cls, keyfile, format='PEM'): """Loads a key in PKCS#1 DER or PEM format. :param keyfile: contents of a DER- or PEM-encoded file that contains - the public key. + the key. + :type keyfile: bytes :param format: the format of the file to load; 'PEM' or 'DER' + :type format: str - :return: a PublicKey object + :return: the loaded key + :rtype: AbstractKey """ methods = { @@ -87,10 +130,12 @@ class AbstractKey(object): formats)) def save_pkcs1(self, format='PEM'): - """Saves the public key in PKCS#1 DER or PEM format. + """Saves the key in PKCS#1 DER or PEM format. :param format: the format to save; 'PEM' or 'DER' - :returns: the DER- or PEM-encoded public key. + :type format: str + :returns: the DER- or PEM-encoded key. + :rtype: bytes """ methods = { @@ -139,7 +184,7 @@ class PublicKey(AbstractKey): This key is also known as the 'encryption key'. It contains the 'n' and 'e' values. - Supports attributes as well as dictionary-like access. Attribute accesss is + Supports attributes as well as dictionary-like access. Attribute access is faster, though. >>> PublicKey(5, 3) @@ -185,6 +230,9 @@ class PublicKey(AbstractKey): def __ne__(self, other): return not (self == other) + def __hash__(self): + return hash((self.n, self.e)) + @classmethod def _load_pkcs1_der(cls, keyfile): """Loads a key in PKCS#1 DER format. @@ -215,7 +263,8 @@ class PublicKey(AbstractKey): def _save_pkcs1_der(self): """Saves the public key in PKCS#1 DER format. - @returns: the DER-encoded public key. + :returns: the DER-encoded public key. + :rtype: bytes """ from pyasn1.codec.der import encoder @@ -247,6 +296,7 @@ class PublicKey(AbstractKey): """Saves a PKCS#1 PEM-encoded public key file. :return: contents of a PEM-encoded file that contains the public key. + :rtype: bytes """ der = self._save_pkcs1_der() @@ -264,6 +314,7 @@ class PublicKey(AbstractKey): :param keyfile: contents of a PEM-encoded file that contains the public key, from OpenSSL. + :type keyfile: bytes :return: a PublicKey object """ @@ -277,6 +328,7 @@ class PublicKey(AbstractKey): :param keyfile: contents of a DER-encoded file that contains the public key, from OpenSSL. :return: a PublicKey object + :rtype: bytes """ @@ -298,57 +350,36 @@ class PrivateKey(AbstractKey): This key is also known as the 'decryption key'. It contains the 'n', 'e', 'd', 'p', 'q' and other values. - Supports attributes as well as dictionary-like access. Attribute accesss is + Supports attributes as well as dictionary-like access. Attribute access is faster, though. >>> PrivateKey(3247, 65537, 833, 191, 17) PrivateKey(3247, 65537, 833, 191, 17) - exp1, exp2 and coef can be given, but if None or omitted they will be calculated: + exp1, exp2 and coef will be calculated: - >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287, exp2=4) + >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) >>> pk.exp1 55063 - >>> pk.exp2 # this is of course not a correct value, but it is the one we passed. - 4 + >>> pk.exp2 + 10095 >>> pk.coef 50797 - If you give exp1, exp2 or coef, they will be used as-is: - - >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8) - >>> pk.exp1 - 6 - >>> pk.exp2 - 7 - >>> pk.coef - 8 - """ __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') - def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None): + def __init__(self, n, e, d, p, q): AbstractKey.__init__(self, n, e) self.d = d self.p = p self.q = q - # Calculate the other values if they aren't supplied - if exp1 is None: - self.exp1 = int(d % (p - 1)) - else: - self.exp1 = exp1 - - if exp2 is None: - self.exp2 = int(d % (q - 1)) - else: - self.exp2 = exp2 - - if coef is None: - self.coef = rsa.common.inverse(q, p) - else: - self.coef = coef + # Calculate exponents and coefficient. + self.exp1 = int(d % (p - 1)) + self.exp2 = int(d % (q - 1)) + self.coef = rsa.common.inverse(q, p) def __getitem__(self, key): return getattr(self, key) @@ -383,6 +414,9 @@ class PrivateKey(AbstractKey): def __ne__(self, other): return not (self == other) + def __hash__(self): + return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) + def blinded_decrypt(self, encrypted): """Decrypts the message using blinding to prevent side-channel attacks. @@ -420,6 +454,7 @@ class PrivateKey(AbstractKey): :param keyfile: contents of a DER-encoded file that contains the private key. + :type keyfile: bytes :return: a PrivateKey object First let's construct a DER encoded key: @@ -456,13 +491,26 @@ class PrivateKey(AbstractKey): if priv[0] != 0: raise ValueError('Unable to read this file, version %s != 0' % priv[0]) - as_ints = tuple(int(x) for x in priv[1:9]) - return cls(*as_ints) + as_ints = map(int, priv[1:6]) + key = cls(*as_ints) + + exp1, exp2, coef = map(int, priv[6:9]) + + if (key.exp1, key.exp2, key.coef) != (exp1, exp2, coef): + warnings.warn( + 'You have provided a malformed keyfile. Either the exponents ' + 'or the coefficient are incorrect. Using the correct values ' + 'instead.', + UserWarning, + ) + + return key def _save_pkcs1_der(self): """Saves the private key in PKCS#1 DER format. - @returns: the DER-encoded private key. + :returns: the DER-encoded private key. + :rtype: bytes """ from pyasn1.type import univ, namedtype @@ -470,15 +518,15 @@ class PrivateKey(AbstractKey): class AsnPrivKey(univ.Sequence): componentType = namedtype.NamedTypes( - namedtype.NamedType('version', univ.Integer()), - namedtype.NamedType('modulus', univ.Integer()), - namedtype.NamedType('publicExponent', univ.Integer()), - namedtype.NamedType('privateExponent', univ.Integer()), - namedtype.NamedType('prime1', univ.Integer()), - namedtype.NamedType('prime2', univ.Integer()), - namedtype.NamedType('exponent1', univ.Integer()), - namedtype.NamedType('exponent2', univ.Integer()), - namedtype.NamedType('coefficient', univ.Integer()), + namedtype.NamedType('version', univ.Integer()), + namedtype.NamedType('modulus', univ.Integer()), + namedtype.NamedType('publicExponent', univ.Integer()), + namedtype.NamedType('privateExponent', univ.Integer()), + namedtype.NamedType('prime1', univ.Integer()), + namedtype.NamedType('prime2', univ.Integer()), + namedtype.NamedType('exponent1', univ.Integer()), + namedtype.NamedType('exponent2', univ.Integer()), + namedtype.NamedType('coefficient', univ.Integer()), ) # Create the ASN object @@ -504,20 +552,22 @@ class PrivateKey(AbstractKey): :param keyfile: contents of a PEM-encoded file that contains the private key. + :type keyfile: bytes :return: a PrivateKey object """ - der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY')) + der = rsa.pem.load_pem(keyfile, b'RSA PRIVATE KEY') return cls._load_pkcs1_der(der) def _save_pkcs1_pem(self): """Saves a PKCS#1 PEM-encoded private key file. :return: contents of a PEM-encoded file that contains the private key. + :rtype: bytes """ der = self._save_pkcs1_der() - return rsa.pem.save_pem(der, b('RSA PRIVATE KEY')) + return rsa.pem.save_pem(der, b'RSA PRIVATE KEY') def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): @@ -615,9 +665,11 @@ def calculate_keys_custom_exponent(p, q, exponent): try: d = rsa.common.inverse(exponent, phi_n) - except ValueError: - raise ValueError("e (%d) and phi_n (%d) are not relatively prime" % - (exponent, phi_n)) + except rsa.common.NotRelativePrimeError as ex: + raise rsa.common.NotRelativePrimeError( + exponent, phi_n, ex.d, + msg="e (%d) and phi_n (%d) are not relatively prime (divider=%i)" % + (exponent, phi_n, ex.d)) if (exponent * d) % phi_n != 1: raise ValueError("e (%d) and d (%d) are not mult. inv. modulo " @@ -731,7 +783,7 @@ if __name__ == '__main__': if failures: break - if (count and count % 10 == 0) or count == 1: + if (count % 10 == 0 and count) or count == 1: print('%i times' % count) except KeyboardInterrupt: print('Aborted') diff --git a/src/rsa/machine_size.py b/src/rsa/machine_size.py new file mode 100644 index 00000000..2a871b8f --- /dev/null +++ b/src/rsa/machine_size.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Detection of 32-bit and 64-bit machines and byte alignment.""" + +import sys + +MAX_INT = sys.maxsize +MAX_INT64 = (1 << 63) - 1 +MAX_INT32 = (1 << 31) - 1 +MAX_INT16 = (1 << 15) - 1 + +# Determine the word size of the processor. +if MAX_INT == MAX_INT64: + # 64-bit processor. + MACHINE_WORD_SIZE = 64 +elif MAX_INT == MAX_INT32: + # 32-bit processor. + MACHINE_WORD_SIZE = 32 +else: + # Else we just assume 64-bit processor keeping up with modern times. + MACHINE_WORD_SIZE = 64 + + +def get_word_alignment(num, force_arch=64, + _machine_word_size=MACHINE_WORD_SIZE): + """ + Returns alignment details for the given number based on the platform + Python is running on. + + :param num: + Unsigned integral number. + :param force_arch: + If you don't want to use 64-bit unsigned chunks, set this to + anything other than 64. 32-bit chunks will be preferred then. + Default 64 will be used when on a 64-bit machine. + :param _machine_word_size: + (Internal) The machine word size used for alignment. + :returns: + 4-tuple:: + + (word_bits, word_bytes, + max_uint, packing_format_type) + """ + max_uint64 = 0xffffffffffffffff + max_uint32 = 0xffffffff + max_uint16 = 0xffff + max_uint8 = 0xff + + if force_arch == 64 and _machine_word_size >= 64 and num > max_uint32: + # 64-bit unsigned integer. + return 64, 8, max_uint64, "Q" + elif num > max_uint16: + # 32-bit unsigned integer + return 32, 4, max_uint32, "L" + elif num > max_uint8: + # 16-bit unsigned integer. + return 16, 2, max_uint16, "H" + else: + # 8-bit unsigned integer. + return 8, 1, max_uint8, "B" diff --git a/src/rsa/parallel.py b/src/rsa/parallel.py index edc924fd..a3fe3122 100644 --- a/src/rsa/parallel.py +++ b/src/rsa/parallel.py @@ -28,6 +28,7 @@ from __future__ import print_function import multiprocessing as mp +from rsa._compat import range import rsa.prime import rsa.randnum @@ -94,7 +95,7 @@ if __name__ == '__main__': if failures: break - if count and count % 10 == 0: + if count % 10 == 0 and count: print('%i times' % count) print('Doctests done') diff --git a/src/rsa/pem.py b/src/rsa/pem.py index 0f68cb2a..2ddfae86 100644 --- a/src/rsa/pem.py +++ b/src/rsa/pem.py @@ -17,19 +17,20 @@ """Functions that load and write PEM-encoded files.""" import base64 -from rsa._compat import b, is_bytes + +from rsa._compat import is_bytes, range def _markers(pem_marker): """ - Returns the start and end PEM markers + Returns the start and end PEM markers, as bytes. """ - if is_bytes(pem_marker): - pem_marker = pem_marker.decode('utf-8') + if not is_bytes(pem_marker): + pem_marker = pem_marker.encode('ascii') - return (b('-----BEGIN %s-----' % pem_marker), - b('-----END %s-----' % pem_marker)) + return (b'-----BEGIN ' + pem_marker + b'-----', + b'-----END ' + pem_marker + b'-----') def load_pem(contents, pem_marker): @@ -81,7 +82,7 @@ def load_pem(contents, pem_marker): break # Load fields - if b(':') in line: + if b':' in line: continue pem_lines.append(line) @@ -94,7 +95,7 @@ def load_pem(contents, pem_marker): raise ValueError('No PEM end marker "%s" found' % pem_end) # Base64-decode the contents - pem = b('').join(pem_lines) + pem = b''.join(pem_lines) return base64.standard_b64decode(pem) @@ -106,13 +107,13 @@ def save_pem(contents, pem_marker): when your file has '-----BEGIN RSA PRIVATE KEY-----' and '-----END RSA PRIVATE KEY-----' markers. - :return: the base64-encoded content between the start and end markers. + :return: the base64-encoded content between the start and end markers, as bytes. """ (pem_start, pem_end) = _markers(pem_marker) - b64 = base64.standard_b64encode(contents).replace(b('\n'), b('')) + b64 = base64.standard_b64encode(contents).replace(b'\n', b'') pem_lines = [pem_start] for block_start in range(0, len(b64), 64): @@ -120,6 +121,6 @@ def save_pem(contents, pem_marker): pem_lines.append(block) pem_lines.append(pem_end) - pem_lines.append(b('')) + pem_lines.append(b'') - return b('\n').join(pem_lines) + return b'\n'.join(pem_lines) diff --git a/src/rsa/pkcs1.py b/src/rsa/pkcs1.py index 28f0dc54..84f0e3b6 100644 --- a/src/rsa/pkcs1.py +++ b/src/rsa/pkcs1.py @@ -31,21 +31,23 @@ to your users. import hashlib import os -from rsa._compat import b +from rsa._compat import range from rsa import common, transform, core # ASN.1 codes that describe the hash algorithm used. HASH_ASN1 = { - 'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'), - 'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'), - 'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'), - 'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'), - 'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'), + 'MD5': b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10', + 'SHA-1': b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14', + 'SHA-224': b'\x30\x2d\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x04\x05\x00\x04\x1c', + 'SHA-256': b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20', + 'SHA-384': b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30', + 'SHA-512': b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40', } HASH_METHODS = { 'MD5': hashlib.md5, 'SHA-1': hashlib.sha1, + 'SHA-224': hashlib.sha224, 'SHA-256': hashlib.sha256, 'SHA-384': hashlib.sha384, 'SHA-512': hashlib.sha512, @@ -87,7 +89,7 @@ def _pad_for_encryption(message, target_length): ' space for %i' % (msglength, max_msglength)) # Get random padding - padding = b('') + padding = b'' padding_length = target_length - msglength - 3 # We remove 0-bytes, so we'll end up with less padding than we've asked for, @@ -99,15 +101,15 @@ def _pad_for_encryption(message, target_length): # after removing the 0-bytes. This increases the chance of getting # enough bytes, especially when needed_bytes is small new_padding = os.urandom(needed_bytes + 5) - new_padding = new_padding.replace(b('\x00'), b('')) + new_padding = new_padding.replace(b'\x00', b'') padding = padding + new_padding[:needed_bytes] assert len(padding) == padding_length - return b('').join([b('\x00\x02'), - padding, - b('\x00'), - message]) + return b''.join([b'\x00\x02', + padding, + b'\x00', + message]) def _pad_for_signing(message, target_length): @@ -138,10 +140,10 @@ def _pad_for_signing(message, target_length): padding_length = target_length - msglength - 3 - return b('').join([b('\x00\x01'), - padding_length * b('\xff'), - b('\x00'), - message]) + return b''.join([b'\x00\x01', + padding_length * b'\xff', + b'\x00', + message]) def encrypt(message, pub_key): @@ -233,19 +235,53 @@ def decrypt(crypto, priv_key): cleartext = transform.int2bytes(decrypted, blocksize) # If we can't find the cleartext marker, decryption failed. - if cleartext[0:2] != b('\x00\x02'): + if cleartext[0:2] != b'\x00\x02': raise DecryptionError('Decryption failed') # Find the 00 separator between the padding and the message try: - sep_idx = cleartext.index(b('\x00'), 2) + sep_idx = cleartext.index(b'\x00', 2) except ValueError: raise DecryptionError('Decryption failed') return cleartext[sep_idx + 1:] -def sign(message, priv_key, hash): +def sign_hash(hash_value, priv_key, hash_method): + """Signs a precomputed hash with the private key. + + Hashes the message, then signs the hash with the given key. This is known + as a "detached signature", because the message itself isn't altered. + + :param hash_value: A precomputed hash to sign (ignores message). Should be set to + None if needing to hash and sign message. + :param priv_key: the :py:class:`rsa.PrivateKey` to sign with + :param hash_method: the hash method used on the message. Use 'MD5', 'SHA-1', + 'SHA-224', SHA-256', 'SHA-384' or 'SHA-512'. + :return: a message signature block. + :raise OverflowError: if the private key is too small to contain the + requested hash. + + """ + + # Get the ASN1 code for this hash method + if hash_method not in HASH_ASN1: + raise ValueError('Invalid hash method: %s' % hash_method) + asn1code = HASH_ASN1[hash_method] + + # Encrypt the hash with the private key + cleartext = asn1code + hash_value + keylength = common.byte_size(priv_key.n) + padded = _pad_for_signing(cleartext, keylength) + + payload = transform.bytes2int(padded) + encrypted = priv_key.blinded_encrypt(payload) + block = transform.int2bytes(encrypted, keylength) + + return block + + +def sign(message, priv_key, hash_method): """Signs the message with the private key. Hashes the message, then signs the hash with the given key. This is known @@ -255,32 +291,16 @@ def sign(message, priv_key, hash): object. If ``message`` has a ``read()`` method, it is assumed to be a file-like object. :param priv_key: the :py:class:`rsa.PrivateKey` to sign with - :param hash: the hash method used on the message. Use 'MD5', 'SHA-1', - 'SHA-256', 'SHA-384' or 'SHA-512'. + :param hash_method: the hash method used on the message. Use 'MD5', 'SHA-1', + 'SHA-224', SHA-256', 'SHA-384' or 'SHA-512'. :return: a message signature block. :raise OverflowError: if the private key is too small to contain the requested hash. """ - # Get the ASN1 code for this hash method - if hash not in HASH_ASN1: - raise ValueError('Invalid hash method: %s' % hash) - asn1code = HASH_ASN1[hash] - - # Calculate the hash - hash = _hash(message, hash) - - # Encrypt the hash with the private key - cleartext = asn1code + hash - keylength = common.byte_size(priv_key.n) - padded = _pad_for_signing(cleartext, keylength) - - payload = transform.bytes2int(padded) - encrypted = priv_key.blinded_encrypt(payload) - block = transform.int2bytes(encrypted, keylength) - - return block + msg_hash = compute_hash(message, hash_method) + return sign_hash(msg_hash, priv_key, hash_method) def verify(message, signature, pub_key): @@ -294,6 +314,7 @@ def verify(message, signature, pub_key): :param signature: the signature block, as created with :py:func:`rsa.sign`. :param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message. :raise VerificationError: when the signature doesn't match the message. + :returns: the name of the used hash. """ @@ -304,7 +325,7 @@ def verify(message, signature, pub_key): # Get the hash method method_name = _find_method_hash(clearsig) - message_hash = _hash(message, method_name) + message_hash = compute_hash(message, method_name) # Reconstruct the expected padded hash cleartext = HASH_ASN1[method_name] + message_hash @@ -314,10 +335,50 @@ def verify(message, signature, pub_key): if expected != clearsig: raise VerificationError('Verification failed') - return True + return method_name -def _hash(message, method_name): +def find_signature_hash(signature, pub_key): + """Returns the hash name detected from the signature. + + If you also want to verify the message, use :py:func:`rsa.verify()` instead. + It also returns the name of the used hash. + + :param signature: the signature block, as created with :py:func:`rsa.sign`. + :param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message. + :returns: the name of the used hash. + """ + + keylength = common.byte_size(pub_key.n) + encrypted = transform.bytes2int(signature) + decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n) + clearsig = transform.int2bytes(decrypted, keylength) + + return _find_method_hash(clearsig) + + +def yield_fixedblocks(infile, blocksize): + """Generator, yields each block of ``blocksize`` bytes in the input file. + + :param infile: file to read and separate in blocks. + :param blocksize: block size in bytes. + :returns: a generator that yields the contents of each block + """ + + while True: + block = infile.read(blocksize) + + read_bytes = len(block) + if read_bytes == 0: + break + + yield block + + if read_bytes < blocksize: + break + + +def compute_hash(message, method_name): """Returns the message digest. :param message: the signed message. Can be an 8-bit string or a file-like @@ -335,11 +396,8 @@ def _hash(message, method_name): hasher = method() if hasattr(message, 'read') and hasattr(message.read, '__call__'): - # Late import to prevent DeprecationWarnings. - from . import varblock - # read as 1K blocks - for block in varblock.yield_fixedblocks(message, 1024): + for block in yield_fixedblocks(message, 1024): hasher.update(block) else: # hash the message object itself. @@ -375,7 +433,7 @@ if __name__ == '__main__': if failures: break - if count and count % 100 == 0: + if count % 100 == 0 and count: print('%i times' % count) print('Doctests done') diff --git a/src/rsa/pkcs1_v2.py b/src/rsa/pkcs1_v2.py new file mode 100644 index 00000000..5f9c7ddc --- /dev/null +++ b/src/rsa/pkcs1_v2.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for PKCS#1 version 2 encryption and signing + +This module implements certain functionality from PKCS#1 version 2. Main +documentation is RFC 2437: https://tools.ietf.org/html/rfc2437 +""" + +from rsa._compat import range +from rsa import ( + common, + pkcs1, + transform, +) + + +def mgf1(seed, length, hasher='SHA-1'): + """ + MGF1 is a Mask Generation Function based on a hash function. + + A mask generation function takes an octet string of variable length and a + desired output length as input, and outputs an octet string of the desired + length. The plaintext-awareness of RSAES-OAEP relies on the random nature of + the output of the mask generation function, which in turn relies on the + random nature of the underlying hash. + + :param bytes seed: seed from which mask is generated, an octet string + :param int length: intended length in octets of the mask, at most 2^32(hLen) + :param str hasher: hash function (hLen denotes the length in octets of the hash + function output) + + :return: mask, an octet string of length `length` + :rtype: bytes + + :raise OverflowError: when `length` is too large for the specified `hasher` + :raise ValueError: when specified `hasher` is invalid + """ + + try: + hash_length = pkcs1.HASH_METHODS[hasher]().digest_size + except KeyError: + raise ValueError( + 'Invalid `hasher` specified. Please select one of: {hash_list}'.format( + hash_list=', '.join(sorted(pkcs1.HASH_METHODS.keys())) + ) + ) + + # If l > 2^32(hLen), output "mask too long" and stop. + if length > (2**32 * hash_length): + raise OverflowError( + "Desired length should be at most 2**32 times the hasher's output " + "length ({hash_length} for {hasher} function)".format( + hash_length=hash_length, + hasher=hasher, + ) + ) + + # Looping `counter` from 0 to ceil(l / hLen)-1, build `output` based on the + # hashes formed by (`seed` + C), being `C` an octet string of length 4 + # generated by converting `counter` with the primitive I2OSP + output = b''.join( + pkcs1.compute_hash( + seed + transform.int2bytes(counter, fill_size=4), + method_name=hasher, + ) + for counter in range(common.ceil_div(length, hash_length) + 1) + ) + + # Output the leading `length` octets of `output` as the octet string mask. + return output[:length] + + +__all__ = [ + 'mgf1', +] + +if __name__ == '__main__': + print('Running doctests 1000x or until failure') + import doctest + + for count in range(1000): + (failures, tests) = doctest.testmod() + if failures: + break + + if count % 100 == 0 and count: + print('%i times' % count) + + print('Doctests done') diff --git a/src/rsa/prime.py b/src/rsa/prime.py index 6f23f9da..3d63542e 100644 --- a/src/rsa/prime.py +++ b/src/rsa/prime.py @@ -20,6 +20,8 @@ Implementation based on the book Algorithm Design by Michael T. Goodrich and Roberto Tamassia, 2002. """ +from rsa._compat import range +import rsa.common import rsa.randnum __all__ = ['getprime', 'are_relatively_prime'] @@ -37,6 +39,32 @@ def gcd(p, q): return p +def get_primality_testing_rounds(number): + """Returns minimum number of rounds for Miller-Rabing primality testing, + based on number bitsize. + + According to NIST FIPS 186-4, Appendix C, Table C.3, minimum number of + rounds of M-R testing, using an error probability of 2 ** (-100), for + different p, q bitsizes are: + * p, q bitsize: 512; rounds: 7 + * p, q bitsize: 1024; rounds: 4 + * p, q bitsize: 1536; rounds: 3 + See: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf + """ + + # Calculate number bitsize. + bitsize = rsa.common.bit_size(number) + # Set number of rounds. + if bitsize >= 1536: + return 3 + if bitsize >= 1024: + return 4 + if bitsize >= 512: + return 7 + # For smaller bitsizes, set arbitrary number of rounds. + return 10 + + def miller_rabin_primality_testing(n, k): """Calculates whether n is composite (which is always correct) or prime (which theoretically is incorrect with error probability 4**-k), by @@ -69,7 +97,7 @@ def miller_rabin_primality_testing(n, k): # Test k witnesses. for _ in range(k): # Generate random integer a, where 2 <= a <= (n - 2) - a = rsa.randnum.randint(n - 4) + 2 + a = rsa.randnum.randint(n - 3) + 1 x = pow(a, d, n) if x == 1 or x == n - 1: @@ -99,26 +127,21 @@ def is_prime(number): False >>> is_prime(41) True - >>> [x for x in range(901, 1000) if is_prime(x)] - [907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997] """ # Check for small numbers. if number < 10: - return number in [2, 3, 5, 7] + return number in {2, 3, 5, 7} # Check for even numbers. if not (number & 1): return False - # According to NIST FIPS 186-4, Appendix C, Table C.3, minimum number of - # rounds of M-R testing, using an error probability of 2 ** (-100), for - # different p, q bitsizes are: - # * p, q bitsize: 512; rounds: 7 - # * p, q bitsize: 1024; rounds: 4 - # * p, q bitsize: 1536; rounds: 3 - # See: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf - return miller_rabin_primality_testing(number, 7) + # Calculate minimum number of rounds. + k = get_primality_testing_rounds(number) + + # Run primality testing with (minimum + 1) rounds. + return miller_rabin_primality_testing(number, k + 1) def getprime(nbits): @@ -172,7 +195,7 @@ if __name__ == '__main__': if failures: break - if count and count % 100 == 0: + if count % 100 == 0 and count: print('%i times' % count) print('Doctests done') diff --git a/src/rsa/randnum.py b/src/rsa/randnum.py index 3c788a57..310acaa6 100644 --- a/src/rsa/randnum.py +++ b/src/rsa/randnum.py @@ -88,7 +88,7 @@ def randint(maxvalue): if value <= maxvalue: break - if tries and tries % 10 == 0: + if tries % 10 == 0 and tries: # After a lot of tries to get the right number of bits but still # smaller than maxvalue, decrease the number of bits by 1. That'll # dramatically increase the chances to get a large enough number. diff --git a/src/rsa/transform.py b/src/rsa/transform.py index 16061a94..628d0afb 100644 --- a/src/rsa/transform.py +++ b/src/rsa/transform.py @@ -21,20 +21,11 @@ From bytes to a number, number to bytes, etc. from __future__ import absolute_import -try: - # We'll use psyco if available on 32-bit architectures to speed up code. - # Using psyco (if available) cuts down the execution time on Python 2.5 - # at least by half. - import psyco - - psyco.full() -except ImportError: - pass - import binascii from struct import pack -from rsa import common -from rsa._compat import is_integer, b, byte, get_word_alignment, ZERO_BYTE, EMPTY_BYTE + +from rsa._compat import byte, is_integer +from rsa import common, machine_size def bytes2int(raw_bytes): @@ -92,7 +83,7 @@ def _int2bytes(number, block_size=None): # Do some bounds checking if number == 0: needed_bytes = 1 - raw_bytes = [ZERO_BYTE] + raw_bytes = [b'\x00'] else: needed_bytes = common.byte_size(number) raw_bytes = [] @@ -110,14 +101,14 @@ def _int2bytes(number, block_size=None): # Pad with zeroes to fill the block if block_size and block_size > 0: - padding = (block_size - needed_bytes) * ZERO_BYTE + padding = (block_size - needed_bytes) * b'\x00' else: - padding = EMPTY_BYTE + padding = b'' - return padding + EMPTY_BYTE.join(raw_bytes) + return padding + b''.join(raw_bytes) -def bytes_leading(raw_bytes, needle=ZERO_BYTE): +def bytes_leading(raw_bytes, needle=b'\x00'): """ Finds the number of prefixed byte occurrences in the haystack. @@ -126,7 +117,7 @@ def bytes_leading(raw_bytes, needle=ZERO_BYTE): :param raw_bytes: Raw bytes. :param needle: - The byte to count. Default \000. + The byte to count. Default \x00. :returns: The number of leading needle bytes. """ @@ -186,11 +177,11 @@ def int2bytes(number, fill_size=None, chunk_size=None, overflow=False): # Ensure these are integers. number & 1 - raw_bytes = b('') + raw_bytes = b'' # Pack the integer one machine word at a time into bytes. num = number - word_bits, _, max_uint, pack_type = get_word_alignment(num) + word_bits, _, max_uint, pack_type = machine_size.get_word_alignment(num) pack_format = ">%s" % pack_type while num > 0: raw_bytes = pack(pack_format, num & max_uint) + raw_bytes @@ -198,7 +189,7 @@ def int2bytes(number, fill_size=None, chunk_size=None, overflow=False): # Obtain the index of the first non-zero byte. zero_leading = bytes_leading(raw_bytes) if number == 0: - raw_bytes = ZERO_BYTE + raw_bytes = b'\x00' # De-padding. raw_bytes = raw_bytes[zero_leading:] @@ -209,12 +200,12 @@ def int2bytes(number, fill_size=None, chunk_size=None, overflow=False): "Need %d bytes for number, but fill size is %d" % (length, fill_size) ) - raw_bytes = raw_bytes.rjust(fill_size, ZERO_BYTE) + raw_bytes = raw_bytes.rjust(fill_size, b'\x00') elif chunk_size and chunk_size > 0: remainder = length % chunk_size if remainder: padding_size = chunk_size - remainder - raw_bytes = raw_bytes.rjust(length + padding_size, ZERO_BYTE) + raw_bytes = raw_bytes.rjust(length + padding_size, b'\x00') return raw_bytes