Skip to content

Latest commit

 

History

History
201 lines (140 loc) · 6.43 KB

README.md

File metadata and controls

201 lines (140 loc) · 6.43 KB

Plant Disease Detection

This project leverages Deep Learning techniques to classify plant diseases from images. It employs a custom Convolutional Neural Network (CNN) architecture, Grad-CAM for interpretability, and Early Stopping to optimize training performance. Implemented in PyTorch, the model is trained and evaluated on the popular PlantVillage dataset, achieving high accuracy and providing insights through visualization techniques.

This repository contains detailed information about the code as well as the outputs generated by it.


Key Features

  1. Custom Dataset Loader: Handles multi-class image datasets with tailored data augmentation and transformations.
  2. CNN Architecture: Custom-designed CNN model optimized with techniques like batch normalization and dropout.
  3. Early Stopping: Prevents overfitting by halting training when validation performance plateaus.
  4. Learning Rate Scheduler: Automatically adjusts learning rate when validation loss stagnates.
  5. Grad-CAM Visualization: Highlights key regions of input images that influence predictions, providing model interpretability.
  6. Performance Metrics:
    • Confusion matrix
    • Classification report
    • Training curves for loss, accuracy, and learning rate
  7. Inference: Predicts diseases from unseen images and provides confidence scores for top-3 predictions.

Dataset

The project utilizes the PlantVillage dataset, which contains images of healthy and diseased plant leaves across multiple classes. The dataset structure is expected as follows:

/root_dir
   ├── Class_1
   │       ├── img1.jpg
   │       ├── img2.jpg
   │       └── ...
   ├── Class_2
   │       ├── img1.jpg
   │       ├── img2.jpg
   │       └── ...
   └── ...

Update the root_dir in the code to the location of your dataset.


Requirements

Install the following Python libraries before running the code:

pip install torch torchvision numpy pillow scikit-learn matplotlib seaborn plotly opencv-python

Code Structure Overview

1. Dataset Preparation

  • Class: CustomImageDataset
    • Responsible for loading images from the dataset.
    • Applies transformations (e.g., resizing, normalization, augmentation).
    • Maps class names to corresponding indices for model compatibility.

2. CNN Architecture

  • Class: CNNClassifier
    • Defines a Convolutional Neural Network (CNN) with:
      • Convolutional Layers: Extract spatial features from images.
      • Batch Normalization: Stabilizes training and accelerates convergence.
      • Dropout Layers: Reduces overfitting by randomly disabling neurons.
      • Fully Connected Layers: Perform final classification.
    • Designed for flexibility and robustness in classification tasks.

3. Training

  • Function: train_model
    • Trains the CNN model using a specified dataset.
    • Saves the best model weights based on validation accuracy.
    • Implements Early Stopping:
      • Stops training when validation loss doesn't improve after a pre-defined number of epochs.

4. Evaluation

  • Function: evaluate_model
    • Evaluates the trained model on the test dataset.
    • Generates:
      • A detailed classification report.
      • A confusion matrix to visualize performance across classes.

5. Grad-CAM Visualization

  • Function: visualize_grad_cam
    • Uses Grad-CAM (Gradient-weighted Class Activation Mapping) to:
      • Highlight important regions of an image influencing the model's predictions.
      • Provide insights into the model's decision-making process.

6. Inference

  • Function: predict_disease
    • Predicts labels for unseen images.
    • Displays Top-3 class probabilities for better interpretability.

Training Metrics

Visualization of Training Progress

  • Graphs: Show trends in:
    • Loss (Training & Validation).
    • Accuracy (Training & Validation).
    • Learning Rate Progression.

Example Metrics Analysis

  • Loss Curves:
    • Steady decrease in both training and validation loss.
    • Indicates effective learning and optimization.
  • Accuracy Curves:
    • Achieves over 98% accuracy.
    • Minimal overfitting:
      • Training and validation metrics are closely aligned throughout.


Grad-CAM Visualization

Overview

  • Purpose: Grad-CAM visualizations highlight the critical regions in the input image that significantly influence the model's predictions.
  • Benefit: Adds interpretability to the model and helps validate its predictions by showing where the model is "looking" when making decisions.

Prediction Example

Input A sample diseased tomato leaf image.

Output

  • Predicted Disease: Tomato Bacterial Spot
  • Confidence: 100.00%

Top-3 Predictions:

  1. Tomato_Bacterial_Spot: 100.00%
  2. Tomato_Early_Blight: 0.00%
  3. Tomato_YellowLeaf_Curl_Virus: 0.00%

Grad-CAM Heatmap:

How to Run

Dataset Setup

  1. Download the PlantVillage dataset.
  2. Organize the dataset as described in the Dataset section.
  3. Update the root_dir path in the code to point to your dataset.

Training the Model

Run the following command to train the model:

python main.py

The script will:

  • Train the model and print the training metrics (loss, accuracy).
  • Save the best-performing model weights.
  • Display evaluation metrics such as confusion matrix and classification report.

Prediction on New Images

Replace sample_image_path with the path to your test image and run the inference function to predict the disease and visualize Grad-CAM heatmaps.

Summary

  • Dataset: PlantVillage Dataset
  • Model: Custom CNN with Early Stopping and Grad-CAM for visualization
  • Performance: Achieved high prediction accuracy and strong generalization to unseen data
  • Usage: Suitable for diagnosing plant diseases and assisting farmers with actionable insights