From ee28a0507fabc4aeffa1b3bcae1c07290f48cb79 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 14 Dec 2023 00:08:08 +0900 Subject: [PATCH] Feat: Add serde.py --- requirements.txt | 3 +- torchglyph/env.py | 3 +- torchglyph/io.py | 89 +------------------------------------------ torchglyph/meter.py | 2 +- torchglyph/serde.py | 56 +++++++++++++++++++++++++++ torchglyph/summary.py | 4 +- 6 files changed, 64 insertions(+), 93 deletions(-) create mode 100644 torchglyph/serde.py diff --git a/requirements.txt b/requirements.txt index 99ab649..a8833ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ dataclasses colorlog tabulate matplotlib -aku \ No newline at end of file +aku +pyyaml \ No newline at end of file diff --git a/torchglyph/env.py b/torchglyph/env.py index 57e52b3..1729230 100644 --- a/torchglyph/env.py +++ b/torchglyph/env.py @@ -6,8 +6,9 @@ from torchglyph import DEBUG from torchglyph.dist import get_device, init_process, init_seed -from torchglyph.io import hf_hash, lock_folder, save_args +from torchglyph.io import hf_hash, lock_folder from torchglyph.logger import init_logger +from torchglyph.serde import save_args def timestamp(*, time_format: str = r'%y%m%d-%H%M%S') -> str: diff --git a/torchglyph/io.py b/torchglyph/io.py index 2a965fb..91bfbd3 100644 --- a/torchglyph/io.py +++ b/torchglyph/io.py @@ -1,19 +1,12 @@ -import gzip -import json import logging -import shutil -import tarfile -import zipfile from contextlib import contextmanager from pathlib import Path -from typing import Any, List, Tuple +from typing import List, Tuple -import torch from datasets.config import DATASETDICT_JSON_FILENAME, DATASET_INFO_FILENAME from datasets.download import DownloadConfig, DownloadManager from datasets.fingerprint import Hasher from filelock import FileLock -from torch import nn from torchglyph import DEBUG, data_dir @@ -83,86 +76,6 @@ def lock_folder(path: Path): yield -def load_json(path: Path) -> Any: - with path.open(mode='r', encoding='utf-8') as fp: - return json.load(fp=fp) - - -def load_args(out_dir: Path, name: str = ARGS_JSON) -> Any: - return load_json(path=out_dir / name) - - -def load_sota(out_dir: Path, name: str = SOTA_JSON) -> Any: - return load_json(path=out_dir / name) - - -def save_json(path: Path, **kwargs) -> None: - data = {} - if not path.exists(): - logger.info(f'saving to {path}') - else: - with path.open(mode='r', encoding='utf-8') as fp: - data = json.load(fp=fp) - - with path.open(mode='w', encoding='utf-8') as fp: - json.dump({**data, **kwargs}, fp=fp, indent=2, ensure_ascii=False) - - -def save_args(out_dir: Path, name: str = ARGS_JSON, **kwargs) -> None: - return save_json(path=out_dir / name, **kwargs) - - -def save_sota(out_dir: Path, name: str = SOTA_JSON, **kwargs) -> None: - return save_json(path=out_dir / name, **kwargs) - - -def load_checkpoint(name: str = CHECKPOINT_PT, strict: bool = True, *, out_dir: Path, **kwargs) -> None: - state_dict = torch.load(out_dir / name, map_location=torch.device('cpu')) - - for name, module in kwargs.items(): # type: str, nn.Module - logger.info(f'loading {name}.checkpoint from {out_dir / name}') - missing_keys, unexpected_keys = module.load_state_dict(state_dict=state_dict[name], strict=strict) - - if not strict: - for missing_key in missing_keys: - logger.warning(f'{name}.{missing_key} is missing') - - for unexpected_key in unexpected_keys: - logger.warning(f'{name}.{unexpected_key} is unexpected') - - -def save_checkpoint(name: str = CHECKPOINT_PT, *, out_dir: Path, **kwargs) -> None: - logger.info(f'saving checkpoint ({", ".join(kwargs.keys())}) to {out_dir / name}') - return torch.save({name: module.state_dict() for name, module in kwargs.items()}, f=out_dir / name) - - -def extract(path: Path) -> Path: - logger.info(f'extracting files from {path}') - - if path.name.endswith('.zip'): - with zipfile.ZipFile(path, 'r') as fp: - fp.extractall(path=path.parent) - - elif path.name.endswith('.tar'): - with tarfile.open(path, 'r') as fp: - fp.extractall(path=path.parent) - - elif path.name.endswith('.tar.gz') or path.name.endswith('.tgz'): - with tarfile.open(path, 'r:gz') as fp: - fp.extractall(path=path.parent) - - elif path.name.endswith('.tar.bz2') or path.name.endswith('.tbz'): - with tarfile.open(path, 'r:bz2') as fp: - fp.extractall(path=path.parent) - - elif path.name.endswith('.gz'): - with gzip.open(path, mode='rb') as fs: - with path.with_suffix('').open(mode='wb') as fd: - shutil.copyfileobj(fs, fd) - - return path - - class DownloadMixin(object): name: str diff --git a/torchglyph/meter.py b/torchglyph/meter.py index 0035c12..a6f1ce0 100644 --- a/torchglyph/meter.py +++ b/torchglyph/meter.py @@ -12,7 +12,7 @@ from torchrua import CattedSequence from torchglyph.dist import all_gather_object, is_master -from torchglyph.io import save_sota +from torchglyph.serde import save_sota logger = getLogger(__name__) diff --git a/torchglyph/serde.py b/torchglyph/serde.py new file mode 100644 index 0000000..2cf76b4 --- /dev/null +++ b/torchglyph/serde.py @@ -0,0 +1,56 @@ +import json +from logging import getLogger +from pathlib import Path +from typing import Any + +import yaml + +logger = getLogger(__name__) + + +def load_json(path: Path) -> Any: + with path.open(mode='r', encoding='utf-8') as fp: + return json.load(fp=fp) + + +def load_yaml(path: Path) -> Any: + with path.open(mode='r', encoding='utf-8') as fp: + return yaml.load(stream=fp, Loader=yaml.CLoader) + + +def save_json(path: Path, **kwargs) -> None: + if not path.exists(): + logger.info(f'saving to {path}') + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open(mode='w', encoding='utf-8') as fp: + return json.dump(obj=kwargs, fp=fp, indent=2, ensure_ascii=False) + + +def save_yaml(path: Path, **kwargs) -> None: + if not path.exists(): + logger.info(f'saving to {path}') + path.parent.mkdir(parents=True, exist_ok=True) + + with path.open(mode='w', encoding='utf-8') as fp: + return yaml.dump(data=kwargs, stream=fp, indent=2, allow_unicode=True) + + +ARGS_FILENAME = 'args.json' +SOTA_FILENAME = 'sota.json' + + +def load_args(out_dir: Path, name: str = ARGS_FILENAME) -> Any: + return load_json(path=out_dir / name) + + +def load_sota(out_dir: Path, name: str = SOTA_FILENAME) -> Any: + return load_json(path=out_dir / name) + + +def save_args(out_dir: Path, name: str = ARGS_FILENAME, **kwargs) -> None: + return save_json(path=load_json(out_dir / name), **kwargs) + + +def save_sota(out_dir: Path, name: str = SOTA_FILENAME, **kwargs) -> None: + return save_json(path=load_json(out_dir / name), **kwargs) diff --git a/torchglyph/summary.py b/torchglyph/summary.py index f98f82e..9694f7b 100644 --- a/torchglyph/summary.py +++ b/torchglyph/summary.py @@ -5,8 +5,8 @@ import torch from tabulate import tabulate -from torchglyph.io import ARGS_JSON, SOTA_JSON, load_args, load_sota from torchglyph.logger import LOG_TXT +from torchglyph.serde import ARGS_FILENAME, SOTA_FILENAME, load_args, load_sota logger = getLogger(__name__) @@ -17,7 +17,7 @@ def iter_dir(path: Path) -> Iterable[Path]: for path in path.iterdir(): if path.is_dir(): - if (path / ARGS_JSON).exists() and (path / SOTA_JSON).exists(): + if (path / ARGS_FILENAME).exists() and (path / SOTA_FILENAME).exists(): yield path else: yield from iter_dir(path)