The Effect of Intrinsic Dataset Properties on Generalization:
Unraveling Learning Differences Between Natural and Medical Images

Duke University
ICLR 2024
Teaser figure showing generalization scaling laws across imaging domains

Neural networks generalize differently across imaging domains. Generalization error increases with the intrinsic dimension ddata of the training set, but the steepness of this relationship differs markedly between natural and medical images. We explain this discrepancy through our proposed label sharpness metric, and derive a theoretical generalization scaling law unifying both domains.

Abstract

This paper investigates discrepancies in how neural networks learn from different imaging domains, which are commonly overlooked when adopting computer vision techniques from the domain of natural images to other specialized domains such as medical images. Recent works have found that the generalization error of a trained network typically increases with the intrinsic dimension (ddata) of its training set. Yet, the steepness of this relationship varies significantly between medical (radiological) and natural imaging domains, with no existing theoretical explanation.


We address this gap in knowledge by establishing and empirically validating a generalization scaling law with respect to ddata, and propose that the substantial scaling discrepancy between the two considered domains may be at least partially attributed to the higher intrinsic label sharpness (KF) of medical imaging datasets, a metric which we propose. Next, we demonstrate an additional benefit of measuring the label sharpness of a training set: it is negatively correlated with the trained model's adversarial robustness, which notably leads to models for medical images having a substantially higher vulnerability to adversarial attack. Finally, we extend our ddata formalism to the related metric of learned representation intrinsic dimension (drepr), derive a generalization scaling law with respect to drepr, and show that ddata serves as an upper bound for drepr. Our theoretical results are supported by thorough experiments with six models and eleven natural and medical imaging datasets over a range of training set sizes.

Key Contributions

1. Generalization Scaling Law.  We derive a theoretical generalization scaling law as a function of the training set's intrinsic dimension ddata, and empirically validate it across natural and medical imaging datasets.

2. Label Sharpness (KF).  We propose label sharpness, a new dataset metric measuring the extent to which images can resemble each other while still having different labels. Medical imaging datasets have systematically higher label sharpness than natural image datasets, which we show partially explains the observed domain discrepancy in generalization scaling.

3. Adversarial Robustness Connection.  Label sharpness is negatively correlated with a model's adversarial robustness — meaning medical image models are substantially more vulnerable to adversarial attack than their natural image counterparts.

4. Representation Intrinsic Dimension.  We extend our framework to the intrinsic dimension of learned representations (drepr), deriving a corresponding scaling law and showing that ddata serves as an upper bound for drepr.

5. Large-Scale Empirical Validation.  Results are validated with 6 network architectures and 11 natural and medical imaging datasets across a range of training set sizes.

Code

Our codebase lets you measure intrinsic properties of any PyTorch dataset: label sharpness F, data intrinsic dimension ddata, and representation intrinsic dimension drepr.

from datasetproperties import compute_labelsharpness, compute_intrinsic_datadim, compute_intrinsic_reprdim

from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import Subset

# load your dataset
dataset = CIFAR10(root='data', download=True, transform=ToTensor())
classes = [0, 1]
dataset = Subset(dataset, [i for i, s in enumerate(dataset) if s[1] in classes])

# compute intrinsic properties
KF       = compute_labelsharpness(dataset)
datadim  = compute_intrinsic_datadim(dataset)

model    = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to("cuda")
reprdim  = compute_intrinsic_reprdim(dataset, model, model.layer4)

print(f"Label sharpness      = {KF:.3f}")
print(f"Data intrinsic dim   = {int(datadim)}")
print(f"Repr intrinsic dim   = {int(reprdim)}")

Related Work

BibTeX

@inproceedings{konz2024intrinsicproperties,
  title     = {The Effect of Intrinsic Dataset Properties on Generalization:
               Unraveling Learning Differences Between Natural and Medical Images},
  author    = {Konz, Nicholas and Mazurowski, Maciej A},
  booktitle = {The Twelfth International Conference on Learning Representations (ICLR)},
  year      = {2024},
  url       = {https://openreview.net/forum?id=ixP76Y33y1}
}