Добавил:
Опубликованный материал нарушает ваши авторские права? Сообщите нам.
Вуз: Предмет: Файл:
Шолле Ф. - Глубокое обучение на Python (Библиотека программиста) - 2023.pdf
Скачиваний:
6
Добавлен:
07.04.2024
Размер:
11.34 Mб
Скачать

4.1. Классификация отзывов к фильмам    143

Что.касается.выбора.оптимизатора,.то.в.этой.модели.мы.будем.использовать. rmsprop .—.хороший.вариант.по.умолчанию.для.большинства.задач.

Теперь.настроим.модель,.передав.ей.оптимизатор.rmsprop .и.функцию.потерь. binary_crossentropy..Обратите.внимание,.что.мы.также.задали.мониторинг. точности.во.время.обучения.

Листинг 4.5. Компиляция модели

model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

4.1.4. Проверка решения

Как.отмечалось.в.главе.3,.качество.модели.глубокого.обучения.никогда.не.долж- но.оцениваться.на.обучающих.данных.—.обычно.для.мониторинга.изменения. точности.модели.во.время.обучения.используется.проверочная.выборка..Создадим.такую.выборку,.включив.в.нее.10.000.образцов.из.оригинального.набора. обучающих.данных.

Листинг 4.6. Создание проверочного набора

x_val = x_train[:10000] partial_x_train = x_train[10000:] y_val = y_train[:10000] partial_y_train = y_train[10000:]

Теперь.проведем.обучение.модели.в.течение.20.эпох.(выполнив.20.итераций. по.всем.образцам.в.обучающей.выборке).пакетами.по.512.образцов..В.то.же. время .будем .следить .за .потерями .и .точностью .по .10 .000 .отложенных .образцов..Для .этого .достаточно .передать .проверочные .данные .в .аргументе. validation_data.

Листинг 4.7. Обучение модели

history = model.fit(partial_x_train, partial_y_train, epochs=20, batch_size=512,

validation_data=(x_val, y_val))

При.использовании.CPU.на.каждую.эпоху.будет.потрачено.менее.2.секунд.—. а.все.обучение.закончится.через.20.секунд..В.конце.каждой.эпохи.обучение. приостанавливается.для.вычисления.потерь.и.точности.на.10.000.образцах. проверочных.данных.

b означает solid blue line — «сплошная синяя линия»
bo означает blue dot — «синяя точка»

144    Глава 4. Начало работы с нейронными сетями: классификация и регрессия

Обратите.внимание,.что.вызов.model.fit() .возвращает.объект.History,.с.ко- торым.вы.познакомились.в.главе.3..Этот.объект.имеет.поле.history .—.словарь. с.данными.обо.всем.происходящем.в.процессе.обучения..Заглянем.в.него:

>>>history_dict = history.history

>>>history_dict.keys()

[u"accuracy", u"loss", u"val_accuracy", u"val_loss"]

Словарь.содержит.четыре.элемента.—.по.одному.на.метрику,.—.за.которыми. осуществлялся.мониторинг.в.процессе.обучения.и.проверки..В.следующих.двух. листингах.используется.библиотека.Matplotlib.для.вывода.графиков.потерь. (рис..4.4).и.графиков.точности.на.этапах.обучения.и.проверки.(рис..4.5)..Имейте.в.виду,.что.ваши.результаты.могут.несколько.отличаться:.это.обусловлено. различием.в.случайных.числах,.использовавшихся.для.инициализации.сети.

Рис. 4.4. Потери на этапах обучения и проверки

Листинг 4.8. Формирование графиков потерь на этапах обучения и проверки

import matplotlib.pyplot as plt history_dict = history.history loss_values = history_dict["loss"] val_loss_values = history_dict["val_loss"] epochs = range(1, len(loss_values) + 1)

plt.plot(epochs, loss_values, "bo", label="Потери на этапе обучения") plt.plot(epochs, val_loss_values, "b", label="Потери на этапе проверки") plt.title("Потери на этапах обучения и проверки")

plt.xlabel("Эпохи") plt.ylabel("Потери") plt.legend() plt.show()

4.1. Классификация отзывов к фильмам 

145

Рис. 4.5. Точность на этапах обучения и проверки

 

 

Листинг 4.9. Формирование графиков точности на этапах обучения и проверки

Очистить

рисунок

plt.clf()

acc = history_dict["accuracy"] val_acc = history_dict["val_accuracy"]

plt.plot(epochs, acc, "bo", label="Точность на этапе обучения") plt.plot(epochs, val_acc, "b", label="Точность на этапе проверки") plt.title("Точность на этапах обучения и проверки") plt.xlabel("Эпохи")

plt.ylabel("Точность") plt.legend() plt.show()

Как.видите,.на.этапе.обучения.потери.снижаются.с.каждой.эпохой,.а.точность. растет..Именно.такое.поведение.ожидается.от.оптимизации.градиентным.спуском:.величина,.которую.вы.пытаетесь.минимизировать,.должна.становиться. все.меньше.с.каждой.итерацией..Но.это.не.относится.к.потерям.и.точности. на.этапе.проверки:.похоже,.что.они.достигли.пика.в.четвертую.эпоху..Перед. вами.пример.того,.о.чем.мы.говорили.выше:.модель,.показывающая.хорошие. результаты.на.обучающих.данных,.не.обязательно.даст.такие.же.на.данных,. которые.не.видела.прежде..Выражаясь.точнее,.в.данном.случае.наблюдается. переобучение:.после.четвертой.эпохи.произошла.чрезмерная.оптимизация.на. обучающих.данных.—.в.результате.получилось.представление,.характерное. для.обучающих.данных.и.не.обобщающее.данные.за.пределами.обучающего. набора.

146    Глава 4. Начало работы с нейронными сетями: классификация и регрессия

В.данном.случае.для.предотвращения.переобучения.можно.прекратить.обучение.после.четвертой.эпохи..Вообще,.есть.целый.спектр.приемов,.ослабляющих. эффект.переобучения,.—.их.мы.рассмотрим.в.главе.5.

А.теперь.обучим.новую.модель.с.нуля.в.течение.четырех.эпох.и.затем.оценим. получившийся.результат.на.контрольных.данных.

Листинг 4.10. Обучение новой модели с нуля

model = keras.Sequential([ layers.Dense(16, activation="relu"), layers.Dense(16, activation="relu"), layers.Dense(1, activation="sigmoid")

])

model.compile(optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])

model.fit(x_train, y_train, epochs=4, batch_size=512) results = model.evaluate(x_test, y_test)

Конечные.результаты:

>>> results

 

Первое число (0,29) — это потери

 

[0.2929924130630493, 0.88327999999999995]

 

 

на контрольной выборке; второе число (0,88) —

 

 

 

 

 

точность на контрольной выборке

Это.простейшее.решение.позволило.достичь.точности.88.%..При.использовании.же.самых.современных.подходов.точность.может.доходить.до.95.%.

4.1.5. Использование обученной сети для предсказаний на новых данных

После.обучения.модели.ее.можно.использовать.для.решения.практических. задач..Например,.попробуем.оценить.вероятность.того,.что.отзывы.будут.положительными,.с.помощью.метода.predict:

>>> model.predict(x_test)

array([[

0.98006207]

[

0.99758697]

[

0.99975556]

...,

[

0.82167041]

[

0.02885115]

[

0.65371346]], dtype=float32)

Как.видите,.модель.уверена.в.одних.образцах.(0,99.или.выше.или.0,01.или.ниже),. но.не.так.уверена.в.других.(0,6;.0,4).