Skip to content

simet.metrics.precision_recall

simet.metrics.precision_recall

PrecisionRecall

PrecisionRecall(
    *,
    metric="l2",
    index_type="flat",
    nlist=1024,
    use_gpu=True,
    num_gpus=None,
    batch_size=None,
    use_fp16=False,
    random_state=1234,
)

Bases: Metric[tuple[float, float]]

Precision/Recall between real and synthetic feature sets via FAISS k-NN.

Computes
  • Precision: fraction of synthetic samples whose 1-NN in the real set lies within the synthetic sample’s real-set k-th neighbor radius.
  • Recall: fraction of real samples whose 1-NN in the synthetic set lies within the real sample’s synthetic-set k-th neighbor radius.

Distance & index backends: - metric="l2": uses squared L2 distances. - metric="cosine": uses inner-product search on L2-normalized vectors (we L2-normalize both sets in-place) and converts sims to cosine distances as 1 - sim. - index_type="flat": exact search (IndexFlatL2 / IndexFlatIP). - index_type="ivf": coarse-quantized IVF (IndexIVFFlat), requires training.

GPU
  • If use_gpu=True and GPUs are available, the FAISS index is cloned to GPU.
  • With num_gpus > 1, uses a sharded multi-GPU index.
  • use_fp16=True stores/searches in fp16 on GPU (memory/speed trade-offs).

Parameters:

Name Type Description Default
metric Literal['l2', 'cosine']

Distance/similarity type. Defaults to "l2".

'l2'
index_type Literal['flat', 'ivf']

FAISS index type. Defaults to "flat".

'flat'
nlist int

Number of IVF lists if index_type="ivf". Defaults to 1024.

1024
use_gpu bool

Enable GPU indices when available. Defaults to True.

True
num_gpus int | None

Number of GPUs to use (None→all). Defaults to None.

None
batch_size int | None

Query batch size for FAISS search (None→auto). Defaults to None.

None
use_fp16 bool

Use fp16 storage/search on GPU. Defaults to False.

False
random_state int | None

Seed for FAISS training (IVF). Defaults to 1234.

1234
Notes
  • If either feature set is empty, returns (0.0, 0.0) with a warning.
  • If k ≥ min(|real|, |synth|), k is clamped down to avoid degeneracy.
  • For metric="cosine", features are L2-normalized in place.
Source code in simet/metrics/precision_recall.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    *,
    metric: MetricType = "l2",
    index_type: IndexType = "flat",
    nlist: int = 1024,  # IVF lists if index_type="ivf"
    use_gpu: bool = True,
    num_gpus: Optional[int] = None,  # None => all available
    batch_size: Optional[int] = None,  # None => faiss decides
    use_fp16: bool = False,  # GPU only: store/search in fp16
    random_state: Optional[int] = 1234,
) -> None:
    logger.info("Initializing Precision/Recall metric")
    super().__init__()
    self.metric = metric
    self.index_type = index_type
    self.nlist = int(nlist)
    self.use_gpu = bool(use_gpu)
    self.num_gpus = num_gpus
    self.batch_size = batch_size
    self.use_fp16 = use_fp16
    self.random_state = random_state

name property

name

Human-readable metric name.

compute

compute(loader, k=5)

Compute (precision, recall) between real/synth feature sets.

Steps

1) Validate shapes: both 2D and same D. 2) (Cosine only) L2-normalize real and synth in place. 3) Build k-th neighbor radii per-domain (real→real, synth→synth). 4) Precision: 1-NN of synth in real vs real radii. 5) Recall: 1-NN of real in synth vs synth radii.

Parameters:

Name Type Description Default
loader DatasetLoader

Provides real_features and synth_features arrays.

required
k int

Neighborhood size for radii (k-th neighbor). Defaults to 5.

5

Returns:

Type Description
Tuple[float, float]

Tuple[float, float]: (precision, recall) each in [0.0, 1.0].

