Source code for provenaclient.clients.prov_client

'''
Created Date: Monday June 17th 2024 +1000
Author: Peter Baker
-----
Last Modified: Monday June 17th 2024 4:45:39 pm +1000
Modified By: Peter Baker
-----
Description: Prov API L2 Client.
-----
HISTORY:
Date      	By	Comments
----------	---	---------------------------------------------------------

18-06-2024 | Peter Baker | Note that this layer does not provide any file IO capabilities - see L3
'''

from typing import List, cast
from provenaclient.auth.manager import AuthManager
from provenaclient.utils.config import Config
from enum import Enum
from provenaclient.utils.helpers import *
from provenaclient.clients.client_helpers import *
from provenaclient.models.general import HealthCheckResponse
from ProvenaInterfaces.ProvenanceAPI import LineageResponse, ModelRunRecord, RegisterModelRunResponse, RegisterBatchModelRunRequest, RegisterBatchModelRunResponse, ConvertModelRunsResponse, PostUpdateModelRunResponse, PostUpdateModelRunInput, GenerateReportRequest
from ProvenaInterfaces.RegistryAPI import ItemModelRun


[docs] class ProvAPIEndpoints(str, Enum): """An ENUM containing the prov api endpoints.""" # Completed POST_MODEL_RUN_REGISTER = "/model_run/register" POST_MODEL_RUN_UPDATE = "/model_run/update" POST_MODEL_RUN_REGISTER_BATCH = "/model_run/register_batch" POST_GENERATE_REPORT = "/explore/generate/report" GET_EXPLORE_UPSTREAM = "/explore/upstream" GET_EXPLORE_DOWNSTREAM = "/explore/downstream" GET_EXPLORE_SPECIAL_CONTRIBUTING_DATASETS = "/explore/special/contributing_datasets" GET_EXPLORE_SPECIAL_EFFECTED_DATASETS = "/explore/special/effected_datasets" GET_EXPLORE_SPECIAL_CONTRIBUTING_AGENTS = "/explore/special/contributing_agents" GET_EXPLORE_SPECIAL_EFFECTED_AGENTS = "/explore/special/effected_agents" GET_HEALTH_CHECK = "/" GET_BULK_GENERATE_TEMPLATE_CSV = "/bulk/generate_template/csv" POST_BULK_CONVERT_MODEL_RUNS_CSV = "/bulk/convert_model_runs/csv" GET_BULK_REGENERATE_FROM_BATCH_CSV = "/bulk/regenerate_from_batch/csv" # Not completed. GET_CHECK_ACCESS_CHECK_GENERAL_ACCESS = "/check-access/check-general-access" GET_CHECK_ACCESS_CHECK_ADMIN_ACCESS = "/check-access/check-admin-access" GET_CHECK_ACCESS_CHECK_READ_ACCESS = "/check-access/check-read-access" GET_CHECK_ACCESS_CHECK_WRITE_ACCESS = "/check-access/check-write-access"
[docs] class ProvAPIAdminEndpoints(str, Enum): """An ENUM containing the prov api admin endpoints.""" # Completed GET_ADMIN_CONFIG = "/admin/config" POST_ADMIN_STORE_RECORD = "/admin/store_record" POST_ADMIN_STORE_RECORDS = "/admin/store_records" POST_ADMIN_STORE_ALL_REGISTRY_RECORDS = "/admin/store_all_registry_records" # Not completed yet, TODO. GET_ADMIN_SENTRY_DEBUG = "/admin/sentry-debug"
# L2 interface.
[docs] class ProvAdminClient(ClientService): def __init__(self, auth: AuthManager, config: Config) -> None: self._auth = auth self._config = config
[docs] def _build_endpoint(self, endpoint: ProvAPIAdminEndpoints) -> str: return self._config.prov_api_endpoint + endpoint.value
[docs] async def generate_config_file(self, required_only: bool) -> str: """Generates a nicely formatted .env file of the current required/non supplied properties Used to quickly bootstrap a local environment or to understand currently deployed API. Parameters ---------- required_only : bool, optional By default True """ response = await validated_get_request( client=self, url=self._build_endpoint(ProvAPIAdminEndpoints.GET_ADMIN_CONFIG), error_message=f"Failed to generate config file", params={"required_only": required_only}, ) return response.text
[docs] async def store_record(self, registry_record: ItemModelRun, validate_record: bool) -> StatusResponse: """An admin only endpoint which enables the reupload/storage of an existing completed provenance record. Parameters ---------- registry_record : ItemModelRun The completed registry record for the model run. validate_record: bool Optional Should the ids in the payload be validated?, by default True Returns ------- StatusResponse A status response indicating the success of the request and any other details. """ return await parsed_post_request_with_status( client=self, url=self._build_endpoint( ProvAPIAdminEndpoints.POST_ADMIN_STORE_RECORD), error_message=f"Failed to store record with display name {registry_record.display_name} and id {registry_record.id}", params={"validate_record": validate_record}, json_body=py_to_dict(registry_record), model=StatusResponse )
[docs] async def store_multiple_records(self, registry_record: List[ItemModelRun], validate_record: bool) -> StatusResponse: """An admin only endpoint which enables the reupload/storage of an existing but multiple completed provenance record. Parameters ---------- registry_record : List[ItemModelRun] List of the completed registry record for the model run validate_record validate_record: bool Optional Should the ids in the payload be validated?, by default True Returns ------- StatusResponse A status response indicating the success of the request and any other details. """ return await parsed_post_request_with_status( client=self, url=self._build_endpoint( ProvAPIAdminEndpoints.POST_ADMIN_STORE_RECORDS), error_message=f"Failed to complete multiple store record request.", params={"validate_record": validate_record}, model=StatusResponse, json_body=cast(List[Dict[str, Any]], [py_to_dict(item) for item in registry_record]) )
[docs] async def store_all_registry_records(self, validate_record: bool) -> StatusResponse: """Applies the store record endpoint action across a list of ItemModelRuns ' which is found by querying the registry model run list endpoint directly. Parameters ---------- validate_record : bool Optional Should the ids in the payload be validated?, by default True Returns ------- StatusResponse A status response indicating the success of the request and any other details. """ return await parsed_post_request_with_status( client=self, url=self._build_endpoint( ProvAPIAdminEndpoints.POST_ADMIN_STORE_ALL_REGISTRY_RECORDS), error_message=f"Failed to validate records.", params={"validate_record": validate_record}, json_body=None, model=StatusResponse )
[docs] class ProvClient(ClientService): admin: ProvAdminClient def __init__(self, auth: AuthManager, config: Config) -> None: """Initialises the REPLACEClient with authentication and configuration. Parameters ---------- auth : AuthManager An abstract interface containing the user's requested auth flow method. config : Config A config object which contains information related to the Provena instance. """ self._auth = auth self._config = config self.admin = ProvAdminClient(auth=auth, config=config)
[docs] def _build_endpoint(self, endpoint: ProvAPIEndpoints) -> str: return self._config.prov_api_endpoint + endpoint.value
[docs] async def get_health_check(self) -> HealthCheckResponse: """Checks the health status of the PROV-API. Returns ------- HealthCheckResponse Response containing the PROV-API health information. """ return await parsed_get_request( client=self, url=self._build_endpoint(ProvAPIEndpoints.GET_HEALTH_CHECK), error_message="Health check failed!", params={}, model=HealthCheckResponse )
[docs] async def post_update_model_run(self, model_run_id: str, reason: str, record: ModelRunRecord) -> PostUpdateModelRunResponse: """Updates an existing model run in the system. Args: model_run_id (str): The ID of the model run to update reason (str): The reason for the update record (ModelRunRecord): The updated model run record Returns: PostUpdateModelRunResponse: The response containing the job session ID """ update_payload = PostUpdateModelRunInput( model_run_id=model_run_id, reason=reason, record=record ) return await parsed_post_request( client=self, url=self._build_endpoint(ProvAPIEndpoints.POST_MODEL_RUN_UPDATE), error_message=f"Model run update failed for ID {model_run_id}!", params={}, json_body=py_to_dict(update_payload), model=PostUpdateModelRunResponse )
# Explore Lineage endpoints
[docs] async def explore_upstream(self, starting_id: str, depth: int) -> LineageResponse: """Explores in the upstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the upstream direction, by default 100. Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint(ProvAPIEndpoints.GET_EXPLORE_UPSTREAM), error_message=f"Upstream query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
[docs] async def explore_downstream(self, starting_id: str, depth: int) -> LineageResponse: """Explores in the downstream direction (inputs/associations) starting at the specified node handle ID. The search depth is bounded by the depth parameter which has a default maximum of 100. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the downstream direction, by default 100 Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint(ProvAPIEndpoints.GET_EXPLORE_DOWNSTREAM), error_message=f"Downstream query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
[docs] async def get_contributing_datasets(self, starting_id: str, depth: int) -> LineageResponse: """Fetches datasets (inputs) which involved in a model run naturally in the upstream direction. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the upstream direction, by default 100 Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_EXPLORE_SPECIAL_CONTRIBUTING_DATASETS), error_message=f"Contributing datasets query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
[docs] async def get_effected_datasets(self, starting_id: str, depth: int) -> LineageResponse: """Fetches datasets (outputs) which are derived from the model run naturally in the downstream direction. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the downstream direction, by default 100. Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_EXPLORE_SPECIAL_EFFECTED_DATASETS), error_message=f"Effected datasets query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
[docs] async def get_contributing_agents(self, starting_id: str, depth: int) -> LineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the upstream direction. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the upstream direction, by default 100. Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_EXPLORE_SPECIAL_CONTRIBUTING_AGENTS), error_message=f"Contributing agents query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
[docs] async def get_effected_agents(self, starting_id: str, depth: int) -> LineageResponse: """Fetches agents (organisations or peoples) that are involved or impacted by the model run. naturally in the downstream direction. Parameters ---------- starting_id : str The ID of the entity to start at. depth : int, optional The depth to traverse in the downstream direction, by default 100. Returns ------- LineageResponse A response containing the status, node count, and networkx serialised graph response. """ return await parsed_get_request_with_status( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_EXPLORE_SPECIAL_EFFECTED_AGENTS), error_message=f"Effected agents query with starting id {starting_id} and depth {depth} failed!", params={"starting_id": starting_id, "depth": depth}, model=LineageResponse )
# Model run endpoints.
[docs] async def register_batch_model_runs(self, model_run_batch_payload: RegisterBatchModelRunRequest) -> RegisterBatchModelRunResponse: """This function allows you to register multiple model runs in one go (batch) asynchronously. Note: You can utilise the returned session ID to poll on the JOB API to check status of the model run registration(s). Parameters ---------- batch_model_run_payload : RegisterBatchModelRunRequest A list of model runs (ModelRunRecord objects) Returns ------- RegisterBatchModelRunResponse The job session id derived from job-api for the model-run batch. """ return await parsed_post_request( client=self, url=self._build_endpoint( ProvAPIEndpoints.POST_MODEL_RUN_REGISTER_BATCH), error_message=f"Model run batch registration failed!", params={}, json_body=py_to_dict(model_run_batch_payload), model=RegisterBatchModelRunResponse )
[docs] async def register_model_run(self, model_run_payload: ModelRunRecord) -> RegisterModelRunResponse: """Asynchronously registers a single model run. Note: You can utilise the returned session ID to poll on the JOB API to check status of the model run registration. Parameters ---------- model_run_payload : ModelRunRecord Contains information needed for the model run such as workflow template, inputs, outputs, description etc. Returns ------- RegisterModelRunResponse The job session id derived from job-api for the model-run. """ return await parsed_post_request( client=self, url=self._build_endpoint(ProvAPIEndpoints.POST_MODEL_RUN_REGISTER), error_message=f"Model run registration failed!", params={}, json_body=py_to_dict(model_run_payload), model=RegisterModelRunResponse )
# CSV template tools endpoints
[docs] async def generate_csv_template(self, workflow_template_id: str) -> str: """Generates a model run csv template to be utilised for creating model runs through csv format. Parameters ---------- workflow_template_id : str An ID of a created and existing model run workflow template. """ response = await validated_get_request( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_BULK_GENERATE_TEMPLATE_CSV), error_message=f"Failed to generate CSV file", params={"workflow_template_id": workflow_template_id}, ) return response.text
[docs] async def convert_model_runs_to_csv(self, csv_file_contents: str) -> ConvertModelRunsResponse: """Reads a CSV file, and it's defined model run contents and lodges a model run. Parameters ---------- csv_file_contents : str Contains the model run contents. Returns ------- ConvertModelRunsResponse Returns the model run information in an interactive python datatype. """ # Convert string to bytes. try: model_run_content_encoded: ByteString = csv_file_contents.encode( "utf-8") except Exception as e: raise Exception( f"Exception has occurred while encoding model run content: {e}") # The csv file object to be used for httpx post requests # A dictionary representing file(s) to be uploaded with the # request. Each key in the dictionary is the name of the form field for the file according, # to API specifications. For Provena it's "csv_file" and and the value # is a tuple of (filename, filedata, MIME type / media type). csv_file: HttpxFileUpload = {"csv_file": ( "upload.csv", model_run_content_encoded, "text/csv")} return await parsed_post_request_with_status( client=self, url=self._build_endpoint( ProvAPIEndpoints.POST_BULK_CONVERT_MODEL_RUNS_CSV), error_message="Failed to generate CSV file", files=csv_file, json_body=None, params={}, model=ConvertModelRunsResponse )
[docs] async def regenerate_csv_from_model_run_batch(self, batch_id: str) -> str: """Regenerate/create a csv file containing model run information from a model run batch job. The batch id must exist in the system. Parameters ---------- batch_id : str Obtained from creating a batch model run. """ response = await validated_get_request( client=self, url=self._build_endpoint( ProvAPIEndpoints.GET_BULK_REGENERATE_FROM_BATCH_CSV), error_message=f"Failed to generate CSV file from batch_id {batch_id}", params={"batch_id": batch_id}, ) return response.text
[docs] async def generate_report(self, report_request: GenerateReportRequest) -> ByteString: """Generates a provenance report from a Study or Model Run Entity containing the associated inputs, model runs and outputs involved. The report is generated in `.docx` format by making a POST request to the API. Parameters ---------- report_request : GenerateReportRequest The request object containing the parameters for generating the report, including the `id`, `item_subtype`, and `depth`. Returns ------- ByteString The raw byte content of the generated `.docx` file. The type of the returned content will be either `bytes` or `bytearray`, which can be directly saved to a file. Raises ------ AssertionError If the response content is not found or is not in the expected `bytes` or `bytearray` format. """ response = await validated_post_request( client=self, url=self._build_endpoint(ProvAPIEndpoints.POST_GENERATE_REPORT), error_message=f"Something has gone wrong during report generation for node with id {report_request.id}", json_body=py_to_dict(report_request), params=None, headers = { "Content-Type": "application/json", # Indicates the body is JSON "Accept": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # Indicates the response type } ) # Validate that byte content is present, before returning to the user. assert response.content, f"Failed to generate report for node with id {report_request.id} - Response content not found!" assert isinstance(response.content, (bytes, bytearray)), "Unexpected content type from server. Expected bytes or bytearray!" return response.content