解決Keras中循環(huán)使用K.ctc_decode內(nèi)存不釋放的問題
如下一段代碼,在多次調(diào)用了K.ctc_decode時(shí),會(huì)發(fā)現(xiàn)程序占用的內(nèi)存會(huì)越來(lái)越高,執(zhí)行速度越來(lái)越慢。
data = generator(...) model = init_model(...) for i in range(NUM): x, y = next(data) _y = model.predict(x) shape = _y.shape input_length = np.ones(shape[0]) * shape[1] ctc_decode = K.ctc_decode(_y, input_length)[0][0] out = K.get_value(ctc_decode)
原因
每次執(zhí)行ctc_decode時(shí)都會(huì)向計(jì)算圖中添加一個(gè)節(jié)點(diǎn),這樣會(huì)導(dǎo)致計(jì)算圖逐漸變大,從而影響計(jì)算速度和內(nèi)存。
PS:有資料說是由于get_value導(dǎo)致的,其中也給出了解決方案。
但是我將ctc_decode放在循環(huán)體之外就不再出現(xiàn)內(nèi)存和速度問題,這是否說明get_value影響其實(shí)不大呢?
解決方案
通過K.function封裝K.ctc_decode,只需初始化一次,只向計(jì)算圖中添加一個(gè)計(jì)算節(jié)點(diǎn),然后多次調(diào)用該節(jié)點(diǎn)(函數(shù))
data = generator(...) model = init_model(...) x = model.output # [batch_sizes, series_length, classes] input_length = KL.Input(batch_shape=[None], dtype='int32') ctc_decode = K.ctc_decode(x, input_length=input_length * K.shape(x)[1]) decode = K.function([model.input, input_length], [ctc_decode[0][0]]) for i in range(NUM): _x, _y = next(data) out = decode([_x, np.ones(1)])
補(bǔ)充知識(shí):CTC_loss和CTC_decode的模型封裝代碼避免節(jié)點(diǎn)不斷增加
該問題可以參考上面的描述,無(wú)論是CTC_decode還是CTC_loss,每次運(yùn)行都會(huì)創(chuàng)建節(jié)點(diǎn),避免的方法是將其封裝到model中,這樣就固定了計(jì)算節(jié)點(diǎn)。
測(cè)試方法: 在初始化節(jié)點(diǎn)后(注意是在運(yùn)行fit/predict至少一次后,因?yàn)檫@些方法也會(huì)更改計(jì)算圖狀態(tài)),運(yùn)行K.get_session().graph.finalize()鎖定節(jié)點(diǎn),此時(shí)如果圖節(jié)點(diǎn)變了會(huì)報(bào)錯(cuò)并提示出錯(cuò)代碼。
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
'''
用于計(jì)算CTC loss
'''
def ctc_lambda_func(self,args):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)` 真實(shí)標(biāo)簽
y_pred: tensor `(samples, time_steps, num_categories)` 預(yù)測(cè)前未經(jīng)過softmax的向量
input_length: tensor `(samples, 1)` 每一個(gè)y_pred的長(zhǎng)度
label_length: tensor `(samples, 1)` 每一個(gè)y_true的長(zhǎng)度
# Returns
Tensor with shape (samples,1) 包含了每一個(gè)樣本的ctc loss
"""
y_true, y_pred, input_length, label_length = args
# y_pred = y_pred[:, :, :]
# y_pred = y_pred[:, 2:, :]
return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)
def __call__(self, args):
'''
ctc_decode 每次創(chuàng)建會(huì)生成一個(gè)節(jié)點(diǎn),這里參考了上面的內(nèi)容
將ctc封裝成模型,是否會(huì)解決這個(gè)問題還沒有測(cè)試過這種方法是否還會(huì)出現(xiàn)創(chuàng)建節(jié)點(diǎn)的問題
'''
y_true = Input(shape=(None,))
y_pred = Input(shape=(None,None))
input_length = Input(shape=(1,))
label_length = Input(shape=(1,))
lamd = Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')([y_true,y_pred,input_length,label_length])
model = Model([y_true,y_pred,input_length,label_length],[lamd],name="ctc")
# return Lambda(self.ctc_lambda_func, output_shape=(1,), name='ctc')(args)
return model(args)
def ctc_batch_cost(self,y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.
# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
label_length = tf.to_int32(tf.squeeze(label_length, axis=-1))
input_length = tf.to_int32(tf.squeeze(input_length, axis=-1))
sparse_labels = tf.to_int32(K.ctc_label_dense_to_sparse(y_true, label_length))
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-7)
# 注意這里的True是為了忽略解碼失敗的情況,此時(shí)loss會(huì)變成nan直到下一個(gè)個(gè)batch
return tf.expand_dims(ctc.ctc_loss(inputs=y_pred,
labels=sparse_labels,
sequence_length=input_length,
ignore_longer_outputs_than_inputs=True), 1)
# 使用方法:(注意shape)
loss_out = CTC_Batch_Cost()([y_true, y_pred, audio_length, label_length])
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTCDecodeLayer(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _ctc_decode(self,args):
base_pred, in_len = args
in_len = K.squeeze(in_len,axis=-1)
r = K.ctc_decode(base_pred, in_len, greedy=True, beam_width=100, top_paths=1)
r1 = r[0][0]
prob = r[1][0]
return [r1,prob]
def call(self, inputs, **kwargs):
return self._ctc_decode(inputs)
def compute_output_shape(self, input_shape):
return [(None,None),(1,)]
class CTCDecode():
'''用與CTC 解碼,得到真實(shí)語(yǔ)音序列
2019年7月18日所寫,對(duì)ctc_decode使用模型進(jìn)行了封裝,從而在初始化完成后不會(huì)再有新節(jié)點(diǎn)的產(chǎn)生
'''
def __init__(self):
base_pred = Input(shape=[None,None],name="pred")
feature_len = Input(shape=[1,],name="feature_len")
r1, prob = CTCDecodeLayer()([base_pred,feature_len])
self.model = Model([base_pred,feature_len],[r1,prob])
pass
def ctc_decode(self,base_pred,in_len,return_prob = False):
'''
:param base_pred:[sample,timestamp,vector]
:param in_len: [sample,1]
:return:
'''
result,prob = self.model.predict([base_pred,in_len])
if return_prob:
return result,prob
return result
def __call__(self,base_pred,in_len,return_prob = False):
return self.ctc_decode(base_pred,in_len,return_prob)
# 使用方法:(注意shape,是batch級(jí)的輸入)
ctc_decoder = CTCDecode()
ctc_decoder.ctc_decode(result,feature_len)
以上這篇解決Keras中循環(huán)使用K.ctc_decode內(nèi)存不釋放的問題就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- Asp.net Core 3.1基于AspectCore實(shí)現(xiàn)AOP實(shí)現(xiàn)事務(wù)、緩存攔截器功能
- 使用keras框架cnn+ctc_loss識(shí)別不定長(zhǎng)字符圖片操作
- Asp.Net Core輕量級(jí)Aop解決方案:AspectCore
- Kotlin基礎(chǔ)教程之dataclass,objectclass,use函數(shù),類擴(kuò)展,socket
- IOS ObjectC與javascript交互詳解及實(shí)現(xiàn)代碼
- asp內(nèi)置對(duì)象 ObjectContext 事務(wù)管理 詳解
- python實(shí)現(xiàn)CTC以及案例講解
相關(guān)文章
python如何通過實(shí)例方法名字調(diào)用方法
這篇文章主要為大家詳細(xì)介紹了python如何通過實(shí)例方法名字調(diào)用方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03
Python統(tǒng)計(jì)一個(gè)字符串中每個(gè)字符出現(xiàn)了多少次的方法【字符串轉(zhuǎn)換為列表再統(tǒng)計(jì)】
這篇文章主要介紹了Python統(tǒng)計(jì)一個(gè)字符串中每個(gè)字符出現(xiàn)了多少次的方法,涉及Python字符串轉(zhuǎn)換及列表遍歷、統(tǒng)計(jì)等相關(guān)操作技巧,需要的朋友可以參考下2019-05-05
使用python-cv2實(shí)現(xiàn)視頻的分解與合成的示例代碼
這篇文章主要介紹了使用python-cv2實(shí)現(xiàn)視頻的分解與合成的示例代碼,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10
Python中處理無(wú)效數(shù)據(jù)的詳細(xì)教程
無(wú)效數(shù)據(jù)是指不符合數(shù)據(jù)收集目的或數(shù)據(jù)收集標(biāo)準(zhǔn)的數(shù)據(jù),這些數(shù)據(jù)可能來(lái)自于不準(zhǔn)確的測(cè)量、缺失值、錯(cuò)誤標(biāo)注、虛假的數(shù)據(jù)源或其他問題,本文就將帶大家學(xué)習(xí)Python中如何處理無(wú)效數(shù)據(jù),感興趣的同學(xué)可以跟著小編一起來(lái)學(xué)習(xí)2023-06-06
詳解Python如何解析JSON中的對(duì)象數(shù)組
這篇文章主要為大家詳細(xì)介紹了如何使用Python的JSON模塊傳輸和接收J(rèn)SON數(shù)據(jù),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-10-10
Python爬蟲實(shí)例_利用百度地圖API批量獲取城市所有的POI點(diǎn)
下面小編就為大家分享一篇Python爬蟲實(shí)例_利用百度地圖API批量獲取城市所有的POI點(diǎn),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-01-01
python簡(jiǎn)單實(shí)現(xiàn)刷新智聯(lián)簡(jiǎn)歷
本文給大家分享的是個(gè)人弄的一個(gè)使用Python簡(jiǎn)單實(shí)現(xiàn)刷新智聯(lián)招聘簡(jiǎn)歷的小工具的代碼,非常的簡(jiǎn)單,給大家參考下吧。2016-03-03
Python中集合的創(chuàng)建及常用函數(shù)的使用詳解
這篇文章主要為大家詳細(xì)介紹了Python中集合的創(chuàng)建、使用和遍歷,集合常見的操作函數(shù),集合與列表,元組,字典的嵌套,感興趣的小伙伴可以了解一下2022-06-06