Raises:

Type Description
ValueError

On invalid shapes, NaN/Inf, or invalid k.

Notes
  • If k >= min(len(real), len(synth)), k will be clamped to min(n)-1.
  • For IVF indices, training is performed on the same domain used to build radii/DB.
Source code in simet/metrics/precision_recall.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@override
def compute(self, loader: DatasetLoader, k: int = 5) -> Tuple[float, float]:
    """Compute `(precision, recall)` between real/synth feature sets.

    Steps:
        1) Validate shapes: both 2D and same `D`.
        2) (Cosine only) L2-normalize `real` and `synth` **in place**.
        3) Build **k-th neighbor radii** per-domain (real→real, synth→synth).
        4) Precision: 1-NN of synth in real vs real radii.
        5) Recall: 1-NN of real in synth vs synth radii.

    Args:
        loader (DatasetLoader): Provides `real_features` and `synth_features` arrays.
        k (int, optional): Neighborhood size for radii (k-th neighbor). Defaults to 5.

    Returns:
        Tuple[float, float]: `(precision, recall)` each in `[0.0, 1.0]`.

    Raises:
        ValueError: On invalid shapes, NaN/Inf, or invalid `k`.

    Notes:
        - If `k >= min(len(real), len(synth))`, `k` will be clamped to `min(n)-1`.
        - For IVF indices, training is performed on the same domain used to build radii/DB.
    """
    real = np.ascontiguousarray(loader.real_features, dtype=np.float32)
    synth = np.ascontiguousarray(loader.synth_features, dtype=np.float32)
    logger.debug("Loaded features from loader")

    logger.debug("Checking feature shapes and types")
    if real.ndim != 2 or synth.ndim != 2:
        logger.error("Features must be 2D: (n_samples, n_dims).")
        raise ValueError("Features must be 2D: (n_samples, n_dims).")
    if real.shape[1] != synth.shape[1]:
        logger.error(
            "Real and synthetic features must have the same dimensionality."
        )
        raise ValueError(
            f"Dim mismatch: real {real.shape[1]} vs synth {synth.shape[1]}"
        )
    if len(real) == 0 or len(synth) == 0:
        logger.warning("One of the feature sets is empty; returning 0.0, 0.0")
        return 0.0, 0.0
    if not np.isfinite(real).all() or not np.isfinite(synth).all():
        logger.error("Features contain NaN/Inf.")
        raise ValueError("Features contain NaN/Inf.")
    if k < 1:
        logger.error("k must be >= 1.")
        raise ValueError("k must be >= 1.")
    if k >= len(real) or k >= len(synth):
        logger.warning(
            "k is greater than or equal to the number of samples; clamping."
        )
        # With k >= n, radius degenerates; clamp safely
        k = max(1, min(k, len(real) - 1, len(synth) - 1))

    if self.metric == "cosine":
        logger.debug("Normalizing features for cosine metric")
        PrecisionRecallService.safe_norm(real)
        PrecisionRecallService.safe_norm(synth)

    # --- build radii for each domain
    logger.debug("Computing k-th neighbor radii")
    real_rad = self._kth_neighbor_radius(real, k=k)
    synth_rad = self._kth_neighbor_radius(synth, k=k)

    # --- precision: synth -> real (1-NN)
    logger.debug("Computing precision and recall")
    d_sr, idx_sr = self._nn_search(real, synth, k=1)
    # Compare squared L2 (or cosine distance converted from IP below)
    precision = float((d_sr[:, 0] <= real_rad[idx_sr[:, 0]]).mean())
    logger.info(f"Computed precision: {precision}")

    # --- recall: real -> synth (1-NN)
    d_rs, idx_rs = self._nn_search(synth, real, k=1)
    recall = float((d_rs[:, 0] <= synth_rad[idx_rs[:, 0]]).mean())
    logger.info(f"Computed recall: {recall}")

    return precision, recall