diff --git a/src/gam.py b/src/gam.py index ffc21204..1ee3a344 100755 --- a/src/gam.py +++ b/src/gam.py @@ -80,6 +80,26 @@ Go to the following link in your browser: {address} """ +# Override and wrap google_auth_httplib2.Request.__call__ so that the GAM +# user-agent string is inserted into HTTP request headers. +google_auth_httplib2_request_call = google_auth_httplib2.Request.__call__ +def _request_with_user_agent(self, *args, **kwargs): + """Inserts the GAM user-agent header in all google_auth_httplib2 requests.""" + GAM_USER_AGENT = GAM_INFO + + if kwargs.get('headers') is not None: + if kwargs['headers'].get('user-agent'): + # Save the existing user-agent header and tack on the GAM user-agent. + kwargs['headers']['user-agent'] = '%s %s' % ( + GAM_USER_AGENT, kwargs.headers['user-agent']) + else: + kwargs['headers']['user-agent'] = GAM_USER_AGENT + else: + kwargs['headers'] = {'user-agent': GAM_USER_AGENT} + return google_auth_httplib2_request_call(self, *args, **kwargs) + +google_auth_httplib2.Request.__call__ = _request_with_user_agent + def showUsage(): doGAMVersion(checkForArgs=False) print u''' @@ -1064,10 +1084,10 @@ def buildGAPIServiceObject(api, act_as, showAuthError=True): GM_Globals[GM_CURRENT_API_USER] = act_as GM_Globals[GM_CURRENT_API_SCOPES] = API_SCOPE_MAPPING[api] credentials = getSvcAcctCredentials(GM_Globals[GM_CURRENT_API_SCOPES], act_as) - request = google_auth_httplib2.Request(http, user_agent=GAM_INFO) + request = google_auth_httplib2.Request(http) try: credentials.refresh(request) - service._http = google_auth_httplib2.AuthorizedHttp(credentials, http=http, user_agent=GAM_INFO) + service._http = google_auth_httplib2.AuthorizedHttp(credentials, http=http) except httplib2.ServerNotFoundError as e: systemErrorExit(4, e) except google.auth.exceptions.RefreshError as e: @@ -1129,7 +1149,7 @@ def doCheckServiceAccount(users): for scope in all_scopes: try: credentials = getSvcAcctCredentials([scope], user) - request = google_auth_httplib2.Request(httplib2.Http(disable_ssl_certificate_validation=GC_Values[GC_NO_VERIFY_SSL]), user_agent=GAM_INFO) + request = google_auth_httplib2.Request(httplib2.Http(disable_ssl_certificate_validation=GC_Values[GC_NO_VERIFY_SSL])) credentials.refresh(request) result = u'PASS' except httplib2.ServerNotFoundError as e: diff --git a/src/google_auth_httplib2.py b/src/google_auth_httplib2/__init__.py similarity index 91% rename from src/google_auth_httplib2.py rename to src/google_auth_httplib2/__init__.py index c49c136f..51c406eb 100644 --- a/src/google_auth_httplib2.py +++ b/src/google_auth_httplib2/__init__.py @@ -80,9 +80,8 @@ class Request(transport.Request): .. automethod:: __call__ """ - def __init__(self, http, user_agent=None): + def __init__(self, http): self.http = http - self.user_agent = user_agent def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): @@ -112,12 +111,6 @@ class Request(transport.Request): 'Set the timeout when constructing the httplib2.Http instance.' ) - if self.user_agent: - if headers.get('user-agent'): - headers['user-agent'] = '%s %s' % (self.user_agent, headers['user-agent']) - else: - headers['user-agent'] = self.user_agent - try: _LOGGER.debug('Making request: %s %s', method, url) response, data = self.http.request( @@ -154,7 +147,7 @@ class AuthorizedHttp(object): The underlying :meth:`request` implementation handles adding the credentials' headers to the request and refreshing credentials as needed. """ - def __init__(self, credentials, http=None, user_agent=None, + def __init__(self, credentials, http=None, refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS): """ @@ -164,7 +157,6 @@ class AuthorizedHttp(object): http (httplib2.Http): The underlying HTTP object to use to make requests. If not specified, a :class:`httplib2.Http` instance will be constructed. - user_agent: the user-agent header refresh_status_codes (Sequence[int]): Which HTTP status codes indicate that credentials should be refreshed and the request should be retried. @@ -177,12 +169,11 @@ class AuthorizedHttp(object): self.http = http self.credentials = credentials - self.user_agent = user_agent self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts # Request instance used by internal methods (for example, # credentials.refresh). - self._request = Request(self.http, self.user_agent) + self._request = Request(self.http) def request(self, uri, method='GET', body=None, headers=None, **kwargs): @@ -195,12 +186,6 @@ class AuthorizedHttp(object): # and we want to pass the original headers if we recurse. request_headers = headers.copy() if headers is not None else {} - if self.user_agent: - if request_headers.get('user-agent'): - request_headers['user-agent'] = '%s %s' % (self.user_agent, request_headers['user-agent']) - else: - request_headers['user-agent'] = self.user_agent - self.credentials.before_request( self._request, method, uri, request_headers)