ENGINEER BLOG

ENGINEER BLOG

AIと一緒にコロナ禍の運動不足を解消しよう!

はじめに

こんにちは。イノベーション本部の樋口です。

コロナ禍が続いておりますが、皆さんいかがお過ごしでしょうか。

私はほとんど外出しなくなったせいか、1年で体重が5kgも増えてしまいました(泣)

最近になってこれじゃマズいなぁ。と思い始め、スキマ時間で筋トレしています。

せっかくやるなら、正しいフォームで質のいい筋トレをしたいと思い、AIを使った「姿勢推定」に注目。

姿勢推定

姿勢推定とは、人体や動物などの姿勢を推定する技術です。
例えば人体であれば、首やお尻や膝などを検出し、それぞれをつなげることで姿勢を表現できます。

姿勢推定
引用元:Diana.grytsku – jp.freepik.com によって作成された technology 写真

筋トレに応用できそう!ということで、
今回は普段やっているスクワットのフォームで姿勢推定したいと思います。

0. 環境

  • Google Coraboratory

1. やってみる

独断と偏見で、深い / 浅い スクワットのフォームを仮定してそれぞれ撮影しました。

スクワット比較

極端なフォームで撮影したので、姿勢推定なんかしなくても見りゃ分かるじゃん!と思うかもしれませんが、(笑)

この二つの画像の姿勢推定をしてみます。

書籍「つくりながら学ぶ! PyTorchによる発展ディープラーニング」の実装コードGitHubと、こちらのサイト「ディープラーニングでスクワットのフォームをチェックする」を参考にしました。

今回は学習させたOpenPoseで姿勢推定を行います。ソースは概略です。

# モデルの定義
net = OpenPoseNet()

# 学習済みモデルと、OpenPoseNetでネットワークの層の名前が違うので、
# 対応させてロードする
net_weights = torch.load('./pose_model_scratch.pth', map_location={'cuda:0': 'cpu'})
keys = list(net_weights.keys())

weights_load = {}

for i in range(len(keys)):
    weights_load[list(net.state_dict().keys())[i]
                 ] = net_weights[list(keys)[i]]

state = net.state_dict()
state.update(weights_load)
net.load_state_dict(state)

# OpenPoseでheatmapsとPAFsを求めます
net.eval()
predicted_outputs, _ = net(x)

# 画像をテンソルからNumPyに変化し、サイズを戻します
pafs = predicted_outputs[0][0].detach().numpy().transpose(1, 2, 0)
heatmaps = predicted_outputs[1][0].detach().numpy().transpose(1, 2, 0)

