From 7d6dffc1388f16568ec6fe249621b9586b099fb9 Mon Sep 17 00:00:00 2001 From: Jay Lee Date: Sat, 4 Jul 2026 13:12:28 -0400 Subject: [PATCH] replace custom signjwt class with google-auth calls --- src/gam/util/api.py | 97 ++++++++++++--------------------------------- 1 file changed, 25 insertions(+), 72 deletions(-) diff --git a/src/gam/util/api.py b/src/gam/util/api.py index 9b44aaae..ff5a501b 100644 --- a/src/gam/util/api.py +++ b/src/gam/util/api.py @@ -20,6 +20,7 @@ import google.auth import google.auth._helpers import google.auth.compute_engine._metadata as gce_metadata import google.auth.crypt +import google.auth.iam import google.auth.exceptions import google.auth.jwt import google.auth.transport.requests @@ -190,43 +191,6 @@ def doGAMCheckForUpdates(forceCheck): if forceCheck: handleServerError(e) -class signjwtJWTCredentials(google.auth.jwt.Credentials): - ''' Class used for DASA ''' - def _make_jwt(self): - now = arrow.utcnow() - expiry = now.shift(seconds=self._token_lifetime) - payload = { - "iat": now.int_timestamp, - "exp": expiry.int_timestamp, - "iss": self._issuer, - "sub": self._subject, - } - if self._audience: - payload["aud"] = self._audience - payload.update(self._additional_claims) - jwt = self._signer.sign(payload) - return jwt, expiry.naive - -class signjwtCredentials(google.oauth2.service_account.Credentials): - ''' Class used for DwD ''' - - def _make_authorization_grant_assertion(self): - now = arrow.utcnow() - expiry = now.shift(seconds=_DEFAULT_TOKEN_LIFETIME_SECS) - payload = { - "iat": now.int_timestamp, - "exp": expiry.int_timestamp, - "iss": self._service_account_email, - "aud": API.GOOGLE_OAUTH2_TOKEN_ENDPOINT, - "scope": google.auth._helpers.scopes_to_string(self._scopes or ()), - } - payload.update(self._additional_claims) - # The subject can be a user email for domain-wide delegation. - if self._subject: - payload.setdefault("sub", self._subject) - token = self._signer(payload) - return token - def get_adc_request(): request = google.auth.transport.requests.Request() if GM.Globals[GM.IS_ON_GCE]: @@ -236,33 +200,22 @@ def get_adc_request(): return request return transportCreateRequest() -class signjwtSignJwt(google.auth.crypt.Signer): - ''' Signer class for SignJWT ''' - def __init__(self, service_account_info): - self.service_account_email = service_account_info['client_email'] - self.name = f'projects/-/serviceAccounts/{self.service_account_email}' - self._key_id = None +def _getIAMSigner(service_account_info): + '''Create an IAM-based signer using Application Default Credentials. - @property # type: ignore - def key_id(self): - return self._key_id - - def sign(self, message): - ''' Call IAM Credentials SignJWT API to get our signed JWT ''' - request = get_adc_request() - try: - credentials, _ = google.auth.default(scopes=[API.IAM_SCOPE], - request=request) - except (google.auth.exceptions.DefaultCredentialsError, google.auth.exceptions.RefreshError) as e: - systemErrorExit(API_ACCESS_DENIED_RC, str(e)) - httpObj = transportAuthorizedHttp(credentials, http=getHttpObj()) - # refresh here so we can use the proper request from above - httpObj.credentials.refresh(request) - iamc = getService(API.IAM_CREDENTIALS, httpObj) - response = callGAPI(iamc.projects().serviceAccounts(), 'signJwt', - name=self.name, body={'payload': json.dumps(message)}) - signed_jwt = response.get('signedJwt') - return signed_jwt + Returns a google.auth.iam.Signer that signs bytes via the IAM signBlob + API. This replaces the need for a local private key — signing is + delegated to Google\'s IAM service using ADC for authentication. + ''' + request = get_adc_request() + try: + credentials, _ = google.auth.default(scopes=[API.IAM_SCOPE], + request=request) + except (google.auth.exceptions.DefaultCredentialsError, google.auth.exceptions.RefreshError) as e: + systemErrorExit(API_ACCESS_DENIED_RC, str(e)) + credentials.refresh(request) + return google.auth.iam.Signer(request, credentials, + service_account_info['client_email']) def handleOAuthTokenError(e, softErrors, displayError=False, i=0, count=0): errMsg = str(e).replace('.', '') @@ -312,8 +265,8 @@ def getOauth2TxtCredentials(exitOnError=True, api=None, noDASA=False, refreshOnl yksigner = yubikey.YubiKey(jsonDict) return (True, JWTCredentials._from_signer_and_info(yksigner, jsonDict, audience=audience)) if key_type == 'signjwt': - sjsigner = signjwtSignJwt(jsonDict) - return (True, signjwtJWTCredentials._from_signer_and_info(sjsigner, jsonDict, audience=audience)) + sjsigner = _getIAMSigner(jsonDict) + return (True, JWTCredentials._from_signer_and_info(sjsigner, jsonDict, audience=audience)) except (IndexError, KeyError, SyntaxError, TypeError, ValueError) as e: invalidOauth2serviceJsonExit(str(e)) invalidOauth2serviceJsonExit(Msg.NO_DATA) @@ -672,9 +625,9 @@ def getSvcAcctCredentials(scopesOrAPI, userEmail, softErrors=False, forceOauth=F credentials = google.oauth2.service_account.Credentials._from_signer_and_info(yksigner, GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) elif key_type == 'signjwt': - sjsigner = signjwtSignJwt(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) - credentials = signjwtCredentials._from_signer_and_info(sjsigner.sign, - GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) + sjsigner = _getIAMSigner(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) + credentials = google.oauth2.service_account.Credentials._from_signer_and_info(sjsigner, + GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) except (ValueError, IndexError, KeyError) as e: if softErrors: return None @@ -692,10 +645,10 @@ def getSvcAcctCredentials(scopesOrAPI, userEmail, softErrors=False, forceOauth=F GM.Globals[GM.OAUTH2SERVICE_JSON_DATA], audience=audience) elif key_type == 'signjwt': - sjsigner = signjwtSignJwt(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) - credentials = signjwtJWTCredentials._from_signer_and_info(sjsigner, - GM.Globals[GM.OAUTH2SERVICE_JSON_DATA], - audience=audience) + sjsigner = _getIAMSigner(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) + credentials = JWTCredentials._from_signer_and_info(sjsigner, + GM.Globals[GM.OAUTH2SERVICE_JSON_DATA], + audience=audience) credentials.project_id = GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]['project_id'] except (ValueError, IndexError, KeyError) as e: if softErrors: