diff --git a/litestar/contrib/sqlalchemy/alembic/commands.py b/litestar/contrib/sqlalchemy/alembic/commands.py index 6eae8266c0..29d9acfc5c 100644 --- a/litestar/contrib/sqlalchemy/alembic/commands.py +++ b/litestar/contrib/sqlalchemy/alembic/commands.py @@ -1,17 +1,42 @@ from __future__ import annotations +import sys from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Mapping, TextIO from alembic import command as migration_command from alembic.config import Config as _AlembicCommandConfig from alembic.ddl.impl import DefaultImpl +from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import SQLAlchemyAsyncConfig +from litestar.contrib.sqlalchemy.plugins.init.plugin import SQLAlchemyInitPlugin + if TYPE_CHECKING: + import os + from argparse import Namespace + from alembic.runtime.environment import ProcessRevisionDirectiveFn + from litestar.app import Litestar + from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import AlembicAsyncConfig + from litestar.contrib.sqlalchemy.plugins.init.config.sync import AlembicSyncConfig + class AlembicCommandConfig(_AlembicCommandConfig): + def __init__( + self, + file_: str | os.PathLike[str] | None = None, + ini_section: str = "alembic", + output_buffer: TextIO | None = None, + stdout: TextIO = sys.stdout, + cmd_opts: Namespace | None = None, + config_args: Mapping[str, Any] = ..., # type: ignore[assignment] + attributes: dict | None = None, + template_directory: Path | None = None, + ) -> None: + self.template_directory = template_directory + super().__init__(file_, ini_section, output_buffer, stdout, cmd_opts, config_args, attributes) + def get_template_directory(self) -> str: """Return the directory where Alembic setup templates are found. @@ -19,9 +44,10 @@ def get_template_directory(self) -> str: commands. """ - from litestar.contrib.sqlalchemy import alembic + if self.template_directory is not None: + return str(self.template_directory) - package_dir = Path(alembic.__file__).parent.resolve() + package_dir = Path(__file__).parent.resolve() return str(Path(package_dir / "templates")) @@ -31,119 +57,131 @@ class AlembicSpannerImpl(DefaultImpl): __dialect__ = "spanner+spanner" -def get_alembic_config( - migration_config: str | None = None, script_location: str = "migrations" +def get_alembic_command_config( + alembic_config: str | None = None, script_location: str = "migrations" ) -> AlembicCommandConfig: kwargs = {} - if migration_config: - kwargs.update({"file_": migration_config}) + if alembic_config: + kwargs.update({"file_": alembic_config}) alembic_cfg = AlembicCommandConfig(**kwargs) # type: ignore alembic_cfg.set_main_option("script_location", script_location) return alembic_cfg -async def upgrade( - migration_config: str | None = None, - script_location: str = "migrations", +def get_alembic_config(app: Litestar) -> AlembicAsyncConfig | AlembicSyncConfig: + return app.plugins.get(SQLAlchemyInitPlugin)._alembic_config + + +def upgrade( + app: Litestar, revision: str = "head", sql: bool = False, tag: str | None = None, ) -> None: """Create or upgrade a database.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.upgrade(config=alembic_cfg, revision=revision, tag=tag, sql=sql) -async def downgrade( - migration_config: str | None = None, - script_location: str = "migrations", +def downgrade( + app: Litestar, revision: str = "head", sql: bool = False, tag: str | None = None, ) -> None: """Downgrade a database to a specific revision.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.downgrade(config=alembic_cfg, revision=revision, tag=tag, sql=sql) -async def check( - migration_config: str | None = None, - script_location: str = "migrations", +def check( + app: Litestar, ) -> None: """Check if revision command with autogenerate has pending upgrade ops.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.check(config=alembic_cfg) -async def current( - migration_config: str | None = None, script_location: str = "migrations", verbose: bool = False -) -> None: +def current(app: Litestar, verbose: bool = False) -> None: """Display the current revision for a database.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.current(alembic_cfg, verbose=verbose) -async def edit( - revision: str, - migration_config: str | None = None, - script_location: str = "migrations", -) -> None: +def edit(app: Litestar, revision: str) -> None: """Edit revision script(s) using $EDITOR.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.edit(config=alembic_cfg, rev=revision) -async def ensure_version( - migration_config: str | None = None, script_location: str = "migrations", sql: bool = False -) -> None: +def ensure_version(app: Litestar, sql: bool = False) -> None: """Create the alembic version table if it doesn't exist already.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.ensure_version(config=alembic_cfg, sql=sql) -async def heads( - migration_config: str | None = None, - script_location: str = "migrations", - verbose: bool = False, - resolve_dependencies: bool = False, -) -> None: +def heads(app: Litestar, verbose: bool = False, resolve_dependencies: bool = False) -> None: """Show current available heads in the script directory.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.heads(config=alembic_cfg, verbose=verbose, resolve_dependencies=resolve_dependencies) # type: ignore[no-untyped-call] -async def history( - migration_config: str | None = None, - script_location: str = "migrations", +def history( + app: Litestar, rev_range: str | None = None, verbose: bool = False, indicate_current: bool = False, ) -> None: """List changeset scripts in chronological order.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.history( config=alembic_cfg, rev_range=rev_range, verbose=verbose, indicate_current=indicate_current ) -async def merge( +def merge( + app: Litestar, revisions: str, - migration_config: str | None = None, - script_location: str = "migrations", message: str | None = None, branch_label: str | None = None, rev_id: str | None = None, ) -> None: """Merge two revisions together. Creates a new migration file.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.merge( config=alembic_cfg, revisions=revisions, message=message, branch_label=branch_label, rev_id=rev_id ) -async def revision( - migration_config: str | None = None, - script_location: str = "migrations", +def revision( + app: Litestar, message: str | None = None, autogenerate: bool = False, sql: bool = False, @@ -156,7 +194,10 @@ async def revision( process_revision_directives: ProcessRevisionDirectiveFn | None = None, ) -> None: """Create a new revision file.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.revision( config=alembic_cfg, message=message, @@ -172,43 +213,57 @@ async def revision( ) -async def show( +def show( + app: Litestar, rev: Any, - migration_config: str | None = None, - script_location: str = "migrations", ) -> None: """Show the revision(s) denoted by the given symbol.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.show(config=alembic_cfg, rev=rev) # type: ignore[no-untyped-call] -async def init( +def init( + app: Litestar, directory: str, - migration_config: str | None = None, template_path: str | None = None, - script_location: str = "migrations", - template: str = "generic", package: bool = False, + multidb: bool = False, ) -> None: """Initialize a new scripts directory.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) + template = "sync" + if isinstance(plugin._config, SQLAlchemyAsyncConfig): + template = "asyncio" + if multidb: + template = f"{template}-multidb" migration_command.init(config=alembic_cfg, directory=directory, template=template, package=package) -async def list_templates(migration_config: str | None, script_location: str = "migrations") -> None: +def list_templates(app: Litestar) -> None: """List available templates.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.list_templates(config=alembic_cfg) -async def stamp( - migration_config: str, +def stamp( + app: Litestar, revision: str, - script_location: str = "migrations", sql: bool = False, tag: str | None[str] = None, purge: bool = False, ) -> None: """'stamp' the revision table with the given revision; don't run any migrations.""" - alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location) + plugin = app.plugins.get(SQLAlchemyInitPlugin) + alembic_cfg = get_alembic_command_config( + alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location + ) migration_command.stamp(config=alembic_cfg, revision=revision, sql=sql, tag=tag, purge=purge) diff --git a/litestar/contrib/sqlalchemy/alembic/templates/script.py.mako b/litestar/contrib/sqlalchemy/alembic/templates/asyncio-multidb/script.py.mako similarity index 100% rename from litestar/contrib/sqlalchemy/alembic/templates/script.py.mako rename to litestar/contrib/sqlalchemy/alembic/templates/asyncio-multidb/script.py.mako diff --git a/litestar/contrib/sqlalchemy/alembic/templates/asyncio/script.py.mako b/litestar/contrib/sqlalchemy/alembic/templates/asyncio/script.py.mako new file mode 100644 index 0000000000..c9dbf81b7b --- /dev/null +++ b/litestar/contrib/sqlalchemy/alembic/templates/asyncio/script.py.mako @@ -0,0 +1,55 @@ +# type: ignore +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +import warnings + +import sqlalchemy as sa +from alembic import op +from litestar.contrib.sqlalchemy.types import GUID, ORA_JSONB, DateTimeUTC +${imports if imports else ""} + +__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"] + +sa.GUID = GUID +sa.DateTimeUTC = DateTimeUTC +sa.ORA_JSONB = ORA_JSONB + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + schema_upgrades() + data_upgrades() + +def downgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + data_downgrades() + schema_downgrades() + +def schema_upgrades() -> None: + """schema upgrade migrations go here.""" + ${upgrades if upgrades else "pass"} + +def schema_downgrades() -> None: + """schema downgrade migrations go here.""" + ${downgrades if downgrades else "pass"} + +def data_upgrades() -> None: + """Add any optional data upgrade migrations here!""" + +def data_downgrades() -> None: + """Add any optional data downgrade migrations here!""" diff --git a/litestar/contrib/sqlalchemy/alembic/templates/sync-multidb/script.py.mako b/litestar/contrib/sqlalchemy/alembic/templates/sync-multidb/script.py.mako new file mode 100644 index 0000000000..c9dbf81b7b --- /dev/null +++ b/litestar/contrib/sqlalchemy/alembic/templates/sync-multidb/script.py.mako @@ -0,0 +1,55 @@ +# type: ignore +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +import warnings + +import sqlalchemy as sa +from alembic import op +from litestar.contrib.sqlalchemy.types import GUID, ORA_JSONB, DateTimeUTC +${imports if imports else ""} + +__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"] + +sa.GUID = GUID +sa.DateTimeUTC = DateTimeUTC +sa.ORA_JSONB = ORA_JSONB + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + schema_upgrades() + data_upgrades() + +def downgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + data_downgrades() + schema_downgrades() + +def schema_upgrades() -> None: + """schema upgrade migrations go here.""" + ${upgrades if upgrades else "pass"} + +def schema_downgrades() -> None: + """schema downgrade migrations go here.""" + ${downgrades if downgrades else "pass"} + +def data_upgrades() -> None: + """Add any optional data upgrade migrations here!""" + +def data_downgrades() -> None: + """Add any optional data downgrade migrations here!""" diff --git a/litestar/contrib/sqlalchemy/alembic/templates/sync/script.py.mako b/litestar/contrib/sqlalchemy/alembic/templates/sync/script.py.mako new file mode 100644 index 0000000000..c9dbf81b7b --- /dev/null +++ b/litestar/contrib/sqlalchemy/alembic/templates/sync/script.py.mako @@ -0,0 +1,55 @@ +# type: ignore +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +import warnings + +import sqlalchemy as sa +from alembic import op +from litestar.contrib.sqlalchemy.types import GUID, ORA_JSONB, DateTimeUTC +${imports if imports else ""} + +__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"] + +sa.GUID = GUID +sa.DateTimeUTC = DateTimeUTC +sa.ORA_JSONB = ORA_JSONB + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + schema_upgrades() + data_upgrades() + +def downgrade() -> None: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + with op.get_context().autocommit_block(): + data_downgrades() + schema_downgrades() + +def schema_upgrades() -> None: + """schema upgrade migrations go here.""" + ${upgrades if upgrades else "pass"} + +def schema_downgrades() -> None: + """schema downgrade migrations go here.""" + ${downgrades if downgrades else "pass"} + +def data_upgrades() -> None: + """Add any optional data upgrade migrations here!""" + +def data_downgrades() -> None: + """Add any optional data downgrade migrations here!""" diff --git a/litestar/contrib/sqlalchemy/cli.py b/litestar/contrib/sqlalchemy/cli.py index 0d6c788bf2..aab5fc0656 100644 --- a/litestar/contrib/sqlalchemy/cli.py +++ b/litestar/contrib/sqlalchemy/cli.py @@ -2,16 +2,11 @@ from typing import TYPE_CHECKING -import anyio - from litestar.cli._utils import RICH_CLICK_INSTALLED, LitestarGroup from litestar.contrib.sqlalchemy.alembic import commands as db_utils -from litestar.exceptions import LitestarException if TYPE_CHECKING: from litestar import Litestar - from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import AlembicAsyncConfig - from litestar.contrib.sqlalchemy.plugins.init.config.sync import AlembicSyncConfig if TYPE_CHECKING or not RICH_CLICK_INSTALLED: @@ -25,16 +20,6 @@ def database_group() -> None: """Manage SQLAlchemy database components.""" -def get_alembic_config(app: Litestar) -> AlembicAsyncConfig | AlembicSyncConfig: - config: AlembicAsyncConfig | AlembicSyncConfig | None = None - for cli_plugin in app.cli_plugins: - if hasattr(cli_plugin, "_alembic_config"): - config = cli_plugin._alembic_config - if config is None: - raise LitestarException("Could not find SQLAlchemy configuration.") - return config - - @database_group.command( name="current-revision", help="Shows the current revision for the database.", @@ -43,8 +28,7 @@ def get_alembic_config(app: Litestar) -> AlembicAsyncConfig | AlembicSyncConfig: def show_database_revision(app: Litestar, verbose: bool) -> None: """Show current database revision.""" - config = get_alembic_config(app) - anyio.run(db_utils.current, config.alembic_config, config.script_location, verbose) + db_utils.current(app=app, verbose=verbose) @database_group.command( @@ -55,7 +39,7 @@ def show_database_revision(app: Litestar, verbose: bool) -> None: "--revision", type=str, help="Revision to upgrade to", - default="head", + default="-1", ) @option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) @option( @@ -64,11 +48,10 @@ def show_database_revision(app: Litestar, verbose: bool) -> None: type=str, default=None, ) -def downgrade_database(app: Litestar, revision: str | None, sql: bool, tag: str | None) -> None: +def downgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None) -> None: """Downgrade the database to the latest revision.""" - config = get_alembic_config(app) - anyio.run(db_utils.downgrade, config.alembic_config, config.script_location, revision, sql, tag) + db_utils.downgrade(app=app, revision=revision, sql=sql, tag=tag) @database_group.command( @@ -88,11 +71,10 @@ def downgrade_database(app: Litestar, revision: str | None, sql: bool, tag: str type=str, default=None, ) -def upgrade_database(app: Litestar, revision: str | None, sql: bool, tag: str | None) -> None: +def upgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None) -> None: """Upgrade the database to the latest revision.""" - config = get_alembic_config(app) - anyio.run(db_utils.upgrade, config.alembic_config, config.script_location, revision, sql, tag) + db_utils.upgrade(app=app, revision=revision, sql=sql, tag=tag) @database_group.command( @@ -100,20 +82,11 @@ def upgrade_database(app: Litestar, revision: str | None, sql: bool, tag: str | help="Initialize migrations for the project.", ) @option( - "--revision", - type=str, - help="Revision to upgrade to", - default="head", -) -@option("--sql", type=bool, help="Generate SQL output for offline migrations.", default=False, is_flag=True) -@option( - "--tag", - help="an arbitrary 'tag' that can be intercepted by custom env.py scripts via the .EnvironmentContext.get_tag_argument method.", - type=str, - default=None, + "-d", "--directory", default="migrations", help="Location to save migration scripts. The default is 'migrations/'" ) -def init_alembic(app: Litestar, revision: str | None, sql: bool, tag: str | None) -> None: +@option("--multidb", is_flag=True, default=False, help="Support multiple databases") +@option("--package", is_flag=True, default=True, help="Create `__init__.py` for created folder") +def init_alembic(app: Litestar, directory: str, multidb: bool, package: bool) -> None: """Upgrade the database to the latest revision.""" - config = get_alembic_config(app) - anyio.run(db_utils.init, config.alembic_config, config.script_location, revision, sql, tag) + db_utils.init(app=app, directory=directory, multidb=multidb, package=package) diff --git a/litestar/contrib/sqlalchemy/plugins/init/config/common.py b/litestar/contrib/sqlalchemy/plugins/init/config/common.py index 02fd5ef5f2..10b8895b4c 100644 --- a/litestar/contrib/sqlalchemy/plugins/init/config/common.py +++ b/litestar/contrib/sqlalchemy/plugins/init/config/common.py @@ -15,7 +15,6 @@ from .engine import EngineConfig if TYPE_CHECKING: - from pathlib import Path from typing import Any from sqlalchemy import Connection, Engine, MetaData @@ -263,14 +262,14 @@ class GenericAlembicConfig: For details see: https://alembic.sqlalchemy.org/en/latest/api/config.html """ - alembic_config: str | Path | EmptyType = Empty + alembic_config: str | None = None """A path to the Alembic configuration file such as ``alembic.ini``. If left unset, the default configuration will be used. """ - version_table_name: str | EmptyType = Empty + version_table_name: str | None = None """Configure the name of the table used to hold the applied alembic revisions. Defaults to ``alembic``. THe name of the table """ - script_location: str | Path | EmptyType = Empty + script_location: str = "migrations" """A path to save generated migrations. """ target_metadata: MetaData = orm_registry.metadata