Introduction

Coral reefs are essential to marine biodiversity, yet they are under increasing threat due to climate change, pollution, and human activities. Accurate monitoring of coral health is vital to conservation efforts, but manual methods can be labor-intensive and subjective. Deep learning offers a scalable and automated solution for coral classification, providing consistent and interpretable results. In this project, a pre-trained ResNet50 architecture was fine-tuned to classify coral health into three categories: bleached, healthy, and partially bleached. The model's performance was evaluated in two iterations—an initial baseline model and a second improved model. This report explores the technical reasoning behind key decisions, analyzes the results, and discusses the implications of this approach for coral reef conservation.

Data

The dataset was acquired using the Roboflow API, containing images categorized into three health states: bleached, healthy, and partially bleached. An initial analysis of the class distribution revealed a significant imbalance, with the 'bleached' class overrepresented and the 'partially bleached' class underrepresented. This imbalance posed a clear challenge: models trained on such data are likely to perform poorly on underrepresented categories, as evidenced in the initial confusion matrix (Figure 3).

Training Class Distribution
Figure 1. Distribution of classes in the dataset

Figure 1. shows the distribution of training data across categories. The disparity in category frequencies explains the poor performance of the initial model on the 'partially bleached' class, as the model had fewer examples to learn distinguishing features for this category.

Visualization of coral dataset sample
Figure 2. Visualization of coral dataset sample

To prepare the images for the model, preprocessing steps included resizing all images to 224x224 pixels, which is the standard input size for ResNet50, followed by data augmentation techniques such as random horizontal flipping and color jittering. These augmentations aimed to improve the model's ability to generalize to new data by simulating variability in lighting and orientation. Finally, normalization was applied to align pixel intensity distributions with the ImageNet dataset on which ResNet50 was pre-trained.


        # Dataset Loading and Preprocessing
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        

Modelling

The baseline model used a pre-trained ResNet50 architecture, with its final fully connected layer replaced by a custom classifier tailored to the three coral health categories. The model was trained for 10 epochs using the Adam optimizer with an initial learning rate of 0.001. A StepLR scheduler reduced the learning rate by a factor of 0.1 after 5 epochs to facilitate convergence.

Confusion Matrix - Initial Model
Figure 3. Confusion Matrix - Initial Model

While the training loss consistently decreased, indicating that the model was learning from the data, the validation loss plateaued early. This gap between training and validation performance suggested overfitting, where the model was overly reliant on the training data and unable to generalize effectively. The initial confusion matrix confirms these shortcomings, with significant misclassifications in the 'partially bleached' category. The model frequently confused 'partially bleached' with 'bleached,' likely due to overlapping visual features such as pale or degraded coral tissue. In contrast, the 'healthy' category performed relatively well, as its vibrant colors and well-defined coral structures were easier for the model to distinguish.

Training and Validation Loss - Initial Model
Figure 4. Training and Validation Loss - Initial Model

To address the limitations of the baseline model, several refinements were introduced. A key enhancement was the incorporation of class weights into the loss function during fine-tuning. These weights, calculated as the inverse of class frequencies, prioritized learning from underrepresented categories such as 'partially bleached.' This adjustment directly addressed the imbalance in the dataset, ensuring the model did not disproportionately favor the dominant 'bleached' class. Regularization techniques were also applied to improve generalization. A dropout layer with a rate of 50% was added to the fully connected layer, reducing the risk of overfitting by preventing the model from relying too heavily on specific neurons. Additionally, the optimizer was upgraded to AdamW, which incorporates weight decay to penalize large parameter values, further promoting generalization. The learning rate scheduler was replaced with OneCycleLR, which adjusts the learning rate dynamically throughout training, facilitating smoother convergence and reducing the risk of stagnation.


        # Adding Dropout Layer to Reduce Overfitting
        # Replace the fully connected (fc) layer with a Sequential layer that includes Dropout
        model.fc = nn.Sequential(
          nn.Dropout(0.5),  # Dropout rate of 50% to randomly zero out neurons during training
          nn.Linear(model.fc.in_features, num_classes)  # Fully connected layer for classification
        )
      

Results

These adjustments resulted in significant performance improvements, as evidenced by the updated confusion matrix (Figure 5). The model demonstrated better recall for the 'partially bleached' category, reducing confusion with 'bleached' and 'healthy.' This improvement underscores the effectiveness of weighted loss and fine-tuning techniques in addressing dataset imbalance.

Confusion matrix at epoch 10
Figure 5. Confusion matrix at epoch 10

The fine-tuned model exhibited smoother training and validation loss curves, as shown in Figure 6. Unlike the baseline model, the validation loss closely mirrored the training loss, indicating improved generalization. The integration of class weights ensured that the model paid more attention to underrepresented categories, leading to balanced performance across all classes.

Training and validation loss curves
Figure 6. Training and validation loss curves

The training and validation loss curves, depicted above, show the model’s learning progression over 10 epochs. The training loss decreased steadily throughout, indicating that the model was learning effectively from the training data. However, the validation loss initially dropped but began to plateau around epoch 4 and slightly increased toward epoch 9. This divergence between training and validation loss suggests the onset of overfitting, where the model performs well on the training set but struggles to generalize to unseen validation data.

Grad-CAM

Grad-CAM (Gradient-weighted Class Activation Mapping) was employed to visualize the areas of the input images that influenced the model’s predictions. These heatmaps (Figure 7) provide insights into the decision-making process, revealing that the model consistently focused on coral regions rather than background noise. However, misclassifications highlighted by Grad-CAM suggest that mixed health states, where a coral exhibits features of both 'healthy' and 'partially bleached,' remain a challenge for the model.

Grad-CAM Visualizations
Figure 7. Grad-CAM Visualizationss

Conclusion

This project demonstrates the potential of deep learning for automated coral health classification, addressing key challenges such as class imbalance and overfitting. The baseline model highlighted the limitations of training on imbalanced datasets, with poor performance on underrepresented categories such as 'partially bleached.' By incorporating class weights, dropout regularization, and advanced optimization techniques, the fine-tuned model achieved substantial improvements, as reflected in the confusion matrix and loss curves. Grad-CAM visualizations provided valuable insights into the model's behavior, confirming that it learned to focus on relevant features while also identifying areas for further refinement. These findings underscore the importance of interpretability in deep learning, particularly for ecological applications where accuracy and transparency are critical. Future directions include expanding the dataset to cover a wider range of coral reef environments and integrating the model into a real-time monitoring system. Such advancements will enhance the ability of conservationists to monitor and protect these vital ecosystems.

References

  1. K. He, X. Zhang, S. Ren, and J. Sun, "Deep residual learning for image recognition," in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770–778. Available: https://arxiv.org/abs/1512.03385.
  2. R. R. Selvaraju, "Grad-CAM library documentation," 2024. Available: https://github.com/jacobgil/pytorch-grad-cam.