torchbearer¶
Trial¶
-
class
torchbearer.trial.
CallbackListInjection
(callback, callback_list)[source]¶ This class allows for an callback to be injected into a callback list, without masking the methods available for mutating the list. In this way, callbacks (such as printers) can be injected seamlessly into the methods of the trial class.
Parameters: - callback – The callback to inject
- callback_list (CallbackList) – The underlying callback list
-
load_state_dict
(state_dict)[source]¶ Resume this callback list from the given state. Callbacks must be given in the same order for this to work.
Parameters: state_dict (dict) – The state dict to reload Returns: self Return type: CallbackList
-
class
torchbearer.trial.
Sampler
(batch_loader)[source]¶ Sampler wraps a batch loader function and executes it when
Sampler.sample()
is calledParameters: batch_loader (function) – The batch loader to execute
-
class
torchbearer.trial.
Trial
(model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False, verbose=2)[source]¶ The trial class contains all of the required hyper-parameters for model running in torchbearer and presents an API for model fitting, evaluating and predicting.
Parameters: - model (torch.nn.Module) – The base pytorch model
- optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
- criterion (function or None) – The final loss criterion that provides a loss value to the optimizer
- metrics (list) – The list of
torchbearer.Metric
instances to process during fitting - callbacks (list) – The list of
torchbearer.Callback
instances to call during fitting - pass_state (bool) – If True, the torchbearer state will be passed to the model during fitting
- verbose (int) – Global verbosity .If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
-
cuda
(device=None)[source]¶ Moves all model parameters and buffers to the GPU.
Parameters: device (int, optional) – if specified, all parameters will be copied to that device Returns: self Return type: Trial
-
evaluate
(verbose=-1, data_key=None)[source]¶ Evaluate this trial on the validation data.
Parameters: - verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
- data_key (StateKey) – Optional key for the data to evaluate on. Default: torchbearer.VALIDATION_DATA
Returns: The final metric values
Return type: dict
-
for_steps
(train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial for the given number of train, val and test steps. Returns self so that methods can be chained for convenience.
Parameters: - train_steps (int, optional) – The number of training steps per epoch to run
- val_steps (int, optional) – The number of validation steps per epoch to run
- test_steps (int, optional) – The number of test steps per epoch to run (when using
predict()
)
Returns: self
Return type:
-
for_test_steps
(steps)[source]¶ Run this trial for the given number of test steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.
Parameters: steps (int) – The number of test steps per epoch to run (when using predict()
)Returns: self Return type: Trial
-
for_train_steps
(steps)[source]¶ Run this trial for the given number of training steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.
Parameters: steps (int) – The number of training steps per epoch to run Returns: self Return type: Trial
-
for_val_steps
(steps)[source]¶ Run this trial for the given number of validation steps. Note that the generator will output (None, None) if it has not been set. Useful for differentiable programming. Returns self so that methods can be chained for convenience.
Parameters: steps (int) – The number of validation steps per epoch to run Returns: self Return type: Trial
-
load_state_dict
(state_dict, resume=True, **kwargs)[source]¶ Resume this trial from the given state. Expects that this trial was constructed in the same way. Optionally, just load the model state when resume=False.
Parameters: - state_dict (dict) – The state dict to reload
- resume – If True, resume from the given state. Else, just load in the model weights.
- kwargs – See: torch.nn.Module.load_state_dict
Returns: self
Return type:
-
predict
(verbose=-1, data_key=None)[source]¶ Determine predictions for this trial on the test data.
Parameters: - verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress, If -1: Automatic
- data_key (StateKey) – Optional key for the data to predict on. Default: torchbearer.TEST_DATA
Returns: Model outputs as a list
Return type: list
-
replay
(callbacks=[], verbose=2)[source]¶ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.
Parameters: - callbacks (list) – List of callbacks to be run during the replay
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
Returns: self
Return type:
-
run
(epochs=1, verbose=-1)[source]¶ Run this trial for the given number of epochs, starting from the last trained epoch.
Parameters: - epochs (int, optional) – The number of epochs to run for
- verbose (int, optional) – If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training
- If -1 (progress,) – Automatic
- State Requirements:
torchbearer.state.MODEL
: Model should be callable and not none, set on Trial init
Returns: The model history (list of tuple of steps summary and epoch metric dicts) Return type: list
-
state_dict
(**kwargs)[source]¶ Get a dict containing the model and optimizer states, as well as the model history.
Parameters: kwargs – See: torch.nn.Module.state_dict Returns: A dict containing parameters and persistent buffers. Return type: dict
-
to
(*args, **kwargs)[source]¶ Moves and/or casts the parameters and buffers.
Parameters: - args – See: torch.nn.Module.to
- kwargs –
See: torch.nn.Module.to
Returns: self
Return type:
-
with_generators
(train_generator=None, val_generator=None, test_generator=None, train_steps=None, val_steps=None, test_steps=None)[source]¶ Use this trial with the given generators. Returns self so that methods can be chained for convenience.
Parameters: - train_generator (DataLoader) – The training data generator to use during calls to
run()
- val_generator (DataLoader) – The validation data generator to use during calls to
run()
andevaluate()
- test_generator (DataLoader) – The testing data generator to use during calls to
predict()
- train_steps (int) – The number of steps per epoch to take when using the training generator
- val_steps (int) – The number of steps per epoch to take when using the validation generator
- test_steps (int) – The number of steps per epoch to take when using the testing generator
Returns: self
Return type: - train_generator (DataLoader) – The training data generator to use during calls to
-
with_test_data
(x, batch_size=1, num_workers=1, steps=None)[source]¶ Use this trial with the given test data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The test x data to use during calls to
predict()
- batch_size (int) – The size of each batch to sample from the data
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The test x data to use during calls to
-
with_test_generator
(generator, steps=None)[source]¶ Use this trial with the given test generator. Returns self so that methods can be chained for convenience.
Parameters: - generator (DataLoader) – The test data generator to use during calls to
predict()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator (DataLoader) – The test data generator to use during calls to
-
with_train_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given train data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The train x data to use during calls to
run()
- y (torch.Tensor) – The train labels to use during calls to
run()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The train x data to use during calls to
-
with_train_generator
(generator, steps=None)[source]¶ Use this trial with the given train generator. Returns self so that methods can be chained for convenience.
Parameters: - generator (DataLoader) – The train data generator to use during calls to
run()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator (DataLoader) – The train data generator to use during calls to
-
with_val_data
(x, y, batch_size=1, shuffle=True, num_workers=1, steps=None)[source]¶ Use this trial with the given validation data. Returns self so that methods can be chained for convenience.
Parameters: - x (torch.Tensor) – The validation x data to use during calls to
run()
andevaluate()
- y (torch.Tensor) – The validation labels to use during calls to
run()
andevaluate()
- batch_size (int) – The size of each batch to sample from the data
- shuffle (bool) – If True, then data will be shuffled each epoch
- num_workers (int) – Number of worker threads to use in the data loader
- steps (int) – The number of steps per epoch to take when using this data
Returns: self
Return type: - x (torch.Tensor) – The validation x data to use during calls to
-
with_val_generator
(generator, steps=None)[source]¶ Use this trial with the given validation generator. Returns self so that methods can be chained for convenience.
Parameters: - generator (DataLoader) – The validation data generator to use during calls to
run()
andevaluate()
- steps (int) – The number of steps per epoch to take when using this generator
Returns: self
Return type: - generator (DataLoader) – The validation data generator to use during calls to
-
torchbearer.trial.
deep_to
(batch, device, dtype)[source]¶ Static method to call
to()
on tensors or tuples. All items in tuple will havedeep_to()
called :param batch: The mini-batch which requires ato()
call :type batch: tuple, list, torch.Tensor :param device: The desired device of the batch :type device: torch.device :param dtype: The desired datatype of the batch :type dtype: torch.dtype :return: The moved or casted batch :rtype: tuple, list, torch.Tensor
-
torchbearer.trial.
inject_callback
(callback)[source]¶ Decorator to inject a callback into the callback list and remove the callback after the decorated function has executed
Parameters: callback (Callback) – Callback to be injected Returns: the decorator
-
torchbearer.trial.
inject_printer
(validation_label_letter='v')[source]¶ The inject printer decorator is used to inject the appropriate printer callback, according to the verbosity level.
Parameters: validation_label_letter – The validation label letter to use Returns: A decorator
-
torchbearer.trial.
inject_sampler
(data_key, predict=False)[source]¶ Decorator to inject a
Sampler
into state[torchbearer.SAMPLER] along with the specified generator into state[torchbearer.GENERATOR] and number of steps into state[torchbearer.STEPS] :param data_key: Key for the data to inject :type data_key: StateKey :param predict: If true, the prediction batch loader is used, if false the standard data loader is used :type predict: bool :return: the decorator
-
torchbearer.trial.
load_batch_none
(state)[source]¶ Load a none (none, none) tuple mini-batch into state
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
torchbearer.trial.
load_batch_predict
(state)[source]¶ Load a prediction (input data, target) or (input data) mini-batch from iterator into state
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
torchbearer.trial.
load_batch_standard
(state)[source]¶ Load a standard (input data, target) tuple mini-batch from iterator into state
Parameters: state (dict[str,any]) – The current state dict of the Trial
.
-
torchbearer.trial.
update_device_and_dtype
(state, *args, **kwargs)[source]¶ Function get data type and device values from the args / kwargs and update state.
Parameters: - state (State) – The dict to update
- args – Arguments to the
Trial.to()
function - kwargs – Keyword arguments to the
Trial.to()
function
Returns: device, dtype pair
Return type: tuple
Model (Deprecated)¶
-
class
torchbearer.torchbearer.
Model
(model, optimizer, criterion=None, metrics=[])[source]¶ Deprecated since version 0.2.0: Use
Trial
instead.Create torchbearermodel which wraps a base torchmodel and provides a training environment surrounding it
Parameters: - model (torch.nn.Module) – The base pytorch model
- optimizer (torch.optim.Optimizer) – The optimizer used for pytorch model weight updates
- criterion (function or None) – The final loss criterion that provides a loss value to the optimizer
- metrics (list) – Additional metrics for display and use within callbacks
-
cpu
()[source]¶ Moves all model parameters and buffers to the CPU.
Returns: Self torchbearermodel Return type: Model
-
cuda
(device=None)[source]¶ Moves all model parameters and buffers to the GPU.
Parameters: device (int, optional) – if specified, all parameters will be copied to that device Returns: Self torchbearermodel Return type: Model
-
evaluate
(x=None, y=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]¶ Perform an evaluation loop on given data and label tensors to evaluate metrics
Parameters: - x (torch.Tensor) – The input data tensor
- y (torch.Tensor) – The target labels for data tensor x
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The dictionary containing final metrics
Return type: dict[str,any]
-
evaluate_generator
(generator, verbose=2, steps=None, pass_state=False)[source]¶ Perform an evaluation loop on given data generator to evaluate metrics
Parameters: - generator (DataLoader) – The evaluation data generator (usually a pytorch DataLoader)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The dictionary containing final metrics
Return type: dict[str,any]
-
fit
(x, y, batch_size=None, epochs=1, verbose=2, callbacks=[], validation_split=None, validation_data=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False)[source]¶ Perform fitting of a model to given data and label tensors
Parameters: - x (torch.Tensor) – The input data tensor
- y (torch.Tensor) – The target labels for data tensor x
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
- callbacks (list) – The list of torchbearer callbacks to be called during training and validation
- validation_split (float) – Fraction of the training dataset to be set aside for validation testing
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data tensor
- shuffle (bool) – If True mini-batches of training/validation data are randomly selected, if False mini-batches samples are selected in order defined by dataset
- initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
- steps_per_epoch (int) – The number of training mini-batches to run per epoch
- validation_steps (int) – The number of validation mini-batches to run per epoch
- workers (int) – The number of cpu workers devoted to batch loading and aggregating
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The final state context dictionary
Return type: dict[str,any]
-
fit_generator
(generator, train_steps=None, epochs=1, verbose=2, callbacks=[], validation_generator=None, validation_steps=None, initial_epoch=0, pass_state=False)[source]¶ Perform fitting of a model to given data generator
Parameters: - generator (DataLoader) – The training data generator (usually a pytorch DataLoader)
- train_steps (int) – The number of training mini-batches to run per epoch
- epochs (int) – The number of training epochs to be run (each sample from the dataset is viewed exactly once)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
- callbacks (list) – The list of torchbearer callbacks to be called during training and validation
- validation_generator (DataLoader) – The validation data generator (usually a pytorch DataLoader)
- validation_steps (int) – The number of validation mini-batches to run per epoch
- initial_epoch (int) – The integer value representing the first epoch - useful for continuing training after a number of epochs
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: The final state context dictionary
Return type: dict[str,any]
-
load_state_dict
(state_dict, **kwargs)[source]¶ Copies parameters and buffers from
state_dict()
into this module and its descendants.Parameters: - state_dict (dict) – A dict containing parameters and persistent buffers.
- kwargs –
-
predict
(x=None, batch_size=32, verbose=2, steps=None, pass_state=False)[source]¶ Perform a prediction loop on given data tensor to predict labels
Parameters: - x (torch.Tensor) – The input data tensor
- batch_size (int) – The mini-batch size (number of samples processed for a single weight update)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: Tensor of final predicted labels
Return type: torch.Tensor
-
predict_generator
(generator, verbose=2, steps=None, pass_state=False)[source]¶ Perform a prediction loop on given data generator to predict labels
Parameters: - generator (DataLoader) – The prediction data generator (usually a pytorch DataLoader)
- verbose (int) – If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no progress
- steps (int) – The number of evaluation mini-batches to run
- pass_state (bool) – If True the state dictionary is passed to the torch model forward method, if False only the input data is passed
Returns: Tensor of final predicted labels
Return type: torch.Tensor
-
state_dict
(**kwargs)[source]¶ Parameters: kwargs – Returns: A dict containing parameters and persistent buffers. Return type: dict
-
to
(*args, **kwargs)[source]¶ Moves and/or casts the parameters and buffers.
Parameters: - args –
See: torch.nn.Module.to
- kwargs –
See: torch.nn.Module.to
Returns: Self torchbearermodel
Return type: - args –
State¶
The state is central in torchbearer, storing all of the relevant intermediate values that may be changed or replaced
during model fitting. This module defines classes for interacting with state and all of the built in state keys used
throughout torchbearer. The state_key()
function can be used to create custom state keys for use in callbacks or
metrics.
Example:
from torchbearer import state_key
MY_KEY = state_key('my_test_key')
-
torchbearer.state.
BACKWARD_ARGS
= backward_args¶ The optional arguments which should be passed to the backward call
-
torchbearer.state.
BATCH
= t¶ The current batch number
-
torchbearer.state.
CALLBACK_LIST
= callback_list¶ The
CallbackList
object which is called by the Trial
-
torchbearer.state.
CRITERION
= criterion¶ The criterion to use when model fitting
-
torchbearer.state.
DATA
= data¶ The string name of the current data
-
torchbearer.state.
DATA_TYPE
= dtype¶ The data type of tensors in use by the model, match this to avoid type issues
-
torchbearer.state.
EPOCH
= epoch¶ The current epoch number
-
torchbearer.state.
FINAL_PREDICTIONS
= final_predictions¶ The key which maps to the predictions over the dataset when calling predict
-
torchbearer.state.
GENERATOR
= generator¶ The current data generator (DataLoader)
-
torchbearer.state.
HISTORY
= history¶ The history list of the Trial instance
-
torchbearer.state.
ITERATOR
= iterator¶ The current iterator
-
torchbearer.state.
LOSS
= loss¶ The current value for the loss
-
torchbearer.state.
MAX_EPOCHS
= max_epochs¶ The total number of epochs to run for
-
torchbearer.state.
METRICS
= metrics¶ The metric dict from the current batch of data
-
torchbearer.state.
MODEL
= model¶ The PyTorch module / model that will be trained
-
torchbearer.state.
OPTIMIZER
= optimizer¶ The optimizer to use when model fitting
-
torchbearer.state.
SAMPLER
= sampler¶ The sampler which loads data from the generator onto the correct device
-
torchbearer.state.
SELF
= self¶ A self refrence to the Trial object for persistence etc.
-
torchbearer.state.
STEPS
= steps¶ The current number of steps per epoch
-
torchbearer.state.
STOP_TRAINING
= stop_training¶ A flag that can be set to true to stop the current fit call
-
class
torchbearer.state.
State
[source]¶ State dictionary that behaves like a python dict but accepts StateKeys
-
class
torchbearer.state.
StateKey
(key)[source]¶ StateKey class that is a unique state key based on the input string key. State keys are also metrics which retrieve themselves from state.
Parameters: key (String) – Base key -
process
(state)[source]¶ MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.
If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.
Attributes and the return value of a MagicMock will also be MagicMocks.
-
process_final
(state)[source]¶ MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to configure the magic methods yourself.
If you use the spec or spec_set arguments then only magic methods that exist in the spec will be created.
Attributes and the return value of a MagicMock will also be MagicMocks.
-
-
torchbearer.state.
TEST_DATA
= test_data¶ The flag representing test data
-
torchbearer.state.
TEST_GENERATOR
= test_generator¶ The test data generator in the Trial object
-
torchbearer.state.
TEST_STEPS
= test_steps¶ The number of test steps to take
-
torchbearer.state.
TIMINGS
= timings¶ The timings keys used by the timer callback
-
torchbearer.state.
TRAIN_DATA
= train_data¶ The flag representing train data
-
torchbearer.state.
TRAIN_GENERATOR
= train_generator¶ The train data generator in the Trial object
-
torchbearer.state.
TRAIN_STEPS
= train_steps¶ The number of train steps to take
-
torchbearer.state.
VALIDATION_DATA
= validation_data¶ The flag representing validation data
-
torchbearer.state.
VALIDATION_GENERATOR
= validation_generator¶ The validation data generator in the Trial object
-
torchbearer.state.
VALIDATION_STEPS
= validation_steps¶ The number of validation steps to take
-
torchbearer.state.
VERSION
= torchbearer_version¶ The torchbearer version
-
torchbearer.state.
X
= x¶ The current batch of inputs
-
torchbearer.state.
Y_PRED
= y_pred¶ The current batch of predictions
-
torchbearer.state.
Y_TRUE
= y_true¶ The current batch of ground truth data
Utilities¶
-
class
torchbearer.cv_utils.
DatasetValidationSplitter
(dataset_len, split_fraction, shuffle_seed=None)[source]¶
-
torchbearer.cv_utils.
get_train_valid_sets
(x, y, validation_data, validation_split, shuffle=True)[source]¶ Generate validation and training datasets from whole dataset tensors
Parameters: - x (torch.Tensor) – Data tensor for dataset
- y (torch.Tensor) – Label tensor for dataset
- validation_data ((torch.Tensor, torch.Tensor)) – Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
- validation_split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation datasets
Return type: tuple
-
torchbearer.cv_utils.
train_valid_splitter
(x, y, split, shuffle=True)[source]¶ Generate training and validation tensors from whole dataset data and label tensors
Parameters: - x (torch.Tensor) – Data tensor for whole dataset
- y (torch.Tensor) – Label tensor for whole dataset
- split (float) – Fraction of dataset to be used for validation
- shuffle (bool) – If True randomize tensor order before splitting else do not randomize
Returns: Training and validation tensors (training data, training labels, validation data, validation labels)
Return type: tuple