Skip to content

Commit

Permalink
feat: testing template dir layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Aug 2, 2023
1 parent 105e372 commit 065b636
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 109 deletions.
189 changes: 122 additions & 67 deletions litestar/contrib/sqlalchemy/alembic/commands.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,53 @@
from __future__ import annotations

import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Mapping, TextIO

from alembic import command as migration_command
from alembic.config import Config as _AlembicCommandConfig
from alembic.ddl.impl import DefaultImpl

from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
from litestar.contrib.sqlalchemy.plugins.init.plugin import SQLAlchemyInitPlugin

if TYPE_CHECKING:
import os
from argparse import Namespace

from alembic.runtime.environment import ProcessRevisionDirectiveFn

from litestar.app import Litestar
from litestar.contrib.sqlalchemy.plugins.init.config.asyncio import AlembicAsyncConfig
from litestar.contrib.sqlalchemy.plugins.init.config.sync import AlembicSyncConfig


class AlembicCommandConfig(_AlembicCommandConfig):
def __init__(
self,
file_: str | os.PathLike[str] | None = None,
ini_section: str = "alembic",
output_buffer: TextIO | None = None,
stdout: TextIO = sys.stdout,
cmd_opts: Namespace | None = None,
config_args: Mapping[str, Any] = ..., # type: ignore[assignment]
attributes: dict | None = None,
template_directory: Path | None = None,
) -> None:
self.template_directory = template_directory
super().__init__(file_, ini_section, output_buffer, stdout, cmd_opts, config_args, attributes)

def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
commands.
"""
from litestar.contrib.sqlalchemy import alembic
if self.template_directory is not None:
return str(self.template_directory)

package_dir = Path(alembic.__file__).parent.resolve()
package_dir = Path(__file__).parent.resolve()
return str(Path(package_dir / "templates"))


Expand All @@ -31,119 +57,131 @@ class AlembicSpannerImpl(DefaultImpl):
__dialect__ = "spanner+spanner"


def get_alembic_config(
migration_config: str | None = None, script_location: str = "migrations"
def get_alembic_command_config(
alembic_config: str | None = None, script_location: str = "migrations"
) -> AlembicCommandConfig:
kwargs = {}
if migration_config:
kwargs.update({"file_": migration_config})
if alembic_config:
kwargs.update({"file_": alembic_config})
alembic_cfg = AlembicCommandConfig(**kwargs) # type: ignore
alembic_cfg.set_main_option("script_location", script_location)
return alembic_cfg


async def upgrade(
migration_config: str | None = None,
script_location: str = "migrations",
def get_alembic_config(app: Litestar) -> AlembicAsyncConfig | AlembicSyncConfig:
return app.plugins.get(SQLAlchemyInitPlugin)._alembic_config


def upgrade(
app: Litestar,
revision: str = "head",
sql: bool = False,
tag: str | None = None,
) -> None:
"""Create or upgrade a database."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.upgrade(config=alembic_cfg, revision=revision, tag=tag, sql=sql)


