對(duì)tensorflow中tf.nn.conv1d和layers.conv1d的區(qū)別詳解
在用tensorflow做一維的卷積神經(jīng)網(wǎng)絡(luò)的時(shí)候會(huì)遇到tf.nn.conv1d和layers.conv1d這兩個(gè)函數(shù),但是這兩個(gè)函數(shù)有什么區(qū)別呢,通過(guò)計(jì)算得到一些規(guī)律。
1.關(guān)于tf.nn.conv1d的解釋?zhuān)韵率荰ensor Flow中關(guān)于tf.nn.conv1d的API注解:
Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape [batch, in_width, in_channels] if data_format is "NHWC", or [batch, in_channels, in_width] if data_format is "NCHW", and a filter / kernel tensor of shape [filter_width, in_channels, out_channels], this op reshapes the arguments to pass them to conv2d to perform the equivalent convolution operation.
Internally, this op reshapes the input tensors and invokes `tf.nn.conv2d`. For example, if `data_format` does not start with "NC", a tensor of shape [batch, in_width, in_channels] is reshaped to [batch, 1, in_width, in_channels], and the filter is reshaped to [1, filter_width, in_channels, out_channels]. The result is then reshaped back to [batch, out_width, out_channels] whereoutwidthisafunctionofthestrideandpaddingasinconv2dwhereoutwidthisafunctionofthestrideandpaddingasinconv2d and returned to the caller.
Args: value: A 3D `Tensor`. Must be of type `float32` or `float64`. filters: A 3D `Tensor`. Must have the same type as `input`. stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' use_cudnn_on_gpu: An optional `bool`. Defaults to `True`. data_format: An optional `string` from `"NHWC", "NCHW"`. Defaults to `"NHWC"`, the data is stored in the order of [batch, in_width, in_channels]. The `"NCHW"` format stores data as [batch, in_channels, in_width]. name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as input.
Raises:
ValueError: if `data_format` is invalid.
什么意思呢?就是說(shuō)conv1d的參數(shù)含義:(以NHWC格式為例,即,通道維在最后)
1、value:在注釋中,value的格式為:[batch, in_width, in_channels],batch為樣本維,表示多少個(gè)樣本,in_width為寬度維,表示樣本的寬度,in_channels維通道維,表示樣本有多少個(gè)通道。 事實(shí)上,也可以把格式看作如下:[batch, 行數(shù), 列數(shù)],把每一個(gè)樣本看作一個(gè)平鋪開(kāi)的二維數(shù)組。這樣的話(huà)可以方便理解。
2、filters:在注釋中,filters的格式為:[filter_width, in_channels, out_channels]。按照value的第二種看法,filter_width可以看作每次與value進(jìn)行卷積的行數(shù),in_channels表示value一共有多少列(與value中的in_channels相對(duì)應(yīng))。out_channels表示輸出通道,可以理解為一共有多少個(gè)卷積核,即卷積核的數(shù)目。
3、stride:一個(gè)整數(shù),表示步長(zhǎng),每次(向下)移動(dòng)的距離(TensorFlow中解釋是向右移動(dòng)的距離,這里可以看作向下移動(dòng)的距離)。
4、padding:同conv2d,value是否需要在下方填補(bǔ)0。
5、name:名稱(chēng)??墒÷浴?/p>
首先從參數(shù)列表可以看出value指的輸入的數(shù)據(jù),stride就是卷積的步長(zhǎng),這里我們最有疑問(wèn)的就是filters這個(gè)參數(shù),那么我們對(duì)filter進(jìn)行簡(jiǎn)單的說(shuō)明。從上面可以看到filters的格式為:[filter_width, in_channels, out_channels],這是一個(gè)數(shù)組的維度,對(duì)應(yīng)的是卷積核的大小,輸入的channel的格式,和卷積核的個(gè)數(shù),下面我們用例子說(shuō)明問(wèn)題:
import tensorflow as tf import numpy as np if __name__ == '__main__': inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1]) w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1]) # filter width, filter channels and out channels(number of kernels) cov1 = tf.nn.conv1d(inputs, w, stride=1, padding='VALID') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) out = sess.run(cov1) print(out)
其輸出為:
[[[ 5.], [ 8.], [11.], [14.]]]
我們分析一下,輸入的數(shù)據(jù)為[[[1],[2],[3],[4],[5]]],有5個(gè)特征,分別對(duì)應(yīng)的數(shù)值為1,2,3,4,5,那么經(jīng)過(guò)卷積的結(jié)果為5,8,11,14,那么這個(gè)結(jié)果是怎么來(lái)的呢,我們根據(jù)卷積的計(jì)算,可以得到5 = 1*1 + 2*2, 8=2*1+ 3*2, 11 = 3*1+4*2, 14=4*1+5*2, 也就是W1=1, W2=2,正好和我們先面filters設(shè)置的數(shù)值相等,
w = np.array([1, 2], dtype=np.float32).reshape([2, 1, 1])
所以可以看到這個(gè)filtes設(shè)置的是是卷積核矩陣的,換句話(huà)說(shuō),卷積核矩陣我們是可以設(shè)置的。
2. 1.關(guān)于tf.layers.conv1d,函數(shù)的定義如下
tf.layers.conv1d( inputs, filters, kernel_size, strides=1, padding='valid', data_format='channels_last', dilation_rate=1, activation=None, use_bias=True, kernel_initializer=None, bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None, reuse=None )
比較重要的幾個(gè)參數(shù)是inputs, filters, kernel_size,下面分別說(shuō)明
inputs : 輸入tensor, 維度(None, a, b) 是一個(gè)三維的tensor
None : 一般是填充樣本的個(gè)數(shù),batch_size
a : 句子中的詞數(shù)或者字?jǐn)?shù)
b : 字或者詞的向量維度
filters : 過(guò)濾器的個(gè)數(shù)
kernel_size : 卷積核的大小,卷積核其實(shí)應(yīng)該是一個(gè)二維的,這里只需要指定一維,是因?yàn)榫矸e核的第二維與輸入的詞向量維度是一致的,因?yàn)閷?duì)于句子而言,卷積的移動(dòng)方向只能是沿著詞的方向,即只能在列維度移動(dòng)。一個(gè)例子:
import tensorflow as tf import numpy as np if __name__ == '__main__': inputs = tf.constant(np.arange(1, 6, dtype=np.float32), shape=[1, 5, 1]) cov2 = tf.layers.conv1d(inputs, filters=1, kernel_size=2, strides=1, padding='VALID') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) out = sess.run(cov2) print(out)
輸出結(jié)果:
[[[-1.9953331] [-3.5520997] [-5.108866 ] [-6.6656327]]]
也許你得到的結(jié)果和我得到的結(jié)果不同,因?yàn)樵谶@個(gè)函數(shù)里面只是設(shè)置了卷積核的尺寸和步長(zhǎng),沒(méi)有設(shè)置具體的卷積核矩陣,所以這個(gè)卷積核矩陣是隨機(jī)生成的,就會(huì)出現(xiàn)可能運(yùn)行上面的程序出現(xiàn)不同結(jié)果的情況。
以上這篇對(duì)tensorflow中tf.nn.conv1d和layers.conv1d的區(qū)別詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python?Concurrent?Futures解鎖并行化編程的魔法示例
Python的concurrent.futures模塊為并行化編程提供了強(qiáng)大的工具,使得開(kāi)發(fā)者能夠輕松地利用多核心和異步執(zhí)行的能力,本文將深入探討concurrent.futures的各個(gè)方面,從基礎(chǔ)概念到高級(jí)用法,為讀者提供全面的了解和實(shí)用的示例代碼2023-12-12使用Python和百度語(yǔ)音識(shí)別生成視頻字幕的實(shí)現(xiàn)
這篇文章主要介紹了使用Python和百度語(yǔ)音識(shí)別生成視頻字幕,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04Python confluent kafka客戶(hù)端配置kerberos認(rèn)證流程詳解
這篇文章主要介紹了Python confluent kafka客戶(hù)端配置kerberos認(rèn)證流程詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10Python中函數(shù)的參數(shù)傳遞與可變長(zhǎng)參數(shù)介紹
這篇文章主要介紹了Python中函數(shù)的參數(shù)傳遞與可變長(zhǎng)參數(shù)介紹,本文分別給出多個(gè)代碼實(shí)例來(lái)講解多種多樣的函數(shù)參數(shù),需要的朋友可以參考下2015-06-06使用BeautifulSoup爬蟲(chóng)程序獲取百度搜索結(jié)果的標(biāo)題和url示例
這篇文章主要介紹了使用BeautifulSoup編寫(xiě)了一段爬蟲(chóng)程序獲取百度搜索結(jié)果的標(biāo)題和url的示例,大家參考使用吧2014-01-0113個(gè)最常用的Python深度學(xué)習(xí)庫(kù)介紹
這篇文章主要介紹了13個(gè)最常用的Python深度學(xué)習(xí)庫(kù)介紹,具有一定參考價(jià)值,需要的朋友可以參考下。2017-10-10解決pycharm不能自動(dòng)補(bǔ)全第三方庫(kù)的函數(shù)和屬性問(wèn)題
這篇文章主要介紹了解決pycharm不能自動(dòng)補(bǔ)全第三方庫(kù)的函數(shù)和屬性問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-03-03python的鏈表基礎(chǔ)知識(shí)點(diǎn)
在本篇文章里小編給大家整理的是一篇關(guān)于python的鏈表基礎(chǔ)知識(shí)點(diǎn)內(nèi)容,有興趣的朋友們可以參考學(xué)習(xí)下。2020-09-09