Home

Awesome

SGDOptim

A Julia package for Stochastic Gradient Descent (SGD) and its variants.

Build Status


With the advent of Big Data, Stochastic Gradient Descent (SGD) has become increasingly popular in recent years, especially in machine learning and related areas. This package implements the SGD algorithm and its variants under a generic setting to facilitate the use of SGD in practice.

Here is an example that demonstrates the use of this package in solving a ridge regression problem.

Optimization Algorithms

This package depends on EmpiricalRisks.jl, which provides the basic components, including predictors, loss functions, and regularizers.

On top of that, we provide a variety of algorithms, including SGD and its variants, and you may choose one that is suitable for your need:

For streaming settings:

For distributed settings:

Learning rate:

The setting of the learning rate has significant impact on the algorithm's behavior. This package allows the learning rate setting to be provided as a function on t as a keyword argument.

The default setting is t -> 1.0 / (1.0 + t).

Key Functions

Streams

Unlike conventional methods, SGD and its variants look at a single sample or a small batch of samples at each iteration. In other words, data are viewed as a stream of samples or minibatches.

This package provides a variety of ways to construct data streams. Each data stream is essentially an iterator that implements the start, done, and next methods (see [here]( <http://julia.readthedocs.org/en/latest/stdlib/collections/#iteration) for details of Julia's iteration patterns). Each item from a data stream can be either a sample (as a pair of input and output) or a mini-batch (as a pair of multi-input array and multi-output array).

Note: All SGD algorithms in this package support both sample streams and mini-batch streams. At each iteration, the algorithm works on a single item from the stream, which can be either a sample or a mini-batch.

The package provides several methods to construct streams of samples or minibatches.

Callbacks

The algorithms provided in this package interoperate with the rest of the world through callbacks. In particular, it allows a third party (e.g. a higher-level script, a user, a GUI, etc) to monitor the progress of the optimization and take proper actions.

Generally, a callback is an arbitrary function (or closure) that can be called in the following way:

callback(theta, t, n, v)
paramsdescriptions
thetaThe current solution.
tThe number of elapsed iterations.
nThe number of samples that have been used.
vThe objective value of the last item, which can be an objective evaluated on a single sample or the total objective value evaluated on the last batch of samples.

The package already provides some callbacks for simple use: