Skip to content

Commit

Permalink
Run pre-commit hooks for the first time
Browse files Browse the repository at this point in the history
  • Loading branch information
lpsinger committed Sep 3, 2024
1 parent 752af12 commit df6013d
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 113 deletions.
25 changes: 16 additions & 9 deletions .github/workflows/runtime_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
import matplotlib.pyplot as plt
import pandas as pd

with open('benchmark_results.json') as f:
with open("benchmark_results.json") as f:
data = json.load(f)

df = pd.json_normalize(data['benchmarks'])
df['func'] = df['name'].str.split('[', expand=True)[0]
df['fullfunc'] = df['fullname'].str.split('[', expand=True)[0]
df['param'] = pd.to_numeric(df['param'])
for func, group in df.groupby('func'):
df = pd.json_normalize(data["benchmarks"])
df["func"] = df["name"].str.split("[", expand=True)[0]
df["fullfunc"] = df["fullname"].str.split("[", expand=True)[0]
df["param"] = pd.to_numeric(df["param"])
for func, group in df.groupby("func"):
fig, ax = plt.subplots()
group.plot.scatter(x='param', y='stats.mean', yerr='stats.stddev',
loglog=True, xlabel='N', ylabel='Runtime (s)', ax=ax)
group.plot.scatter(
x="param",
y="stats.mean",
yerr="stats.stddev",
loglog=True,
xlabel="N",
ylabel="Runtime (s)",
ax=ax,
)
fig.tight_layout()
fig.savefig(f'plots/{func}')
fig.savefig(f"plots/{func}")
20 changes: 12 additions & 8 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
"""Pytest configuration for running the doctests in README.md."""

from unittest.mock import Mock
import sqlalchemy as sa

import pytest
import sqlalchemy as sa


@pytest.fixture
def engine(postgresql):
"""Create an SQLAlchemy engine with a disposable PostgreSQL database."""
return sa.create_engine('postgresql+psycopg://',
poolclass=sa.pool.StaticPool,
pool_reset_on_return=None,
creator=lambda: postgresql)
return sa.create_engine(
"postgresql+psycopg://",
poolclass=sa.pool.StaticPool,
pool_reset_on_return=None,
creator=lambda: postgresql,
)


@pytest.fixture(autouse=True)
def add_mock_create_engine(monkeypatch, request):
"""Monkey patch sqlalchemy.create_engine for doctests in README.md."""
if request.node.name == 'README.md':
engine = request.getfixturevalue('engine')
monkeypatch.setattr(sa, 'create_engine', Mock(return_value=engine))
if request.node.name == "README.md":
engine = request.getfixturevalue("engine")
monkeypatch.setattr(sa, "create_engine", Mock(return_value=engine))
4 changes: 2 additions & 2 deletions healpix_alchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .types import Point, Tile
from . import func # noqa: F401
from .types import Point, Tile

__all__ = ('Point', 'Tile')
__all__ = ("Point", "Tile")
8 changes: 4 additions & 4 deletions healpix_alchemy/constants.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from astropy.coordinates import ICRS
import sqlalchemy as sa
from astropy import units as u
from astropy_healpix import level_to_nside, HEALPix
from astropy.coordinates import ICRS
from astropy_healpix import HEALPix, level_to_nside
from mocpy import MOC
import sqlalchemy as sa

LEVEL = MOC.MAX_ORDER
"""Base HEALPix resolution. This is the maximum HEALPix level that can be
stored in a signed 8-byte integer data type."""

HPX = HEALPix(nside=level_to_nside(LEVEL), order='nested', frame=ICRS())
HPX = HEALPix(nside=level_to_nside(LEVEL), order="nested", frame=ICRS())
"""HEALPix projection object."""

PIXEL_AREA = HPX.pixel_area.to_value(u.sr)
Expand Down
1 change: 1 addition & 0 deletions healpix_alchemy/func.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SQLAlchemy functions."""

from sqlalchemy import func as _func

from .types import Tile as _Tile
Expand Down
32 changes: 15 additions & 17 deletions healpix_alchemy/types.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""SQLAlchemy types for multiresolution HEALPix data."""

