From dd732d0d462b1c3bffc305b992b712c1cfe34e25 Mon Sep 17 00:00:00 2001 From: Pedro Silva Date: Thu, 25 Jul 2024 20:06:14 +0100 Subject: [PATCH] feat(cli): Make consistent use of DataHubGraphClientConfig (#10466) Deprecates get_url_and_token() in favor of a more complete option: load_graph_config() that returns a full DatahubClientConfig. This change was then propagated across previous usages of get_url_and_token so that connections to DataHub server from the client respect the full breadth of configuration specified by DatahubClientConfig. I.e: You can now specify disable_ssl_verification: true in your ~/.datahubenv file so that all cli functions to the server work when ssl certification is disabled. Fixes #9705 --- docs/how/updating-datahub.md | 1 + .../src/datahub/cli/cli_utils.py | 321 +----------------- .../src/datahub/cli/config_utils.py | 70 +--- .../src/datahub/cli/delete_cli.py | 7 +- metadata-ingestion/src/datahub/cli/get_cli.py | 10 +- .../src/datahub/cli/ingest_cli.py | 15 +- .../src/datahub/cli/lite_cli.py | 6 +- metadata-ingestion/src/datahub/cli/migrate.py | 83 ++++- .../src/datahub/cli/migration_utils.py | 34 +- metadata-ingestion/src/datahub/cli/put_cli.py | 5 + .../src/datahub/cli/timeline_cli.py | 6 +- metadata-ingestion/src/datahub/entrypoints.py | 8 +- .../src/datahub/ingestion/graph/client.py | 101 +++++- .../datahub/ingestion/run/pipeline_config.py | 34 +- .../datahub/ingestion/sink/datahub_rest.py | 2 - .../ingestion/source/metadata/lineage.py | 6 + .../src/datahub/upgrade/upgrade.py | 25 +- .../tests/unit/test_cli_utils.py | 11 +- smoke-test/tests/cli/datahub_cli.py | 53 ++- smoke-test/tests/cli/datahub_graph_test.py | 6 +- smoke-test/tests/delete/delete_test.py | 36 +- smoke-test/tests/lineage/test_lineage.py | 14 +- .../tests/patch/test_dataset_patches.py | 17 +- smoke-test/tests/telemetry/telemetry_test.py | 21 +- smoke-test/tests/timeline/timeline_test.py | 4 +- smoke-test/tests/utils.py | 4 +- 26 files changed, 396 insertions(+), 504 deletions(-) diff --git a/docs/how/updating-datahub.md b/docs/how/updating-datahub.md index ffceb7a5d1b02..a9c24849544a3 100644 --- a/docs/how/updating-datahub.md +++ b/docs/how/updating-datahub.md @@ -80,6 +80,7 @@ New (optional fields `systemMetadata` and `headers`): ### Deprecations ### Other Notable Change +- #10466 - Extends configuration in `~/.datahubenv` to match `DatahubClientConfig` object definition. See full configuration in https://datahubproject.io/docs/python-sdk/clients/. The CLI should now respect the updated configurations specified in `~/.datahubenv` across its functions and utilities. This means that for systems where ssl certification is disabled, setting `disable_ssl_verification: true` in `~./datahubenv` will apply to all CLI calls. ## 0.13.1 diff --git a/metadata-ingestion/src/datahub/cli/cli_utils.py b/metadata-ingestion/src/datahub/cli/cli_utils.py index bda351a4de6b1..b0039b5f87b34 100644 --- a/metadata-ingestion/src/datahub/cli/cli_utils.py +++ b/metadata-ingestion/src/datahub/cli/cli_utils.py @@ -2,15 +2,12 @@ import logging import os import os.path -import sys import typing from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import click import requests -from deprecated import deprecated -from requests.models import Response from requests.sessions import Session import datahub @@ -28,46 +25,14 @@ log = logging.getLogger(__name__) -ENV_METADATA_HOST_URL = "DATAHUB_GMS_URL" -ENV_METADATA_HOST = "DATAHUB_GMS_HOST" -ENV_METADATA_PORT = "DATAHUB_GMS_PORT" -ENV_METADATA_PROTOCOL = "DATAHUB_GMS_PROTOCOL" -ENV_METADATA_TOKEN = "DATAHUB_GMS_TOKEN" ENV_DATAHUB_SYSTEM_CLIENT_ID = "DATAHUB_SYSTEM_CLIENT_ID" ENV_DATAHUB_SYSTEM_CLIENT_SECRET = "DATAHUB_SYSTEM_CLIENT_SECRET" -config_override: Dict = {} - # TODO: Many of the methods in this file duplicate logic that already lives # in the DataHubGraph client. We should refactor this to use the client instead. # For the methods that aren't duplicates, that logic should be moved to the client. -def set_env_variables_override_config(url: str, token: Optional[str]) -> None: - """Should be used to override the config when using rest emitter""" - config_override[ENV_METADATA_HOST_URL] = url - if token is not None: - config_override[ENV_METADATA_TOKEN] = token - - -def get_details_from_env() -> Tuple[Optional[str], Optional[str]]: - host = os.environ.get(ENV_METADATA_HOST) - port = os.environ.get(ENV_METADATA_PORT) - token = os.environ.get(ENV_METADATA_TOKEN) - protocol = os.environ.get(ENV_METADATA_PROTOCOL, "http") - url = os.environ.get(ENV_METADATA_HOST_URL) - if port is not None: - url = f"{protocol}://{host}:{port}" - return url, token - # The reason for using host as URL is backward compatibility - # If port is not being used we assume someone is using host env var as URL - if url is None and host is not None: - log.warning( - f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead" - ) - return url or host, token - - def first_non_null(ls: List[Optional[str]]) -> Optional[str]: return next((el for el in ls if el is not None and el.strip() != ""), None) @@ -80,72 +45,6 @@ def get_system_auth() -> Optional[str]: return None -def get_url_and_token(): - gms_host_env, gms_token_env = get_details_from_env() - if len(config_override.keys()) > 0: - gms_host = config_override.get(ENV_METADATA_HOST_URL) - gms_token = config_override.get(ENV_METADATA_TOKEN) - elif config_utils.should_skip_config(): - gms_host = gms_host_env - gms_token = gms_token_env - else: - config_utils.ensure_datahub_config() - gms_host_conf, gms_token_conf = config_utils.get_details_from_config() - gms_host = first_non_null([gms_host_env, gms_host_conf]) - gms_token = first_non_null([gms_token_env, gms_token_conf]) - return gms_host, gms_token - - -def get_token(): - return get_url_and_token()[1] - - -def get_session_and_host(): - session = requests.Session() - - gms_host, gms_token = get_url_and_token() - - if gms_host is None or gms_host.strip() == "": - log.error( - f"GMS Host is not set. Use datahub init command or set {ENV_METADATA_HOST_URL} env var" - ) - return None, None - - session.headers.update( - { - "X-RestLi-Protocol-Version": "2.0.0", - "Content-Type": "application/json", - } - ) - if isinstance(gms_token, str) and len(gms_token) > 0: - session.headers.update( - {"Authorization": f"Bearer {gms_token.format(**os.environ)}"} - ) - - return session, gms_host - - -def test_connection(): - (session, host) = get_session_and_host() - url = f"{host}/config" - response = session.get(url) - response.raise_for_status() - - -def test_connectivity_complain_exit(operation_name: str) -> None: - """Test connectivity to metadata-service, log operation name and exit""" - # First test connectivity - try: - test_connection() - except Exception as e: - click.secho( - f"Failed to connect to DataHub server at {get_session_and_host()[1]}. Run with datahub --debug {operation_name} ... to get more information.", - fg="red", - ) - log.debug(f"Failed to connect with {e}") - sys.exit(1) - - def parse_run_restli_response(response: requests.Response) -> dict: response_json = response.json() if response.status_code != 200: @@ -195,10 +94,11 @@ def format_aspect_summaries(summaries: list) -> typing.List[typing.List[str]]: def post_rollback_endpoint( + session: Session, + gms_host: str, payload_obj: dict, path: str, ) -> typing.Tuple[typing.List[typing.List[str]], int, int, int, int, typing.List[dict]]: - session, gms_host = get_session_and_host() url = gms_host + path payload = json.dumps(payload_obj) @@ -229,212 +129,13 @@ def post_rollback_endpoint( ) -@deprecated(reason="Use DataHubGraph.get_urns_by_filter instead") -def get_urns_by_filter( - platform: Optional[str], - env: Optional[str] = None, - entity_type: str = "dataset", - search_query: str = "*", - include_removed: bool = False, - only_soft_deleted: Optional[bool] = None, -) -> Iterable[str]: - # TODO: Replace with DataHubGraph call - session, gms_host = get_session_and_host() - endpoint: str = "/entities?action=search" - url = gms_host + endpoint - filter_criteria = [] - entity_type_lower = entity_type.lower() - if env and entity_type_lower != "container": - filter_criteria.append({"field": "origin", "value": env, "condition": "EQUAL"}) - if ( - platform is not None - and entity_type_lower == "dataset" - or entity_type_lower == "dataflow" - or entity_type_lower == "datajob" - or entity_type_lower == "container" - ): - filter_criteria.append( - { - "field": "platform.keyword", - "value": f"urn:li:dataPlatform:{platform}", - "condition": "EQUAL", - } - ) - if platform is not None and entity_type_lower in {"chart", "dashboard"}: - filter_criteria.append( - { - "field": "tool", - "value": platform, - "condition": "EQUAL", - } - ) - - if only_soft_deleted: - filter_criteria.append( - { - "field": "removed", - "value": "true", - "condition": "EQUAL", - } - ) - elif include_removed: - filter_criteria.append( - { - "field": "removed", - "value": "", # accept anything regarding removed property (true, false, non-existent) - "condition": "EQUAL", - } - ) - - search_body = { - "input": search_query, - "entity": entity_type, - "start": 0, - "count": 10000, - "filter": {"or": [{"and": filter_criteria}]}, - } - payload = json.dumps(search_body) - log.debug(payload) - response: Response = session.post(url, payload) - if response.status_code == 200: - assert response._content - results = json.loads(response._content) - num_entities = results["value"]["numEntities"] - entities_yielded: int = 0 - for x in results["value"]["entities"]: - entities_yielded += 1 - log.debug(f"yielding {x['entity']}") - yield x["entity"] - if entities_yielded != num_entities: - log.warning( - f"Discrepancy in entities yielded {entities_yielded} and num entities {num_entities}. This means all entities may not have been deleted." - ) - else: - log.error(f"Failed to execute search query with {str(response.content)}") - response.raise_for_status() - - -def get_container_ids_by_filter( - env: Optional[str], - entity_type: str = "container", - search_query: str = "*", -) -> Iterable[str]: - session, gms_host = get_session_and_host() - endpoint: str = "/entities?action=search" - url = gms_host + endpoint - - container_filters = [] - for container_subtype in ["Database", "Schema", "Project", "Dataset"]: - filter_criteria = [] - - filter_criteria.append( - { - "field": "customProperties", - "value": f"instance={env}", - "condition": "EQUAL", - } - ) - - filter_criteria.append( - { - "field": "typeNames", - "value": container_subtype, - "condition": "EQUAL", - } - ) - container_filters.append({"and": filter_criteria}) - search_body = { - "input": search_query, - "entity": entity_type, - "start": 0, - "count": 10000, - "filter": {"or": container_filters}, - } - payload = json.dumps(search_body) - log.debug(payload) - response: Response = session.post(url, payload) - if response.status_code == 200: - assert response._content - log.debug(response._content) - results = json.loads(response._content) - num_entities = results["value"]["numEntities"] - entities_yielded: int = 0 - for x in results["value"]["entities"]: - entities_yielded += 1 - log.debug(f"yielding {x['entity']}") - yield x["entity"] - assert ( - entities_yielded == num_entities - ), "Did not delete all entities, try running this command again!" - else: - log.error(f"Failed to execute search query with {str(response.content)}") - response.raise_for_status() - - -def batch_get_ids( - ids: List[str], -) -> Iterable[Dict]: - session, gms_host = get_session_and_host() - endpoint: str = "/entitiesV2" - url = gms_host + endpoint - ids_to_get = [Urn.url_encode(id) for id in ids] - response = session.get( - f"{url}?ids=List({','.join(ids_to_get)})", - ) - - if response.status_code == 200: - assert response._content - log.debug(response._content) - results = json.loads(response._content) - num_entities = len(results["results"]) - entities_yielded: int = 0 - for x in results["results"].values(): - entities_yielded += 1 - log.debug(f"yielding {x}") - yield x - assert ( - entities_yielded == num_entities - ), "Did not delete all entities, try running this command again!" - else: - log.error(f"Failed to execute batch get with {str(response.content)}") - response.raise_for_status() - - -def get_incoming_relationships(urn: str, types: List[str]) -> Iterable[Dict]: - yield from get_relationships(urn=urn, types=types, direction="INCOMING") - - -def get_outgoing_relationships(urn: str, types: List[str]) -> Iterable[Dict]: - yield from get_relationships(urn=urn, types=types, direction="OUTGOING") - - -def get_relationships(urn: str, types: List[str], direction: str) -> Iterable[Dict]: - session, gms_host = get_session_and_host() - encoded_urn: str = Urn.url_encode(urn) - types_param_string = "List(" + ",".join(types) + ")" - endpoint: str = f"{gms_host}/relationships?urn={encoded_urn}&direction={direction}&types={types_param_string}" - response: Response = session.get(endpoint) - if response.status_code == 200: - results = response.json() - log.debug(f"Relationship response: {results}") - num_entities = results["count"] - entities_yielded: int = 0 - for x in results["relationships"]: - entities_yielded += 1 - yield x - if entities_yielded != num_entities: - log.warn("Yielded entities differ from num entities") - else: - log.error(f"Failed to execute relationships query with {str(response.content)}") - response.raise_for_status() - - def get_entity( + session: Session, + gms_host: str, urn: str, aspect: Optional[List] = None, cached_session_host: Optional[Tuple[Session, str]] = None, ) -> Dict: - session, gms_host = cached_session_host or get_session_and_host() if urn.startswith("urn%3A"): # we assume the urn is already encoded encoded_urn: str = urn @@ -457,6 +158,8 @@ def get_entity( def post_entity( + session: Session, + gms_host: str, urn: str, entity_type: str, aspect_name: str, @@ -464,7 +167,6 @@ def post_entity( cached_session_host: Optional[Tuple[Session, str]] = None, is_async: Optional[str] = "false", ) -> int: - session, gms_host = cached_session_host or get_session_and_host() endpoint: str = "/aspects/?action=ingestProposal" proposal = { @@ -502,11 +204,12 @@ def _get_pydantic_class_from_aspect_name(aspect_name: str) -> Optional[Type[_Asp def get_latest_timeseries_aspect_values( + session: Session, + gms_host: str, entity_urn: str, timeseries_aspect_name: str, cached_session_host: Optional[Tuple[Session, str]], ) -> Dict: - session, gms_host = cached_session_host or get_session_and_host() query_body = { "urn": entity_urn, "entity": guess_entity_type(entity_urn), @@ -524,6 +227,8 @@ def get_latest_timeseries_aspect_values( def get_aspects_for_entity( + session: Session, + gms_host: str, entity_urn: str, aspects: List[str], typed: bool = False, @@ -533,7 +238,7 @@ def get_aspects_for_entity( # Process non-timeseries aspects non_timeseries_aspects = [a for a in aspects if a not in TIMESERIES_ASPECT_MAP] entity_response = get_entity( - entity_urn, non_timeseries_aspects, cached_session_host + session, gms_host, entity_urn, non_timeseries_aspects, cached_session_host ) aspect_list: Dict[str, dict] = entity_response["aspects"] @@ -541,7 +246,7 @@ def get_aspects_for_entity( timeseries_aspects: List[str] = [a for a in aspects if a in TIMESERIES_ASPECT_MAP] for timeseries_aspect in timeseries_aspects: timeseries_response: Dict = get_latest_timeseries_aspect_values( - entity_urn, timeseries_aspect, cached_session_host + session, gms_host, entity_urn, timeseries_aspect, cached_session_host ) values: List[Dict] = timeseries_response.get("value", {}).get("values", []) if values: diff --git a/metadata-ingestion/src/datahub/cli/config_utils.py b/metadata-ingestion/src/datahub/cli/config_utils.py index 8cddc41551038..7a3fee1c760da 100644 --- a/metadata-ingestion/src/datahub/cli/config_utils.py +++ b/metadata-ingestion/src/datahub/cli/config_utils.py @@ -4,12 +4,10 @@ import logging import os -import sys -from typing import Optional, Union +from typing import Optional import click import yaml -from pydantic import BaseModel, ValidationError from datahub.cli.env_utils import get_boolean_env_variable @@ -22,82 +20,20 @@ ENV_SKIP_CONFIG = "DATAHUB_SKIP_CONFIG" -class GmsConfig(BaseModel): - server: str - token: Optional[str] = None - - -class DatahubConfig(BaseModel): - gms: GmsConfig - - def persist_datahub_config(config: dict) -> None: with open(DATAHUB_CONFIG_PATH, "w+") as outfile: yaml.dump(config, outfile, default_flow_style=False) return None -def write_gms_config( - host: str, token: Optional[str], merge_with_previous: bool = True -) -> None: - config = DatahubConfig(gms=GmsConfig(server=host, token=token)) - if merge_with_previous: - try: - previous_config = get_client_config(as_dict=True) - assert isinstance(previous_config, dict) - except Exception as e: - # ok to fail on this - previous_config = {} - log.debug( - f"Failed to retrieve config from file {DATAHUB_CONFIG_PATH}: {e}. This isn't fatal." - ) - config_dict = {**previous_config, **config.dict()} - else: - config_dict = config.dict() - persist_datahub_config(config_dict) - - -def get_details_from_config(): - datahub_config = get_client_config(as_dict=False) - assert isinstance(datahub_config, DatahubConfig) - if datahub_config is not None: - gms_config = datahub_config.gms - - gms_host = gms_config.server - gms_token = gms_config.token - return gms_host, gms_token - else: - return None, None - - def should_skip_config() -> bool: return get_boolean_env_variable(ENV_SKIP_CONFIG, False) -def ensure_datahub_config() -> None: - if not os.path.isfile(DATAHUB_CONFIG_PATH): - click.secho( - f"No {CONDENSED_DATAHUB_CONFIG_PATH} file found, generating one for you...", - bold=True, - ) - write_gms_config(DEFAULT_GMS_HOST, None) - - -def get_client_config(as_dict: bool = False) -> Union[Optional[DatahubConfig], dict]: +def get_client_config() -> Optional[dict]: with open(DATAHUB_CONFIG_PATH) as stream: try: - config_json = yaml.safe_load(stream) - if as_dict: - return config_json - try: - datahub_config = DatahubConfig.parse_obj(config_json) - return datahub_config - except ValidationError as e: - click.echo( - f"Received error, please check your {CONDENSED_DATAHUB_CONFIG_PATH}" - ) - click.echo(e, err=True) - sys.exit(1) + return yaml.safe_load(stream) except yaml.YAMLError as exc: click.secho(f"{DATAHUB_CONFIG_PATH} malformed, error: {exc}", bold=True) return None diff --git a/metadata-ingestion/src/datahub/cli/delete_cli.py b/metadata-ingestion/src/datahub/cli/delete_cli.py index 9332b701fed39..b5cc67532a9dd 100644 --- a/metadata-ingestion/src/datahub/cli/delete_cli.py +++ b/metadata-ingestion/src/datahub/cli/delete_cli.py @@ -123,6 +123,8 @@ def by_registry( Delete all metadata written using the given registry id and version pair. """ + client = get_default_graph() + if soft and not dry_run: raise click.UsageError( "Soft-deleting with a registry-id is not yet supported. Try --dry-run to see what you will be deleting, before issuing a hard-delete using the --hard flag" @@ -138,7 +140,10 @@ def by_registry( unsafe_entity_count, unsafe_entities, ) = cli_utils.post_rollback_endpoint( - registry_delete, "/entities?action=deleteAll" + client._session, + client.config.server, + registry_delete, + "/entities?action=deleteAll", ) if not dry_run: diff --git a/metadata-ingestion/src/datahub/cli/get_cli.py b/metadata-ingestion/src/datahub/cli/get_cli.py index 46e2fdf5b1f79..b6ff5f39a2c14 100644 --- a/metadata-ingestion/src/datahub/cli/get_cli.py +++ b/metadata-ingestion/src/datahub/cli/get_cli.py @@ -6,6 +6,7 @@ from click_default_group import DefaultGroup from datahub.cli.cli_utils import get_aspects_for_entity +from datahub.ingestion.graph.client import get_default_graph from datahub.telemetry import telemetry from datahub.upgrade import upgrade @@ -44,10 +45,17 @@ def urn(ctx: Any, urn: Optional[str], aspect: List[str], details: bool) -> None: raise click.UsageError("Nothing for me to get. Maybe provide an urn?") urn = ctx.args[0] logger.debug(f"Using urn from args {urn}") + + client = get_default_graph() + click.echo( json.dumps( get_aspects_for_entity( - entity_urn=urn, aspects=aspect, typed=False, details=details + session=client._session, + gms_host=client.config.server, + entity_urn=urn, + aspects=aspect, + typed=False, ), sort_keys=True, indent=2, diff --git a/metadata-ingestion/src/datahub/cli/ingest_cli.py b/metadata-ingestion/src/datahub/cli/ingest_cli.py index bb8d67f8439ab..75760f3dbd95d 100644 --- a/metadata-ingestion/src/datahub/cli/ingest_cli.py +++ b/metadata-ingestion/src/datahub/cli/ingest_cli.py @@ -427,7 +427,9 @@ def mcps(path: str) -> None: def list_runs(page_offset: int, page_size: int, include_soft_deletes: bool) -> None: """List recent ingestion runs to datahub""" - session, gms_host = cli_utils.get_session_and_host() + client = get_default_graph() + session = client._session + gms_host = client.config.server url = f"{gms_host}/runs?action=list" @@ -476,7 +478,9 @@ def show( run_id: str, start: int, count: int, include_soft_deletes: bool, show_aspect: bool ) -> None: """Describe a provided ingestion run to datahub""" - session, gms_host = cli_utils.get_session_and_host() + client = get_default_graph() + session = client._session + gms_host = client.config.server url = f"{gms_host}/runs?action=describe" @@ -524,8 +528,7 @@ def rollback( run_id: str, force: bool, dry_run: bool, safe: bool, report_dir: str ) -> None: """Rollback a provided ingestion run to datahub""" - - cli_utils.test_connectivity_complain_exit("ingest") + client = get_default_graph() if not force and not dry_run: click.confirm( @@ -541,7 +544,9 @@ def rollback( aspects_affected, unsafe_entity_count, unsafe_entities, - ) = cli_utils.post_rollback_endpoint(payload_obj, "/runs?action=rollback") + ) = cli_utils.post_rollback_endpoint( + client._session, client.config.server, payload_obj, "/runs?action=rollback" + ) click.echo( "Rolling back deletes the entities created by a run and reverts the updated aspects" diff --git a/metadata-ingestion/src/datahub/cli/lite_cli.py b/metadata-ingestion/src/datahub/cli/lite_cli.py index 7e2ad23a7753f..7000cdbd73094 100644 --- a/metadata-ingestion/src/datahub/cli/lite_cli.py +++ b/metadata-ingestion/src/datahub/cli/lite_cli.py @@ -11,12 +11,12 @@ from datahub.cli.config_utils import ( DATAHUB_ROOT_FOLDER, - DatahubConfig, get_client_config, persist_datahub_config, ) from datahub.ingestion.api.common import PipelineContext, RecordEnvelope from datahub.ingestion.api.sink import NoopWriteCallback +from datahub.ingestion.graph.client import DatahubConfig from datahub.ingestion.run.pipeline import Pipeline from datahub.ingestion.sink.file import FileSink, FileSinkConfig from datahub.lite.duckdb_lite_config import DuckDBLiteConfig @@ -45,7 +45,7 @@ class LiteCliConfig(DatahubConfig): def get_lite_config() -> LiteLocalConfig: - client_config_dict = get_client_config(as_dict=True) + client_config_dict = get_client_config() lite_config = LiteCliConfig.parse_obj(client_config_dict) return lite_config.lite @@ -309,7 +309,7 @@ def search( def write_lite_config(lite_config: LiteLocalConfig) -> None: - cli_config = get_client_config(as_dict=True) + cli_config = get_client_config() assert isinstance(cli_config, dict) cli_config["lite"] = lite_config.dict() persist_datahub_config(cli_config) diff --git a/metadata-ingestion/src/datahub/cli/migrate.py b/metadata-ingestion/src/datahub/cli/migrate.py index 30f82987a6b65..ea5375c947128 100644 --- a/metadata-ingestion/src/datahub/cli/migrate.py +++ b/metadata-ingestion/src/datahub/cli/migrate.py @@ -1,7 +1,8 @@ +import json import logging import random import uuid -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, Iterable, List, Tuple, Union import click import progressbar @@ -23,7 +24,11 @@ SchemaKey, ) from datahub.emitter.rest_emitter import DatahubRestEmitter -from datahub.ingestion.graph.client import DataHubGraph, get_default_graph +from datahub.ingestion.graph.client import ( + DataHubGraph, + RelatedEntity, + get_default_graph, +) from datahub.metadata.schema_classes import ( ContainerKeyClass, ContainerPropertiesClass, @@ -31,6 +36,7 @@ SystemMetadataClass, ) from datahub.telemetry import telemetry +from datahub.utilities.urns.urn import Urn log = logging.getLogger(__name__) @@ -143,15 +149,17 @@ def dataplatform2instance_func( graph = get_default_graph() - urns_to_migrate = [] + urns_to_migrate: List[str] = [] # we first calculate all the urns we will be migrating - for src_entity_urn in cli_utils.get_urns_by_filter(platform=platform, env=env): + for src_entity_urn in graph.get_urns_by_filter(platform=platform, env=env): key = dataset_urn_to_key(src_entity_urn) assert key # Does this urn already have a platform instance associated with it? - response = cli_utils.get_aspects_for_entity( - entity_urn=src_entity_urn, aspects=["dataPlatformInstance"], typed=True + response = graph.get_aspects_for_entity( + entity_urn=src_entity_urn, + aspects=["dataPlatformInstance"], + aspect_types=[DataPlatformInstanceClass], ) if "dataPlatformInstance" in response: assert isinstance( @@ -229,14 +237,14 @@ def dataplatform2instance_func( migration_report.on_entity_create(new_urn, "dataPlatformInstance") for relationship in relationships: - target_urn = relationship["entity"] + target_urn = relationship.urn entity_type = _get_type_from_urn(target_urn) - relationshipType = relationship["type"] + relationshipType = relationship.relationship_type aspect_name = migration_utils.get_aspect_name_from_relationship( relationshipType, entity_type ) aspect_map = cli_utils.get_aspects_for_entity( - target_urn, aspects=[aspect_name], typed=True + graph._session, graph.config.server, target_urn, aspects=[aspect_name] ) if aspect_name in aspect_map: aspect = aspect_map[aspect_name] @@ -378,13 +386,16 @@ def migrate_containers( def get_containers_for_migration(env: str) -> List[Any]: - containers_to_migrate = list(cli_utils.get_container_ids_by_filter(env=env)) + client = get_default_graph() + containers_to_migrate = list( + client.get_urns_by_filter(entity_types=["container"], env=env) + ) containers = [] increment = 20 for i in range(0, len(containers_to_migrate), increment): - for container in cli_utils.batch_get_ids( - containers_to_migrate[i : i + increment] + for container in batch_get_ids( + client, containers_to_migrate[i : i + increment] ): log.debug(container) containers.append(container) @@ -392,6 +403,37 @@ def get_containers_for_migration(env: str) -> List[Any]: return containers +def batch_get_ids( + client: DataHubGraph, + ids: List[str], +) -> Iterable[Dict]: + session = client._session + gms_host = client.config.server + endpoint: str = "/entitiesV2" + url = gms_host + endpoint + ids_to_get = [Urn.url_encode(id) for id in ids] + response = session.get( + f"{url}?ids=List({','.join(ids_to_get)})", + ) + + if response.status_code == 200: + assert response._content + log.debug(response._content) + results = json.loads(response._content) + num_entities = len(results["results"]) + entities_yielded: int = 0 + for x in results["results"].values(): + entities_yielded += 1 + log.debug(f"yielding {x}") + yield x + assert ( + entities_yielded == num_entities + ), "Did not delete all entities, try running this command again!" + else: + log.error(f"Failed to execute batch get with {str(response.content)}") + response.raise_for_status() + + def process_container_relationships( container_id_map: Dict[str, str], dry_run: bool, @@ -400,22 +442,29 @@ def process_container_relationships( migration_report: MigrationReport, rest_emitter: DatahubRestEmitter, ) -> None: - relationships = migration_utils.get_incoming_relationships(urn=src_urn) + relationships: Iterable[RelatedEntity] = migration_utils.get_incoming_relationships( + urn=src_urn + ) + client = get_default_graph() for relationship in relationships: log.debug(f"Incoming Relationship: {relationship}") - target_urn = relationship["entity"] + target_urn: str = relationship.urn # We should use the new id if we already migrated it if target_urn in container_id_map: - target_urn = container_id_map.get(target_urn) + target_urn = container_id_map[target_urn] entity_type = _get_type_from_urn(target_urn) - relationshipType = relationship["type"] + relationshipType = relationship.relationship_type aspect_name = migration_utils.get_aspect_name_from_relationship( relationshipType, entity_type ) aspect_map = cli_utils.get_aspects_for_entity( - target_urn, aspects=[aspect_name], typed=True + client._session, + client.config.server, + target_urn, + aspects=[aspect_name], + typed=True, ) if aspect_name in aspect_map: aspect = aspect_map[aspect_name] diff --git a/metadata-ingestion/src/datahub/cli/migration_utils.py b/metadata-ingestion/src/datahub/cli/migration_utils.py index 09bf58bf4ec76..a3dfcfe2ac403 100644 --- a/metadata-ingestion/src/datahub/cli/migration_utils.py +++ b/metadata-ingestion/src/datahub/cli/migration_utils.py @@ -1,12 +1,17 @@ import logging import uuid -from typing import Dict, Iterable, List +from typing import Iterable, List from avrogen.dict_wrapper import DictWrapper from datahub.cli import cli_utils from datahub.emitter.mce_builder import Aspect from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import ( + DataHubGraph, + RelatedEntity, + get_default_graph, +) from datahub.metadata.schema_classes import ( ChartInfoClass, ContainerClass, @@ -238,8 +243,13 @@ def clone_aspect( run_id: str = str(uuid.uuid4()), dry_run: bool = False, ) -> Iterable[MetadataChangeProposalWrapper]: + client = get_default_graph() aspect_map = cli_utils.get_aspects_for_entity( - entity_urn=src_urn, aspects=aspect_names, typed=True + client._session, + client.config.server, + entity_urn=src_urn, + aspects=aspect_names, + typed=True, ) if aspect_names is not None: @@ -263,10 +273,11 @@ def clone_aspect( log.debug(f"did not find aspect {a} in response, continuing...") -def get_incoming_relationships(urn: str) -> Iterable[Dict]: - yield from cli_utils.get_incoming_relationships( - urn, - types=[ +def get_incoming_relationships(urn: str) -> Iterable[RelatedEntity]: + client = get_default_graph() + yield from client.get_related_entities( + entity_urn=urn, + relationship_types=[ "DownstreamOf", "Consumes", "Produces", @@ -274,13 +285,15 @@ def get_incoming_relationships(urn: str) -> Iterable[Dict]: "DerivedFrom", "IsPartOf", ], + direction=DataHubGraph.RelationshipDirection.INCOMING, ) -def get_outgoing_relationships(urn: str) -> Iterable[Dict]: - yield from cli_utils.get_outgoing_relationships( - urn, - types=[ +def get_outgoing_relationships(urn: str) -> Iterable[RelatedEntity]: + client = get_default_graph() + yield from client.get_related_entities( + entity_urn=urn, + relationship_types=[ "DownstreamOf", "Consumes", "Produces", @@ -288,4 +301,5 @@ def get_outgoing_relationships(urn: str) -> Iterable[Dict]: "DerivedFrom", "IsPartOf", ], + direction=DataHubGraph.RelationshipDirection.OUTGOING, ) diff --git a/metadata-ingestion/src/datahub/cli/put_cli.py b/metadata-ingestion/src/datahub/cli/put_cli.py index 324d7f94db258..40af54c7c7e2e 100644 --- a/metadata-ingestion/src/datahub/cli/put_cli.py +++ b/metadata-ingestion/src/datahub/cli/put_cli.py @@ -46,7 +46,12 @@ def aspect(urn: str, aspect: str, aspect_data: str) -> None: aspect_data, allow_stdin=True, resolve_env_vars=False, process_directives=False ) + client = get_default_graph() + + # TODO: Replace with client.emit, requires figuring out the correct subsclass of _Aspect to create from the data status = post_entity( + client._session, + client.config.server, urn=urn, aspect_name=aspect, entity_type=entity_type, diff --git a/metadata-ingestion/src/datahub/cli/timeline_cli.py b/metadata-ingestion/src/datahub/cli/timeline_cli.py index ca0486043f157..63e05aa65d9a5 100644 --- a/metadata-ingestion/src/datahub/cli/timeline_cli.py +++ b/metadata-ingestion/src/datahub/cli/timeline_cli.py @@ -8,8 +8,8 @@ from requests import Response from termcolor import colored -import datahub.cli.cli_utils from datahub.emitter.mce_builder import dataset_urn_to_key, schema_field_urn_to_key +from datahub.ingestion.graph.client import get_default_graph from datahub.telemetry import telemetry from datahub.upgrade import upgrade from datahub.utilities.urns.urn import Urn @@ -63,7 +63,9 @@ def get_timeline( end_time: Optional[int], diff: bool, ) -> Any: - session, host = datahub.cli.cli_utils.get_session_and_host() + client = get_default_graph() + session = client._session + host = client.config.server if urn.startswith("urn%3A"): # we assume the urn is already encoded encoded_urn: str = urn diff --git a/metadata-ingestion/src/datahub/entrypoints.py b/metadata-ingestion/src/datahub/entrypoints.py index 72b19d882a45f..d6b888b391bfb 100644 --- a/metadata-ingestion/src/datahub/entrypoints.py +++ b/metadata-ingestion/src/datahub/entrypoints.py @@ -13,11 +13,7 @@ generate_access_token, make_shim_command, ) -from datahub.cli.config_utils import ( - DATAHUB_CONFIG_PATH, - get_boolean_env_variable, - write_gms_config, -) +from datahub.cli.config_utils import DATAHUB_CONFIG_PATH, get_boolean_env_variable from datahub.cli.delete_cli import delete from datahub.cli.docker_cli import docker from datahub.cli.exists_cli import exists @@ -37,7 +33,7 @@ from datahub.cli.telemetry import telemetry as telemetry_cli from datahub.cli.timeline_cli import timeline from datahub.configuration.common import should_show_stack_trace -from datahub.ingestion.graph.client import get_default_graph +from datahub.ingestion.graph.client import get_default_graph, write_gms_config from datahub.telemetry import telemetry from datahub.utilities._custom_package_loader import model_version_name from datahub.utilities.logging_manager import configure_logging diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index 221136a897956..b2a768099c481 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -3,6 +3,8 @@ import functools import json import logging +import os +import sys import textwrap import time from dataclasses import dataclass @@ -22,12 +24,13 @@ Union, ) +import click from avro.schema import RecordSchema from deprecated import deprecated -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from requests.models import HTTPError -from datahub.cli.cli_utils import get_url_and_token +from datahub.cli import config_utils from datahub.configuration.common import ConfigModel, GraphError, OperationalError from datahub.emitter.aspect import TIMESERIES_ASPECT_MAP from datahub.emitter.mce_builder import DEFAULT_ENV, Aspect @@ -87,6 +90,12 @@ _MISSING_SERVER_ID = "missing" _GRAPH_DUMMY_RUN_ID = "__datahub-graph-client" +ENV_METADATA_HOST_URL = "DATAHUB_GMS_URL" +ENV_METADATA_TOKEN = "DATAHUB_GMS_TOKEN" +ENV_METADATA_HOST = "DATAHUB_GMS_HOST" +ENV_METADATA_PORT = "DATAHUB_GMS_PORT" +ENV_METADATA_PROTOCOL = "DATAHUB_GMS_PROTOCOL" + class DatahubClientConfig(ConfigModel): """Configuration class for holding connectivity to datahub gms""" @@ -583,6 +592,9 @@ def _relationships_endpoint(self): def _aspect_count_endpoint(self): return f"{self.config.server}/aspects?action=getCount" + # def _session(self) -> Session: + # return super()._session + def get_domain_urn_by_name(self, domain_name: str) -> Optional[str]: """Retrieve a domain urn based on its name. Returns None if there is no match found""" @@ -1763,7 +1775,88 @@ def close(self) -> None: def get_default_graph() -> DataHubGraph: - (url, token) = get_url_and_token() - graph = DataHubGraph(DatahubClientConfig(server=url, token=token)) + graph_config = load_client_config() + graph = DataHubGraph(graph_config) graph.test_connection() return graph + + +class DatahubConfig(BaseModel): + gms: DatahubClientConfig + + +config_override: Dict = {} + + +def get_details_from_env() -> Tuple[Optional[str], Optional[str]]: + host = os.environ.get(ENV_METADATA_HOST) + port = os.environ.get(ENV_METADATA_PORT) + token = os.environ.get(ENV_METADATA_TOKEN) + protocol = os.environ.get(ENV_METADATA_PROTOCOL, "http") + url = os.environ.get(ENV_METADATA_HOST_URL) + if port is not None: + url = f"{protocol}://{host}:{port}" + return url, token + # The reason for using host as URL is backward compatibility + # If port is not being used we assume someone is using host env var as URL + if url is None and host is not None: + logger.warning( + f"Do not use {ENV_METADATA_HOST} as URL. Use {ENV_METADATA_HOST_URL} instead" + ) + return url or host, token + + +def load_client_config() -> DatahubClientConfig: + try: + ensure_datahub_config() + client_config_dict = config_utils.get_client_config() + datahub_config: DatahubClientConfig = DatahubConfig.parse_obj( + client_config_dict + ).gms + except ValidationError as e: + click.echo( + f"Received error, please check your {config_utils.CONDENSED_DATAHUB_CONFIG_PATH}" + ) + click.echo(e, err=True) + sys.exit(1) + + # Override gms & token configs if specified. + if len(config_override.keys()) > 0: + datahub_config.server = str(config_override.get(ENV_METADATA_HOST_URL)) + datahub_config.token = config_override.get(ENV_METADATA_TOKEN) + elif config_utils.should_skip_config(): + gms_host_env, gms_token_env = get_details_from_env() + if gms_host_env: + datahub_config.server = gms_host_env + datahub_config.token = gms_token_env + + return datahub_config + + +def ensure_datahub_config() -> None: + if not os.path.isfile(config_utils.DATAHUB_CONFIG_PATH): + click.secho( + f"No {config_utils.CONDENSED_DATAHUB_CONFIG_PATH} file found, generating one for you...", + bold=True, + ) + write_gms_config(config_utils.DEFAULT_GMS_HOST, None) + + +def write_gms_config( + host: str, token: Optional[str], merge_with_previous: bool = True +) -> None: + config = DatahubConfig(gms=DatahubClientConfig(server=host, token=token)) + if merge_with_previous: + try: + previous_config = config_utils.get_client_config() + assert isinstance(previous_config, dict) + except Exception as e: + # ok to fail on this + previous_config = {} + logger.debug( + f"Failed to retrieve config from file {config_utils.DATAHUB_CONFIG_PATH}: {e}. This isn't fatal." + ) + config_dict = {**previous_config, **config.dict()} + else: + config_dict = config.dict() + config_utils.persist_datahub_config(config_dict) diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py index 98629ba030695..51b657ad5bf77 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py @@ -1,12 +1,14 @@ import datetime import logging +import os import uuid from typing import Any, Dict, List, Optional -from pydantic import Field, validator +from pydantic import Field, root_validator, validator +from datahub.configuration import config_loader from datahub.configuration.common import ConfigModel, DynamicTypedConfig -from datahub.ingestion.graph.client import DatahubClientConfig +from datahub.ingestion.graph.client import DatahubClientConfig, load_client_config from datahub.ingestion.sink.file import FileSinkConfig logger = logging.getLogger(__name__) @@ -101,6 +103,34 @@ def run_id_should_be_semantic( assert v is not None return v + @root_validator(pre=True) + def default_sink_is_datahub_rest(cls, values: Dict[str, Any]) -> Any: + if "sink" not in values: + config = load_client_config() + # update this + default_sink_config = { + "type": "datahub-rest", + "config": config.dict(exclude_defaults=True), + } + # resolve env variables if present + default_sink_config = config_loader.resolve_env_variables( + default_sink_config, environ=os.environ + ) + values["sink"] = default_sink_config + + return values + + @validator("datahub_api", always=True) + def datahub_api_should_use_rest_sink_as_default( + cls, v: Optional[DatahubClientConfig], values: Dict[str, Any], **kwargs: Any + ) -> Optional[DatahubClientConfig]: + if v is None and "sink" in values and hasattr(values["sink"], "type"): + sink_type = values["sink"].type + if sink_type == "datahub-rest": + sink_config = values["sink"].config + v = DatahubClientConfig.parse_obj_allow_extras(sink_config) + return v + @classmethod def from_dict( cls, resolved_dict: dict, raw_dict: Optional[dict] = None diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 33a8f4a126182..a9f788acf66d3 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -9,7 +9,6 @@ from enum import auto from typing import List, Optional, Tuple, Union -from datahub.cli.cli_utils import set_env_variables_override_config from datahub.configuration.common import ( ConfigEnum, ConfigurationError, @@ -120,7 +119,6 @@ def __post_init__(self) -> None: ) self.report.max_threads = self.config.max_threads logger.debug("Setting env variables to override config") - set_env_variables_override_config(self.config.server, self.config.token) logger.debug("Setting gms config") set_gms_config(gms_config) diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py index 8bd2e70b2d478..08ed7677c7ab4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py @@ -35,6 +35,7 @@ auto_workunit_reporter, ) from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.graph.client import get_default_graph from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( FineGrainedLineageDownstreamType, FineGrainedLineageUpstreamType, @@ -209,7 +210,12 @@ def _get_lineage_mcp( # extract the old lineage and save it for the new mcp if preserve_upstream: + + client = get_default_graph() + old_upstream_lineage = get_aspects_for_entity( + client._session, + client.config.server, entity_urn=entity_urn, aspects=["upstreamLineage"], typed=True, diff --git a/metadata-ingestion/src/datahub/upgrade/upgrade.py b/metadata-ingestion/src/datahub/upgrade/upgrade.py index 446f1a05b71a6..cdb7df57efe79 100644 --- a/metadata-ingestion/src/datahub/upgrade/upgrade.py +++ b/metadata-ingestion/src/datahub/upgrade/upgrade.py @@ -12,8 +12,7 @@ from termcolor import colored from datahub import __version__ -from datahub.cli import cli_utils -from datahub.ingestion.graph.client import DataHubGraph +from datahub.ingestion.graph.client import DataHubGraph, load_client_config log = logging.getLogger(__name__) @@ -101,16 +100,18 @@ async def get_github_stats(): return (latest_server_version, latest_server_date) -async def get_server_config(gms_url: str, token: str) -> dict: +async def get_server_config(gms_url: str, token: Optional[str]) -> dict: import aiohttp - async with aiohttp.ClientSession( - headers={ - "X-RestLi-Protocol-Version": "2.0.0", - "Content-Type": "application/json", - "Authorization": f"Bearer {token}", - } - ) as session: + headers = { + "X-RestLi-Protocol-Version": "2.0.0", + "Content-Type": "application/json", + } + + if token: + headers["Authorization"] = f"Bearer {token}" + + async with aiohttp.ClientSession() as session: config_endpoint = f"{gms_url}/config" async with session.get(config_endpoint) as dh_response: dh_response_json = await dh_response.json() @@ -126,7 +127,9 @@ async def get_server_version_stats( if not server: try: # let's get the server from the cli config - host, token = cli_utils.get_url_and_token() + client_config = load_client_config() + host = client_config.server + token = client_config.token server_config = await get_server_config(host, token) log.debug(f"server_config:{server_config}") except Exception as e: diff --git a/metadata-ingestion/tests/unit/test_cli_utils.py b/metadata-ingestion/tests/unit/test_cli_utils.py index bc1826d422e38..68cb985af4734 100644 --- a/metadata-ingestion/tests/unit/test_cli_utils.py +++ b/metadata-ingestion/tests/unit/test_cli_utils.py @@ -2,6 +2,7 @@ from unittest import mock from datahub.cli import cli_utils +from datahub.ingestion.graph.client import get_details_from_env def test_first_non_null(): @@ -16,14 +17,14 @@ def test_first_non_null(): @mock.patch.dict(os.environ, {"DATAHUB_GMS_HOST": "http://localhost:9092"}) def test_correct_url_when_gms_host_in_old_format(): - assert cli_utils.get_details_from_env() == ("http://localhost:9092", None) + assert get_details_from_env() == ("http://localhost:9092", None) @mock.patch.dict( os.environ, {"DATAHUB_GMS_HOST": "localhost", "DATAHUB_GMS_PORT": "8080"} ) def test_correct_url_when_gms_host_and_port_set(): - assert cli_utils.get_details_from_env() == ("http://localhost:8080", None) + assert get_details_from_env() == ("http://localhost:8080", None) @mock.patch.dict( @@ -35,7 +36,7 @@ def test_correct_url_when_gms_host_and_port_set(): }, ) def test_correct_url_when_gms_host_port_url_set(): - assert cli_utils.get_details_from_env() == ("http://localhost:8080", None) + assert get_details_from_env() == ("http://localhost:8080", None) @mock.patch.dict( @@ -48,7 +49,7 @@ def test_correct_url_when_gms_host_port_url_set(): }, ) def test_correct_url_when_gms_host_port_url_protocol_set(): - assert cli_utils.get_details_from_env() == ("https://localhost:8080", None) + assert get_details_from_env() == ("https://localhost:8080", None) @mock.patch.dict( @@ -58,7 +59,7 @@ def test_correct_url_when_gms_host_port_url_protocol_set(): }, ) def test_correct_url_when_url_set(): - assert cli_utils.get_details_from_env() == ("https://example.com", None) + assert get_details_from_env() == ("https://example.com", None) def test_fixup_gms_url(): diff --git a/smoke-test/tests/cli/datahub_cli.py b/smoke-test/tests/cli/datahub_cli.py index d1620d03c88b2..a57cfd0b7b2be 100644 --- a/smoke-test/tests/cli/datahub_cli.py +++ b/smoke-test/tests/cli/datahub_cli.py @@ -1,7 +1,11 @@ import json import pytest -from datahub.cli.cli_utils import get_aspects_for_entity, get_session_and_host +from datahub.ingestion.graph.client import get_default_graph +from datahub.metadata.schema_classes import ( + BrowsePathsV2Class, + EditableDatasetPropertiesClass, +) from tests.utils import ingest_file_via_rest, wait_for_writes_to_sync @@ -22,23 +26,19 @@ def test_setup(): env = "PROD" dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})" - session, gms_host = get_session_and_host() + client = get_default_graph() + session = client._session + gms_host = client.config.server - assert "browsePathsV2" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False - ) - assert "editableDatasetProperties" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False - ) + assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None + assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is None ingested_dataset_run_id = ingest_file_via_rest( "tests/cli/cli_test_data.json" ).config.run_id print("Setup ingestion id: " + ingested_dataset_run_id) - assert "browsePathsV2" in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False - ) + assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is not None yield @@ -58,12 +58,8 @@ def test_setup(): ), ) - assert "browsePathsV2" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False - ) - assert "editableDatasetProperties" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False - ) + assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None + assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is None @pytest.mark.dependency() @@ -75,13 +71,14 @@ def test_rollback_editable(): env = "PROD" dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})" - session, gms_host = get_session_and_host() + client = get_default_graph() + session = client._session + gms_host = client.config.server print("Ingested dataset id:", ingested_dataset_run_id) # Assert that second data ingestion worked - assert "browsePathsV2" in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False - ) + + assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is not None # Make editable change ingested_editable_run_id = ingest_file_via_rest( @@ -89,9 +86,8 @@ def test_rollback_editable(): ).config.run_id print("ingested editable id:", ingested_editable_run_id) # Assert that second data ingestion worked - assert "editableDatasetProperties" in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False - ) + + assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is not None # rollback ingestion 1 rollback_url = f"{gms_host}/runs?action=rollback" @@ -107,10 +103,7 @@ def test_rollback_editable(): wait_for_writes_to_sync() # EditableDatasetProperties should still be part of the entity that was soft deleted. - assert "editableDatasetProperties" in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False - ) + assert client.get_aspect(dataset_urn, EditableDatasetPropertiesClass) is not None + # But first ingestion aspects should not be present - assert "browsePathsV2" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["browsePathsV2"], typed=False - ) + assert client.get_aspect(dataset_urn, BrowsePathsV2Class) is None diff --git a/smoke-test/tests/cli/datahub_graph_test.py b/smoke-test/tests/cli/datahub_graph_test.py index 1e324477adb6b..0af5572c7d1d9 100644 --- a/smoke-test/tests/cli/datahub_graph_test.py +++ b/smoke-test/tests/cli/datahub_graph_test.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest import tenacity from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph @@ -36,9 +38,9 @@ def test_healthchecks(wait_for_healthchecks): @pytest.mark.dependency(depends=["test_healthchecks"]) def test_get_aspect_v2(frontend_session, ingest_cleanup_data): - graph: DataHubGraph = DataHubGraph(DatahubClientConfig(server=get_gms_url())) + client: DataHubGraph = DataHubGraph(DatahubClientConfig(server=get_gms_url())) urn = "urn:li:dataset:(urn:li:dataPlatform:kafka,test-rollback,PROD)" - schema_metadata: SchemaMetadataClass = graph.get_aspect_v2( + schema_metadata: Optional[SchemaMetadataClass] = client.get_aspect_v2( urn, aspect="schemaMetadata", aspect_type=SchemaMetadataClass ) diff --git a/smoke-test/tests/delete/delete_test.py b/smoke-test/tests/delete/delete_test.py index 3a80e05d0cc4b..3a999224fd3e6 100644 --- a/smoke-test/tests/delete/delete_test.py +++ b/smoke-test/tests/delete/delete_test.py @@ -2,7 +2,7 @@ import os import pytest -from datahub.cli.cli_utils import get_aspects_for_entity, get_session_and_host +from datahub.cli.cli_utils import get_aspects_for_entity from tests.utils import ( delete_urns_from_file, @@ -38,14 +38,24 @@ def test_setup(): env = "PROD" dataset_urn = f"urn:li:dataset:({platform},{dataset_name},{env})" - session, gms_host = get_session_and_host() + client = get_datahub_graph() + session = client._session + gms_host = client.config.server try: assert "institutionalMemory" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False + session, + gms_host, + entity_urn=dataset_urn, + aspects=["institutionalMemory"], + typed=False, ) assert "editableDatasetProperties" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False + session, + gms_host, + entity_urn=dataset_urn, + aspects=["editableDatasetProperties"], + typed=False, ) except Exception as e: delete_urns_from_file("tests/delete/cli_test_data.json") @@ -56,7 +66,11 @@ def test_setup(): ).config.run_id assert "institutionalMemory" in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False + session, + gms_host, + entity_urn=dataset_urn, + aspects=["institutionalMemory"], + typed=False, ) yield @@ -71,10 +85,18 @@ def test_setup(): wait_for_writes_to_sync() assert "institutionalMemory" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["institutionalMemory"], typed=False + session, + gms_host, + entity_urn=dataset_urn, + aspects=["institutionalMemory"], + typed=False, ) assert "editableDatasetProperties" not in get_aspects_for_entity( - entity_urn=dataset_urn, aspects=["editableDatasetProperties"], typed=False + session, + gms_host, + entity_urn=dataset_urn, + aspects=["editableDatasetProperties"], + typed=False, ) diff --git a/smoke-test/tests/lineage/test_lineage.py b/smoke-test/tests/lineage/test_lineage.py index a24a700593378..c9895568a7140 100644 --- a/smoke-test/tests/lineage/test_lineage.py +++ b/smoke-test/tests/lineage/test_lineage.py @@ -6,10 +6,8 @@ import datahub.emitter.mce_builder as builder import networkx as nx import pytest -from datahub.cli.cli_utils import get_url_and_token from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.graph.client import DataHubGraph # get_default_graph, -from datahub.ingestion.graph.client import DatahubClientConfig +from datahub.ingestion.graph.client import DataHubGraph, get_default_graph from datahub.metadata.schema_classes import ( AuditStampClass, ChangeAuditStampsClass, @@ -847,10 +845,7 @@ def test_lineage_via_node( ) # Create an emitter to the GMS REST API. - (url, token) = get_url_and_token() - with DataHubGraph( - DatahubClientConfig(server=url, token=token, retry_max_times=0) - ) as graph: + with get_default_graph() as graph: emitter = graph # emitter = DataHubConsoleEmitter() @@ -891,14 +886,11 @@ def destination_urn_fixture(): def ingest_multipath_metadata( chart_urn_fixture, intermediates_fixture, destination_urn_fixture ): - (url, token) = get_url_and_token() fake_auditstamp = AuditStampClass( time=int(time.time() * 1000), actor="urn:li:corpuser:datahub", ) - with DataHubGraph( - DatahubClientConfig(server=url, token=token, retry_max_times=0) - ) as graph: + with get_default_graph() as graph: chart_urn = chart_urn_fixture intermediates = intermediates_fixture destination_urn = destination_urn_fixture diff --git a/smoke-test/tests/patch/test_dataset_patches.py b/smoke-test/tests/patch/test_dataset_patches.py index ec6b4a91fa6be..0c161fb0e6607 100644 --- a/smoke-test/tests/patch/test_dataset_patches.py +++ b/smoke-test/tests/patch/test_dataset_patches.py @@ -3,7 +3,7 @@ from datahub.emitter.mce_builder import make_dataset_urn, make_tag_urn, make_term_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.graph.client import DataHubGraph, DataHubGraphConfig +from datahub.ingestion.graph.client import DataHubGraph, get_default_graph from datahub.metadata.schema_classes import ( DatasetLineageTypeClass, DatasetPropertiesClass, @@ -72,13 +72,16 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks): ) mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=upstream_lineage) - with DataHubGraph(DataHubGraphConfig()) as graph: + with get_default_graph() as graph: graph.emit_mcp(mcpw) upstream_lineage_read = graph.get_aspect_v2( entity_urn=dataset_urn, aspect_type=UpstreamLineageClass, aspect="upstreamLineage", ) + + assert upstream_lineage_read is not None + assert len(upstream_lineage_read.upstreams) > 0 assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn for patch_mcp in ( @@ -94,6 +97,8 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks): aspect_type=UpstreamLineageClass, aspect="upstreamLineage", ) + + assert upstream_lineage_read is not None assert len(upstream_lineage_read.upstreams) == 2 assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn assert upstream_lineage_read.upstreams[1].dataset == patch_dataset_urn @@ -111,6 +116,8 @@ def test_dataset_upstream_lineage_patch(wait_for_healthchecks): aspect_type=UpstreamLineageClass, aspect="upstreamLineage", ) + + assert upstream_lineage_read is not None assert len(upstream_lineage_read.upstreams) == 1 assert upstream_lineage_read.upstreams[0].dataset == other_dataset_urn @@ -148,7 +155,7 @@ def test_field_terms_patch(wait_for_healthchecks): ) mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=editable_field) - with DataHubGraph(DataHubGraphConfig()) as graph: + with get_default_graph() as graph: graph.emit_mcp(mcpw) field_info = get_field_info(graph, dataset_urn, field_path) assert field_info @@ -209,7 +216,7 @@ def test_field_tags_patch(wait_for_healthchecks): ) mcpw = MetadataChangeProposalWrapper(entityUrn=dataset_urn, aspect=editable_field) - with DataHubGraph(DataHubGraphConfig()) as graph: + with get_default_graph() as graph: graph.emit_mcp(mcpw) field_info = get_field_info(graph, dataset_urn, field_path) assert field_info @@ -299,7 +306,7 @@ def test_custom_properties_patch(wait_for_healthchecks): base_aspect=orig_dataset_properties, ) - with DataHubGraph(DataHubGraphConfig()) as graph: + with get_default_graph() as graph: # Patch custom properties along with name for patch_mcp in ( DatasetPatchBuilder(dataset_urn) diff --git a/smoke-test/tests/telemetry/telemetry_test.py b/smoke-test/tests/telemetry/telemetry_test.py index 963d85baef3bb..96f2fa69014cf 100644 --- a/smoke-test/tests/telemetry/telemetry_test.py +++ b/smoke-test/tests/telemetry/telemetry_test.py @@ -1,6 +1,7 @@ import json from datahub.cli.cli_utils import get_aspects_for_entity +from datahub.ingestion.graph.client import get_default_graph def test_no_client_id(): @@ -9,8 +10,16 @@ def test_no_client_id(): "clientId" ] # this is checking for the removal of the invalid aspect RemoveClientIdAspectStep.java + client = get_default_graph() + res_data = json.dumps( - get_aspects_for_entity(entity_urn=client_id_urn, aspects=aspect, typed=False) + get_aspects_for_entity( + session=client._session, + gms_host=client.config.server, + entity_urn=client_id_urn, + aspects=aspect, + typed=False, + ) ) assert res_data == "{}" @@ -19,7 +28,15 @@ def test_no_telemetry_client_id(): client_id_urn = "urn:li:telemetry:clientId" aspect = ["telemetryClientId"] # telemetry expected to be disabled for tests + client = get_default_graph() + res_data = json.dumps( - get_aspects_for_entity(entity_urn=client_id_urn, aspects=aspect, typed=False) + get_aspects_for_entity( + session=client._session, + gms_host=client.config.server, + entity_urn=client_id_urn, + aspects=aspect, + typed=False, + ) ) assert res_data == "{}" diff --git a/smoke-test/tests/timeline/timeline_test.py b/smoke-test/tests/timeline/timeline_test.py index f8a0e425c3781..4573514f7806c 100644 --- a/smoke-test/tests/timeline/timeline_test.py +++ b/smoke-test/tests/timeline/timeline_test.py @@ -179,11 +179,13 @@ def test_ownership(): def put(urn: str, aspect: str, aspect_data: str) -> None: """Update a single aspect of an entity""" - + client = get_datahub_graph() entity_type = guess_entity_type(urn) with open(aspect_data) as fp: aspect_obj = json.load(fp) post_entity( + session=client._session, + gms_host=client.config.server, urn=urn, aspect_name=aspect, entity_type=entity_type, diff --git a/smoke-test/tests/utils.py b/smoke-test/tests/utils.py index 29b956bde9ab8..0895056fe3ddd 100644 --- a/smoke-test/tests/utils.py +++ b/smoke-test/tests/utils.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Tuple from datahub.cli import cli_utils, env_utils -from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph +from datahub.ingestion.graph.client import DataHubGraph, get_default_graph from datahub.ingestion.run.pipeline import Pipeline from joblib import Parallel, delayed @@ -120,7 +120,7 @@ def ingest_file_via_rest(filename: str) -> Pipeline: @functools.lru_cache(maxsize=1) def get_datahub_graph() -> DataHubGraph: - return DataHubGraph(DatahubClientConfig(server=get_gms_url())) + return get_default_graph() def delete_urn(urn: str) -> None: