Skip to content

Callbacks

Callbacks module contains classes and functions for handling callback functions during an event-driven process. This makes it easier to customize the behavior of the training loop and add additional functionality to the training process without modifying the core code.

To use a callback, create a class that inherits from the Callback class and implement the necessary methods. Callbacks can be used to perform actions at different stages of the training process, such as at the beginning or end of an epoch, batch, or fitting process. Then pass the callback object to the Trainer.

How to Use:

1
2
3
4
5
6
trainer = OpenMLTrainerModule(
    data_module=data_module,
    verbose = True,
    epoch_count = 1,
    callbacks=[ <insert your callback class name here> ],
)

To add a custom parameter, for example to add a different metric to the AvgStatsCallBack. ```python trainer = OpenMLTrainerModule( data_module=data_module, verbose = True, epoch_count = 1, callbacks=[ AvgStatsCallBack([accuracy]) ], )

Useful Callbacks:

  • TestCallback: Use when you are testing out new code and want to iterate through the training loop quickly. Stops training after 2 iterations.

AvgStats

AvgStats class is used to track and accumulate average statistics (like loss and other metrics) during training and validation phases.

Attributes:

Name Type Description
metrics list

A list of metric functions to be tracked.

in_train bool

A flag to indicate if the statistics are for the training phase.

Methods:

Name Description
__init__

Initializes the AvgStats with metrics and in_train flag.

reset

Resets the accumulated statistics.

all_stats

Property that returns all accumulated statistics including loss and metrics.

avg_stats

Property that returns the average of the accumulated statistics.

accumulate

Accumulates the statistics using the data from the given run.

__repr__

Returns a string representation of the average statistics.

Source code in openml_pytorch/callbacks/recording.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class AvgStats:
    """
    AvgStats class is used to track and accumulate average statistics (like loss and other metrics) during training and validation phases.

    Attributes:
        metrics (list): A list of metric functions to be tracked.
        in_train (bool): A flag to indicate if the statistics are for the training phase.

    Methods:
        __init__(metrics, in_train):
            Initializes the AvgStats with metrics and in_train flag.

        reset():
            Resets the accumulated statistics.

        all_stats:
            Property that returns all accumulated statistics including loss and metrics.

        avg_stats:
            Property that returns the average of the accumulated statistics.

        accumulate(run):
            Accumulates the statistics using the data from the given run.

        __repr__():
            Returns a string representation of the average statistics.
    """

    def __init__(self, metrics, in_train):
        self.metrics, self.in_train = listify(metrics), in_train

    def reset(self):
        self.tot_loss, self.count = 0.0, 0
        self.tot_mets = [0.0] * len(self.metrics)

    @property
    def all_stats(self):
        return [self.tot_loss.item()] + self.tot_mets

    @property
    def avg_stats(self):
        return [o / self.count for o in self.all_stats]

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

    def __repr__(self):
        if not self.count:
            return ""
        phase = "train" if self.in_train else "valid"
        try:
            return f"{phase} loss: {self.avg_stats[0]:.4f} | accuracy: {self.avg_stats[1]:.4f} | other metrics: {self.avg_stats[2:]}"
        except IndexError:
            return f"{phase} loss: {self.avg_stats[0]:.4f} | other metrics: {self.avg_stats[1:]}"

AvgStatsCallback

Bases: Callback

AvgStatsCallBack class is a custom callback used to track and print average statistics for training and validation phases during the training loop.

Parameters:

Name Type Description Default
metrics

A list of metric functions to evaluate during training and validation.

required

Methods:

Name Description
__init__

Initializes the callback with given metrics and sets up AvgStats objects for both training and validation phases.

begin_epoch

Resets the statistics at the beginning of each epoch.

after_loss

Accumulates the metrics after computing the loss, differentiating between training and validation phases.

after_epoch

Prints the accumulated statistics for both training and validation phases after each epoch.

Source code in openml_pytorch/callbacks/recording.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
class AvgStatsCallback(Callback):
    """
    AvgStatsCallBack class is a custom callback used to track and print average statistics for training and validation phases during the training loop.

    Arguments:
        metrics: A list of metric functions to evaluate during training and validation.

    Methods:
        __init__: Initializes the callback with given metrics and sets up AvgStats objects for both training and validation phases.
        begin_epoch: Resets the statistics at the beginning of each epoch.
        after_loss: Accumulates the metrics after computing the loss, differentiating between training and validation phases.
        after_epoch: Prints the accumulated statistics for both training and validation phases after each epoch.
    """

    def __init__(self, metrics):
        self.train_stats, self.valid_stats = (
            AvgStats(metrics, True),
            AvgStats(metrics, False),
        )

    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()

    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.run)

    def after_epoch(self):
        current_epoch = self.current_epoch
        print(f"\n{'='*40}")
        print(f"Epoch {current_epoch}")
        print(f"{'-'*40}")
        print(f"Train: {self.train_stats}")
        print(f"Valid: {self.valid_stats}")
        print(f"{'='*40}\n")

Callback

Callback class is a base class designed for handling different callback functions during an event-driven process. It provides functionality to set a runner, retrieve the class name in snake_case format, directly call callback methods, and delegate attribute access to the runner if the attribute does not exist in the Callback class.

The _order is used to decide the order of Callbacks.

Source code in openml_pytorch/callbacks/callback.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Callback:
    """

    Callback class is a base class designed for handling different callback functions during
    an event-driven process. It provides functionality to set a runner, retrieve the class
    name in snake_case format, directly call callback methods, and delegate attribute access
    to the runner if the attribute does not exist in the Callback class.

    The _order is used to decide the order of Callbacks.

    """

    _order = 0

    def set_runner(self, run) -> None:
        self.run = run

    @property
    def name(self):
        name = re.sub(r"Callback$", "", self.__class__.__name__)
        return camel2snake(name or "callback")

    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f():
            return True
        return False

    def __getattr__(self, k):
        return getattr(self.run, k)

ParamScheduler

Bases: Callback

Manages scheduling of parameter adjustments over the course of training.

Source code in openml_pytorch/callbacks/annealing.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
class ParamScheduler(Callback):
    """
    Manages scheduling of parameter adjustments over the course of training.
    """

    _order = 1

    def __init__(self, pname, sched_funcs):
        self.pname, self.sched_funcs = pname, sched_funcs

    def begin_fit(self):
        """
        Prepare the scheduler at the start of the fitting process.
        This method ensures that sched_funcs is a list with one function per parameter group.
        """
        if not isinstance(self.sched_funcs, (list, tuple)):
            self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)

    def set_param(self):
        """
        Adjust the parameter value for each parameter group based on the scheduling function.
        Ensures the number of scheduling functions matches the number of parameter groups.
        """
        assert len(self.opt.param_groups) == len(self.sched_funcs)
        for pg, f in zip(self.opt.param_groups, self.sched_funcs):
            pg[self.pname] = f(self.n_epochs / self.epochs)

    def begin_batch(self):
        """
        Apply parameter adjustments at the beginning of each batch if in training mode.
        """
        if self.in_train:
            self.set_param()

begin_batch()

Apply parameter adjustments at the beginning of each batch if in training mode.

Source code in openml_pytorch/callbacks/annealing.py
 97
 98
 99
100
101
102
def begin_batch(self):
    """
    Apply parameter adjustments at the beginning of each batch if in training mode.
    """
    if self.in_train:
        self.set_param()

begin_fit()

Prepare the scheduler at the start of the fitting process. This method ensures that sched_funcs is a list with one function per parameter group.

Source code in openml_pytorch/callbacks/annealing.py
80
81
82
83
84
85
86
def begin_fit(self):
    """
    Prepare the scheduler at the start of the fitting process.
    This method ensures that sched_funcs is a list with one function per parameter group.
    """
    if not isinstance(self.sched_funcs, (list, tuple)):
        self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)

set_param()

Adjust the parameter value for each parameter group based on the scheduling function. Ensures the number of scheduling functions matches the number of parameter groups.

Source code in openml_pytorch/callbacks/annealing.py
88
89
90
91
92
93
94
95
def set_param(self):
    """
    Adjust the parameter value for each parameter group based on the scheduling function.
    Ensures the number of scheduling functions matches the number of parameter groups.
    """
    assert len(self.opt.param_groups) == len(self.sched_funcs)
    for pg, f in zip(self.opt.param_groups, self.sched_funcs):
        pg[self.pname] = f(self.n_epochs / self.epochs)

PutDataOnDeviceCallback

Bases: Callback

PutDataOnDevice class is a custom callback used to move the input data and target labels to the device (CPU or GPU) before passing them to the model.

Methods:

Name Description
begin_fit

Moves the model to the device at the beginning of the fitting process.

begin_batch

Moves the input data and target labels to the device at the beginning of each batch.

Source code in openml_pytorch/callbacks/device_callbacks.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class PutDataOnDeviceCallback(Callback):
    """
    PutDataOnDevice class is a custom callback used to move the input data and target labels to the device (CPU or GPU) before passing them to the model.

    Methods:
        begin_fit: Moves the model to the device at the beginning of the fitting process.
        begin_batch: Moves the input data and target labels to the device at the beginning of each batch.
    """

    def __init__(self, device):
        self.device = device

    def begin_fit(self):
        self.model.to(self.device)

    def begin_batch(self):
        self.run.xb, self.run.yb = self.xb.to(self.device), self.yb.to(self.device)

    def after_pred(self):
        self.run.pred = self.run.pred.to(self.device)
        self.run.yb = self.run.yb.to(self.device)

Recorder

Bases: Callback

Recorder is a callback class used to record learning rates, losses, and metrics during the training process.

Source code in openml_pytorch/callbacks/recording.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class Recorder(Callback):
    """
    Recorder is a callback class used to record learning rates, losses, and metrics during the training process.
    """

    def begin_fit(self):
        """
        Initializes attributes necessary for the fitting process.
        """
        self.lrs = [[] for _ in self.opt.param_groups]
        self.losses = []
        self.metrics = (
            {metric.__name__: [] for metric in self.metrics}
            if hasattr(self, "metrics")
            else {}
        )
        self.epochs = []
        self.current_epoch = 0

    def begin_epoch(self):
        """
        Handles operations at the beginning of each epoch.
        """
        self.current_epoch += 1

    def after_batch(self):
        """
        Handles operations to execute after each training batch.
        """
        if not self.in_train:
            return

        for pg, lr in zip(self.opt.param_groups, self.lrs):
            lr.append(pg["lr"])

        self.losses.append(self.loss.detach().cpu())

    def after_epoch(self):
        """
        Records metrics at the end of each epoch.
        """
        self.epochs.append(self.current_epoch)
        # Record metrics from AvgStatsCallback if available
        if hasattr(self, "run"):
            for cb in self.run.cbs:
                if isinstance(cb, AvgStatsCallback):
                    for i, metric_fn in enumerate(cb.train_stats.metrics):
                        metric_name = metric_fn.__name__
                        if metric_name == "":
                            metric_name = str(metric_fn)
                        if metric_name not in self.metrics:
                            self.metrics[metric_name] = []
                        # Store both train and valid metrics
                        self.metrics[metric_name].append(
                            {
                                "train": cb.train_stats.avg_stats[
                                    i + 1
                                ],  # +1 because avg_stats includes loss as first element
                                "valid": cb.valid_stats.avg_stats[i + 1],
                            }
                        )

    def plot_lr(self, pgid=-1, save_path=None):
        """
        Plots the learning rate for a given parameter group.
        """
        # check if empty
        if not self.lrs[pgid]:
            print("No learning rates recorded.")
            return
        plot = plt.plot(self.lrs[pgid])
        plt.xlabel("Iterations")
        plt.ylabel("Learning Rate")
        if save_path:
            plt.savefig(save_path)
        return plot

    def plot_loss(self, skip_last=0, save_path=None):
        """
        Plots the loss values.
        """
        # check if empty
        if not self.losses:
            print("No losses recorded.")
            return
        plot = plt.plot(self.losses[: len(self.losses) - skip_last])
        plt.xlabel("Iterations")
        plt.ylabel("Loss")

        if save_path:
            plt.savefig(save_path)
        return plot

    def plot(self, skip_last=0, pgid=-1):
        """
        Generates a plot of the loss values against the learning rates.
        """
        losses = [o.item() for o in self.losses]
        lrs = self.lrs[pgid]
        n = len(losses) - skip_last
        plt.xscale("log")
        plt.xlabel("Learning Rate")
        plt.ylabel("Loss")
        return plt.plot(lrs[:n], losses[:n])

    def plot_metric(self, metric_name, skip_last=0, save_path=None):
        """
        Plots a specific metric over epochs.

        Args:
            metric_name (str): Name of the metric to plot
            skip_last (int): Number of last points to skip
        """
        # check if empty
        if not self.metrics:
            print("No metrics recorded.")
            return
        if metric_name not in self.metrics:
            print(
                f"Metric '{metric_name}' not found. Available metrics: {list(self.metrics.keys())}"
            )
            return

        train_vals = [d["train"] for d in self.metrics[metric_name]]
        valid_vals = [d["valid"] for d in self.metrics[metric_name]]

        # convert to cpu numpy if necessary
        train_vals = [
            val.item() if isinstance(val, torch.Tensor) else val for val in train_vals
        ]
        valid_vals = [
            val.item() if isinstance(val, torch.Tensor) else val for val in valid_vals
        ]

        plt.figure(figsize=(10, 6))
        plt.plot(
            self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
            train_vals[:-skip_last] if skip_last > 0 else train_vals,
            label=f"Train {metric_name}",
        )
        plt.plot(
            self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
            valid_vals[:-skip_last] if skip_last > 0 else valid_vals,
            label=f"Valid {metric_name}",
        )
        plt.xlabel("Epochs")
        plt.ylabel(metric_name)
        plt.title(f"{metric_name} vs. Epochs")
        plt.legend()
        if save_path:
            plt.savefig(save_path)
        plt.show()
        return plt

    def plot_all_metrics(self, skip_last=0, save_path=None):
        """
        Plots all available metrics in subplots.

        Args:
            skip_last (int): Number of last points to skip for all metrics
        """
        # check if empty
        if not self.metrics:
            print("No metrics recorded.")
            return
        if len(self.metrics) == 0:
            print("No metrics recorded.")
            return
        num_metrics = len(self.metrics)
        fig, axes = plt.subplots(num_metrics, 1, figsize=(10, 6 * num_metrics))

        # If there's only one metric, axes is not a list, so make sure we handle that.
        if num_metrics == 1:
            axes = [axes]

        for i, (metric_name, metric_data) in enumerate(self.metrics.items()):
            train_vals = [d["train"] for d in metric_data]
            valid_vals = [d["valid"] for d in metric_data]

            # convert to cpu numpy if necessary
            train_vals = [
                val.item() if isinstance(val, torch.Tensor) else val
                for val in train_vals
            ]
            valid_vals = [
                val.item() if isinstance(val, torch.Tensor) else val
                for val in valid_vals
            ]

            # Plot the data
            axes[i].plot(
                self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
                train_vals[:-skip_last] if skip_last > 0 else train_vals,
                label=f"Train {metric_name}",
            )
            axes[i].plot(
                self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
                valid_vals[:-skip_last] if skip_last > 0 else valid_vals,
                label=f"Valid {metric_name}",
            )
            axes[i].set_xlabel("Epochs")
            axes[i].set_ylabel(metric_name)
            axes[i].set_title(f"{metric_name} vs. Epochs")
            axes[i].legend()

        # Adjust layout
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path)
        plt.show()
        return plt

    def get_metrics_history(self):
        """
        Returns a dictionary containing the history of all recorded metrics.

        Returns:
            dict: A dictionary with metric names as keys and lists of values as values
        """
        return self.metrics

after_batch()

Handles operations to execute after each training batch.

Source code in openml_pytorch/callbacks/recording.py
32
33
34
35
36
37
38
39
40
41
42
def after_batch(self):
    """
    Handles operations to execute after each training batch.
    """
    if not self.in_train:
        return

    for pg, lr in zip(self.opt.param_groups, self.lrs):
        lr.append(pg["lr"])

    self.losses.append(self.loss.detach().cpu())

after_epoch()

Records metrics at the end of each epoch.

Source code in openml_pytorch/callbacks/recording.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def after_epoch(self):
    """
    Records metrics at the end of each epoch.
    """
    self.epochs.append(self.current_epoch)
    # Record metrics from AvgStatsCallback if available
    if hasattr(self, "run"):
        for cb in self.run.cbs:
            if isinstance(cb, AvgStatsCallback):
                for i, metric_fn in enumerate(cb.train_stats.metrics):
                    metric_name = metric_fn.__name__
                    if metric_name == "":
                        metric_name = str(metric_fn)
                    if metric_name not in self.metrics:
                        self.metrics[metric_name] = []
                    # Store both train and valid metrics
                    self.metrics[metric_name].append(
                        {
                            "train": cb.train_stats.avg_stats[
                                i + 1
                            ],  # +1 because avg_stats includes loss as first element
                            "valid": cb.valid_stats.avg_stats[i + 1],
                        }
                    )

begin_epoch()

Handles operations at the beginning of each epoch.

Source code in openml_pytorch/callbacks/recording.py
26
27
28
29
30
def begin_epoch(self):
    """
    Handles operations at the beginning of each epoch.
    """
    self.current_epoch += 1

begin_fit()

Initializes attributes necessary for the fitting process.

Source code in openml_pytorch/callbacks/recording.py
12
13
14
15
16
17
18
19
20
21
22
23
24
def begin_fit(self):
    """
    Initializes attributes necessary for the fitting process.
    """
    self.lrs = [[] for _ in self.opt.param_groups]
    self.losses = []
    self.metrics = (
        {metric.__name__: [] for metric in self.metrics}
        if hasattr(self, "metrics")
        else {}
    )
    self.epochs = []
    self.current_epoch = 0

get_metrics_history()

Returns a dictionary containing the history of all recorded metrics.

Returns:

Name Type Description
dict

A dictionary with metric names as keys and lists of values as values

Source code in openml_pytorch/callbacks/recording.py
219
220
221
222
223
224
225
226
def get_metrics_history(self):
    """
    Returns a dictionary containing the history of all recorded metrics.

    Returns:
        dict: A dictionary with metric names as keys and lists of values as values
    """
    return self.metrics

plot(skip_last=0, pgid=-1)

Generates a plot of the loss values against the learning rates.

Source code in openml_pytorch/callbacks/recording.py
100
101
102
103
104
105
106
107
108
109
110
def plot(self, skip_last=0, pgid=-1):
    """
    Generates a plot of the loss values against the learning rates.
    """
    losses = [o.item() for o in self.losses]
    lrs = self.lrs[pgid]
    n = len(losses) - skip_last
    plt.xscale("log")
    plt.xlabel("Learning Rate")
    plt.ylabel("Loss")
    return plt.plot(lrs[:n], losses[:n])

plot_all_metrics(skip_last=0, save_path=None)

Plots all available metrics in subplots.

Parameters:

Name Type Description Default
skip_last int

Number of last points to skip for all metrics

0
Source code in openml_pytorch/callbacks/recording.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def plot_all_metrics(self, skip_last=0, save_path=None):
    """
    Plots all available metrics in subplots.

    Args:
        skip_last (int): Number of last points to skip for all metrics
    """
    # check if empty
    if not self.metrics:
        print("No metrics recorded.")
        return
    if len(self.metrics) == 0:
        print("No metrics recorded.")
        return
    num_metrics = len(self.metrics)
    fig, axes = plt.subplots(num_metrics, 1, figsize=(10, 6 * num_metrics))

    # If there's only one metric, axes is not a list, so make sure we handle that.
    if num_metrics == 1:
        axes = [axes]

    for i, (metric_name, metric_data) in enumerate(self.metrics.items()):
        train_vals = [d["train"] for d in metric_data]
        valid_vals = [d["valid"] for d in metric_data]

        # convert to cpu numpy if necessary
        train_vals = [
            val.item() if isinstance(val, torch.Tensor) else val
            for val in train_vals
        ]
        valid_vals = [
            val.item() if isinstance(val, torch.Tensor) else val
            for val in valid_vals
        ]

        # Plot the data
        axes[i].plot(
            self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
            train_vals[:-skip_last] if skip_last > 0 else train_vals,
            label=f"Train {metric_name}",
        )
        axes[i].plot(
            self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
            valid_vals[:-skip_last] if skip_last > 0 else valid_vals,
            label=f"Valid {metric_name}",
        )
        axes[i].set_xlabel("Epochs")
        axes[i].set_ylabel(metric_name)
        axes[i].set_title(f"{metric_name} vs. Epochs")
        axes[i].legend()

    # Adjust layout
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()
    return plt

plot_loss(skip_last=0, save_path=None)

Plots the loss values.

Source code in openml_pytorch/callbacks/recording.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def plot_loss(self, skip_last=0, save_path=None):
    """
    Plots the loss values.
    """
    # check if empty
    if not self.losses:
        print("No losses recorded.")
        return
    plot = plt.plot(self.losses[: len(self.losses) - skip_last])
    plt.xlabel("Iterations")
    plt.ylabel("Loss")

    if save_path:
        plt.savefig(save_path)
    return plot

plot_lr(pgid=-1, save_path=None)

Plots the learning rate for a given parameter group.

Source code in openml_pytorch/callbacks/recording.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot_lr(self, pgid=-1, save_path=None):
    """
    Plots the learning rate for a given parameter group.
    """
    # check if empty
    if not self.lrs[pgid]:
        print("No learning rates recorded.")
        return
    plot = plt.plot(self.lrs[pgid])
    plt.xlabel("Iterations")
    plt.ylabel("Learning Rate")
    if save_path:
        plt.savefig(save_path)
    return plot

plot_metric(metric_name, skip_last=0, save_path=None)

Plots a specific metric over epochs.

Parameters:

Name Type Description Default
metric_name str

Name of the metric to plot

required
skip_last int

Number of last points to skip

0
Source code in openml_pytorch/callbacks/recording.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def plot_metric(self, metric_name, skip_last=0, save_path=None):
    """
    Plots a specific metric over epochs.

    Args:
        metric_name (str): Name of the metric to plot
        skip_last (int): Number of last points to skip
    """
    # check if empty
    if not self.metrics:
        print("No metrics recorded.")
        return
    if metric_name not in self.metrics:
        print(
            f"Metric '{metric_name}' not found. Available metrics: {list(self.metrics.keys())}"
        )
        return

    train_vals = [d["train"] for d in self.metrics[metric_name]]
    valid_vals = [d["valid"] for d in self.metrics[metric_name]]

    # convert to cpu numpy if necessary
    train_vals = [
        val.item() if isinstance(val, torch.Tensor) else val for val in train_vals
    ]
    valid_vals = [
        val.item() if isinstance(val, torch.Tensor) else val for val in valid_vals
    ]

    plt.figure(figsize=(10, 6))
    plt.plot(
        self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
        train_vals[:-skip_last] if skip_last > 0 else train_vals,
        label=f"Train {metric_name}",
    )
    plt.plot(
        self.epochs[:-skip_last] if skip_last > 0 else self.epochs,
        valid_vals[:-skip_last] if skip_last > 0 else valid_vals,
        label=f"Valid {metric_name}",
    )
    plt.xlabel("Epochs")
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} vs. Epochs")
    plt.legend()
    if save_path:
        plt.savefig(save_path)
    plt.show()
    return plt

TensorBoardCallback

Bases: Callback

Log specific things to TensorBoard. - Model

Source code in openml_pytorch/callbacks/tensorboard.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class TensorBoardCallback(Callback):
    """
    Log specific things to TensorBoard.
    - Model
    """

    def __init__(self, writer):
        self.writer = writer

    def begin_batch(self):
        if "saved_graph" not in self.__dict__ or not self.saved_graph:
            self.writer.add_graph(self.model, self.xb)
            self.saved_graph = True

    def after_fit(self):
        # check if tensorboard writer is available
        try:
            # add loss and learning rate  to tensorboard
            self.writer.add_scalar("Loss", self.run.loss, self.n_iter)
            self.writer.add_scalar(
                "Learning rate", self.run.opt.param_groups[0]["lr"], self.n_iter
            )
        except Exception as e:
            print(f"Error: {e}")
        self.writer.close()

TestCallback

Bases: Callback

TestCallback class is a custom callback used to test the training loop by stopping the training process after 2 iterations. Useful for debugging and testing purposes, not intended for actual training.

Source code in openml_pytorch/callbacks/training_callbacks.py
61
62
63
64
65
66
67
68
class TestCallback(Callback):
    """
    TestCallback class is a custom callback used to test the training loop by stopping the training process after 2 iterations. Useful for debugging and testing purposes, not intended for actual training.
    """

    def after_step(self):
        if self.n_iter >= 1:
            raise CancelTrainException()

TrainEvalCallback

Bases: Callback

TrainEvalCallback class is a custom callback used during the training and validation phases of a machine learning model to perform specific actions at the beginning and after certain events.

Methods:

begin_fit(): Initialize the number of epochs and iteration counts at the start of the fitting process.

after_batch(): Update the epoch and iteration counts after each batch during training.

begin_epoch(): Set the current epoch, switch the model to training mode, and indicate that the model is in training.

begin_validate(): Switch the model to evaluation mode and indicate that the model is in validation.

Source code in openml_pytorch/callbacks/training_callbacks.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TrainEvalCallback(Callback):
    """
    TrainEvalCallback class is a custom callback used during the training
    and validation phases of a machine learning model to perform specific
    actions at the beginning and after certain events.

    Methods:

    begin_fit():
        Initialize the number of epochs and iteration counts at the start
        of the fitting process.

    after_batch():
        Update the epoch and iteration counts after each batch during
        training.

    begin_epoch():
        Set the current epoch, switch the model to training mode, and
        indicate that the model is in training.

    begin_validate():
        Switch the model to evaluation mode and indicate that the model
        is in validation.
    """

    def begin_fit(self):
        self.run.n_epochs = 0
        self.run.n_iter = 0

    def after_batch(self):
        if not self.in_train:
            return
        self.run.n_epochs += 1.0 / self.iters
        self.run.n_iter += 1

    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False

annealer(f)

A decorator function for creating a partially applied function with predefined start and end arguments. The inner function _inner captures the start and end parameters and returns a partial object that fixes these parameters for the decorated function f.

Source code in openml_pytorch/callbacks/annealing.py
 9
10
11
12
13
14
15
16
17
18
def annealer(f) -> callable:
    """
    A decorator function for creating a partially applied function with predefined start and end arguments.
    The inner function `_inner` captures the `start` and `end` parameters and returns a `partial` object that fixes these parameters for the decorated function `f`.
    """

    def _inner(start, end):
        return partial(f, start, end)

    return _inner

camel2snake(name)

Convert name from camel case to snake case.

Source code in openml_pytorch/callbacks/helper.py
23
24
25
26
27
28
def camel2snake(name: str) -> str:
    """
    Convert `name` from camel case to snake case.
    """
    s1 = re.sub(_camel_re1, r"\1_\2", name)
    return re.sub(_camel_re2, r"\1_\2", s1).lower()

combine_scheds(pcts, scheds)

Combine multiple scheduling functions.

Source code in openml_pytorch/callbacks/annealing.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def combine_scheds(pcts: Iterable[float], scheds: Iterable[callable]) -> callable:
    """
    Combine multiple scheduling functions.
    """
    assert sum(pcts) == 1.0
    pcts = torch.tensor([0] + listify(pcts))
    assert torch.all(pcts >= 0)
    pcts = torch.cumsum(pcts, 0)

    def _inner(pos):
        idx = (pos >= pcts).nonzero().max()
        actual_pos = (pos - pcts[idx]) / (pcts[idx + 1] - pcts[idx])
        return scheds[idx](actual_pos)

    return _inner

listify(o=None)

Convert o to list. If o is None, return empty list.

Source code in openml_pytorch/callbacks/helper.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def listify(o=None) -> list:
    """
    Convert `o` to list. If `o` is None, return empty list.
    """
    if o is None:
        return []
    if isinstance(o, list):
        return o
    if isinstance(o, str):
        return [o]
    if isinstance(o, Iterable):
        return list(o)
    return [o]

sched_cos(start, end, pos)

A cosine schedule function.

Source code in openml_pytorch/callbacks/annealing.py
29
30
31
32
33
34
@annealer
def sched_cos(start: float, end: float, pos: float) -> float:
    """
    A cosine schedule function.
    """
    return start + (1 + math.cos(math.pi * (1 - pos))) * (end - start) / 2

sched_exp(start, end, pos)

Exponential schedule function.

Source code in openml_pytorch/callbacks/annealing.py
45
46
47
48
49
50
@annealer
def sched_exp(start: float, end: float, pos: float) -> float:
    """
    Exponential schedule function.
    """
    return start * (end / start) ** pos

sched_lin(start, end, pos)

A linear schedule function.

Source code in openml_pytorch/callbacks/annealing.py
21
22
23
24
25
26
@annealer
def sched_lin(start: float, end: float, pos: float) -> float:
    """
    A linear schedule function.
    """
    return start + pos * (end - start)

sched_no(start, end, pos)

Disabled scheduling.

Source code in openml_pytorch/callbacks/annealing.py
37
38
39
40
41
42
@annealer
def sched_no(start: float, end: float, pos: float) -> float:
    """
    Disabled scheduling.
    """
    return start