Source code for provenaclient.auth.implementations

'''
Created Date: Tuesday June 18th 2024 +1000
Author: Peter Baker
-----
Last Modified: Tuesday June 18th 2024 12:44:03 pm +1000
Modified By: Peter Baker
-----
Description: Implementations of the Auth interface defined in auth/manager.py
-----
HISTORY:
Date      	By	Comments
----------	---	---------------------------------------------------------
'''

from typing import Any, Dict, Optional
from provenaclient.auth.manager import AuthManager, LogType
import requests
import webbrowser
import time
import os
from provenaclient.auth.helpers import AccessToken, Tokens, keycloak_refresh_token_request, validate_access_token, retrieve_keycloak_public_key
import json
from provenaclient.auth.manager import DEFAULT_LOG_LEVEL
from provenaclient.utils.config import Config


[docs] class DeviceFlow(AuthManager): keycloak_endpoint: str client_id: str scopes: list device_endpoint: str token_endpoint: str def __init__(self, config: Config, client_id: str, log_level: Optional[LogType] = None) -> None: f""" Create and generate a DeviceFlow object. The tokens are automatically refreshed when accessed through the get_auth() function. Tokens are cached in local storage with a configurable file name and are only reproduced if the refresh token expires. Parameters ---------- keycloak_endpoint : str The keycloak endpoint to use. client_id : str The client id for the keycloak authorisation. log_level: Optional[LogType] The logging level to use - defaults to {DEFAULT_LOG_LEVEL} """ # construct parent class and include log level super().__init__(log_level=log_level) self.keycloak_endpoint = config.keycloak_endpoint self.client_id = client_id self.scopes: list = [] self.file_name = ".tokens.json" self.device_endpoint = f'{self.keycloak_endpoint}/protocol/openid-connect/auth/device' self.token_endpoint = f'{self.keycloak_endpoint}/protocol/openid-connect/token' try: # First thing to do here is obtain the keycloak public key. self.public_key = retrieve_keycloak_public_key( logger=self.logger, keycloak_endpoint=self.keycloak_endpoint, ) except Exception as e: raise Exception( "Failed to retrieve the Keycloak public key, authentication cannot proceed.") from e # Second thing will be to check if the tokens.json file is already present or not. # If it's present validate it, if fails then refresh else not present then fetch new tokens. if os.path.exists(self.file_name): self.tokens = self.load_tokens() if not self.tokens: self.logger.info( "No tokens found or failed to load tokens.") self.start_device_flow() else: # Attempt to validate tokens if validate_access_token( logger=self.logger, access_token=self.tokens.access_token, public_key=self.public_key ): self.logger.info("Tokens are valid...") else: try: self.refresh_tokens() except Exception as e: self.logger.info( f"Refresh token has expired or is invalid {e}") self.start_device_flow() else: self.logger.info("No token file found, starting device flow.") self.start_device_flow()
[docs] def get_token(self) -> str: """ IMPLEMENTS BASE METHOD Uses the current token - validates it, refreshes if necessary, and returns the valid token ready to be used. Returns ------- str The access token Raises ------ Exception Raises exception if tokens/public_key are not setup - make sure that the object is instantiated properly before calling this function. Exception If the token is invalid and cannot be refreshed. Exception If the token validation still fails after re-conducting the device flow. """ if self.tokens is None or self.public_key is None: raise Exception( "Cannot generate token without access token or public key.") try: # Attempt to validate the current token. if validate_access_token(public_key=self.public_key, access_token=self.tokens.access_token, logger=self.logger): return self.tokens.access_token # didnt return, refresh and try again. self.logger.info("Token was invalid. Attempting Refresh") self.refresh_tokens() if validate_access_token(public_key=self.public_key, access_token=self.tokens.access_token, logger=self.logger): return self.tokens.access_token # still no good, restart flow self.logger.info("Token was invalid after refresh. Re-iniating Device Flow") self.start_device_flow() if validate_access_token(public_key=self.public_key, access_token=self.tokens.access_token, logger=self.logger): return self.tokens.access_token except Exception as e: self.logger.info("Something went wrong during get_token operation. Starting device flow.") self.start_device_flow() if validate_access_token(public_key=self.public_key, access_token=self.tokens.access_token, logger=self.logger): return self.tokens.access_token # no error, but also no valid token. Something is wrong. err_msg = "Failed to obtain a valid token after refreshing and initiating a new device flow." self.logger.error(err_msg) raise Exception(err_msg)
[docs] def force_refresh(self) -> None: """ IMPLEMENTS BASE METHOD A method to reset the current authentication state. """ # Force refresh everything hear, so reset the tokens file and re-generate the device flow. self.clear_token_storage() self.start_device_flow()
[docs] def refresh_tokens(self) -> None: """Attempts to refresh the authentication tokens using a stored refresh token. This method updates the current tokens if the refresh is successful. Raises ------ ValueError If no initial tokens are set, indicating that there is nothing to refresh. ValueError If the refresh operation fails due to missing access or refresh tokens in the response, suggesting a failure in the refresh process. """ if self.tokens is None: raise ValueError( "Token refresh attempted with no initial tokens set. ") self.logger.info("Refreshing using refresh token") refreshed: Dict[str, Any] = self.make_token_refresh_request() access_token = refreshed.get('access_token') refresh_token = refreshed.get('refresh_token') if not access_token or not refresh_token: error_message = "Failed to refresh tokens: Missing access or refresh tokens" self.logger.info(error_message) raise ValueError(error_message) else: self.tokens = Tokens( access_token=access_token, refresh_token=refresh_token ) self.save_tokens(self.tokens)
[docs] def save_tokens(self, tokens: Tokens) -> None: """Saves authentication tokens to a local file in JSON format. Parameters ---------- tokens : Tokens An object representing the authentication tokens containing the access and refresh tokens. Raises ------- Generic Exception A generic exception is raised that handles errors from IO/File operations. """ self.logger.info("Saving tokens into local storage.") try: with open(self.file_name, 'w') as file: json.dump(tokens.dict(), file) self.logger.info("Tokens saved to file successfully.") except Exception as e: print(f"Failed to save tokens: {e}")
[docs] def clear_token_storage(self) -> None: """Checks if the tokens.json file exists and accordingly removes it and resets token object saved to class variable. """ if os.path.exists("tokens.json"): os.remove("tokens.json") self.logger.info("Stored tokens have been clear.") self.tokens = None
[docs] def load_tokens(self) -> Optional[Tokens]: """Loads authentication tokens from a local JSON file and returns them as a Tokens object. Returns ------- Tokens An object representing the authentication tokens containing the access and refresh tokens. Raises ------- Generic Exception A generic exception is raised that handles errors from IO/File operations. """ self.logger.info("Looking for existing tokens in local storage.") try: with open(self.file_name, 'r') as file: token_data = json.load(file) return Tokens(**token_data) except Exception as e: print(f"Failed to load tokens: {e}") return None
[docs] def make_token_refresh_request(self, tokens: Optional[Tokens] = None) -> Dict[str, Any]: """Performs the token refresh by making an HTTP post request to the token endpoint to obtain new access and refresh tokens. Parameters ---------- tokens : Optional[Tokens], optional An optional Tokens object containing the refresh token. If not provided, the method will use the class variable stored tokens. By default this parameter is None. Returns ------- Dict[str, Any] A dictionary containing the new access and refresh tokens if the refresh is successful. Raises ------ ValueError If no refresh token is provided or found in the class token variable. Exception If the HTTP request fails a message is displayed with the HTTP status code. Can occur if the refresh token has expired. """ # make sure we have tokens to use desired_tokens: Optional[Tokens] if tokens: desired_tokens = tokens else: desired_tokens = self.tokens if not desired_tokens or not desired_tokens.refresh_token: raise ValueError("Refresh token is required but was not provided.") return keycloak_refresh_token_request( logger=self.logger, client_id=self.client_id, refresh_token=desired_tokens.refresh_token, scopes=self.scopes, token_endpoint=self.token_endpoint )
[docs] def start_device_flow(self) -> None: """Initiates the device authorisation flow by requesting a device code from server and prompts user for authentication through the web browser and continues to handle the flow. Raises ------ Exception If the request to the server fails or if the server response is not of status code 200, suggesting that the flow could not initiated. """ self.logger.info("Initiating auth device flow.") data = { "client_id": self.client_id, "scopes": ' '.join(self.scopes) } response = requests.post(self.device_endpoint, data=data) if response.status_code == 200: response_data = response.json() self.device_code = response_data.get('device_code') self.interval = response_data.get('interval') verification_url = response_data.get('verification_uri_complete') user_code = response_data.get('user_code') self.display_device_auth_flow( user_code=user_code, verification_url=verification_url) self.handle_auth_flow() else: raise Exception("Failed to initiate device flow auth.")
[docs] def display_device_auth_flow(self, user_code: str, verification_url: str) -> None: """Displays the current device auth flow challenge - first by trying to open a browser window - if this fails then prints suggestion to stdout to try using the URL manually. Parameters ---------- user_code : str The user code verification_url : str The url which embeds challenge code """ print(f"Verification URL: {verification_url}") print(f"User Code: {user_code}") try: webbrowser.open(verification_url) except Exception: print("Tried to open web-browser but failed. Please visit URL above.")
[docs] def handle_auth_flow(self) -> None: """Handles the device authorisation flow by constantly polling the token endpoint until a token is received, an error is received or a timeout occurs. """ device_grant_type = "urn:ietf:params:oauth:grant-type:device_code" data = { "grant_type": device_grant_type, "device_code": self.device_code, "client_id": self.client_id, "scope": " ".join(self.scopes) } # Setup success criteria succeeded = False timed_out = False misc_fail = False # start time response_data: Optional[Dict[str, Any]] = None # Poll for success while not succeeded and not timed_out and not misc_fail: response = requests.post(self.token_endpoint, data=data) response_data = response.json() if response_data is None: misc_fail = True self.logger.info("No data received in the response.") elif response_data.get('error'): error = response_data['error'] if error != 'authorization_pending': misc_fail = True # Wait appropriate OAuth poll interval time.sleep(self.interval) else: # Successful as there was no error at the endpoint # We will produce a token object here. access_token = response_data.get("access_token") refresh_token = response_data.get("refresh_token") if not access_token: misc_fail = True self.logger.info("Missing or invalid access token.") continue # Skip this iteration, as we were not able to obtain a successful token if not refresh_token: misc_fail = True self.logger.info("Missing or invalid refresh token.") continue # Skip this iteration, as we were not able to obtain a successful token self.tokens = Tokens( access_token=access_token, refresh_token=refresh_token ) # Save the tokens into '.token.json' self.save_tokens(self.tokens) succeeded = True if not succeeded: if response_data and "error" in response_data: self.logger.info(f"Failed due to {response_data['error']}") else: self.logger.info( f"Failed with unknown error, failed to find error message.")
[docs] class OfflineFlow(AuthManager): # The keycloak endpoint to target for tokens keycloak_endpoint: str # The offline token provided by user offline_token: str # The client ID to target for auth client_id: str # The token endpoint to use token_endpoint: str scopes: list public_key: str def __init__(self, config: Config, client_id: str, offline_token: Optional[str] = None, offline_token_file: Optional[str] = None, log_level: Optional[LogType] = None) -> None: f"""Create and generate an OfflineFlow object. Instatiate from provided offline token, or attempt to read one from file and generate the access token. Can provide the offline token directly, a file for it stored as plain text. Parameters ---------- keycloak_endpoint : str The keycloak endpoint to use. E.g., https://auth.example.org/auth/realms/my_realm", client_id : str The client to target for auth. E.g., landing-portal-ui offline_token : Optional[str], optional The offline token to use for generating access tokens from. If not provided, defaults to None and init will try use offline_token_file to read an offline_token. offline_token_file : Optional[str], optional The file name to read the offline token from, where it is stored as plain text. Be sure to add this file to your .gitignore if using this parameter. Raises ------ Exception Fails to retrive public key from keycloak endpoint. Exception Fails to validate new tokens generated from using the supplied offline token (Only if offline token is provided) Exception Fails to validate refreshed tokens from using the offline token in the default file (Only if loading from file) """ # construct parent class and include log level super().__init__(log_level=log_level) self.keycloak_endpoint = config.keycloak_endpoint self.client_id = client_id self.scopes: list = [] self.token_endpoint = f'{self.keycloak_endpoint}/protocol/openid-connect/token' try: # First thing to do here is obtain the keycloak public key. self.public_key = retrieve_keycloak_public_key( logger=self.logger, keycloak_endpoint=self.keycloak_endpoint, ) except Exception as e: raise Exception( "Failed to retrieve the Keycloak public key, authentication cannot proceed.") from e if not offline_token and not offline_token_file: err_msg = "Please provide a value or offline_token or offline_token_file." self.logger.error(err_msg) raise ValueError(err_msg) if offline_token: self.logger.info( "Offline token provided, attempting to generate tokens from it.") self.offline_token = offline_token elif offline_token_file: self.logger.info( "Offline token file provided, attempting to generate tokens from it.") self.offline_token = self.load_offline_token(offline_token_file) else: err_msg = "Please provide a value or offline_token or offline_token_file." self.logger.error(err_msg) raise ValueError(err_msg) # Ok, got an offline token, now generate an temporary access token from it try: self.get_access_token_from_offline_token() except Exception as e: err_msg = "Failed to validate new tokens generated from offline token file." self.logger.error(err_msg) raise Exception(err_msg) from e
[docs] def get_token(self) -> str: """ IMPLEMENTS BASE METHOD Uses the current token - validates it, refreshes if necessary, and returns the valid token ready to be used. Returns ------- str The access token Raises ------ Exception Raises exception if tokens/public_key are not setup - make sure that the object is instantiated properly before calling this function. Exception If the token is invalid and cannot be refreshed. Exception If the token validation still fails after re-conducting the device flow. """ if self.tokens is None or self.public_key is None: raise Exception( "Cannot generate token without access token or public key.") try: # Attempt to validate the current token. if validate_access_token( logger=self.logger, access_token=self.tokens.access_token, public_key=self.public_key ): return self.tokens.access_token else: self.logger.info("Token was invalid. Attempting Refresh") # refresh with refresh token and attempt re validation self.get_access_token_from_offline_token() if validate_access_token( logger=self.logger, access_token=self.tokens.access_token, public_key=self.public_key ): return self.tokens.access_token # still here, error err_msg = "Failed to produce a valid access token from the offline token." self.logger.error(err_msg) raise Exception( err_msg) except Exception as e: err_msg = "Failed to refresh token." self.logger.error(err_msg) raise Exception(err_msg) from e
[docs] def force_refresh(self) -> None: """ IMPLEMENTS BASE METHOD A method to reset the current authentication state. Since the offline flow has no cached state - this just forces a refresh token request to be made. """ self.get_access_token_from_offline_token()
[docs] def get_access_token_from_offline_token(self) -> None: tokens = keycloak_refresh_token_request( logger=self.logger, client_id=self.client_id, token_endpoint=self.token_endpoint, scopes=self.scopes, refresh_token=self.offline_token ) access_token = tokens.get('access_token') if not access_token: err_msg = "Failed to geneate access token. Returned access token is None." self.logger.error(err_msg) raise ValueError(err_msg) self.tokens = AccessToken( access_token=access_token )
[docs] def load_offline_token(self, file_name: str) -> str: """Loads the offline token from the provided file. Parameters ---------- file_name : str The file name to load the offline token from. Returns ------- str The offline token read from the file. Raises ------ Exception If the file does not exist or if the file is empty. """ if not os.path.exists(file_name): err_msg = f"Offline token file named {file_name} does not exist." self.logger.error(err_msg) raise Exception(err_msg) with open(file_name, 'r') as f: offline_token = f.read().strip() if not offline_token or offline_token == "": err_msg = f"File {file_name} is empty." self.logger.error(err_msg) raise Exception(err_msg) return offline_token