Home

Awesome

Go Report Card License

It's an old version of dnn fo golang. Please see new updated version https://github.com/go-ml-dev/nn

import (
	"github.com/sudachen/go-dnn/data/mnist"
	"github.com/sudachen/go-dnn/fu"
	"github.com/sudachen/go-dnn/mx"
	"github.com/sudachen/go-dnn/ng"
	"github.com/sudachen/go-dnn/nn"
	"gotest.tools/assert"
	"testing"
	"time"
)

var mnistConv0 = nn.Connect(
	&nn.Convolution{Channels: 24, Kernel: mx.Dim(3, 3), Activation: nn.ReLU},
	&nn.MaxPool{Kernel: mx.Dim(2, 2), Stride: mx.Dim(2, 2)},
	&nn.Convolution{Channels: 32, Kernel: mx.Dim(5, 5), Activation: nn.ReLU, BatchNorm: true},
	&nn.MaxPool{Kernel: mx.Dim(2, 2), Stride: mx.Dim(2, 2)},
	&nn.FullyConnected{Size: 32, Activation: nn.Swish, BatchNorm: true},
	&nn.FullyConnected{Size: 10, Activation: nn.Softmax})

func Test_mnistConv0(t *testing.T) {

	gym := &ng.Gym{
		Optimizer: &nn.Adam{Lr: .001},
		Loss:      &nn.LabelCrossEntropyLoss{},
		Input:     mx.Dim(32, 1, 28, 28),
		Epochs:    5,
		Verbose:   ng.Printing,
		Every:     1 * time.Second,
		Dataset:   &mnist.Dataset{},
		Metric:    &ng.Classification{Accuracy: 0.98},
		Seed:      42,
	}

	acc, params, err := gym.Train(mx.CPU, mnistConv0)
	assert.NilError(t, err)
	assert.Assert(t, acc >= 0.98)
	err = params.Save(fu.CacheFile("tests/mnistConv0.params"))
	assert.NilError(t, err)

	net, err := nn.Bind(mx.CPU, mnistConv0, mx.Dim(10, 1, 28, 28), nil)
	assert.NilError(t, err)
	err = net.LoadParamsFile(fu.CacheFile("tests/mnistConv0.params"), false)
	assert.NilError(t, err)
	_ = net.PrintSummary(false)

	ok, err := ng.Measure(net, &mnist.Dataset{}, &ng.Classification{Accuracy: 0.98}, ng.Printing)
	assert.Assert(t, ok)
}
Network Identity: 158cf5bd604e12e7bd438084e135703bd89dc10f
Symbol              | Operation            | Output        |  Params #
----------------------------------------------------------------------
_input              | null                 | (32,1,28,28)  |         0
Convolution01       | Convolution((3,3)//) | (32,24,26,26) |       240
Convolution01$A     | Activation(relu)     | (32,24,26,26) |         0
MaxPool02           | Pooling(max)         | (32,24,13,13) |         0
Convolution03       | Convolution((5,5)//) | (32,32,9,9)   |     19232
Convolution03$BN    | BatchNorm            | (32,32,9,9)   |       128
Convolution03$A     | Activation(relu)     | (32,32,9,9)   |         0
MaxPool04           | Pooling(max)         | (32,32,4,4)   |         0
FullyConnected05    | FullyConnected       | (32,32)       |     16416
FullyConnected05$BN | BatchNorm            | (32,32)       |       128
sigmoid@sym07       | sigmoid              | (32,32)       |         0
FullyConnected05$A  | elemwise_mul         | (32,32)       |         0
FullyConnected06    | FullyConnected       | (32,10)       |       330
FullyConnected06$A  | SoftmaxActivation()  | (32,10)       |         0
BlockGrad@sym08     | BlockGrad            | (32,10)       |         0
make_loss@sym09     | make_loss            | (32,10)       |         0
pick@sym10          | pick                 | (32,1)        |         0
log@sym11           | log                  | (32,1)        |         0
_mul_scalar@sym12   | _mul_scalar          | (32,1)        |         0
mean@sym13          | mean                 | (1)           |         0
make_loss@sym14     | make_loss            | (1)           |         0
----------------------------------------------------------------------
Total params: 36474
[000] batch: 389, loss: 0.09991227
[000] batch: 1074, loss: 0.055281825
[000] batch: 1855, loss: 0.0760978
[000] metric: 0.988, final loss: 0.0515
Achieved reqired metric
Symbol              | Operation            | Output        |  Params #
----------------------------------------------------------------------
_input              | null                 | (10,1,28,28)  |         0
Convolution01       | Convolution((3,3)//) | (10,24,26,26) |       240
Convolution01$A     | Activation(relu)     | (10,24,26,26) |         0
MaxPool02           | Pooling(max)         | (10,24,13,13) |         0
Convolution03       | Convolution((5,5)//) | (10,32,9,9)   |     19232
Convolution03$BN    | BatchNorm            | (10,32,9,9)   |       128
Convolution03$A     | Activation(relu)     | (10,32,9,9)   |         0
MaxPool04           | Pooling(max)         | (10,32,4,4)   |         0
FullyConnected05    | FullyConnected       | (10,32)       |     16416
FullyConnected05$BN | BatchNorm            | (10,32)       |       128
sigmoid@sym07       | sigmoid              | (10,32)       |         0
FullyConnected05$A  | elemwise_mul         | (10,32)       |         0
FullyConnected06    | FullyConnected       | (10,10)       |       330
FullyConnected06$A  | SoftmaxActivation()  | (10,10)       |         0
----------------------------------------------------------------------
Total params: 36474
Accuracy over 1000*10 batchs: 0.988
--- PASS: Test_mnistConv0 (6.51s)