Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
Шолле Ф. - Глубокое обучение на Python (Библиотека программиста) - 2023.pdf
Скачиваний:
6
Добавлен:
07.04.2024
Размер:
11.34 Mб
Скачать

7.3. Встроенные циклы обучения и оценки    243

быть.представлены.в.виде.графов.слоев..Однако.мы.будем.часто.использовать. подклассы.слоев..В.целом.функциональные.модели,.включающие.подклассы,. позволяют.сочетать.лучшее.из.обоих.миров:.высокую.гибкость.разработки.при. сохранении.преимуществ.функционального.API.

7.3. ВСТРОЕННЫЕ ЦИКЛЫ ОБУЧЕНИЯ И ОЦЕНКИ

Принцип.постепенного.раскрытия.сложности.—.доступ.к.спектру.рабочих.процес- сов.шаг.за.шагом,.от.предельно.простых.до.бесконечно.гибких.—.также.применим. для.обучения.моделей..Библиотека.Keras.предлагает.разные.подходы.к.обучению. моделей,.от.несложных,.таких.как.вызов.fit() .с.обучающими.данными,.до.продвинутых,.связанных.с.разработкой.новых.алгоритмов.обучения.с.нуля.

Вы.уже.знакомы.с.последовательностью.вызовов.compile(),.fit(),.evaluate(),. predict()..Чтобы.вспомнить.ее,.взгляните.на.следующий.листинг.

Листинг 7.17. Стандартный рабочий процесс: compile(), fit(), evaluate(), predict()

 

from tensorflow.keras.datasets import mnist

 

Создание модели (вынесено

 

 

 

 

 

 

 

 

в отдельную функцию, чтобы иметь

 

def get_mnist_model():

 

 

возможность использовать ее позже)

 

 

 

 

inputs = keras.Input(shape=(28 * 28,))

 

 

 

 

 

 

 

 

 

 

features = layers.Dense(512, activation="relu")(inputs)

 

features = layers.Dropout(0.5)(features)

 

 

 

 

 

 

 

 

 

 

outputs = layers.Dense(10, activation="softmax")(features)

 

model = keras.Model(inputs, outputs)

 

 

 

 

 

Загрузка данных

 

 

 

 

 

 

 

 

 

return model

 

 

 

 

 

для обучения

 

 

 

 

 

 

 

 

и проверки

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

(images, labels), (test_images, test_labels) = mnist.load_data()

 

 

 

 

 

 

 

 

images = images.reshape((60000, 28 * 28)).astype("float32") / 255

 

test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255

 

train_images, val_images = images[10000:], images[:10000]

 

 

 

 

 

 

train_labels, val_labels = labels[10000:], labels[:10000]

 

 

 

 

 

 

model = get_mnist_model()

 

 

 

 

Компиляция модели

 

 

 

 

 

 

 

 

 

 

с указанными

 

model.compile(optimizer="rmsprop",

 

 

 

 

 

 

 

 

 

оптимизатором, функцией

 

loss="sparse_categorical_crossentropy",

 

 

потерь, которая должна

 

 

 

 

metrics=["accuracy"])

 

 

 

 

минимизироваться,

 

model.fit(train_images, train_labels,

 

 

 

 

и метриками для оценки

 

 

 

 

 

 

epochs=3,

 

 

 

 

 

Вызов evaluate()

 

validation_data=(val_images, val_labels))

 

 

test_metrics = model.evaluate(test_images, test_labels)

 

 

для вычисления потерь

 

 

 

и метрик на контрольных

 

predictions = model.predict(test_images)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

данных

Вызов fit() для обучения модели; при необходимости можно

 

 

 

 

 

 

 

Вызов predict() с контрольными

передать проверочные данные для мониторинга качества

 

данными для вычисления вероятностей

модели на данных, отличных от обучающих

 

принадлежности к тому или иному классу

