Skip to content

Commit

Permalink
Make some internal utils functions private (#332)
Browse files Browse the repository at this point in the history
* make some methods private

* add option to overwrite database file
  • Loading branch information
wtbarnes authored Oct 2, 2024
1 parent f996aea commit 9dd7fac
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 48 deletions.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@

# On Read the Docs and CI, download the database and build a minimal HDF5 version
if (ON_RTD or ON_GHA):
from fiasco.util import check_database, get_test_file_list
from fiasco.tests import get_test_file_list
from fiasco.util import check_database
from fiasco.util.setup_db import CHIANTI_URL, LATEST_VERSION
from fiasco.util.util import FIASCO_HOME, FIASCO_RC
FIASCO_HOME.mkdir(exist_ok=True, parents=True)
Expand Down
3 changes: 2 additions & 1 deletion fiasco/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from packaging.version import Version

from fiasco.util import check_database, get_test_file_list, read_chianti_version
from fiasco.tests import get_test_file_list
from fiasco.util import check_database, read_chianti_version

# Force MPL to use non-gui backends for testing.
try:
Expand Down
14 changes: 14 additions & 0 deletions fiasco/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,17 @@
"""
This module contains package tests.
"""
import json
import pathlib

from astropy.utils.data import get_pkg_data_path

__all__ = ['get_test_file_list']


def get_test_file_list():
data_dir = pathlib.Path(get_pkg_data_path('data', package='fiasco.tests'))
file_path = data_dir / 'test_file_list.json'
with open(file_path) as f:
hash_table = json.load(f)
return hash_table['test_files']
File renamed without changes.
File renamed without changes.
File renamed without changes.
82 changes: 39 additions & 43 deletions fiasco/util/setup_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
]
LATEST_VERSION = SUPPORTED_VERSIONS[-1]

__all__ = ['check_database', 'check_database_version', 'download_dbase', 'md5hash', 'get_test_file_list', 'build_hdf5_dbase']
__all__ = ['check_database', 'download_dbase', 'build_hdf5_dbase']


def check_database(hdf5_dbase_root, **kwargs):
Expand Down Expand Up @@ -85,19 +85,12 @@ def check_database(hdf5_dbase_root, **kwargs):
# NOTE: this check is only meant to be bypassed when testing new
# versions. Hence, this kwarg is not documented
if kwargs.get('check_chianti_version', True):
check_database_version(ascii_dbase_root)
_check_database_version(ascii_dbase_root)
# If we made it this far, build the HDF5 database
files = kwargs.get('files')
build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=files, check_hash=kwargs.get('check_hash', False))


def check_database_version(ascii_dbase_root):
version = read_chianti_version(ascii_dbase_root)
if str(version) not in SUPPORTED_VERSIONS:
raise UnsupportedVersionError(
f'CHIANTI {version} is not in the list of supported versions {SUPPORTED_VERSIONS}.')


