セグメンテーション用の教師データ、インデックスカラー画像(PNG形式)

インデックスカラー画像(PNG形式)のインデックス及びカラーパレット確認方法、セグメンテーション用の教師データとして用いる方法について解説します

images/cards/PyTorch_logo_black.svg.webp

目次

1. セグメンテーション用の教師データ(ラベル画像)

セグメンテーション用の教師データとして良く用いられるラベル画像は、インデックスカラー画像(PNG 形式)で作られることが多いです。 インデックスカラーとは、色が順番に定義されたカラーパレットが存在し、0:黒、1:赤、2:オレンジのような形で表現されています。

インデックスカラー画像について
参照: CamVid

ダイレクトカラー画像

ダイレクトカラー画像は、JPEG や PNG 形式など、各ピクセル毎に色情報(例えば、RGB)を持っている形式の画像です。 各ピクセルで RGB 値を持つため、インデックスカラー画像と比較すると、ファイルサイズは大きくなります。

インデックスカラー画像

インデックスカラー画像は、GIF や PNG 形式など、カラーパレットを持つ形式の画像です。 例えば、256 色のカラーパレットには 0 番から 255 番までの番号が割り振られ、 この番号をインデックス番号と呼びます。 0 ~ 255 のインデックス番号は 8bit で表現可能なため、8bit のインデックスカラー画像になります。 各ピクセルは、カラーパレットのインデックス番号のみを持っているため、ダイレクトカラー画像と比較すると、ファイルサイズは小さくなります。

2. カラーパレットの確認

from PIL import Image

label = Image.open("test.png")
label = label.convert("P")
palette = label.getpalette()
palette = np.array(palette).reshape(-1, 3)
print("mode : " + label.mode)
print(palette)

3. インデックスの確認

import numpy as np
from PIL import Image

label = Image.open("test.png")
label_np = np.asarray(label)
count = np.unique(label_np)
print(count)

4. インデックスカラー画像の作り方

import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

label = np.array(
    [[0,0,0,0,0,0],
    [0,1,1,1,1,0],
    [0,1,2,2,1,0],
    [0,1,2,2,1,0],
    [0,1,1,1,1,0],
    [0,0,0,0,0,0]], dtype = 'int8'
)

print("label_shape:", label.shape)
plt.imshow(label)
plt.axis("off")
label = Image.fromarray(label, mode = "P")

color_palette = [
    255, 0, 0,
    255, 255, 0,
    0, 255, 255,
]
label.putpalette(color_palette)
label.save("./label.png")

4. インデックスカラー画像を読み出し、one-hot 表現に

one-hot 表現とは

K 次元ベクトルのうち 1 つの次元だけが 1 であり,他の次元の値は全て 0 であるベクトル表現です。 クラス識別の出力の「K クラス分類の正解ベクトル」としてよく用いる表現のエンコード手法です。

import numpy as np
from PIL import Image

label = Image.open("label.png")
label = np.asarray(label)
class_valus = [0, 1, 2]
labels = [(label == v) for v in class_values]
label = np.stack(labels, axis=-1).astype('float')

PyTorch の場合

import numpy as np
from PIL import Image
import torch

label = Image.open("label.png")
label = np.asarray(label)
label_tensor = torch.from_numpy(label.astype(np.float32)).clone() #tensor に変換
label_onehot = torch.nn.functional.one_hot(label_tensor.long(), num_classes=3)
print("label_tensor_shape:",label_tensor.size())
print(label_tensor)
plt.imshow(label_tensor)
plt.axis("off")
print("label_tensor_shape:",label_onehot.size())
print(label_onehot)

参照

関連記事