simet.restraints.downstream_task.sample_tstr¶
simet.restraints.downstream_task.sample_tstr ¶
SampleTSTRRestraint ¶
SampleTSTRRestraint(lower_bound=0.0, upper_bound=1.0)
Bases: Restraint[float]
Restraint on the SampleTSTR downstream score (Train Synth → Test Real).
Wraps :class:SampleTSTR and checks that the resulting accuracy lies
within the inclusive interval [lower_bound, upper_bound] when those
bounds are provided.
Requirements
- The provided :class:
DatasetLoadermust have been built with adownstream_transformso thatsynth_downstream_dataloaderandreal_downstream_dataloaderare available to the underlying task.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower_bound
|
float | None
|
Minimum acceptable accuracy (inclusive). Defaults to 0.0. |
0.0
|
upper_bound
|
float | None
|
Maximum acceptable accuracy (inclusive). Defaults to 1.0. |
1.0
|
Returns (from apply):
tuple[bool, float]: (passes, value) where value is the TSTR
accuracy in [0.0, 1.0] and passes indicates whether the value
is within the configured bounds.
Example
r = SampleTSTRRestraint(lower_bound=0.70) passed, score = r.apply(loader) passed and score >= 0.70 True
Initialize the TSTR restraint and its underlying metric.
Source code in simet/restraints/downstream_task/sample_tstr.py
42 43 44 45 46 47 48 49 50 51 52 53 | |
apply ¶
apply(loader)
Compute TSTR accuracy and evaluate it against the bounds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loader
|
DatasetLoader
|
Dataset context with downstream dataloaders present. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
tuple[bool, float]: |
float
|
iff the accuracy is within the inclusive bounds. |
Source code in simet/restraints/downstream_task/sample_tstr.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | |