Skip to content

simet.metrics.downstream_task.sample.sample_trts

simet.metrics.downstream_task.sample.sample_trts

SampleTRTS

Bases: SampleDownstreamTask

Toy downstream task using Train on Real, Test on Synth splits.

Trains the underlying model on loader.real_downstream_dataloader and evaluates on loader.synth_downstream_dataloader, returning the final accuracy reported by the parent implementation.

Requirements
  • DatasetLoader must have been constructed with a downstream_transform so that real_downstream_dataloader and synth_downstream_dataloader are available. Otherwise, an AttributeError will be raised when accessing these attributes.
Example

task = SampleTRTS() score = task.compute(loader) # uses real for train, synth for test 0.0 <= score <= 1.0 True

name property

name

Human-readable task name.

compute

compute(loader)

Train on real, test on synth, and return accuracy.

Parameters:

Name Type Description Default
loader DatasetLoader

Must expose real_downstream_dataloader and synth_downstream_dataloader (present only if a downstream_transform was provided).

required

Returns:

Name Type Description
float float

Final test accuracy on the synthetic set in [0.0, 1.0].

Source code in simet/metrics/downstream_task/sample/sample_trts.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@override
def compute(self, loader: DatasetLoader) -> float:
    """Train on real, test on synth, and return accuracy.

    Args:
        loader (DatasetLoader): Must expose
            `real_downstream_dataloader` and `synth_downstream_dataloader`
            (present only if a `downstream_transform` was provided).

    Returns:
        float: Final test accuracy on the synthetic set in `[0.0, 1.0]`.
    """
    return self._compute(
        train_set=loader.real_downstream_dataloader,
        test_set=loader.synth_downstream_dataloader,
    )