diff --git a/superset/db_engine_specs/doris.py b/superset/db_engine_specs/doris.py index e502f5bda2be7..4cc1d9ce554be 100644 --- a/superset/db_engine_specs/doris.py +++ b/superset/db_engine_specs/doris.py @@ -22,11 +22,13 @@ from flask_babel import gettext as __ from sqlalchemy import Float, Integer, Numeric, String, TEXT, types +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.sql.type_api import TypeEngine from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import SupersetErrorType +from superset.models.core import Database from superset.utils.core import GenericDataType # Regular expressions to catch custom errors @@ -111,6 +113,7 @@ class DorisEngineSpec(MySQLEngineSpec): ) encryption_parameters = {"ssl": "0"} supports_dynamic_schema = True + supports_catalog = supports_dynamic_catalog = True column_type_mappings = ( # type: ignore ( @@ -245,17 +248,47 @@ def adjust_engine_params( catalog: Optional[str] = None, schema: Optional[str] = None, ) -> tuple[URL, dict[str, Any]]: - database = uri.database - if schema and database: - schema = parse.quote(schema, safe="") - if "." in database: - database = database.split(".")[0] + "." + schema - else: - database = "internal." + schema - uri = uri.set(database=database) - + if catalog: + pass + elif uri.database and "." in uri.database: + catalog, _ = uri.database.split(".", 1) + else: + catalog = "internal" + + # In Apache Doris, each catalog has an information_schema for BI tool + # compatibility. See: https://github.com/apache/doris/pull/28919 + if schema: + adjusted_database = ".".join([catalog or "", schema]) + else: + adjusted_database = ".".join([catalog or "", "information_schema"]) + uri = uri.set(database=adjusted_database) return uri, connect_args + @classmethod + def get_default_catalog(cls, database: Database) -> Optional[str]: + """ + Return the default catalog. + """ + if database.url_object.database is None: + return None + + return database.url_object.database.split(".")[0] + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + """ + Get all catalogs. + For Doris, the SHOW CATALOGS command returns multiple columns: + CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment + We need to extract just the CatalogName column. + """ + result = inspector.bind.execute("SHOW CATALOGS") + return {row.CatalogName for row in result} + @classmethod def get_schema_from_engine_params( cls, diff --git a/tests/unit_tests/db_engine_specs/test_doris.py b/tests/unit_tests/db_engine_specs/test_doris.py index ced1a6862b83e..d79bc7dcbbd24 100644 --- a/tests/unit_tests/db_engine_specs/test_doris.py +++ b/tests/unit_tests/db_engine_specs/test_doris.py @@ -16,6 +16,7 @@ # under the License. from typing import Any, Optional +from unittest.mock import Mock import pytest from sqlalchemy import JSON, types @@ -85,25 +86,25 @@ def test_get_column_spec( ( "doris://user:password@host/db1", {"param1": "some_value"}, - "db1", + "internal.information_schema", {"param1": "some_value"}, ), ( "pydoris://user:password@host/db1", {"param1": "some_value"}, - "db1", + "internal.information_schema", {"param1": "some_value"}, ), ( "doris://user:password@host/catalog1.db1", {"param1": "some_value"}, - "catalog1.db1", + "catalog1.information_schema", {"param1": "some_value"}, ), ( "pydoris://user:password@host/catalog1.db1", {"param1": "some_value"}, - "catalog1.db1", + "catalog1.information_schema", {"param1": "some_value"}, ), ], @@ -120,11 +121,21 @@ def test_adjust_engine_params( returned_url, returned_connect_args = DorisEngineSpec.adjust_engine_params( url, connect_args ) + assert returned_url.database == return_schema assert returned_connect_args == return_connect_args -def test_get_schema_from_engine_params() -> None: +@pytest.mark.parametrize( + "url,expected_schema", + [ + ("doris://localhost:9030/hive.test", "test"), + ("doris://localhost:9030/hive", None), + ], +) +def test_get_schema_from_engine_params( + url: str, expected_schema: Optional[str] +) -> None: """ Test the ``get_schema_from_engine_params`` method. """ @@ -132,16 +143,74 @@ def test_get_schema_from_engine_params() -> None: assert ( DorisEngineSpec.get_schema_from_engine_params( - make_url("doris://localhost:9030/hive.test"), + make_url(url), {}, ) - == "test" + == expected_schema ) - assert ( - DorisEngineSpec.get_schema_from_engine_params( - make_url("doris://localhost:9030/hive"), - {}, - ) - is None - ) + +@pytest.mark.parametrize( + "database_value,expected_catalog", + [ + ("catalog1.schema1", "catalog1"), + ("catalog1", "catalog1"), + (None, None), + ], +) +def test_get_default_catalog( + database_value: Optional[str], expected_catalog: Optional[str] +) -> None: + """ + Test the ``get_default_catalog`` method. + """ + from superset.db_engine_specs.doris import DorisEngineSpec + from superset.models.core import Database + + database = Mock(spec=Database) + database.url_object.database = database_value + + assert DorisEngineSpec.get_default_catalog(database) == expected_catalog + + +@pytest.mark.parametrize( + "mock_catalogs,expected_result", + [ + ( + [ + Mock(CatalogName="catalog1"), + Mock(CatalogName="catalog2"), + Mock(CatalogName="catalog3"), + ], + {"catalog1", "catalog2", "catalog3"}, + ), + ( + [Mock(CatalogName="single_catalog")], + {"single_catalog"}, + ), + ( + [], + set(), + ), + ], +) +def test_get_catalog_names( + mock_catalogs: list[Mock], expected_result: set[str] +) -> None: + """ + Test the ``get_catalog_names`` method. + """ + from superset.db_engine_specs.doris import DorisEngineSpec + from superset.models.core import Database + + database = Mock(spec=Database) + inspector = Mock() + inspector.bind.execute.return_value = mock_catalogs + + catalogs = DorisEngineSpec.get_catalog_names(database, inspector) + + # Verify the SQL query + inspector.bind.execute.assert_called_once_with("SHOW CATALOGS") + + # Verify the returned catalog names + assert catalogs == expected_result