Callbacks¶
Callbacks module contains classes and functions for handling callback functions during an event-driven process. This makes it easier to customize the behavior of the training loop and add additional functionality to the training process without modifying the core code.
To use a callback, create a class that inherits from the Callback class and implement the necessary methods. Callbacks can be used to perform actions at different stages of the training process, such as at the beginning or end of an epoch, batch, or fitting process. Then pass the callback object to the Trainer.
How to Use:¶
1 2 3 4 5 6 |
|
To add a custom parameter, for example to add a different metric to the AvgStatsCallBack. ```python trainer = OpenMLTrainerModule( data_module=data_module, verbose = True, epoch_count = 1, callbacks=[ AvgStatsCallBack([accuracy]) ], )
Useful Callbacks:¶
- TestCallback: Use when you are testing out new code and want to iterate through the training loop quickly. Stops training after 2 iterations.
AvgStats
¶
AvgStats class is used to track and accumulate average statistics (like loss and other metrics) during training and validation phases.
Attributes:
Name | Type | Description |
---|---|---|
metrics |
list
|
A list of metric functions to be tracked. |
in_train |
bool
|
A flag to indicate if the statistics are for the training phase. |
Methods:
Name | Description |
---|---|
__init__ |
Initializes the AvgStats with metrics and in_train flag. |
reset |
Resets the accumulated statistics. |
all_stats |
Property that returns all accumulated statistics including loss and metrics. |
avg_stats |
Property that returns the average of the accumulated statistics. |
accumulate |
Accumulates the statistics using the data from the given run. |
__repr__ |
Returns a string representation of the average statistics. |
Source code in openml_pytorch/callbacks/recording.py
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
|
AvgStatsCallback
¶
Bases: Callback
AvgStatsCallBack class is a custom callback used to track and print average statistics for training and validation phases during the training loop.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metrics
|
A list of metric functions to evaluate during training and validation. |
required |
Methods:
Name | Description |
---|---|
__init__ |
Initializes the callback with given metrics and sets up AvgStats objects for both training and validation phases. |
begin_epoch |
Resets the statistics at the beginning of each epoch. |
after_loss |
Accumulates the metrics after computing the loss, differentiating between training and validation phases. |
after_epoch |
Prints the accumulated statistics for both training and validation phases after each epoch. |
Source code in openml_pytorch/callbacks/recording.py
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
|
Callback
¶
Callback class is a base class designed for handling different callback functions during an event-driven process. It provides functionality to set a runner, retrieve the class name in snake_case format, directly call callback methods, and delegate attribute access to the runner if the attribute does not exist in the Callback class.
The _order is used to decide the order of Callbacks.
Source code in openml_pytorch/callbacks/callback.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
ParamScheduler
¶
Bases: Callback
Manages scheduling of parameter adjustments over the course of training.
Source code in openml_pytorch/callbacks/annealing.py
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
|
begin_batch()
¶
Apply parameter adjustments at the beginning of each batch if in training mode.
Source code in openml_pytorch/callbacks/annealing.py
97 98 99 100 101 102 |
|
begin_fit()
¶
Prepare the scheduler at the start of the fitting process. This method ensures that sched_funcs is a list with one function per parameter group.
Source code in openml_pytorch/callbacks/annealing.py
80 81 82 83 84 85 86 |
|
set_param()
¶
Adjust the parameter value for each parameter group based on the scheduling function. Ensures the number of scheduling functions matches the number of parameter groups.
Source code in openml_pytorch/callbacks/annealing.py
88 89 90 91 92 93 94 95 |
|
PutDataOnDeviceCallback
¶
Bases: Callback
PutDataOnDevice class is a custom callback used to move the input data and target labels to the device (CPU or GPU) before passing them to the model.
Methods:
Name | Description |
---|---|
begin_fit |
Moves the model to the device at the beginning of the fitting process. |
begin_batch |
Moves the input data and target labels to the device at the beginning of each batch. |
Source code in openml_pytorch/callbacks/device_callbacks.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
Recorder
¶
Bases: Callback
Recorder is a callback class used to record learning rates, losses, and metrics during the training process.
Source code in openml_pytorch/callbacks/recording.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
|
after_batch()
¶
Handles operations to execute after each training batch.
Source code in openml_pytorch/callbacks/recording.py
32 33 34 35 36 37 38 39 40 41 42 |
|
after_epoch()
¶
Records metrics at the end of each epoch.
Source code in openml_pytorch/callbacks/recording.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
begin_epoch()
¶
Handles operations at the beginning of each epoch.
Source code in openml_pytorch/callbacks/recording.py
26 27 28 29 30 |
|
begin_fit()
¶
Initializes attributes necessary for the fitting process.
Source code in openml_pytorch/callbacks/recording.py
12 13 14 15 16 17 18 19 20 21 22 23 24 |
|
get_metrics_history()
¶
Returns a dictionary containing the history of all recorded metrics.
Returns:
Name | Type | Description |
---|---|---|
dict |
A dictionary with metric names as keys and lists of values as values |
Source code in openml_pytorch/callbacks/recording.py
219 220 221 222 223 224 225 226 |
|
plot(skip_last=0, pgid=-1)
¶
Generates a plot of the loss values against the learning rates.
Source code in openml_pytorch/callbacks/recording.py
100 101 102 103 104 105 106 107 108 109 110 |
|
plot_all_metrics(skip_last=0, save_path=None)
¶
Plots all available metrics in subplots.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_last
|
int
|
Number of last points to skip for all metrics |
0
|
Source code in openml_pytorch/callbacks/recording.py
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
|
plot_loss(skip_last=0, save_path=None)
¶
Plots the loss values.
Source code in openml_pytorch/callbacks/recording.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
|
plot_lr(pgid=-1, save_path=None)
¶
Plots the learning rate for a given parameter group.
Source code in openml_pytorch/callbacks/recording.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
|
plot_metric(metric_name, skip_last=0, save_path=None)
¶
Plots a specific metric over epochs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metric_name
|
str
|
Name of the metric to plot |
required |
skip_last
|
int
|
Number of last points to skip |
0
|
Source code in openml_pytorch/callbacks/recording.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
|
TensorBoardCallback
¶
Bases: Callback
Log specific things to TensorBoard. - Model
Source code in openml_pytorch/callbacks/tensorboard.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
|
TestCallback
¶
Bases: Callback
TestCallback class is a custom callback used to test the training loop by stopping the training process after 2 iterations. Useful for debugging and testing purposes, not intended for actual training.
Source code in openml_pytorch/callbacks/training_callbacks.py
61 62 63 64 65 66 67 68 |
|
TrainEvalCallback
¶
Bases: Callback
TrainEvalCallback class is a custom callback used during the training and validation phases of a machine learning model to perform specific actions at the beginning and after certain events.
Methods:
begin_fit(): Initialize the number of epochs and iteration counts at the start of the fitting process.
after_batch(): Update the epoch and iteration counts after each batch during training.
begin_epoch(): Set the current epoch, switch the model to training mode, and indicate that the model is in training.
begin_validate(): Switch the model to evaluation mode and indicate that the model is in validation.
Source code in openml_pytorch/callbacks/training_callbacks.py
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
annealer(f)
¶
A decorator function for creating a partially applied function with predefined start and end arguments.
The inner function _inner
captures the start
and end
parameters and returns a partial
object that fixes these parameters for the decorated function f
.
Source code in openml_pytorch/callbacks/annealing.py
9 10 11 12 13 14 15 16 17 18 |
|
camel2snake(name)
¶
Convert name
from camel case to snake case.
Source code in openml_pytorch/callbacks/helper.py
23 24 25 26 27 28 |
|
combine_scheds(pcts, scheds)
¶
Combine multiple scheduling functions.
Source code in openml_pytorch/callbacks/annealing.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
listify(o=None)
¶
Convert o
to list. If o
is None, return empty list.
Source code in openml_pytorch/callbacks/helper.py
8 9 10 11 12 13 14 15 16 17 18 19 20 |
|
sched_cos(start, end, pos)
¶
A cosine schedule function.
Source code in openml_pytorch/callbacks/annealing.py
29 30 31 32 33 34 |
|
sched_exp(start, end, pos)
¶
Exponential schedule function.
Source code in openml_pytorch/callbacks/annealing.py
45 46 47 48 49 50 |
|
sched_lin(start, end, pos)
¶
A linear schedule function.
Source code in openml_pytorch/callbacks/annealing.py
21 22 23 24 25 26 |
|
sched_no(start, end, pos)
¶
Disabled scheduling.
Source code in openml_pytorch/callbacks/annealing.py
37 38 39 40 41 42 |
|