PyTorchで画像セグメンテーション

segmentation-models.pytorchを用いた、マルチクラスの画像セグメンテーションの実装について解説します

images/cards/PyTorch_logo_black.svg.webp

目次

1. まずは、開発環境を準備

ここでは、Ubuntu 環境を想定して解説します。もちろん Windows でも問題ありません。

  • Ubuntu 20.04
  • GeForce RTX 2080 Ti
  • CUDA Version: 11.5
  • Python 3.8
  • PyTorch 1.4.0
  • PyTorchVision 0.5.0

2. segmentation-models.pytorch をインストール

Tensorflow(Keras)時代からお世話になっていますが、Segmentaion Models のライブラリを利用します。このライブラリの特徴を以下に列挙します。

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 124 available encoders (and 500+ encoders from timm)
  • All encoders have pre-trained weights for faster and better convergence
  • Popular metrics and losses for training routines
pipenv install segmentation-models-pytorch

以下、examples/cars segmentation (camvid).ipynbの内容を少し改変して解説していきます。

3. データセットの準備

アノテーションしたラベル画像(インデックスカラー画像)を用意します。ない場合は、CamVid のデータセットを用いることもできます。用意できたデータセットを以下のようなフォルダ構成で分けておきます。train : val : test の比率は、ここでは 8:1:1 としました。

  • train: 訓練用の画像データ(jpg 形式)
  • trainannot: 訓練用の画像データに対するアノテーションしたラベル画像(png 形式、ファイル名は同じにしておく)
  • val: 検証用の画像データ(jpg 形式)
  • valannot: 検証用の画像データに対するアノテーションしたラベル画像(png 形式、ファイル名は同じにしておく)
  • test: 評価用の画像データ(jpg 形式)
  • testannot: 評価用の画像データに対するアノテーションしたラベル画像(png 形式、ファイル名は同じにしておく)
DATA_DIR = './dataset'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')
x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

4. Dataloader

データセットから読み出すデータローダーでは、インデックスカラー画像から one-hot 表現に変換したデータを返すようにします。 以下を「dataloader.py」として保存します。

class Dataset(BaseDataset):
    CLASSES = ['0', '1', '2']
    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):

        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        pil_mask = Image.open(self.masks_fps[i].replace(".jpg", ".png"))
        mask = np.asarray(pil_mask)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)

5. Augmentaion

Albumentations とは

Albumentations は、機械学習用データ拡張用 の Python ライブラリで、Data augmentation でよく使われる機能が豊富に揃っています。以下を「preprocess.py」として保存します。

import albumentations as albu

def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),
        albu.IAAAdditiveGaussianNoise(p=0.2),
        albu.IAAPerspective(p=0.5),
        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightness(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.IAASharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.RandomContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

6. セグメンテーション用のモデルを生成し、学習!

  • セグメンテーションモデルには、FPN (Feature Pyramid Networks)を用いています。
  • 損失関数には Dice 係数(DiceLoss)、評価には IoU(Jaccard 係数)、Optimizer には Adam を設定しています。検証用データに対して IoU スコアが更新された時に、モデル情報を保存しています。
  • エンコーダーの学習済パラメータをダウンロードする際にエラーが出たので、「import ssl」の 2 行を追加しています。
import torch
import numpy as np
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from dataloader import Dataset
from preprocess import get_training_augmentation, get_preprocessing, get_validation_augmentation
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['0', '1', '2']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = Dataset(
    x_train_dir,
    y_train_dir,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

valid_dataset = Dataset(
    x_valid_dir,
    y_valid_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0001),
])

train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

max_score = 0

for i in range(0, 40):

    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

7. 学習したモデルで推論!

学習の結果は、「best_model.pth」に保存されているので、これを読み出して評価します。

best_model = torch.load('./best_model.pth')

test_dataset = Dataset(
x_test_dir,
y_test_dir,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

test_epoch = smp.utils.train.ValidEpoch(
model=best_model,
loss=loss,
metrics=metrics,
device=DEVICE,
)

logs = test_epoch.run(test_dataloader)

test_dataset_vis = Dataset(
x_test_dir, y_test_dir,
classes=CLASSES,
)

for i in range(5):
    n = np.random.choice(len(test_dataset))

    image_vis = test_dataset_vis[n][0].astype('uint8')
    gt_mask_vis = test_dataset_vis[n][1]
    image, gt_mask = test_dataset[n]

    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    # Channels First → Channels Last
    pr_mask = pr_mask.transpose(1,2,0)

    visualize(
        image=image_vis,
        ground_truth_mask=gt_mask_vis.squeeze(),
        predicted_mask=pr_mask
    )

参考

関連記事