Skip to content

simet.parser.restraint

simet.parser.restraint

RestraintParser

Factory wrapper that builds restraint objects from config dicts.

Converts a user/config value (e.g., from YAML/JSON) into a :class:RestraintType and delegates construction to :meth:RestraintType.get_restraint.

parse_restraint staticmethod

parse_restraint(restraint_data)

Parse config and return a concrete restraint instance.

Parameters:

Name Type Description Default
restraint_data dict

Mapping that must include the key "type" matching a :class:RestraintType value (e.g., "FIDRestraint"), plus optional "upper_bound" and "lower_bound" fields whose shape depends on the restraint: - FID / RocAuc: float | None - PrecisionRecall: Sequence[float] | None for each bound

required

Returns:

Name Type Description
Restraint Restraint

An instance of the requested restraint class.

Raises:

Type Description
KeyError

If "type" is missing.

ValueError

If the "type" value is unknown/unsupported.

Source code in simet/parser/restraint.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@staticmethod
def parse_restraint(restraint_data: dict) -> Restraint:
    """Parse config and return a concrete restraint instance.

    Args:
        restraint_data (dict):
            Mapping that must include the key `"type"` matching a
            :class:`RestraintType` value (e.g., `"FIDRestraint"`), plus
            optional `"upper_bound"` and `"lower_bound"` fields whose
            shape depends on the restraint:
              - FID / RocAuc: `float | None`
              - PrecisionRecall: `Sequence[float] | None` for each bound

    Returns:
        Restraint: An instance of the requested restraint class.

    Raises:
        KeyError: If `"type"` is missing.
        ValueError: If the `"type"` value is unknown/unsupported.
    """
    restraint_type = RestraintType(restraint_data["type"])
    return RestraintType.get_restraint(restraint_type, restraint_data)

RestraintType

Bases: StrEnum

Enum of supported restraint identifiers (string-valued).

get_restraint staticmethod

get_restraint(restraint_type, restraint_data)

Construct a restraint instance for the given enum value.

Expects restraint_data to provide upper_bound and/or lower_bound as appropriate for the metric. For Precision/Recall, list-like bounds are coerced to tuples.

Parameters:

Name Type Description Default
restraint_type RestraintType

Enum indicating which restraint to build.

required
restraint_data dict

Source config (typically parsed from YAML/JSON).

required

Returns:

Name Type Description
Restraint Restraint

A concrete restraint instance: - FIDRestraint(upper_bound: float | None, lower_bound: float | None) - PrecisionRecallRestraint(upper_bound: tuple[float, ...] | None, lower_bound: tuple[float, ...] | None) - RocAucRestraint(upper_bound: float | None, lower_bound: float | None)

Raises:

Type Description
ValueError

If the restraint_type is unknown (logged and re-raised).

Notes
  • upper_bound / lower_bound may be absent or None. Accessing them directly (e.g., restraint_data["upper_bound"]) will raise KeyError; adapt if you plan to make them optional.
  • Ensure the order/length of Precision/Recall bounds match that restraint’s output semantics.
Source code in simet/parser/restraint.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@staticmethod
def get_restraint(restraint_type: "RestraintType", restraint_data: dict) -> Restraint:
    """Construct a restraint instance for the given enum value.

    Expects `restraint_data` to provide `upper_bound` and/or `lower_bound`
    as appropriate for the metric. For Precision/Recall, list-like bounds
    are coerced to tuples.

    Args:
        restraint_type (RestraintType): Enum indicating which restraint to build.
        restraint_data (dict): Source config (typically parsed from YAML/JSON).

    Returns:
        Restraint: A concrete restraint instance:
            - `FIDRestraint(upper_bound: float | None, lower_bound: float | None)`
            - `PrecisionRecallRestraint(upper_bound: tuple[float, ...] | None, lower_bound: tuple[float, ...] | None)`
            - `RocAucRestraint(upper_bound: float | None, lower_bound: float | None)`

    Raises:
        ValueError: If the `restraint_type` is unknown (logged and re-raised).

    Notes:
        - `upper_bound` / `lower_bound` may be absent or `None`. Accessing
          them directly (e.g., `restraint_data["upper_bound"]`) will raise
          `KeyError`; adapt if you plan to make them optional.
        - Ensure the order/length of Precision/Recall bounds match that
          restraint’s output semantics.
    """
    try:
        match restraint_type:
            case RestraintType.FID:
                return FIDRestraint(
                    upper_bound=restraint_data["upper_bound"],
                    lower_bound=restraint_data["lower_bound"],
                )
            case RestraintType.PRECISIONRECALL:
                return PrecisionRecallRestraint(
                    upper_bound=None if restraint_data["upper_bound"] is None
                    else tuple(restraint_data["upper_bound"]),
                    lower_bound=None if restraint_data["lower_bound"] is None
                    else tuple(restraint_data["lower_bound"]),
                )
            case RestraintType.ROCAUC:
                return RocAucRestraint(
                    upper_bound=restraint_data["upper_bound"],
                    lower_bound=restraint_data["lower_bound"],
                )
    except ValueError as e:
        logger.error(f"Unknown restraint type: {restraint_type}")
        raise ValueError(f"Unknown restraint type: {restraint_type}") from e