Skip to content

Commit

Permalink
Feat: Add serde.py
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Dec 13, 2023
1 parent 1ad8796 commit ee28a05
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 93 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ dataclasses
colorlog
tabulate
matplotlib
aku
aku
pyyaml
3 changes: 2 additions & 1 deletion torchglyph/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from torchglyph import DEBUG
from torchglyph.dist import get_device, init_process, init_seed
from torchglyph.io import hf_hash, lock_folder, save_args
from torchglyph.io import hf_hash, lock_folder
from torchglyph.logger import init_logger
from torchglyph.serde import save_args


def timestamp(*, time_format: str = r'%y%m%d-%H%M%S') -> str:
Expand Down
89 changes: 1 addition & 88 deletions torchglyph/io.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import gzip
import json
import logging
import shutil
import tarfile
import zipfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, List, Tuple
from typing import List, Tuple

import torch
from datasets.config import DATASETDICT_JSON_FILENAME, DATASET_INFO_FILENAME
from datasets.download import DownloadConfig, DownloadManager
from datasets.fingerprint import Hasher
from filelock import FileLock
from torch import nn

from torchglyph import DEBUG, data_dir

Expand Down Expand Up @@ -83,86 +76,6 @@ def lock_folder(path: Path):
yield


def load_json(path: Path) -> Any:
with path.open(mode='r', encoding='utf-8') as fp:
return json.load(fp=fp)


def load_args(out_dir: Path, name: str = ARGS_JSON) -> Any:
return load_json(path=out_dir / name)


def load_sota(out_dir: Path, name: str = SOTA_JSON) -> Any:
return load_json(path=out_dir / name)


def save_json(path: Path, **kwargs) -> None:
data = {}
if not path.exists():
logger.info(f'saving to {path}')
else:
with path.open(mode='r', encoding='utf-8') as fp:
data = json.load(fp=fp)

with path.open(mode='w', encoding='utf-8') as fp:
json.dump({**data, **kwargs}, fp=fp, indent=2, ensure_ascii=False)


def save_args(out_dir: Path, name: str = ARGS_JSON, **kwargs) -> None:
return save_json(path=out_dir / name, **kwargs)


def save_sota(out_dir: Path, name: str = SOTA_JSON, **kwargs) -> None:
return save_json(path=out_dir / name, **kwargs)


def load_checkpoint(name: str = CHECKPOINT_PT, strict: bool = True, *, out_dir: Path, **kwargs) -> None:
state_dict = torch.load(out_dir / name, map_location=torch.device('cpu'))

for name, module in kwargs.items(): # type: str, nn.Module
logger.info(f'loading {name}.checkpoint from {out_dir / name}')
missing_keys, unexpected_keys = module.load_state_dict(state_dict=state_dict[name], strict=strict)

if not strict:
for missing_key in missing_keys:
logger.warning(f'{name}.{missing_key} is missing')

for unexpected_key in unexpected_keys:
logger.warning(f'{name}.{unexpected_key} is unexpected')


def save_checkpoint(name: str = CHECKPOINT_PT, *, out_dir: Path, **kwargs) -> None:
logger.info(f'saving checkpoint ({", ".join(kwargs.keys())}) to {out_dir / name}')
return torch.save({name: module.state_dict() for name, module in kwargs.items()}, f=out_dir / name)


def extract(path: Path) -> Path:
logger.info(f'extracting files from {path}')

if path.name.endswith('.zip'):
with zipfile.ZipFile(path, 'r') as fp:
fp.extractall(path=path.parent)

elif path.name.endswith('.tar'):
with tarfile.open(path, 'r') as fp:
fp.extractall(path=path.parent)

elif path.name.endswith('.tar.gz') or path.name.endswith('.tgz'):
with tarfile.open(path, 'r:gz') as fp:
fp.extractall(path=path.parent)

elif path.name.endswith('.tar.bz2') or path.name.endswith('.tbz'):
with tarfile.open(path, 'r:bz2') as fp:
fp.extractall(path=path.parent)

elif path.name.endswith('.gz'):
with gzip.open(path, mode='rb') as fs:
with path.with_suffix('').open(mode='wb') as fd:
shutil.copyfileobj(fs, fd)

return path


class DownloadMixin(object):
name: str

Expand Down
2 changes: 1 addition & 1 deletion torchglyph/meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchrua import CattedSequence

from torchglyph.dist import all_gather_object, is_master
from torchglyph.io import save_sota
from torchglyph.serde import save_sota

logger = getLogger(__name__)

Expand Down
56 changes: 56 additions & 0 deletions torchglyph/serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import json
from logging import getLogger
from pathlib import Path
from typing import Any

import yaml

logger = getLogger(__name__)


def load_json(path: Path) -> Any:
with path.open(mode='r', encoding='utf-8') as fp:
return json.load(fp=fp)


def load_yaml(path: Path) -> Any:
with path.open(mode='r', encoding='utf-8') as fp:
return yaml.load(stream=fp, Loader=yaml.CLoader)


def save_json(path: Path, **kwargs) -> None:
if not path.exists():
logger.info(f'saving to {path}')
path.parent.mkdir(parents=True, exist_ok=True)

with path.open(mode='w', encoding='utf-8') as fp:
return json.dump(obj=kwargs, fp=fp, indent=2, ensure_ascii=False)


def save_yaml(path: Path, **kwargs) -> None:
if not path.exists():
logger.info(f'saving to {path}')
path.parent.mkdir(parents=True, exist_ok=True)

with path.open(mode='w', encoding='utf-8') as fp:
return yaml.dump(data=kwargs, stream=fp, indent=2, allow_unicode=True)


ARGS_FILENAME = 'args.json'
SOTA_FILENAME = 'sota.json'


def load_args(out_dir: Path, name: str = ARGS_FILENAME) -> Any:
return load_json(path=out_dir / name)


def load_sota(out_dir: Path, name: str = SOTA_FILENAME) -> Any:
return load_json(path=out_dir / name)


def save_args(out_dir: Path, name: str = ARGS_FILENAME, **kwargs) -> None:
return save_json(path=load_json(out_dir / name), **kwargs)


def save_sota(out_dir: Path, name: str = SOTA_FILENAME, **kwargs) -> None:
return save_json(path=load_json(out_dir / name), **kwargs)
4 changes: 2 additions & 2 deletions torchglyph/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
from tabulate import tabulate

from torchglyph.io import ARGS_JSON, SOTA_JSON, load_args, load_sota
from torchglyph.logger import LOG_TXT
from torchglyph.serde import ARGS_FILENAME, SOTA_FILENAME, load_args, load_sota

logger = getLogger(__name__)

Expand All @@ -17,7 +17,7 @@
def iter_dir(path: Path) -> Iterable[Path]:
for path in path.iterdir():
if path.is_dir():
if (path / ARGS_JSON).exists() and (path / SOTA_JSON).exists():
if (path / ARGS_FILENAME).exists() and (path / SOTA_FILENAME).exists():
yield path
else:
yield from iter_dir(path)
Expand Down

0 comments on commit ee28a05

Please sign in to comment.