Skip to content

simet.dataset_loaders.dataset_loader

simet.dataset_loaders.dataset_loader

DatasetLoader

DatasetLoader(
    real_provider,
    synth_provider,
    provider_transform=None,
    feature_extractor=None,
    downstream_transform=None,
)

Load and preprocess real vs. synthetic datasets, and compute features.

Provides a unified interface to

1) subsample providers consistently, 2) build DataLoaders with a given Transform, and 3) extract feature arrays for downstream evaluation (e.g., FID, Precision/Recall, ROC-AUC).

Parameters:

Name Type Description Default
real_provider Provider

Provider that yields the real dataset (e.g., images). Must implement get_data(transform) returning a torch.utils.data.Dataset.

required
synth_provider Provider

Provider that yields the synthetic dataset, same interface as real_provider.

required
provider_transform Transform | None

Transform applied to both providers for feature extraction. If not provided, defaults to :class:InceptionTransform. This transform controls what the feature extractor will “see” (e.g., resizing/normalization for Inception).

None
feature_extractor FeatureExtractor | None

Feature extractor used to compute real_features and synth_features. Defaults to :class:InceptionFeatureExtractor.

None
downstream_transform Transform | None

Optional transform for downstream tasks (separate from feature extraction). When given, the loader also builds real_downstream_dataloader and synth_downstream_dataloader using this transform.

None

Attributes:

Name Type Description
provider_transform Transform

The transform actually used for feature extraction (defaults to InceptionTransform if None).

real_dataloader DataLoader[VisionDataset]

Dataloader for the real dataset with provider_transform.

synth_dataloader DataLoader[VisionDataset]

Dataloader for the synthetic dataset with provider_transform.

real_features ndarray

Feature matrix extracted from real_dataloader. Shape typically (N_real, D).

synth_features ndarray

Feature matrix extracted from synth_dataloader. Shape typically (N_synth, D).

real_downstream_dataloader DataLoader[VisionDataset]

Present only if downstream_transform is provided; uses that transform.

synth_downstream_dataloader DataLoader[VisionDataset]

Present only if downstream_transform is provided; uses that transform.

Notes
  • Providers are first subsampled consistently via SubsamplingService.subsample(...) using provider_transform, then converted into DataLoaders.
  • Feature extraction is performed eagerly at construction time by feature_extractor.
  • Private helpers (e.g., _to_dataloader, _compute_features) are internal and not part of the public API.

Initialize the DatasetLoader, see class documentation for details.

Source code in simet/dataset_loaders/dataset_loader.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    real_provider: Provider,
    synth_provider: Provider,
    provider_transform: Transform | None = None,
    feature_extractor: FeatureExtractor | None = None,
    downstream_transform: Transform | None = None,
) -> None:
    """
    Initialize the DatasetLoader, see class documentation for details.
    """
    self.provider_transform = provider_transform or InceptionTransform()
    real_provider, synth_provider = SubsamplingService.subsample(
        real_provider, synth_provider, self.provider_transform
    )
    self.real_dataloader = self._to_dataloader(real_provider, self.provider_transform)
    self.synth_dataloader = self._to_dataloader(synth_provider, self.provider_transform)
    feature_extractor = feature_extractor or InceptionFeatureExtractor()
    self._compute_features(feature_extractor)

    # Downstream task specifics
    if downstream_transform:
        self.real_downstream_dataloader = self._to_dataloader(real_provider, downstream_transform)
        self.synth_downstream_dataloader = self._to_dataloader(synth_provider, downstream_transform)