Core

Basic

Photon.WorkoutType.

The Workout keeps track of the progress of the training session. At least a model and a loss function needs to be provided. Optional an optimizer and one or more metrics can be specified.

If no optimizer is specified, SGD will be used. If no metrics are provided, only the loss during training and validation will be registered (:loss and :val_loss).

The provided mover will move data to the correct device. See also SmartMover. If no mover is required, you can provide: (x) -> x or simple identity

Usage

workout = Workout(model, L1Loss())

workout = Workout(model, CrossEntropy(), Adam())

workout = Workout(model, HingeLoss(); acc=BinaryAccuracy())

workout = Workout(model, L1Loss(), mover=identity)
source
Missing docstring.

Missing docstring for saveWorkout. Check Documenter's build log for details.

Missing docstring.

Missing docstring for loadWorkout. Check Documenter's build log for details.

Photon.validateFunction.

Validate a minibatch and calculate the loss and metrics. Typically this function is called from the train! method. But if required can also be invoked directly.

source
Photon.predictFunction.

Predict a sample, either a single value or a batch. Compared to invoking the model directory with model(x), predit takes care of:

  • Moving the data to the GPU if required.
  • Shaping the data into a batch (controlled by makebatch parameter)

Usage

x = randn(Float32, 224, 224, 3)
predict(workout, x)
source
Photon.train!Function.

Train the model based on a supervised dataset and for a number of epochs. train! can be called multiple times and will continue to train where is left of last time.

By default the train! function will try to ensure the data is of the right type (e.g. Float32) and on the right device (e.g. GPU) before feeding it to the model.

Usage

train!(workout, traindata)
train!(workout, traindata, testdata, epochs=50)
source
Photon.freeze!Function.

Freeze a parameter so it no longer will be updated during training.

source
Photon.unfreeze!Function.

Unfreeze a parameter so it will be updated again during training.

source
Photon.stopFunction.

Stop a training session. Typically invoked by a callback function that detects that the training is not progressing anymore.

If this function is called outside the scope of a trianing session, an exception is thrown.

source
Photon.gradientsFunction.

Utility function to calculate the gradients. Useful when checking for vanishing or exploding gradients. The returned value is a Vector of (Param, Gradient).

source

Internal

You normally won't have to invoke the following functions directly when training a model. But in some cases you might want to write a specialized version of them.

Photon.back!Function.

Perform the back propagation and update of weights in a single go.

source
Photon.step!Function.

Take a single step in updating the weights of a model. This function will be invoked from train! to do the actual learning.

For a minibatch (x,y) of data, the folowing sequence will be executed:

  1. perform the forward pass
  2. calculate the loss
  3. update and remember the metrics, if any
  4. do the backpropagation and update the weights
source