mirror of
https://github.com/GAM-team/GAM.git
synced 2026-06-28 09:51:36 +00:00
Move transport customizations to their own module (#1071)
* Move transport customizations to their own module
This helps to accomplish a few things:
1) Makes the forced user-agent customization on HTTP requests a bit
clearer by subclassing the targeted objects (as opposed to hiding the
behavior behind a forced override of the google_auth_httplib2 object
methods)
2) Standardizes the creation of HTTP objects. These objects can still
largely be customized, but using a single creation mechanism will
standardize a default and streamline creation, thereby decreasing the
code that would otherwise be replicated in the caller
3) Moves create_http() to a more general purpose module, since it will
likely be used by more than gapi-related methods.
* Use string values for TLS version tests
More closely matches [existing
behavior](4fb73e6073/src/var.py (L853))
that sets the global default to a string value.
This commit is contained in:
67
src/gam.py
67
src/gam.py
@@ -65,7 +65,6 @@ import googleapiclient.http
|
||||
import google.oauth2.id_token
|
||||
import google.oauth2.service_account
|
||||
import google_auth_oauthlib.flow
|
||||
import google_auth_httplib2
|
||||
import httplib2
|
||||
|
||||
from cryptography import x509
|
||||
@@ -81,6 +80,7 @@ import display
|
||||
import fileutils
|
||||
import gapi.errors
|
||||
import gapi
|
||||
import transport
|
||||
import utils
|
||||
from var import *
|
||||
|
||||
@@ -98,31 +98,6 @@ else:
|
||||
# Source code
|
||||
GM_Globals[GM_GAM_PATH] = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# Override and wrap google_auth_httplib2 request methods so that the GAM
|
||||
# user-agent string is inserted into HTTP request headers.
|
||||
def _request_with_user_agent(request_method):
|
||||
"""Inserts the GAM user-agent header kwargs sent to a method."""
|
||||
GAM_USER_AGENT = GAM_INFO
|
||||
|
||||
def wrapped_request_method(self, *args, **kwargs):
|
||||
if kwargs.get('headers') is not None:
|
||||
if kwargs['headers'].get('user-agent'):
|
||||
if GAM_USER_AGENT not in kwargs['headers']['user-agent']:
|
||||
# Save the existing user-agent header and tack on the GAM user-agent.
|
||||
kwargs['headers']['user-agent'] = '%s %s' % (GAM_USER_AGENT, kwargs['headers']['user-agent'])
|
||||
else:
|
||||
kwargs['headers']['user-agent'] = GAM_USER_AGENT
|
||||
else:
|
||||
kwargs['headers'] = {'user-agent': GAM_USER_AGENT}
|
||||
return request_method(self, *args, **kwargs)
|
||||
|
||||
return wrapped_request_method
|
||||
|
||||
google_auth_httplib2.Request.__call__ = _request_with_user_agent(
|
||||
google_auth_httplib2.Request.__call__)
|
||||
google_auth_httplib2.AuthorizedHttp.request = _request_with_user_agent(
|
||||
google_auth_httplib2.AuthorizedHttp.request)
|
||||
|
||||
def showUsage():
|
||||
doGAMVersion(checkForArgs=False)
|
||||
print('''
|
||||
@@ -649,7 +624,7 @@ def getLocalGoogleTimeOffset(testLocation='www.googleapis.com'):
|
||||
# we disable SSL verify so we can still get time even if clock
|
||||
# is way off. This could be spoofed / MitM but we'll fail for those
|
||||
# situations everywhere else but here.
|
||||
badhttp = gapi.create_http()
|
||||
badhttp = transport.create_http()
|
||||
badhttp.disable_ssl_certificate_validation = True
|
||||
googleUTC = dateutil.parser.parse(badhttp.request('https://'+testLocation, 'HEAD')[0]['date'])
|
||||
except (httplib2.ServerNotFoundError, RuntimeError, ValueError) as e:
|
||||
@@ -682,7 +657,7 @@ def doGAMCheckForUpdates(forceCheck=False):
|
||||
return
|
||||
check_url = GAM_LATEST_RELEASE # latest full release
|
||||
headers = {'Accept': 'application/vnd.github.v3.text+json'}
|
||||
simplehttp = gapi.create_http(timeout=10)
|
||||
simplehttp = transport.create_http(timeout=10)
|
||||
try:
|
||||
(_, c) = simplehttp.request(check_url, 'GET', headers=headers)
|
||||
try:
|
||||
@@ -785,7 +760,7 @@ def _getServerTLSUsed(location):
|
||||
url = 'https://%s' % location
|
||||
_, netloc, _, _, _, _ = urlparse(url)
|
||||
conn = 'https:%s' % netloc
|
||||
httpc = gapi.create_http()
|
||||
httpc = transport.create_http()
|
||||
headers = {'user-agent': GAM_INFO}
|
||||
retries = 5
|
||||
for n in range(1, retries+1):
|
||||
@@ -874,7 +849,7 @@ def getValidOauth2TxtCredentials(force_refresh=False):
|
||||
retries = 3
|
||||
for n in range(1, retries+1):
|
||||
try:
|
||||
credentials.refresh(google_auth_httplib2.Request(gapi.create_http()))
|
||||
credentials.refresh(transport.create_request())
|
||||
writeCredentials(credentials)
|
||||
break
|
||||
except google.auth.exceptions.RefreshError as e:
|
||||
@@ -949,7 +924,7 @@ def buildGAPIObject(api):
|
||||
GM_Globals[GM_CURRENT_API_USER] = None
|
||||
credentials = getValidOauth2TxtCredentials()
|
||||
credentials.user_agent = GAM_INFO
|
||||
http = google_auth_httplib2.AuthorizedHttp(credentials, gapi.create_http(cache=GM_Globals[GM_CACHE_DIR]))
|
||||
http = transport.AuthorizedHttp(credentials, transport.create_http(cache=GM_Globals[GM_CACHE_DIR]))
|
||||
service = getService(api, http)
|
||||
if GC_Values[GC_DOMAIN]:
|
||||
if not GC_Values[GC_CUSTOMER_ID]:
|
||||
@@ -1040,17 +1015,17 @@ def convertEmailAddressToUID(emailAddressOrUID, cd=None, email_type='user'):
|
||||
return normalizedEmailAddressOrUID
|
||||
|
||||
def buildGAPIServiceObject(api, act_as, showAuthError=True):
|
||||
http = gapi.create_http(cache=GM_Globals[GM_CACHE_DIR])
|
||||
http = transport.create_http(cache=GM_Globals[GM_CACHE_DIR])
|
||||
service = getService(api, http)
|
||||
GM_Globals[GM_CURRENT_API_USER] = act_as
|
||||
GM_Globals[GM_CURRENT_API_SCOPES] = API_SCOPE_MAPPING.get(api, service._rootDesc['auth']['oauth2']['scopes'])
|
||||
credentials = getSvcAcctCredentials(GM_Globals[GM_CURRENT_API_SCOPES], act_as)
|
||||
request = google_auth_httplib2.Request(http)
|
||||
request = transport.create_request(http)
|
||||
retries = 3
|
||||
for n in range(1, retries+1):
|
||||
try:
|
||||
credentials.refresh(request)
|
||||
service._http = google_auth_httplib2.AuthorizedHttp(credentials, http=http)
|
||||
service._http = transport.AuthorizedHttp(credentials, http=http)
|
||||
break
|
||||
except (httplib2.ServerNotFoundError, RuntimeError) as e:
|
||||
if n != retries:
|
||||
@@ -1125,13 +1100,13 @@ def doCheckServiceAccount(users):
|
||||
else:
|
||||
time_status = 'FAIL'
|
||||
printPassFail(MESSAGE_YOUR_SYSTEM_TIME_DIFFERS_FROM_GOOGLE_BY % nicetime, time_status)
|
||||
oa2 = googleapiclient.discovery.build('oauth2', 'v1', gapi.create_http())
|
||||
oa2 = googleapiclient.discovery.build('oauth2', 'v1', transport.create_http())
|
||||
print('Service Account Private Key Authentication:')
|
||||
# We are explicitly not doing DwD here, just confirming service account can auth
|
||||
auth_error = ''
|
||||
try:
|
||||
credentials = getSvcAcctCredentials([USERINFO_EMAIL_SCOPE], None)
|
||||
request = google_auth_httplib2.Request(gapi.create_http())
|
||||
request = transport.create_request()
|
||||
credentials.refresh(request)
|
||||
sa_token_info = gapi.call(oa2, 'tokeninfo', access_token=credentials.token)
|
||||
if sa_token_info:
|
||||
@@ -1151,7 +1126,7 @@ def doCheckServiceAccount(users):
|
||||
for user in users:
|
||||
user = user.lower()
|
||||
all_scopes_pass = True
|
||||
oa2 = googleapiclient.discovery.build('oauth2', 'v1', gapi.create_http())
|
||||
oa2 = googleapiclient.discovery.build('oauth2', 'v1', transport.create_http())
|
||||
print('Domain-Wide Delegation authentication as %s:' % (user))
|
||||
for scope in check_scopes:
|
||||
# try with and without email scope
|
||||
@@ -3759,7 +3734,7 @@ def doPhoto(users):
|
||||
filename = filename.replace('#username#', user[:user.find('@')])
|
||||
print("Updating photo for %s with %s (%s/%s)" % (user, filename, i, count))
|
||||
if re.match('^(ht|f)tps?://.*$', filename):
|
||||
simplehttp = gapi.create_http()
|
||||
simplehttp = transport.create_http()
|
||||
try:
|
||||
(_, image_data) = simplehttp.request(filename, 'GET')
|
||||
except (httplib2.HttpLib2Error, httplib2.ServerNotFoundError) as e:
|
||||
@@ -7315,7 +7290,7 @@ def getUserAttributes(i, cd, updateCmd):
|
||||
class ShortURLFlow(google_auth_oauthlib.flow.InstalledAppFlow):
|
||||
def authorization_url(self, **kwargs):
|
||||
long_url, state = super(ShortURLFlow, self).authorization_url(**kwargs)
|
||||
simplehttp = gapi.create_http(timeout=10)
|
||||
simplehttp = transport.create_http(timeout=10)
|
||||
url_shortnr = 'https://gam-shortn.appspot.com/create'
|
||||
headers = {'Content-Type': 'application/json',
|
||||
'user-agent': GAM_INFO}
|
||||
@@ -7364,7 +7339,7 @@ def getCRMService(login_hint):
|
||||
client_id = '297408095146-fug707qsjv4ikron0hugpevbrjhkmsk7.apps.googleusercontent.com'
|
||||
client_secret = 'qM3dP8f_4qedwzWQE1VR4zzU'
|
||||
credentials = _run_oauth_flow(client_id, client_secret, scopes, 'online', login_hint)
|
||||
httpc = google_auth_httplib2.AuthorizedHttp(credentials)
|
||||
httpc = transport.AuthorizedHttp(credentials)
|
||||
return (googleapiclient.discovery.build('cloudresourcemanager', 'v1',
|
||||
http=httpc, cache_discovery=False,
|
||||
discoveryServiceUrl=googleapiclient.discovery.V2_DISCOVERY_URI),
|
||||
@@ -7378,7 +7353,7 @@ def getCRM2Service(httpc):
|
||||
|
||||
def getGAMProjectFile(filepath):
|
||||
file_url = GAM_PROJECT_FILEPATH+filepath
|
||||
httpObj = gapi.create_http()
|
||||
httpObj = transport.create_http()
|
||||
_, c = httpObj.request(file_url, 'GET')
|
||||
return c.decode(UTF8)
|
||||
|
||||
@@ -7523,7 +7498,7 @@ def _createClientSecretsOauth2service(httpObj, projectId):
|
||||
client_secret = input('Enter your Client Secret: ').strip()
|
||||
if not client_secret:
|
||||
client_secret = input().strip()
|
||||
simplehttp = gapi.create_http()
|
||||
simplehttp = transport.create_http()
|
||||
client_valid = _checkClientAndSecret(simplehttp, client_id, client_secret)
|
||||
if client_valid:
|
||||
break
|
||||
@@ -9910,8 +9885,8 @@ def doCreateResoldCustomer():
|
||||
def _getValueFromOAuth(field, credentials=None):
|
||||
if not GC_Values[GC_DECODED_ID_TOKEN]:
|
||||
credentials = credentials if credentials is not None else getValidOauth2TxtCredentials()
|
||||
http = google_auth_httplib2.Request(gapi.create_http())
|
||||
GC_Values[GC_DECODED_ID_TOKEN] = google.oauth2.id_token.verify_oauth2_token(credentials.id_token, http)
|
||||
request = transport.create_request()
|
||||
GC_Values[GC_DECODED_ID_TOKEN] = google.oauth2.id_token.verify_oauth2_token(credentials.id_token, request)
|
||||
return GC_Values[GC_DECODED_ID_TOKEN].get(field, 'Unknown')
|
||||
|
||||
def doGetMemberInfo():
|
||||
@@ -10587,7 +10562,7 @@ def doSiteVerifyAttempt():
|
||||
print('Method: %s' % verify_data['method'])
|
||||
print('Expected Token: %s' % verify_data['token'])
|
||||
if verify_data['method'] in ['DNS_CNAME', 'DNS_TXT']:
|
||||
simplehttp = gapi.create_http()
|
||||
simplehttp = transport.create_http()
|
||||
base_url = 'https://dns.google/resolve?'
|
||||
query_params = {}
|
||||
if verify_data['method'] == 'DNS_CNAME':
|
||||
@@ -13182,7 +13157,7 @@ def doDeleteOAuth():
|
||||
credentials = getOauth2TxtStorageCredentials()
|
||||
if credentials is None:
|
||||
return
|
||||
simplehttp = gapi.create_http()
|
||||
simplehttp = transport.create_http()
|
||||
params = {'token': credentials.refresh_token}
|
||||
revoke_uri = 'https://accounts.google.com/o/oauth2/revoke?%s' % urlencode(params)
|
||||
sys.stderr.write('This OAuth token will self-destruct in 3...')
|
||||
|
||||
@@ -9,42 +9,13 @@ import httplib2
|
||||
import controlflow
|
||||
import display
|
||||
from gapi import errors
|
||||
from var import (GC_CA_FILE, GC_Values, GC_TLS_MIN_VERSION, GC_TLS_MAX_VERSION,
|
||||
GM_Globals, GM_CURRENT_API_SCOPES, GM_CURRENT_API_USER,
|
||||
import transport
|
||||
from var import (GM_Globals, GM_CURRENT_API_SCOPES, GM_CURRENT_API_USER,
|
||||
GM_EXTRA_ARGS_DICT, GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID,
|
||||
MAX_RESULTS_API_EXCEPTIONS, MESSAGE_API_ACCESS_CONFIG,
|
||||
MESSAGE_API_ACCESS_DENIED, MESSAGE_SERVICE_NOT_APPLICABLE)
|
||||
|
||||
|
||||
def create_http(cache=None,
|
||||
timeout=None,
|
||||
override_min_tls=None,
|
||||
override_max_tls=None):
|
||||
"""Creates a uniform HTTP transport object.
|
||||
|
||||
Args:
|
||||
cache: The HTTP cache to use.
|
||||
timeout: The cache timeout, in seconds.
|
||||
override_min_tls: The minimum TLS version to require. If not provided, the
|
||||
default is used.
|
||||
override_max_tls: The maximum TLS version to require. If not provided, the
|
||||
default is used.
|
||||
|
||||
Returns:
|
||||
httplib2.Http with the specified options.
|
||||
"""
|
||||
tls_minimum_version = override_min_tls if override_min_tls else GC_Values[
|
||||
GC_TLS_MIN_VERSION]
|
||||
tls_maximum_version = override_max_tls if override_max_tls else GC_Values[
|
||||
GC_TLS_MAX_VERSION]
|
||||
return httplib2.Http(
|
||||
ca_certs=GC_Values[GC_CA_FILE],
|
||||
tls_maximum_version=tls_maximum_version,
|
||||
tls_minimum_version=tls_minimum_version,
|
||||
cache=cache,
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def call(service,
|
||||
function,
|
||||
silent_errors=False,
|
||||
@@ -79,20 +50,20 @@ def call(service,
|
||||
method = getattr(service, function)
|
||||
retries = 10
|
||||
parameters = dict(
|
||||
list(kwargs.items()) + list(GM_Globals[GM_EXTRA_ARGS_DICT].items()))
|
||||
list(kwargs.items()) + list(GM_Globals[GM_EXTRA_ARGS_DICT].items()))
|
||||
for n in range(1, retries + 1):
|
||||
try:
|
||||
return method(**parameters).execute()
|
||||
except googleapiclient.errors.HttpError as e:
|
||||
http_status, reason, message = errors.get_gapi_error_detail(
|
||||
e,
|
||||
soft_errors=soft_errors,
|
||||
silent_errors=silent_errors,
|
||||
retry_on_http_error=n < 3)
|
||||
e,
|
||||
soft_errors=soft_errors,
|
||||
silent_errors=silent_errors,
|
||||
retry_on_http_error=n < 3)
|
||||
if http_status == -1:
|
||||
# The error detail indicated that we should retry this request
|
||||
# We'll refresh credentials and make another pass
|
||||
service._http.request.credentials.refresh(create_http())
|
||||
service._http.request.credentials.refresh(transport.create_http())
|
||||
continue
|
||||
if http_status == 0:
|
||||
return None
|
||||
@@ -101,7 +72,7 @@ def call(service,
|
||||
if is_known_error_reason and errors.ErrorReason(reason) in throw_reasons:
|
||||
if errors.ErrorReason(reason) in errors.ERROR_REASON_TO_EXCEPTION:
|
||||
raise errors.ERROR_REASON_TO_EXCEPTION[errors.ErrorReason(reason)](
|
||||
message)
|
||||
message)
|
||||
raise e
|
||||
if (n != retries) and (is_known_error_reason and errors.ErrorReason(
|
||||
reason) in errors.DEFAULT_RETRY_REASONS + retry_reasons):
|
||||
@@ -114,16 +85,16 @@ def call(service,
|
||||
': Giving up.'][n > 1]))
|
||||
return None
|
||||
controlflow.system_error_exit(
|
||||
int(http_status), '{0}: {1} - {2}'.format(http_status, message,
|
||||
reason))
|
||||
int(http_status), '{0}: {1} - {2}'.format(http_status, message,
|
||||
reason))
|
||||
except google.auth.exceptions.RefreshError as e:
|
||||
handle_oauth_token_error(
|
||||
e, soft_errors or
|
||||
errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons)
|
||||
e, soft_errors or
|
||||
errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons)
|
||||
if errors.ErrorReason.SERVICE_NOT_AVAILABLE in throw_reasons:
|
||||
raise errors.GapiServiceNotAvailableError(str(e))
|
||||
display.print_error('User {0}: {1}'.format(
|
||||
GM_Globals[GM_CURRENT_API_USER], str(e)))
|
||||
GM_Globals[GM_CURRENT_API_USER], str(e)))
|
||||
return None
|
||||
except ValueError as e:
|
||||
if hasattr(service._http, 'cache') and service._http.cache is not None:
|
||||
@@ -165,11 +136,11 @@ def get_items(service,
|
||||
The list of items in the first page of a response.
|
||||
"""
|
||||
results = call(
|
||||
service,
|
||||
function,
|
||||
throw_reasons=throw_reasons,
|
||||
retry_reasons=retry_reasons,
|
||||
**kwargs)
|
||||
service,
|
||||
function,
|
||||
throw_reasons=throw_reasons,
|
||||
retry_reasons=retry_reasons,
|
||||
**kwargs)
|
||||
if results:
|
||||
return results.get(items, [])
|
||||
return []
|
||||
@@ -200,7 +171,7 @@ def _get_max_page_size_for_api_call(service, function, **kwargs):
|
||||
return None
|
||||
known_api_max = MAX_RESULTS_API_EXCEPTIONS.get(api_id)
|
||||
max_results = a_method['parameters']['maxResults'].get(
|
||||
'maximum', known_api_max)
|
||||
'maximum', known_api_max)
|
||||
return {'maxResults': max_results}
|
||||
|
||||
return None
|
||||
@@ -258,13 +229,13 @@ def get_all_pages(service,
|
||||
total_items = 0
|
||||
while True:
|
||||
page = call(
|
||||
service,
|
||||
function,
|
||||
soft_errors=soft_errors,
|
||||
throw_reasons=throw_reasons,
|
||||
retry_reasons=retry_reasons,
|
||||
pageToken=page_token,
|
||||
**kwargs)
|
||||
service,
|
||||
function,
|
||||
soft_errors=soft_errors,
|
||||
throw_reasons=throw_reasons,
|
||||
retry_reasons=retry_reasons,
|
||||
pageToken=page_token,
|
||||
**kwargs)
|
||||
if page:
|
||||
page_token = page.get('nextPageToken')
|
||||
page_items = page.get(items, [])
|
||||
@@ -282,9 +253,9 @@ def get_all_pages(service,
|
||||
first_item = page_items[0] if num_page_items > 0 else {}
|
||||
last_item = page_items[-1] if num_page_items > 1 else first_item
|
||||
show_message = show_message.replace(
|
||||
'%%first_item%%', str(first_item.get(message_attribute, '')))
|
||||
'%%first_item%%', str(first_item.get(message_attribute, '')))
|
||||
show_message = show_message.replace(
|
||||
'%%last_item%%', str(last_item.get(message_attribute, '')))
|
||||
'%%last_item%%', str(last_item.get(message_attribute, '')))
|
||||
sys.stderr.write('\r')
|
||||
sys.stderr.flush()
|
||||
sys.stderr.write(show_message)
|
||||
@@ -314,14 +285,14 @@ def handle_oauth_token_error(e, soft_errors):
|
||||
return
|
||||
if not GM_Globals[GM_CURRENT_API_USER]:
|
||||
display.print_error(
|
||||
MESSAGE_API_ACCESS_DENIED.format(
|
||||
GM_Globals[GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID],
|
||||
','.join(GM_Globals[GM_CURRENT_API_SCOPES])))
|
||||
MESSAGE_API_ACCESS_DENIED.format(
|
||||
GM_Globals[GM_OAUTH2SERVICE_ACCOUNT_CLIENT_ID],
|
||||
','.join(GM_Globals[GM_CURRENT_API_SCOPES])))
|
||||
controlflow.system_error_exit(12, MESSAGE_API_ACCESS_CONFIG)
|
||||
else:
|
||||
controlflow.system_error_exit(
|
||||
19,
|
||||
MESSAGE_SERVICE_NOT_APPLICABLE.format(
|
||||
GM_Globals[GM_CURRENT_API_USER]))
|
||||
19,
|
||||
MESSAGE_SERVICE_NOT_APPLICABLE.format(
|
||||
GM_Globals[GM_CURRENT_API_USER]))
|
||||
controlflow.system_error_exit(18,
|
||||
'Authentication Token Error - {0}'.format(e))
|
||||
|
||||
@@ -38,40 +38,6 @@ def create_http_error(status, reason, message):
|
||||
return gapi.googleapiclient.errors.HttpError(response, content_bytes)
|
||||
|
||||
|
||||
class CreateHttpTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
SetGlobalVariables()
|
||||
super(CreateHttpTest, self).setUp()
|
||||
|
||||
def test_create_http_sets_default_values_on_http(self):
|
||||
http = gapi.create_http()
|
||||
self.assertIsNone(http.cache)
|
||||
self.assertIsNone(http.timeout)
|
||||
self.assertEqual(http.tls_minimum_version,
|
||||
gapi.GC_Values[gapi.GC_TLS_MIN_VERSION])
|
||||
self.assertEqual(http.tls_maximum_version,
|
||||
gapi.GC_Values[gapi.GC_TLS_MAX_VERSION])
|
||||
self.assertEqual(http.ca_certs, gapi.GC_Values[gapi.GC_CA_FILE])
|
||||
|
||||
def test_create_http_sets_tls_min_version(self):
|
||||
http = gapi.create_http(override_min_tls=1111)
|
||||
self.assertEqual(http.tls_minimum_version, 1111)
|
||||
|
||||
def test_create_http_sets_tls_max_version(self):
|
||||
http = gapi.create_http(override_max_tls=9999)
|
||||
self.assertEqual(http.tls_maximum_version, 9999)
|
||||
|
||||
def test_create_http_sets_cache(self):
|
||||
fake_cache = {}
|
||||
http = gapi.create_http(cache=fake_cache)
|
||||
self.assertEqual(http.cache, fake_cache)
|
||||
|
||||
def test_create_http_sets_cache_timeout(self):
|
||||
http = gapi.create_http(timeout=1234)
|
||||
self.assertEqual(http.timeout, 1234)
|
||||
|
||||
|
||||
class GapiTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
100
src/transport.py
Normal file
100
src/transport.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Methods related to network transport."""
|
||||
|
||||
import google_auth_httplib2
|
||||
import httplib2
|
||||
|
||||
from var import GAM_INFO
|
||||
from var import GC_CA_FILE
|
||||
from var import GC_TLS_MAX_VERSION
|
||||
from var import GC_TLS_MIN_VERSION
|
||||
from var import GC_Values
|
||||
|
||||
|
||||
def create_http(cache=None,
|
||||
timeout=None,
|
||||
override_min_tls=None,
|
||||
override_max_tls=None):
|
||||
"""Creates a uniform HTTP transport object.
|
||||
|
||||
Args:
|
||||
cache: The HTTP cache to use.
|
||||
timeout: The cache timeout, in seconds.
|
||||
override_min_tls: The minimum TLS version to require. If not provided, the
|
||||
default is used.
|
||||
override_max_tls: The maximum TLS version to require. If not provided, the
|
||||
default is used.
|
||||
|
||||
Returns:
|
||||
httplib2.Http with the specified options.
|
||||
"""
|
||||
tls_minimum_version = override_min_tls if override_min_tls else GC_Values[
|
||||
GC_TLS_MIN_VERSION]
|
||||
tls_maximum_version = override_max_tls if override_max_tls else GC_Values[
|
||||
GC_TLS_MAX_VERSION]
|
||||
return httplib2.Http(
|
||||
ca_certs=GC_Values[GC_CA_FILE],
|
||||
tls_maximum_version=tls_maximum_version,
|
||||
tls_minimum_version=tls_minimum_version,
|
||||
cache=cache,
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def create_request(http=None):
|
||||
"""Creates a uniform Request object with a default http, if not provided.
|
||||
|
||||
Args:
|
||||
http: Optional httplib2.Http compatible object to be used with the request.
|
||||
If not provided, a default HTTP will be used.
|
||||
|
||||
Returns:
|
||||
Request: A google_auth_httplib2.Request compatible Request.
|
||||
"""
|
||||
if not http:
|
||||
http = create_http()
|
||||
return Request(http)
|
||||
|
||||
|
||||
GAM_USER_AGENT = GAM_INFO
|
||||
|
||||
|
||||
def _force_user_agent(user_agent):
|
||||
"""Creates a decorator which can force a user agent in HTTP headers."""
|
||||
|
||||
def decorator(request_method):
|
||||
"""Wraps a request method to insert a user-agent in HTTP headers."""
|
||||
|
||||
def wrapped_request_method(*args, **kwargs):
|
||||
"""Modifies HTTP headers to include a specified user-agent."""
|
||||
if kwargs.get('headers') is not None:
|
||||
if kwargs['headers'].get('user-agent'):
|
||||
if user_agent not in kwargs['headers']['user-agent']:
|
||||
# Save the existing user-agent header and tack on our own.
|
||||
kwargs['headers']['user-agent'] = '%s %s' % (
|
||||
user_agent, kwargs['headers']['user-agent'])
|
||||
else:
|
||||
kwargs['headers']['user-agent'] = user_agent
|
||||
else:
|
||||
kwargs['headers'] = {'user-agent': user_agent}
|
||||
return request_method(*args, **kwargs)
|
||||
|
||||
return wrapped_request_method
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Request(google_auth_httplib2.Request):
|
||||
"""A Request which forces a user agent."""
|
||||
|
||||
@_force_user_agent(GAM_USER_AGENT)
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Inserts the GAM user-agent header in requests."""
|
||||
return super(Request, self).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
class AuthorizedHttp(google_auth_httplib2.AuthorizedHttp):
|
||||
"""An AuthorizedHttp which forces a user agent during requests."""
|
||||
|
||||
@_force_user_agent(GAM_USER_AGENT)
|
||||
def request(self, *args, **kwargs):
|
||||
"""Inserts the GAM user-agent header in requests."""
|
||||
return super(AuthorizedHttp, self).request(*args, **kwargs)
|
||||
179
src/transport_test.py
Normal file
179
src/transport_test.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Tests for transport."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from gam import SetGlobalVariables
|
||||
import google_auth_httplib2
|
||||
import httplib2
|
||||
|
||||
import transport
|
||||
|
||||
|
||||
class CreateHttpTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
SetGlobalVariables()
|
||||
super(CreateHttpTest, self).setUp()
|
||||
|
||||
def test_create_http_sets_default_values_on_http(self):
|
||||
http = transport.create_http()
|
||||
self.assertIsNone(http.cache)
|
||||
self.assertIsNone(http.timeout)
|
||||
self.assertEqual(http.tls_minimum_version,
|
||||
transport.GC_Values[transport.GC_TLS_MIN_VERSION])
|
||||
self.assertEqual(http.tls_maximum_version,
|
||||
transport.GC_Values[transport.GC_TLS_MAX_VERSION])
|
||||
self.assertEqual(http.ca_certs, transport.GC_Values[transport.GC_CA_FILE])
|
||||
|
||||
def test_create_http_sets_tls_min_version(self):
|
||||
http = transport.create_http(override_min_tls='TLSv1_1')
|
||||
self.assertEqual(http.tls_minimum_version, 'TLSv1_1')
|
||||
|
||||
def test_create_http_sets_tls_max_version(self):
|
||||
http = transport.create_http(override_max_tls='TLSv1_3')
|
||||
self.assertEqual(http.tls_maximum_version, 'TLSv1_3')
|
||||
|
||||
def test_create_http_sets_cache(self):
|
||||
fake_cache = {}
|
||||
http = transport.create_http(cache=fake_cache)
|
||||
self.assertEqual(http.cache, fake_cache)
|
||||
|
||||
def test_create_http_sets_cache_timeout(self):
|
||||
http = transport.create_http(timeout=1234)
|
||||
self.assertEqual(http.timeout, 1234)
|
||||
|
||||
|
||||
class TransportTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_http = MagicMock(spec=httplib2.Http)
|
||||
self.mock_response = MagicMock(spec=httplib2.Response)
|
||||
self.mock_content = MagicMock()
|
||||
self.mock_http.request.return_value = (self.mock_response,
|
||||
self.mock_content)
|
||||
self.mock_credentials = MagicMock()
|
||||
self.test_uri = 'http://example.com'
|
||||
super(TransportTest, self).setUp()
|
||||
|
||||
@patch.object(transport, 'create_http')
|
||||
def test_create_request_uses_default_http(self, mock_create_http):
|
||||
request = transport.create_request()
|
||||
self.assertEqual(request.http, mock_create_http.return_value)
|
||||
|
||||
def test_create_request_uses_provided_http(self):
|
||||
request = transport.create_request(http=self.mock_http)
|
||||
self.assertEqual(request.http, self.mock_http)
|
||||
|
||||
def test_create_request_returns_request_with_forced_user_agent(self):
|
||||
request = transport.create_request()
|
||||
self.assertIsInstance(request, transport.Request)
|
||||
|
||||
def test_request_is_google_auth_httplib2_compatible(self):
|
||||
request = transport.create_request()
|
||||
self.assertIsInstance(request, google_auth_httplib2.Request)
|
||||
|
||||
def test_request_call_returns_response_content(self):
|
||||
request = transport.Request(self.mock_http)
|
||||
response = request(self.test_uri)
|
||||
self.assertEqual(self.mock_response.status, response.status)
|
||||
self.assertEqual(self.mock_content, response.data)
|
||||
|
||||
def test_request_call_forces_user_agent_no_provided_headers(self):
|
||||
request = transport.Request(self.mock_http)
|
||||
|
||||
request(self.test_uri)
|
||||
headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, headers['user-agent'])
|
||||
|
||||
def test_request_call_forces_user_agent_no_agent_in_headers(self):
|
||||
request = transport.Request(self.mock_http)
|
||||
fake_request_headers = {'some-header-thats-not-a-user-agent': 'someData'}
|
||||
|
||||
request(self.test_uri, headers=fake_request_headers)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
self.assertIn('some-header-thats-not-a-user-agent', final_headers)
|
||||
self.assertEqual('someData',
|
||||
final_headers['some-header-thats-not-a-user-agent'])
|
||||
|
||||
def test_request_call_forces_user_agent_with_another_agent_in_headers(self):
|
||||
request = transport.Request(self.mock_http)
|
||||
headers_with_user_agent = {'user-agent': 'existing-user-agent'}
|
||||
|
||||
request(self.test_uri, headers=headers_with_user_agent)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn('existing-user-agent', final_headers['user-agent'])
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
|
||||
def test_request_call_same_user_agent_already_in_headers(self):
|
||||
request = transport.Request(self.mock_http)
|
||||
same_user_agent_header = {'user-agent': transport.GAM_USER_AGENT}
|
||||
|
||||
request(self.test_uri, headers=same_user_agent_header)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
# Make sure the header wasn't duplicated
|
||||
self.assertEqual(
|
||||
len(transport.GAM_USER_AGENT), len(final_headers['user-agent']))
|
||||
|
||||
def test_authorizedhttp_is_google_auth_httplib2_compatible(self):
|
||||
http = transport.AuthorizedHttp(self.mock_credentials)
|
||||
self.assertIsInstance(http, google_auth_httplib2.AuthorizedHttp)
|
||||
|
||||
def test_authorizedhttp_request_returns_response_content(self):
|
||||
http = transport.AuthorizedHttp(self.mock_credentials, http=self.mock_http)
|
||||
response, content = http.request(self.test_uri)
|
||||
self.assertEqual(self.mock_response, response)
|
||||
self.assertEqual(self.mock_content, content)
|
||||
|
||||
def test_authorizedhttp_request_forces_user_agent_no_provided_headers(self):
|
||||
authorized_http = transport.AuthorizedHttp(
|
||||
self.mock_credentials, http=self.mock_http)
|
||||
authorized_http.request(self.test_uri)
|
||||
headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, headers['user-agent'])
|
||||
|
||||
def test_authorizedhttp_request_forces_user_agent_no_agent_in_headers(self):
|
||||
authorized_http = transport.AuthorizedHttp(
|
||||
self.mock_credentials, http=self.mock_http)
|
||||
fake_request_headers = {'some-header-thats-not-a-user-agent': 'someData'}
|
||||
|
||||
authorized_http.request(self.test_uri, headers=fake_request_headers)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
self.assertIn('some-header-thats-not-a-user-agent', final_headers)
|
||||
self.assertEqual('someData',
|
||||
final_headers['some-header-thats-not-a-user-agent'])
|
||||
|
||||
def test_authorizedhttp_request_forces_user_agent_with_another_agent_in_headers(
|
||||
self):
|
||||
authorized_http = transport.AuthorizedHttp(
|
||||
self.mock_credentials, http=self.mock_http)
|
||||
headers_with_user_agent = {'user-agent': 'existing-user-agent'}
|
||||
|
||||
authorized_http.request(self.test_uri, headers=headers_with_user_agent)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn('existing-user-agent', final_headers['user-agent'])
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
|
||||
def test_authorizedhttp_request_same_user_agent_already_in_headers(self):
|
||||
authorized_http = transport.AuthorizedHttp(
|
||||
self.mock_credentials, http=self.mock_http)
|
||||
same_user_agent_header = {'user-agent': transport.GAM_USER_AGENT}
|
||||
|
||||
authorized_http.request(self.test_uri, headers=same_user_agent_header)
|
||||
final_headers = self.mock_http.request.call_args[1]['headers']
|
||||
self.assertIn('user-agent', final_headers)
|
||||
self.assertIn(transport.GAM_USER_AGENT, final_headers['user-agent'])
|
||||
# Make sure the header wasn't duplicated
|
||||
self.assertEqual(
|
||||
len(transport.GAM_USER_AGENT), len(final_headers['user-agent']))
|
||||
Reference in New Issue
Block a user