pafs = cv2.resize(pafs, size, interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(heatmaps, size, interpolation=cv2.INTER_CUBIC)

pafs = cv2.resize(
    pafs, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
heatmaps = cv2.resize(
    heatmaps, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)

_, result_img, _, _ = decode_pose(oriImg, heatmaps, pafs)

# 結果を描画
plt.imshow(result_img)
plt.show()

スクワットサンプル

いい感じですね。

この状態だと比較するとき若干見づらいので、
首、腰、右膝、右足首だけを残した姿勢推定にしてみます。

necessary_parts=[1,8,9,10] 
fig, ax = plt.subplots(2, 2, figsize=(16, 10))

for i, part in enumerate(necessary_parts):
    heat_map = heatmaps[:, :, part]
    heat_map = np.uint8(cm.jet(heat_map)*255)
    heat_map = cv2.cvtColor(heat_map, cv2.COLOR_RGBA2RGB)
    blend_img = cv2.addWeighted(oriImg, 0.5, heat_map, 0.5, 0)
    ax[int(i/2), i%2].imshow(blend_img)

joints = []
param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
for part in necessary_parts:
    heat_map = heatmaps[:, :, part]
    peaks = find_peaks(param, heat_map)

    if len(peaks) == 0:
      joints.append([np.nan, np.nan])
    
    if peaks.shape[0]>1:
      max_peak = None
      for peak in peaks: 
        val = heat_map[peak[1], peak[0]]
        if max_peak is None or heat_map[max_peak[1], max_peak[0]] < val:
          max_peak = peak
    else:
      max_peak = peaks[0]

    joints.append(max_peak)

img = oriImg.copy()
for i in range(len(joints)-1):
  img = cv2.line(img,tuple(joints[i].astype(int)),tuple(joints[i+1].astype(int)),(255,0,0),3)

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
cv2_imshow(img)

姿勢推定後の 深い / 浅いフォームを並べて比較してみます。

スクワット比較_analysis

線があることでよりはっきりとフォームの違いが分かりますね。
ここまでで、二つの画像を見比べて簡易的にフォームのチェックが出来ちゃいます。

2. 撮影した動画に応用してみる

せっかくなので、スクワットしている動画でも姿勢推定してみます。

まずは、動画保存のために必要なライブラリのインストールをします。

# 動画保存のために必要なライブラリのインストール
!pip install scikit-video
import skvideo.io

test_video_path = './squat_good.mov'
video_data = skvideo.io.vread(test_video_path)

つぎに、動画の各フレームに対して姿勢推定を行います。


batch_size = 10
orgsize_wh = (video_data[0].shape[1], video_data[0].shape[0])
out_video = []
person_to_joint_assoc_list = []
joint_list_list = []
to_plot_list = []
result_img_list = []
heatmaps_list = []
orgimg_list = []

model.eval()
with torch.no_grad():
  mb_num = np.ceil(len(video_data) / batch_size).astype(int)
  for index in range(mb_num):

    mb_img_list = video_data[index * batch_size:(index+1)*batch_size]
    imgs = torch.empty((len(mb_img_list), 3, 368, 368), dtype=torch.float)

    for i, frame in enumerate(mb_img_list):
      img = preprocess(frame)
      imgs[i, ...] = img[...]

    predicted_outputs, _ = model(imgs.to(device))

    for pafs_tensor, heatmaps_tensor, img in zip(predicted_outputs[0], predicted_outputs[1], mb_img_list):
      heatmaps = heatmaps_tensor.cpu().detach().numpy().transpose(1, 2, 0)
      orgimg_list.append(img)
      heatmaps_list.append(heatmaps)

ただ姿勢検知するだけだとつまらないので、フォーム判定もして見ようと思います。
今回は、腰の位置が設定した高さよりも高かったら赤い枠、低かったら緑の枠を出力するように工夫してみます。

設定した高さは見やすいよう、青線で描画します。
これも動画の各フレームに対して描画していきます。

def create_joints(img, heatmap, necessary_parts):
  img = img.copy()

  # ジョイントの座標を検出する
  joints = []
  param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
  for part in necessary_parts:
      heat_map = heatmaps[:, :, part]
      peaks = find_peaks(param, heat_map)
      if len(peaks) == 0:
        joints.append([np.nan, np.nan])
      if peaks.shape[0]>1:
        max_peak = None
        for peak in peaks: 
          val = heat_map[peak[1], peak[0]]
          if max_peak is None or heat_map[max_peak[1], max_peak[0]] < val:
            max_peak = peak
      else:
        max_peak = peaks[0]
      joints.append(max_peak)

  # 各ジョインを結ぶ
  for i in range(len(joints)-1):
    img = cv2.line(img,tuple(joints[i].astype(int)),tuple(joints[i+1].astype(int)),(255,0,0),3)

  #位置(高さ)を指定
  height_border = 1100
  #ボーダーの線を引く
  img = cv2.line(img, (0, height_border), (img.shape[1], height_border), (0,0,255), 10)
  #分かりやすいように腰部分に点を描画
  img = cv2.circle(img,tuple(joints[1].astype(int)), 1, (255,0,0), -1)

  #指定した位置より低い場合はフレームの色を変える
  frame = [[0,0],[0,img.shape[0]],[img.shape[1],img.shape[0]],[img.shape[1],0],[0,0]]
  if (joints[1][1] <= height_border):
    for i in range(len(frame)-1):
      img = cv2.line(img,tuple(frame[i]),tuple(frame[i+1]),(255,0,0),100)
  else:
    for i in range(len(frame)-1):
      img = cv2.line(img,tuple(frame[i]),tuple(frame[i+1]),(0,255,0),100)

  return img

imgs =[]

for img, hm in zip(orgimg_list, heatmaps_list):
  hm = cv2.resize(hm, orgsize_wh, interpolation=cv2.INTER_CUBIC)
  img = cv2.resize(img, orgsize_wh, interpolation=cv2.INTER_CUBIC)

  img = create_joints(img, hm, necessary_parts)
  imgs.append(img)

出来た画像を動画にします。

def create_video(imgs, out_video_path, size_wh, fps=30):

  vid_out = skvideo.io.FFmpegWriter(out_video_path,
      inputdict={
          '-r': str(fps)
      },
      outputdict={
          '-r': str(fps)
      })

  for img in imgs:
    img = cv2.resize(img, size_wh)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    vid_out.writeFrame(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

  vid_out.close()

  create_video(imgs, out_video_path, size_wh=orgsize_wh, fps=30)

結果がこちら。

スクワット

動画でも姿勢検知することができました。

3. おわりに

今回はスクワットのフォームを姿勢検出 + モーション検知のようなこともできました。今度は違う動きも姿勢検知してみたいです。

これで、理想のフォームを追求していきます。

最後までお読みいただきありがとうございました!