Awesome
scikit-chainer
scikit-learn like interface to chainer
How to install
$ pip install scikit-chainer
what's this?
This is a scikit-learn like interface to the chainer deeplearning framework.
You can use it to build your network model and use the model with scikit-learn APIs (e.g. fit
, predict
)
There are ChainerRegresser
for regression, ChainerClassifer
for classification base classes and ChainerTransformer
for transformation.
You need to inherit them and implement the following functions,
_setup_network
: network definition (FunctionSet
,Chain
orChainList
of chainer)forward
: emit the resultz
inputx
(note this is not the final predicted value)loss_func
: the loss function to minimize (e.g.mean_squared_error
,softmax_cross_entropy
etc)output_func
: emit the final resulty
from forwarded valuesz
(e.g.identity
for regression andsoftmax
for classification.
Example
Linear Regression
class LinearRegression(ChainerRegresser):
def _setup_network(self, **params):
return Chain(l1=F.Linear(params["n_dim"], 1))
def forward(self, x):
y = self.network.l1(x)
return y
def loss_func(self, y, t):
return F.mean_squared_error(y, t)
def output_func(self, h):
return F.identity(h)
LogisticRegression
class LogisticRegression(ChainerClassifier):
def _setup_network(self, **params):
return Chain(l1=F.Linear(params["n_dim"], params["n_class"]))
def forward(self, x):
y = self.network.l1(x)
return y
def loss_func(self, y, t):
return F.softmax_cross_entropy(y, t)
def output_func(self, h):
return F.softmax(h)
AutoEncoder
class AutoEncoder(ChainerTransformer):
def __init__(self, activation=F.relu, **params):
super(ChainerTransformer, self).__init__(**params)
self.activation = activation
def _setup_network(self, **params):
return Chain(
encoder=F.Linear(params["input_dim"], params["hidden_dim"]),
decoder=F.Linear(params["hidden_dim"], params["input_dim"])
)
def _forward(self, x, train=False):
z = self._transform(x, train)
y = self.network.decoder(z)
return y
def _loss_func(self, y, t):
return F.mean_squared_error(y, t)
def _transform(self, x, train=False):
return self.activation(self.network.encoder(x))