replace custom signjwt class with google-auth calls

This commit is contained in:
Jay Lee
2026-07-04 13:12:28 -04:00
parent da9875adde
commit 7d6dffc138

View File

@@ -20,6 +20,7 @@ import google.auth
import google.auth._helpers import google.auth._helpers
import google.auth.compute_engine._metadata as gce_metadata import google.auth.compute_engine._metadata as gce_metadata
import google.auth.crypt import google.auth.crypt
import google.auth.iam
import google.auth.exceptions import google.auth.exceptions
import google.auth.jwt import google.auth.jwt
import google.auth.transport.requests import google.auth.transport.requests
@@ -190,43 +191,6 @@ def doGAMCheckForUpdates(forceCheck):
if forceCheck: if forceCheck:
handleServerError(e) handleServerError(e)
class signjwtJWTCredentials(google.auth.jwt.Credentials):
''' Class used for DASA '''
def _make_jwt(self):
now = arrow.utcnow()
expiry = now.shift(seconds=self._token_lifetime)
payload = {
"iat": now.int_timestamp,
"exp": expiry.int_timestamp,
"iss": self._issuer,
"sub": self._subject,
}
if self._audience:
payload["aud"] = self._audience
payload.update(self._additional_claims)
jwt = self._signer.sign(payload)
return jwt, expiry.naive
class signjwtCredentials(google.oauth2.service_account.Credentials):
''' Class used for DwD '''
def _make_authorization_grant_assertion(self):
now = arrow.utcnow()
expiry = now.shift(seconds=_DEFAULT_TOKEN_LIFETIME_SECS)
payload = {
"iat": now.int_timestamp,
"exp": expiry.int_timestamp,
"iss": self._service_account_email,
"aud": API.GOOGLE_OAUTH2_TOKEN_ENDPOINT,
"scope": google.auth._helpers.scopes_to_string(self._scopes or ()),
}
payload.update(self._additional_claims)
# The subject can be a user email for domain-wide delegation.
if self._subject:
payload.setdefault("sub", self._subject)
token = self._signer(payload)
return token
def get_adc_request(): def get_adc_request():
request = google.auth.transport.requests.Request() request = google.auth.transport.requests.Request()
if GM.Globals[GM.IS_ON_GCE]: if GM.Globals[GM.IS_ON_GCE]:
@@ -236,33 +200,22 @@ def get_adc_request():
return request return request
return transportCreateRequest() return transportCreateRequest()
class signjwtSignJwt(google.auth.crypt.Signer): def _getIAMSigner(service_account_info):
''' Signer class for SignJWT ''' '''Create an IAM-based signer using Application Default Credentials.
def __init__(self, service_account_info):
self.service_account_email = service_account_info['client_email']
self.name = f'projects/-/serviceAccounts/{self.service_account_email}'
self._key_id = None
@property # type: ignore Returns a google.auth.iam.Signer that signs bytes via the IAM signBlob
def key_id(self): API. This replaces the need for a local private key — signing is
return self._key_id delegated to Google\'s IAM service using ADC for authentication.
'''
def sign(self, message): request = get_adc_request()
''' Call IAM Credentials SignJWT API to get our signed JWT ''' try:
request = get_adc_request() credentials, _ = google.auth.default(scopes=[API.IAM_SCOPE],
try: request=request)
credentials, _ = google.auth.default(scopes=[API.IAM_SCOPE], except (google.auth.exceptions.DefaultCredentialsError, google.auth.exceptions.RefreshError) as e:
request=request) systemErrorExit(API_ACCESS_DENIED_RC, str(e))
except (google.auth.exceptions.DefaultCredentialsError, google.auth.exceptions.RefreshError) as e: credentials.refresh(request)
systemErrorExit(API_ACCESS_DENIED_RC, str(e)) return google.auth.iam.Signer(request, credentials,
httpObj = transportAuthorizedHttp(credentials, http=getHttpObj()) service_account_info['client_email'])
# refresh here so we can use the proper request from above
httpObj.credentials.refresh(request)
iamc = getService(API.IAM_CREDENTIALS, httpObj)
response = callGAPI(iamc.projects().serviceAccounts(), 'signJwt',
name=self.name, body={'payload': json.dumps(message)})
signed_jwt = response.get('signedJwt')
return signed_jwt
def handleOAuthTokenError(e, softErrors, displayError=False, i=0, count=0): def handleOAuthTokenError(e, softErrors, displayError=False, i=0, count=0):
errMsg = str(e).replace('.', '') errMsg = str(e).replace('.', '')
@@ -312,8 +265,8 @@ def getOauth2TxtCredentials(exitOnError=True, api=None, noDASA=False, refreshOnl
yksigner = yubikey.YubiKey(jsonDict) yksigner = yubikey.YubiKey(jsonDict)
return (True, JWTCredentials._from_signer_and_info(yksigner, jsonDict, audience=audience)) return (True, JWTCredentials._from_signer_and_info(yksigner, jsonDict, audience=audience))
if key_type == 'signjwt': if key_type == 'signjwt':
sjsigner = signjwtSignJwt(jsonDict) sjsigner = _getIAMSigner(jsonDict)
return (True, signjwtJWTCredentials._from_signer_and_info(sjsigner, jsonDict, audience=audience)) return (True, JWTCredentials._from_signer_and_info(sjsigner, jsonDict, audience=audience))
except (IndexError, KeyError, SyntaxError, TypeError, ValueError) as e: except (IndexError, KeyError, SyntaxError, TypeError, ValueError) as e:
invalidOauth2serviceJsonExit(str(e)) invalidOauth2serviceJsonExit(str(e))
invalidOauth2serviceJsonExit(Msg.NO_DATA) invalidOauth2serviceJsonExit(Msg.NO_DATA)
@@ -672,9 +625,9 @@ def getSvcAcctCredentials(scopesOrAPI, userEmail, softErrors=False, forceOauth=F
credentials = google.oauth2.service_account.Credentials._from_signer_and_info(yksigner, credentials = google.oauth2.service_account.Credentials._from_signer_and_info(yksigner,
GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) GM.Globals[GM.OAUTH2SERVICE_JSON_DATA])
elif key_type == 'signjwt': elif key_type == 'signjwt':
sjsigner = signjwtSignJwt(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) sjsigner = _getIAMSigner(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA])
credentials = signjwtCredentials._from_signer_and_info(sjsigner.sign, credentials = google.oauth2.service_account.Credentials._from_signer_and_info(sjsigner,
GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) GM.Globals[GM.OAUTH2SERVICE_JSON_DATA])
except (ValueError, IndexError, KeyError) as e: except (ValueError, IndexError, KeyError) as e:
if softErrors: if softErrors:
return None return None
@@ -692,10 +645,10 @@ def getSvcAcctCredentials(scopesOrAPI, userEmail, softErrors=False, forceOauth=F
GM.Globals[GM.OAUTH2SERVICE_JSON_DATA], GM.Globals[GM.OAUTH2SERVICE_JSON_DATA],
audience=audience) audience=audience)
elif key_type == 'signjwt': elif key_type == 'signjwt':
sjsigner = signjwtSignJwt(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]) sjsigner = _getIAMSigner(GM.Globals[GM.OAUTH2SERVICE_JSON_DATA])
credentials = signjwtJWTCredentials._from_signer_and_info(sjsigner, credentials = JWTCredentials._from_signer_and_info(sjsigner,
GM.Globals[GM.OAUTH2SERVICE_JSON_DATA], GM.Globals[GM.OAUTH2SERVICE_JSON_DATA],
audience=audience) audience=audience)
credentials.project_id = GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]['project_id'] credentials.project_id = GM.Globals[GM.OAUTH2SERVICE_JSON_DATA]['project_id']
except (ValueError, IndexError, KeyError) as e: except (ValueError, IndexError, KeyError) as e:
if softErrors: if softErrors: