From 71da849ba9a880fc4dfaff6a6eff0b87fcd9f915 Mon Sep 17 00:00:00 2001 From: ejochman <34144949+ejochman@users.noreply.github.com> Date: Thu, 23 Jan 2020 09:07:55 -0800 Subject: [PATCH] Move transport customizations to their own module (#1071) * Move transport customizations to their own module This helps to accomplish a few things: 1) Makes the forced user-agent customization on HTTP requests a bit clearer by subclassing the targeted objects (as opposed to hiding the behavior behind a forced override of the google_auth_httplib2 object methods) 2) Standardizes the creation of HTTP objects. These objects can still largely be customized, but using a single creation mechanism will standardize a default and streamline creation, thereby decreasing the code that would otherwise be replicated in the caller 3) Moves create_http() to a more general purpose module, since it will likely be used by more than gapi-related methods. * Use string values for TLS version tests More closely matches [existing behavior](https://github.com/jay0lee/GAM/blob/4fb73e6073bb33034f4ea2221548a471a7829b06/src/var.py#L853) that sets the global default to a string value. --- src/gam.py | 67 +++++--------- src/gapi/__init__.py | 99 ++++++++------------- src/gapi/__init___test.py | 34 -------- src/transport.py | 100 +++++++++++++++++++++ src/transport_test.py | 179 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 335 insertions(+), 144 deletions(-) create mode 100644 src/transport.py create mode 100644 src/transport_test.py diff --git a/src/gam.py b/src/gam.py index 7fc6f40f..767c99ec 100755 --- a/src/gam.py +++ b/src/gam.py @@ -65,7 +65,6 @@ import googleapiclient.http import google.oauth2.id_token import google.oauth2.service_account import google_auth_oauthlib.flow -import google_auth_httplib2 import httplib2 from cryptography import x509 @@ -81,6 +80,7 @@ import display import fileutils import gapi.errors import gapi +import transport import utils from var import * @@ -98,31 +98,6 @@ else: # Source code GM_Globals[GM_GAM_PATH] = os.path.dirname(os.path.realpath(__file__)) -# Override and wrap google_auth_httplib2 request methods so that the GAM -# user-agent string is inserted into HTTP request headers. -def _request_with_user_agent(request_method): - """Inserts the GAM user-agent header kwargs sent to a method.""" - GAM_USER_AGENT = GAM_INFO - - def wrapped_request_method(self, *args, **kwargs): - if kwargs.get('headers') is not None: - if kwargs['headers'].get('user-agent'): - if GAM_USER_AGENT not in kwargs['headers']['user-agent']: - # Save the existing user-agent header and tack on the GAM user-agent. - kwargs['headers']['user-agent'] = '%s %s' % (GAM_USER_AGENT, kwargs['headers']['user-agent']) - else: - kwargs['headers']['user-agent'] = GAM_USER_AGENT - else: - kwargs['headers'] = {'user-agent': GAM_USER_AGENT} - return request_method(self, *args, **kwargs) - - return wrapped_request_method - -google_auth_httplib2.Request.__call__ = _request_with_user_agent( - google_auth_httplib2.Request.__call__) -google_auth_httplib2.AuthorizedHttp.request = _request_with_user_agent( - google_auth_httplib2.AuthorizedHttp.request) - def showUsage(): doGAMVersion(checkForArgs=False) print(''' @@ -649,7 +624,7 @@ def getLocalGoogleTimeOffset(testLocation='www.googleapis.com'): # we disable SSL verify so we can still get time even if clock # is way off. This could be spoofed / MitM but we'll fail for those # situations everywhere else but here. - badhttp = gapi.create_http() + badhttp = transport.create_http() badhttp.disable_ssl_certificate_validation = True googleUTC = dateutil.parser.parse(badhttp.request('https://'+testLocation, 'HEAD')[0]['date']) except (httplib2.ServerNotFoundError, RuntimeError, ValueError) as e: @@ -682,7 +657,7 @@ def doGAMCheckForUpdates(forceCheck=False): return check_url = GAM_LATEST_RELEASE # latest full release headers = {'Accept': 'application/vnd.github.v3.text+json'} - simplehttp = gapi.create_http(timeout=10) + simplehttp = transport.create_http(timeout=10) try: (_, c) = simplehttp.request(check_url, 'GET', headers=headers) try: @@ -785,7 +760,7 @@ def _getServerTLSUsed(location): url = 'https://%s' % location _, netloc, _, _, _, _ = urlparse(url) conn = 'https:%s' % netloc - httpc = gapi.create_http() + httpc = transport.create_http() headers = {'user-agent': GAM_INFO} retries = 5 for n in range(1, retries+1): @@ -874,7 +849,7 @@ def getValidOauth2TxtCredentials(force_refresh=False): retries = 3 for n in range(1, retries+1): try: - credentials.refresh(google_auth_httplib2.Request(gapi.create_http())) + credentials.refresh(transport.create_request()) writeCredentials(credentials) break except google.auth.exceptions.RefreshError as e: @@ -949,7 +924,7 @@ def buildGAPIObject(api): GM_Globals[GM_CURRENT_API_USER] = None credentials = getValidOauth2TxtCredentials() credentials.user_agent = GAM_INFO - http = google_auth_httplib2.AuthorizedHttp(credentials, gapi.create_http(cache=GM_Globals[GM_CACHE_DIR])) + http = transport.AuthorizedHttp(credentials, transport.create_http(cache=GM_Globals[GM_CACHE_DIR])) service = getService(api, http) if GC_Values[GC_DOMAIN]: if not GC_Values[GC_CUSTOMER_ID]: @@ -1040,17 +1015,17 @@ def convertEmailAddressToUID(emailAddressOrUID, cd=None, email_type='user'): return normalizedEmailAddressOrUID def buildGAPIServiceObject(api, act_as, showAuthError=True): - http = gapi.create_http(cache=GM_Globals[GM_CACHE_DIR]) + http = transport.create_http(cache=GM_Globals[GM_CACHE_DIR]) service = getService(api, http) GM_Globals[GM_CURRENT_API_USER] = act_as GM_Globals[GM_CURRENT_API_SCOPES] = API_SCOPE_MAPPING.get(api, service._rootDesc['auth']['oauth2']['scopes']) credentials = getSvcAcctCredentials(GM_Globals[GM_CURRENT_API_SCOPES], act_as) - request = google_auth_httplib2.Request(http) + request = transport.create_request(http) retries = 3 for n in range(1, retries+1): try: credentials.refresh(request) - service._http = google_auth_httplib2.AuthorizedHttp(credentials, http=http) + service._http = transport.AuthorizedHttp(credentials, http=http) break except (httplib2.ServerNotFoundError, RuntimeError) as e: if n != retries: @@ -1125,13 +1100,13 @@ def doCheckServiceAccount(users): else: time_status = 'FAIL' printPassFail(MESSAGE_YOUR_SYSTEM_TIME_DIFFERS_FROM_GOOGLE_BY % nicetime, time_status) - oa2 = googleapiclient.discovery.build('oauth2', 'v1', gapi.create_http()) + oa2 = googleapiclient.discovery.build('oauth2', 'v1', transport.create_http()) print('Service Account Private Key Authentication:') # We are explicitly not doing DwD here, just confirming service account can auth auth_error = '' try: credentials = getSvcAcctCredentials([USERINFO_EMAIL_SCOPE], None) - request = google_auth_httplib2.Request(gapi.create_http()) + request = transport.create_request() credentials.refresh(request) sa_token_info = gapi.call(oa2, 'tokeninfo', access_token=credentials.token) if sa_token_info: @@ -1151,7 +1126,7 @@ def doCheckServiceAccount(users): for user in users: user = user.lower() all_scopes_pass = True - oa2 = googleapiclient.discovery.build('oauth2', 'v1', gapi.create_http()) + oa2 = googleapiclient.discovery.build('oauth2', 'v1', transport.create_http()) print('Domain-Wide Delegation authentication as %s:' % (user)) for scope in check_scopes: # try with and without email scope @@ -3759,7 +3734,7 @@ def doPhoto(users): filename = filename.replace('#username#', user[:user.find('@')]) print("Updating photo for %s with %s (%s/%s)" % (user, filename, i, count)) if re.match('^(ht|f)tps?://.*$', filename): - simplehttp = gapi.create_http() + simplehttp = transport.create_http() try: (_, image_data) = simplehttp.request(filename, 'GET') except (httplib2.HttpLib2Error, httplib2.ServerNotFoundError) as e: @@ -7315,7 +7290,7 @@ def getUserAttributes(i, cd, updateCmd): class ShortURLFlow(google_auth_oauthlib.flow.InstalledAppFlow): def authorization_url(self, **kwargs): long_url, state = super(ShortURLFlow, self).authorization_url(**kwargs) - simplehttp = gapi.create_http(timeout=10) + simplehttp = transport.create_http(timeout=10) url_shortnr = 'https://gam-shortn.appspot.com/create' headers = {'Content-Type': 'application/json', 'user-agent': GAM_INFO} @@ -7364,7 +7339,7 @@ def getCRMService(login_hint): client_id = '297408095146-fug707qsjv4ikron0hugpevbrjhkmsk7.apps.googleusercontent.com' client_secret = 'qM3dP8f_4qedwzWQE1VR4zzU' credentials = _run_oauth_flow(client_id, client_secret, scopes, 'online', login_hint) - httpc = google_auth_httplib2.AuthorizedHttp(credentials) + httpc = transport.AuthorizedHttp(credentials) return (googleapiclient.discovery.build('cloudresourcemanager', 'v1', http=httpc, cache_discovery=False, discoveryServiceUrl=googleapiclient.discovery.V2_DISCOVERY_URI), @@ -7378,7 +7353,7 @@ def getCRM2Service(httpc): def getGAMProjectFile(filepath): file_url = GAM_PROJECT_FILEPATH+filepath - httpObj = gapi.create_http() + httpObj = transport.create_http() _, c = httpObj.request(file_url, 'GET') return c.decode(UTF8) @@ -7523,7 +7498,7 @@ def _createClientSecretsOauth2service(httpObj, projectId): client_secret = input('Enter your Client Secret: ').strip() if not client_secret: client_secret = input().strip() - simplehttp = gapi.create_http() + simplehttp = transport.create_http() client_valid = _checkClientAndSecret(simplehttp, client_id, client_secret) if client_valid: break @@ -9910,8 +9885,8 @@ def doCreateResoldCustomer(): def _getValueFromOAuth(field, credentials=None): if not GC_Values[GC_DECODED_ID_TOKEN]: credentials = credentials if credentials is not None else getValidOauth2TxtCredentials() - http = google_auth_httplib2.Request(gapi.create_http()) - GC_Values[GC_DECODED_ID_TOKEN] = google.oauth2.id_token.verify_oauth2_token(credentials.id_token, http) + request = transport.create_request() + GC_Values[GC_DECODED_ID_TOKEN] = google.oauth2.id_token.verify_oauth2_token(credentials.id_token, request) return GC_Values[GC_DECODED_ID_TOKEN].get(field, 'Unknown') def doGetMemberInfo(): @@ -10587,7 +10562,7 @@ def doSiteVerifyAttempt(): print('Method: %s' % verify_data['method']) print('Expected Token: %s' % verify_data['token']) if verify_data['method'] in ['DNS_CNAME', 'DNS_TXT']: - simplehttp = gapi.create_http() + simplehttp = transport.create_http() base_url = 'https://dns.google/resolve?' query_params = {} if verify_data['method'] == 'DNS_CNAME': @@ -13182,7 +13157,7 @@ def doDeleteOAuth(): credentials = getOauth2TxtStorageCredentials() if credentials is None: return - simplehttp = gapi.create_http() + simplehttp = transport.create_http() params = {'token': credentials.refresh_token} revoke_uri = 'https://accounts.google.com/o/oauth2/revoke?%s' % urlencode(params) sys.stderr.write('This OAuth token will self-destruct in 3...') diff --git a/src/gapi/__init__.py b/src/gapi/__init__.py index 9bb780f5..cb66c006 100644 --- a/src/gapi/__init__.py +++ b/src/gapi/__init__.py @@ -9,42 +9,13 @@ import httplib2 import controlflow import display from gapi import errors -from var import (GC_CA_FILE, GC_Values, GC_TLS_MIN_VERSION, GC_TLS_MAX_VERSION, - GM_Globals, GM_CURRENT_API_SCOPES, GM_CURRENT_API_USER, +import transport +from var import (GM_Globals, GM_CURRENT_API_SCOPES, GM_CURRENT_API_USER, GM_EXTRA_ARGS_DICT, GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID, MAX_RESULTS_API_EXCEPTIONS, MESSAGE_API_ACCESS_CONFIG, MESSAGE_API_ACCESS_DENIED, MESSAGE_SERVICE_NOT_APPLICABLE) -def create_http(cache=None, - timeout=None, - override_min_tls=None, - override_max_tls=None): - """Creates a uniform HTTP transport object. - - Args: - cache: The HTTP cache to use. - timeout: The cache timeout, in seconds. - override_min_tls: The minimum TLS version to require. If not provided, the - default is used. - override_max_tls: The maximum TLS version to require. If not provided, the - default is used. - - Returns: - httplib2.Http with the specified options. - """ - tls_minimum_version = override_min_tls if override_min_tls else GC_Values[ - GC_TLS_MIN_VERSION] - tls_maximum_version = override_max_tls if override_max_tls else GC_Values[ - GC_TLS_MAX_VERSION] - return httplib2.Http( - ca_certs=GC_Values[GC_CA_FILE], - tls_maximum_version=tls_maximum_version, - tls_minimum_version=tls_minimum_version, - cache=cache, - timeout=timeout) - - def call(service, function, silent_errors=False, @@ -79,20 +50,20 @@ def call(service, method = getattr(service, function) retries = 10 parameters = dict( - list(kwargs.items()) + list(GM_Globals[GM_EXTRA_ARGS_DICT].items())) + list(kwargs.items()) + list(GM_Globals[GM_EXTRA_ARGS_DICT].items())) for n in range(1, retries + 1): try: return method(**parameters).execute() except googleapiclient.errors.HttpError as e: http_status, reason, message = errors.get_gapi_error_detail( - e, - soft_errors=soft_errors, - silent_errors=silent_errors, - retry_on_http_error=n < 3) + e, + soft_errors=soft_errors, + silent_errors=silent_errors, + retry_on_http_error=n < 3) if http_status == -1: # The error detail indicated that we should retry this request # We'll refresh credentials and make another pass - service._http.request.credentials.refresh(create_http()) + service._http.request.credentials.refresh(transport.create_http()) continue if http_status == 0: return None @@ -101,7 +72,7 @@ def call(service, if is_known_error_reason and errors.ErrorReason(reason) in throw_reasons: if errors.ErrorReason(reason) in errors.ERROR_REASON_TO_EXCEPTION: raise errors.ERROR_REASON_TO_EXCEPTION[errors.ErrorReason(reason)]( - message) + message) raise e if (n != retries) and (is_known_error_reason and errors.ErrorReason( reason) in errors.DEFAULT_RETRY_REASONS + retry_reasons): @@ -114,16 +85,16 @@ def call(service, ': Giving up.'][n > 1])) return None controlflow.system_error_exit( - int(http_status), '{0}: {1} - {2}'.format(http_status, message, - reason)) + int(http_status), '{0}: {1} - {2}'.format(http_status, message, + reason)) except google.auth.exceptions.RefreshError as e: handle_oauth_token_error( - e, soft_errors or - errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons) + e, soft_errors or + errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons) if errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons: raise errors.GapiServiceNotAvailableError(str(e)) display.print_error('User {0}: {1}'.format( - GM_Globals[GM_CURRENT_API_USER], str(e))) + GM_Globals[GM_CURRENT_API_USER], str(e))) return None except ValueError as e: if hasattr(service._http, 'cache') and service._http.cache is not None: @@ -165,11 +136,11 @@ def get_items(service, The list of items in the first page of a response. """ results = call( - service, - function, - throw_reasons=throw_reasons, - retry_reasons=retry_reasons, - **kwargs) + service, + function, + throw_reasons=throw_reasons, + retry_reasons=retry_reasons, + **kwargs) if results: return results.get(items, []) return [] @@ -200,7 +171,7 @@ def _get_max_page_size_for_api_call(service, function, **kwargs): return None known_api_max = MAX_RESULTS_API_EXCEPTIONS.get(api_id) max_results = a_method['parameters']['maxResults'].get( - 'maximum', known_api_max) + 'maximum', known_api_max) return {'maxResults': max_results} return None @@ -258,13 +229,13 @@ def get_all_pages(service, total_items = 0 while True: page = call( - service, - function, - soft_errors=soft_errors, - throw_reasons=throw_reasons, - retry_reasons=retry_reasons, - pageToken=page_token, - **kwargs) + service, + function, + soft_errors=soft_errors, + throw_reasons=throw_reasons, + retry_reasons=retry_reasons, + pageToken=page_token, + **kwargs) if page: page_token = page.get('nextPageToken') page_items = page.get(items, []) @@ -282,9 +253,9 @@ def get_all_pages(service, first_item = page_items[0] if num_page_items > 0 else {} last_item = page_items[-1] if num_page_items > 1 else first_item show_message = show_message.replace( - '%%first_item%%', str(first_item.get(message_attribute, ''))) + '%%first_item%%', str(first_item.get(message_attribute, ''))) show_message = show_message.replace( - '%%last_item%%', str(last_item.get(message_attribute, ''))) + '%%last_item%%', str(last_item.get(message_attribute, ''))) sys.stderr.write('\r') sys.stderr.flush() sys.stderr.write(show_message) @@ -314,14 +285,14 @@ def handle_oauth_token_error(e, soft_errors): return if not GM_Globals[GM_CURRENT_API_USER]: display.print_error( - MESSAGE_API_ACCESS_DENIED.format( - GM_Globals[GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID], - ','.join(GM_Globals[GM_CURRENT_API_SCOPES]))) + MESSAGE_API_ACCESS_DENIED.format( + GM_Globals[GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID], + ','.join(GM_Globals[GM_CURRENT_API_SCOPES]))) controlflow.system_error_exit(12, MESSAGE_API_ACCESS_CONFIG) else: controlflow.system_error_exit( - 19, - MESSAGE_SERVICE_NOT_APPLICABLE.format( - GM_Globals[GM_CURRENT_API_USER])) + 19, + MESSAGE_SERVICE_NOT_APPLICABLE.format( + GM_Globals[GM_CURRENT_API_USER])) controlflow.system_error_exit(18, 'Authentication Token Error - {0}'.format(e)) diff --git a/src/gapi/__init___test.py b/src/gapi/__init___test.py index 8758446d..540ce11f 100644 --- a/src/gapi/__init___test.py +++ b/src/gapi/__init___test.py @@ -38,40 +38,6 @@ def create_http_error(status, reason, message): return gapi.googleapiclient.errors.HttpError(response, content_bytes) -class CreateHttpTest(unittest.TestCase): - - def setUp(self): - SetGlobalVariables() - super(CreateHttpTest, self).setUp() - - def test_create_http_sets_default_values_on_http(self): - http = gapi.create_http() - self.assertIsNone(http.cache) - self.assertIsNone(http.timeout) - self.assertEqual(http.tls_minimum_version, - gapi.GC_Values[gapi.GC_TLS_MIN_VERSION]) - self.assertEqual(http.tls_maximum_version, - gapi.GC_Values[gapi.GC_TLS_MAX_VERSION]) - self.assertEqual(http.ca_certs, gapi.GC_Values[gapi.GC_CA_FILE]) - - def test_create_http_sets_tls_min_version(self): - http = gapi.create_http(override_min_tls=1111) - self.assertEqual(http.tls_minimum_version, 1111) - - def test_create_http_sets_tls_max_version(self): - http = gapi.create_http(override_max_tls=9999) - self.assertEqual(http.tls_maximum_version, 9999) - - def test_create_http_sets_cache(self): - fake_cache = {} - http = gapi.create_http(cache=fake_cache) - self.assertEqual(http.cache, fake_cache) - - def test_create_http_sets_cache_timeout(self): - http = gapi.create_http(timeout=1234) - self.assertEqual(http.timeout, 1234) - - class GapiTest(unittest.TestCase): def setUp(self): diff --git a/src/transport.py b/src/transport.py new file mode 100644 index 00000000..f19dbff5 --- /dev/null +++ b/src/transport.py @@ -0,0 +1,100 @@ +"""Methods related to network transport.""" + +import google_auth_httplib2 +import httplib2 + +from var import GAM_INFO +from var import GC_CA_FILE +from var import GC_TLS_MAX_VERSION +from var import GC_TLS_MIN_VERSION +from var import GC_Values + + +def create_http(cache=None, + timeout=None, + override_min_tls=None, + override_max_tls=None): + """Creates a uniform HTTP transport object. + + Args: + cache: The HTTP cache to use. + timeout: The cache timeout, in seconds. + override_min_tls: The minimum TLS version to require. If not provided, the + default is used. + override_max_tls: The maximum TLS version to require. If not provided, the + default is used. + + Returns: + httplib2.Http with the specified options. + """ + tls_minimum_version = override_min_tls if override_min_tls else GC_Values[ + GC_TLS_MIN_VERSION] + tls_maximum_version = override_max_tls if override_max_tls else GC_Values[ + GC_TLS_MAX_VERSION] + return httplib2.Http( + ca_certs=GC_Values[GC_CA_FILE], + tls_maximum_version=tls_maximum_version, + tls_minimum_version=tls_minimum_version, + cache=cache, + timeout=timeout) + + +def create_request(http=None): + """Creates a uniform Request object with a default http, if not provided. + + Args: + http: Optional httplib2.Http compatible object to be used with the request. + If not provided, a default HTTP will be used. + + Returns: + Request: A google_auth_httplib2.Request compatible Request. + """ + if not http: + http = create_http() + return Request(http) + + +GAM_USER_AGENT = GAM_INFO + + +def _force_user_agent(user_agent): + """Creates a decorator which can force a user agent in HTTP headers.""" + + def decorator(request_method): + """Wraps a request method to insert a user-agent in HTTP headers.""" + + def wrapped_request_method(*args, **kwargs): + """Modifies HTTP headers to include a specified user-agent.""" + if kwargs.get('headers') is not None: + if kwargs['headers'].get('user-agent'): + if user_agent not in kwargs['headers']['user-agent']: + # Save the existing user-agent header and tack on our own. + kwargs['headers']['user-agent'] = '%s %s' % ( + user_agent, kwargs['headers']['user-agent']) + else: + kwargs['headers']['user-agent'] = user_agent + else: + kwargs['headers'] = {'user-agent': user_agent} + return request_method(*args, **kwargs) + + return wrapped_request_method + + return decorator + + +class Request(google_auth_httplib2.Request): + """A Request which forces a user agent.""" + + @_force_user_agent(GAM_USER_AGENT) + def __call__(self, *args, **kwargs): + """Inserts the GAM user-agent header in requests.""" + return super(Request, self).__call__(*args, **kwargs) + + +class AuthorizedHttp(google_auth_httplib2.AuthorizedHttp): + """An AuthorizedHttp which forces a user agent during requests.""" + + @_force_user_agent(GAM_USER_AGENT) + def request(self, *args, **kwargs): + """Inserts the GAM user-agent header in requests.""" + return super(AuthorizedHttp, self).request(*args, **kwargs) diff --git a/src/transport_test.py b/src/transport_test.py new file mode 100644 index 00000000..f6374ab9 --- /dev/null +++ b/src/transport_test.py @@ -0,0 +1,179 @@ +"""Tests for transport.""" + +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +from gam import SetGlobalVariables +import google_auth_httplib2 +import httplib2 + +import transport + + +class CreateHttpTest(unittest.TestCase): + + def setUp(self): + SetGlobalVariables() + super(CreateHttpTest, self).setUp() + + def test_create_http_sets_default_values_on_http(self): + http = transport.create_http() + self.assertIsNone(http.cache) + self.assertIsNone(http.timeout) + self.assertEqual(http.tls_minimum_version, + transport.GC_Values[transport.GC_TLS_MIN_VERSION]) + self.assertEqual(http.tls_maximum_version, + transport.GC_Values[transport.GC_TLS_MAX_VERSION]) + self.assertEqual(http.ca_certs, transport.GC_Values[transport.GC_CA_FILE]) + + def test_create_http_sets_tls_min_version(self): + http = transport.create_http(override_min_tls='TLSv1_1') + self.assertEqual(http.tls_minimum_version, 'TLSv1_1') + + def test_create_http_sets_tls_max_version(self): + http = transport.create_http(override_max_tls='TLSv1_3') + self.assertEqual(http.tls_maximum_version, 'TLSv1_3') + + def test_create_http_sets_cache(self): + fake_cache = {} + http = transport.create_http(cache=fake_cache) + self.assertEqual(http.cache, fake_cache) + + def test_create_http_sets_cache_timeout(self): + http = transport.create_http(timeout=1234) + self.assertEqual(http.timeout, 1234) + + +class TransportTest(unittest.TestCase): + + def setUp(self): + self.mock_http = MagicMock(spec=httplib2.Http) + self.mock_response = MagicMock(spec=httplib2.Response) + self.mock_content = MagicMock() + self.mock_http.request.return_value = (self.mock_response, + self.mock_content) + self.mock_credentials = MagicMock() + self.test_uri = 'http://example.com' + super(TransportTest, self).setUp() + + @patch.object(transport, 'create_http') + def test_create_request_uses_default_http(self, mock_create_http): + request = transport.create_request() + self.assertEqual(request.http, mock_create_http.return_value) + + def test_create_request_uses_provided_http(self): + request = transport.create_request(http=self.mock_http) + self.assertEqual(request.http, self.mock_http) + + def test_create_request_returns_request_with_forced_user_agent(self): + request = transport.create_request() + self.assertIsInstance(request, transport.Request) + + def test_request_is_google_auth_httplib2_compatible(self): + request = transport.create_request() + self.assertIsInstance(request, google_auth_httplib2.Request) + + def test_request_call_returns_response_content(self): + request = transport.Request(self.mock_http) + response = request(self.test_uri) + self.assertEqual(self.mock_response.status, response.status) + self.assertEqual(self.mock_content, response.data) + + def test_request_call_forces_user_agent_no_provided_headers(self): + request = transport.Request(self.mock_http) + + request(self.test_uri) + headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', headers) + self.assertIn(transport.GAM_USER_AGENT, headers['user-agent']) + + def test_request_call_forces_user_agent_no_agent_in_headers(self): + request = transport.Request(self.mock_http) + fake_request_headers = {'some-header-thats-not-a-user-agent': 'someData'} + + request(self.test_uri, headers=fake_request_headers) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + self.assertIn('some-header-thats-not-a-user-agent', final_headers) + self.assertEqual('someData', + final_headers['some-header-thats-not-a-user-agent']) + + def test_request_call_forces_user_agent_with_another_agent_in_headers(self): + request = transport.Request(self.mock_http) + headers_with_user_agent = {'user-agent': 'existing-user-agent'} + + request(self.test_uri, headers=headers_with_user_agent) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn('existing-user-agent', final_headers['user-agent']) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + + def test_request_call_same_user_agent_already_in_headers(self): + request = transport.Request(self.mock_http) + same_user_agent_header = {'user-agent': transport.GAM_USER_AGENT} + + request(self.test_uri, headers=same_user_agent_header) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + # Make sure the header wasn't duplicated + self.assertEqual( + len(transport.GAM_USER_AGENT), len(final_headers['user-agent'])) + + def test_authorizedhttp_is_google_auth_httplib2_compatible(self): + http = transport.AuthorizedHttp(self.mock_credentials) + self.assertIsInstance(http, google_auth_httplib2.AuthorizedHttp) + + def test_authorizedhttp_request_returns_response_content(self): + http = transport.AuthorizedHttp(self.mock_credentials, http=self.mock_http) + response, content = http.request(self.test_uri) + self.assertEqual(self.mock_response, response) + self.assertEqual(self.mock_content, content) + + def test_authorizedhttp_request_forces_user_agent_no_provided_headers(self): + authorized_http = transport.AuthorizedHttp( + self.mock_credentials, http=self.mock_http) + authorized_http.request(self.test_uri) + headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', headers) + self.assertIn(transport.GAM_USER_AGENT, headers['user-agent']) + + def test_authorizedhttp_request_forces_user_agent_no_agent_in_headers(self): + authorized_http = transport.AuthorizedHttp( + self.mock_credentials, http=self.mock_http) + fake_request_headers = {'some-header-thats-not-a-user-agent': 'someData'} + + authorized_http.request(self.test_uri, headers=fake_request_headers) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + self.assertIn('some-header-thats-not-a-user-agent', final_headers) + self.assertEqual('someData', + final_headers['some-header-thats-not-a-user-agent']) + + def test_authorizedhttp_request_forces_user_agent_with_another_agent_in_headers( + self): + authorized_http = transport.AuthorizedHttp( + self.mock_credentials, http=self.mock_http) + headers_with_user_agent = {'user-agent': 'existing-user-agent'} + + authorized_http.request(self.test_uri, headers=headers_with_user_agent) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn('existing-user-agent', final_headers['user-agent']) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + + def test_authorizedhttp_request_same_user_agent_already_in_headers(self): + authorized_http = transport.AuthorizedHttp( + self.mock_credentials, http=self.mock_http) + same_user_agent_header = {'user-agent': transport.GAM_USER_AGENT} + + authorized_http.request(self.test_uri, headers=same_user_agent_header) + final_headers = self.mock_http.request.call_args[1]['headers'] + self.assertIn('user-agent', final_headers) + self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent']) + # Make sure the header wasn't duplicated + self.assertEqual( + len(transport.GAM_USER_AGENT), len(final_headers['user-agent']))