Согласно нашей модели MNIST ожидаются категориальные прогнозы и целочисленные метки

244    Глава 7. Работа с Keras: глубокое погружение

Данный.простой.процесс.можно.скорректировать:

. указав.свои.метрики;

.передав.функции обратного вызова.методу.fit(),.чтобы.выполнить.некоторые. действия.в.определенные.моменты.обучения.

Посмотрим,.как.это.сделать.

7.3.1. Использование собственных метрик

Метрики.являются.ключом.к.оценке.качества.модели.—.в.частности,.они.позво- ляют.измерить.разницу.качества.модели.на.обучающих.и.контрольных.данных..

Метрики,.обычно.используемые.для.классификации.и.регрессии,.уже.включены. в.стандартный.модуль.keras.metrics,.и.в.большинстве.случаев.вы.будете.брать. именно.их..Но.иногда,.особенно.при.решении.необычных.задач,.вам.может.понадобиться.умение.писать.свои.метрики..Это.просто!

Метрики.в.Keras.являются.подклассом.класса.keras.metrics.Metric..Подобно.слоям,.метрики.имеют.внутреннее.состояние,.хранящееся.в.переменных. TensorFlow..Но,.в.отличие.от.слоев,.эти.переменные.не.обновляются.на.этапе. обратного.распространения,.поэтому.вам.придется.написать.свою.логику.их. обновления.в.методе.update_state().

Вот.пример.простой.метрики,.измеряющей.среднеквадратичную.ошибку.(Root. Mean.Squared.Error,.RMSE).

Листинг 7.18. Реализация метрики путем создания класса, производного от класса Metric

import tensorflow as tf

Подкласс

 

 

class RootMeanSquaredError(keras.metrics.Metric):

 

класса Metric

 

 

def __init__(self, name="rmse", **kwargs): super().__init__(name=name, **kwargs)

self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros") self.total_samples = self.add_weight(

name="total_samples", initializer="zeros", dtype="int32")

def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1]) mse = tf.reduce_sum(tf.square(y_true - y_pred)) self.mse_sum.assign_add(mse)

num_samples = tf.shape(y_pred)[0] self.total_samples.assign_add(num_samples)

Определение в конструкторе переменных

update_state() реализует логику обновления состояния. Аргумент

для хранения состояния. По аналогии

y_true — это цели (или метки) для одного пакета, а y_pred

со слоями есть возможность использовать

представляет соответствующие прогнозы модели. Аргумент

метод add_weight()

sample_weight можно игнорировать — здесь он не используется

7.3. Встроенные циклы обучения и оценки    245

Для.получения.текущего.значения.метрики.нужно.реализовать.метод.result():

def result(self):

return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

Также.следует.предоставить.возможность.сбросить.метрику.в.исходное.состояние.без.необходимости.создавать.ее.повторно..Это.позволит.использовать. одни.и.те.же.объекты.метрик.в.разные.эпохи.обучения.или.на.этапах.и.обучения,. и.оценки..Для.этого.достаточно.реализовать.метод.reset_state():

def reset_state(self): self.mse_sum.assign(0.) self.total_samples.assign(0)

Нестандартные.метрики.используются.точно.так.же,.как.встроенные..Давайте. протестируем.нашу.метрику:

model = get_mnist_model() model.compile(optimizer="rmsprop",

loss="sparse_categorical_crossentropy", metrics=["accuracy", RootMeanSquaredError()])

model.fit(train_images, train_labels, epochs=3,

validation_data=(val_images, val_labels)) test_metrics = model.evaluate(test_images, test_labels)

Теперь.индикатор.выполнения.fit() .будет.отображать.среднеквадратичную. ошибку.модели.

7.3.2. Использование обратных вызовов

