Project 2 of the Machine Learning course given at the EPFL Fall 2021.
The goal of this project is to segment satellite images by detecting roads. Our classifier consists of a convolutional neural network called UNet.
- Quentin Deschamps
- Emilien Seiler
- Louis Le Guillouzic
Model | Data augmentation | Postprocessing | F1 score | Accuracy | Submission |
---|---|---|---|---|---|
UNet | Yes | Yes | 0.901 | 0.946 | #169349 |
UNet | Yes | No | 0.900 | 0.945 | #168760 |
Nested UNet | Yes | No | 0.896 | 0.943 | #169077 |
SegNet | Yes | No | 0.895 | 0.944 | #169078 |
UNet | No | No | 0.853 | 0.922 | #169073 |
To run the code of this project, you need to install the libraries listed in
the requirements.txt
file. You can perform the installation using this
command:
pip3 install -r requirements.txt
Dependencies:
- matplotlib
- numpy
- pillow
- scikit-image
- torch
- torchvision
- tqdm
The scripts
directory contains scripts to perform the different tasks of the
project.
To reproduce our submission on
AIcrowd, move
in the scripts
folder and run:
python3 run.py
This command will create the predicted mask for each test image in the
out/submission
directory. The csv file for submission produced will be
out/submission.csv
.
To create the augmented training dataset, you can run:
python3 augment_data.py
The images created will be in the data/training_augmented
directory. If this
directory already exists, it will overwrite the images.
To train a model, you can use the train.py
script:
python3 train.py
To see the different options, run python3 train.py --help
.
To create the predicted masks using a trained model, you can use the
predict.py
script:
python3 predict.py
To see the different options, run python3 predict.py --help
.
The pickle
files created during training can be visualized using the
plot_metrics.py
script:
python3 plot_metrics.py --file FILE
FILE
must be a pickle
file.
This is the structure of the repository:
data
: contains the datasetsdocs
: contains the documentationfigs
: contains the figuresnotebooks
: contains the notebooksscripts
: contains the main scriptsaugment_data.py
: create the augmented datasetconfig.py
: helpers functions to configure pathsplot_metrics.py
: plot metricspredict.py
: make predictions using a trained modelrun.py
: make predictions for AIcrowdtrain.py
: train the model
src
: source codemodels
: neural network modelsnested_unet.py
: nested UNet implementationsegnet.py
: SegNet implementationunet.py
: UNet implementation
data_augmentation.py
: creation of the augmented datasetdatasets.py
: custom dataset class for satellite imagesloss.py
: custom loss functionsmetrics.py
: score and performance functionspath.py
: paths and archives managementplot_utils.py
: plot utils using matplotlibpostprocessing.py
: postprocessing functions to improve predictionspredicter.py
: predicter class to make predictions using a trained modelsubmission.py
: submission utilstrainer.py
: trainer class to train a model
See references.