This project leverages Deep Learning techniques to classify plant diseases from images. It employs a custom Convolutional Neural Network (CNN) architecture, Grad-CAM for interpretability, and Early Stopping to optimize training performance. Implemented in PyTorch, the model is trained and evaluated on the popular PlantVillage dataset, achieving high accuracy and providing insights through visualization techniques.
This repository contains detailed information about the code as well as the outputs generated by it.
- Custom Dataset Loader: Handles multi-class image datasets with tailored data augmentation and transformations.
- CNN Architecture: Custom-designed CNN model optimized with techniques like batch normalization and dropout.
- Early Stopping: Prevents overfitting by halting training when validation performance plateaus.
- Learning Rate Scheduler: Automatically adjusts learning rate when validation loss stagnates.
- Grad-CAM Visualization: Highlights key regions of input images that influence predictions, providing model interpretability.
- Performance Metrics:
- Confusion matrix
- Classification report
- Training curves for loss, accuracy, and learning rate
- Inference: Predicts diseases from unseen images and provides confidence scores for top-3 predictions.
The project utilizes the PlantVillage dataset, which contains images of healthy and diseased plant leaves across multiple classes. The dataset structure is expected as follows:
/root_dir
├── Class_1
│ ├── img1.jpg
│ ├── img2.jpg
│ └── ...
├── Class_2
│ ├── img1.jpg
│ ├── img2.jpg
│ └── ...
└── ...
Update the root_dir
in the code to the location of your dataset.
Install the following Python libraries before running the code:
pip install torch torchvision numpy pillow scikit-learn matplotlib seaborn plotly opencv-python
- Class:
CustomImageDataset
- Responsible for loading images from the dataset.
- Applies transformations (e.g., resizing, normalization, augmentation).
- Maps class names to corresponding indices for model compatibility.
- Class:
CNNClassifier
- Defines a Convolutional Neural Network (CNN) with:
- Convolutional Layers: Extract spatial features from images.
- Batch Normalization: Stabilizes training and accelerates convergence.
- Dropout Layers: Reduces overfitting by randomly disabling neurons.
- Fully Connected Layers: Perform final classification.
- Designed for flexibility and robustness in classification tasks.
- Defines a Convolutional Neural Network (CNN) with:
- Function:
train_model
- Trains the CNN model using a specified dataset.
- Saves the best model weights based on validation accuracy.
- Implements Early Stopping:
- Stops training when validation loss doesn't improve after a pre-defined number of epochs.
- Function:
evaluate_model
- Evaluates the trained model on the test dataset.
- Generates:
- A detailed classification report.
- A confusion matrix to visualize performance across classes.
- Function:
visualize_grad_cam
- Uses Grad-CAM (Gradient-weighted Class Activation Mapping) to:
- Highlight important regions of an image influencing the model's predictions.
- Provide insights into the model's decision-making process.
- Uses Grad-CAM (Gradient-weighted Class Activation Mapping) to:
- Function:
predict_disease
- Predicts labels for unseen images.
- Displays Top-3 class probabilities for better interpretability.
- Graphs: Show trends in:
- Loss (Training & Validation).
- Accuracy (Training & Validation).
- Learning Rate Progression.
- Loss Curves:
- Steady decrease in both training and validation loss.
- Indicates effective learning and optimization.
- Accuracy Curves:
- Achieves over 98% accuracy.
- Minimal overfitting:
- Training and validation metrics are closely aligned throughout.
- Purpose: Grad-CAM visualizations highlight the critical regions in the input image that significantly influence the model's predictions.
- Benefit: Adds interpretability to the model and helps validate its predictions by showing where the model is "looking" when making decisions.
Input A sample diseased tomato leaf image.
Output
- Predicted Disease: Tomato Bacterial Spot
- Confidence: 100.00%
Top-3 Predictions:
- Tomato_Bacterial_Spot: 100.00%
- Tomato_Early_Blight: 0.00%
- Tomato_YellowLeaf_Curl_Virus: 0.00%
Grad-CAM Heatmap:
Dataset Setup
- Download the PlantVillage dataset.
- Organize the dataset as described in the Dataset section.
- Update the
root_dir
path in the code to point to your dataset.
Training the Model
Run the following command to train the model:
python main.py
The script will:
- Train the model and print the training metrics (loss, accuracy).
- Save the best-performing model weights.
- Display evaluation metrics such as confusion matrix and classification report.
Prediction on New Images
Replace sample_image_path
with the path to your test image and run the inference function to predict the disease and visualize Grad-CAM heatmaps.
- Dataset: PlantVillage Dataset
- Model: Custom CNN with Early Stopping and Grad-CAM for visualization
- Performance: Achieved high prediction accuracy and strong generalization to unseen data
- Usage: Suitable for diagnosing plant diseases and assisting farmers with actionable insights