from collections.abc import Sequence
from numbers import Integral

import numpy as np
import sqlalchemy as sa
from astropy.coordinates import SkyCoord
from astropy_healpix import uniq_to_level_ipix
from mocpy import MOC
import numpy as np
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import INT8RANGE

from .constants import HPX, LEVEL, PIXEL_AREA_LITERAL

__all__ = ('Point', 'Tile')
__all__ = ("Point", "Tile")


class Point(sa.TypeDecorator):

cache_ok = True
impl = sa.BigInteger

Expand All @@ -30,7 +30,6 @@ def process_bind_param(self, value, dialect):


class Tile(sa.TypeDecorator):

cache_ok = True
impl = INT8RANGE

Expand All @@ -40,11 +39,10 @@ def process_bind_param(self, value, dialect):
shift = 2 * (LEVEL - level)
value = (ipix << shift, (ipix + 1) << shift)
if isinstance(value, Sequence) and len(value) == 2:
value = f'[{value[0]},{value[1]})'
value = f"[{value[0]},{value[1]})"
return value

class comparator_factory(INT8RANGE.comparator_factory):

@property
def lower(self):
return sa.func.lower(self, type_=Point)
Expand All @@ -68,29 +66,29 @@ def tiles_from(cls, obj):
elif isinstance(obj, SkyCoord):
return cls.tiles_from_polygon_skycoord(obj)
else:
raise TypeError('Unknown type')
raise TypeError("Unknown type")

@classmethod
def tiles_from_polygon_skycoord(cls, polygon):
return cls.tiles_from_moc(
MOC.from_polygon_skycoord(
polygon.transform_to(HPX.frame)))
MOC.from_polygon_skycoord(polygon.transform_to(HPX.frame))
)

@classmethod
def tiles_from_moc(cls, moc):
return (f'[{lo},{hi})' for lo, hi in moc.to_depth29_ranges)
return (f"[{lo},{hi})" for lo, hi in moc.to_depth29_ranges)


@sa.event.listens_for(sa.Index, 'after_parent_attach')
@sa.event.listens_for(sa.Index, "after_parent_attach")
def _create_indices(index, parent):
"""Set index method to SP-GiST_ for any indexed Tile or Region columns.
.. _SP-GiST: https://www.postgresql.org/docs/current/spgist.html
"""
if (
index._column_flag and
len(index.expressions) == 1 and
isinstance(index.expressions[0], sa.Column) and
isinstance(index.expressions[0].type, Tile)
index._column_flag
and len(index.expressions) == 1
and isinstance(index.expressions[0], sa.Column)
and isinstance(index.expressions[0].type, Tile)
):
index.dialect_options['postgresql']['using'] = 'spgist'
index.dialect_options["postgresql"]["using"] = "spgist"
2 changes: 1 addition & 1 deletion tests/benchmarks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from sqlalchemy import orm
import pytest
from sqlalchemy import orm

from . import data, models

