Skip to content

split

openml.tasks.split #

OpenMLSplit #

OpenMLSplit(name: int | str, description: str, split: dict[int, dict[int, dict[int, tuple[ndarray, ndarray]]]])

OpenML Split object.

This class manages train-test splits for a dataset across multiple repetitions, folds, and samples.

Parameters#

name : int or str The name or ID of the split. description : str A description of the split. split : dict A dictionary containing the splits organized by repetition, fold, and sample.

Source code in openml/tasks/split.py
def __init__(
    self,
    name: int | str,
    description: str,
    split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]],
):
    self.description = description
    self.name = name
    self.split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]] = {}

    # Add splits according to repetition
    for repetition in split:
        _rep = int(repetition)
        self.split[_rep] = OrderedDict()
        for fold in split[_rep]:
            self.split[_rep][fold] = OrderedDict()
            for sample in split[_rep][fold]:
                self.split[_rep][fold][sample] = split[_rep][fold][sample]

    self.repeats = len(self.split)

    # TODO(eddiebergman): Better error message
    if any(len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)):
        raise ValueError("")

    self.folds = len(self.split[0])
    self.samples = len(self.split[0][0])

get #

get(repeat: int = 0, fold: int = 0, sample: int = 0) -> tuple[ndarray, ndarray]

Returns the specified data split from the CrossValidationSplit object.

Parameters#

repeat : int Index of the repeat to retrieve. fold : int Index of the fold to retrieve. sample : int Index of the sample to retrieve.

Returns#

numpy.ndarray The data split for the specified repeat, fold, and sample.

Raises#

ValueError If the specified repeat, fold, or sample is not known.

Source code in openml/tasks/split.py
def get(self, repeat: int = 0, fold: int = 0, sample: int = 0) -> tuple[np.ndarray, np.ndarray]:
    """Returns the specified data split from the CrossValidationSplit object.

    Parameters
    ----------
    repeat : int
        Index of the repeat to retrieve.
    fold : int
        Index of the fold to retrieve.
    sample : int
        Index of the sample to retrieve.

    Returns
    -------
    numpy.ndarray
        The data split for the specified repeat, fold, and sample.

    Raises
    ------
    ValueError
        If the specified repeat, fold, or sample is not known.
    """
    if repeat not in self.split:
        raise ValueError(f"Repeat {repeat!s} not known")
    if fold not in self.split[repeat]:
        raise ValueError(f"Fold {fold!s} not known")
    if sample not in self.split[repeat][fold]:
        raise ValueError(f"Sample {sample!s} not known")
    return self.split[repeat][fold][sample]

Split #

Bases: NamedTuple

A single split of a dataset.