def download_dbase(ascii_dbase_url, ascii_dbase_root):
"""
Download the CHIANTI database in ASCII format
Expand All @@ -113,39 +106,7 @@ def download_dbase(ascii_dbase_url, ascii_dbase_root):
tar.extractall(path=ascii_dbase_root)


def md5hash(path):
# Use the md5 utility to generate this
path = pathlib.Path(path)
with path.open('rb') as f:
return hashlib.md5(f.read()).hexdigest()


def _get_hash_table(version):
data_dir = pathlib.Path(get_pkg_data_path('data', package='fiasco.util'))
file_path = data_dir / f'file_hashes_v{version}.json'
with open(file_path) as f:
hash_table = json.load(f)
return hash_table


def get_test_file_list():
data_dir = pathlib.Path(get_pkg_data_path('data', package='fiasco.util'))
file_path = data_dir / 'test_file_list.json'
with open(file_path) as f:
hash_table = json.load(f)
return hash_table['test_files']


def _check_hash(parser, hash_table):
actual = md5hash(parser.full_path)
key = '_'.join(parser.full_path.relative_to(parser.ascii_dbase_root).parts)
if hash_table[key] != actual:
raise RuntimeError(
f'Hash of {parser.full_path} ({actual}) did not match expected hash ({hash_table[key]})'
)


def build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=None, check_hash=False):
def build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=None, check_hash=False, overwrite=False):
"""
Assemble HDF5 file from raw ASCII CHIANTI database.
Expand All @@ -161,6 +122,9 @@ def build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=None, check_hash=F
check_hash: `bool`, optional
If True, check the file hash before adding it to the database.
Building the database will fail if any of the hashes is not as expected.
overwrite: `bool`, optional
If True, overwrite existing database file. By default, this is false such
that an exception will be thrown if the database already exists.
"""
# Import the logger here to avoid circular imports
from fiasco import log
Expand All @@ -176,8 +140,9 @@ def build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=None, check_hash=F
hash_table = _get_hash_table(version)
log.debug(f'Checking hashes for version {version}')
log.debug(f'Building HDF5 database in {hdf5_dbase_root}')
mode = 'w' if overwrite else 'x'
with ProgressBar(len(files)) as progress:
with h5py.File(hdf5_dbase_root, 'a') as hf:
with h5py.File(hdf5_dbase_root, mode=mode) as hf:
for f in files:
parser = fiasco.io.Parser(f, ascii_dbase_root=ascii_dbase_root)
try:
Expand All @@ -201,3 +166,34 @@ def build_hdf5_dbase(ascii_dbase_root, hdf5_dbase_root, files=None, check_hash=F
ion_list = list_ions(hdf5_dbase_root)
ds = hf.create_dataset('ion_index', data=np.array(ion_list).astype(np.bytes_))
ds.attrs['unit'] = 'SKIP'


def _check_database_version(ascii_dbase_root):
version = read_chianti_version(ascii_dbase_root)
if str(version) not in SUPPORTED_VERSIONS:
raise UnsupportedVersionError(
f'CHIANTI {version} is not in the list of supported versions {SUPPORTED_VERSIONS}.')


def _md5hash(path):
# Use the md5 utility to generate this
path = pathlib.Path(path)
with path.open('rb') as f:
return hashlib.md5(f.read()).hexdigest()


def _get_hash_table(version):
data_dir = pathlib.Path(get_pkg_data_path('data', package='fiasco.tests'))
file_path = data_dir / f'file_hashes_v{version}.json'
with open(file_path) as f:
hash_table = json.load(f)
return hash_table


def _check_hash(parser, hash_table):
actual = _md5hash(parser.full_path)
key = '_'.join(parser.full_path.relative_to(parser.ascii_dbase_root).parts)
if hash_table[key] != actual:
raise RuntimeError(
f'Hash of {parser.full_path} ({actual}) did not match expected hash ({hash_table[key]})'
)
4 changes: 2 additions & 2 deletions tools/generate_hash_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from astropy.utils.data import get_pkg_data_path
from itertools import chain

from fiasco.util.setup_db import md5hash
from fiasco.util.setup_db import _md5hash
from fiasco.util.util import get_chianti_catalog, read_chianti_version


Expand All @@ -29,7 +29,7 @@ def build_hash_table(dbase_root):
map(lambda x: pathlib.Path('dem') / x, catalogue['dem_files']),
)
filepaths = map(lambda x: dbase_root / x, filepaths)
return {'_'.join(f.relative_to(dbase_root).parts): md5hash(f) for f in filepaths}
return {'_'.join(f.relative_to(dbase_root).parts): _md5hash(f) for f in filepaths}


@click.command()
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_test_file_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def sort_func(x):

if __name__ == '__main__':
# An example of how you might use this function to update the test file list
from fiasco.util import get_test_file_list
from fiasco.tests import get_test_file_list

test_files = get_test_file_list() # Read current files
test_files += ... # Add new files here
Expand Down

0 comments on commit 9dd7fac

Please sign in to comment.