基于Keras的格式化輸出Loss實(shí)現(xiàn)方式
在win7 64位,Anaconda安裝的Python3.6.1下安裝的TensorFlow與Keras,Keras的backend為TensorFlow。在運(yùn)行Mask R-CNN時(shí),在進(jìn)行調(diào)試時(shí)想知道PyCharm (Python IDE)底部窗口輸出的Loss格式是在哪里定義的,如下圖紅框中所示:

圖1 訓(xùn)練過(guò)程的Loss格式化輸出
在上圖紅框中,Loss的輸出格式是在哪里定義的呢?有一點(diǎn)是明確的,即上圖紅框中的內(nèi)容是在訓(xùn)練的時(shí)候輸出的。那么先來(lái)看一下Mask R-CNN的訓(xùn)練過(guò)程。Keras以Numpy數(shù)組作為輸入數(shù)據(jù)和標(biāo)簽的數(shù)據(jù)類型。訓(xùn)練模型一般使用 fit 函數(shù)。然而由于Mask R-CNN訓(xùn)練數(shù)據(jù)巨大,不能一次性全部載入,否則太消耗內(nèi)存。于是采用生成器的方式一次載入一個(gè)batch的數(shù)據(jù),而且是在用到這個(gè)batch的數(shù)據(jù)才開(kāi)始載入的,那么它的訓(xùn)練函數(shù)如下:
self.keras_model.fit_generator( train_generator, initial_epoch=self.epoch, epochs=epochs, steps_per_epoch=self.config.STEPS_PER_EPOCH, callbacks=callbacks, validation_data=val_generator, validation_steps=self.config.VALIDATION_STEPS, max_queue_size=100, workers=workers, use_multiprocessing=False, )
這里訓(xùn)練模型的函數(shù)相應(yīng)的為 fit_generator 函數(shù)。注意其中的參數(shù)callbacks=callbacks,這個(gè)參數(shù)在輸出紅框中的內(nèi)容起到了關(guān)鍵性的作用。下面看一下callbacks的值:
# Callbacks
callbacks = [
keras.callbacks.TensorBoard(log_dir=self.log_dir,
histogram_freq=0, write_graph=True, write_images=False),
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=0, save_weights_only=True),
]
在輸出紅框中的內(nèi)容所需的數(shù)據(jù)均保存在self.log_dir下。然后調(diào)試進(jìn)入self.keras_model.fit_generator函數(shù),進(jìn)入keras,legacy.interfaces的legacy_support(func)函數(shù),如下所示:
def legacy_support(func):
@six.wraps(func)
def wrapper(*args, **kwargs):
if object_type == 'class':
object_name = args[0].__class__.__name__
else:
object_name = func.__name__
if preprocessor:
args, kwargs, converted = preprocessor(args, kwargs)
else:
converted = []
if check_positional_args:
if len(args) > len(allowed_positional_args) + 1:
raise TypeError('`' + object_name +
'` can accept only ' +
str(len(allowed_positional_args)) +
' positional arguments ' +
str(tuple(allowed_positional_args)) +
', but you passed the following '
'positional arguments: ' +
str(list(args[1:])))
for key in value_conversions:
if key in kwargs:
old_value = kwargs[key]
if old_value in value_conversions[key]:
kwargs[key] = value_conversions[key][old_value]
for old_name, new_name in conversions:
if old_name in kwargs:
value = kwargs.pop(old_name)
if new_name in kwargs:
raise_duplicate_arg_error(old_name, new_name)
kwargs[new_name] = value
converted.append((new_name, old_name))
if converted:
signature = '`' + object_name + '('
for i, value in enumerate(args[1:]):
if isinstance(value, six.string_types):
signature += '"' + value + '"'
else:
if isinstance(value, np.ndarray):
str_val = 'array'
else:
str_val = str(value)
if len(str_val) > 10:
str_val = str_val[:10] + '...'
signature += str_val
if i < len(args[1:]) - 1 or kwargs:
signature += ', '
for i, (name, value) in enumerate(kwargs.items()):
signature += name + '='
if isinstance(value, six.string_types):
signature += '"' + value + '"'
else:
if isinstance(value, np.ndarray):
str_val = 'array'
else:
str_val = str(value)
if len(str_val) > 10:
str_val = str_val[:10] + '...'
signature += str_val
if i < len(kwargs) - 1:
signature += ', '
signature += ')`'
warnings.warn('Update your `' + object_name +
'` call to the Keras 2 API: ' + signature, stacklevel=2)
return func(*args, **kwargs)
wrapper._original_function = func
return wrapper
return legacy_support
在上述代碼的倒數(shù)第4行的return func(*args, **kwargs)處返回func,func為fit_generator函數(shù),現(xiàn)調(diào)試進(jìn)入fit_generator函數(shù),該函數(shù)定義在keras.engine.training模塊內(nèi)的fit_generator函數(shù),調(diào)試進(jìn)入函數(shù)callbacks.on_epoch_begin(epoch),如下所示:
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
for m in self.stateful_metric_functions:
m.reset_states()
callbacks.on_epoch_begin(epoch)
調(diào)試進(jìn)入到callbacks.on_epoch_begin(epoch)函數(shù),進(jìn)入on_epoch_begin函數(shù),如下所示:
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
# Arguments
epoch: integer, index of epoch.
logs: dictionary of logs.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
self._delta_t_batch = 0.
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
在上述函數(shù)on_epoch_begin中調(diào)試進(jìn)入callback.on_epoch_begin(epoch, logs)函數(shù),轉(zhuǎn)到類ProgbarLogger(Callback)中定義的on_epoch_begin函數(shù),如下所示:
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
# Arguments
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
# Raises
ValueError: In case of invalid `count_mode`.
"""
def __init__(self, count_mode='samples',
stateful_metrics=None):
super(ProgbarLogger, self).__init__()
if count_mode == 'samples':
self.use_steps = False
elif count_mode == 'steps':
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
def on_train_begin(self, logs=None):
self.verbose = self.params['verbose']
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
if self.verbose:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
if self.use_steps:
target = self.params['steps']
else:
target = self.params['samples']
self.target = target
self.progbar = Progbar(target=self.target,
verbose=self.verbose,
stateful_metrics=self.stateful_metrics)
self.seen = 0
在上述代碼的
print('Epoch %d/%d' % (epoch + 1, self.epochs))
輸出
Epoch 1/40(如紅框中所示內(nèi)容的第一行)。
然后返回到keras.engine.training模塊內(nèi)的fit_generator函數(shù),執(zhí)行到self.train_on_batch函數(shù),如下所示:
outs = self.train_on_batch(x, y,
sample_weight=sample_weight,
class_weight=class_weight)
if not isinstance(outs, list):
outs = [outs]
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
batch_index += 1
steps_done += 1
調(diào)試進(jìn)入上述代碼中的callbacks.on_batch_end(batch_index, batch_logs)函數(shù),進(jìn)入到on_batch_end函數(shù)后,該函數(shù)的定義如下所示:
def on_batch_end(self, batch, logs=None):
"""Called at the end of a batch.
# Arguments
batch: integer, index of batch within the current epoch.
logs: dictionary of logs.
"""
logs = logs or {}
if not hasattr(self, '_t_enter_batch'):
self._t_enter_batch = time.time()
self._delta_t_batch = time.time() - self._t_enter_batch
t_before_callbacks = time.time()
for callback in self.callbacks:
callback.on_batch_end(batch, logs)
self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
delta_t_median = np.median(self._delta_ts_batch_end)
if (self._delta_t_batch > 0. and
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
warnings.warn('Method on_batch_end() is slow compared '
'to the batch update (%f). Check your callbacks.'
% delta_t_median)
接著繼續(xù)調(diào)試進(jìn)入上述代碼中的callback.on_batch_end(batch, logs)函數(shù),進(jìn)入到在類中ProgbarLogger(Callback)定義的on_batch_end函數(shù),如下所示:
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
if self.use_steps:
self.seen += 1
else:
self.seen += batch_size
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if self.verbose and self.seen < self.target:
self.progbar.update(self.seen, self.log_values)
然后執(zhí)行到上述代碼的最后一行self.progbar.update(self.seen, self.log_values),調(diào)試進(jìn)入update函數(shù),該函數(shù)定義在模塊keras.utils.generic_utils中的類Progbar(object)定義的函數(shù)。類的定義及方法如下所示:
class Progbar(object):
"""Displays a progress bar.
# Arguments
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over time. Metrics in this list
will be displayed as-is. All others will be averaged
by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
"""
def __init__(self, target, width=30, verbose=1, interval=0.05,
stateful_metrics=None):
self.target = target
self.width = width
self.verbose = verbose
self.interval = interval
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules)
self._total_width = 0
self._seen_so_far = 0
self._values = collections.OrderedDict()
self._start = time.time()
self._last_update = 0
def update(self, current, values=None):
"""Updates the progress bar.
# Arguments
current: Index of current step.
values: List of tuples:
`(name, value_for_last_step)`.
If `name` is in `stateful_metrics`,
`value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
"""
values = values or []
for k, v in values:
if k not in self.stateful_metrics:
if k not in self._values:
self._values[k] = [v * (current - self._seen_so_far),
current - self._seen_so_far]
else:
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self._start)
if self.verbose == 1:
if (now - self._last_update < self.interval and
self.target is not None and current < self.target):
return
prev_total_width = self._total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%d [' % (numdigits, self.target)
bar = barstr % current
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self._total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self._start) / current
else:
time_per_unit = 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
else:
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self._values:
info += ' - %s:' % k
if isinstance(self._values[k], list):
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self._values[k]
self._total_width += len(info)
if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
if self.target is not None and current >= self.target:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if self.target is None or current >= self.target:
for k in self._values:
info += ' - %s:' % k
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
重點(diǎn)是上述代碼中的update(self, current, values=None)函數(shù),在該函數(shù)內(nèi)設(shè)置斷點(diǎn),即可調(diào)入該函數(shù)。下面重點(diǎn)分析上述代碼中的幾個(gè)輸出條目:
1. sys.stdout.write('\n') #換行
2. sys.stdout.write('bar') #輸出 [..................],其中bar= [..................];
3. sys.stdout.write(info) #輸出loss格式,其中info='- ETA:...';
4. sys.stdout.flush() #刷新緩存,立即得到輸出。
通過(guò)對(duì)Mask R-CNN代碼的調(diào)試分析可知,圖1中的紅框中的訓(xùn)練過(guò)程中的Loss格式化輸出是由built-in模塊實(shí)現(xiàn)的。若想得到類似的格式化輸出,關(guān)鍵在self.keras_model.fit_generator函數(shù)中傳入callbacks參數(shù)和callbacks中內(nèi)容的定義。
以上這篇基于Keras的格式化輸出Loss實(shí)現(xiàn)方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
web自動(dòng)化測(cè)試Selenium點(diǎn)擊元素的常用方法
在Web自動(dòng)化測(cè)試中,Selenium提供多種點(diǎn)擊方法,常用的click()方法通過(guò)選中元素并觸發(fā)點(diǎn)擊事件,若click()方法不穩(wěn)定,可以采用JavaScript執(zhí)行點(diǎn)擊或使用ActionChains類模擬鼠標(biāo)點(diǎn)擊,需要的朋友可以參考下2024-09-09
python使用matplotlib繪圖時(shí)圖例顯示問(wèn)題的解決
matplotlib 是python最著名的繪圖庫(kù),它提供了一整套和matlab相似的命令A(yù)PI,十分適合交互式地進(jìn)行制圖。下面這篇文章主要給大家介紹了在python使用matplotlib繪圖時(shí)圖例顯示問(wèn)題的解決方法,需要的朋友可以參考學(xué)習(xí),下面來(lái)一起看看吧。2017-04-04
使用Python來(lái)開(kāi)發(fā)Markdown腳本擴(kuò)展的實(shí)例分享
這篇文章主要介紹了使用Python來(lái)開(kāi)發(fā)Markdown腳本擴(kuò)展的實(shí)例分享,文中的示例是用來(lái)簡(jiǎn)單地轉(zhuǎn)換文檔結(jié)構(gòu),主要為了體現(xiàn)一個(gè)思路,需要的朋友可以參考下2016-03-03
Python將一個(gè)Excel拆分為多個(gè)Excel
這篇文章主要為大家詳細(xì)介紹了Python將一個(gè)Excel拆分為多個(gè)Excel,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-11-11
OpenCV實(shí)現(xiàn)手勢(shì)虛擬拖拽的使用示例(附demo)
本文主要介紹了OpenCV實(shí)現(xiàn)手勢(shì)虛擬拖拽的使用示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-11-11
TensorFlow2.1.0安裝過(guò)程中setuptools、wrapt等相關(guān)錯(cuò)誤指南
這篇文章主要介紹了TensorFlow2.1.0安裝時(shí)setuptools、wrapt等相關(guān)錯(cuò)誤指南,本文通過(guò)安裝錯(cuò)誤分析給出大家解決方案,感興趣的朋友跟隨小編一起看看吧2020-04-04

