Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
Шолле Ф. - Глубокое обучение на Python (Библиотека программиста) - 2023.pdf
Скачиваний:
6
Добавлен:
07.04.2024
Размер:
11.34 Mб
Скачать

7.4. Разработка своего цикла обучения и оценки    251

Рис. 7.7. С помощью TensorBoard можно наблюдать, как продвигается обучение и изменяются значения метрик

7.4. РАЗРАБОТКА СВОЕГО ЦИКЛА ОБУЧЕНИЯ И ОЦЕНКИ

Рабочий.процесс.на.основе.метода.fit() .обеспечивает.хороший.баланс.между. простотой.и.гибкостью..Именно.этот.подход.вы.будете.использовать.чаще.всего..Однако.он.не.предназначен.для.поддержки.нужд.исследователей.глубокого. обучения,.даже.несмотря.на.возможность.настройки.метрик,.функций.потерь. и.обратных.вызовов.

В.конце.концов,.подход.с.использованием.метода.fit().ориентирован.исключительно.на.обучение с учителем,.когда.заранее.известны.цели.(также.называемые. метками .или.аннотациями),.связанные.с.входными.данными,.а.потери.вычисляются.как.функция.этих.целей.и.прогнозов.модели..Однако.не.все.формы. машинного.обучения.попадают.в.эту.категорию..В.некоторых.случаях.нет.явных. целей.—.например,.в.генеративном.обучении.(которое.мы.обсудим.в.главе.12),. в.самоконтролируемом обучении.(когда.цели.извлекаются.из.входных.данных). и.в.обучении с подкреплением.(когда.обучение.подкрепляется.«вознаграждениями»,. что.очень.напоминает.дрессировку.собаки)..Даже.если.вы.регулярно.занимаетесь.

252    Глава 7. Работа с Keras: глубокое погружение

обучением.с.учителем,.вам,.как.исследователю,.может.понадобиться.добавить. несколько.новых.опций,.а.для.этого.нужна.гибкость.на.низком.уровне.

Всякий.раз,.оказавшись.в.ситуации,.когда.встроенного.метода.fit().недостаточно,.вам.нужно.будет.написать.свою.логику.обучения..Вы.уже.видели.простые. примеры.низкоуровневых.циклов.обучения.в.главах.2.и.3..Напомню.порядок. проведения.типичного.цикла.обучения.по.основным.этапам.

1.. Выполнить.прямой.проход.(вычислить.выходы.модели).внутри.GradientTape,. чтобы.получить.величину.потерь.для.текущего.пакета.данных.

2.. Получить.градиенты.потерь.с.учетом.весов.модели.

3.. Скорректировать.веса.модели,.чтобы.уменьшить.величину.потерь.на.текущем. пакете.данных.

Эти.шаги.повторяются.для.выбранного.количества.пакетов..Фактически.именно. так.и.действует.метод.fit()..Далее.вы.узнаете,.как.переопределить.fit() .и.написать.свою.реализацию.с.нуля,.что.позволит.вам.в.будущем.создать.любой. алгоритм.обучения,.который.только.вы.придумаете.

А.теперь.перейдем.к.деталям.

7.4.1. Обучение и прогнозирование

В.примерах.низкоуровневого.цикла.обучения,.которые.вы.видели.до.сих.пор,. шаг.1.(прямой.проход).выполнялся.инструкцией.predictions = model(input),. а.шаг.2.(получение.градиентов,.вычисленных.с.помощью.GradientTape).—.ин- струкцией.gradients = tape.gradient(loss, model.weights)..В.общем.случае. следует.учитывать.две.тонкости.

Некоторые.слои.Keras.(такие.как.Dropout).во.время.обучения.и.во.время.прогнози- рования.ведут.себя.по-разному..Метод.call().таких.слоев.принимает.логический. аргумент.training..Вызов.dropout(inputs, training=True) .приведет.к.сбросу. некоторых.активаций,.а.вызов.dropout(inputs, training=False) .—.нет..Кроме. того,.метод.call() .функциональных.и.последовательных.моделей.тоже.поддерживает.аргумент.training..Важно.не.забывать.передавать.training=True.при. выполнении.прямого.прохода.модели.Keras!.То.есть.прямой.проход.фактически. должен.выполняться.инструкцией.predictions = model(inputs, training=True).

