Project 2 of the Machine Learning course given at the EPFL Fall 2021.
The goal of this project is to segment satellite images by detecting roads. Our classifier consists of a convolutional neural network called UNet.
- Quentin Deschamps
- Emilien Seiler
- Louis Le Guillouzic
| Model | Data augmentation | Postprocessing | F1 score | Accuracy | Submission |
|---|---|---|---|---|---|
| UNet | Yes | Yes | 0.901 | 0.946 | #169349 |
| UNet | Yes | No | 0.900 | 0.945 | #168760 |
| Nested UNet | Yes | No | 0.896 | 0.943 | #169077 |
| SegNet | Yes | No | 0.895 | 0.944 | #169078 |
| UNet | No | No | 0.853 | 0.922 | #169073 |
To run the code of this project, you need to install the libraries listed in
the requirements.txt file. You can perform the installation using this
command:
pip3 install -r requirements.txt
Dependencies:
- matplotlib
- numpy
- pillow
- scikit-image
- torch
- torchvision
- tqdm
The scripts directory contains scripts to perform the different tasks of the
project.
To reproduce our submission on
AIcrowd, move
in the scripts folder and run:
python3 run.py
This command will create the predicted mask for each test image in the
out/submission directory. The csv file for submission produced will be
out/submission.csv.
To create the augmented training dataset, you can run:
python3 augment_data.py
The images created will be in the data/training_augmented directory. If this
directory already exists, it will overwrite the images.
To train a model, you can use the train.py script:
python3 train.py
To see the different options, run python3 train.py --help.
To create the predicted masks using a trained model, you can use the
predict.py script:
python3 predict.py
To see the different options, run python3 predict.py --help.
The pickle files created during training can be visualized using the
plot_metrics.py script:
python3 plot_metrics.py --file FILE
FILE must be a pickle file.
This is the structure of the repository:
data: contains the datasetsdocs: contains the documentationfigs: contains the figuresnotebooks: contains the notebooksscripts: contains the main scriptsaugment_data.py: create the augmented datasetconfig.py: helpers functions to configure pathsplot_metrics.py: plot metricspredict.py: make predictions using a trained modelrun.py: make predictions for AIcrowdtrain.py: train the model
src: source codemodels: neural network modelsnested_unet.py: nested UNet implementationsegnet.py: SegNet implementationunet.py: UNet implementation
data_augmentation.py: creation of the augmented datasetdatasets.py: custom dataset class for satellite imagesloss.py: custom loss functionsmetrics.py: score and performance functionspath.py: paths and archives managementplot_utils.py: plot utils using matplotlibpostprocessing.py: postprocessing functions to improve predictionspredicter.py: predicter class to make predictions using a trained modelsubmission.py: submission utilstrainer.py: trainer class to train a model
See references.