Skip to content

simet.restraints.fid

simet.restraints.fid

FIDRestraint

FIDRestraint(lower_bound=0.0, upper_bound=500.0)

Bases: Restraint[float]

Restraint on the FID score computed from real vs. synthetic features.

Wraps :class:FID and checks that the resulting score falls within the inclusive interval [lower_bound, upper_bound] when those bounds are set.

Semantics
  • Lower FID is typically better. By default this restraint accepts values in [0.0, 500.0]. Adjust bounds per your quality targets.
Requirements
  • The provided :class:DatasetLoader must expose real_features and synth_features as 2D arrays with matching feature dimension.

Parameters:

Name Type Description Default
lower_bound float | None

Minimum acceptable FID (inclusive). Defaults to 0.0.

0.0
upper_bound float | None

Maximum acceptable FID (inclusive). Defaults to 500.0.

500.0

Returns (from apply): tuple[bool, float]: (passes, value) where value is the FID (non-negative) and passes indicates whether it lies within bounds.

Initialize the FID restraint and its underlying metric.

Source code in simet/restraints/fid.py
36
37
38
39
40
41
42
43
44
@override
def __init__(
    self,
    lower_bound: float | None = 0.0,
    upper_bound: float | None = 500.0,
) -> None:
    """Initialize the FID restraint and its underlying metric."""
    super().__init__(lower_bound, upper_bound)
    self.metric = FID()

apply

apply(loader)

Compute FID and evaluate it against the configured bounds.

Parameters:

Name Type Description Default
loader DatasetLoader

Must provide real_features and synth_features suitable for the FID computation.

required

Returns:

Type Description
bool

tuple[bool, float]: (passes, fid) where passes is True iff

float

lower_bound <= fid <= upper_bound (treating None as unbounded).

Source code in simet/restraints/fid.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@override
def apply(self, loader: DatasetLoader) -> tuple[bool, float]:
    """Compute FID and evaluate it against the configured bounds.

    Args:
        loader (DatasetLoader): Must provide `real_features` and `synth_features`
            suitable for the FID computation.

    Returns:
        tuple[bool, float]: ``(passes, fid)`` where ``passes`` is True iff
        ``lower_bound <= fid <= upper_bound`` (treating `None` as unbounded).
    """
    fid = self.metric.compute(loader)
    lower_ok = self.lower_bound is None or fid >= self.lower_bound
    upper_ok = self.upper_bound is None or fid <= self.upper_bound
    passes = lower_ok and upper_ok
    logger.info(f"FID Restraint passes: {passes}")
    return passes, fid