Skip to content

simet.metrics.fid

simet.metrics.fid

FID

FID(base_eps=1e-06)

Bases: Metric[float]

Frechet Inception Distance computed from pre-extracted features.

Expects 2D feature arrays for real and synthetic datasets with the same dimensionality (D). Computes per-set mean and unbiased covariance, then evaluates:

FID = ||μ_r − μ_s||² + Tr(Σ_r + Σ_s − 2·(Σ_r^{1/2} Σ_s Σ_r^{1/2})^{1/2})
Numerical behavior
  • Uses double precision (float64) for statistics by default.
  • Runs linear algebra on CUDA if available; otherwise on CPU.
  • Adds a small diagonal (base_eps) to covariance matrices.
  • Tries a pure-PyTorch eigen/sqrt path with escalating eps; falls back to SciPy’s sqrtm if needed.

Parameters:

Name Type Description Default
base_eps float, default=1e-6

Diagonal regularization added to covariance matrices. Also used as a minimum eigenvalue floor.

1e-06

Attributes:

Name Type Description
_LINALG_DEVICE device

Device for linear algebra ops.

_STAT_DTYPE dtype

Dtype for statistics (default: float64).

_BASE_EPS float

Stabilizing epsilon for covariances/eigenvalues.

Raises:

Type Description
ValueError

If feature arrays are not 2D, have mismatched dims, or have fewer than two samples (unbiased covariance requires n≥2).

Source code in simet/metrics/fid.py
43
44
45
46
47
48
49
50
def __init__(self, base_eps: float = 1e-6) -> None:
    logger.info("Initializing FID metric")
    super().__init__()
    self._LINALG_DEVICE = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    self._STAT_DTYPE = torch.float64
    self._BASE_EPS = float(base_eps)

name property

name

Human-readable metric name.

compute

compute(loader)

Compute FID from loader.real_features and loader.synth_features.

Steps

1) Move features to _LINALG_DEVICE in _STAT_DTYPE. 2) Validate shapes: both 2D and same feature dimension. 3) Compute (μ, Σ) for real and synthetic sets. 4) Compute the sandwich covariance square root and final FID.

Parameters:

Name Type Description Default
loader DatasetLoader

Must expose real_features and synth_features as 2D arrays with the same second dimension.

required

Returns:

Name Type Description
float float

Non-negative FID score (lower is better).

Raises:

Type Description
ValueError

If shapes are invalid or sample counts are < 2.

Source code in simet/metrics/fid.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@override
def compute(self, loader: DatasetLoader) -> float:
    """Compute FID from `loader.real_features` and `loader.synth_features`.

    Steps:
        1) Move features to `_LINALG_DEVICE` in `_STAT_DTYPE`.
        2) Validate shapes: both 2D and same feature dimension.
        3) Compute `(μ, Σ)` for real and synthetic sets.
        4) Compute the sandwich covariance square root and final FID.

    Args:
        loader (DatasetLoader): Must expose `real_features` and
            `synth_features` as 2D arrays with the same second dimension.

    Returns:
        float: Non-negative FID score (lower is better).

    Raises:
        ValueError: If shapes are invalid or sample counts are < 2.
    """
    with torch.no_grad():
        real = torch.as_tensor(
            loader.real_features, device=self._LINALG_DEVICE, dtype=self._STAT_DTYPE
        )
        logger.debug("Loaded real features into tensor")
        synth = torch.as_tensor(
            loader.synth_features,
            device=self._LINALG_DEVICE,
            dtype=self._STAT_DTYPE,
        )
        logger.debug("Loaded synthetic features into tensor")

        logger.debug("Checking feature shapes and types")
        if real.ndim != 2 or synth.ndim != 2:
            logger.error("Features must be 2D [num_samples, feat_dim].")
            raise ValueError("Features must be 2D [num_samples, feat_dim].")
        if real.size(1) != synth.size(1):
            logger.error(
                "Real and synthetic features must have the same dimensionality."
            )
            raise ValueError(
                "Real and synthetic features must have the same dimensionality."
            )
        if real.size(0) < 2 or synth.size(0) < 2:
            logger.error(
                "Need at least two samples per set to compute unbiased covariance."
            )
            raise ValueError(
                "Need at least two samples per set to compute unbiased covariance."
            )

        real_mu, real_sigma = self._compute_statistics(real)
        synth_mu, synth_sigma = self._compute_statistics(synth)
        logger.info("Computed statistics for real and synthetic features")

        fid = self._compute_fid(real_mu, real_sigma, synth_mu, synth_sigma)
        logger.info(f"Computed FID: {fid}")

        return fid