Запуск.процедуры.обучения.продолжительностью.в.десятки.эпох.на.большом. наборе.данных.вызовом.model.fit().напоминает.запуск.бумажного.самолетика:. придав.начальный.импульс,.вы.больше.никак.не.управляете.ни.траекторией.его. полета,.ни.местом.приземления..Чтобы.избежать.отрицательных.результатов. (и.потери.самолетика),.лучше.использовать.не.бумажный.самолетик,.а.управляемый.беспилотник,.анализирующий.окружающую.обстановку,.посылающий. информацию.о.ней.обратно.оператору.и.автоматически.управляющий.рулями.

в.зависимости.от.своего.текущего.состояния..Поддержка.обратных вызовов.

в.Keras.поможет.превратить.вызов.model.fit() .из.бумажного.самолетика.в.интеллектуальный,.автономный.беспилотник,.способный.оценивать.свое.состояние. и.своевременно.выполнять.управляющие.действия.

Обратный вызов.—.это.объект.(экземпляр.класса,.реализующего.конкретные. методы),.который.передается.в.модель.через.вызов.fit() .и.который.будет.вызываться.моделью.в.разные.моменты.обучения..Он.имеет.доступ.ко.всей.информации.о.состоянии.модели.и.ее.качестве.и.может.предпринимать.следующие.

246    Глава 7. Работа с Keras: глубокое погружение

действия:.прерывать.обучение,.сохранять.модель,.загружать.разные.наборы. весов.или.как-то.иначе.изменять.состояние.модели.

Вот.несколько.примеров.использования.обратных.вызовов:

.фиксация состояния модели в контрольных точках.—.сохранение.текущего. состояния.модели.в.разные.моменты.в.ходе.обучения;

.ранняя остановка.—.прерывание.обучения,.когда.оценка.потерь.на.провероч- ных.данных.перестает.улучшаться.(и,.конечно,.сохранение.лучшего.варианта. модели,.полученного.в.ходе.обучения);

.динамическая корректировка значений некоторых параметров в процессе обучения.—.например,.шага.обучения.оптимизатора;

.журналирование оценок для обучающего и проверочного наборов данных в ходе обучения или визуализация представлений, получаемых моделью, по мере их обновления.—.индикатор.выполнения.в.fit(),.с.которым.вы.уже.знакомы,.—. это.на.самом.деле.обратный.вызов!

Модуль.keras.callbacks .включает.ряд.встроенных.обратных.вызовов..Вот.далеко.не.полный.список:

keras.callbacks.ModelCheckpoint

keras.callbacks.EarlyStopping

keras.callbacks.LearningRateScheduler

keras.callbacks.ReduceLROnPlateau

keras.callbacks.CSVLogger

Рассмотрим.некоторые.из.них,.чтобы.понять,.как.ими.пользоваться:.EarlyStopping . и.ModelCheckpoint.

Обратные вызовы EarlyStopping и ModelCheckpoint

Многие.аспекты.обучения.модели.нельзя.предсказать.заранее.—.например,.ко- личество.эпох,.обеспечивающее.оптимальное.значение.потерь.на.проверочном. наборе..В.примерах,.приводившихся.до.сих.пор,.использовалась.стратегия.обучения.с.достаточно.большим.количеством.эпох..Таким.способом.достигался. эффект.переобучения:.когда.сначала.выполнялся.первый.прогон,.чтобы.выяс- нить.необходимое.количество.эпох.обучения,.а.затем.второй.—.новый.—.с.этим. количеством..Конечно,.данная.стратегия.довольно.расточительная..Гораздо. лучше.было.бы.остановить.обучение,.как.только.выяснится,.что.оценка.потерь. на.проверочном.наборе.перестала.улучшаться..Это.можно.реализовать.с.использованием.обратного.вызова.EarlyStopping.

Обратный.вызов.EarlyStopping.прерывает.процесс.обучения,.если.находящаяся. под.наблюдением.целевая.метрика.не.улучшалась.на.протяжении.заданного. количества.эпох..Он.позволит.остановить.обучение.после.наступления.эффекта.