diff --git a/.dvc/config b/.dvc/config index df8f9a9..2b52c50 100644 --- a/.dvc/config +++ b/.dvc/config @@ -2,3 +2,4 @@ remote = gdrive ['remote "gdrive"'] url = gdrive://155tBftKDG8VSAWojOWT3exax3hz0Xuwg + gdrive_acknowledge_abuse = true diff --git a/.github/workflows/clf.yml b/.github/workflows/clf.yml index a7fac4d..f2d0a33 100644 --- a/.github/workflows/clf.yml +++ b/.github/workflows/clf.yml @@ -1,5 +1,9 @@ name: autobuild_clf on: + pull_request: + branches: + - master + - dev push: branches: - master @@ -14,10 +18,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' architecture: 'x64' - uses: actions/setup-node@v3 with: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index a2e0794..11f0e73 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -3,8 +3,6 @@ on: pull_request: branches: - master - - 'V**' - - dev env: REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -16,10 +14,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' architecture: 'x64' - uses: iterative/setup-cml@v1 - uses: iterative/setup-dvc@v1 diff --git a/.github/workflows/segm.yml b/.github/workflows/segm.yml index a955226..581bcb9 100644 --- a/.github/workflows/segm.yml +++ b/.github/workflows/segm.yml @@ -1,5 +1,9 @@ name: autobuild_segm on: + pull_request: + branches: + - master + - dev push: branches: - master @@ -14,10 +18,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' architecture: 'x64' - uses: actions/setup-node@v3 with: @@ -48,5 +52,7 @@ jobs: cat runs/pytest_segm/Validation/metrics.json | md-table >> report.md echo >> report.md echo "#### Prediction" >> report.md - echo "![Prediction](runs/pytest_segm/Validation/prediction.png)" >> report.md + echo "![Prediction](runs/pytest_segm/overlays/fc9399fafb30_01.jpg)" >> report.md + echo "![Prediction](runs/pytest_segm/overlays/fcac2903b622_06.jpg)" >> report.md + echo "![Prediction](runs/pytest_segm/overlays/fcac2903b622_11.jpg)" >> report.md cml comment create report.md diff --git a/.github/workflows/tablr.yml b/.github/workflows/tablr.yml index c05faee..734e657 100644 --- a/.github/workflows/tablr.yml +++ b/.github/workflows/tablr.yml @@ -1,5 +1,9 @@ name: autobuild_tablr on: + pull_request: + branches: + - master + - dev push: branches: - master @@ -14,10 +18,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' architecture: 'x64' - uses: actions/setup-node@v3 with: @@ -29,7 +33,7 @@ jobs: run: | python -m pip install --upgrade pip pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - pip install -e '.[dev,tabular,tabular_classification]' + pip install -e '.[dev,ml]' pip install dvc dvc-gdrive npm install -g markdown-table-cli - name: Download data diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79044bc..fdae9e4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/README.md b/README.md index af340f7..f8aa8cc 100644 --- a/README.md +++ b/README.md @@ -9,37 +9,40 @@ # :pencil: Instructions ### Installation -- Install Pytorch: - - For conda: `conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia` - - For PIP: `pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116` +- Create virtual environment: `conda create -n myenv python=3.10` +- Install [Pytorch](https://pytorch.org/). Currently support `torch==1.13.1` - Inside your project, install this package by `git+https://github.com/kaylode/theseus.git@master#egg=theseus[cv,cv.classification,cv.detection,cv.semantic]` -***extra packages can be identified from the project's folder structure***. +***extra packages can be identified from the pyproject.toml***. ### To adapt for personal project 1. Create your own dataset, dataloader, model, loss function, metric function, ... and register it to the registry so that it can be generated from config at runtime. 2. Customize inherited trainer and pipeline to your need, such as what to do before/after training/validating step,... -3. Modify configuration file +3. Write custom callbacks (recommended!), follow [Lightning](https://lightning.ai/docs/pytorch/latest/) +4. Modify configuration file -*See ```theseus/classification``` for example* +*See ```theseus/cv/classification``` for example* ### To execute scripts with arguments -- Run the script with `-c` flag with specified config file. Example: +- Run the script with `--config-dir` flag with a specified config folder that contains the yaml file. And `--config-name` is that file's name. +Example: ``` -python train.py -c pipeline.yaml +python train.py \ + --config-dir configs \ + --config-name pipeline.yaml ``` -- To override arguments inside the .yaml file, use flag `-o` with key and value. For example, to train 50 epochs and resume training from checkpoints: +- To override arguments inside the .yaml file, follow the instructions from [Hydra](https://hydra.cc/docs/intro/). For example, to train 50 epochs and resume training from checkpoints: ``` python train.py \ - -c pipeline.yaml \ - -o trainer.args.num_iterations=5000 \ + --config-dir configs \ + --config-name pipeline.yaml \ + trainer.args.max_epochs=5000 \ global.resume=checkpoint.pth ``` **Notice: There are no spaces between keys and values in -o flag** -- Also, if you want to do inference, you need to write your own script. For example see ```configs/classification/infer.py``` # :school_satchel: Resources - Example colab notebooks for classification tasks: [![Notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mZmT1B5zI1j_0w1MbP-kq8_Tbcx_tIFq?usp=sharing) diff --git a/configs/base/globals.yaml b/configs/base/globals.yaml deleted file mode 100644 index 79370ae..0000000 --- a/configs/base/globals.yaml +++ /dev/null @@ -1,15 +0,0 @@ -global: - exp_name: null - exist_ok: false - debug: false - save_dir: runs - device: cuda:0 - pretrained: null - resume: null -trainer: - name: SupervisedTrainer - args: - num_iterations: 10000 - clip_grad: 1.0 - evaluate_interval: 1 - use_fp16: true diff --git a/configs/classification/eval.py b/configs/classification/eval.py deleted file mode 100644 index 3ca0ffd..0000000 --- a/configs/classification/eval.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.classification.pipeline import ClassificationPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - val_pipeline = ClassificationPipeline(opts) - val_pipeline.evaluate() diff --git a/configs/classification/optuna/pipeline.yaml b/configs/classification/optuna/pipeline.yaml deleted file mode 100644 index e4e1415..0000000 --- a/configs/classification/optuna/pipeline.yaml +++ /dev/null @@ -1,24 +0,0 @@ -includes: - - configs/base/globals.yaml - - configs/base/optimizer.yaml - - configs/classification/transform.yaml - - configs/classification/pipeline.yaml - -trainer: - name: SupervisedTrainer - args: - num_iterations: 10 - clip_grad: null - evaluate_interval: 0 - use_fp16: false - -callbacks: [] - -optimizer: - name: AdamW - args: - lr: [0.0001, 0.001] - -optuna: - float: - - optimizer.args.lr diff --git a/configs/classification/train.py b/configs/classification/train.py deleted file mode 100644 index 3997446..0000000 --- a/configs/classification/train.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.classification.pipeline import ClassificationPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - train_pipeline = ClassificationPipeline(opts) - train_pipeline.fit() diff --git a/configs/detection/eval.py b/configs/detection/eval.py deleted file mode 100644 index 9a4369c..0000000 --- a/configs/detection/eval.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.detection.pipeline import DetectionPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - val_pipeline = DetectionPipeline(opts) - val_pipeline.evaluate() diff --git a/configs/detection/pipeline.yaml b/configs/detection/pipeline.yaml deleted file mode 100644 index d138d53..0000000 --- a/configs/detection/pipeline.yaml +++ /dev/null @@ -1,70 +0,0 @@ -includes: - - configs/base/globals.yaml - - configs/base/optimizer.yaml - - configs/detection/transform.yaml - -callbacks: - - name: TorchCheckpointCallbacks - - name: DetectionVisualizerCallbacks - - name: TensorboardCallbacks -model: - name: DETRConvnext - args: - model_name: detr - backbone_name: resnet50 - num_queries: 100 - min_conf: 0.25 - hidden_dim: 256 - position_embedding: sine - freeze_backbone: false - dilation: false - dropout: 0.1 - nheads: 8 - dim_feedforward: 2048 - enc_layers: 6 - dec_layers: 6 - pre_norm: false - aux_loss: true -loss: - name: DETRLosses - args: - loss_ce: 1 - loss_bbox: 5 - loss_giou: 2 - loss_mask: null - loss_dice: null - cost_class: 1 - cost_bbox: 5 - cost_giou: 2 - eos_coef: 0.1 -metrics: - - name: DetectionPrecisionRecall - args: - min_conf: 0.2 - min_iou: 0.5 - eps: 1e-6 -data: - dataset: - train: - name: COCODataset - args: - image_dir: ./data/coco/train2017 - label_path: ./data/coco/annotations/instances_train2017.json - val: - name: COCODataset - args: - image_dir: ./data/coco/val2017 - label_path: ./data/coco/annotations/instances_val2017.json - dataloader: - train: - name: DataLoaderWithCollator - args: - batch_size: 2 - drop_last: true - shuffle: true - val: - name: DataLoaderWithCollator - args: - batch_size: 2 - drop_last: false - shuffle: false diff --git a/configs/detection/train.py b/configs/detection/train.py deleted file mode 100644 index 949e16f..0000000 --- a/configs/detection/train.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.detection.pipeline import DetectionPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - train_pipeline = DetectionPipeline(opts) - train_pipeline.fit() diff --git a/configs/detection/transform.yaml b/configs/detection/transform.yaml deleted file mode 100644 index c1b2fca..0000000 --- a/configs/detection/transform.yaml +++ /dev/null @@ -1,69 +0,0 @@ -augmentations: - train: - name: DetCompose - args: - transforms: - - name: DetCompose - args: - transforms: - - name: BoxOrder - args: - order: xywh2cxcywh - - name: BoxNormalize - args: - order: cxcywh - - name: AlbCompose - args: - bbox_params: - name: AlbBboxParams - args: - format: 'yolo' - min_area: 0 - min_visibility: 0 - label_fields: ['class_labels'] - transforms: - - name: AlbResize - args: - width: 640 - height: 640 - - name: AlbNormalize - args: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - max_pixel_value: 1.0 - p: 1.0 - - name: AlbToTensorV2 - val: - name: DetCompose - args: - transforms: - - name: DetCompose - args: - transforms: - - name: BoxOrder - args: - order: xywh2cxcywh - - name: BoxNormalize - args: - order: cxcywh - - name: AlbCompose - args: - bbox_params: - name: AlbBboxParams - args: - format: 'yolo' - min_area: 0 - min_visibility: 0 - label_fields: ['class_labels'] - transforms: - - name: AlbResize - args: - width: 640 - height: 640 - - name: AlbNormalize - args: - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - max_pixel_value: 1.0 - p: 1.0 - - name: AlbToTensorV2 diff --git a/configs/semantic/eval.py b/configs/semantic/eval.py deleted file mode 100644 index dc2f427..0000000 --- a/configs/semantic/eval.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.semantic.pipeline import SemanticPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - val_pipeline = SemanticPipeline(opts) - val_pipeline.evaluate() diff --git a/configs/semantic/train.py b/configs/semantic/train.py deleted file mode 100644 index 522e47d..0000000 --- a/configs/semantic/train.py +++ /dev/null @@ -1,10 +0,0 @@ -import matplotlib as mpl - -mpl.use("Agg") -from theseus.cv.semantic.pipeline import SemanticPipeline -from theseus.opt import Opts - -if __name__ == "__main__": - opts = Opts().parse_args() - train_pipeline = SemanticPipeline(opts) - train_pipeline.fit() diff --git a/configs/tabular/optuna/tune.py b/configs/tabular/optuna/tune.py deleted file mode 100644 index 6b94435..0000000 --- a/configs/tabular/optuna/tune.py +++ /dev/null @@ -1,16 +0,0 @@ -from theseus.base.utilities.optuna_tuner import OptunaWrapper -from theseus.opt import Config -from theseus.tabular.classification.pipeline import TabularPipeline - -if __name__ == "__main__": - config = Config("configs/tabular/optuna.yaml") - tuner = OptunaWrapper() - - tuner.tune( - config=config, - pipeline_class=TabularPipeline, - best_key="bl_acc", - n_trials=5, - direction="maximize", - save_dir="runs/optuna/", - ) diff --git a/pyproject.toml b/pyproject.toml index 0679c95..fdf241a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,10 @@ packages = ["theseus"] [project] name = "theseus" -version = "1.1.0" +version = "1.5.0" description = "A general template for various Deep Learning tasks. Strongly relies on Pytorch" readme = "README.md" -requires-python = ">=3.6" +requires-python = ">=3.10" license = {file = "LICENSE"} keywords = ["pytorch", "template", "deep learning"] authors = [ @@ -43,7 +43,9 @@ dependencies = [ "loguru>=0.6.0", "kaleido>=0.2.1", "optuna>=3.0.5", - "deepdiff>=6.2.3" + "deepdiff>=6.2.3", + "hydra-core>=1.3.2", + "lightning>=2.0.0" ] [project.optional-dependencies] @@ -58,6 +60,7 @@ cv = [ ] cv_classification = [ "timm", + "scikit-plot", "grad-cam>=1.4.5" ] cv_semantic = [ @@ -80,25 +83,24 @@ nlp_retrieval = [ "rank_bm25>=0.2.2", "elasticsearch>=7.17.7" ] -tabular = [ - "pandas>=1.5.1", +ml = [ + "pandas", "pandarallel>=1.6.3", - "numpy>=1.23.4", + "numpy", "scikit-learn>=1.0.0", "scipy>=1.7.0", "optuna>=3.0.5", "psycopg2-binary>=2.9.5", - "gunicorn>=20.1.0" -] -tabular_classification = [ + "gunicorn>=20.1.0", "lightgbm>=3.3.3", - "xgboost>=1.7.1", + "xgboost<=1.7.1", "catboost", "shap>=0.41.0", - "lime>=0.2.0.1" + "lime>=0.2.0.1", + "scikit-plot", ] all = [ - "theseus[cv,cv_classification,cv_semantic,cv_detection,nlp,nlp_retrieval,tabular,tabular_classification]", + "theseus[cv,cv_classification,cv_semantic,cv_detection,nlp,nlp_retrieval,ml]", ] [project.urls] # Optional diff --git a/tests/classification/configs/base/globals.yaml b/tests/classification/configs/base/globals.yaml new file mode 100644 index 0000000..094d2b5 --- /dev/null +++ b/tests/classification/configs/base/globals.yaml @@ -0,0 +1,18 @@ +global: + exp_name: null + exist_ok: false + save_dir: runs + resume: null + pretrained: null +trainer: + name: plTrainer + args: + devices: 1 + accelerator: gpu + enable_progress_bar: false + precision: 32 + max_epochs: 30 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 + gradient_clip_val: 1.0 + deterministic: True diff --git a/tests/classification/configs/base/hydra.yaml b/tests/classification/configs/base/hydra.yaml new file mode 100644 index 0000000..2dc96a2 --- /dev/null +++ b/tests/classification/configs/base/hydra.yaml @@ -0,0 +1,3 @@ +hydra: + run: + dir: ./runs/hydra/${now:%Y-%m-%d-%H-%M-%S} diff --git a/configs/base/optimizer.yaml b/tests/classification/configs/base/optimizer.yaml similarity index 100% rename from configs/base/optimizer.yaml rename to tests/classification/configs/base/optimizer.yaml diff --git a/configs/classification/pipeline.yaml b/tests/classification/configs/optuna.yaml similarity index 77% rename from configs/classification/pipeline.yaml rename to tests/classification/configs/optuna.yaml index 27d86d4..c2d1191 100644 --- a/configs/classification/pipeline.yaml +++ b/tests/classification/configs/optuna.yaml @@ -1,14 +1,10 @@ -includes: - - configs/base/globals.yaml - - configs/base/optimizer.yaml - - configs/classification/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/optimizer@_here_ + - transform@_here_ + - _self_ -callbacks: - - name: TorchCheckpointCallbacks - args: - best_key: bl_acc - - name: ClassificationVisualizerCallbacks - - name: TensorboardCallbacks model: name: BaseTimmModel args: @@ -22,8 +18,6 @@ metrics: - name: F1ScoreMetric args: average: weighted - - name: ConfusionMatrix - - name: ErrorCases data: dataset: train: @@ -58,3 +52,15 @@ data: batch_size: 16 drop_last: false shuffle: true + + +callbacks: [] + +optimizer: + name: AdamW + args: + lr: [0.0001, 0.001] + +optuna: + float: + - optimizer.args.lr diff --git a/tests/classification/configs/pipeline.yaml b/tests/classification/configs/pipeline.yaml new file mode 100644 index 0000000..cfa0298 --- /dev/null +++ b/tests/classification/configs/pipeline.yaml @@ -0,0 +1,71 @@ +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/optimizer@_here_ + - transform@_here_ + - _self_ + +callbacks: + - name: TorchCheckpointCallback + args: + filename: best + save_top_k: 1 + save_last: true + monitor: bl_acc + mode: max + - name: RichModelSummary + args: + - name: LearningRateMonitor + args: + logging_interval: step + - name: TensorboardCallback + +model: + name: BaseTimmModel + args: + model_name: efficientnet_b0 + from_pretrained: true +loss: + name: ClassificationCELoss +metrics: + - name: Accuracy + - name: BalancedAccuracyMetric + - name: F1ScoreMetric + args: + average: weighted + - name: ConfusionMatrix + - name: ErrorCases +data: + dataset: + train: + name: ClassificationImageFolderDataset + args: + image_dir: samples/dog-vs-cats/train + txt_classnames: samples/dog-vs-cats/classnames.txt + val: + name: ClassificationImageFolderDataset + args: + image_dir: samples/dog-vs-cats/val + txt_classnames: samples/dog-vs-cats/classnames.txt + dataloader: + train: + name: DataLoaderWithCollator + args: + batch_size: 16 + drop_last: false + shuffle: false + collate_fn: + name: MixupCutmixCollator + args: + mixup_alpha: 0.4 + cutmix_alpha: 1.0 + weight: [0.2, 0.2] + sampler: + name: BalanceSampler + args: + val: + name: DataLoaderWithCollator + args: + batch_size: 16 + drop_last: false + shuffle: false diff --git a/configs/classification/test.yaml b/tests/classification/configs/test.yaml similarity index 83% rename from configs/classification/test.yaml rename to tests/classification/configs/test.yaml index 6c4d133..a51ac5e 100644 --- a/configs/classification/test.yaml +++ b/tests/classification/configs/test.yaml @@ -1,11 +1,13 @@ -includes: - - configs/classification/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - transform@_here_ + - _self_ + global: exp_name: null exist_ok: false - debug: True save_dir: runs - device: cuda:0 weights: null model: name: BaseTimmModel diff --git a/configs/classification/transform.yaml b/tests/classification/configs/transform.yaml similarity index 100% rename from configs/classification/transform.yaml rename to tests/classification/configs/transform.yaml diff --git a/tests/classification/conftest.py b/tests/classification/conftest.py index 5048f0b..e3f57b1 100644 --- a/tests/classification/conftest.py +++ b/tests/classification/conftest.py @@ -2,44 +2,69 @@ import optuna import pytest +from hydra import compose, initialize, initialize_config_module +from omegaconf import OmegaConf from optuna.storages import JournalFileStorage, JournalStorage from theseus.base.utilities.optuna_tuner import OptunaWrapper -from theseus.opt import Config @pytest.fixture(scope="session") def override_config(): - config = Config("./configs/classification/pipeline.yaml") - config["global"]["exp_name"] = "pytest_clf" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" - config["trainer"]["args"]["use_fp16"] = False - config["trainer"]["args"]["num_iterations"] = 10 - config["data"]["dataloader"]["train"]["args"]["batch_size"] = 1 - config["data"]["dataloader"]["val"]["args"]["batch_size"] = 1 + with initialize(config_path="configs"): + config = compose( + config_name="pipeline", + overrides=[ + "global.exp_name=pytest_clf", + "global.exist_ok=True", + "global.save_dir=runs", + "trainer.args.max_epochs=5", + "trainer.args.precision=32", + "trainer.args.accelerator=cpu", + "trainer.args.devices=1", + "data.dataloader.train.args.batch_size=1", + "data.dataloader.val.args.batch_size=1", + ], + ) + return config @pytest.fixture(scope="session") def override_test_config(): - config = Config("./configs/classification/test.yaml") - config["global"]["exp_name"] = "pytest_clf" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" - config["data"]["dataloader"]["args"]["batch_size"] = 1 + with initialize(config_path="configs"): + config = compose( + config_name="test", + overrides=[ + "global.exp_name=pytest_clf", + "global.exist_ok=True", + "global.save_dir=runs", + "trainer.args.precision=32", + "trainer.args.accelerator=cpu", + "trainer.args.devices=1", + "data.dataloader.args.batch_size=1", + ], + ) + return config @pytest.fixture(scope="session") def override_tuner_config(): - config = Config(f"./configs/classification/optuna/pipeline.yaml") - config["global"]["exp_name"] = "pytest_clf_optuna" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" + + with initialize(config_path="configs"): + config = compose( + config_name="optuna", + overrides=[ + "global.exp_name=pytest_clf", + "global.exist_ok=True", + "global.save_dir=runs", + "trainer.args.precision=32", + "trainer.args.accelerator=cpu", + "trainer.args.devices=1", + ], + ) + return config diff --git a/configs/classification/infer.py b/tests/classification/inference.py similarity index 85% rename from configs/classification/infer.py rename to tests/classification/inference.py index d7b83f1..220d294 100644 --- a/configs/classification/infer.py +++ b/tests/classification/inference.py @@ -4,6 +4,7 @@ import os import pandas as pd +from omegaconf import DictConfig from tqdm import tqdm from theseus.base.pipeline import BaseTestPipeline @@ -11,11 +12,10 @@ from theseus.cv.classification.augmentations import TRANSFORM_REGISTRY from theseus.cv.classification.datasets import DATALOADER_REGISTRY, DATASET_REGISTRY from theseus.cv.classification.models import MODEL_REGISTRY -from theseus.opt import Config, Opts class TestPipeline(BaseTestPipeline): - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(TestPipeline, self).__init__(opt) self.opt = opt @@ -38,7 +38,7 @@ def inference(self): for idx, batch in enumerate(tqdm(self.dataloader)): img_names = batch["img_names"] - outputs = self.model.get_prediction(batch, self.device) + outputs = self.model.predict_step(batch) preds = outputs["names"] probs = outputs["confidences"] @@ -50,9 +50,3 @@ def inference(self): df = pd.DataFrame(df_dict) savepath = os.path.join(self.savedir, "prediction.csv") df.to_csv(savepath, index=False) - - -if __name__ == "__main__": - opts = Opts().parse_args() - val_pipeline = TestPipeline(opts) - val_pipeline.inference() diff --git a/tests/classification/test_clf.py b/tests/classification/test_clf.py index 0039103..56f3126 100644 --- a/tests/classification/test_clf.py +++ b/tests/classification/test_clf.py @@ -1,6 +1,6 @@ import pytest -from configs.classification.infer import TestPipeline +from tests.classification.inference import TestPipeline from theseus.cv.classification.pipeline import ClassificationPipeline @@ -12,13 +12,13 @@ def test_train_clf(override_config): @pytest.mark.order(2) def test_eval_clf(override_config): - override_config["global"]["pretrained"] = "runs/pytest_clf/checkpoints/best.pth" + override_config["global"]["resume"] = "runs/pytest_clf/checkpoints/best.ckpt" val_pipeline = ClassificationPipeline(override_config) val_pipeline.evaluate() @pytest.mark.order(2) def test_infer_clf(override_test_config): - override_test_config["global"]["weights"] = "runs/pytest_clf/checkpoints/best.pth" + override_test_config["global"]["weights"] = "runs/pytest_clf/checkpoints/best.ckpt" test_pipeline = TestPipeline(override_test_config) test_pipeline.inference() diff --git a/tests/classification/test_tuner_clf.py b/tests/classification/test_tuner_clf.py index 6d760b3..0f9d6fa 100644 --- a/tests/classification/test_tuner_clf.py +++ b/tests/classification/test_tuner_clf.py @@ -2,6 +2,7 @@ import pytest +from theseus.base.callbacks.optuna_callback import OptunaCallback from theseus.cv.classification.pipeline import ClassificationPipeline @@ -10,6 +11,7 @@ def test_train_clf_tune(override_tuner_config, override_tuner_tuner): override_tuner_tuner.tune( config=override_tuner_config, pipeline_class=ClassificationPipeline, + optuna_callback=OptunaCallback, trial_user_attrs={ "best_key": "bl_acc", "model_name": override_tuner_config["model"]["args"]["model_name"], diff --git a/tests/semantic/configs/base/globals.yaml b/tests/semantic/configs/base/globals.yaml new file mode 100644 index 0000000..094d2b5 --- /dev/null +++ b/tests/semantic/configs/base/globals.yaml @@ -0,0 +1,18 @@ +global: + exp_name: null + exist_ok: false + save_dir: runs + resume: null + pretrained: null +trainer: + name: plTrainer + args: + devices: 1 + accelerator: gpu + enable_progress_bar: false + precision: 32 + max_epochs: 30 + check_val_every_n_epoch: 1 + num_sanity_val_steps: 0 + gradient_clip_val: 1.0 + deterministic: True diff --git a/tests/semantic/configs/base/hydra.yaml b/tests/semantic/configs/base/hydra.yaml new file mode 100644 index 0000000..2dc96a2 --- /dev/null +++ b/tests/semantic/configs/base/hydra.yaml @@ -0,0 +1,3 @@ +hydra: + run: + dir: ./runs/hydra/${now:%Y-%m-%d-%H-%M-%S} diff --git a/tests/semantic/configs/base/optimizer.yaml b/tests/semantic/configs/base/optimizer.yaml new file mode 100644 index 0000000..73e4f17 --- /dev/null +++ b/tests/semantic/configs/base/optimizer.yaml @@ -0,0 +1,16 @@ +optimizer: + name: AdamW + args: + lr: 0.001 + weight_decay: 0.0005 + betas: + - 0.937 + - 0.999 +scheduler: + name: SchedulerWrapper + args: + scheduler_name: cosine2 + t_initial: 7 + t_mul: 0.9 + eta_mul: 0.9 + eta_min: 1.0e-06 diff --git a/configs/semantic/pipeline.yaml b/tests/semantic/configs/pipeline.yaml similarity index 76% rename from configs/semantic/pipeline.yaml rename to tests/semantic/configs/pipeline.yaml index 7efff83..e40e346 100644 --- a/configs/semantic/pipeline.yaml +++ b/tests/semantic/configs/pipeline.yaml @@ -1,13 +1,25 @@ -includes: - - configs/base/globals.yaml - - configs/base/optimizer.yaml - - configs/semantic/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/optimizer@_here_ + - transform@_here_ + - _self_ + callbacks: - - name: TorchCheckpointCallbacks + - name: TorchCheckpointCallback args: - best_key: dice - - name: SemanticVisualizerCallbacks - - name: TensorboardCallbacks + filename: best + save_top_k: 1 + save_last: true + monitor: dice + mode: max + - name: RichModelSummary + args: + - name: LearningRateMonitor + args: + logging_interval: step + - name: TensorboardCallback + model: name: BaseSegModel args: @@ -64,4 +76,4 @@ data: args: batch_size: 32 drop_last: false - shuffle: true + shuffle: false diff --git a/configs/semantic/test.yaml b/tests/semantic/configs/test.yaml similarity index 83% rename from configs/semantic/test.yaml rename to tests/semantic/configs/test.yaml index b50fd43..7c1c628 100644 --- a/configs/semantic/test.yaml +++ b/tests/semantic/configs/test.yaml @@ -1,11 +1,13 @@ -includes: - - configs/semantic/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - transform@_here_ + - _self_ + global: exp_name: null exist_ok: false - debug: True save_dir: runs - device: cuda:0 weights: null model: name: BaseSegModel diff --git a/configs/semantic/transform.yaml b/tests/semantic/configs/transform.yaml similarity index 86% rename from configs/semantic/transform.yaml rename to tests/semantic/configs/transform.yaml index f83fe4e..5b53f48 100644 --- a/configs/semantic/transform.yaml +++ b/tests/semantic/configs/transform.yaml @@ -20,12 +20,6 @@ augmentations: args: brightness_limit: 0.3 contrast_limit: 0.3 - # - name: AlbRandomRotate90 - # args: - # - name: AlbShiftScaleRotate - # args: - # border_mode: 0 - # value: 0 - name: AlbNormalize args: mean: [0.485, 0.456, 0.406] diff --git a/tests/semantic/conftest.py b/tests/semantic/conftest.py index 29093d4..196b19c 100644 --- a/tests/semantic/conftest.py +++ b/tests/semantic/conftest.py @@ -1,28 +1,42 @@ import pytest - -from theseus.opt import Config +from hydra import compose, initialize @pytest.fixture(scope="session") def override_config(): - config = Config("./configs/semantic/pipeline.yaml") - config["global"]["exp_name"] = "pytest_segm" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" - config["trainer"]["args"]["use_fp16"] = False - config["trainer"]["args"]["num_iterations"] = 10 - config["data"]["dataloader"]["train"]["args"]["batch_size"] = 1 - config["data"]["dataloader"]["val"]["args"]["batch_size"] = 1 + with initialize(config_path="configs"): + config = compose( + config_name="pipeline", + overrides=[ + "global.exp_name=pytest_segm", + "global.exist_ok=True", + "global.save_dir=runs", + "trainer.args.max_epochs=5", + "trainer.args.precision=32", + "trainer.args.accelerator=cpu", + "trainer.args.devices=1", + "data.dataloader.train.args.batch_size=1", + "data.dataloader.val.args.batch_size=1", + ], + ) + return config @pytest.fixture(scope="session") def override_test_config(): - config = Config("./configs/semantic/test.yaml") - config["global"]["exp_name"] = "pytest_segm" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" - config["data"]["dataloader"]["args"]["batch_size"] = 1 + with initialize(config_path="configs"): + config = compose( + config_name="test", + overrides=[ + "global.exp_name=pytest_segm", + "global.exist_ok=True", + "global.save_dir=runs", + "trainer.args.precision=32", + "trainer.args.accelerator=cpu", + "trainer.args.devices=1", + "data.dataloader.args.batch_size=1", + ], + ) + return config diff --git a/configs/semantic/infer.py b/tests/semantic/inference.py similarity index 90% rename from configs/semantic/infer.py rename to tests/semantic/inference.py index f057690..394d962 100644 --- a/configs/semantic/infer.py +++ b/tests/semantic/inference.py @@ -5,6 +5,7 @@ import cv2 import torch +from omegaconf import DictConfig from theseus.base.pipeline import BaseTestPipeline from theseus.base.utilities.loggers import LoggerObserver @@ -12,11 +13,10 @@ from theseus.cv.semantic.augmentations import TRANSFORM_REGISTRY from theseus.cv.semantic.datasets import DATALOADER_REGISTRY, DATASET_REGISTRY from theseus.cv.semantic.models import MODEL_REGISTRY -from theseus.opt import Config, Opts class TestPipeline(BaseTestPipeline): - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(TestPipeline, self).__init__(opt) self.opt = opt @@ -49,7 +49,7 @@ def inference(self): img_names = batch["img_names"] ori_sizes = batch["ori_sizes"] - outputs = self.model.get_prediction(batch, self.device) + outputs = self.model.predict_step(batch) preds = outputs["masks"] for (inpt, pred, filename, ori_size) in zip( @@ -70,9 +70,3 @@ def inference(self): cv2.imwrite(savepath, overlay) self.logger.text(f"Save image at {savepath}", level=LoggerObserver.INFO) - - -if __name__ == "__main__": - opts = Opts().parse_args() - val_pipeline = TestPipeline(opts) - val_pipeline.inference() diff --git a/tests/semantic/test_segm.py b/tests/semantic/test_segm.py index dd4e4b9..d9b0bd7 100644 --- a/tests/semantic/test_segm.py +++ b/tests/semantic/test_segm.py @@ -1,6 +1,6 @@ import pytest -from configs.semantic.infer import TestPipeline +from tests.semantic.inference import TestPipeline from theseus.cv.semantic.pipeline import SemanticPipeline @@ -12,13 +12,15 @@ def test_train_clf(override_config): @pytest.mark.order(2) def test_eval_clf(override_config): - override_config["global"]["pretrained"] = "runs/pytest_segm/checkpoints/best.pth" + override_config["global"]["resume"] = "runs/pytest_segm/checkpoints/best.ckpt" val_pipeline = SemanticPipeline(override_config) val_pipeline.evaluate() @pytest.mark.order(2) def test_infer_clf(override_test_config): - override_test_config["global"]["weights"] = "runs/pytest_segm/checkpoints/best.pth" + override_test_config["global"][ + "pretrained" + ] = "runs/pytest_segm/checkpoints/best.ckpt" test_pipeline = TestPipeline(override_test_config) test_pipeline.inference() diff --git a/configs/tabular/base/data.yaml b/tests/tabular/configs/base/data.yaml similarity index 100% rename from configs/tabular/base/data.yaml rename to tests/tabular/configs/base/data.yaml diff --git a/configs/tabular/base/globals.yaml b/tests/tabular/configs/base/globals.yaml similarity index 91% rename from configs/tabular/base/globals.yaml rename to tests/tabular/configs/base/globals.yaml index 53094c9..9e038a2 100644 --- a/configs/tabular/base/globals.yaml +++ b/tests/tabular/configs/base/globals.yaml @@ -5,7 +5,6 @@ global: save_dir: runs pretrained: null resume: null - device: cpu trainer: name: MLTrainer args: @@ -16,7 +15,6 @@ callbacks: plot_type: bar check_additivity: False # - name: PermutationImportance - - name: TensorboardCallbacks metrics: - name: SKLAccuracy - name: SKLBalancedAccuracyMetric diff --git a/tests/tabular/configs/base/hydra.yaml b/tests/tabular/configs/base/hydra.yaml new file mode 100644 index 0000000..2dc96a2 --- /dev/null +++ b/tests/tabular/configs/base/hydra.yaml @@ -0,0 +1,3 @@ +hydra: + run: + dir: ./runs/hydra/${now:%Y-%m-%d-%H-%M-%S} diff --git a/configs/tabular/base/transform.yaml b/tests/tabular/configs/base/transform.yaml similarity index 100% rename from configs/tabular/base/transform.yaml rename to tests/tabular/configs/base/transform.yaml diff --git a/configs/tabular/catboost.yaml b/tests/tabular/configs/catboost.yaml similarity index 70% rename from configs/tabular/catboost.yaml rename to tests/tabular/configs/catboost.yaml index 5ec509f..a51d057 100644 --- a/configs/tabular/catboost.yaml +++ b/tests/tabular/configs/catboost.yaml @@ -1,7 +1,9 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/data@_here_ + - base/transform@_here_ + - _self_ model: name: GBClassifiers args: diff --git a/configs/tabular/lightgbm.yaml b/tests/tabular/configs/lightgbm.yaml similarity index 72% rename from configs/tabular/lightgbm.yaml rename to tests/tabular/configs/lightgbm.yaml index fc22d1f..adb3280 100644 --- a/configs/tabular/lightgbm.yaml +++ b/tests/tabular/configs/lightgbm.yaml @@ -1,7 +1,9 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/data@_here_ + - base/transform@_here_ + - _self_ model: name: GBClassifiers args: diff --git a/configs/tabular/optuna/catboost_tune.yaml b/tests/tabular/configs/optuna/catboost_tune.yaml similarity index 83% rename from configs/tabular/optuna/catboost_tune.yaml rename to tests/tabular/configs/optuna/catboost_tune.yaml index 186767a..d0c228a 100644 --- a/configs/tabular/optuna/catboost_tune.yaml +++ b/tests/tabular/configs/optuna/catboost_tune.yaml @@ -1,8 +1,10 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml - - configs/tabular/optuna/optuna.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/data@_here_ + - base/transform@_here_ + - optuna@_here_ + - _self_ model: name: GBClassifiers args: diff --git a/configs/tabular/optuna/lightgbm_tune.yaml b/tests/tabular/configs/optuna/lightgbm_tune.yaml similarity index 82% rename from configs/tabular/optuna/lightgbm_tune.yaml rename to tests/tabular/configs/optuna/lightgbm_tune.yaml index 5ec5c32..d02f14c 100644 --- a/configs/tabular/optuna/lightgbm_tune.yaml +++ b/tests/tabular/configs/optuna/lightgbm_tune.yaml @@ -1,8 +1,10 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml - - configs/tabular/optuna/optuna.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/data@_here_ + - base/transform@_here_ + - optuna@_here_ + - _self_ model: name: GBClassifiers args: diff --git a/configs/tabular/optuna/optuna.yaml b/tests/tabular/configs/optuna/optuna.yaml similarity index 100% rename from configs/tabular/optuna/optuna.yaml rename to tests/tabular/configs/optuna/optuna.yaml diff --git a/configs/tabular/optuna/xgboost_tune.yaml b/tests/tabular/configs/optuna/xgboost_tune.yaml similarity index 83% rename from configs/tabular/optuna/xgboost_tune.yaml rename to tests/tabular/configs/optuna/xgboost_tune.yaml index 063978f..53f160d 100644 --- a/configs/tabular/optuna/xgboost_tune.yaml +++ b/tests/tabular/configs/optuna/xgboost_tune.yaml @@ -1,8 +1,11 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml - - configs/tabular/optuna/optuna.yaml +defaults: + - ../base/hydra@_here_ + - ../base/globals@_here_ + - ../base/data@_here_ + - ../base/transform@_here_ + - optuna@_here_ + - _self_ + model: name: GBClassifiers args: @@ -15,8 +18,8 @@ model: reg_alpha: [0.001, 1.0] #This will anyways be tuned later. reg_lambda: [0.001, 1.0] #This will anyways be tuned later. early_stopping_rounds: 30 - eval_metric: ['auc'] objective: "multi:softprob" + # eval_metric: ['auc'] optuna: int: - model.args.model_config.n_estimators diff --git a/configs/tabular/xgboost.yaml b/tests/tabular/configs/xgboost.yaml similarity index 80% rename from configs/tabular/xgboost.yaml rename to tests/tabular/configs/xgboost.yaml index b0632b9..844bb7d 100644 --- a/configs/tabular/xgboost.yaml +++ b/tests/tabular/configs/xgboost.yaml @@ -1,7 +1,10 @@ -includes: - - configs/tabular/base/globals.yaml - - configs/tabular/base/data.yaml - - configs/tabular/base/transform.yaml +defaults: + - base/hydra@_here_ + - base/globals@_here_ + - base/data@_here_ + - base/transform@_here_ + - _self_ + model: name: GBClassifiers args: @@ -14,5 +17,5 @@ model: reg_alpha: 0 #This will anyways be tuned later. reg_lambda: 1 #This will anyways be tuned later. early_stopping_rounds: 30 - eval_metric: ['auc'] objective: "multi:softprob" + # eval_metric: ['auc'] diff --git a/tests/tabular/conftest.py b/tests/tabular/conftest.py index c6a558f..7fccd21 100644 --- a/tests/tabular/conftest.py +++ b/tests/tabular/conftest.py @@ -1,10 +1,10 @@ import os import pytest +from hydra import compose, initialize from optuna.storages import JournalFileStorage, JournalStorage from theseus.base.utilities.optuna_tuner import OptunaWrapper -from theseus.opt import Config MODELS = ["xgboost"] # , "catboost", 'lightgbm'] TUNER_MODELS = ["xgboost_tune"] # , "catboost_tune"] #, 'lightgbm_tune'] @@ -12,27 +12,36 @@ @pytest.fixture(scope="session", params=MODELS) def override_config(request): - config = Config(f"./configs/tabular/{request.param}.yaml") - config["global"]["exp_name"] = "pytest_tablr" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" + with initialize(config_path="configs"): + config = compose( + config_name=f"{request.param}", + overrides=[ + "global.exp_name=pytest_tablr", + "global.exist_ok=True", + "global.save_dir=runs", + ], + ) + return config @pytest.fixture(scope="function", params=TUNER_MODELS) def override_tuner_config(request): - config = Config(f"./configs/tabular/optuna/{request.param}.yaml") - config["global"]["exp_name"] = "pytest_tablr_optuna" - config["global"]["exist_ok"] = True - config["global"]["save_dir"] = "runs" - config["global"]["device"] = "cpu" + with initialize(config_path="configs/optuna"): + config = compose( + config_name=f"{request.param}", + overrides=[ + "global.exp_name=pytest_tablr_optuna", + "global.exist_ok=True", + "global.save_dir=runs", + ], + ) + return config @pytest.fixture(scope="session") def override_tuner_tuner(): - os.makedirs("runs/optuna/tablr", exist_ok=True) database = JournalStorage( JournalFileStorage("runs/optuna/tablr/pytest_tablr_optuna.log") diff --git a/tests/tabular/test_tablr.py b/tests/tabular/test_tablr.py index e839d38..5637788 100644 --- a/tests/tabular/test_tablr.py +++ b/tests/tabular/test_tablr.py @@ -1,19 +1,19 @@ import pytest # from configs.tabular.infer import TestPipeline -from theseus.tabular.classification.pipeline import TabularPipeline +from theseus.ml.pipeline import MLPipeline @pytest.mark.order(1) def test_train_tblr(override_config): - train_pipeline = TabularPipeline(override_config) + train_pipeline = MLPipeline(override_config) train_pipeline.fit() @pytest.mark.order(2) def test_eval_tblr(override_config): override_config["global"]["pretrained"] = "runs/pytest_tablr/checkpoints/last" - val_pipeline = TabularPipeline(override_config) + val_pipeline = MLPipeline(override_config) val_pipeline.evaluate() diff --git a/tests/tabular/test_tuner_tblr.py b/tests/tabular/test_tuner_tblr.py index 98373b1..c5e094b 100644 --- a/tests/tabular/test_tuner_tblr.py +++ b/tests/tabular/test_tuner_tblr.py @@ -2,14 +2,16 @@ import pytest -from theseus.tabular.classification.pipeline import TabularPipeline +from theseus.ml.callbacks.optuna_callbacks import OptunaCallbacks +from theseus.ml.pipeline import MLPipeline @pytest.mark.order(1) def test_train_tblr_tune(override_tuner_config, override_tuner_tuner): override_tuner_tuner.tune( config=override_tuner_config, - pipeline_class=TabularPipeline, + pipeline_class=MLPipeline, + optuna_callback=OptunaCallbacks, trial_user_attrs={ "best_key": "bl_acc", "model_name": override_tuner_config["model"]["args"]["model_name"], @@ -18,7 +20,6 @@ def test_train_tblr_tune(override_tuner_config, override_tuner_tuner): leaderboard_df = override_tuner_tuner.leaderboard() os.makedirs("runs/optuna/tablr/overview", exist_ok=True) - # leaderboard_df.to_csv("runs/optuna/tablr/overview/leaderboard.csv", index=False) leaderboard_df.to_json( "runs/optuna/tablr/overview/leaderboard.json", orient="records" ) diff --git a/theseus/__init__.py b/theseus/__init__.py index 167706f..6f39759 100644 --- a/theseus/__init__.py +++ b/theseus/__init__.py @@ -11,7 +11,7 @@ __author__ = "kaylode" __license__ = "MIT" __copyright__ = "Copyright 2020-present Kaylode" -__version__ = "1.1.0" +__version__ = "1.5.0" from .base import * from .registry import Registry diff --git a/theseus/base/callbacks/__init__.py b/theseus/base/callbacks/__init__.py index b55adf1..a9b0c11 100644 --- a/theseus/base/callbacks/__init__.py +++ b/theseus/base/callbacks/__init__.py @@ -1,22 +1,28 @@ +from lightning.pytorch.callbacks import ( + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + RichModelSummary, +) + from theseus.registry import Registry -from .base_callbacks import Callbacks, CallbacksList -from .checkpoint_callbacks import TorchCheckpointCallbacks -from .debug_callbacks import DebugCallbacks -from .loss_logging_callbacks import LossLoggerCallbacks -from .lr_autofind import AutoFindLRCallbacks -from .metric_logging_callbacks import MetricLoggerCallbacks -from .timer_callbacks import TimerCallbacks -from .tsb_callbacks import TensorboardCallbacks -from .wandb_callbacks import WandbCallbacks +from .checkpoint_callback import TorchCheckpointCallback +from .loss_logging_callback import LossLoggerCallback +from .metric_logging_callback import MetricLoggerCallback +from .timer_callback import TimerCallback +from .tsb_callback import TensorboardCallback +from .wandb_callback import WandbCallback CALLBACKS_REGISTRY = Registry("CALLBACKS") -CALLBACKS_REGISTRY.register(TimerCallbacks) -CALLBACKS_REGISTRY.register(TorchCheckpointCallbacks) -CALLBACKS_REGISTRY.register(TensorboardCallbacks) -CALLBACKS_REGISTRY.register(WandbCallbacks) -CALLBACKS_REGISTRY.register(DebugCallbacks) -CALLBACKS_REGISTRY.register(LossLoggerCallbacks) -CALLBACKS_REGISTRY.register(MetricLoggerCallbacks) -CALLBACKS_REGISTRY.register(AutoFindLRCallbacks) +CALLBACKS_REGISTRY.register(TimerCallback) +CALLBACKS_REGISTRY.register(TensorboardCallback) +CALLBACKS_REGISTRY.register(WandbCallback) +CALLBACKS_REGISTRY.register(ModelCheckpoint) +CALLBACKS_REGISTRY.register(RichModelSummary) +CALLBACKS_REGISTRY.register(LearningRateMonitor) +CALLBACKS_REGISTRY.register(EarlyStopping) +CALLBACKS_REGISTRY.register(LossLoggerCallback) +CALLBACKS_REGISTRY.register(MetricLoggerCallback) +CALLBACKS_REGISTRY.register(TorchCheckpointCallback) diff --git a/theseus/base/callbacks/checkpoint_callback.py b/theseus/base/callbacks/checkpoint_callback.py new file mode 100644 index 0000000..e544a53 --- /dev/null +++ b/theseus/base/callbacks/checkpoint_callback.py @@ -0,0 +1,62 @@ +import inspect +import os +import os.path as osp + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint + +from theseus.base.utilities.loggers.observer import LoggerObserver + +LOGGER = LoggerObserver.getLogger("main") + + +class TorchCheckpointCallback(ModelCheckpoint): + def __init__(self, save_dir: str, **kwargs) -> None: + + save_dir = osp.join(save_dir, "checkpoints") + os.makedirs(save_dir, exist_ok=True) + inspection = inspect.signature(ModelCheckpoint) + class_kwargs = inspection.parameters.keys() + filtered_kwargs = {k: v for k, v in kwargs.items() if k in class_kwargs} + + super().__init__(dirpath=save_dir, **filtered_kwargs) + + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: + super().setup(trainer, pl_module, stage) + self.params = {} + trainloader = pl_module.datamodule.trainloader + if trainloader is not None: + batch_size = trainloader.batch_size + self.params["trainloader_length"] = len(trainloader) + self.params["num_iterations"] = len(trainloader) * trainer.max_epochs + + if self._every_n_train_steps is None or self._every_n_train_steps == 0: + LOGGER.text( + "Save interval not specified. Auto calculating...", + level=LoggerObserver.DEBUG, + ) + self._every_n_train_steps = self.auto_get_save_interval() + + def auto_get_save_interval(self, train_fraction=0.5): + """ + Automatically decide the number of save interval + """ + save_interval = max(int(train_fraction * self.params["trainloader_length"]), 1) + return save_interval + + def _save_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + super()._save_checkpoint(trainer, filepath) + + if filepath in self.best_k_models.keys(): + if self.best_k_models[filepath] == self.best_model_score: + LOGGER.text( + f"Evaluation improved to {self.current_score}", + level=LoggerObserver.SUCCESS, + ) + + LOGGER.text( + f"Save checkpoints to {filepath}", + level=LoggerObserver.INFO, + ) diff --git a/theseus/base/callbacks/checkpoint_callbacks.py b/theseus/base/callbacks/checkpoint_callbacks.py deleted file mode 100644 index b646b77..0000000 --- a/theseus/base/callbacks/checkpoint_callbacks.py +++ /dev/null @@ -1,184 +0,0 @@ -import os -import os.path as osp -from typing import Dict - -import torch - -from theseus.base.callbacks import Callbacks -from theseus.base.utilities.loading import load_state_dict -from theseus.base.utilities.loggers.observer import LoggerObserver - -LOGGER = LoggerObserver.getLogger("main") - - -class TorchCheckpointCallbacks(Callbacks): - """ - Callbacks for saving checkpoints. - Features: - - Load checkpoint at start - - Save checkpoint every save_interval - - Save checkpoint if metric value is improving - - save_dir: `str` - save directory - save_interval: `int` - iteration cycle to save checkpoint - best_key: `str` - save best based on metric key - resume: `str` - path to .pth to resume checkpoints - - """ - - def __init__( - self, - save_dir: str = "runs", - save_interval: int = None, - best_key: str = None, - resume: str = None, - **kwargs, - ) -> None: - super().__init__() - - self.best_value = 0 - self.best_key = best_key - self.save_dir = osp.join(save_dir, "checkpoints") - os.makedirs(self.save_dir, exist_ok=True) - self.save_interval = save_interval - self.resume = resume - - def auto_get_save_interval(self, train_fraction=0.5): - """ - Automatically decide the number of save interval - """ - trainloader = self.params["trainer"].trainloader - num_iterations_per_epoch = len(trainloader) - save_interval = max(int(train_fraction * num_iterations_per_epoch), 1) - return save_interval - - def load_checkpoint(self, path, trainer): - """ - Load all information the current iteration from checkpoint - """ - LOGGER.text("Loading checkpoints...", level=LoggerObserver.INFO) - state_dict = torch.load(path, map_location="cpu") - trainer.iters = load_state_dict(trainer.iters, state_dict, "iters") - if trainer.scaler: - trainer.scaler = load_state_dict( - trainer.scaler, state_dict, trainer.scaler.state_dict_key - ) - self.best_value = load_state_dict(self.best_value, state_dict, "best_value") - - def save_checkpoint(self, trainer, iters, outname="last"): - """ - Save all information of the current iteration - """ - weights = { - "model": trainer.model.model.state_dict(), - "optimizer": trainer.optimizer.state_dict(), - "scheduler": trainer.scheduler.state_dict(), - "iters": iters, - "best_value": self.best_value, - } - - if trainer.scheduler: - weights["scheduler"] = trainer.scheduler.state_dict() - - if trainer.scaler: - weights[trainer.scaler.state_dict_key] = trainer.scaler.state_dict() - - torch.save(weights, os.path.join(self.save_dir, outname) + ".pth") - LOGGER.text( - f"Save checkpoints to {os.path.join(self.save_dir, outname)}" + ".pth", - level=LoggerObserver.INFO, - ) - - def sanitycheck(self, logs: Dict = None): - """ - Sanitycheck before starting. Run only when debug=True - """ - if self.resume is not None: - self.load_checkpoint(self.resume, self.params["trainer"]) - self.resume = None # Turn off so that on_start would not be called - - def on_start(self, logs: Dict = None): - """ - Before going to the main loop - """ - if self.resume is not None: - self.load_checkpoint(self.resume, self.params["trainer"]) - - if self.save_interval is None: - self.save_interval = self.auto_get_save_interval() - LOGGER.text( - "Save interval not specified. Auto calculating...", - level=LoggerObserver.DEBUG, - ) - - def on_finish(self, logs: Dict = None): - """ - After finish training - """ - - iters = logs["iters"] - num_iterations = logs["num_iterations"] - - self.save_checkpoint(self.params["trainer"], iters=iters) - LOGGER.text( - f"Save model at [{iters}|{num_iterations}] to last.pth", - LoggerObserver.INFO, - ) - - def on_train_batch_end(self, logs: Dict = None): - """ - On training batch (iteration) end - """ - - iters = logs["iters"] - num_iterations = logs["num_iterations"] - - # Saving checkpoint - if iters % self.save_interval == 0 or iters == num_iterations - 1: - self.save_checkpoint(self.params["trainer"], iters=iters) - LOGGER.text( - f"Save model at [{iters}|{num_iterations}] to last.pth", - LoggerObserver.INFO, - ) - - def on_val_epoch_end(self, logs: Dict = None): - """ - On validation batch (iteration) end - """ - - iters = logs["iters"] - num_iterations = logs["num_iterations"] - metric_dict = logs["metric_dict"] - - if self.best_key is None: - return - - if not self.best_key in metric_dict.keys(): - LOGGER.text( - f"{self.best_key} key does not present in metric. Available keys are: {metric_dict.keys()}", - LoggerObserver.WARN, - ) - return - - # Saving checkpoint - if metric_dict[self.best_key] > self.best_value: - if ( - iters > 0 - ): # Have been training, else in evaluation-only mode or just sanity check - LOGGER.text( - f"Evaluation improved from {self.best_value} to {metric_dict[self.best_key]}", - level=LoggerObserver.SUCCESS, - ) - self.best_value = metric_dict[self.best_key] - self.save_checkpoint( - self.params["trainer"], iters=iters, outname="best" - ) - - LOGGER.text( - f"Save model at [{iters}|{num_iterations}] to best.pth", - LoggerObserver.INFO, - ) diff --git a/theseus/base/callbacks/debug_callbacks.py b/theseus/base/callbacks/debug_callbacks.py deleted file mode 100644 index 7125665..0000000 --- a/theseus/base/callbacks/debug_callbacks.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Dict - -from theseus.base.callbacks import Callbacks -from theseus.base.utilities.loggers.observer import LoggerObserver - -LOGGER = LoggerObserver.getLogger("main") - - -class DebugCallbacks(Callbacks): - """ - Callbacks for debugging. - """ - - def __init__( - self, - **kwargs, - ) -> None: - super().__init__() - - def sanitycheck(self, logs: Dict = None): - """ - Sanitycheck before starting. Run only when debug=True - """ - - LOGGER.text("Start sanity checks", level=LoggerObserver.DEBUG) - self.params["trainer"].evaluate_epoch() diff --git a/theseus/base/callbacks/loss_logging_callbacks.py b/theseus/base/callbacks/loss_logging_callback.py similarity index 62% rename from theseus/base/callbacks/loss_logging_callbacks.py rename to theseus/base/callbacks/loss_logging_callback.py index ee86660..b973c33 100644 --- a/theseus/base/callbacks/loss_logging_callbacks.py +++ b/theseus/base/callbacks/loss_logging_callback.py @@ -1,15 +1,17 @@ import time -from typing import Dict, List +from typing import Any, Dict, List +import lightning.pytorch as pl import numpy as np +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver LOGGER = LoggerObserver.getLogger("main") -class LossLoggerCallbacks(Callbacks): +class LossLoggerCallback(Callback): """ Callbacks for logging running loss while training Features: @@ -26,49 +28,86 @@ def __init__(self, print_interval: int = None, **kwargs) -> None: self.running_loss = {} self.print_interval = print_interval - def auto_get_print_interval(self, train_fraction=0.1): + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: """ - Automatically decide the number of print interval + Setup the callback """ - trainloader = self.params["trainer"].trainloader - num_iterations_per_epoch = len(trainloader) - print_interval = max(int(train_fraction * num_iterations_per_epoch), 1) - return print_interval + self.params = {} + + trainloader = pl_module.datamodule.trainloader + if trainloader is not None: + batch_size = trainloader.batch_size + self.params["num_iterations"] = len(trainloader) * trainer.max_epochs + self.params["trainloader_length"] = len(trainloader) + else: + self.params["num_iterations"] = None + self.params["trainloader_length"] = None + + valloader = pl_module.datamodule.valloader + if valloader is not None: + batch_size = valloader.batch_size + self.params["valloader_length"] = len(valloader) + else: + self.params["valloader_length"] = None - def on_start(self, logs: Dict = None): - """ - Before going to the main loop - """ if self.print_interval is None: - self.print_interval = self.auto_get_print_interval() + self.print_interval = self.auto_get_print_interval(pl_module) LOGGER.text( "Print interval not specified. Auto calculating...", level=LoggerObserver.DEBUG, ) - def on_train_epoch_start(self, logs: Dict = None): + def auto_get_print_interval( + self, pl_module: pl.LightningModule, train_fraction: float = 0.1 + ): + """ + Automatically decide the number of print interval + """ + + num_iterations_per_epoch = ( + self.params["trainloader_length"] + if self.params["trainloader_length"] is not None + else self.params["valloader_length"] + ) + print_interval = max(int(train_fraction * num_iterations_per_epoch), 1) + return print_interval + + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ Before going to the training loop """ self.running_loss = {} self.running_time_list = [] - def on_train_batch_start(self, logs: Dict = None): + def on_train_batch_start( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch: Any, + batch_idx: int, + ): """ Before going to the training loop """ self.running_time = time.time() - def on_train_batch_end(self, logs: Dict = None): + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ): """ After finish a batch """ - lr = logs["lr"] - iters = logs["iters"] - loss_dict = logs["loss_dict"] - num_iterations = logs["num_iterations"] - trainloader_length = len(self.params["trainer"].trainloader) + lr = pl_module.lr + iters = trainer.global_step + loss_dict = outputs["loss_dict"] # Update running loss of batch for (key, value) in loss_dict.items(): @@ -81,7 +120,10 @@ def on_train_batch_end(self, logs: Dict = None): self.running_time_list.append(batch_time) # Logging - if iters % self.print_interval == 0 or (iters + 1) % trainloader_length == 0: + if ( + iters % self.print_interval == 0 + or (iters + 1) % self.params["trainloader_length"] == 0 + ): # Running loss since last interval for key in self.running_loss.keys(): @@ -98,7 +140,7 @@ def on_train_batch_end(self, logs: Dict = None): LOGGER.text( "[{}|{}] || {} || Time: {:10.4f} (it/s)".format( iters, - num_iterations, + self.params["num_iterations"], loss_string, running_time, ), @@ -141,19 +183,29 @@ def on_train_batch_end(self, logs: Dict = None): self.running_loss = {} self.running_time_list = [] - def on_val_epoch_start(self, logs: Dict = None): + def on_validation_epoch_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): """ Before main validation loops """ self.running_time = time.time() self.running_loss = {} - def on_val_batch_end(self, logs: Dict = None): + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ): """ After finish a batch """ - loss_dict = logs["loss_dict"] + loss_dict = outputs["loss_dict"] # Update batch loss for (key, value) in loss_dict.items(): @@ -161,15 +213,14 @@ def on_val_batch_end(self, logs: Dict = None): self.running_loss[key] = [] self.running_loss[key].append(value) - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ - iters = logs["iters"] - num_iterations = logs["num_iterations"] + iters = trainer.global_step + num_iterations = self.params["num_iterations"] epoch_time = time.time() - self.running_time - valloader = self.params["trainer"].valloader # Log loss for key in self.running_loss.keys(): @@ -179,7 +230,10 @@ def on_val_epoch_end(self, logs: Dict = None): ) LOGGER.text( "[{}|{}] || {} || Time: {:10.4f} (it/s)".format( - iters, num_iterations, loss_string, len(valloader) / epoch_time + iters, + num_iterations, + loss_string, + self.params["valloader_length"] / epoch_time, ), level=LoggerObserver.INFO, ) diff --git a/theseus/base/callbacks/lr_autofind.py b/theseus/base/callbacks/lr_autofind.py deleted file mode 100644 index 413e3a5..0000000 --- a/theseus/base/callbacks/lr_autofind.py +++ /dev/null @@ -1,124 +0,0 @@ -import time -from typing import Dict, List - -import numpy as np - -from theseus.base.callbacks.base_callbacks import Callbacks -from theseus.base.utilities.loggers.observer import LoggerObserver - -LOGGER = LoggerObserver.getLogger("main") - - -class AutoFindLRCallbacks(Callbacks): - """ - Callbacks for auto finding LR - :params: - lr_range: List - learning rate search space - gamma: int - number of iterations per lr step - """ - - def __init__( - self, lr_range: List[float], num_steps: int, num_epochs: int = 1, **kwargs - ) -> None: - super().__init__() - - self.lr_range = lr_range - self.num_steps = num_steps - self.num_epochs = num_epochs - - assert ( - self.lr_range[1] > self.lr_range[0] - ), "Learning rate range should be from low to high" - assert self.num_epochs > 0, "Num epochs should be higher than 0" - - def auto_get_interval(self): - """ - Automatically decide the number of interval - """ - trainloader = self.params["trainer"].trainloader - num_iterations = len(trainloader) * self.num_epochs - - num_iterations_per_steps = (num_iterations - 1) // self.num_steps - step_iters = [ - int(round(x * num_iterations_per_steps)) for x in range(0, self.num_steps) - ] - - gamma = (self.lr_range[1] - self.lr_range[0]) / float(self.num_steps - 1) - lrs = [self.lr_range[0] + x * gamma for x in range(0, self.num_steps)] - - return step_iters, lrs - - def on_start(self, logs: Dict = None): - """ - Before going to the main loop - """ - - LOGGER.text( - "Autofinding LR is activated. Running for 1 epoch only...", - level=LoggerObserver.DEBUG, - ) - - trainloader = self.params["trainer"].trainloader - num_iterations = len(trainloader) * self.num_epochs - self.params["trainer"].num_iterations = num_iterations - - self.step_iters, self.lrs = self.auto_get_interval() - self.current_idx = 0 - LOGGER.text( - "Interval for Learning Rate AutoFinding not specified. Auto calculating...", - level=LoggerObserver.DEBUG, - ) - - self.tracking_loss = [] - self.tracking_lr = [] - - optim = self.params["trainer"].optimizer - for g in optim.param_groups: - g["lr"] = self.lrs[self.current_idx] - self.current_idx += 1 - - def on_train_batch_end(self, logs: Dict = None): - """ - After finish a batch - """ - - lr = logs["lr"] - iters = logs["iters"] - loss_dict = logs["loss_dict"] - optim = self.params["trainer"].optimizer - - log_dict = [ - { - "tag": f"AutoLR/{k} Loss", - "value": v, - "type": LoggerObserver.SCALAR, - "kwargs": {"step": iters}, - } - for k, v in loss_dict.items() - ] - - # Log learning rates - log_dict.append( - { - "tag": "AutoLR/Learning rate", - "value": lr, - "type": LoggerObserver.SCALAR, - "kwargs": {"step": iters}, - } - ) - - LOGGER.log(log_dict) - - self.tracking_loss.append(sum([v for v in loss_dict.values()])) - self.tracking_lr.append(lr) - - # Logging - if ( - self.current_idx < len(self.step_iters) - and iters == self.step_iters[self.current_idx] - ): - for g in optim.param_groups: - g["lr"] = self.lrs[self.current_idx] - self.current_idx += 1 diff --git a/theseus/base/callbacks/metric_logging_callback.py b/theseus/base/callbacks/metric_logging_callback.py new file mode 100644 index 0000000..8dfe9f4 --- /dev/null +++ b/theseus/base/callbacks/metric_logging_callback.py @@ -0,0 +1,77 @@ +import json +import os +import os.path as osp +from typing import Dict + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Callback + +from theseus.base.utilities.loggers.observer import LoggerObserver + +LOGGER = LoggerObserver.getLogger("main") + + +class MetricLoggerCallback(Callback): + """ + Callbacks for logging running metric while training every epoch end + Features: + - Only do logging + """ + + def __init__(self, save_json: bool = True, **kwargs) -> None: + super().__init__() + self.save_json = save_json + if self.save_json: + self.save_dir = kwargs.get("save_dir", None) + if self.save_dir is not None: + self.save_dir = osp.join(self.save_dir, "Validation") + os.makedirs(self.save_dir, exist_ok=True) + self.output_dict = [] + + def on_validation_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """ + After finish validation + """ + iters = trainer.global_step + metric_dict = pl_module.metric_dict + + # Save json + if self.save_json: + item = {} + for metric, score in metric_dict.items(): + if isinstance(score, (int, float)): + item[metric] = float(f"{score:.5f}") + if len(item.keys()) > 0: + item["iters"] = iters + self.output_dict.append(item) + + # Log metric + metric_string = "" + for metric, score in metric_dict.items(): + if isinstance(score, (int, float)): + metric_string += metric + ": " + f"{score:.5f}" + " | " + metric_string += "\n" + + LOGGER.text(metric_string, level=LoggerObserver.INFO) + + # Call other loggers + log_dict = [ + {"tag": f"Validation/{k}", "value": v, "kwargs": {"step": iters}} + for k, v in metric_dict.items() + ] + + LOGGER.log(log_dict) + + def teardown( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: + """ + After finish everything + """ + if self.save_json: + save_json = osp.join(self.save_dir, "metrics.json") + if len(self.output_dict) > 0: + with open(save_json, "w") as f: + json.dump(self.output_dict, f) diff --git a/theseus/base/callbacks/optuna_callback.py b/theseus/base/callbacks/optuna_callback.py new file mode 100644 index 0000000..f93c714 --- /dev/null +++ b/theseus/base/callbacks/optuna_callback.py @@ -0,0 +1,36 @@ +from typing import Dict, List + +import lightning.pytorch as pl +import optuna +from lightning.pytorch.callbacks import Callback + +from theseus.base.utilities.loggers.observer import LoggerObserver + +LOGGER = LoggerObserver.getLogger("main") + + +class OptunaCallback(Callback): + """ + Callbacks for reporting value to optuna trials to decide whether to prune + """ + + def __init__(self, trial: optuna.Trial, **kwargs) -> None: + super().__init__() + self.trial = trial + + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + """ + After finish validation + """ + + iters = trainer.global_step + metric_dict = pl_module.metric_dict + + best_key = self.trial.user_attrs["best_key"] + self.trial.report(value=metric_dict[best_key], step=iters) + + if self.trial.should_prune(): + LOGGER.text( + f"Trial {self.trial.number} has been pruned", level=LoggerObserver.DEBUG + ) + raise optuna.TrialPruned() diff --git a/theseus/base/callbacks/timer_callbacks.py b/theseus/base/callbacks/timer_callback.py similarity index 78% rename from theseus/base/callbacks/timer_callbacks.py rename to theseus/base/callbacks/timer_callback.py index a719f6b..796d812 100644 --- a/theseus/base/callbacks/timer_callbacks.py +++ b/theseus/base/callbacks/timer_callback.py @@ -1,7 +1,8 @@ import time -from typing import Dict, List -from theseus.base.callbacks.base_callbacks import Callbacks +import lightning.pytorch as pl +from lightning.pytorch.callbacks import Callback + from theseus.base.utilities.loggers.observer import LoggerObserver LOGGER = LoggerObserver.getLogger("main") @@ -14,7 +15,7 @@ def seconds_to_hours(seconds): return h, m, s -class TimerCallbacks(Callbacks): +class TimerCallback(Callback): """ Callbacks for logging running loss/metric/time while training Features: @@ -26,7 +27,7 @@ def __init__(self, **kwargs) -> None: self.running_time = 0 self.start_time = 0 - def on_start(self, logs: Dict = None): + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ Before going to the main loop """ @@ -36,7 +37,7 @@ def on_start(self, logs: Dict = None): level=LoggerObserver.INFO, ) - def on_finish(self, logs: Dict = None): + def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ After the main loop """ @@ -49,13 +50,13 @@ def on_finish(self, logs: Dict = None): level=LoggerObserver.INFO, ) - def on_train_epoch_start(self, logs: Dict = None): + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ Before going to the training loop """ self.train_epoch_start_time = time.time() - def on_train_epoch_end(self, logs: Dict = None): + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After going to the training loop """ @@ -66,7 +67,7 @@ def on_train_epoch_end(self, logs: Dict = None): level=LoggerObserver.INFO, ) - def on_val_epoch_start(self, logs: Dict = None): + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ Before main validation loops """ @@ -76,7 +77,7 @@ def on_val_epoch_start(self, logs: Dict = None): LoggerObserver.INFO, ) - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ diff --git a/theseus/base/callbacks/tsb_callbacks.py b/theseus/base/callbacks/tsb_callback.py similarity index 91% rename from theseus/base/callbacks/tsb_callbacks.py rename to theseus/base/callbacks/tsb_callback.py index 443d39b..49b07d1 100644 --- a/theseus/base/callbacks/tsb_callbacks.py +++ b/theseus/base/callbacks/tsb_callback.py @@ -1,6 +1,7 @@ import os -from theseus.base.callbacks.base_callbacks import Callbacks +from lightning.pytorch.callbacks import Callback + from theseus.base.utilities.loading import find_old_tflog from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.base.utilities.loggers.tsb_logger import TensorboardLogger @@ -8,7 +9,7 @@ LOGGER = LoggerObserver.getLogger("main") -class TensorboardCallbacks(Callbacks): +class TensorboardCallback(Callback): """ Callbacks for logging running loss/metric/time while training to tensorboard Features: diff --git a/theseus/base/callbacks/wandb_callbacks.py b/theseus/base/callbacks/wandb_callback.py similarity index 86% rename from theseus/base/callbacks/wandb_callbacks.py rename to theseus/base/callbacks/wandb_callback.py index 01dfc3c..699deb2 100644 --- a/theseus/base/callbacks/wandb_callbacks.py +++ b/theseus/base/callbacks/wandb_callback.py @@ -4,12 +4,13 @@ from datetime import datetime from typing import Dict +import lightning.pytorch as pl from deepdiff import DeepDiff +from lightning.pytorch.callbacks import Callback +from omegaconf import DictConfig, OmegaConf -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.base.utilities.loggers.wandb_logger import WandbLogger, find_run_id -from theseus.opt import Config try: import wandb as wandblogger @@ -39,7 +40,7 @@ def pretty_print_diff(diff): return "\n".join(texts) -class WandbCallbacks(Callbacks): +class WandbCallback(Callback): """ Callbacks for logging running loss/metric/time while training to wandb server Features: @@ -60,7 +61,7 @@ def __init__( group_name: str = None, save_dir: str = None, resume: str = None, - config_dict: Dict = None, + config_dict: DictConfig = None, **kwargs, ) -> None: super().__init__() @@ -97,10 +98,11 @@ def __init__( ) # Check if the config remains the same, if not, create new run id - old_config_dict = Config(old_config_path) + old_config_dict = OmegaConf.load(old_config_path) tmp_config_dict = deepcopy(self.config_dict) ## strip off global key because `resume` will always different old_config_dict.pop("global", None) + OmegaConf.set_struct(tmp_config_dict, False) tmp_config_dict.pop("global", None) if old_config_dict == tmp_config_dict: self.id = run_id @@ -120,6 +122,7 @@ def __init__( """Run configuration changes since the last run. Decide: (1) Terminate run (2) Create new run + (3) Override run (not recommended) """, LoggerObserver.WARN, ) @@ -132,9 +135,15 @@ def __init__( LoggerObserver.WARN, ) self.id = wandblogger.util.generate_id() - else: + elif answer == 1: LOGGER.text("Terminating run...", level=LoggerObserver.ERROR) raise InterruptedError() + else: + LOGGER.text( + "Overriding run...", + LoggerObserver.WARN, + ) + self.id = run_id except ValueError as e: LOGGER.text( @@ -158,7 +167,9 @@ def __init__( ) LOGGER.subscribe(self.wandb_logger) - def on_start(self, logs: Dict = None): + def setup( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str + ) -> None: """ Before going to the main loop. Save run id """ @@ -173,7 +184,7 @@ def on_start(self, logs: Dict = None): value=osp.join(self.save_dir, "*.yaml"), ) - def on_finish(self, logs: Dict = None): + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str): """ After finish training """ @@ -181,10 +192,10 @@ def on_finish(self, logs: Dict = None): self.wandb_logger.log_file( tag="checkpoint", base_folder=self.save_dir, - value=osp.join(base_folder, "*.pth"), + value=osp.join(base_folder, "*.ckpt"), ) - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ On validation batch (iteration) end """ @@ -192,5 +203,5 @@ def on_val_epoch_end(self, logs: Dict = None): self.wandb_logger.log_file( tag="checkpoint", base_folder=self.save_dir, - value=osp.join(base_folder, "*.pth"), + value=osp.join(base_folder, "*.ckpt"), ) diff --git a/theseus/base/datasets/__init__.py b/theseus/base/datasets/__init__.py index 6036d2d..792e42f 100644 --- a/theseus/base/datasets/__init__.py +++ b/theseus/base/datasets/__init__.py @@ -18,3 +18,5 @@ DATALOADER_REGISTRY.register(BalanceSampler) DATALOADER_REGISTRY.register(ChainCollatorWrapper) DATALOADER_REGISTRY.register(DataLoaderWithCollator) + +from .wrapper import LightningDataModuleWrapper diff --git a/theseus/base/datasets/sampler.py b/theseus/base/datasets/sampler.py index 8a8a603..0f1fd66 100644 --- a/theseus/base/datasets/sampler.py +++ b/theseus/base/datasets/sampler.py @@ -24,7 +24,7 @@ def __init__(self, dataset: torch.utils.data.Dataset, **kwargs): class_weighting = 1.0 / class_count sample_weights = np.array([class_weighting[t] for t in labels.squeeze()]) sample_weights = torch.from_numpy(sample_weights) - super().__init__(sample_weights, len(sample_weights)) + super().__init__(sample_weights, len(sample_weights), replacement=True) def _load_labels(self, dataset): op = getattr(dataset, "_calculate_classes_dist", None) diff --git a/theseus/base/datasets/wrapper.py b/theseus/base/datasets/wrapper.py new file mode 100644 index 0000000..a9b2de1 --- /dev/null +++ b/theseus/base/datasets/wrapper.py @@ -0,0 +1,24 @@ +import lightning as L +import torch + + +class LightningDataModuleWrapper(L.LightningDataModule): + def __init__( + self, + trainloader: torch.utils.data.DataLoader, + valloader: torch.utils.data.DataLoader, + testloader: torch.utils.data.DataLoader = None, + ): + super().__init__() + self.trainloader = trainloader + self.valloader = valloader + self.testloader = testloader + + def train_dataloader(self): + return self.trainloader + + def val_dataloader(self): + return self.valloader + + def test_dataloader(self): + return self.testloader diff --git a/theseus/base/losses/__init__.py b/theseus/base/losses/__init__.py index f670a14..48731b2 100644 --- a/theseus/base/losses/__init__.py +++ b/theseus/base/losses/__init__.py @@ -2,6 +2,7 @@ LOSS_REGISTRY = Registry("LOSS") +from .bce_loss import BCELoss from .ce_loss import * from .focal_loss import FocalLoss from .mse_loss import MeanSquaredErrorLoss @@ -12,3 +13,4 @@ LOSS_REGISTRY.register(FocalLoss) LOSS_REGISTRY.register(MeanSquaredErrorLoss) LOSS_REGISTRY.register(ClassificationSmoothCELoss) +LOSS_REGISTRY.register(BCELoss) diff --git a/theseus/base/losses/bce_loss.py b/theseus/base/losses/bce_loss.py new file mode 100644 index 0000000..c9c22b4 --- /dev/null +++ b/theseus/base/losses/bce_loss.py @@ -0,0 +1,40 @@ +from typing import Any, Dict + +import torch +from torch import nn + +from theseus.base.utilities.cuda import move_to + + +class BCELoss(nn.Module): + r"""CELoss is warper of cross-entropy loss""" + + def __init__(self, **kwargs): + super(BCELoss, self).__init__() + if "weight" in kwargs: + weight = torch.FloatTensor(kwargs.get("weight")) + else: + weight = None + self.criterion = nn.BCELoss( + weight=weight, + ) + + def forward( + self, + outputs: Dict[str, Any], + batch: Dict[str, Any], + device: torch.device = None, + ): + pred = outputs["outputs"] + if device is not None: + target = move_to(batch["targets"], device) + else: + target = batch["targets"].float() + + if pred.shape == target.shape: + loss = self.criterion(pred, target) + else: + loss = self.criterion(pred, target.view(-1).contiguous()) + + loss_dict = {"BCE": loss.item()} + return loss, loss_dict diff --git a/theseus/base/losses/ce_loss.py b/theseus/base/losses/ce_loss.py index a19b7e0..61b52cf 100644 --- a/theseus/base/losses/ce_loss.py +++ b/theseus/base/losses/ce_loss.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import * import torch from torch import nn @@ -11,18 +11,27 @@ class ClassificationCELoss(nn.Module): r"""CELoss is warper of cross-entropy loss""" - def __init__(self, **kwargs): + def __init__(self, weight: List = None, **kwargs): super(ClassificationCELoss, self).__init__() - self.criterion = nn.CrossEntropyLoss() + if weight is not None: + weight = torch.tensor(weight) + self.criterion = nn.CrossEntropyLoss( + weight=weight, + ignore_index=kwargs.get("ignore_index", -100), + label_smoothing=kwargs.get("label_smoothing", 0.0), + ) def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): pred = outputs["outputs"] - target = move_to(batch["targets"], device) + if device is not None: + target = move_to(batch["targets"], device) + else: + target = batch["targets"] if pred.shape == target.shape: loss = self.criterion(pred, target) @@ -45,7 +54,7 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): pred = outputs["outputs"] target = batch["targets"] diff --git a/theseus/base/losses/focal_loss.py b/theseus/base/losses/focal_loss.py index 90c8fd0..84df8bc 100644 --- a/theseus/base/losses/focal_loss.py +++ b/theseus/base/losses/focal_loss.py @@ -20,10 +20,13 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): outputs = outputs["outputs"] - targets = move_to(batch["targets"], device) + if device is not None: + targets = move_to(batch["targets"], device) + else: + targets = batch["targets"] num_classes = outputs.shape[-1] # Need to be one hot encoding diff --git a/theseus/base/losses/mse_loss.py b/theseus/base/losses/mse_loss.py index 874b731..1144532 100644 --- a/theseus/base/losses/mse_loss.py +++ b/theseus/base/losses/mse_loss.py @@ -17,10 +17,13 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): pred = outputs["outputs"] - target = move_to(batch["targets"], device) + if device is not None: + target = move_to(batch["targets"], device) + else: + target = batch["targets"] if pred.shape == target.shape: loss = self.criterion(pred, target) diff --git a/theseus/base/losses/multi_loss.py b/theseus/base/losses/multi_loss.py index b262a42..5da01a9 100644 --- a/theseus/base/losses/multi_loss.py +++ b/theseus/base/losses/multi_loss.py @@ -16,7 +16,7 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): """ Forward inputs and targets through multiple losses diff --git a/theseus/base/losses/smoothing.py b/theseus/base/losses/smoothing.py index cc90a72..133072a 100644 --- a/theseus/base/losses/smoothing.py +++ b/theseus/base/losses/smoothing.py @@ -24,11 +24,14 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): pred = outputs["outputs"] - target = move_to(batch["targets"], device) + if device is not None: + target = move_to(batch["targets"], device) + else: + target = batch["targets"] logprobs = F.log_softmax(pred, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) @@ -49,11 +52,14 @@ def forward( self, outputs: Dict[str, Any], batch: Dict[str, Any], - device: torch.device, + device: torch.device = None, ): pred = outputs["outputs"] - target = move_to(batch["targets"], device) + if device is not None: + target = move_to(batch["targets"], device) + else: + target = batch["targets"] loss = torch.sum(-target * F.log_softmax(pred, dim=-1), dim=-1) loss_dict = {"SoftCE": loss.item()} diff --git a/theseus/base/metrics/__init__.py b/theseus/base/metrics/__init__.py index 589ef4a..7a2d263 100644 --- a/theseus/base/metrics/__init__.py +++ b/theseus/base/metrics/__init__.py @@ -8,10 +8,14 @@ from .bl_accuracy import * from .confusion_matrix import * from .f1 import * +from .mcc import * from .precision_recall import * +from .roc_auc_score import * METRIC_REGISTRY.register(Accuracy) METRIC_REGISTRY.register(BalancedAccuracyMetric) METRIC_REGISTRY.register(F1ScoreMetric) METRIC_REGISTRY.register(ConfusionMatrix) METRIC_REGISTRY.register(PrecisionRecall) +METRIC_REGISTRY.register(ROCAUCScore) +METRIC_REGISTRY.register(MCC) diff --git a/theseus/base/metrics/accuracy.py b/theseus/base/metrics/accuracy.py index 2a8d7a7..89830a8 100644 --- a/theseus/base/metrics/accuracy.py +++ b/theseus/base/metrics/accuracy.py @@ -15,15 +15,14 @@ def __init__(self, label_type: str = "multiclass", **kwargs): self.threshold = kwargs.get("threshold", 0.5) self.reset() - def update(self, output: Dict[str, Any], batch: Dict[str, Any]): + def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): """ Perform calculation based on prediction and targets """ - output = output["outputs"] - target = batch["targets"] - + outputs = outputs["outputs"].detach().cpu() + target = batch["targets"].cpu() prediction = logits2labels( - output, label_type=self.type, threshold=self.threshold + outputs, label_type=self.type, threshold=self.threshold ) correct = (prediction.view(-1) == target.view(-1)).sum() diff --git a/theseus/base/metrics/bl_accuracy.py b/theseus/base/metrics/bl_accuracy.py index ad3aa07..fc56c61 100644 --- a/theseus/base/metrics/bl_accuracy.py +++ b/theseus/base/metrics/bl_accuracy.py @@ -1,9 +1,10 @@ from typing import Any, Dict import numpy as np -import torch +from sklearn.metrics import balanced_accuracy_score from theseus.base.metrics.metric_template import Metric +from theseus.base.utilities.logits import logits2labels def compute_multiclass(outputs, targets, index): @@ -22,8 +23,10 @@ class BalancedAccuracyMetric(Metric): Balanced Accuracy metric for classification """ - def __init__(self, **kwargs): + def __init__(self, label_type: str = "multiclass", **kwargs): super().__init__(**kwargs) + self.type = label_type + self.threshold = kwargs.get("threshold", 0.5) self.reset() def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): @@ -32,7 +35,8 @@ def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): """ outputs = outputs["outputs"] targets = batch["targets"] - outputs = torch.argmax(outputs, dim=1) + outputs = logits2labels(outputs, label_type=self.type, threshold=self.threshold) + outputs = outputs.detach().cpu() targets = targets.detach().cpu().view(-1) @@ -51,19 +55,21 @@ def value(self): self.corrects = {str(k): 0 for k in self.unique_ids} self.total = {str(k): 0 for k in self.unique_ids} - - # Calculate accuracy for each class index - for i in self.unique_ids: - correct, sample_size = compute_multiclass(self.outputs, self.targets, i) - self.corrects[str(i)] += correct - self.total[str(i)] += sample_size - each_acc = [ - self.corrects[str(i)] * 1.0 / (self.total[str(i)]) - for i in self.unique_ids - if self.total[str(i)] > 0 - ] - - # Get mean accuracy across classes - values = sum(each_acc) / len(self.unique_ids) + if self.type == "binary": + values = balanced_accuracy_score(self.targets, self.outputs) + else: + # Calculate accuracy for each class index + for i in self.unique_ids: + correct, sample_size = compute_multiclass(self.outputs, self.targets, i) + self.corrects[str(i)] += correct + self.total[str(i)] += sample_size + each_acc = [ + self.corrects[str(i)] * 1.0 / (self.total[str(i)]) + for i in self.unique_ids + if self.total[str(i)] > 0 + ] + + # Get mean accuracy across classes + values = sum(each_acc) / len(self.unique_ids) return {"bl_acc": values} diff --git a/theseus/base/metrics/confusion_matrix.py b/theseus/base/metrics/confusion_matrix.py index 4d24b2f..c3cb04a 100644 --- a/theseus/base/metrics/confusion_matrix.py +++ b/theseus/base/metrics/confusion_matrix.py @@ -6,8 +6,11 @@ from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix from theseus.base.metrics.metric_template import Metric +from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.base.utilities.logits import logits2labels +LOGGER = LoggerObserver.getLogger("main") + def plot_cfm(cm, ax, labels: List): """ @@ -78,8 +81,8 @@ def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): Perform calculation based on prediction and targets """ # in torchvision models, pred is a dict[key=out, value=Tensor] - outputs = outputs["outputs"] - targets = batch["targets"] + targets = batch["targets"].cpu() + outputs = outputs["outputs"].detach().cpu() outputs = logits2labels(outputs, label_type=self.type, threshold=self.threshold) @@ -91,18 +94,26 @@ def reset(self): self.targets = [] def value(self): - if self.type == "multiclass": - values = confusion_matrix( - self.outputs, - self.targets, - labels=self.num_classes, - normalize="pred", - ) - values = values[np.newaxis, :, :] - else: - values = multilabel_confusion_matrix( - self.outputs, self.targets, labels=self.num_classes + try: + if self.type == "multiclass": + values = confusion_matrix( + self.outputs, + self.targets, + labels=self.num_classes, + normalize="pred", + ) + values = values[np.newaxis, :, :] + + else: + values = multilabel_confusion_matrix( + self.outputs, self.targets, labels=self.num_classes + ) + fig = make_cm_fig(values, self.classnames) + except ValueError as e: + LOGGER.text( + f"Confusion Matrix could not be calculated: {e}", + level=LoggerObserver.WARN, ) + fig = 0 - fig = make_cm_fig(values, self.classnames) return {"cfm": fig} diff --git a/theseus/base/metrics/f1.py b/theseus/base/metrics/f1.py index 5bffd5f..9ca85af 100644 --- a/theseus/base/metrics/f1.py +++ b/theseus/base/metrics/f1.py @@ -22,8 +22,8 @@ def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): """ Perform calculation based on prediction and targets """ - targets = batch["targets"] - outputs = outputs["outputs"] + targets = batch["targets"].cpu().view(-1) + outputs = outputs["outputs"].detach().cpu() outputs = logits2labels(outputs, label_type=self.type, threshold=self.threshold) diff --git a/theseus/base/metrics/mcc.py b/theseus/base/metrics/mcc.py new file mode 100644 index 0000000..dc1cb43 --- /dev/null +++ b/theseus/base/metrics/mcc.py @@ -0,0 +1,38 @@ +from typing import Any, Dict + +from sklearn.metrics import matthews_corrcoef + +from theseus.base.metrics.metric_template import Metric +from theseus.base.utilities.logits import logits2labels + + +class MCC(Metric): + """ + Mathew Correlation Coefficient + """ + + def __init__(self, label_type: str = "multiclass", **kwargs): + super().__init__(**kwargs) + self.type = label_type + self.reset() + + def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): + """ + Perform calculation based on prediction and targets + """ + targets = batch["targets"].cpu() + outputs = outputs["outputs"].detach().cpu() + outputs = logits2labels(outputs, label_type=self.type) + + self.preds += outputs.numpy().tolist() + self.targets += targets.numpy().tolist() + + def value(self): + score = matthews_corrcoef(self.targets, self.preds) + return { + f"mcc": score, + } + + def reset(self): + self.targets = [] + self.preds = [] diff --git a/theseus/base/metrics/precision_recall.py b/theseus/base/metrics/precision_recall.py index 2852e2f..92234bf 100644 --- a/theseus/base/metrics/precision_recall.py +++ b/theseus/base/metrics/precision_recall.py @@ -22,8 +22,8 @@ def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): """ Perform calculation based on prediction and targets """ - targets = batch["targets"] - outputs = outputs["outputs"] + targets = batch["targets"].cpu() + outputs = outputs["outputs"].detach().cpu() outputs = logits2labels(outputs, label_type=self.type, threshold=self.threshold) self.preds += outputs.numpy().tolist() diff --git a/theseus/base/metrics/roc_auc_score.py b/theseus/base/metrics/roc_auc_score.py new file mode 100644 index 0000000..ef3bca3 --- /dev/null +++ b/theseus/base/metrics/roc_auc_score.py @@ -0,0 +1,93 @@ +from typing import Any, Dict + +import torch + +try: + from scikitplot.metrics import plot_precision_recall, plot_roc + + has_scikitplot = True +except: + has_scikitplot = False +from sklearn.metrics import roc_auc_score + +from theseus.base.metrics.metric_template import Metric +from theseus.base.utilities.cuda import detach, move_to +from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.base.utilities.logits import logits2labels + +LOGGER = LoggerObserver.getLogger("main") + + +class ROCAUCScore(Metric): + """ + Area Under Curve, ROC Curve Score + """ + + def __init__( + self, + average: str = "weighted", + label_type: str = "multiclass", + plot_curve: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.type = label_type + self.average = average + self.plot_curve = plot_curve + + if self.type == "multiclass": + self.label_type = "ovr" + elif self.type == "multilabel": + self.label_type = "ovr" + else: + self.label_type = "raise" + + self.reset() + + def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): + """ + Perform calculation based on prediction and targets + """ + targets = batch["targets"].cpu() + outputs = move_to(outputs["outputs"], torch.device("cpu")) + + if self.type == "multiclass": + probs = torch.softmax(outputs, dim=1) + self.preds.extend(probs.numpy().tolist()) + else: + _, probs = logits2labels(outputs, label_type=self.type, return_probs=True) + self.preds += probs.numpy().tolist() + self.targets += targets.view(-1).numpy().tolist() + + def value(self): + try: + roc_auc_scr = roc_auc_score( + self.targets, + self.preds, + average=self.average, + multi_class=self.label_type, + ) + except Exception as e: + LOGGER.text( + f"AUC score could not be calculated: {e}", level=LoggerObserver.WARN + ) + roc_auc_scr = 0 + + results = { + f"{self.average}-roc_auc_score": roc_auc_scr, + } + if has_scikitplot and self.plot_curve: + roc_curve_fig = plot_roc(self.targets, self.preds).get_figure() + pr_fig = plot_precision_recall(self.targets, self.preds).get_figure() + results.update( + { + "roc_curve": roc_curve_fig, + "precision_recall_curve": pr_fig, + } + ) + + return results + + def reset(self): + self.targets = [] + self.preds = [] diff --git a/theseus/base/models/__init__.py b/theseus/base/models/__init__.py index 1ff5b97..4c6d050 100644 --- a/theseus/base/models/__init__.py +++ b/theseus/base/models/__init__.py @@ -1,3 +1,5 @@ from theseus.registry import Registry +from .wrapper import LightningModelWrapper + MODEL_REGISTRY = Registry("MODEL") diff --git a/theseus/base/models/wrapper.py b/theseus/base/models/wrapper.py index 4dafa24..93b39a1 100644 --- a/theseus/base/models/wrapper.py +++ b/theseus/base/models/wrapper.py @@ -1,29 +1,77 @@ -import torch -from torch import nn +from typing import Any, Callable, Dict, List, Mapping, Optional, Union +import lightning.pytorch as pl +import torch +import torch.nn as nn +from lightning.pytorch.utilities.types import _METRIC, STEP_OUTPUT -class ModelWithLoss(nn.Module): - """Add utilitarian functions for module to work with pipeline +from theseus.base.datasets import LightningDataModuleWrapper +from theseus.base.optimizers import OPTIM_REGISTRY, SCHEDULER_REGISTRY +from theseus.base.utilities.getter import get_instance - Args: - model (Module): Base Model without loss - loss (Module): Base loss function with stat +class LightningModelWrapper(pl.LightningModule): + """ + Wrapper for Lightning Module + Instansiates the model, criterion, optimizer and scheduler """ - def __init__(self, model: nn.Module, criterion: nn.Module, device: torch.device): + def __init__( + self, + model: nn.Module, + criterion: nn.Module = None, + metrics: List[Any] = None, + optimizer_config: Dict = None, + scheduler_config: Dict = None, + scheduler_kwargs: Dict = None, + datamodule: LightningDataModuleWrapper = None, + ): super().__init__() self.model = model self.criterion = criterion - self.device = device + self.metrics = metrics + self.optimizer_config = optimizer_config + self.scheduler_config = scheduler_config + self.scheduler_kwargs = scheduler_kwargs + self.datamodule = datamodule + self.lr = 0 + self.metric_dict = {} + + def log_dict(self, dictionary: Mapping[str, Any], **kwargs) -> None: + filtered_dict = { + key: value + for key, value in dictionary.items() + if isinstance(value, (torch.Tensor, float, int)) + } + return super().log_dict(filtered_dict, **kwargs) - def forward_batch(self, batch, metrics=None): + def on_train_batch_end( + self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + lrl = [x["lr"] for x in self.optimizer.param_groups] + self.lr = sum(lrl) / len(lrl) + + def on_validation_epoch_end(self) -> None: + self.metric_dict = {} + if self.metrics is not None: + for metric in self.metrics: + self.metric_dict.update(metric.value()) + metric.reset() + + self.log_dict( + self.metric_dict, + prog_bar=True, + batch_size=self.datamodule.valloader.batch_size, + ) + + def _forward(self, batch: Dict, metrics: List[Any] = None): """ Forward the batch through models, losses and metrics If some parameters are needed, it's best to include in the batch """ - outputs = self.model.forward_batch(batch, self.device) - loss, loss_dict = self.criterion(outputs, batch, self.device) + + outputs = self.model.forward_batch(batch) + loss, loss_dict = self.criterion(outputs, batch) if metrics is not None: for metric in metrics: @@ -31,14 +79,46 @@ def forward_batch(self, batch, metrics=None): return {"loss": loss, "loss_dict": loss_dict, "model_outputs": outputs} - def training_step(self, batch): - return self.forward_batch(batch) + def trainable_parameters(self): + return sum(p.numel() for p in self.parameters() if p.requires_grad) - def evaluate_step(self, batch, metrics=None): - return self.forward_batch(batch, metrics) + def training_step(self, batch, batch_idx): + # training_step defines the train loop. + outputs = self._forward(batch) + self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False) + return outputs - def state_dict(self): - return self.model.state_dict() + def validation_step(self, batch, batch_idx): + # this is the validation loop + outputs = self._forward(batch, metrics=self.metrics) + self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False) + return outputs - def trainable_parameters(self): - return sum(p.numel() for p in self.parameters() if p.requires_grad) + def predict_step(self, batch, batch_idx=None): + pred = self.model.get_prediction(batch) + return pred + + def configure_optimizers(self): + if self.optimizer_config is not None: + self.optimizer = get_instance( + self.optimizer_config, + registry=OPTIM_REGISTRY, + params=self.model.parameters(), + ) + + if self.scheduler_config is not None: + self.scheduler = get_instance( + self.scheduler_config, + registry=SCHEDULER_REGISTRY, + optimizer=self.optimizer, + **self.scheduler_kwargs, + ) + else: + return self.optimizer + + scheduler_interval = "epoch" if self.scheduler.step_per_epoch else "step" + scheduler = { + "scheduler": self.scheduler.scheduler, + "interval": scheduler_interval, + } + return [self.optimizer], [scheduler] diff --git a/theseus/base/optimizers/schedulers/cosine.py b/theseus/base/optimizers/schedulers/cosine.py index 50c3e82..47542ff 100644 --- a/theseus/base/optimizers/schedulers/cosine.py +++ b/theseus/base/optimizers/schedulers/cosine.py @@ -1,6 +1,5 @@ # code from AllenNLP -import logging from typing import Any, Dict import numpy as np @@ -11,7 +10,7 @@ LOGGER = LoggerObserver.getLogger("main") -class CosineWithRestarts: +class CosineWithRestarts(torch.optim.lr_scheduler._LRScheduler): """ Cosine annealing with restarts. This is described in the paper https://arxiv.org/abs/1608.03983. Note that diff --git a/theseus/base/optimizers/schedulers/wrapper.py b/theseus/base/optimizers/schedulers/wrapper.py index 6c4e389..3f5d23b 100644 --- a/theseus/base/optimizers/schedulers/wrapper.py +++ b/theseus/base/optimizers/schedulers/wrapper.py @@ -51,7 +51,7 @@ def one_cycle(y1=0.0, y2=1.0, steps=100): gamma=kwargs["gamma"], last_epoch=kwargs["last_epoch"], ) - step_per_epoch = False + step_per_epoch = True elif scheduler_name == "plateau": scheduler = ReduceLROnPlateau( @@ -90,6 +90,18 @@ def one_cycle(y1=0.0, y2=1.0, steps=100): ) step_per_epoch = True + elif scheduler_name == "tf_cosinewarmup": + from transformers import get_cosine_schedule_with_warmup + + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=kwargs["num_warmup_steps"], + num_training_steps=kwargs["num_iterations"], + num_cycles=kwargs.get("num_cycles", 0.5), + last_epoch=kwargs["last_epoch"], + ) + step_per_epoch = False + self.scheduler = scheduler self.step_per_epoch = step_per_epoch diff --git a/theseus/base/pipeline.py b/theseus/base/pipeline.py index e4d9cd3..2aba725 100644 --- a/theseus/base/pipeline.py +++ b/theseus/base/pipeline.py @@ -2,29 +2,29 @@ from datetime import datetime import torch +from omegaconf import DictConfig, OmegaConf from theseus.base.augmentations import TRANSFORM_REGISTRY from theseus.base.callbacks import CALLBACKS_REGISTRY -from theseus.base.datasets import DATALOADER_REGISTRY, DATASET_REGISTRY +from theseus.base.datasets import ( + DATALOADER_REGISTRY, + DATASET_REGISTRY, + LightningDataModuleWrapper, +) from theseus.base.losses import LOSS_REGISTRY from theseus.base.metrics import METRIC_REGISTRY -from theseus.base.models import MODEL_REGISTRY -from theseus.base.models.wrapper import ModelWithLoss -from theseus.base.optimizers import OPTIM_REGISTRY, SCHEDULER_REGISTRY +from theseus.base.models import MODEL_REGISTRY, LightningModelWrapper from theseus.base.trainer import TRAINER_REGISTRY -from theseus.base.utilities.cuda import get_device, get_devices_info, move_to from theseus.base.utilities.folder import get_new_folder_name from theseus.base.utilities.getter import get_instance, get_instance_recursively -from theseus.base.utilities.loading import load_state_dict from theseus.base.utilities.loggers import FileLogger, ImageWriter, LoggerObserver from theseus.base.utilities.seed import seed_everything -from theseus.opt import Config class BasePipeline(object): """docstring for BasePipeline.""" - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(BasePipeline, self).__init__() self.opt = opt self.seed = self.opt["global"].get("seed", 1702) @@ -40,7 +40,6 @@ def init_globals(self): self.exp_name = self.opt["global"].get("exp_name", None) self.exist_ok = self.opt["global"].get("exist_ok", False) self.debug = self.opt["global"].get("debug", False) - self.device_name = self.opt["global"].get("device", "cpu") self.resume = self.opt["global"].get("resume", None) self.pretrained = self.opt["global"].get("pretrained", None) self.transform_cfg = self.opt["global"].get("cfg_transform", None) @@ -67,20 +66,10 @@ def init_globals(self): image_logger = ImageWriter(self.savedir) self.logger.subscribe(image_logger) - if self.transform_cfg is not None: - self.logger.text( - "cfg_transform is deprecated, please use 'includes' instead", - level=LoggerObserver.WARN, - ) - self.transform_cfg = Config.load_yaml(self.transform_cfg) - self.opt["augmentations"] = self.transform_cfg - else: - self.transform_cfg = self.opt.get("augmentations", None) - - self.device = get_device(self.device_name) + self.transform_cfg = self.opt.get("augmentations", None) # Logging out configs - self.logger.text(self.opt, level=LoggerObserver.INFO) + self.logger.text("\n" + OmegaConf.to_yaml(self.opt), level=LoggerObserver.INFO) self.logger.text( f"Everything will be saved to {self.savedir}", level=LoggerObserver.INFO, @@ -144,6 +133,13 @@ def init_validation_dataloader(self): level=LoggerObserver.INFO, ) + def init_datamodule(self): + self.datamodule = LightningDataModuleWrapper( + trainloader=getattr(self, "train_dataloader", None), + valloader=getattr(self, "val_dataloader", None), + testloader=getattr(self, "test_dataloader", None), + ) + def init_model(self): CLASSNAMES = getattr(self.val_dataset, "classnames", None) model = get_instance( @@ -152,7 +148,6 @@ def init_model(self): num_classes=len(CLASSNAMES) if CLASSNAMES is not None else None, classnames=CLASSNAMES, ) - model = move_to(model, self.device) return model def init_criterion(self): @@ -163,19 +158,45 @@ def init_criterion(self): num_classes=len(CLASSNAMES) if CLASSNAMES is not None else None, classnames=CLASSNAMES, ) - self.criterion = move_to(self.criterion, self.device) return self.criterion - def init_model_with_loss(self): - model = self.init_model() + def init_model_with_loss(self, is_train=True): + self.model = self.init_model() criterion = self.init_criterion() - self.model = ModelWithLoss(model, criterion, self.device) - self.logger.text( - f"Number of trainable parameters: {self.model.trainable_parameters():,}", - level=LoggerObserver.INFO, + num_epochs = self.opt["trainer"]["args"]["max_epochs"] + batch_size = self.opt["data"]["dataloader"]["val"]["args"]["batch_size"] + + self.model = LightningModelWrapper( + self.model, + criterion, + datamodule=getattr(self, "datamodule", None), + metrics=getattr(self, "metrics", None), + optimizer_config=self.opt["optimizer"] if is_train else None, + scheduler_config=self.opt["scheduler"] if is_train else None, + scheduler_kwargs={ + "num_epochs": num_epochs, + "num_iterations": num_epochs * len(self.train_dataloader), + "batch_size": batch_size, + "last_epoch": getattr(self, "last_epoch", -1), + } + if is_train + else None, ) - device_info = get_devices_info(self.device_name) - self.logger.text("Using " + device_info, level=LoggerObserver.INFO) + + pretrained = self.opt["global"].get("pretrained", None) + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu") + try: + self.model.load_state_dict(state_dict["state_dict"], strict=False) + self.logger.text( + f"Loaded pretrained model from {pretrained}", + level=LoggerObserver.SUCCESS, + ) + except Exception as e: + self.logger.text( + f"Loaded pretrained model from {pretrained}. Mismatched keys: {e}", + level=LoggerObserver.WARN, + ) def init_metrics(self): CLASSNAMES = getattr(self.val_dataset, "classnames", None) @@ -186,51 +207,6 @@ def init_metrics(self): classnames=CLASSNAMES, ) - def init_optimizer(self): - self.optimizer = get_instance( - self.opt["optimizer"], - registry=self.optimizer_registry, - params=self.model.parameters(), - ) - - def init_loading(self): - self.last_epoch = -1 - if getattr(self, "pretrained", None): - state_dict = torch.load(self.pretrained, map_location="cpu") - self.model.model = load_state_dict(self.model.model, state_dict, "model") - - if getattr(self, "resume", None): - state_dict = torch.load(self.resume, map_location="cpu") - self.model.model = load_state_dict(self.model.model, state_dict, "model") - self.optimizer = load_state_dict(self.optimizer, state_dict, "optimizer") - iters = load_state_dict(None, state_dict, "iters") - self.last_epoch = iters // len(self.train_dataloader) - 1 - - def init_scheduler(self): - if "scheduler" in self.opt.keys() and self.opt["scheduler"] is not None: - self.scheduler = get_instance( - self.opt["scheduler"], - registry=self.scheduler_registry, - optimizer=self.optimizer, - **{ - "num_epochs": self.opt["trainer"]["args"]["num_iterations"] - // len(self.train_dataloader), - "trainset": self.train_dataset, - "batch_size": self.opt["data"]["dataloader"]["val"]["args"][ - "batch_size" - ], - "last_epoch": getattr(self, "last_epoch", -1), - }, - ) - - if getattr(self, "resume", None): - state_dict = torch.load(self.resume) - self.scheduler = load_state_dict( - self.scheduler, state_dict, "scheduler" - ) - else: - self.scheduler = None - def init_callbacks(self): callbacks = get_instance_recursively( self.opt["callbacks"], @@ -244,19 +220,15 @@ def init_callbacks(self): def init_trainer(self, callbacks): self.trainer = get_instance( self.opt["trainer"], - model=self.model, - trainloader=getattr(self, "train_dataloader", None), - valloader=getattr(self, "val_dataloader", None), - metrics=getattr(self, "metrics", None), - optimizer=getattr(self, "optimizer", None), - scheduler=getattr(self, "scheduler", None), - debug=getattr(self, "debug", False), - registry=self.trainer_registry, + default_root_dir=getattr(self, "savedir", "runs"), + deterministic="warn", callbacks=callbacks, + registry=self.trainer_registry, ) def save_configs(self): - self.opt.save_yaml(os.path.join(self.savedir, "pipeline.yaml")) + with open(os.path.join(self.savedir, "pipeline.yaml"), "w") as f: + OmegaConf.save(config=self.opt, f=f) def init_registry(self): self.model_registry = MODEL_REGISTRY @@ -264,8 +236,6 @@ def init_registry(self): self.dataloader_registry = DATALOADER_REGISTRY self.metric_registry = METRIC_REGISTRY self.loss_registry = LOSS_REGISTRY - self.optimizer_registry = OPTIM_REGISTRY - self.scheduler_registry = SCHEDULER_REGISTRY self.callbacks_registry = CALLBACKS_REGISTRY self.trainer_registry = TRAINER_REGISTRY self.transform_registry = TRANSFORM_REGISTRY @@ -282,52 +252,60 @@ def init_pipeline(self, train=False): if train: self.init_train_dataloader() self.init_validation_dataloader() - self.init_model_with_loss() + self.init_datamodule() self.init_metrics() - self.init_optimizer() - self.init_loading() - self.init_scheduler() + self.init_model_with_loss() callbacks = self.init_callbacks() self.save_configs() else: self.init_validation_dataloader() - self.init_model_with_loss() + self.init_datamodule() self.init_metrics() - self.init_loading() + self.init_model_with_loss(is_train=train) callbacks = [] - if getattr(self, "metrics", None): + if getattr(self.model, "metrics", None): callbacks.insert( 0, - self.callbacks_registry.get("MetricLoggerCallbacks")( + self.callbacks_registry.get("MetricLoggerCallback")( save_dir=self.savedir ), ) - if getattr(self, "criterion", None): + if getattr(self.model, "criterion", None): callbacks.insert( 0, - self.callbacks_registry.get("LossLoggerCallbacks")( - print_interval=self.opt["global"].get("print_interval", None), + self.callbacks_registry.get("LossLoggerCallback")( + print_interval=self.opt["trainer"]["args"].get( + "log_every_n_steps", None + ), ), ) - if self.debug: - callbacks.insert(0, self.callbacks_registry.get("DebugCallbacks")()) - callbacks.insert(0, self.callbacks_registry.get("TimerCallbacks")()) + callbacks.insert(0, self.callbacks_registry.get("TimerCallback")()) + self.init_trainer(callbacks) self.initialized = True def fit(self): self.init_pipeline(train=True) - self.trainer.fit() + self.trainer.fit( + model=self.model, + datamodule=self.datamodule, + ckpt_path=self.resume, + ) def evaluate(self): self.init_pipeline(train=False) - self.logger.text("Evaluating...", level=LoggerObserver.INFO) - return self.trainer.evaluate_epoch() + self.trainer.validate( + model=self.model, + datamodule=self.datamodule, + ckpt_path=self.resume, + ) + + return self.trainer.callback_metrics class BaseTestPipeline(object): - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(BaseTestPipeline, self).__init__() self.opt = opt @@ -342,33 +320,23 @@ def init_globals(self): self.exp_name = self.opt["global"].get("exp_name", None) self.exist_ok = self.opt["global"].get("exist_ok", False) self.debug = self.opt["global"].get("debug", False) - self.device_name = self.opt["global"].get("device", "cpu") self.transform_cfg = self.opt["global"].get("cfg_transform", None) - self.device = get_device(self.device_name) # Experiment name if self.exp_name: self.savedir = os.path.join( - self.opt["global"].get("save_dir", "tests"), self.exp_name + self.opt["global"].get("save_dir", "runs"), self.exp_name ) if not self.exist_ok: self.savedir = get_new_folder_name(self.savedir) else: self.savedir = os.path.join( - self.opt["global"].get("save_dir", "tests"), + self.opt["global"].get("save_dir", "runs"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), ) os.makedirs(self.savedir, exist_ok=True) - if self.transform_cfg is not None: - self.logger.text( - "cfg_transform is deprecated, please use 'includes' instead", - level=LoggerObserver.WARN, - ) - self.transform_cfg = Config.load_yaml(self.transform_cfg) - self.opt["augmentations"] = self.transform_cfg - else: - self.transform_cfg = self.opt.get("augmentations", None) + self.transform_cfg = self.opt.get("augmentations", None) # Logging to files file_logger = FileLogger(__name__, self.savedir, debug=self.debug) @@ -422,12 +390,6 @@ def init_test_dataloader(self): level=LoggerObserver.INFO, ) - def init_loading(self): - self.weights = self.opt["global"].get("weights", None) - if self.weights: - state_dict = torch.load(self.weights, map_location="cpu") - self.model = load_state_dict(self.model, state_dict, "model") - def init_model(self): CLASSNAMES = getattr(self.dataset, "classnames", None) self.model = get_instance( @@ -436,9 +398,15 @@ def init_model(self): num_classes=len(CLASSNAMES) if CLASSNAMES is not None else None, classnames=CLASSNAMES, ) - self.model = move_to(self.model, self.device) + self.model = LightningModelWrapper(self.model) self.model.eval() + def init_loading(self): + self.weights = self.opt["global"].get("pretrained", None) + if self.weights: + state_dict = torch.load(self.weights, map_location="cpu") + self.model.load_state_dict(state_dict["state_dict"]) + def init_pipeline(self): self.init_globals() self.init_registry() diff --git a/theseus/base/trainer/__init__.py b/theseus/base/trainer/__init__.py index 37dc665..ef53029 100644 --- a/theseus/base/trainer/__init__.py +++ b/theseus/base/trainer/__init__.py @@ -1,8 +1,7 @@ +from lightning.pytorch.trainer import Trainer + from theseus.registry import Registry -from .base_trainer import BaseTrainer -from .supervised_trainer import SupervisedTrainer +TRAINER_REGISTRY = Registry("trainer") -TRAINER_REGISTRY = Registry("TRAINER") -TRAINER_REGISTRY.register(BaseTrainer) -TRAINER_REGISTRY.register(SupervisedTrainer) +TRAINER_REGISTRY.register(Trainer, prefix="pl") diff --git a/theseus/base/trainer/base_trainer.py b/theseus/base/trainer/base_trainer.py deleted file mode 100644 index 7c344ea..0000000 --- a/theseus/base/trainer/base_trainer.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import List, Optional, Tuple - -from theseus.base.callbacks import CallbacksList, TimerCallbacks -from theseus.base.callbacks.base_callbacks import Callbacks -from theseus.base.optimizers.scalers import NativeScaler -from theseus.base.utilities.loggers.observer import LoggerObserver - -LOGGER = LoggerObserver.getLogger("main") - - -class BaseTrainer: - """Base class for trainer - - use_fp16: `bool` - whether to use 16bit floating-point precision - num_iterations: `int` - total number of running epochs - clip_grad: `float` - Gradient clipping - evaluate_interval: `int` - Number of epochs to perform validation - resume: `str` - Path to checkpoint for continue training - """ - - def __init__( - self, - use_fp16: bool = False, - num_iterations: int = 10000, - clip_grad: float = 10.0, - evaluate_interval: int = 1, - callbacks: List[Callbacks] = [ - TimerCallbacks(), - ], - debug: bool = False, - **kwargs - ): - - self.num_iterations = num_iterations - self.use_amp = True if use_fp16 else False - self.scaler = NativeScaler(use_fp16) - self.clip_grad = clip_grad - self.evaluate_interval = evaluate_interval - self.iters = 0 - self.debug = debug - self.shutdown_all = False # Flag to stop trainer imediately - - if not isinstance(callbacks, CallbacksList): - callbacks = callbacks if isinstance(callbacks, list) else [callbacks] - callbacks = CallbacksList(callbacks) - self.callbacks = callbacks - self.callbacks.set_params({"trainer": self}) - - def fit(self): - - # Sanity check if debug is set - if self.debug: - self.callbacks.run( - "sanitycheck", - {"iters": self.iters, "num_iterations": self.num_iterations}, - ) - - # On start callbacks - self.callbacks.run("on_start") - - while self.iters < self.num_iterations: - try: - - # Check if shutdown flag has been turned on - if self.shutdown_all: - break - - # On epoch start callbacks - self.callbacks.run("on_epoch_start", {"iters": self.iters}) - - # Start training - self.training_epoch() - - # Start evaluation - if self.evaluate_interval != 0: - if self.iters % self.evaluate_interval == 0 and self.iters > 0: - self.evaluate_epoch() - - # On epoch end callbacks - self.callbacks.run("on_epoch_end", {"iters": self.iters}) - - except KeyboardInterrupt: - break - - # On finish callbacks - self.callbacks.run( - "on_finish", - {"iters": self.iters, "num_iterations": self.num_iterations}, - ) diff --git a/theseus/base/trainer/supervised_trainer.py b/theseus/base/trainer/supervised_trainer.py deleted file mode 100644 index 04a1837..0000000 --- a/theseus/base/trainer/supervised_trainer.py +++ /dev/null @@ -1,175 +0,0 @@ -import time - -import numpy as np -import torch -from torch.cuda import amp -from tqdm import tqdm - -from theseus.base.utilities.loggers.observer import LoggerObserver - -from .base_trainer import BaseTrainer - -LOGGER = LoggerObserver.getLogger("main") - - -class SupervisedTrainer(BaseTrainer): - """Trainer for supervised tasks - - model : `torch.nn.Module` - Wrapper model with loss - trainloader : `torch.utils.DataLoader` - DataLoader for training - valloader : `torch.utils.DataLoader` - DataLoader for validation - metrics: `List[Metric]` - list of metrics for evaluation - optimizer: `torch.optim.Optimizer` - optimizer for parameters update - scheduler: `torch.optim.lr_scheduler.Scheduler` - learning rate schedulers - - """ - - def __init__( - self, model, trainloader, valloader, metrics, optimizer, scheduler, **kwargs - ): - - super().__init__(**kwargs) - - self.model = model - self.metrics = metrics - self.optimizer = optimizer - self.scheduler = scheduler - self.trainloader = trainloader - self.valloader = valloader - self.use_cuda = next(self.model.parameters()).is_cuda - - if self.scheduler: - self.step_per_epoch = self.scheduler.step_per_epoch - - # Flags for shutting down training or validation stages - self.shutdown_training = False - self.shutdown_validation = False - - def training_epoch(self): - """ - Perform training one epoch - """ - self.model.train() - self.callbacks.run("on_train_epoch_start") - self.optimizer.zero_grad() - for i, batch in enumerate(self.trainloader): - - # Check if shutdown flag has been turned on - if self.shutdown_training or self.shutdown_all: - break - - self.callbacks.run( - "on_train_batch_start", - { - "batch": batch, - "iters": self.iters, - "num_iterations": self.num_iterations, - }, - ) - - # Gradient scaler - with amp.autocast(enabled=self.use_amp): - outputs = self.model.training_step(batch) - loss = outputs["loss"] - loss_dict = outputs["loss_dict"] - - # Backward loss - self.scaler(loss, self.optimizer) - - # Optmizer step - self.scaler.step( - self.optimizer, - clip_grad=self.clip_grad, - parameters=self.model.parameters(), - ) - if self.scheduler and not self.step_per_epoch: - self.scheduler.step() - self.optimizer.zero_grad() - - if self.use_cuda: - torch.cuda.synchronize() - - # Get learning rate - lrl = [x["lr"] for x in self.optimizer.param_groups] - lr = sum(lrl) / len(lrl) - - self.callbacks.run( - "on_train_batch_end", - { - "loss_dict": loss_dict, - "iters": self.iters, - "num_iterations": self.num_iterations, - "lr": lr, - }, - ) - - # Calculate current iteration - self.iters = self.iters + 1 - - if self.scheduler and self.step_per_epoch: - self.scheduler.step() - - self.callbacks.run( - "on_train_epoch_end", {"last_batch": batch, "iters": self.iters} - ) - - @torch.no_grad() - def evaluate_epoch(self): - """ - Perform validation one epoch - """ - self.model.eval() - - self.callbacks.run("on_val_epoch_start") - for batch in tqdm(self.valloader): - - # Check if shutdown flag has been turned on - if self.shutdown_validation or self.shutdown_all: - break - - self.callbacks.run( - "on_val_batch_start", - { - "batch": batch, - "iters": self.iters, - "num_iterations": self.num_iterations, - }, - ) - - # Gradient scaler - with amp.autocast(enabled=self.use_amp): - outputs = self.model.evaluate_step(batch, self.metrics) - loss_dict = outputs["loss_dict"] - - self.callbacks.run( - "on_val_batch_end", - { - "loss_dict": loss_dict, - "iters": self.iters, - "num_iterations": self.num_iterations, - }, - ) - - metric_dict = {} - for metric in self.metrics: - metric_dict.update(metric.value()) - metric.reset() - - self.callbacks.run( - "on_val_epoch_end", - { - "metric_dict": metric_dict, - "iters": self.iters, - "num_iterations": self.num_iterations, - "last_batch": batch, - "last_outputs": outputs["model_outputs"], - }, - ) - - return metric_dict diff --git a/theseus/base/utilities/download.py b/theseus/base/utilities/download.py index 16553a1..e05c4c0 100644 --- a/theseus/base/utilities/download.py +++ b/theseus/base/utilities/download.py @@ -1,10 +1,9 @@ import os import os.path as osp import urllib.request as urlreq +from pathlib import Path import gdown -import os -from pathlib import Path from theseus.base.utilities.loggers.observer import LoggerObserver @@ -62,28 +61,37 @@ def download_from_url(url, root=None, filename=None): return fpath -def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_text_file=False): +def download_from_wandb( + filename, run_path, save_dir, rename=None, generate_id_text_file=False +): import wandb - + try: path = wandb.restore(filename, run_path=run_path, root=save_dir) - LOGGER.text("Successfully download {} from wandb run path {}".format(filename, run_path), level=LoggerObserver.INFO) - + LOGGER.text( + "Successfully download {} from wandb run path {}".format( + filename, run_path + ), + level=LoggerObserver.INFO, + ) + # Save run id to wandb_id.txt if generate_id_text_file: wandb_id = osp.basename(run_path) with open(osp.join(save_dir, "wandb_id.txt"), "w") as f: f.write(wandb_id) - + if rename: new_name = str(Path(path.name).resolve().parent / rename) os.rename(Path(path.name).resolve(), new_name) LOGGER.text("Saved to {}".format(new_name), level=LoggerObserver.INFO) return new_name - - - LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO) + + LOGGER.text( + "Saved to {}".format((Path(save_dir) / path.name).resolve()), + level=LoggerObserver.INFO, + ) return path.name except Exception as e: LOGGER.text(f"Failed to download from wandb. {e}", level=LoggerObserver.ERROR) diff --git a/theseus/base/utilities/getter.py b/theseus/base/utilities/getter.py index bef2d84..ab76e7a 100644 --- a/theseus/base/utilities/getter.py +++ b/theseus/base/utilities/getter.py @@ -1,45 +1,61 @@ +import inspect + +from omegaconf import DictConfig, ListConfig + from theseus.registry import Registry +def get_instance_with_kwargs(registry, name, args: list = None, kwargs: dict = {}): + # get keyword arguments from class signature + inspection = inspect.signature(registry.get(name)) + class_kwargs = inspection.parameters.keys() + + if isinstance(args, (dict, DictConfig)): + # override kwargs (from parent) with args (from config) + kwargs.update(args) + args = None + + if "kwargs" in class_kwargs: + if args is None: + return registry.get(name)(**kwargs) + else: + return registry.get(name)(*args, **kwargs) + else: + filtered_kwargs = {k: v for k, v in kwargs.items() if k in class_kwargs} + if args is None: + return registry.get(name)(**filtered_kwargs) + else: + return registry.get(name)(*args, **filtered_kwargs) + + def get_instance(config, registry: Registry, **kwargs): # ref https://github.com/vltanh/torchan/blob/master/torchan/utils/getter.py assert "name" in config - config.setdefault("args", {}) - if config.get("args", None) is None: - config["args"] = {} + args = config.get("args", []) - return registry.get(config["name"])(**config.get("args", {}), **kwargs) + return get_instance_with_kwargs(registry, config["name"], args, kwargs) def get_instance_recursively(config, registry: Registry, **kwargs): - if isinstance(config, (list, tuple)): + if isinstance(config, (list, tuple, ListConfig)): out = [ get_instance_recursively(item, registry=registry, **kwargs) for item in config ] return out - if isinstance(config, dict): + if isinstance(config, (dict, DictConfig)): if "name" in config.keys(): if registry: args = get_instance_recursively( config.get("args", {}), registry, **kwargs ) - if args is None: - return registry.get(config["name"])(**kwargs) - if isinstance(args, list): - return registry.get(config["name"])(*args, **kwargs) - if isinstance(args, dict): - kwargs.update( - args - ) # override kwargs (from parent) with args (from config) - return registry.get(config["name"])(**kwargs) - raise ValueError(f"Unknown type: {type(args)}") + return get_instance_with_kwargs(registry, config["name"], args, kwargs) + else: out = {} for k, v in config.items(): out[k] = get_instance_recursively(v, registry=registry, **kwargs) return out - return globals()[config["name"]](**config["args"], **kwargs) return config diff --git a/theseus/base/utilities/loggers/observer.py b/theseus/base/utilities/loggers/observer.py index d84399e..ad2742c 100644 --- a/theseus/base/utilities/loggers/observer.py +++ b/theseus/base/utilities/loggers/observer.py @@ -31,12 +31,8 @@ def get_type(value): return LoggerObserver.HTML else: return LoggerObserver.TEXT - - LoggerObserver.text( - f"Fail to log undefined type: {type(value)}", - level=LoggerObserver.CRITICAL, - ) - raise ValueError() + else: + raise ValueError(f"Fail to log undefined type: {type(value)}") class LoggerObserver(object): diff --git a/theseus/base/utilities/loggers/wandb_logger.py b/theseus/base/utilities/loggers/wandb_logger.py index c3ff34e..6194314 100644 --- a/theseus/base/utilities/loggers/wandb_logger.py +++ b/theseus/base/utilities/loggers/wandb_logger.py @@ -94,11 +94,14 @@ def log_figure(self, tag, value, step=0, **kwargs): :param step: (int) logging step """ - if isinstance(value, torch.Tensor): - image = wandb_logger.Image(value) - wandb_logger.log({tag: image, "iterations": step}) - else: - wandb_logger.log({tag: value, "iterations": step}) + try: + if isinstance(value, torch.Tensor): + image = wandb_logger.Image(value) + wandb_logger.log({tag: image, "iterations": step}) + else: + wandb_logger.log({tag: value, "iterations": step}) + except Exception as e: + pass def log_torch_module(self, tag, value, log_freq, **kwargs): """ diff --git a/theseus/base/utilities/logits.py b/theseus/base/utilities/logits.py index 65f80d1..3867428 100644 --- a/theseus/base/utilities/logits.py +++ b/theseus/base/utilities/logits.py @@ -10,7 +10,7 @@ def multiclass_logits2labels(outputs, return_probs: bool = False): outputs = move_to(detach(outputs), torch.device("cpu")) if return_probs: - return outputs.long(), probs + return outputs.long().view(-1), probs return outputs @@ -27,6 +27,16 @@ def multilabel_logits2labels(outputs, threshold=0.5, return_probs: bool = False) return outputs +def binary_logits2labels(outputs, threshold=0.5, return_probs: bool = False): + assert threshold is not None, "Please specify threshold value for sigmoid" + preds = (outputs.view(-1) > threshold).long() + preds = move_to(detach(preds), torch.device("cpu")) + if return_probs: + probs = move_to(detach(outputs), torch.device("cpu")) + return preds.long(), probs.view(-1) + return preds + + def logits2labels( outputs, label_type="multiclass", @@ -37,4 +47,6 @@ def logits2labels( return multiclass_logits2labels(outputs, return_probs) if label_type == "multilabel": return multilabel_logits2labels(outputs, threshold, return_probs) + if label_type == "binary": + return binary_logits2labels(outputs, threshold, return_probs) return outputs diff --git a/theseus/base/utilities/optuna_tuner.py b/theseus/base/utilities/optuna_tuner.py index 0b184fb..f35079e 100644 --- a/theseus/base/utilities/optuna_tuner.py +++ b/theseus/base/utilities/optuna_tuner.py @@ -3,6 +3,7 @@ from copy import deepcopy import optuna +from omegaconf import DictConfig, OmegaConf from optuna.visualization import ( plot_contour, plot_edf, @@ -13,10 +14,8 @@ plot_slice, ) -from theseus.base.callbacks.optuna_callbacks import OptunaCallbacks from theseus.base.pipeline import BasePipeline from theseus.base.utilities.loggers import LoggerObserver -from theseus.opt import Config class OptunaWrapper: @@ -53,8 +52,9 @@ def __init__( def tune( self, - config: Config, + config: DictConfig, pipeline_class: BasePipeline, + optuna_callback: callable = None, trial_user_attrs: dict = {}, ): @@ -66,7 +66,7 @@ def tune( raise ValueError() wrapped_objective = lambda trial: self.objective( - trial, config, pipeline_class, trial_user_attrs + trial, config, pipeline_class, trial_user_attrs, optuna_callback ) self.study.optimize(wrapped_objective, n_trials=self.n_trials) @@ -75,7 +75,7 @@ def tune( self._rename_params() return best_trial - def save_best_config(self, save_dir: str, config: Config, best_params: dict): + def save_best_config(self, save_dir: str, config: DictConfig, best_params: dict): for param_str, param_val in best_params.items(): here = config keys = param_str.split(".") @@ -84,13 +84,16 @@ def save_best_config(self, save_dir: str, config: Config, best_params: dict): here[keys[-1]] = param_val save_dir = osp.join(save_dir, "best_configs") os.makedirs(save_dir, exist_ok=True) - config.save_yaml(osp.join(save_dir, "best_pipeline.yaml")) + + with open(os.path.join(save_dir, "best_pipeline.yaml"), "w") as f: + OmegaConf.save(config=config, f=f) + self.logger.text( f"Best configuration saved at {save_dir}", level=LoggerObserver.INFO ) def _override_dict_with_optuna( - self, trial, config: Config, param_str: str, variable_type: str + self, trial, config: DictConfig, param_str: str, variable_type: str ): """ Override config with optuna suggested params @@ -134,9 +137,10 @@ def _override_dict_with_optuna( def objective( self, trial: optuna.Trial, - config: Config, + config: DictConfig, pipeline_class: BasePipeline, trial_user_attrs: dict = {}, + optuna_callback: callable = None, ): """Define the objective function""" @@ -160,7 +164,9 @@ def objective( # Hook a callback inside pipeline pipeline = pipeline_class(tmp_config) pipeline.init_trainer = self.callback_hook( - trial=trial, init_trainer_function=pipeline.init_trainer + trial=trial, + init_trainer_function=pipeline.init_trainer, + callback_fn=optuna_callback, ) # Start training and evaluation @@ -170,11 +176,11 @@ def objective( best_key = trial_user_attrs.get("best_key", None) if best_key is not None: - return score_dict[best_key] + return float(score_dict[best_key]) return score_dict - def callback_hook(self, trial, init_trainer_function): - callback = OptunaCallbacks(trial=trial) + def callback_hook(self, trial, init_trainer_function, callback_fn): + callback = callback_fn(trial=trial) def hook_optuna_callback(callbacks): callbacks.append(callback) diff --git a/theseus/cv/classification/callbacks/__init__.py b/theseus/cv/classification/callbacks/__init__.py index 13a8fba..e985a9c 100644 --- a/theseus/cv/classification/callbacks/__init__.py +++ b/theseus/cv/classification/callbacks/__init__.py @@ -1,7 +1,7 @@ from theseus.base.callbacks import CALLBACKS_REGISTRY -from .gradcam_callbacks import GradCAMVisualizationCallbacks -from .visualize_callbacks import ClassificationVisualizerCallbacks +from .gradcam_callback import GradCAMVisualizationCallback +from .visualize_callback import ClassificationVisualizerCallback -CALLBACKS_REGISTRY.register(ClassificationVisualizerCallbacks) -CALLBACKS_REGISTRY.register(GradCAMVisualizationCallbacks) +CALLBACKS_REGISTRY.register(ClassificationVisualizerCallback) +CALLBACKS_REGISTRY.register(GradCAMVisualizationCallback) diff --git a/theseus/cv/classification/callbacks/gradcam_callbacks.py b/theseus/cv/classification/callbacks/gradcam_callback.py similarity index 79% rename from theseus/cv/classification/callbacks/gradcam_callbacks.py rename to theseus/cv/classification/callbacks/gradcam_callback.py index de0bbf9..04e8ea8 100644 --- a/theseus/cv/classification/callbacks/gradcam_callbacks.py +++ b/theseus/cv/classification/callbacks/gradcam_callback.py @@ -1,10 +1,12 @@ -from typing import Dict, List +from typing import Any, Dict, List, Optional +import lightning.pytorch as pl import matplotlib.pyplot as plt import torch +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchvision.transforms import functional as TFF -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.cv.base.utilities.visualization.visualizer import Visualizer from theseus.cv.classification.utilities.gradcam import CAMWrapper, show_cam_on_image @@ -12,7 +14,7 @@ LOGGER = LoggerObserver.getLogger("main") -class GradCAMVisualizationCallbacks(Callbacks): +class GradCAMVisualizationCallback(Callback): """ Callbacks for visualizing stuff during training Features: @@ -32,17 +34,29 @@ def __init__( self.mean = mean self.std = std + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.params = {} + self.params["last_batch"] = batch + @torch.enable_grad() # enable grad for CAM - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ - iters = logs["iters"] - last_batch = logs["last_batch"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader - optimizer = self.params["trainer"].optimizer + iters = trainer.global_step + last_batch = self.params["last_batch"] + model = pl_module.model + valloader = pl_module.datamodule.valloader + optimizer = pl_module.optimizer # Zeroing gradients in model and optimizer for supress warning optimizer.zero_grad() diff --git a/theseus/cv/classification/callbacks/visualize_callbacks.py b/theseus/cv/classification/callbacks/visualize_callback.py similarity index 84% rename from theseus/cv/classification/callbacks/visualize_callbacks.py rename to theseus/cv/classification/callbacks/visualize_callback.py index 67049b7..f3d084d 100644 --- a/theseus/cv/classification/callbacks/visualize_callbacks.py +++ b/theseus/cv/classification/callbacks/visualize_callback.py @@ -1,10 +1,12 @@ -from typing import Dict, List +from typing import Any, Dict, List +import lightning.pytorch as pl import matplotlib.pyplot as plt import torch +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchvision.transforms import functional as TFF -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.cuda import move_to from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.cv.base.utilities.visualization.visualizer import Visualizer @@ -13,7 +15,7 @@ LOGGER = LoggerObserver.getLogger("main") -class ClassificationVisualizerCallbacks(Callbacks): +class ClassificationVisualizerCallback(Callback): """ Callbacks for visualizing stuff during training Features: @@ -34,19 +36,20 @@ def __init__( self.mean = mean self.std = std - def sanitycheck(self, logs: Dict = None): + def on_sanity_check_start( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> None: + """ Sanitycheck before starting. Run only when debug=True """ - iters = logs["iters"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader - trainloader = self.params["trainer"].trainloader + iters = trainer.global_step + model = pl_module.model + valloader = pl_module.datamodule.valloader + trainloader = pl_module.datamodule.trainloader train_batch = next(iter(trainloader)) val_batch = next(iter(valloader)) - trainset = trainloader.dataset - valset = valloader.dataset try: self.visualize_model(model, train_batch) @@ -133,16 +136,28 @@ def visualize_gt(self, train_batch, val_batch, iters): plt.clf() # Clear figure plt.close() + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.params = {} + self.params["last_batch"] = batch + @torch.no_grad() # enable grad for CAM - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ - iters = logs["iters"] - last_batch = logs["last_batch"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader + iters = trainer.global_step + last_batch = self.params["last_batch"] + model = pl_module.model + valloader = pl_module.datamodule.valloader # Vizualize model predictions LOGGER.text("Visualizing model predictions...", level=LoggerObserver.DEBUG) diff --git a/theseus/cv/classification/metrics/projection.py b/theseus/cv/classification/metrics/projection.py index beaec98..2f7dba0 100644 --- a/theseus/cv/classification/metrics/projection.py +++ b/theseus/cv/classification/metrics/projection.py @@ -44,7 +44,8 @@ def update(self, outputs: Dict[str, Any], batch: Dict[str, Any]): torch.argmax(outputs["outputs"].detach().cpu(), dim=1).numpy().tolist() ) inputs = batch["inputs"] - targets = batch["targets"].numpy().tolist() + if self.has_labels: + targets = batch["targets"].numpy().tolist() img_names = batch["img_names"] for i, _ in enumerate(features): diff --git a/theseus/cv/classification/models/huggingface_models.py b/theseus/cv/classification/models/huggingface_models.py index 8e7c4c0..15dd9fc 100644 --- a/theseus/cv/classification/models/huggingface_models.py +++ b/theseus/cv/classification/models/huggingface_models.py @@ -74,7 +74,7 @@ def get_model(self): """ return self.model - def forward_features(self, batch: Dict, device: torch.device): + def forward_features(self, batch: Dict, device: torch.device = None): input_ids, attention_mask = batch["input_ids"], batch["attention_mask"] transformer_out = self.model(input_ids=input_ids, attention_mask=attention_mask) @@ -90,14 +90,15 @@ def forward_features(self, batch: Dict, device: torch.device): return features - def forward_batch(self, batch: Dict, device: torch.device): - batch = move_to(batch, device) + def forward_batch(self, batch: Dict, device: torch.device = None): + if device is not None: + batch = move_to(batch, device) features = self.forward_features(batch, device) outputs = self.head(features) return {"outputs": outputs, "features": features} - def get_prediction(self, adict: Dict[str, Any], device: torch.device): + def get_prediction(self, adict: Dict[str, Any], device: torch.device = None): """ Inference using the model. diff --git a/theseus/cv/classification/models/timm_models.py b/theseus/cv/classification/models/timm_models.py index ed8c0f5..c5dc562 100644 --- a/theseus/cv/classification/models/timm_models.py +++ b/theseus/cv/classification/models/timm_models.py @@ -89,15 +89,18 @@ def get_model(self): """ return self.model - def forward_batch(self, batch: Dict, device: torch.device): - x = move_to(batch["inputs"], device) + def forward_batch(self, batch: Dict, device: torch.device = None): + if device is not None: + x = move_to(batch["inputs"], device) + else: + x = batch["inputs"] self.features = None # Clear current features outputs = self.model(x) if self.num_classes == 0: self.features = outputs return {"outputs": outputs, "features": self.features} - def get_prediction(self, adict: Dict[str, Any], device: torch.device): + def get_prediction(self, adict: Dict[str, Any], device: torch.device = None): """ Inference using the model. diff --git a/theseus/cv/classification/pipeline.py b/theseus/cv/classification/pipeline.py index bb1a834..d53ee06 100644 --- a/theseus/cv/classification/pipeline.py +++ b/theseus/cv/classification/pipeline.py @@ -1,3 +1,5 @@ +from omegaconf import DictConfig + from theseus.base.pipeline import BasePipeline from theseus.base.utilities.loggers import LoggerObserver from theseus.cv.classification.augmentations import TRANSFORM_REGISTRY @@ -7,13 +9,12 @@ from theseus.cv.classification.metrics import METRIC_REGISTRY from theseus.cv.classification.models import MODEL_REGISTRY from theseus.cv.classification.trainer import TRAINER_REGISTRY -from theseus.opt import Config class ClassificationPipeline(BasePipeline): """docstring for Pipeline.""" - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(ClassificationPipeline, self).__init__(opt) self.opt = opt diff --git a/theseus/cv/detection/callbacks/__init__.py b/theseus/cv/detection/callbacks/__init__.py index dd8ad87..606ff86 100644 --- a/theseus/cv/detection/callbacks/__init__.py +++ b/theseus/cv/detection/callbacks/__init__.py @@ -1,5 +1,5 @@ from theseus.base.callbacks import CALLBACKS_REGISTRY -from .visualization import DetectionVisualizerCallbacks +from .visualization import DetectionVisualizerCallback -CALLBACKS_REGISTRY.register(DetectionVisualizerCallbacks) +CALLBACKS_REGISTRY.register(DetectionVisualizerCallback) diff --git a/theseus/cv/detection/callbacks/visualization.py b/theseus/cv/detection/callbacks/visualization.py index f77624c..b30e2ec 100644 --- a/theseus/cv/detection/callbacks/visualization.py +++ b/theseus/cv/detection/callbacks/visualization.py @@ -1,12 +1,14 @@ -from typing import Dict, List +from typing import Any, Dict, List +import lightning.pytorch as pl import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import torch +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchvision.transforms import functional as TFF -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.cv.base.utilities.visualization.colors import color_list from theseus.cv.base.utilities.visualization.visualizer import Visualizer @@ -14,7 +16,7 @@ LOGGER = LoggerObserver.getLogger("main") -class DetectionVisualizerCallbacks(Callbacks): +class DetectionVisualizerCallback(Callback): """ Callbacks for visualizing stuff during training Features: @@ -35,18 +37,17 @@ def __init__( self.mean = mean self.std = std - def sanitycheck(self, logs: Dict = None): + def on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ Sanitycheck before starting. Run only when debug=True """ - iters = logs["iters"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader - trainloader = self.params["trainer"].trainloader + iters = trainer.iterations + model = pl_module.model + valloader = pl_module.datamodule.valloader + trainloader = pl_module.datamodule.trainloader train_batch = next(iter(trainloader)) val_batch = next(iter(valloader)) - trainset = trainloader.dataset valset = valloader.dataset classnames = valset.classnames @@ -156,16 +157,28 @@ def visualize_gt(self, train_batch, val_batch, iters, classnames): plt.clf() # Clear figure plt.close() + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.params = {} + self.params["last_batch"] = batch + @torch.no_grad() - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ - iters = logs["iters"] - last_batch = logs["last_batch"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader + iters = trainer.global_step + last_batch = self.params["last_batch"] + model = pl_module.model + valloader = pl_module.datamodule.valloader # Vizualize model predictions LOGGER.text("Visualizing model predictions...", level=LoggerObserver.DEBUG) diff --git a/theseus/cv/detection/pipeline.py b/theseus/cv/detection/pipeline.py index e60a5eb..c584507 100644 --- a/theseus/cv/detection/pipeline.py +++ b/theseus/cv/detection/pipeline.py @@ -1,3 +1,5 @@ +from omegaconf import DictConfig + from theseus.base.pipeline import BasePipeline from theseus.base.utilities.cuda import get_devices_info from theseus.base.utilities.loggers import LoggerObserver @@ -8,13 +10,12 @@ from theseus.cv.detection.metrics import METRIC_REGISTRY from theseus.cv.detection.models import MODEL_REGISTRY, ModelWithLossandPostprocess from theseus.cv.detection.trainer import TRAINER_REGISTRY -from theseus.opt import Config class DetectionPipeline(BasePipeline): """docstring for Pipeline.""" - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(DetectionPipeline, self).__init__(opt) self.opt = opt diff --git a/theseus/cv/semantic/callbacks/__init__.py b/theseus/cv/semantic/callbacks/__init__.py index 3a6c659..2f439d2 100644 --- a/theseus/cv/semantic/callbacks/__init__.py +++ b/theseus/cv/semantic/callbacks/__init__.py @@ -1,6 +1,4 @@ from theseus.base.callbacks import CALLBACKS_REGISTRY -from theseus.cv.semantic.callbacks.visualize_callbacks import ( - SemanticVisualizerCallbacks, -) +from theseus.cv.semantic.callbacks.visualize_callbacks import SemanticVisualizerCallback -CALLBACKS_REGISTRY.register(SemanticVisualizerCallbacks) +CALLBACKS_REGISTRY.register(SemanticVisualizerCallback) diff --git a/theseus/cv/semantic/callbacks/visualize_callbacks.py b/theseus/cv/semantic/callbacks/visualize_callbacks.py index c7adf69..cb5be96 100644 --- a/theseus/cv/semantic/callbacks/visualize_callbacks.py +++ b/theseus/cv/semantic/callbacks/visualize_callbacks.py @@ -1,12 +1,14 @@ -from typing import Dict +from typing import Any, Dict +import lightning.pytorch as pl import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import torch +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT from torchvision.transforms import functional as TFF -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver from theseus.cv.base.utilities.visualization.colors import color_list from theseus.cv.base.utilities.visualization.visualizer import Visualizer @@ -14,7 +16,7 @@ LOGGER = LoggerObserver.getLogger("main") -class SemanticVisualizerCallbacks(Callbacks): +class SemanticVisualizerCallback(Callback): """ Callbacks for visualizing stuff during training Features: @@ -28,18 +30,17 @@ def __init__(self, **kwargs) -> None: self.visualizer = Visualizer() - def sanitycheck(self, logs: Dict = None): + def on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ Sanitycheck before starting. Run only when debug=True """ - iters = logs["iters"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader - trainloader = self.params["trainer"].trainloader + iters = trainer.global_step + model = pl_module.model + valloader = pl_module.datamodule.valloader + trainloader = pl_module.datamodule.trainloader train_batch = next(iter(trainloader)) val_batch = next(iter(valloader)) - trainset = trainloader.dataset valset = valloader.dataset classnames = valset.classnames @@ -155,16 +156,28 @@ def visualize_gt(self, train_batch, val_batch, iters, classnames): plt.clf() # Clear figure plt.close() + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.params = {} + self.params["last_batch"] = batch + @torch.no_grad() - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """ After finish validation """ - iters = logs["iters"] - last_batch = logs["last_batch"] - model = self.params["trainer"].model - valloader = self.params["trainer"].valloader + iters = trainer.global_step + last_batch = self.params["last_batch"] + model = pl_module.model + valloader = pl_module.datamodule.valloader # Vizualize model predictions LOGGER.text("Visualizing model predictions...", level=LoggerObserver.DEBUG) diff --git a/theseus/cv/semantic/models/segmodels.py b/theseus/cv/semantic/models/segmodels.py index 009fc38..eeee478 100644 --- a/theseus/cv/semantic/models/segmodels.py +++ b/theseus/cv/semantic/models/segmodels.py @@ -57,14 +57,17 @@ def get_model(self): """ return self.model - def forward_batch(self, batch: Dict, device: torch.device): - x = move_to(batch["inputs"], device) + def forward_batch(self, batch: Dict, device: torch.device = None): + if device is not None: + x = move_to(batch["inputs"], device) + else: + x = batch["inputs"] outputs = self.model(x) return { "outputs": outputs, } - def get_prediction(self, adict: Dict[str, Any], device: torch.device): + def get_prediction(self, adict: Dict[str, Any], device: torch.device = None): """ Inference using the model. adict: `Dict[str, Any]` diff --git a/theseus/cv/semantic/pipeline.py b/theseus/cv/semantic/pipeline.py index 501aabe..8d8bf65 100644 --- a/theseus/cv/semantic/pipeline.py +++ b/theseus/cv/semantic/pipeline.py @@ -1,3 +1,5 @@ +from omegaconf import DictConfig + from theseus.base.pipeline import BasePipeline from theseus.base.utilities.loggers import LoggerObserver from theseus.cv.semantic.augmentations import TRANSFORM_REGISTRY @@ -7,13 +9,12 @@ from theseus.cv.semantic.metrics import METRIC_REGISTRY from theseus.cv.semantic.models import MODEL_REGISTRY from theseus.cv.semantic.trainer import TRAINER_REGISTRY -from theseus.opt import Config class SemanticPipeline(BasePipeline): """docstring for Pipeline.""" - def __init__(self, opt: Config): + def __init__(self, opt: DictConfig): super(SemanticPipeline, self).__init__(opt) self.opt = opt diff --git a/theseus/tabular/base/__init__.py b/theseus/ml/__init__.py similarity index 100% rename from theseus/tabular/base/__init__.py rename to theseus/ml/__init__.py diff --git a/theseus/ml/callbacks/__init__.py b/theseus/ml/callbacks/__init__.py new file mode 100644 index 0000000..33c6032 --- /dev/null +++ b/theseus/ml/callbacks/__init__.py @@ -0,0 +1,22 @@ +from theseus.registry import Registry + +from .base_callbacks import Callbacks, CallbacksList +from .checkpoint_callbacks import SKLearnCheckpointCallbacks +from .explainer import ( + LIMEExplainer, + PartialDependencePlots, + PermutationImportance, + ShapValueExplainer, +) +from .metric_callbacks import MetricLoggerCallbacks +from .optuna_callbacks import OptunaCallbacks + +CALLBACKS_REGISTRY = Registry("CALLBACKS") + +CALLBACKS_REGISTRY.register(SKLearnCheckpointCallbacks) +CALLBACKS_REGISTRY.register(ShapValueExplainer) +CALLBACKS_REGISTRY.register(PermutationImportance) +CALLBACKS_REGISTRY.register(PartialDependencePlots) +CALLBACKS_REGISTRY.register(LIMEExplainer) +CALLBACKS_REGISTRY.register(OptunaCallbacks) +CALLBACKS_REGISTRY.register(MetricLoggerCallbacks) diff --git a/theseus/base/callbacks/base_callbacks.py b/theseus/ml/callbacks/base_callbacks.py similarity index 97% rename from theseus/base/callbacks/base_callbacks.py rename to theseus/ml/callbacks/base_callbacks.py index 4eab38c..bbb6b77 100644 --- a/theseus/base/callbacks/base_callbacks.py +++ b/theseus/ml/callbacks/base_callbacks.py @@ -19,11 +19,11 @@ "on_train_batch_start", "on_train_batch_end", "on_train_step", - "on_val_epoch_start", - "on_val_epoch_end", - "on_val_batch_start", - "on_val_batch_end", - "on_val_step", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_validation_batch_start", + "on_validation_batch_end", + "on_validation_step", ] diff --git a/theseus/tabular/classification/callbacks/checkpoint_callbacks.py b/theseus/ml/callbacks/checkpoint_callbacks.py similarity index 95% rename from theseus/tabular/classification/callbacks/checkpoint_callbacks.py rename to theseus/ml/callbacks/checkpoint_callbacks.py index 98be09d..2e711e3 100644 --- a/theseus/tabular/classification/callbacks/checkpoint_callbacks.py +++ b/theseus/ml/callbacks/checkpoint_callbacks.py @@ -2,8 +2,8 @@ import os.path as osp from typing import Dict -from theseus.base.callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") diff --git a/theseus/tabular/classification/callbacks/explainer/__init__.py b/theseus/ml/callbacks/explainer/__init__.py similarity index 100% rename from theseus/tabular/classification/callbacks/explainer/__init__.py rename to theseus/ml/callbacks/explainer/__init__.py diff --git a/theseus/tabular/classification/callbacks/explainer/lime.py b/theseus/ml/callbacks/explainer/lime.py similarity index 93% rename from theseus/tabular/classification/callbacks/explainer/lime.py rename to theseus/ml/callbacks/explainer/lime.py index b0e5c0a..ddfccf8 100644 --- a/theseus/tabular/classification/callbacks/explainer/lime.py +++ b/theseus/ml/callbacks/explainer/lime.py @@ -4,8 +4,8 @@ from lime import lime_tabular -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -26,13 +26,14 @@ def explain_instance( feature_names=feature_names, class_names=class_names, mode="classification" if class_names is not None else "regression", + discretize_continuous=False, ) return self.explainer.explain_instance( data_row=item, predict_fn=model.predict_proba ) - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ diff --git a/theseus/tabular/classification/callbacks/explainer/pdp.py b/theseus/ml/callbacks/explainer/pdp.py similarity index 96% rename from theseus/tabular/classification/callbacks/explainer/pdp.py rename to theseus/ml/callbacks/explainer/pdp.py index 1bb7b23..5940636 100644 --- a/theseus/tabular/classification/callbacks/explainer/pdp.py +++ b/theseus/ml/callbacks/explainer/pdp.py @@ -4,8 +4,8 @@ import matplotlib.pyplot as plt from sklearn.inspection import PartialDependenceDisplay, partial_dependence -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -78,7 +78,7 @@ def on_train_epoch_end(self, logs: Dict = None): ) plt.clf() - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ diff --git a/theseus/tabular/classification/callbacks/explainer/permutation.py b/theseus/ml/callbacks/explainer/permutation.py similarity index 96% rename from theseus/tabular/classification/callbacks/explainer/permutation.py rename to theseus/ml/callbacks/explainer/permutation.py index 7d520b3..e9d4c2b 100644 --- a/theseus/tabular/classification/callbacks/explainer/permutation.py +++ b/theseus/ml/callbacks/explainer/permutation.py @@ -5,8 +5,8 @@ import plotly.graph_objects as go from sklearn.inspection import permutation_importance -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -60,7 +60,7 @@ def on_train_epoch_end(self, logs: Dict = None): ) plt.clf() - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ diff --git a/theseus/tabular/classification/callbacks/explainer/shapley.py b/theseus/ml/callbacks/explainer/shapley.py similarity index 96% rename from theseus/tabular/classification/callbacks/explainer/shapley.py rename to theseus/ml/callbacks/explainer/shapley.py index 6ebf106..efb50ba 100644 --- a/theseus/tabular/classification/callbacks/explainer/shapley.py +++ b/theseus/ml/callbacks/explainer/shapley.py @@ -6,8 +6,8 @@ import shap from sklearn.inspection import permutation_importance -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -64,7 +64,7 @@ def on_train_epoch_end(self, logs: Dict = None): ) plt.clf() - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ @@ -76,6 +76,7 @@ def on_val_epoch_end(self, logs: Dict = None): shap_values = self.explainer.shap_values( x_val, check_additivity=self.check_additivity ) + plt.clf() shap.summary_plot( shap_values, plot_type=self.plot_type, diff --git a/theseus/base/callbacks/metric_logging_callbacks.py b/theseus/ml/callbacks/metric_callbacks.py similarity index 95% rename from theseus/base/callbacks/metric_logging_callbacks.py rename to theseus/ml/callbacks/metric_callbacks.py index b692412..dd49c27 100644 --- a/theseus/base/callbacks/metric_logging_callbacks.py +++ b/theseus/ml/callbacks/metric_callbacks.py @@ -3,8 +3,8 @@ import os.path as osp from typing import Dict, List -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -26,7 +26,7 @@ def __init__(self, save_json: bool = True, **kwargs) -> None: os.makedirs(self.save_dir, exist_ok=True) self.output_dict = [] - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ diff --git a/theseus/base/callbacks/optuna_callbacks.py b/theseus/ml/callbacks/optuna_callbacks.py similarity index 88% rename from theseus/base/callbacks/optuna_callbacks.py rename to theseus/ml/callbacks/optuna_callbacks.py index 874eef0..16f284a 100644 --- a/theseus/base/callbacks/optuna_callbacks.py +++ b/theseus/ml/callbacks/optuna_callbacks.py @@ -2,8 +2,8 @@ import optuna -from theseus.base.callbacks.base_callbacks import Callbacks from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import Callbacks LOGGER = LoggerObserver.getLogger("main") @@ -17,7 +17,7 @@ def __init__(self, trial: optuna.Trial, **kwargs) -> None: super().__init__() self.trial = trial - def on_val_epoch_end(self, logs: Dict = None): + def on_validation_epoch_end(self, logs: Dict = None): """ After finish validation """ diff --git a/theseus/tabular/classification/datasets/__init__.py b/theseus/ml/datasets/__init__.py similarity index 100% rename from theseus/tabular/classification/datasets/__init__.py rename to theseus/ml/datasets/__init__.py diff --git a/theseus/tabular/classification/datasets/csv_dataset.py b/theseus/ml/datasets/csv_dataset.py similarity index 100% rename from theseus/tabular/classification/datasets/csv_dataset.py rename to theseus/ml/datasets/csv_dataset.py diff --git a/theseus/tabular/classification/metrics/__init__.py b/theseus/ml/metrics/__init__.py similarity index 66% rename from theseus/tabular/classification/metrics/__init__.py rename to theseus/ml/metrics/__init__.py index fb0e3e8..c396538 100644 --- a/theseus/tabular/classification/metrics/__init__.py +++ b/theseus/ml/metrics/__init__.py @@ -1,12 +1,18 @@ from theseus.base.metrics import METRIC_REGISTRY from .acccuracy import SKLAccuracy, SKLBalancedAccuracyMetric +from .confusion_matrix import SKLConfusionMatrix from .f1_score import SKLF1ScoreMetric +from .mcc import SKLMCC from .precision_recall import SKLPrecisionRecall from .projection import SKLEmbeddingProjection +from .roc_auc_score import SKLROCAUCScore METRIC_REGISTRY.register(SKLPrecisionRecall) METRIC_REGISTRY.register(SKLF1ScoreMetric) METRIC_REGISTRY.register(SKLAccuracy) METRIC_REGISTRY.register(SKLBalancedAccuracyMetric) METRIC_REGISTRY.register(SKLEmbeddingProjection) +METRIC_REGISTRY.register(SKLMCC) +METRIC_REGISTRY.register(SKLROCAUCScore) +METRIC_REGISTRY.register(SKLConfusionMatrix) diff --git a/theseus/tabular/classification/metrics/acccuracy.py b/theseus/ml/metrics/acccuracy.py similarity index 68% rename from theseus/tabular/classification/metrics/acccuracy.py rename to theseus/ml/metrics/acccuracy.py index 825ec06..558f59e 100644 --- a/theseus/tabular/classification/metrics/acccuracy.py +++ b/theseus/ml/metrics/acccuracy.py @@ -2,6 +2,7 @@ import numpy as np from scipy.special import softmax +from sklearn.metrics import balanced_accuracy_score from theseus.base.metrics.metric_template import Metric @@ -54,23 +55,6 @@ def value(self, outputs: Dict[str, Any], batch: Dict[str, Any]): targets = batch["targets"] predictions = np.argmax(outputs, axis=-1).reshape(-1).tolist() targets = targets.reshape(-1).tolist() + blacc_score = balanced_accuracy_score(targets, predictions) - unique_ids = np.unique(targets) - corrects = {str(k): 0 for k in unique_ids} - total = {str(k): 0 for k in unique_ids} - - # Calculate accuracy for each class index - for i in unique_ids: - correct, sample_size = compute_multiclass(predictions, targets, i) - corrects[str(i)] += correct - total[str(i)] += sample_size - each_acc = [ - corrects[str(i)] * 1.0 / (total[str(i)]) - for i in unique_ids - if total[str(i)] > 0 - ] - - # Get mean accuracy across classes - values = sum(each_acc) / len(unique_ids) - - return {"bl_acc": values} + return {"bl_acc": blacc_score} diff --git a/theseus/ml/metrics/confusion_matrix.py b/theseus/ml/metrics/confusion_matrix.py new file mode 100644 index 0000000..bc1d769 --- /dev/null +++ b/theseus/ml/metrics/confusion_matrix.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from scipy.special import softmax +from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix + +from theseus.base.metrics.metric_template import Metric + + +def plot_cfm(cm, ax, labels: List): + """ + Make confusion matrix figure + labels: `Optional[List]` + classnames for visualization + """ + + ax = sns.heatmap(cm, annot=False, fmt="", cmap="Blues", ax=ax) + + ax.set_xlabel("\nActual") + ax.set_ylabel("Predicted ") + + ax.xaxis.set_ticklabels(labels) + ax.yaxis.set_ticklabels(labels, rotation=0) + + +def make_cm_fig(cms, labels: Optional[List] = None): + + if cms.shape[0] > 1: # multilabel + num_classes = cms.shape[0] + else: + num_classes = cms.shape[1] + + ## Ticket labels - List must be in alphabetical order + if not labels: + labels = [str(i) for i in range(num_classes)] + + ## + num_cfms = cms.shape[0] + nrow = int(np.ceil(np.sqrt(num_cfms))) + + # Clear figures first to prevent memory-consuming + plt.cla() + plt.clf() + plt.close() + + fig, axes = plt.subplots(nrow, nrow, figsize=(8, 8)) + + if num_cfms > 1: + for ax, cfs_matrix, label in zip(axes.flatten(), cms, labels): + ax.set_title(f"{label}\n\n") + plot_cfm(cfs_matrix, ax, labels=["N", "Y"]) + else: + plot_cfm(cms[0], axes, labels=labels) + + fig.tight_layout() + return fig + + +class SKLConfusionMatrix(Metric): + """ + Confusion Matrix metric for classification + """ + + def __init__(self, classnames=None, label_type: str = "multiclass", **kwargs): + super().__init__(**kwargs) + self.type = label_type + self.classnames = classnames + self.num_classes = ( + [i for i in range(len(self.classnames))] if classnames is not None else None + ) + + def value(self, output: Dict[str, Any], batch: Dict[str, Any]): + """ + Perform calculation based on prediction and targets + """ + output = output["outputs"] + target = batch["targets"] + + probs = softmax(output, axis=-1) + predictions = np.argmax(probs, axis=-1) + + if self.type == "multiclass": + values = confusion_matrix( + predictions, + target, + labels=self.num_classes, + normalize="pred", + ) + values = values[np.newaxis, :, :] + else: + values = multilabel_confusion_matrix( + predictions, target, labels=self.num_classes + ) + + fig = make_cm_fig(values, self.classnames) + return {"cfm": fig} diff --git a/theseus/tabular/classification/metrics/f1_score.py b/theseus/ml/metrics/f1_score.py similarity index 80% rename from theseus/tabular/classification/metrics/f1_score.py rename to theseus/ml/metrics/f1_score.py index 2299731..8301f7b 100644 --- a/theseus/tabular/classification/metrics/f1_score.py +++ b/theseus/ml/metrics/f1_score.py @@ -14,8 +14,6 @@ class SKLF1ScoreMetric(Metric): def __init__(self, average="weighted", **kwargs): super().__init__(**kwargs) self.average = average - self.preds = [] - self.targets = [] def value(self, outputs: Dict[str, Any], batch: Dict[str, Any]): """ @@ -24,8 +22,8 @@ def value(self, outputs: Dict[str, Any], batch: Dict[str, Any]): targets = batch["targets"] outputs = outputs["outputs"] - self.preds += np.argmax(outputs, axis=1).reshape(-1).tolist() - self.targets += targets.reshape(-1).tolist() + self.preds = np.argmax(outputs, axis=1).reshape(-1).tolist() + self.targets = targets.reshape(-1).tolist() score = f1_score(self.targets, self.preds, average=self.average) return {f"{self.average}-f1": score} diff --git a/theseus/ml/metrics/mcc.py b/theseus/ml/metrics/mcc.py new file mode 100644 index 0000000..21a15c9 --- /dev/null +++ b/theseus/ml/metrics/mcc.py @@ -0,0 +1,28 @@ +from typing import Any, Dict + +import numpy as np +from sklearn.metrics import matthews_corrcoef + +from theseus.base.metrics.metric_template import Metric + + +class SKLMCC(Metric): + """ + Mathew Correlation Coefficient + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def value(self, outputs: Dict[str, Any], batch: Dict[str, Any]): + """ + Perform calculation based on prediction and targets + """ + targets = batch["targets"] + outputs = outputs["outputs"] + + self.preds = np.argmax(outputs, axis=1).reshape(-1).tolist() + self.targets = targets.reshape(-1).tolist() + + score = matthews_corrcoef(self.targets, self.preds) + return {f"mcc": score} diff --git a/theseus/tabular/classification/metrics/precision_recall.py b/theseus/ml/metrics/precision_recall.py similarity index 100% rename from theseus/tabular/classification/metrics/precision_recall.py rename to theseus/ml/metrics/precision_recall.py diff --git a/theseus/tabular/classification/metrics/projection.py b/theseus/ml/metrics/projection.py similarity index 100% rename from theseus/tabular/classification/metrics/projection.py rename to theseus/ml/metrics/projection.py diff --git a/theseus/ml/metrics/roc_auc_score.py b/theseus/ml/metrics/roc_auc_score.py new file mode 100644 index 0000000..17b511b --- /dev/null +++ b/theseus/ml/metrics/roc_auc_score.py @@ -0,0 +1,69 @@ +from typing import Any, Dict + +import numpy as np +import scipy + +from theseus.base.metrics.metric_template import Metric + +try: + from scikitplot.metrics import plot_precision_recall_curve, plot_roc_curve + + has_scikitplot = True +except: + has_scikitplot = False +from sklearn.metrics import roc_auc_score + + +class SKLROCAUCScore(Metric): + """ + ROC AUC Score + """ + + def __init__( + self, + average: str = "weighted", + label_type: str = "ovr", + plot_curve: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.plot_curve = plot_curve + self.label_type = label_type + self.average = average + assert self.label_type in [ + "raise", + "ovr", + "ovo", + ], "Invalid type for multiclass ROC AUC score" + + def value(self, outputs: Dict[str, Any], batch: Dict[str, Any]): + """ + Perform calculation based on prediction and targets + """ + targets = batch["targets"] + outputs = outputs["outputs"] + + if self.label_type == "ovr": + outputs = scipy.special.softmax(outputs, axis=-1) + + self.preds = outputs.tolist() + self.targets = targets.reshape(-1).tolist() + + roc_auc_scr = roc_auc_score( + self.targets, self.preds, average=self.average, multi_class=self.label_type + ) + results = { + f"{self.average}-roc_auc_score": roc_auc_scr, + } + + if has_scikitplot and self.plot_curve: + roc_curve_fig = plot_roc_curve(self.targets, self.preds).get_figure() + pr_fig = plot_precision_recall_curve(self.targets, self.preds).get_figure() + results.update( + { + "roc_curve": roc_curve_fig, + "precision_recall_curve": pr_fig, + } + ) + + return results diff --git a/theseus/tabular/classification/models/__init__.py b/theseus/ml/models/__init__.py similarity index 100% rename from theseus/tabular/classification/models/__init__.py rename to theseus/ml/models/__init__.py diff --git a/theseus/tabular/classification/models/gbms.py b/theseus/ml/models/gbms.py similarity index 87% rename from theseus/tabular/classification/models/gbms.py rename to theseus/ml/models/gbms.py index 5038433..6c09ff9 100644 --- a/theseus/tabular/classification/models/gbms.py +++ b/theseus/ml/models/gbms.py @@ -1,6 +1,7 @@ import catboost as cb import lightgbm as lgb import xgboost as xgb +from omegaconf import DictConfig, OmegaConf from theseus.base.utilities.loggers.observer import LoggerObserver @@ -9,8 +10,14 @@ class GBClassifiers: def __init__( - self, model_name, num_classes, model_config={}, training_params={}, **kwargs + self, + model_name, + num_classes, + model_config: DictConfig = {}, + training_params={}, + **kwargs, ): + OmegaConf.set_struct(model_config, False) self.training_params = training_params self.model_name = model_name self.num_classes = num_classes @@ -31,8 +38,8 @@ def get_model(self): def fit(self, trainset, valset, **kwargs): X, y = trainset self.model.fit( - X, - y, + X.copy(), + y.copy(), eval_set=[trainset, valset], # eval_set=[(trainset, 'train'), (valset, 'validation')], **self.training_params, diff --git a/theseus/ml/pipeline.py b/theseus/ml/pipeline.py new file mode 100644 index 0000000..6ea1b63 --- /dev/null +++ b/theseus/ml/pipeline.py @@ -0,0 +1,191 @@ +import os +from datetime import datetime + +from omegaconf import DictConfig, OmegaConf + +from theseus.base.utilities.folder import get_new_folder_name +from theseus.base.utilities.getter import get_instance, get_instance_recursively +from theseus.base.utilities.loggers import FileLogger, ImageWriter, LoggerObserver +from theseus.base.utilities.seed import seed_everything +from theseus.ml.callbacks import CALLBACKS_REGISTRY +from theseus.ml.datasets import DATALOADER_REGISTRY, DATASET_REGISTRY +from theseus.ml.metrics import METRIC_REGISTRY +from theseus.ml.models import MODEL_REGISTRY +from theseus.ml.preprocessors import TRANSFORM_REGISTRY +from theseus.ml.trainer import TRAINER_REGISTRY + + +class MLPipeline(object): + """docstring for Pipeline.""" + + def __init__(self, opt: DictConfig): + self.opt = opt + self.seed = self.opt["global"].get("seed", 1702) + seed_everything(self.seed) + self.initialized = False + + def init_globals(self): + # Main Loggers + self.logger = LoggerObserver.getLogger("main") + + # Global variables + self.exp_name = self.opt["global"].get("exp_name", None) + self.exist_ok = self.opt["global"].get("exist_ok", False) + self.debug = self.opt["global"].get("debug", False) + self.resume = self.opt["global"].get("resume", None) + self.pretrained = self.opt["global"].get("pretrained", None) + self.transform_cfg = self.opt["global"].get("cfg_transform", None) + + # Experiment name + if self.exp_name: + self.savedir = os.path.join( + self.opt["global"].get("save_dir", "runs"), self.exp_name + ) + if not self.exist_ok: + self.savedir = get_new_folder_name(self.savedir) + else: + self.savedir = os.path.join( + self.opt["global"].get("save_dir", "runs"), + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + os.makedirs(self.savedir, exist_ok=True) + + # Logging to files + file_logger = FileLogger(__name__, self.savedir, debug=self.debug) + self.logger.subscribe(file_logger) + + # Logging images + image_logger = ImageWriter(self.savedir) + self.logger.subscribe(image_logger) + + self.transform_cfg = self.opt.get("augmentations", None) + + # Logging out configs + self.logger.text("\n" + OmegaConf.to_yaml(self.opt), level=LoggerObserver.INFO) + self.logger.text( + f"Everything will be saved to {self.savedir}", + level=LoggerObserver.INFO, + ) + + def init_registry(self): + self.callbacks_registry = CALLBACKS_REGISTRY + self.transform_registry = TRANSFORM_REGISTRY + self.model_registry = MODEL_REGISTRY + self.metric_registry = METRIC_REGISTRY + self.trainer_registry = TRAINER_REGISTRY + self.dataset_registry = DATASET_REGISTRY + self.dataloader_registry = DATALOADER_REGISTRY + self.logger.text("Overidding registry in pipeline...", LoggerObserver.INFO) + + def init_model(self): + classnames = self.val_dataset["classnames"] + num_classes = len(classnames) + self.model = get_instance( + self.opt["model"], num_classes=num_classes, registry=self.model_registry + ) + + def init_train_dataloader(self): + self.transform = get_instance_recursively( + self.transform_cfg, registry=self.transform_registry + ) + self.train_dataset = get_instance_recursively( + self.opt["data"]["dataset"]["train"], + registry=self.dataset_registry, + transform=self.transform["train"], + ).load_data() + + self.logger.text( + f"Training shape: {self.train_dataset['inputs'].shape}", + level=LoggerObserver.INFO, + ) + + def init_metrics(self): + CLASSNAMES = getattr(self.val_dataset, "classnames", None) + self.metrics = get_instance_recursively( + self.opt["metrics"], + registry=self.metric_registry, + num_classes=len(CLASSNAMES) if CLASSNAMES is not None else None, + classnames=CLASSNAMES, + ) + + def init_callbacks(self): + callbacks = get_instance_recursively( + self.opt["callbacks"], + save_dir=getattr(self, "savedir", "runs"), + resume=getattr(self, "resume", None), + config_dict=self.opt, + registry=self.callbacks_registry, + ) + return callbacks + + def init_validation_dataloader(self): + self.transform = get_instance_recursively( + self.transform_cfg, registry=self.transform_registry + ) + self.val_dataset = get_instance_recursively( + self.opt["data"]["dataset"]["val"], + registry=self.dataset_registry, + transform=self.transform["val"], + ).load_data() + + classnames = self.val_dataset["classnames"] + num_classes = len(classnames) + + self.logger.text( + f"Validation shape: {self.val_dataset['inputs'].shape}", + level=LoggerObserver.INFO, + ) + self.logger.text( + f"Number of classes: {num_classes}", + level=LoggerObserver.INFO, + ) + + def init_trainer(self, callbacks=None): + self.trainer = get_instance( + self.opt["trainer"], + model=self.model, + trainset=getattr(self, "train_dataset", None), + valset=getattr(self, "val_dataset", None), + metrics=self.metrics, + callbacks=callbacks, + registry=self.trainer_registry, + ) + + def init_loading(self): + if getattr(self, "pretrained", None): + self.model.load_model(self.pretrained) + + def init_pipeline(self, train=False): + if self.initialized: + return + self.init_globals() + self.init_registry() + if train: + self.init_train_dataloader() + self.init_validation_dataloader() + self.init_model() + self.init_loading() + self.init_metrics() + callbacks = self.init_callbacks() + self.save_configs() + else: + self.init_validation_dataloader() + self.init_model() + self.init_metrics() + self.init_loading() + callbacks = [] + + self.init_trainer(callbacks=callbacks) + self.initialized = True + + def save_configs(self): + with open(os.path.join(self.savedir, "pipeline.yaml"), "w") as f: + OmegaConf.save(config=self.opt, f=f) + + def fit(self): + self.init_pipeline(train=True) + self.trainer.fit() + + def evaluate(self): + self.init_pipeline(train=False) + return self.trainer.validate() diff --git a/theseus/tabular/base/preprocessors/__init__.py b/theseus/ml/preprocessors/__init__.py similarity index 94% rename from theseus/tabular/base/preprocessors/__init__.py rename to theseus/ml/preprocessors/__init__.py index f8aee1d..1f2e0e8 100644 --- a/theseus/tabular/base/preprocessors/__init__.py +++ b/theseus/ml/preprocessors/__init__.py @@ -15,7 +15,7 @@ ) from .encoder import LabelEncode from .fill_nan import FillNaN -from .mapping import MapScreenToBinary +from .mapping import MapValue from .new_col import LambdaCreateColumn from .sort import SortBy from .splitter import Splitter @@ -39,4 +39,4 @@ TRANSFORM_REGISTRY.register(LambdaDropRows) TRANSFORM_REGISTRY.register(LambdaCreateColumn) TRANSFORM_REGISTRY.register(SortBy) -TRANSFORM_REGISTRY.register(MapScreenToBinary) +TRANSFORM_REGISTRY.register(MapValue) diff --git a/theseus/tabular/base/preprocessors/aggregation.py b/theseus/ml/preprocessors/aggregation.py similarity index 100% rename from theseus/tabular/base/preprocessors/aggregation.py rename to theseus/ml/preprocessors/aggregation.py diff --git a/theseus/ml/preprocessors/base.py b/theseus/ml/preprocessors/base.py new file mode 100644 index 0000000..41b4ada --- /dev/null +++ b/theseus/ml/preprocessors/base.py @@ -0,0 +1,63 @@ +import pandas as pd +from tqdm import tqdm + +from theseus.base.utilities.loggers import LoggerObserver + +tqdm.pandas() +from .name_filter import FilterColumnNames + +LOGGER = LoggerObserver.getLogger("main") + +try: + from pandarallel import pandarallel + + pandarallel.initialize(progress_bar=True) + use_parallel = True +except: + use_parallel = False + LOGGER.text( + "pandarallel should be installed for parallerization. Using normal apply-function instead", + level=LoggerObserver.WARN, + ) + + +class Preprocessor: + def __init__( + self, column_names=None, exclude_columns=None, verbose=False, **kwargs + ): + self.verbose = verbose + self.column_names = column_names + + self.filter = None + if column_names is not None: + self.filter = FilterColumnNames( + patterns=column_names, excludes=exclude_columns + ) + + def apply(self, df, function, parallel=True, axis=0, show_progress=True): + + df_func = df.apply + if use_parallel and parallel: + if not isinstance(df, pd.core.groupby.SeriesGroupBy): + df_func = df.parallel_apply + else: + if show_progress: + df_func = df.progress_apply + + if isinstance(df, pd.DataFrame): + kwargs = {"axis": axis} + else: + kwargs = {} + + return df_func(function, **kwargs) + + def prerun(self, df): + if self.filter is not None: + self.column_names = self.filter.run(df) + + def run(self, df): + return df + + def log(self, text, level=LoggerObserver.INFO): + if self.verbose: + LOGGER.text(text, level=level) diff --git a/theseus/tabular/base/preprocessors/categorize.py b/theseus/ml/preprocessors/categorize.py similarity index 100% rename from theseus/tabular/base/preprocessors/categorize.py rename to theseus/ml/preprocessors/categorize.py diff --git a/theseus/tabular/base/preprocessors/compose.py b/theseus/ml/preprocessors/compose.py similarity index 100% rename from theseus/tabular/base/preprocessors/compose.py rename to theseus/ml/preprocessors/compose.py diff --git a/theseus/tabular/base/preprocessors/csv_saver.py b/theseus/ml/preprocessors/csv_saver.py similarity index 100% rename from theseus/tabular/base/preprocessors/csv_saver.py rename to theseus/ml/preprocessors/csv_saver.py diff --git a/theseus/tabular/base/preprocessors/datetime.py b/theseus/ml/preprocessors/datetime.py similarity index 100% rename from theseus/tabular/base/preprocessors/datetime.py rename to theseus/ml/preprocessors/datetime.py diff --git a/theseus/tabular/base/preprocessors/drop_col.py b/theseus/ml/preprocessors/drop_col.py similarity index 95% rename from theseus/tabular/base/preprocessors/drop_col.py rename to theseus/ml/preprocessors/drop_col.py index 85d5cf4..1e2f351 100644 --- a/theseus/tabular/base/preprocessors/drop_col.py +++ b/theseus/ml/preprocessors/drop_col.py @@ -23,9 +23,8 @@ def __init__(self, lambda_func, **kwargs): def run(self, df): self.prerun(df) - ori_size = df.shape[0] - df = df.drop(df[df.apply(self.lambda_func, axis=1)].index) + df = df.drop(df[self.apply(df, self.lambda_func, parallel=True, axis=1)].index) dropped_size = ori_size - df.shape[0] self.log(f"Dropped {dropped_size} rows based on lambda function") return df diff --git a/theseus/tabular/base/preprocessors/encoder.py b/theseus/ml/preprocessors/encoder.py similarity index 68% rename from theseus/tabular/base/preprocessors/encoder.py rename to theseus/ml/preprocessors/encoder.py index ad5dd50..38906c9 100644 --- a/theseus/tabular/base/preprocessors/encoder.py +++ b/theseus/ml/preprocessors/encoder.py @@ -23,6 +23,7 @@ def __init__(self, encoder_type="le", save_folder=None, **kwargs): self.encoder_type = encoder_type self.save_folder = save_folder + self.mapping_dict = {} if self.encoder_type == "le": self.encoder = LabelEncoder() @@ -31,6 +32,10 @@ def __init__(self, encoder_type="le", save_folder=None, **kwargs): else: self.encoder = OrdinalEncoder() + @classmethod + def from_json(cls, json_path: str): + return cls(json_path=json_path, encoder_type="json_mapping") + def create_mapping_dict(self, column_name): le_name_mapping = dict( zip( @@ -45,22 +50,32 @@ def create_mapping_dict(self, column_name): open(osp.join(self.save_folder, column_name + ".json"), "w"), indent=4, ) + return le_name_mapping + + def encode_corpus(self, df): + for column_name in self.column_names: + df[column_name] = self.encoder.fit_transform(df[column_name].values).copy() + mapping_dict = self.create_mapping_dict(column_name) + self.mapping_dict[column_name] = mapping_dict + return df + + def encode_query(self, df): + for column_name in self.column_names: + df[column_name] = self.apply( + df[column_name], lambda x: self.mapping_dict[column_name].get(x, -1) + ).copy() + return df def run(self, df): self.prerun(df) - if self.column_names is not None: - for column_name in self.column_names: - df[column_name] = self.encoder.fit_transform(df[column_name].values) - self.create_mapping_dict(column_name) - else: + + if self.column_names is None: self.log( "Column names not specified. Automatically label encode columns with non-defined types", level=LoggerObserver.WARN, ) self.column_names = [col for col, dt in df.dtypes.items() if dt == object] - for column_name in self.column_names: - df[column_name] = self.encoder.fit_transform(df[column_name].values) - self.create_mapping_dict(column_name) + self.encode_corpus(df) self.log(f"Label-encoded columns: {self.column_names}") return df diff --git a/theseus/tabular/base/preprocessors/fill_nan.py b/theseus/ml/preprocessors/fill_nan.py similarity index 100% rename from theseus/tabular/base/preprocessors/fill_nan.py rename to theseus/ml/preprocessors/fill_nan.py diff --git a/theseus/tabular/base/preprocessors/mapping.py b/theseus/ml/preprocessors/mapping.py similarity index 94% rename from theseus/tabular/base/preprocessors/mapping.py rename to theseus/ml/preprocessors/mapping.py index ecf63f9..c8bf335 100644 --- a/theseus/tabular/base/preprocessors/mapping.py +++ b/theseus/ml/preprocessors/mapping.py @@ -5,7 +5,7 @@ LOGGER = LoggerObserver.getLogger("main") -class MapScreenToBinary(Preprocessor): +class MapValue(Preprocessor): """ mapping_dict should be dict of dicts; each inside dict is a mapping dict """ diff --git a/theseus/tabular/base/preprocessors/name_filter.py b/theseus/ml/preprocessors/name_filter.py similarity index 100% rename from theseus/tabular/base/preprocessors/name_filter.py rename to theseus/ml/preprocessors/name_filter.py diff --git a/theseus/tabular/base/preprocessors/new_col.py b/theseus/ml/preprocessors/new_col.py similarity index 100% rename from theseus/tabular/base/preprocessors/new_col.py rename to theseus/ml/preprocessors/new_col.py diff --git a/theseus/tabular/base/preprocessors/sort.py b/theseus/ml/preprocessors/sort.py similarity index 100% rename from theseus/tabular/base/preprocessors/sort.py rename to theseus/ml/preprocessors/sort.py diff --git a/theseus/tabular/base/preprocessors/splitter.py b/theseus/ml/preprocessors/splitter.py similarity index 75% rename from theseus/tabular/base/preprocessors/splitter.py rename to theseus/ml/preprocessors/splitter.py index d9163e0..3a45a30 100644 --- a/theseus/tabular/base/preprocessors/splitter.py +++ b/theseus/ml/preprocessors/splitter.py @@ -1,5 +1,6 @@ import os import os.path as osp +import random from sklearn.model_selection import StratifiedKFold, train_test_split @@ -27,6 +28,7 @@ def __init__( "default", "stratified", "stratifiedkfold", + "unique", ], "splitter type not supported" self.splitter_type = splitter_type @@ -40,6 +42,7 @@ def __init__( if self.splitter_type == "stratified": assert label_column is not None, "Label column should be specified" self.splitter = train_test_split + self.ratio = ratio elif self.splitter_type == "stratifiedkfold": assert label_column is not None, "Label column should be specified" assert n_splits is not None, "number of splits should be specified" @@ -49,6 +52,10 @@ def __init__( elif self.splitter_type == "default": assert ratio is not None, "should specify ratio" self.ratio = ratio + elif self.splitter_type == "unique": + assert ratio is not None, "should specify ratio" + self.splitter = random.sample + self.ratio = ratio def run(self, df): num_samples, num_features = df.shape @@ -59,8 +66,21 @@ def run(self, df): val_df.to_csv(osp.join(self.save_folder, "val.csv"), index=False) elif self.splitter_type == "stratified": train_df, val_df = self.splitter( - df, stratify=df[[self.label_column]], random_state=self.seed + df, + stratify=df[[self.label_column]], + random_state=self.seed, + train_size=self.ratio, + ) + train_df.to_csv(osp.join(self.save_folder, "train.csv"), index=False) + val_df.to_csv(osp.join(self.save_folder, "val.csv"), index=False) + elif self.splitter_type == "unique": + unique_values = df[self.label_column].unique().tolist() + num_unique_samples = len(unique_values) + train_idx = self.splitter( + unique_values, int(num_unique_samples * self.ratio) ) + train_df = df[df[self.label_column].isin(train_idx)] + val_df = df[~df[self.label_column].isin(train_idx)] train_df.to_csv(osp.join(self.save_folder, "train.csv"), index=False) val_df.to_csv(osp.join(self.save_folder, "val.csv"), index=False) else: diff --git a/theseus/tabular/base/preprocessors/standardize.py b/theseus/ml/preprocessors/standardize.py similarity index 100% rename from theseus/tabular/base/preprocessors/standardize.py rename to theseus/ml/preprocessors/standardize.py diff --git a/theseus/tabular/base/reduction/lda.py b/theseus/ml/reduction/lda.py similarity index 100% rename from theseus/tabular/base/reduction/lda.py rename to theseus/ml/reduction/lda.py diff --git a/theseus/tabular/base/reduction/pca.py b/theseus/ml/reduction/pca.py similarity index 100% rename from theseus/tabular/base/reduction/pca.py rename to theseus/ml/reduction/pca.py diff --git a/theseus/tabular/base/reduction/tsne.py b/theseus/ml/reduction/tsne.py similarity index 100% rename from theseus/tabular/base/reduction/tsne.py rename to theseus/ml/reduction/tsne.py diff --git a/theseus/tabular/classification/trainer/__init__.py b/theseus/ml/trainer/__init__.py similarity index 100% rename from theseus/tabular/classification/trainer/__init__.py rename to theseus/ml/trainer/__init__.py diff --git a/theseus/tabular/classification/trainer/ml_trainer.py b/theseus/ml/trainer/ml_trainer.py similarity index 90% rename from theseus/tabular/classification/trainer/ml_trainer.py rename to theseus/ml/trainer/ml_trainer.py index aa1f72b..00bbe0a 100644 --- a/theseus/tabular/classification/trainer/ml_trainer.py +++ b/theseus/ml/trainer/ml_trainer.py @@ -1,5 +1,5 @@ -from theseus.base.callbacks import CallbacksList from theseus.base.utilities.loggers.observer import LoggerObserver +from theseus.ml.callbacks import CallbacksList LOGGER = LoggerObserver.getLogger("main") @@ -34,10 +34,10 @@ def fit(self): {"trainset": self.trainset, "valset": self.valset}, ) - self.callbacks.run("on_val_epoch_start") - metric_dict = self.evaluate_epoch() + self.callbacks.run("on_validation_epoch_start") + metric_dict = self.validate() self.callbacks.run( - "on_val_epoch_end", + "on_validation_epoch_end", { "iters": 0, "trainset": self.trainset, @@ -47,7 +47,7 @@ def fit(self): ) self.callbacks.run("on_finish") - def evaluate_epoch(self): + def validate(self): """ Perform validation one epoch """ diff --git a/theseus/tabular/base/utilities/pprint.py b/theseus/ml/utilities/pprint.py similarity index 100% rename from theseus/tabular/base/utilities/pprint.py rename to theseus/ml/utilities/pprint.py diff --git a/theseus/nlp/base/preprocessors/basic_processors.py b/theseus/nlp/base/preprocessors/basic_processors.py index 9731f42..5c65a80 100644 --- a/theseus/nlp/base/preprocessors/basic_processors.py +++ b/theseus/nlp/base/preprocessors/basic_processors.py @@ -2,6 +2,8 @@ import string import nltk + +nltk.download("punkt") from nltk.corpus import stopwords from nltk.stem import SnowballStemmer from nltk.stem.wordnet import WordNetLemmatizer diff --git a/theseus/nlp/base/preprocessors/vocabulary.py b/theseus/nlp/base/preprocessors/vocabulary.py index 9c7ad4c..e7e3a23 100644 --- a/theseus/nlp/base/preprocessors/vocabulary.py +++ b/theseus/nlp/base/preprocessors/vocabulary.py @@ -1,5 +1,6 @@ import os.path as osp import pickle +from typing import * from theseus.base.utilities.loggers import LoggerObserver @@ -12,11 +13,14 @@ def __init__( max_size=None, min_freq=None, max_freq=None, - special_tokens={}, + special_tokens=None, replace=False, pkl_path=None, unk_word="", pad_word="", + sos_word="", + eos_word="", + use_special_tokens=True, ): self.pkl_path = pkl_path @@ -27,13 +31,33 @@ def __init__( self.max_size = max_size self.unk_word = unk_word self.pad_word = pad_word + self.sos_word = sos_word + self.eos_word = eos_word + self.use_special_tokens = use_special_tokens + self.truncation_side = "right" self.init_vocab() if self.pkl_path is not None: - with open(self.pkl_path, "rb") as f: - vocab = pickle.load(f) - self.word2idx = vocab.word2idx - self.idx2word = vocab.idx2word + self.load_pickle(self.pkl_path) + + def load_pickle(self, vocab_path): + with open(vocab_path, "rb") as f: + vocab = pickle.load(f) + self.word2idx = vocab.word2idx + self.idx2word = vocab.idx2word + self.frequency = vocab.frequency + self.special_tokens = vocab.special_tokens + self.replace = vocab.replace + self.min_freq = vocab.min_freq + self.max_freq = vocab.max_freq + self.max_size = vocab.max_size + self.unk_word = vocab.unk_word + self.pad_word = vocab.pad_word + self.sos_word = vocab.sos_word + self.eos_word = vocab.eos_word + self.vocab_size = vocab.vocab_size + self.truncation_side = self.truncation_side + LOGGER.text( "Vocabulary successfully loaded from vocab.pkl file!", level=LoggerObserver.INFO, @@ -52,9 +76,8 @@ def save_vocab(self, save_path): f.write(term + "\n") LOGGER.text(f"Save pickle to {save_path}", level=LoggerObserver.INFO) - def build_vocab(self, list_tokens): + def build_vocab(self, list_tokens, add_special_tokens=True): """Populate the dictionaries for converting tokens to integers (and vice-versa).""" - for tok in list_tokens: if not tok in self.frequency: self.frequency[tok] = 0 @@ -77,17 +100,19 @@ def build_vocab(self, list_tokens): if self.max_size is not None: list_tokens = list_tokens[: self.max_size] + if self.use_special_tokens: + self.add_special_tokens() for tok in list_tokens: self.add_word(tok) - self.add_special_tokens() - def init_vocab(self): """Initialize the dictionaries for converting tokens to integers (and vice-versa).""" self.word2idx = {} self.idx2word = {} self.frequency = {} - self.idx = 0 + self.vocab_size = 0 + if self.special_tokens is None: + self.special_tokens = {} def add_word(self, word, index=None): """Add a token to the vocabulary.""" @@ -98,18 +123,18 @@ def add_word(self, word, index=None): assert isinstance(index, int), "Index must be type int" if index is None: - index = self.idx + index = self.vocab_size if not word in self.word2idx.keys() and not index in self.idx2word.keys(): - self.word2idx[word] = self.idx - self.idx2word[self.idx] = word - self.idx += 1 + self.word2idx[word] = self.vocab_size + self.idx2word[self.vocab_size] = word + self.vocab_size += 1 elif not word in self.word2idx.keys() and index in self.idx2word.keys(): if self.replace: old_word = self.idx2word[index] - self.word2idx[old_word] = self.idx - self.idx2word[self.idx] = old_word - self.idx += 1 + self.word2idx[old_word] = self.vocab_size + self.idx2word[self.vocab_size] = old_word + self.vocab_size += 1 self.word2idx[word] = index self.idx2word[index] = word @@ -140,16 +165,21 @@ def add_word(self, word, index=None): raise ValueError() def add_special_tokens(self): - if self.unk_word not in self.special_tokens.keys(): - self.special_tokens.update({self.unk_word: self.idx}) - self.idx += 1 + if self.sos_word not in self.special_tokens.keys(): + self.add_word(self.sos_word) + self.special_tokens.update({self.sos_word: self.vocab_size}) + + if self.eos_word not in self.special_tokens.keys(): + self.add_word(self.eos_word) + self.special_tokens.update({self.eos_word: self.vocab_size}) if self.pad_word not in self.special_tokens.keys(): - self.special_tokens.update({self.pad_word: self.idx}) - self.idx += 1 + self.add_word(self.pad_word) + self.special_tokens.update({self.pad_word: self.vocab_size}) - for token, index in self.special_tokens.items(): - self.add_word(token, index) + if self.unk_word not in self.special_tokens.keys(): + self.add_word(self.unk_word) + self.special_tokens.update({self.unk_word: self.vocab_size}) def get_pad_token_id(self): return self.word2idx[self.pad_word] @@ -157,21 +187,116 @@ def get_pad_token_id(self): def get_unk_token_id(self): return self.word2idx[self.unk_word] - def encode_tokens(self, lists_of_tokens): + def get_sos_token_id(self): + return self.word2idx[self.sos_word] + + def get_eos_token_id(self): + return self.word2idx[self.eos_word] + + def encode_tokens(self, lists_of_tokens, **kwargs): """ Batch of list of tokens """ + + add_special_tokens = kwargs.get("add_special_tokens", False) + max_length = kwargs.get("max_length", None) + return_token_type_ids = kwargs.get("return_token_type_ids", False) + truncation = kwargs.get("truncation", False) + + if return_token_type_ids: + token_type_idss = [] + + if max_length == "max": + max_length = max([len(x) for x in lists_of_tokens]) + encoded_list = [] for token_list in lists_of_tokens: - batch = [] + if add_special_tokens and self.use_special_tokens: + batch = [self.__call__(self.sos_word)] + else: + batch = [] for token in token_list: batch.append(self.__call__(token)) + + if add_special_tokens and self.use_special_tokens: + batch.append(self.__call__(self.eos_word)) + + if max_length is not None: + if len(batch) > max_length: + if truncation: + if add_special_tokens and self.use_special_tokens: + if self.truncation_side == "right": + batch = batch[: max_length - 1] + batch.append(self.__call__(self.eos_word)) + else: + batch = batch[1 - max_length :] + batch.insert(0, self.__call__(self.sos_word)) + else: + if self.truncation_side == "right": + batch = batch[:max_length] + else: + batch = batch[-max_length:] + else: + LOGGER.text( + f"Sequence is longer than max_length. Please use truncation=True", + level=LoggerObserver.ERROR, + ) + raise ValueError() + if ( + len(batch) < max_length + and add_special_tokens + and self.use_special_tokens + ): + batch += [self.__call__(self.pad_word)] * (max_length - len(batch)) + + if return_token_type_ids: + token_type_ids = [ + 0 if batch[tk] != self.__call__(self.pad_word) else 1 + for tk in range(len(batch)) + ] + token_type_idss.append(token_type_ids) + encoded_list.append(batch) - return encoded_list + + if return_token_type_ids: + return {"input_ids": encoded_list, "token_type_ids": token_type_idss} + else: + return { + "input_ids": encoded_list, + } + + def decode_tokens(self, list_of_ids: List, remove_special_tokens: bool = True): + """ + Batch of list of ids + """ + decoded_list = [] + for ids in list_of_ids: + if remove_special_tokens: + batch = [ + self.itos(idx) + for idx in ids + if idx + not in [ + self.get_pad_token_id(), + self.get_sos_token_id(), + self.get_eos_token_id(), + ] + ] + else: + batch = [self.itos(idx) for idx in ids] + decoded_list.append(batch) + return decoded_list + + def encode_texts(self, text, **kwargs): + if isinstance(text, str): + text = [text] + + tokenized_texts = [s.split(kwargs.get("delimeter", " ")) for s in text] + return self.encode_tokens(tokenized_texts, **kwargs) def itos(self, idx): if not idx in self.idx2word: - return self.idx2word[self.unk_word] + return self.idx2word[self.__call__(self.unk_word)] return self.idx2word[idx] def __call__(self, word): @@ -180,4 +305,4 @@ def __call__(self, word): return self.word2idx[word] def __len__(self): - return max(list(self.word2idx.values())) + return max(list(self.word2idx.values())) + 1 # add zero value diff --git a/theseus/nlp/retrieval/models/__init__.py b/theseus/nlp/retrieval/models/__init__.py index baa3861..eb83465 100644 --- a/theseus/nlp/retrieval/models/__init__.py +++ b/theseus/nlp/retrieval/models/__init__.py @@ -1,4 +1,5 @@ # from .bm25 import BM25Retrieval # from .ensembler import EnsembleRetriever # from .retrieval_tfms import RetrievalModel +from .spacy_encoder import SpacyEncoder from .tf_idf import TFIDFEncoder diff --git a/theseus/nlp/retrieval/models/spacy_encoder.py b/theseus/nlp/retrieval/models/spacy_encoder.py new file mode 100644 index 0000000..5846ff6 --- /dev/null +++ b/theseus/nlp/retrieval/models/spacy_encoder.py @@ -0,0 +1,16 @@ +import spacy + +from .base import BaseRetrieval + + +class SpacyEncoder(BaseRetrieval): + def __init__(self, name) -> None: + super().__init__() + # "en_core_sci_lg" "en_core_med7_lg" + self.encoder = spacy.load(name) + + def encode_corpus(self, text): + return self.encoder(text).vector + + def encode_query(self, text): + return self.encoder(text).vector diff --git a/theseus/nlp/retrieval/models/tf_idf.py b/theseus/nlp/retrieval/models/tf_idf.py index 7ce61f8..c4cd43d 100644 --- a/theseus/nlp/retrieval/models/tf_idf.py +++ b/theseus/nlp/retrieval/models/tf_idf.py @@ -1,6 +1,6 @@ import os.path as osp import pickle -from typing import Dict, List +from typing import * from sklearn.feature_extraction.text import TfidfVectorizer @@ -12,7 +12,13 @@ def identity_tokenizer(text): class TFIDFEncoder(BaseRetrieval): - def __init__(self, min_df: int = 0, max_df: int = 1.0, model_path: str = None): + def __init__( + self, + min_df: int = 0, + max_df: int = 1.0, + model_path: str = None, + ngram_range: Tuple[int] = (1, 1), + ): super().__init__() @@ -29,6 +35,7 @@ def __init__(self, min_df: int = 0, max_df: int = 1.0, model_path: str = None): max_df=max_df, lowercase=True, norm="l2", + ngram_range=ngram_range, ) def save_model(self, save_path): diff --git a/theseus/opt.py b/theseus/opt.py deleted file mode 100644 index 543770c..0000000 --- a/theseus/opt.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Modified from https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/program.py -""" - -import json -from argparse import ArgumentParser, RawDescriptionHelpFormatter -from copy import deepcopy - -import yaml - -from theseus.base.utilities.loading import load_yaml -from theseus.base.utilities.loggers.observer import LoggerObserver - -LOGGER = LoggerObserver.getLogger("main") - - -class Config(dict): - """Single level attribute dict, recursive""" - - _depth = 0 - _yaml_paths = [] - - # def __new__(class_, yaml_path, *args, **kwargs): - # if yaml_path in class_._yaml_paths: - # LOGGER.text( - # "Circular includes detected in YAML initialization!", - # level=LoggerObserver.CRITICAL, - # ) - # raise ValueError() - # class_._yaml_paths.append(yaml_path) - # return dict.__new__(class_, yaml_path, *args, **kwargs) - - def __init__(self, yaml_path): - super(Config, self).__init__() - - config = load_yaml(yaml_path) - - if "includes" in config.keys(): - final_config = {} - for include_yaml_path in config["includes"]: - tmp_config = Config(include_yaml_path) - final_config.update(tmp_config) - - final_config.update(config) - final_config.pop("includes") - super(Config, self).update(final_config) - else: - super(Config, self).update(config) - - # self._yaml_paths.pop(-1) # the last successful yaml will be popped out - - def __getattr__(self, key): - if key in self: - return self[key] - raise AttributeError("object has no attribute '{}'".format(key)) - - def save_yaml(self, path): - LOGGER.text(f"Saving config to {path}...", level=LoggerObserver.DEBUG) - with open(path, "w") as f: - yaml.dump(dict(self), f, default_flow_style=False, sort_keys=False) - - @classmethod - def load_yaml(cls, path): - LOGGER.text(f"Loading config from {path}...", level=LoggerObserver.DEBUG) - return cls(path) - - def __repr__(self) -> str: - return str(json.dumps(dict(self), sort_keys=False, indent=4)) - - -class Opts(ArgumentParser): - def __init__(self): - super(Opts, self).__init__(formatter_class=RawDescriptionHelpFormatter) - self.add_argument("-c", "--config", help="configuration file to use") - self.add_argument( - "-o", "--opt", nargs="+", help="override configuration options" - ) - - def parse_args(self, argv=None): - args = super(Opts, self).parse_args(argv) - assert args.config is not None, "Please specify --config=configure_file_path." - args.opt = self._parse_opt(args.opt) - - config = Config(args.config) - config = self.override(config, args.opt) - return config - - def _parse_opt(self, opts): - config = {} - if not opts: - return config - for s in opts: - s = s.strip() - try: - k, v = s.split("=") - except ValueError: - LOGGER.text( - "Invalid option: {}, options should be in the format of key=value".format( - s - ), - level=LoggerObserver.ERROR, - ) - raise ValueError() - - config[k] = yaml.load(v, Loader=yaml.Loader) - return config - - def override(self, global_config, overriden): - """ - Merge config into global config. - Args: - config (dict): Config to be merged. - Returns: global config - """ - LOGGER.text("Overriding configuration...", LoggerObserver.DEBUG) - for key, value in overriden.items(): - if "." not in key: - if isinstance(value, dict) and key in global_config: - global_config[key].update(value) - else: - if key in global_config.keys(): - global_config[key] = value - else: - LOGGER.text( - f"'{key}' not found in config", - level=LoggerObserver.WARN, - ) - else: - sub_keys = key.split(".") - assert ( - sub_keys[0] in global_config - ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( - global_config.keys(), sub_keys[0] - ) - cur = global_config[sub_keys[0]] - for idx, sub_key in enumerate(sub_keys[1:]): - if idx == len(sub_keys) - 2: - if sub_key in cur.keys(): - cur[sub_key] = value - else: - LOGGER.text( - f"'{key}' not found in config", - level=LoggerObserver.WARN, - ) - else: - cur = cur[sub_key] - return global_config diff --git a/theseus/tabular/base/preprocessors/base.py b/theseus/tabular/base/preprocessors/base.py deleted file mode 100644 index 432f60d..0000000 --- a/theseus/tabular/base/preprocessors/base.py +++ /dev/null @@ -1,51 +0,0 @@ -from theseus.base.utilities.loggers import LoggerObserver - -from .name_filter import FilterColumnNames - -# try: -# from pandarallel import pandarallel - -# pandarallel.initialize() -# use_parallel = True -# except: -use_parallel = False - -LOGGER = LoggerObserver.getLogger("main") - - -class Preprocessor: - def __init__( - self, column_names=None, exclude_columns=None, verbose=False, **kwargs - ): - self.verbose = verbose - self.column_names = column_names - - self.filter = None - if column_names is not None: - self.filter = FilterColumnNames( - patterns=column_names, excludes=exclude_columns - ) - - def apply(self, df, function, parallel=True, axis=0): - if parallel: - if not use_parallel: - LOGGER.text( - "pandarallel should be installed for parallerization. Using normal apply-function instead", - level=LoggerObserver.WARN, - ) - return df.apply(function, axis=axis) - else: - return df.parallel_apply(function, axis=axis) - else: - return df.apply(function, axis=axis) - - def prerun(self, df): - if self.filter is not None: - self.column_names = self.filter.run(df) - - def run(self, df): - return df - - def log(self, text, level=LoggerObserver.INFO): - if self.verbose: - LOGGER.text(text, level=level) diff --git a/theseus/tabular/classification/__init__.py b/theseus/tabular/classification/__init__.py deleted file mode 100644 index ecb2a2a..0000000 --- a/theseus/tabular/classification/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .callbacks import * -from .metrics import * -from .models import * -from .pipeline import TabularPipeline -from .trainer import * diff --git a/theseus/tabular/classification/callbacks/__init__.py b/theseus/tabular/classification/callbacks/__init__.py deleted file mode 100644 index a9fcb54..0000000 --- a/theseus/tabular/classification/callbacks/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from theseus.base.callbacks import CALLBACKS_REGISTRY - -from .checkpoint_callbacks import SKLearnCheckpointCallbacks -from .explainer import * - -CALLBACKS_REGISTRY.register(SKLearnCheckpointCallbacks) -CALLBACKS_REGISTRY.register(ShapValueExplainer) -CALLBACKS_REGISTRY.register(PermutationImportance) -CALLBACKS_REGISTRY.register(PartialDependencePlots) -CALLBACKS_REGISTRY.register(LIMEExplainer) diff --git a/theseus/tabular/classification/pipeline.py b/theseus/tabular/classification/pipeline.py deleted file mode 100644 index 2a3cc69..0000000 --- a/theseus/tabular/classification/pipeline.py +++ /dev/null @@ -1,131 +0,0 @@ -from theseus.base.pipeline import BasePipeline -from theseus.base.utilities.getter import get_instance, get_instance_recursively -from theseus.base.utilities.loggers import LoggerObserver -from theseus.opt import Config -from theseus.tabular.base.preprocessors import TRANSFORM_REGISTRY -from theseus.tabular.classification.callbacks import CALLBACKS_REGISTRY -from theseus.tabular.classification.datasets import ( - DATALOADER_REGISTRY, - DATASET_REGISTRY, -) -from theseus.tabular.classification.metrics import METRIC_REGISTRY -from theseus.tabular.classification.models import MODEL_REGISTRY -from theseus.tabular.classification.trainer import TRAINER_REGISTRY - - -class TabularPipeline(BasePipeline): - """docstring for Pipeline.""" - - def __init__(self, opt: Config): - super(TabularPipeline, self).__init__(opt) - - def init_registry(self): - super().init_registry() - self.callbacks_registry = CALLBACKS_REGISTRY - self.transform_registry = TRANSFORM_REGISTRY - self.model_registry = MODEL_REGISTRY - self.metric_registry = METRIC_REGISTRY - self.trainer_registry = TRAINER_REGISTRY - self.dataset_registry = DATASET_REGISTRY - self.dataloader_registry = DATALOADER_REGISTRY - self.logger.text("Overidding registry in pipeline...", LoggerObserver.INFO) - - def init_model(self): - classnames = self.val_dataset["classnames"] - num_classes = len(classnames) - self.model = get_instance( - self.opt["model"], num_classes=num_classes, registry=self.model_registry - ) - - def init_train_dataloader(self): - self.transform = get_instance_recursively( - self.transform_cfg, registry=self.transform_registry - ) - self.train_dataset = get_instance_recursively( - self.opt["data"]["dataset"]["train"], - registry=self.dataset_registry, - transform=self.transform["train"], - ).load_data() - - self.logger.text( - f"Training shape: {self.train_dataset['inputs'].shape}", - level=LoggerObserver.INFO, - ) - - def init_validation_dataloader(self): - self.transform = get_instance_recursively( - self.transform_cfg, registry=self.transform_registry - ) - self.val_dataset = get_instance_recursively( - self.opt["data"]["dataset"]["val"], - registry=self.dataset_registry, - transform=self.transform["val"], - ).load_data() - - classnames = self.val_dataset["classnames"] - num_classes = len(classnames) - - self.logger.text( - f"Validation shape: {self.val_dataset['inputs'].shape}", - level=LoggerObserver.INFO, - ) - self.logger.text( - f"Number of classes: {num_classes}", - level=LoggerObserver.INFO, - ) - - def init_trainer(self, callbacks=None): - self.trainer = get_instance( - self.opt["trainer"], - model=self.model, - trainset=getattr(self, "train_dataset", None), - valset=getattr(self, "val_dataset", None), - metrics=self.metrics, - callbacks=callbacks, - registry=self.trainer_registry, - ) - - def init_loading(self): - if getattr(self, "pretrained", None): - self.model.load_model(self.pretrained) - - def init_pipeline(self, train=False): - if self.initialized: - return - self.init_globals() - self.init_registry() - if train: - self.init_train_dataloader() - self.init_validation_dataloader() - self.init_model() - self.init_loading() - self.init_metrics() - callbacks = self.init_callbacks() - self.save_configs() - else: - self.init_validation_dataloader() - self.init_model() - self.init_metrics() - self.init_loading() - callbacks = [] - - if getattr(self, "metrics", None): - callbacks.insert( - 0, - self.callbacks_registry.get("MetricLoggerCallbacks")( - save_dir=self.savedir - ), - ) - if getattr(self, "criterion", None): - callbacks.insert( - 0, - self.callbacks_registry.get("LossLoggerCallbacks")( - print_interval=self.opt["global"].get("print_interval", None), - ), - ) - if self.debug: - callbacks.insert(0, self.callbacks_registry.get("DebugCallbacks")()) - callbacks.insert(0, self.callbacks_registry.get("TimerCallbacks")()) - - self.init_trainer(callbacks=callbacks) - self.initialized = True