スナックelve 本店

バツイチ40代女の日記です

時よとまれ、お前は美しい

引き続き写経して心を落ち着けております。エルベです。

snack.elve.club
www.tensorflow.org

一応動いて、判別もできたんで学習結果だけダウンロードしてローカルで実行しようとしてはまってます。
google Pythonのバージョン3.7.12だったんだもん。
ワイのところはもう3.10~だったからさー(ノД`)ナェルシク もう開発一旦止まって!!(タイトル回収)

とりあえず現在動くコード
fit_generatorは使えなくなる予定らしいからなんとかしないとね(^_^;)

import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import img_to_array, load_img
import os
import numpy as np
import matplotlib.pyplot as plt
train_dir = os.path.join("./drive/MyDrive/img/", 'train')
validation_dir = os.path.join("./drive/MyDrive/img/", 'validation')

train_pic_dir = os.path.join(train_dir, '学習用ピクミン画像')  # 学習用ピクミン画像
train_etc_dir = os.path.join(train_dir, '学習用その他')  # 学習用
validation_pic_dir = os.path.join(validation_dir, 'ピクミン')  # 検証用のピクミン
validation_etc_dir = os.path.join(validation_dir, 'その他')  # 検証用

num_pic_tr = len(os.listdir(train_pic_dir))
num_etc_tr = len(os.listdir(train_etc_dir))

num_pic_val = len(os.listdir(validation_pic_dir))
num_etc_val = len(os.listdir(validation_etc_dir))

total_train = num_pic_tr + num_etc_tr
total_val = num_pic_val + num_etc_val

print('total training cat images:', num_pic_tr)
print('total training dog images:', num_etc_tr)

print('total validation cat images:', num_pic_val)
print('total validation dog images:', num_etc_val)
print("--")
print("Total training images:", total_train)
print("Total validation images:", total_val)


batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
train_image_generator = ImageDataGenerator(rescale=1./255) # 学習データのジェネレータ
validation_image_generator = ImageDataGenerator(rescale=1./255) # 検証データのジェネレータ

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode='binary')

sample_training_images, _ = next(train_data_gen)
# この関数は、1行5列のグリッド形式で画像をプロットし、画像は各列に配置されます。
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
plotImages(sample_training_images[:5])

model = Sequential([
    Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
    MaxPooling2D(),
    Conv2D(32, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(64, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.summary()

history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=num_validation // batch_size
)

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

ここで
f:id:elve:20220115134321p:plain
こんな感じ。別に調整不要そうだけど写経を続ける(学習に反映させる前にハマった)

#水平反転
image_gen = ImageDataGenerator(rescale=1./255, horizontal_flip=True)

train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
                                               directory=train_dir,
                                               shuffle=True,
                                               target_size=(IMG_HEIGHT, IMG_WIDTH))
augmented_images = [train_data_gen[0][0][0] for i in range(5)]
# 上で学習用画像の可視化のために定義、使用されたおなじカスタムプロット関数を再利用する
plotImages(augmented_images)

#ランダムに左右45度の範囲で回転
image_gen = ImageDataGenerator(rescale=1./255, rotation_range=45)
train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
                                               directory=train_dir,
                                               shuffle=True,
                                               target_size=(IMG_HEIGHT, IMG_WIDTH))

augmented_images = [train_data_gen[0][0][0] for i in range(5)]
plotImages(augmented_images)

ここで飽きて、学習結果を保存し、次回は読み込むように仕様、と画策

p='./drive/MyDrive/'
model_json_str = model.to_json()
open(p+'model.json', 'w').write(model_json_str)
model.save_weights(p+'weights.h5');

試しにピクミンの画像を判定

## 次回モデル読み込み
# from tensorflow.keras.models import model_from_json
# from tensorflow.keras.preprocessing.image import img_to_array, load_img
# p = './drive/MyDrive/'
# print(p+'model.json')
## モデルを読み込む
# model = model_from_json(open(p+'model.json').read())
## 学習結果を読み込む
# model.load_weights(p + 'weights.h5')

img_path = os.path.join(validation_dir,  '20220113_115516000_iOS.png')
print(img_path)
img = img_to_array(load_img(img_path, target_size=(150, 150)))
img_nad = img_to_array(img)/255
img_nad = img_nad[None, ...]

label = ['pic', 'etc']
pred = model.predict(img_nad, batch_size=1, verbose=0)
pred_label = label[np.argmax(pred[0])]
if pred_label == 'pic':
  print("pic!!")

無事9割以上の確率でピクミンだと検出。よしっ!