-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_mlp.go
89 lines (70 loc) · 1.51 KB
/
model_mlp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package whale
import (
"encoding/gob"
"math"
"os"
"github.com/hidetatz/whale/tensor"
)
type MLP struct {
Weights []*Variable
Biases []*Variable
Loss LossCalculator
Optim Optimizer
Activation Activation
}
func NewMLP(layers [][]int, bias bool, act Activation, loss LossCalculator, optim Optimizer) *MLP {
mlp := &MLP{Loss: loss, Optim: optim, Activation: act}
// init weights and biases
for _, l := range layers {
scale := float32(math.Sqrt(float64(1.0 / float32(l[1]))))
w := NewVar(tensor.RandNorm(l[0], l[1]).Mul(tensor.Scalar(scale)))
mlp.Weights = append(mlp.Weights, w)
if bias {
b := NewVar(tensor.Zeros(l[1]))
mlp.Biases = append(mlp.Biases, b)
}
}
return mlp
}
func (m *MLP) Train(in *Variable) (*Variable, error) {
var x, y *Variable
var err error
x = in
for i := range m.Weights {
w := m.Weights[i]
var b *Variable
if m.Biases != nil {
b = m.Biases[i]
}
x, err = Linear(x, w, b)
if err != nil {
return nil, err
}
if i == len(m.Weights)-1 {
y = x
break
}
// do activation if not last layer
x, err = m.Activation.Activate(x)
if err != nil {
return nil, err
}
}
return y, nil
}
func (m *MLP) LossFn() LossCalculator {
return m.Loss
}
func (m *MLP) Optimizer() Optimizer {
return m.Optim
}
func (m *MLP) Params() []*Variable {
return append(m.Weights, m.Biases...)
}
func (m *MLP) SaveGobFile(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
}
return gob.NewEncoder(f).Encode(m)
}