Также.обратите.внимание,.что.для.получения.градиентов.весов.модели.следует. использовать.не.tape.gradients(loss, model.weights),.а.tape.gradients(loss, . model.trainable_weights)..В.действительности.слои.и.модели.обладают.двумя. видами.весов,.такими.как:

.обучаемые веса.—.предназначены.для.обновления.на.этапе.обратного.рас- пространения.ошибки,.чтобы.минимизировать.потери.модели,.такие.как. ядро.и.систематическая.ошибка.слоя.Dense;

7.4. Разработка своего цикла обучения и оценки    253

.необучаемые веса.—.предназначены.для.обновления.на.этапе.прямого.про- хода.слоями,.которым.они.принадлежат..Например,.если.вы.решите.добавить.в.свой.слой.счетчик.пакетов,.обработанных.к.данному.моменту,.то.эта. информация.будет.храниться.в.необучаемом.весе.и.после.обработки.каждого. пакета.ваш.слой.будет.увеличивать.счетчик.на.единицу.

Из.встроенных.слоев.Keras.необучаемые.веса.имеет.только.слой.BatchNorma­ lization, .который .мы .обсудим .в .главе .9. .Необучаемые .веса .нужны .слою. BatchNormalization .для.запоминания.среднего.и.стандартного.отклонения. обрабатываемых.данных,.чтобы.потом.динамически.выполнить.нормализацию признаков.(с.этой.идеей.вы.познакомились.в.главе.6).

Принимая.во.внимание.эти.две.детали,.этап.обучения.с.учителем.в.конечном. итоге.будет.выглядеть.следующим.образом:

def train_step(inputs, targets): with tf.GradientTape() as tape:

predictions = model(inputs, training=True) loss = loss_fn(targets, predictions)

gradients = tape.gradients(loss, model.trainable_weights) optimizer.apply_gradients(zip(model.trainable_weights, gradients))

7.4.2. Низкоуровневое использование метрик

В.низкоуровневом.цикле.обучения.часто.возникает.необходимость.использовать. метрики.Keras.(и.стандартные,.и.нестандартные)..Вы.уже.познакомились.с.методами.поддержки.метрик:.просто.вызовите.update_state(y_true, y_pred).для.каждого. пакета.целей.и.прогнозов.и.result().для.получения.текущего.значения.метрики:

metric = keras.metrics.SparseCategoricalAccuracy() targets = [0, 1, 2]

predictions = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] metric.update_state(targets, predictions) current_result = metric.result() print(f"result: {current_result:.2f}")

Вам.также.может.потребоваться.отслеживать.среднее.значение.скаляра,.например .величины .потери .модели..Это .можно .сделать .с .помощью .метрики. keras.metrics.Mean:

values = [0, 1, 2, 3, 4] mean_tracker = keras.metrics.Mean() for value in values:

mean_tracker.update_state(value)

print(f"Mean of values: {mean_tracker.result():.2f}")

Не.забудьте.вызвать.metric.reset_state(),.когда.понадобится.сбросить.текущий. результат.(в.начале.эпохи.обучения.или.в.начале.этапа.оценки).

Возврат текущих значений метрик и потерь

254    Глава 7. Работа с Keras: глубокое погружение

7.4.3. Полный цикл обучения и оценки

Давайте.теперь.объединим.прямой.проход,.обратное.распространение.ошибки. и.отслеживание.метрик.в.шаговую.функцию.(training.step.function),.подобную. fit(),.которая.принимает.пакет.данных.и.цели.и.возвращает.сведения,.которые. будут.отображаться.в.индикаторе.выполнения.fit().

Листинг 7.21. Разработка своего цикла обучения: шаговая функция

Подготовка

model = get_mnist_model() функции потерь Подготовка

оптимизатора

loss_fn = keras.losses.SparseCategoricalCrossentropy()

 

 

 

 

 

 

