Skip to content

simet.restraints.downstream_task.sample_tstr

simet.restraints.downstream_task.sample_tstr

SampleTSTRRestraint

SampleTSTRRestraint(lower_bound=0.0, upper_bound=1.0)

Bases: Restraint[float]

Restraint on the SampleTSTR downstream score (Train Synth → Test Real).

Wraps :class:SampleTSTR and checks that the resulting accuracy lies within the inclusive interval [lower_bound, upper_bound] when those bounds are provided.

Requirements
  • The provided :class:DatasetLoader must have been built with a downstream_transform so that synth_downstream_dataloader and real_downstream_dataloader are available to the underlying task.

Parameters:

Name Type Description Default
lower_bound float | None

Minimum acceptable accuracy (inclusive). Defaults to 0.0.

0.0
upper_bound float | None

Maximum acceptable accuracy (inclusive). Defaults to 1.0.

1.0

Returns (from apply): tuple[bool, float]: (passes, value) where value is the TSTR accuracy in [0.0, 1.0] and passes indicates whether the value is within the configured bounds.

Example

r = SampleTSTRRestraint(lower_bound=0.70) passed, score = r.apply(loader) passed and score >= 0.70 True

Initialize the TSTR restraint and its underlying metric.

Source code in simet/restraints/downstream_task/sample_tstr.py
42
43
44
45
46
47
48
49
50
51
52
53
@override
def __init__(
    self, 
    lower_bound: float | None = 0.0,
    upper_bound: float | None = 1.0,
    ) -> None:
    """Initialize the TSTR restraint and its underlying metric."""
    super().__init__(lower_bound, upper_bound)
    # Propagate bounds explicitly (since we're not passing them to super()).
    self.lower_bound = lower_bound
    self.upper_bound = upper_bound
    self.metric = SampleTSTR()

apply

apply(loader)

Compute TSTR accuracy and evaluate it against the bounds.

Parameters:

Name Type Description Default
loader DatasetLoader

Dataset context with downstream dataloaders present.

required

Returns:

Type Description
bool

tuple[bool, float]: (passes, accuracy) where passes is True

float

iff the accuracy is within the inclusive bounds.

Source code in simet/restraints/downstream_task/sample_tstr.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@override
def apply(self, loader: DatasetLoader) -> tuple[bool, float]:
    """Compute TSTR accuracy and evaluate it against the bounds.

    Args:
        loader: Dataset context with downstream dataloaders present.

    Returns:
        tuple[bool, float]: ``(passes, accuracy)`` where ``passes`` is True
        iff the accuracy is within the inclusive bounds.
    """
    fid = self.metric.compute(loader)
    lower_ok = self.lower_bound is None or fid >= self.lower_bound
    upper_ok = self.upper_bound is None or fid <= self.upper_bound
    passes = lower_ok and upper_ok
    logger.info(f"SampleTSTR restraint passes: {passes}")
    return passes, fid