Raspberry Piで、Pytorchを動かしてみよう!

教育用のマイコンRaspberry Piで、Pytorch のAI モデル(MobileNet v2)を動かしてみます

images/cards/pytorch-raspberrypi.webp

目次

1. PyTorch

PyTorch は、コンピュータビジョンや自然言語処理で利用されている Torch を元に作られた、Python のオープンソースの機械学習ライブラリである。 最初は Facebook の人工知能研究グループ AI Research lab により開発された。

2. 開発環境の確認

  • ハードウェア:Raspberry Pi 3 Model B Rev 1.2
  • OS:Raspbian 11(bullseye)
$ cat /proc/cpuinfo
Model           : Raspberry Pi 3 Model B Rev 1.2

$ lsb_release -a
Distributor ID: Raspbian
Description:    Raspbian GNU/Linux 11 (bullseye)
Release:        11
Codename:       bullseye

3. Python3.7 のインストール

choonkiatlee さんのchoonkiatlee/pi-torchリポジトリから、 Raspberry Pi (armv7l)用にコンパイルされた PyTorch イメージ(wheel)を入手することができます。 自分でコンパイルするのは、かなり大変です。よって、こちらを利用されて頂きます。 そのため、まずは Python3.7 のインストールが必要です。asdfを利用します。

git clone https://github.com/asdf-vm/asdf.git ~/.asdf --branch v0.12.0
. "$HOME/.asdf/asdf.sh"
asdf plugin add python
asdf install python 3.7.0
asdf local python 3.7.0
export PIPENV_VENV_IN_PROJECT=true
pipenv --python 3.7

4. Raspberry Pi (armv7l)用にコンパイルされた PyTorch のインストール

git clone https://github.com/choonkiatlee/pi-torch
cd pi-torch
pip install torch-1.4.0a0+7f73f1d-cp37-cp37m-linux_armv7l.whl
pip install torchvision-0.5.0a0+85b8fbf-cp37-cp37m-linux_armv7l.whl

5. OpenCV のインストール

私の環境だと、最新の OpenCV をインストールすると「libimath-2_2.so.23: cannot open shared object file」などのエラーが発生したので、少し古いバージョンをインストールします。

pip install opencv-python==3.4.18.65

パッケージは以下のような感じになります。

pip list

Package       Version
------------- ---------------
numpy         1.21.6
opencv-python 3.4.18.65
Pillow        9.5.0
pip           23.1.2
setuptools    67.8.0
six           1.16.0
torch         1.4.0a0+7f73f1d
torchvision   0.5.0a0+85b8fbf
wheel         0.40.0

6. Pytorch の AI モデル(MobileNet v2)を動かしてみよう!

PyTorch のチュートリアルREAL TIME INFERENCE ON RASPBERRY PI 4 (30 FPS!)を参考に作ります。 ImageNet のクラス表をダウンロードしておきます。

wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
python test.py

test.py

import time
import torch
import numpy as np
from torchvision import models, transforms
import cv2
from PIL import Image
import json

with open('imagenet_class_index.json', 'r') as f:
    classes = json.load(f)

torch.backends.quantized.engine = 'qnnpack'

cap = cv2.VideoCapture(0, cv2.CAP_V4L2)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 224)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 224)
cap.set(cv2.CAP_PROP_FPS, 36)

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

net = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
# jit model to take it from ~20fps to ~30fps
net = torch.jit.script(net)

started = time.time()
last_logged = time.time()
frame_count = 0

with torch.no_grad():
    while True:
        # read frame
        ret, image = cap.read()
        if not ret:
            raise RuntimeError("failed to read frame")

        # convert opencv output from BGR to RGB
        image = image[:, :, [2, 1, 0]]
        permuted = image

        # preprocess
        input_tensor = preprocess(image)

        # create a mini-batch as expected by the model
        input_batch = input_tensor.unsqueeze(0)

        # run model
        output = net(input_batch)
        # do something with output ...
        top = list(enumerate(output[0].softmax(dim=0)))
        top.sort(key=lambda x: x[1], reverse=True)
        for idx, val in top[:10]:
            print(f"{val.item()*100:.2f}% {classes[str(idx)][1]}")

        # log model performance
        frame_count += 1
        now = time.time()
        if now - last_logged > 1:
            print(f"{frame_count / (now-last_logged)} fps")
            last_logged = now
            frame_count = 0

参考

関連記事