Подготовка

 

 

 

 

 

 

 

 

 

optimizer = keras.optimizers.RMSprop()

 

 

 

 

 

 

 

 

списка метрик

 

 

 

 

 

 

 

metrics = [keras.metrics.SparseCategoricalAccuracy()]

 

 

 

 

 

 

для отслеживания

 

 

 

 

 

 

loss_tracking_metric = keras.metrics.Mean()

 

 

Подготовка метрики Mean для слежения

 

 

 

 

 

def train_step(inputs, targets):

 

за средним значением потерь

 

 

 

 

 

 

 

 

 

with tf.GradientTape() as tape:

 

 

Выполнение прямого

 

 

predictions = model(inputs, training=True)

 

 

прохода. Обратите внимание

loss = loss_fn(targets, predictions)

 

 

на аргумент training=True

gradients = tape.gradient(loss, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights))

logs = {}

for metric in metrics: metric.update_state(targets, predictions) logs[metric.name] = metric.result()

loss_tracking_metric.update_state(loss) logs["loss"] = loss_tracking_metric.result() return logs

Обратное распространение ошибки. Обратите внимание, что здесь используется model.trainable_weights

Слежение за метриками

Слежение за средним значением потерь

Важно.не.забыть.сбросить.состояние.метрик.в.начале.каждой.эпохи.и.перед.началом.этапа.оценки..Вот.вспомогательная.функция,.которая.сделает.это.

Листинг 7.22. Разработка своего цикла обучения: сброс метрик

def reset_metrics():

for metric in metrics: metric.reset_state()

loss_tracking_metric.reset_state()

Теперь.можно.закончить.реализацию.цикла.обучения..Обратите.внимание,.что. здесь.используется.объект.tf.data.Dataset,.превращающий.массив.NumPy. с.данными.в.итератор,.который.выполняет.итерации.по.данным.пакетами.размером.32.

7.4. Разработка своего цикла обучения и оценки    255

Листинг 7.23. Разработка своего цикла обучения: сам цикл

training_dataset = tf.data.Dataset.from_tensor_slices( (train_images, train_labels))

training_dataset = training_dataset.batch(32) epochs = 3

for epoch in range(epochs): reset_metrics()

for inputs_batch, targets_batch in training_dataset: logs = train_step(inputs_batch, targets_batch)

print(f"Results at the end of epoch {epoch}") for key, value in logs.items():

print(f"...{key}: {value:.4f}")

Ниже.приводится.цикл.оценки:.простой.цикл.for,.многократно.вызывающий. функцию.test_step(),.которая.обрабатывает.один.пакет.данных..Функция.test_ step().—.лишь.подмножество.логики.train_step()..В.ней.отсутствует.код,.обнов- ляющий.веса.модели,.то.есть.все,.что.связано.с.GradientTape.и.оптимизатором.

Листинг 7.24. Разработка своего цикла обучения: цикл оценки

def test_step(inputs, targets):

predictions = model(inputs, training=False) loss = loss_fn(targets, predictions)

logs = {}

for metric in metrics: metric.update_state(targets, predictions) logs["val_" + metric.name] = metric.result()

loss_tracking_metric.update_state(loss) logs["val_loss"] = loss_tracking_metric.result() return logs

Обратите внимание

на аргумент training=False

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels)) val_dataset = val_dataset.batch(32)

reset_metrics()

for inputs_batch, targets_batch in val_dataset: logs = test_step(inputs_batch, targets_batch)

print("Evaluation results:") for key, value in logs.items():

print(f"...{key}: {value:.4f}")

Поздравляю,.вы.только.что.реализовали.свои.полноценные.версии.функций. fit().и.evaluate()!.Ну.или.почти.полноценные:.на.самом.деле.fit().и.evaluate() . реализуют.множество.других.возможностей,.включая.крупномасштабные.распределенные.вычисления,.которые.требуют.немного.больше.работы..Они.также. включают.некоторые.важные.оптимизации.производительности.

Давайте .рассмотрим .одну .из .таких .оптимизаций: .компиляцию .функции. TensorFlow.