PyTorch implementation for the training procedure described in Forced Spatial Attention for Driver Foot Activity Classification.
- Clone this repository
- Install Pipenv:
pip3 install pipenv
- Install all requirements and dependencies in a new virtual environment using Pipenv:
cd Forced-Spatial-Attention
pipenv install
- Get link for desired PyTorch and Torchvision wheel from here and install it in the Pipenv virtual environment as follows:
pipenv install https://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl
pipenv install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl
- Download the trainval dataset for driver foot activity classification using this link.
- Extract the data.
The prescribed two-stage training procedure for the classification network can be carried out as follows:
pipenv shell # activate virtual environment
python train_stage1.py --dataset-root-path=/path/to/dataset/ --snapshot=./weights/squeezenet1_1_imagenet.pth --version=1_1 --FSA
python train_stage2.py --dataset-root-path=/path/to/dataset/ --snapshot=/path/to/snapshot/from/stage1/training --version=1_1 --FSA
exit # exit virtual environment
Pretrained weights for SqueezeNet v1.1 using the two-stage FSA loss can be found here. Inference can be carried out using this script as follows:
pipenv shell # activate virtual environment
python demo.py --video=/path/to/dataset/foot.mp4 --snapshot=/path/to/snapshot --version=1_1
exit # exit virtual environment
Config files, logs, results and snapshots from running the above scripts will be stored in the Forced-Spatial-Attention /experiments
folder by default.