深度学习理解概念系列VGG模型参数

文章资讯 2020-08-06 16:07:29

深度学习理解概念系列VGG模型参数

第一次看到VGG19模型的参数时我整个人是懵的(疯狂套娃那种),这个模型参数应该是使用MATLAB导出的,所以要可视化好像需要下MATLAB。但仍然可以导入后在python环境里大概看看长什么样子。
vgg = scipy.io.loadmat(path)
vgg.keys()
'__header__', '__version__', '__globals__', 'layers', 'classes', 'normalization']
这里最重要的是后面三个(normalization对图像预处理用的上)。
w,b参数的获取
由上面我们可以知道
获取第0索引层(也就是input层和第一层之间的)w:vgg['layers'][0][0][0][0][0][0][0]
获取第0索引层(也就是input层和第一层之间的)b:vgg['layers'][0][0][0][0][0][0][1]
也可以直接w,b=vgg['layers'][0][0][0][0][0][0]
把上面红色的0增加就可以获取到所有的层参数
网络层参数的准备代码
def net(data_path, input_image):
layers = (
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
'relu5_3', 'conv5_4', 'relu5_4'
)
data = scipy.io.loadmat(data_path)
weights = data['layers'][0]
net = {}
current = input_image
for i, name in enumerate(layers):
kind = name[:4]
if kind == 'conv':
kernels, bias = weights[i][0][0][0][0]
kernels = np.transpose(kernels, (1, 0, 2, 3))
bias = bias.reshape(-1)
current = _conv_layer(current, kernels, bias)
elif kind == 'relu':
current = tf.nn.relu(current)
elif kind == 'pool':
current = _pool_layer(current)
net[name] = current
assert len(net) == len(layers)
return net, mean_pixel, layers
print ("Network for VGG ready")