Expand Down
48 changes: 24 additions & 24 deletions tests/benchmarks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
We use the psycopg ``copy`` rather than SQLAlchemy for fast insertion.
"""
from astropy.coordinates import SkyCoord, uniform_spherical_random_surface
from astropy import units as u
from mocpy import MOC

import numpy as np
import pytest
from astropy import units as u
from astropy.coordinates import SkyCoord, uniform_spherical_random_surface
from mocpy import MOC

from healpix_alchemy.constants import HPX, LEVEL, PIXEL_AREA
from healpix_alchemy.types import Tile

from .models import Galaxy, Field, FieldTile, Skymap, SkymapTile
from .models import Field, FieldTile, Galaxy, Skymap, SkymapTile

(
RANDOM_GALAXIES_SEED,
RANDOM_FIELDS_SEED,
RANDOM_SKY_MAP_SEED
) = np.random.SeedSequence(12345).spawn(3)
(RANDOM_GALAXIES_SEED, RANDOM_FIELDS_SEED, RANDOM_SKY_MAP_SEED) = (
np.random.SeedSequence(12345).spawn(3)
)


def get_ztf_footprint_corners():
Expand Down Expand Up @@ -67,16 +66,16 @@ def get_footprints_grid(lon, lat, offsets):

def get_random_points(n, seed):
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(np, 'random', np.random.default_rng(seed))
monkeypatch.setattr(np, "random", np.random.default_rng(seed))
return uniform_spherical_random_surface(n)


def get_random_galaxies(n, cursor):
points = SkyCoord(get_random_points(n, RANDOM_GALAXIES_SEED))
hpx = HPX.skycoord_to_healpix(points)

with cursor.copy(f'COPY {Galaxy.__tablename__} (hpx) FROM STDIN') as copy:
copy.write('\n'.join(f'{i}' for i in hpx))
with cursor.copy(f"COPY {Galaxy.__tablename__} (hpx) FROM STDIN") as copy:
copy.write("\n".join(f"{i}" for i in hpx))

return points

Expand All @@ -86,14 +85,15 @@ def get_random_fields(n, cursor):
footprints = get_footprints_grid(*get_ztf_footprint_corners(), centers)
mocs = [MOC.from_polygon_skycoord(footprint) for footprint in footprints]

with cursor.copy(f'COPY {Field.__tablename__} FROM STDIN') as copy:
copy.write('\n'.join(f'{i}' for i in range(len(mocs))))
with cursor.copy(f"COPY {Field.__tablename__} FROM STDIN") as copy:
copy.write("\n".join(f"{i}" for i in range(len(mocs))))

with cursor.copy(f'COPY {FieldTile.__tablename__} FROM STDIN') as copy:
with cursor.copy(f"COPY {FieldTile.__tablename__} FROM STDIN") as copy:
copy.write(
'\n'.join(
f'{i}\t{hpx}'
for i, moc in enumerate(mocs) for hpx in Tile.tiles_from(moc)
"\n".join(
f"{i}\t{hpx}"
for i, moc in enumerate(mocs)
for hpx in Tile.tiles_from(moc)
)
)

Expand All @@ -104,7 +104,7 @@ def get_random_sky_map(n, cursor):
rng = np.random.default_rng(RANDOM_SKY_MAP_SEED)
# Make a randomly subdivided sky map
npix = HPX.npix
tiles = np.arange(0, npix + 1, 4 ** LEVEL).tolist()
tiles = np.arange(0, npix + 1, 4**LEVEL).tolist()
while len(tiles) < n:
i = rng.integers(len(tiles))
lo = 0 if i == 0 else tiles[i - 1]
Expand All @@ -119,13 +119,13 @@ def get_random_sky_map(n, cursor):
probdensity = rng.uniform(0, 1, size=len(tiles) - 1)
probdensity /= np.sum(np.diff(tiles) * probdensity) * PIXEL_AREA

with cursor.copy(f'COPY {Skymap.__tablename__} FROM STDIN') as copy:
copy.write('1')
with cursor.copy(f"COPY {Skymap.__tablename__} FROM STDIN") as copy:
copy.write("1")

with cursor.copy(f'COPY {SkymapTile.__tablename__} FROM STDIN') as copy:
with cursor.copy(f"COPY {SkymapTile.__tablename__} FROM STDIN") as copy:
copy.write(
'\n'.join(
f'1\t[{lo},{hi})\t{p}'
"\n".join(
f"1\t[{lo},{hi})\t{p}"
for lo, hi, p in zip(tiles[:-1], tiles[1:], probdensity)
)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SQLAlchemy ORM models for unit tests."""

import sqlalchemy as sa
from sqlalchemy import orm

Expand All @@ -7,7 +8,6 @@

@orm.as_declarative()
class Base:

@orm.declared_attr
def __tablename__(cls):
return cls.__name__.lower()
Expand Down
Loading

0 comments on commit df6013d

Please sign in to comment.