keras中的loss、optimizer、metrics用法
用keras搭好模型架構之后的下一步,就是執(zhí)行編譯操作。在編譯時,經(jīng)常需要指定三個參數(shù)
loss
optimizer
metrics
這三個參數(shù)有兩類選擇:
使用字符串
使用標識符,如keras.losses,keras.optimizers,metrics包下面的函數(shù)
例如:
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
因為有時可以使用字符串,有時可以使用標識符,令人很想知道背后是如何操作的。下面分別針對optimizer,loss,metrics三種對象的獲取進行研究。
optimizer
一個模型只能有一個optimizer,在執(zhí)行編譯的時候只能指定一個optimizer。
在keras.optimizers.py中,有一個get函數(shù),用于根據(jù)用戶傳進來的optimizer參數(shù)獲取優(yōu)化器的實例:
def get(identifier): # 如果后端是tensorflow并且使用的是tensorflow自帶的優(yōu)化器實例,可以直接使用tensorflow原生的優(yōu)化器 if K.backend() == 'tensorflow': # Wrap TF optimizer instances if isinstance(identifier, tf.train.Optimizer): return TFOptimizer(identifier) # 如果以json串的形式定義optimizer并進行參數(shù)配置 if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, six.string_types): # 如果以字符串形式指定optimizer,那么使用優(yōu)化器的默認配置參數(shù) config = {'class_name': str(identifier), 'config': {}} return deserialize(config) if isinstance(identifier, Optimizer): # 如果使用keras封裝的Optimizer的實例 return identifier else: raise ValueError('Could not interpret optimizer identifier: ' + str(identifier))
其中,deserilize(config)函數(shù)的作用就是把optimizer反序列化制造一個實例。
loss
keras.losses函數(shù)也有一個get(identifier)方法。其中需要注意以下一點:
如果identifier是可調(diào)用的一個函數(shù)名,也就是一個自定義的損失函數(shù),這個損失函數(shù)返回值是一個張量。這樣就輕而易舉的實現(xiàn)了自定義損失函數(shù)。除了使用str和dict類型的identifier,我們也可以直接使用keras.losses包下面的損失函數(shù)。
def get(identifier): if identifier is None: return None if isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) if isinstance(identifier, dict): return deserialize(identifier) elif callable(identifier): return identifier else: raise ValueError('Could not interpret ' 'loss function identifier:', identifier)
metrics
在model.compile()函數(shù)中,optimizer和loss都是單數(shù)形式,只有metrics是復數(shù)形式。因為一個模型只能指明一個optimizer和loss,卻可以指明多個metrics。metrics也是三者中處理邏輯最為復雜的一個。
在keras最核心的地方keras.engine.train.py中有如下處理metrics的函數(shù)。這個函數(shù)其實就做了兩件事:
根據(jù)輸入的metric找到具體的metric對應的函數(shù)
計算metric張量
在尋找metric對應函數(shù)時,有兩種步驟:
使用字符串形式指明準確率和交叉熵
使用keras.metrics.py中的函數(shù)
def handle_metrics(metrics, weights=None): metric_name_prefix = 'weighted_' if weights is not None else '' for metric in metrics: # 如果metrics是最常見的那種:accuracy,交叉熵 if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): # custom handling of accuracy/crossentropy # (because of class mode duality) output_shape = K.int_shape(self.outputs[i]) # 如果輸出維度是1或者損失函數(shù)是二分類損失函數(shù),那么說明是個二分類問題,應該使用二分類的accuracy和二分類的的交叉熵 if (output_shape[-1] == 1 or self.loss_functions[i] == losses.binary_crossentropy): # case: binary accuracy/crossentropy if metric in ('accuracy', 'acc'): metric_fn = metrics_module.binary_accuracy elif metric in ('crossentropy', 'ce'): metric_fn = metrics_module.binary_crossentropy # 如果損失函數(shù)是sparse_categorical_crossentropy,那么目標y_input就不是one-hot的,所以就需要使用sparse的多類準去率和sparse的多類交叉熵 elif self.loss_functions[i] == losses.sparse_categorical_crossentropy: # case: categorical accuracy/crossentropy # with sparse targets if metric in ('accuracy', 'acc'): metric_fn = metrics_module.sparse_categorical_accuracy elif metric in ('crossentropy', 'ce'): metric_fn = metrics_module.sparse_categorical_crossentropy else: # case: categorical accuracy/crossentropy if metric in ('accuracy', 'acc'): metric_fn = metrics_module.categorical_accuracy elif metric in ('crossentropy', 'ce'): metric_fn = metrics_module.categorical_crossentropy if metric in ('accuracy', 'acc'): suffix = 'acc' elif metric in ('crossentropy', 'ce'): suffix = 'ce' weighted_metric_fn = weighted_masked_objective(metric_fn) metric_name = metric_name_prefix + suffix else: # 如果輸入的metric不是字符串,那么就調(diào)用metrics模塊獲取 metric_fn = metrics_module.get(metric) weighted_metric_fn = weighted_masked_objective(metric_fn) # Get metric name as string if hasattr(metric_fn, 'name'): metric_name = metric_fn.name else: metric_name = metric_fn.__name__ metric_name = metric_name_prefix + metric_name with K.name_scope(metric_name): metric_result = weighted_metric_fn(y_true, y_pred, weights=weights, mask=masks[i]) # Append to self.metrics_names, self.metric_tensors, # self.stateful_metric_names if len(self.output_names) > 1: metric_name = self.output_names[i] + '_' + metric_name # Dedupe name j = 1 base_metric_name = metric_name while metric_name in self.metrics_names: metric_name = base_metric_name + '_' + str(j) j += 1 self.metrics_names.append(metric_name) self.metrics_tensors.append(metric_result) # Keep track of state updates created by # stateful metrics (i.e. metrics layers). if isinstance(metric_fn, Layer) and metric_fn.stateful: self.stateful_metric_names.append(metric_name) self.stateful_metric_functions.append(metric_fn) self.metrics_updates += metric_fn.updates
無論怎么使用metric,最終都會變成metrics包下面的函數(shù)。當使用字符串形式指明accuracy和crossentropy時,keras會非常智能地確定應該使用metrics包下面的哪個函數(shù)。因為metrics包下的那些metric函數(shù)有不同的使用場景,例如:
有的處理的是one-hot形式的y_input(數(shù)據(jù)的類別),有的處理的是非one-hot形式的y_input
有的處理的是二分類問題的metric,有的處理的是多分類問題的metric
當使用字符串“accuracy”和“crossentropy”指明metric時,keras會根據(jù)損失函數(shù)、輸出層的shape來確定具體應該使用哪個metric函數(shù)。在任何情況下,直接使用metrics下面的函數(shù)名是總不會出錯的。
keras.metrics.py文件中也有一個get(identifier)函數(shù)用于獲取metric函數(shù)。
def get(identifier): if isinstance(identifier, dict): config = {'class_name': str(identifier), 'config': {}} return deserialize(config) elif isinstance(identifier, six.string_types): return deserialize(str(identifier)) elif callable(identifier): return identifier else: raise ValueError('Could not interpret ' 'metric function identifier:', identifier)
如果identifier是字符串或者字典,那么會根據(jù)identifier反序列化出一個metric函數(shù)。
如果identifier本身就是一個函數(shù)名,那么就直接返回這個函數(shù)名。這種方式就為自定義metric提供了巨大便利。
keras中的設計哲學堪稱完美。
以上這篇keras中的loss、optimizer、metrics用法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
詳解python3 + Scrapy爬蟲學習之創(chuàng)建項目
這篇文章主要介紹了python3 Scrapy爬蟲創(chuàng)建項目,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-04-04Scrapy中詭異xpath的匹配內(nèi)容失效問題及解決
這篇文章主要介紹了Scrapy中詭異xpath的匹配內(nèi)容失效問題及解決方案,具有很好的價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-12-12使用Python創(chuàng)建多功能文件管理器的代碼示例
在本文中,我們將探索一個使用Python的wxPython庫開發(fā)的文件管理器應用程序,這個應用程序不僅能夠瀏覽和選擇文件,還支持文件預覽、壓縮、圖片轉(zhuǎn)換以及生成PPT演示文稿的功能,需要的朋友可以參考下2024-08-08python 插入Null值數(shù)據(jù)到Postgresql的操作
這篇文章主要介紹了python 插入Null值數(shù)據(jù)到Postgresql的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-03-03