あかすくぱるふぇ

同人サークル「あかすくぱるふぇ」のブログです。

pytorch(torchvision)のdatasetsは有名なデータセットを簡単に扱うためのパッケージです。
その中にはCOCOデータセットも含まれているのですが、セグメンテーション情報については(元々のデータセットの時点で)符号化されていて、そのままでは利用することができません。
そこで本記事では、pytorchでCOCOデータセットのセグメンテーション情報をデコードして取得する方法を解説します。

まずは、普通にCOCOデータセットからデータをLoadする方法。
import torch
from torch.utils import data
from torchvision import transforms, datasets

if __name__ == '__main__':

# DataLoaderを生成
transform = transforms.Compose([
transforms.ToTensor()
])
train_set = datasets.CocoDetection(root='../data/coco/train2017',
annFile='../data/coco/instances_train2017.json',
transform=transform)
train_loader = torch.utils.data.DataLoader(train_set)

# DataLoad
for data, target in train_loader:
print('hoge')
データセットはあらかじめCOCOデータセットのwebサイトからダウンロードしておきます。
ImagesとAnnotationsをセットでダウンロードしてください。
http://cocodataset.org/#download

上記のコードで画像dataとアノテーションデータtargetをLoadできます。
ただし、前述したように、targetの'segmentation'キーの値は符号化されていて、そのままでは利用できません。

続いて、セグメンテーション情報のデコードです。
Loadしてからデコードしてもよいのですが、tensor型に変換されてLoadされてくるので、処理が煩雑になってしまいます。
そこで、CocoDetectionを継承したクラスCocoSegmentationを作って、その内部でセグメンテーション情報をデコードしてLoadするようにします。
以下がCocoSegmentationクラスのコードです。
from torchvision.datasets import coco
from pycocotools.mask import frPyObjects, decode

class CocoSegmentation(coco.CocoDetection):
def __getitem__(self, index):
img, target = super().__getitem__(index)
for category in target:
seg_rle = category['segmentation']
category['segmentation'] = decode(frPyObjects(seg_rle, img.shape[1], img.shape[2]))
return img, target
'segmentation'キーの値をデコードして書き換えているだけです。
デコード処理については以下のページが参考になります。
https://github.com/cocodataset/cocoapi/issues/4

あとは、最初に示したコードでCocoDetectionの代わりにCocoSegmentationクラスを利用すれば、デコードされたセグメンテーション情報を利用することができます。

以上です。



・追記
上のコードだとリサイズやバッチ処理が入ると動かないことが判明しました。
最終的なコードは以下のようになりました。
import numpy as np
from torchvision import transforms
from torchvision.datasets import coco
from pycocotools.mask import frPyObjects, decode
from PIL import Image
import os

class CocoSegmentation(coco.CocoDetection):
def __getitem__(self, index):

# データ入力
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
img = Image.open(os.path.join(self.root, path)).convert('RGB')

# セグメンテーション情報のデコード
for category in target:
seg_rle = category['segmentation']
tmp = decode(frPyObjects(seg_rle, img.size[1], img.size[0]))
if tmp.ndim == 3:
tmp = np.sum(tmp, axis=2, dtype=np.uint8)
category['segmentation'] = tmp

# data_transform
if self.transform is not None:
img = self.transform(img)

# target_transform
for category in target:
pilImg = Image.fromarray(category['segmentation'])
tmp = pilImg.resize((img.shape[2], img.shape[1]), resample=Image.NEAREST)
target_transform = transforms.Compose([
transforms.ToTensor()
])
category['segmentation'] = target_transform(tmp)

return img, target

↑このページのトップヘ