5.2. 参数管理¶ Open the notebook in SageMaker Studio Lab
在选择了架构并设置了超参数后,我们就进入了训练阶段。 此时,我们的目标是找到使损失函数最小化的模型参数值。 经过训练后,我们将需要使用这些参数来做出未来的预测。 此外,有时我们希望提取参数,以便在其他环境中复用它们, 将模型保存下来,以便它可以在其他软件中执行, 或者为了获得科学的理解而进行检查。
之前的介绍中,我们只依靠深度学习框架来完成训练的工作, 而忽略了操作参数的具体细节。 本节,我们将介绍以下内容:
访问参数,用于调试、诊断和可视化;
参数初始化;
在不同模型组件间共享参数。
我们首先看一下具有单隐藏层的多层感知机。
from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # 使用默认初始化方法
X = np.random.uniform(size=(2, 4))
net(X) # 正向传播
array([[0.0054572 ],
[0.00488594]])
import torch
from torch import nn
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(size=(2, 4))
net(X)
tensor([[-0.0619],
[-0.0489]], grad_fn=<AddmmBackward0>)
import tensorflow as tf
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4, activation=tf.nn.relu),
tf.keras.layers.Dense(1),
])
X = tf.random.uniform((2, 4))
net(X)
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-0.37380144],
[-0.60269785]], dtype=float32)>
import warnings
warnings.filterwarnings(action='ignore')
import paddle
from paddle import nn
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = paddle.rand([2, 4])
net(X)
Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0.08741230],
[0.30509907]])
5.2.1. 参数访问¶
我们从已有模型中访问参数。 当通过Sequential
类定义模型时,
我们可以通过索引来访问模型的任意层。
这就像模型是一个列表一样,每层的参数都在其属性中。
如下所示,我们可以检查第二个全连接层的参数。
print(net[1].params)
dense1_ (
Parameter dense1_weight (shape=(1, 8), dtype=float32)
Parameter dense1_bias (shape=(1,), dtype=float32)
)
print(net[2].state_dict())
OrderedDict([('weight', tensor([[ 0.3016, -0.1901, -0.1991, -0.1220, 0.1121, -0.1424, -0.3060, 0.3400]])), ('bias', tensor([-0.0291]))])
print(net.layers[2].weights)
[<tf.Variable 'dense_1/kernel:0' shape=(4, 1) dtype=float32, numpy=
array([[-0.3281002 ],
[-0.54713833],
[-0.59404033],
[-0.5690916 ]], dtype=float32)>, <tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]
print(net[2].state_dict())
OrderedDict([('weight', Parameter containing:
Tensor(shape=[8, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-0.25485069],
[-0.04266340],
[-0.79025936],
[-0.75556862],
[-0.26889783],
[ 0.11903048],
[ 0.22622812],
[ 0.58539867]])), ('bias', Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.]))])
输出的结果告诉我们一些重要的事情: 首先,这个全连接层包含两个参数,分别是该层的权重和偏置。 两者都存储为单精度浮点数(float32)。 注意,参数名称允许唯一标识每个参数,即使在包含数百个层的网络中也是如此。
5.2.1.1. 目标参数¶
注意,每个参数都表示为参数类的一个实例。 要对参数执行任何操作,首先我们需要访问底层的数值。 有几种方法可以做到这一点。有些比较简单,而另一些则比较通用。 下面的代码从第二个全连接层(即第三个神经网络层)提取偏置, 提取后返回的是一个参数类实例,并进一步访问该参数的值。
print(type(net[1].bias))
print(net[1].bias)
print(net[1].bias.data())
<class 'mxnet.gluon.parameter.Parameter'>
Parameter dense1_bias (shape=(1,), dtype=float32)
[0.]
参数是复合的对象,包含值、梯度和额外信息。 这就是我们需要显式参数值的原因。 除了值之外,我们还可以访问每个参数的梯度。 在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。
net[1].weight.grad()
array([[0., 0., 0., 0., 0., 0., 0., 0.]])
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.data)
<class 'torch.nn.parameter.Parameter'>
Parameter containing:
tensor([-0.0291], requires_grad=True)
tensor([-0.0291])
参数是复合的对象,包含值、梯度和额外信息。 这就是我们需要显式参数值的原因。 除了值之外,我们还可以访问每个参数的梯度。 在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。
net[2].weight.grad == None
True
print(type(net.layers[2].weights[1]))
print(net.layers[2].weights[1])
print(tf.convert_to_tensor(net.layers[2].weights[1]))
<class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
<tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>
tf.Tensor([0.], shape=(1,), dtype=float32)
print(type(net[2].bias))
print(net[2].bias)
print(net[2].bias.value)
<class 'paddle.fluid.framework.ParamBase'>
Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])
<bound method PyCapsule.value of Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])>
参数是复合的对象,包含值、梯度和额外信息。 这就是我们需要显式参数值的原因。 除了值之外,我们还可以访问每个参数的梯度。 在上面这个网络中,由于我们还没有调用反向传播,所以参数的梯度处于初始状态。
net[2].weight.grad == None
True
5.2.1.2. 一次性访问所有参数¶
当我们需要对所有参数执行操作时,逐个访问它们可能会很麻烦。 当我们处理更复杂的块(例如,嵌套块)时,情况可能会变得特别复杂, 因为我们需要递归整个树来提取每个子块的参数。 下面,我们将通过演示来比较访问第一个全连接层的参数和访问所有层。
print(net[0].collect_params())
print(net.collect_params())
dense0_ (
Parameter dense0_weight (shape=(8, 4), dtype=float32)
Parameter dense0_bias (shape=(8,), dtype=float32)
)
sequential0_ (
Parameter dense0_weight (shape=(8, 4), dtype=float32)
Parameter dense0_bias (shape=(8,), dtype=float32)
Parameter dense1_weight (shape=(1, 8), dtype=float32)
Parameter dense1_bias (shape=(1,), dtype=float32)
)
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])
('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))
('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))
print(net.layers[1].weights)
print(net.get_weights())
[<tf.Variable 'dense/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[-0.0641762 , -0.10979968, -0.00594735, -0.36962172],
[-0.7772793 , 0.01906997, 0.79147226, 0.21822304],
[-0.19784456, -0.6576476 , 0.11548519, -0.6094498 ],
[-0.8143069 , 0.3029465 , 0.24064404, -0.3055349 ]],
dtype=float32)>, <tf.Variable 'dense/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>]
[array([[-0.0641762 , -0.10979968, -0.00594735, -0.36962172],
[-0.7772793 , 0.01906997, 0.79147226, 0.21822304],
[-0.19784456, -0.6576476 , 0.11548519, -0.6094498 ],
[-0.8143069 , 0.3029465 , 0.24064404, -0.3055349 ]],
dtype=float32), array([0., 0., 0., 0.], dtype=float32), array([[-0.3281002 ],
[-0.54713833],
[-0.59404033],
[-0.5690916 ]], dtype=float32), array([0.], dtype=float32)]
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
print(*[(name, param.shape) for name, param in net.named_parameters()])
('weight', [4, 8]) ('bias', [8])
('0.weight', [4, 8]) ('0.bias', [8]) ('2.weight', [8, 1]) ('2.bias', [1])
这为我们提供了另一种访问网络参数的方式,如下所示。
net.collect_params()['dense1_bias'].data()
array([0.])
net.state_dict()['2.bias'].data
tensor([-0.0291])
net.get_weights()[1]
array([0., 0., 0., 0.], dtype=float32)
net.state_dict()['2.bias']
Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,
[0.])
5.2.1.3. 从嵌套块收集参数¶
让我们看看,如果我们将多个块相互嵌套,参数命名约定是如何工作的。 我们首先定义一个生成块的函数(可以说是“块工厂”),然后将这些块组合到更大的块中。
def block1():
net = nn.Sequential()
net.add(nn.Dense(32, activation='relu'))
net.add(nn.Dense(16, activation='relu'))
return net
def block2():
net = nn.Sequential()
for _ in range(4):
# 在这里嵌套
net.add(block1())
return net
rgnet = nn.Sequential()
rgnet.add(block2())
rgnet.add(nn.Dense(10))
rgnet.initialize()
rgnet(X)
array([[-6.3465846e-09, -1.1096752e-09, 6.4161787e-09, 6.6354140e-09,
-1.1265507e-09, 1.3284951e-10, 9.3619388e-09, 3.2229084e-09,
5.9429879e-09, 8.8181435e-09],
[-8.6219423e-09, -7.5150686e-10, 8.3133251e-09, 8.9321128e-09,
-1.6740003e-09, 3.2405989e-10, 1.2115976e-08, 4.4926449e-09,
8.0741742e-09, 1.2075874e-08]])
def block1():
return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),
nn.Linear(8, 4), nn.ReLU())
def block2():
net = nn.Sequential()
for i in range(4):
# 在这里嵌套
net.add_module(f'block {i}', block1())
return net
rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)
tensor([[-0.3078],
[-0.3078]], grad_fn=<AddmmBackward0>)
def block1(name):
return tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4, activation=tf.nn.relu)],
name=name)
def block2():
net = tf.keras.Sequential()
for i in range(4):
# 在这里嵌套
net.add(block1(name=f'block-{i}'))
return net
rgnet = tf.keras.Sequential()
rgnet.add(block2())
rgnet.add(tf.keras.layers.Dense(1))
rgnet(X)
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.],
[0.]], dtype=float32)>
def block1():
return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),
nn.Linear(8, 4), nn.ReLU())
def block2():
net = nn.Sequential()
for i in range(4):
# 在这里嵌套
net.add_sublayer(f'block {i}', block1())
return net
rgnet = nn.Sequential(block2(), nn.Linear(4, 1))
rgnet(X)
Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-0.00040114],
[-0.00059529]])
设计了网络后,我们看看它是如何工作的。
print(rgnet.collect_params)
print(rgnet.collect_params())
<bound method Block.collect_params of Sequential(
(0): Sequential(
(0): Sequential(
(0): Dense(4 -> 32, Activation(relu))
(1): Dense(32 -> 16, Activation(relu))
)
(1): Sequential(
(0): Dense(16 -> 32, Activation(relu))
(1): Dense(32 -> 16, Activation(relu))
)
(2): Sequential(
(0): Dense(16 -> 32, Activation(relu))
(1): Dense(32 -> 16, Activation(relu))
)
(3): Sequential(
(0): Dense(16 -> 32, Activation(relu))
(1): Dense(32 -> 16, Activation(relu))
)
)
(1): Dense(16 -> 10, linear)
)>
sequential1_ (
Parameter dense2_weight (shape=(32, 4), dtype=float32)
Parameter dense2_bias (shape=(32,), dtype=float32)
Parameter dense3_weight (shape=(16, 32), dtype=float32)
Parameter dense3_bias (shape=(16,), dtype=float32)
Parameter dense4_weight (shape=(32, 16), dtype=float32)
Parameter dense4_bias (shape=(32,), dtype=float32)
Parameter dense5_weight (shape=(16, 32), dtype=float32)
Parameter dense5_bias (shape=(16,), dtype=float32)
Parameter dense6_weight (shape=(32, 16), dtype=float32)
Parameter dense6_bias (shape=(32,), dtype=float32)
Parameter dense7_weight (shape=(16, 32), dtype=float32)
Parameter dense7_bias (shape=(16,), dtype=float32)
Parameter dense8_weight (shape=(32, 16), dtype=float32)
Parameter dense8_bias (shape=(32,), dtype=float32)
Parameter dense9_weight (shape=(16, 32), dtype=float32)
Parameter dense9_bias (shape=(16,), dtype=float32)
Parameter dense10_weight (shape=(10, 16), dtype=float32)
Parameter dense10_bias (shape=(10,), dtype=float32)
)
print(rgnet)
Sequential(
(0): Sequential(
(block 0): Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, bias=True)
(3): ReLU()
)
(block 1): Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, bias=True)
(3): ReLU()
)
(block 2): Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, bias=True)
(3): ReLU()
)
(block 3): Sequential(
(0): Linear(in_features=4, out_features=8, bias=True)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, bias=True)
(3): ReLU()
)
)
(1): Linear(in_features=4, out_features=1, bias=True)
)
print(rgnet.summary())
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential_2 (Sequential) (2, 4) 80
dense_6 (Dense) (2, 1) 5
=================================================================
Total params: 85
Trainable params: 85
Non-trainable params: 0
_________________________________________________________________
None
print(rgnet)
Sequential(
(0): Sequential(
(block 0): Sequential(
(0): Linear(in_features=4, out_features=8, dtype=float32)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, dtype=float32)
(3): ReLU()
)
(block 1): Sequential(
(0): Linear(in_features=4, out_features=8, dtype=float32)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, dtype=float32)
(3): ReLU()
)
(block 2): Sequential(
(0): Linear(in_features=4, out_features=8, dtype=float32)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, dtype=float32)
(3): ReLU()
)
(block 3): Sequential(
(0): Linear(in_features=4, out_features=8, dtype=float32)
(1): ReLU()
(2): Linear(in_features=8, out_features=4, dtype=float32)
(3): ReLU()
)
)
(1): Linear(in_features=4, out_features=1, dtype=float32)
)
因为层是分层嵌套的,所以我们也可以像通过嵌套列表索引一样访问它们。 下面,我们访问第一个主要的块中、第二个子块的第一层的偏置项。
rgnet[0][1][0].bias.data()
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
rgnet[0][1][0].bias.data
tensor([-0.2539, 0.4913, 0.3029, -0.4799, 0.2022, 0.3146, 0.0601, 0.3757])
rgnet.layers[0].layers[1].layers[1].weights[1]
<tf.Variable 'dense_3/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>
print(rgnet[0].state_dict()['block 0.0.bias'])
Parameter containing:
Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0., 0., 0., 0., 0., 0., 0.])
5.2.2. 参数初始化¶
知道了如何访问参数后,现在我们看看如何正确地初始化参数。 我们在 4.8节中讨论了良好初始化的必要性。 深度学习框架提供默认随机初始化, 也允许我们创建自定义初始化方法, 满足我们通过其他规则实现初始化权重。
默认情况下,MXNet通过初始化权重参数的方法是
从均匀分布\(U(-0.07, 0.07)\)中随机采样权重,并将偏置参数设置为0。
MXNet的init
模块提供了多种预置初始化方法。
默认情况下,PyTorch会根据一个范围均匀地初始化权重和偏置矩阵,
这个范围是根据输入和输出维度计算出的。
PyTorch的nn.init
模块提供了多种预置初始化方法。
默认情况下,Keras会根据一个范围均匀地初始化权重矩阵,
这个范围是根据输入和输出维度计算出的。 偏置参数设置为0。
TensorFlow在根模块和keras.initializers
模块中提供了各种初始化方法。
默认情况下,PaddlePaddle会使用Xavier初始化权重矩阵, 偏置参数设置为0。
PaddlePaddle的nn.initializer
模块提供了多种预置初始化方法。
5.2.2.1. 内置初始化¶
让我们首先调用内置的初始化器。 下面的代码将所有权重参数初始化为标准差为0.01的高斯随机变量, 且将偏置参数设置为0。
# 这里的force_reinit确保参数会被重新初始化,不论之前是否已经被初始化
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
array([-0.00324057, -0.00895028, -0.00698632, 0.01030831])
def init_normal(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([-0.0128, -0.0141, 0.0062, 0.0028]), tensor(0.))
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1)])
net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_7/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[-0.01661727, 0.0003547 , 0.01252732, 0.00514138],
[-0.01230722, -0.01045955, -0.01211653, -0.00016097],
[ 0.00710331, 0.0097906 , -0.00869265, -0.00833112],
[-0.00381531, 0.00480495, -0.00317094, -0.00495612]],
dtype=float32)>,
<tf.Variable 'dense_7/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)
def init_normal(m):
if type(m) == nn.Linear:
paddle.nn.initializer.Normal(mean=0.0, std=0.01)
paddle.zeros(m.bias)
net.apply(init_normal)
net[0].weight[0],net[0].state_dict()['bias']
(Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[-0.37556022, -0.56956506, 0.51928586, -0.62428892, -0.07560658,
-0.19561028, 0.45752531, -0.30620268]),
Parameter containing:
Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0., 0., 0., 0., 0., 0., 0.]))
我们还可以将所有参数初始化为给定的常数,比如初始化为1。
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
array([1., 1., 1., 1.])
def init_constant(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight, 1)
nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([1., 1., 1., 1.]), tensor(0.))
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.Constant(1),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1),
])
net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_9/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)>,
<tf.Variable 'dense_9/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)
def init_constant(m):
if type(m) == nn.Linear:
paddle.nn.initializer.Constant(value = 1)
paddle.zeros(m.bias)
net.apply(init_constant)
net[0].weight[0],net[0].state_dict()['bias']
(Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[-0.37556022, -0.56956506, 0.51928586, -0.62428892, -0.07560658,
-0.19561028, 0.45752531, -0.30620268]),
Parameter containing:
Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[0., 0., 0., 0., 0., 0., 0., 0.]))
我们还可以对某些块应用不同的初始化方法。 例如,下面我们使用Xavier初始化方法初始化第一个神经网络层, 然后将第三个神经网络层初始化为常量值42。
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
[-0.17594433 0.02314097 -0.1992535 0.09509248]
[[42. 42. 42. 42. 42. 42. 42. 42.]]
def init_xavier(m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
def init_42(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight, 42)
net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
tensor([ 0.3809, 0.5354, -0.4686, -0.2376])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.GlorotUniform()),
tf.keras.layers.Dense(
1, kernel_initializer=tf.keras.initializers.Constant(1)),
])
net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
<tf.Variable 'dense_11/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[-0.8453125 , 0.5705376 , -0.4807635 , -0.32239312],
[-0.78933173, 0.13476151, 0.72802204, 0.34821087],
[-0.49852088, -0.67133677, 0.86592776, -0.383544 ],
[-0.4393313 , -0.17272401, -0.8350105 , -0.5326498 ]],
dtype=float32)>
<tf.Variable 'dense_12/kernel:0' shape=(4, 1) dtype=float32, numpy=
array([[1.],
[1.],
[1.],
[1.]], dtype=float32)>
def xavier(m):
if type(m) == nn.Linear:
paddle.nn.initializer.XavierUniform(m.weight)
def init_42(m):
if type(m) == nn.Linear:
paddle.nn.initializer.Constant(42)
net[0].apply(xavier)
net[2].apply(init_42)
print(net[0].weight[0])
print(net[2].weight)
Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[-0.37556022, -0.56956506, 0.51928586, -0.62428892, -0.07560658,
-0.19561028, 0.45752531, -0.30620268])
Parameter containing:
Tensor(shape=[8, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-0.25485069],
[-0.04266340],
[-0.79025936],
[-0.75556862],
[-0.26889783],
[ 0.11903048],
[ 0.22622812],
[ 0.58539867]])
5.2.2.2. 自定义初始化¶
有时,深度学习框架没有提供我们需要的初始化方法。 在下面的例子中,我们使用以下的分布为任意权重参数\(w\)定义初始化方法:
在这里,我们定义了Initializer
类的子类。
通常,我们只需要实现_init_weight
函数,
该函数接受张量参数(data
)并为其分配所需的初始化值。
class MyInit(init.Initializer):
def _init_weight(self, name, data):
print('Init', name, data.shape)
data[:] = np.random.uniform(-10, 10, data.shape)
data *= np.abs(data) >= 5
net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
array([[ 0. , -0. , -0. , 8.522827 ],
[ 0. , -8.828651 , -0. , -5.6012006]])
同样,我们实现了一个my_init
函数来应用到net
。
def my_init(m):
if type(m) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in m.named_parameters()][0])
nn.init.uniform_(m.weight, -10, 10)
m.weight.data *= m.weight.data.abs() >= 5
net.apply(my_init)
net[0].weight[:2]
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[-0.0000, 0.0000, -0.0000, 0.0000],
[-0.0000, 9.3464, 5.5061, 6.8197]], grad_fn=<SliceBackward0>)
在这里,我们定义了一个Initializer
的子类,
并实现了__call__
函数。 该函数返回给定形状和数据类型的所需张量。
class MyInit(tf.keras.initializers.Initializer):
def __call__(self, shape, dtype=None):
data=tf.random.uniform(shape, -10, 10, dtype=dtype)
factor=(tf.abs(data) >= 5)
factor=tf.cast(factor, tf.float32)
return data * factor
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=MyInit()),
tf.keras.layers.Dense(1),
])
net(X)
print(net.layers[1].weights[0])
<tf.Variable 'dense_13/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0. , 5.3255177, 7.927185 , -6.057389 ],
[-0. , 0. , 6.7501526, 0. ],
[-8.107052 , 0. , -6.7142034, -0. ],
[-0. , -0. , 9.668493 , -0. ]], dtype=float32)>
同样,我们实现了一个my_init
函数来应用到net
。
def my_init(m):
if type(m) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in m.named_parameters()][0])
paddle.nn.initializer.XavierUniform(m.weight, -10, 10)
h = paddle.abs(m.weight) >= 5
h = paddle.to_tensor(h)
m = paddle.to_tensor(m.weight)
m *= h
net.apply(my_init)
net[0].weight[:2]
Init weight [4, 8]
Init weight [8, 1]
Tensor(shape=[2, 8], dtype=float32, place=Place(cpu), stop_gradient=False,
[[-0.37556022, -0.56956506, 0.51928586, -0.62428892, -0.07560658,
-0.19561028, 0.45752531, -0.30620268],
[ 0.03955257, 0.05066150, -0.38958645, 0.57687515, 0.21143413,
-0.51807344, -0.01132214, -0.65783387]])
注意,我们始终可以直接设置参数。
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
array([42. , 1. , 1. , 9.522827])
高级用户请注意:如果要在autograd
范围内调整参数,
则需要使用set_data
,以避免误导自动微分机制。
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
tensor([42., 1., 1., 1.])
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
<tf.Variable 'dense_13/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[42. , 6.3255177, 8.927185 , -5.057389 ],
[ 1. , 1. , 7.7501526, 1. ],
[-7.107052 , 1. , -5.7142034, 1. ],
[ 1. , 1. , 10.668493 , 1. ]], dtype=float32)>
net[0].weight.set_value(net[0].weight.numpy() + 1)
val = net[0].weight.numpy()
val[0, 0] = 42
net[0].weight.set_value(val)
net[0].weight[0]
Tensor(shape=[8], dtype=float32, place=Place(cpu), stop_gradient=False,
[42. , 0.43043494, 1.51928592, 0.37571108, 0.92439342, 0.80438972,
1.45752525, 0.69379735])
5.2.3. 参数绑定¶
有时我们希望在多个层间共享参数: 我们可以定义一个稠密层,然后使用它的参数来设置另一个层的参数。
net = nn.Sequential()
# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Dense(8, activation='relu')
net.add(nn.Dense(8, activation='relu'),
shared,
nn.Dense(8, activation='relu', params=shared.params),
nn.Dense(10))
net.initialize()
X = np.random.uniform(size=(2, 20))
net(X)
# 检查参数是否相同
print(net[1].weight.data()[0] == net[2].weight.data()[0])
net[1].weight.data()[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[1].weight.data()[0] == net[2].weight.data()[0])
[ True True True True True True True True]
[ True True True True True True True True]
这个例子表明第二层和第三层的参数是绑定的。 它们不仅值相等,而且由相同的张量表示。 因此,如果我们改变其中一个参数,另一个参数也会改变。 这里有一个问题:当参数绑定时,梯度会发生什么情况? 答案是由于模型参数包含梯度, 因此在反向传播期间第二个隐藏层和第三个隐藏层的梯度会加在一起。
# 我们需要给共享层一个名称,以便可以引用它的参数
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),
shared, nn.ReLU(),
shared, nn.ReLU(),
nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
# 确保它们实际上是同一个对象,而不只是有相同的值
print(net[2].weight.data[0] == net[4].weight.data[0])
tensor([True, True, True, True, True, True, True, True])
tensor([True, True, True, True, True, True, True, True])
这个例子表明第三个和第五个神经网络层的参数是绑定的。 它们不仅值相等,而且由相同的张量表示。 因此,如果我们改变其中一个参数,另一个参数也会改变。 这里有一个问题:当参数绑定时,梯度会发生什么情况? 答案是由于模型参数包含梯度,因此在反向传播期间第二个隐藏层 (即第三个神经网络层)和第三个隐藏层(即第五个神经网络层)的梯度会加在一起。
# tf.keras的表现有点不同。它会自动删除重复层
shared = tf.keras.layers.Dense(4, activation=tf.nn.relu)
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
shared,
shared,
tf.keras.layers.Dense(1),
])
net(X)
# 检查参数是否不同
print(len(net.layers) == 3)
True
# 我们需要给共享层一个名称,以便可以引用它的参数。
shared = nn.Linear(8, 8)
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),
shared, nn.ReLU(),
shared, nn.ReLU(),
nn.Linear(8, 1))
net(X)
# 检查参数是否相同
print(net[2].weight[0] == net[4].weight[0])
Tensor(shape=[8], dtype=bool, place=Place(cpu), stop_gradient=False,
[True, True, True, True, True, True, True, True])
5.2.4. 小结¶
我们有几种方法可以访问、初始化和绑定模型参数。
我们可以使用自定义初始化方法。