async def downgrade(
migration_config: str | None = None,
script_location: str = "migrations",
def downgrade(
app: Litestar,
revision: str = "head",
sql: bool = False,
tag: str | None = None,
) -> None:
"""Downgrade a database to a specific revision."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.downgrade(config=alembic_cfg, revision=revision, tag=tag, sql=sql)


async def check(
migration_config: str | None = None,
script_location: str = "migrations",
def check(
app: Litestar,
) -> None:
"""Check if revision command with autogenerate has pending upgrade ops."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.check(config=alembic_cfg)


async def current(
migration_config: str | None = None, script_location: str = "migrations", verbose: bool = False
) -> None:
def current(app: Litestar, verbose: bool = False) -> None:
"""Display the current revision for a database."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.current(alembic_cfg, verbose=verbose)


async def edit(
revision: str,
migration_config: str | None = None,
script_location: str = "migrations",
) -> None:
def edit(app: Litestar, revision: str) -> None:
"""Edit revision script(s) using $EDITOR."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.edit(config=alembic_cfg, rev=revision)


async def ensure_version(
migration_config: str | None = None, script_location: str = "migrations", sql: bool = False
) -> None:
def ensure_version(app: Litestar, sql: bool = False) -> None:
"""Create the alembic version table if it doesn't exist already."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.ensure_version(config=alembic_cfg, sql=sql)


async def heads(
migration_config: str | None = None,
script_location: str = "migrations",
verbose: bool = False,
resolve_dependencies: bool = False,
) -> None:
def heads(app: Litestar, verbose: bool = False, resolve_dependencies: bool = False) -> None:
"""Show current available heads in the script directory."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.heads(config=alembic_cfg, verbose=verbose, resolve_dependencies=resolve_dependencies) # type: ignore[no-untyped-call]


async def history(
migration_config: str | None = None,
script_location: str = "migrations",
def history(
app: Litestar,
rev_range: str | None = None,
verbose: bool = False,
indicate_current: bool = False,
) -> None:
"""List changeset scripts in chronological order."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.history(
config=alembic_cfg, rev_range=rev_range, verbose=verbose, indicate_current=indicate_current
)


async def merge(
def merge(
app: Litestar,
revisions: str,
migration_config: str | None = None,
script_location: str = "migrations",
message: str | None = None,
branch_label: str | None = None,
rev_id: str | None = None,
) -> None:
"""Merge two revisions together. Creates a new migration file."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.merge(
config=alembic_cfg, revisions=revisions, message=message, branch_label=branch_label, rev_id=rev_id
)


async def revision(
migration_config: str | None = None,
script_location: str = "migrations",
def revision(
app: Litestar,
message: str | None = None,
autogenerate: bool = False,
sql: bool = False,
Expand All @@ -156,7 +194,10 @@ async def revision(
process_revision_directives: ProcessRevisionDirectiveFn | None = None,
) -> None:
"""Create a new revision file."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.revision(
config=alembic_cfg,
message=message,
Expand All @@ -172,43 +213,57 @@ async def revision(
)


async def show(
def show(
app: Litestar,
rev: Any,
migration_config: str | None = None,
script_location: str = "migrations",
) -> None:
"""Show the revision(s) denoted by the given symbol."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.show(config=alembic_cfg, rev=rev) # type: ignore[no-untyped-call]


async def init(
def init(
app: Litestar,
directory: str,
migration_config: str | None = None,
template_path: str | None = None,
script_location: str = "migrations",
template: str = "generic",
package: bool = False,
multidb: bool = False,
) -> None:
"""Initialize a new scripts directory."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
template = "sync"
if isinstance(plugin._config, SQLAlchemyAsyncConfig):
template = "asyncio"
if multidb:
template = f"{template}-multidb"
migration_command.init(config=alembic_cfg, directory=directory, template=template, package=package)


async def list_templates(migration_config: str | None, script_location: str = "migrations") -> None:
def list_templates(app: Litestar) -> None:
"""List available templates."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.list_templates(config=alembic_cfg)


async def stamp(
migration_config: str,
def stamp(
app: Litestar,
revision: str,
script_location: str = "migrations",
sql: bool = False,
tag: str | None[str] = None,
purge: bool = False,
) -> None:
"""'stamp' the revision table with the given revision; don't run any migrations."""
alembic_cfg = get_alembic_config(migration_config=migration_config, script_location=script_location)
plugin = app.plugins.get(SQLAlchemyInitPlugin)
alembic_cfg = get_alembic_command_config(
alembic_config=plugin._alembic_config.alembic_config, script_location=plugin._alembic_config.script_location
)
migration_command.stamp(config=alembic_cfg, revision=revision, sql=sql, tag=tag, purge=purge)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# type: ignore
"""${message}

Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}

"""
import warnings

import sqlalchemy as sa
from alembic import op
from litestar.contrib.sqlalchemy.types import GUID, ORA_JSONB, DateTimeUTC
${imports if imports else ""}

__all__ = ["downgrade", "upgrade", "schema_upgrades", "schema_downgrades", "data_upgrades", "data_downgrades"]

sa.GUID = GUID
sa.DateTimeUTC = DateTimeUTC
sa.ORA_JSONB = ORA_JSONB

# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}


def upgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
schema_upgrades()
data_upgrades()

def downgrade() -> None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
with op.get_context().autocommit_block():
data_downgrades()
schema_downgrades()

def schema_upgrades() -> None:
"""schema upgrade migrations go here."""
${upgrades if upgrades else "pass"}

def schema_downgrades() -> None:
"""schema downgrade migrations go here."""
${downgrades if downgrades else "pass"}

def data_upgrades() -> None:
"""Add any optional data upgrade migrations here!"""

def data_downgrades() -> None:
"""Add any optional data downgrade migrations here!"""
Loading

0 comments on commit 065b636

Please sign in to comment.