PytorchのモデルをTFLiteに変換する際にうまく変換できない時に試したこと

現在、画像処理や自然言語処理など様々な分野でDeepLearningを用いた研究や開発がなされており、私も画像処理の分野と音の分野でモデル開発を行っています。

そして、ある程度モデルが完成してくるとスマートフォンに搭載したり、サーバーで処理を実行させたりするようになると思います。

そんな時に便利なツールがTFLiteです。

TFLiteは様々なデリゲートがあり、スマホのGPU処理などを丸投げすることができたり、Tensorflowで開発されたモデルをすぐにTFLiteモデルに変換してスマホで処理させることができたりといろいろと便利なツールです。

しかし、研究分野においては今現在Pytorchで開発されたモデルが多く、これには様々な理由があげられますが、その一つに入力サイズを固定しなくてもいいという利点があげられます。

ゆえに、最新のモデルなどはPytorchなどをベースに構築されているものが多いと言えるでしょう。

さて、こういった背景より、Pytorchで作成したモデルをTFLiteに変換してスマートフォンで実行したいというニーズが生まれるのは必然であり、しかしながら現時点においてはPytorchのモデルを直接TFLiteに変換する方法が存在しません。

そのため、現在主流となっている変換方法としてONNXを仲介させてTFLiteに変換する方法が採られています。

ONNXとは、Open Neural Network eXchangeの略で、機械学習向けに開発されたオープンフォーマットです。

現時点ではPytorchやTensorflow、ML Coreなど、さまざまな形式で機械学習の開発が進んでおり、それを統一しようという流れの一つです。OSに例えるとまさにLinuxのような立ち位置ですね。

ちなみに、ONNXの読み方は「オニキス」と呼ぶそうです。宝石のオニキスと同じ発音ですね。

ONNXを仲介したPytorchからTFLiteの変換

さて、ではモデルの変換をPytorch→ONNX→TFLiteという流れで行うわけですが、ここでいろいろと注意点があります。

まず、バージョンの互換性、そしてPytorchで使用する機能の互換性です。

Pytorchのバージョンに関しては開発の都合でバージョンが制限されているということがよくあるかと思います。しかし、そのバージョンと仲介するONNXのバージョン、またONNXからTFLiteへ変換するライブラリとのバージョンが不整合を起こし、うまく変換できないことがあります。

なので、私の場合はcondaの仮想環境を作成し、ネット上で動作しているバージョンなどをいろいろ試して対応しています。経験上、Pytorchのバージョンは基本的に最新のものを利用しても今のところ問題なさそうです。気を付けるのはONNXとonnx-to-kerasやonnx-to-tensorflowなどのバージョン互換性なので、動かない場合はいろいろとバージョンを上げながら調整したほうがよさそうです。

ここで、ONNX→TFLiteの変換ですが、実はここにも結構つまづくポイントが多いです。

ONNX→TFLiteには私が試した方法ではkerasを仲介するパターンとtensorflowを仲介するパターンがありました。

kerasにも数種類あり、一般的にはマルチバックエンドと呼ばれる形で用途によってユーザーが選択することができました。現在はTensorflow.keras に一本化する流れのようですが、開発環境によっては多少違いがあるかもしれないので注意してください。ここではTensorflowバックエンドのkerasを使用することを前提としてお話します。

ONNX→keras→TFLiteでうまくいくこともありますが、モデルによってはうまく変換できない場合があり、その場合にはONNX→Tensorflow→TFLiteという流れで変換を行うことでうまくいくことがあるようです。

しかし、変換できたからと言って本当にそれが正しく変換されているとは限らないというのはまた躓きやすいポイントです。

実際にはまったポイント1- ONNX2Kerasが動かない

ここで私が実際にはまったのは、UNETなどのモデルで最近よく用いられているアップサンプリングレイヤーの変換でした。

逆畳み込み層などとしても知られると思いますが、私の場合はもともとこの実装にnn.Upsampleを使っていました。

(※ちなみにnn.Upsampleは畳み込みは行っておらず、単に出力をbilinearやlinearなど指定した方法で拡張するだけです)

まず試したのがnn.Upsample実装の状態でPytorch→ONNX→kerasと変換をさせることでしたが、ONNX→kerasで止まってしまうという状態に陥りました。

この原因として、このnn.UpsampleをONNX→kerasと変換する際に使用するonnx2keras.onnx_to_keras()という処理でnn.Upsampleに対応した変換がまだ実装されていないらしく、

ここをonnx2kerasを使うのをやめてonnx2tfを使うことでとりあえず一時解決しました。

ちなみにonnx2kerasはpip install でインストール可能です。これも、とりあえず最新のものをインストールして、動かなかった場合にバージョンを調整するのがいいと思います。

onnx2tfも同様にpip でインストール可能です。

下のほうに後でコードを載せますが、コメントアウトしている部分がonnx2kerasとなります。

必要であれば参考にしてみてください。

実際にはまったポイント2- 変換されたモデルが動かない

さて、onnxからの変換を乗り越え、無事TFLiteに変換できたはいいものの、今度はこれが実機で動かないという問題が発生しました。

私の場合はAndroidにTFLiteをC++ベースで実装して動かしているのですが、先ほどのnn.Upsampleで実装したものだとモデルの読み込みに失敗してしまいました。

ということで今度はUpsampleを変更し、別の手法でUpsamplerを実装することにしました。

まず、私が試した方法がPixelShuffleレイヤでの実装です。

PixelShuffleの詳細は省きますが、nn.PixelShuffle(scale_factor=2)というような実装だと、

[N,C,H,W]→[N,C/4,H/2,W/2]という出力が得られます。チャネルがスケール係数の2乗分の1になるのでそこを気を付けないといけませんが、モデルアーキテクチャ的にはやりたいことは同じです。(厳密には違いますが、そこは各自対応してください)

これは最近主流になってきたアップサンプリング手法らしいので、変換する際には転置を用いて内部的に対応しているらしく、これを実装したものをAndroidで実行しようとした際に

“Transpose op only supports 1D-5D input arrays. “というエラーが出てきてしまうようになりました。

そもそも元のモデルだと[N, C, H, W]の4次元の構成なのにいったいどこで6次元以上になっているのか、と思って調べてみると、どうやらPixelShuffleをうまく変換できずに転置に変換した際に6次元入力になるように変換されていました。この転置処理はTFLiteライブラリにおいてtransposeという処理で実行されますが、私の作ったモデルを変換した際に6次元テンソル入力になるように変換されてしまったためにこのようなエラーが出るようになったということですね。

ちなみに、変換前後のモデルを視覚的に比較するために私はnetronというツールを使っています。これでモデルの重みを保存したファイル(pytorchの場合.pthファイル、など)を視覚化することができます。

netronはpth、onnx、tfliteなど様々な形式に対応しているため、変換したモデルの構造がおかしなことになっていないかを調査するときに役に立ちます。よかったらお試しください。

少し話は逸れましたが、とりあえずPixelShuffleもだめだということで次に試したのがConvTranspose2dという処理でした。結論から言うと、これは問題なくTFLiteに変換でき、ちゃんとAndroidでも動作しました。

このConvTranspose2dですが、結構パラメータが多くめんどくさいので取り扱いに注意です。

こちらの方が説明されている記事がわかりやすかったためこちらに紹介させていただきますが、nn.ConvTranspose2d(入力チャネル、出力チャネル、kernel_size=2, stride = 2)とすれば、画像入力を前提とした場合のHとWが2倍となるような出力が得られます。PixelShuffleと同じような出力にする場合には出力チャネルを入力の4分の1にするとおなじテンソル形状となります。

おそらく、この処理が一般的に言われるDeconvolution(逆畳み込み)だと思います。

ちなみに、nn.Upsampleの場合、HとWを倍にしますが、チャネルの数は入力と同じですので、書き換えの際は注意してください。

私の場合はモデルを変更することができますので、入出力チャネルの数を変更することも出来ました。しかし、PixelShuffleのように出力チャネル数が入力に依存してしまうものの場合は取り扱いに注意しなければなりません。

また、nn.Upsampleという処理は重みの計算は行っておらず、対してnn.PixelShuffleとnn.ConvtTranspose2d重みの計算が行われています。なので、おそらくPixelShuffleの処理をConvTranspose2dで書き換えて同じ重みのファイルを使うということはできますが、Upsampleで書かれたものを書き換えることはできない可能性が高いです。

ポイントまとめ

さて、アップサンプルレイヤを変更してこのモデルを変換したところ、無事Androidでも動いてくれるようになりました。

気を付けるべき点としては

・バージョンの互換性に気を付ける

・onnx2kerasでうまくいかない時はonnx2tfを使う

・モデル内のアップサンプルレイヤには気を付ける。できればとりあえずConvTranspose2dで実装する。

といったところでしょうか。

ダイレクトにPytorchからTFLiteに変換する手法が出てくればもっとトラブルは減るかと思いますが、途中にONNXを経由しているため何かとトラブルが多発するのは仕方ないかと思います。

変換処理の実装例

最後に、現在使用しているモデルをうまく変換できているコードを参考程度に書いておきます。

ここで変換できているモデルはVGG19のようなシンプルな構造と、nn.ConvTranspose2dでアップサンプル処理を加えたものが動作確認済みです。nn.PixelShuffleはうまく動きませんでした。

ファイル構造として以下のポイントに注意しつつ参考にしてみてください。

・PytorchのモデルをNETクラスとして定義しているファイルをmodels_pytorch/に配置

・学習済み重みファイル(~.pth)をmodels_pytorch/に配置

・その他出力用のディレクトリを自動で生成する処理は書いていないため各自対応

modelの例

class NET(nn.Module):
def __init__(self, opt):
super().__init__()
構造を定義

def forward(self, input):
  処理を定義

ここで定義したものをpytorchの読み込み時にimportして使う。

converter.py

import models_pytorch.model as module # ex) models_pytorch/model.py
import torch
import onnx
import tensorflow as tf
import cv2
from tensorflow.python.keras import backend as K

def converter(model_path=”./path/to/weight.pth”,
tflite_path=”./path/to/output.tflite”,
output_check = False):
input_img=cv2.imread(“./imgs/test_image.jpg”)#test image
image_h,image_w = [128,128]
img=cv2.resize(input_img,(image_h,image_w))#test lr image
img=img.transpose([2,1,0])#transpose channel last to channel first
img=img.reshape([1,3,image_h,image_w])/256. #nparray to torchTensor
dummy_input = torch.Tensor(img)#torch.randn(1, 3, 128, 128)#same size as input of PytorchModel(n,c,h,w)

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
pytorch_model = module.NET().to(device)
checkpoint = torch.load(model_path)

pytorch_model.load_state_dict(checkpoint[‘state_dict’])
onnx_path=”./onnx.onnx”
input_names = [‘input_array’] #this name is to inputname of samed_model.pb
output_names = [‘out_array’]

torch_out = pytorch_model(dummy_input)
print(“start torch2onnx”)
torch.onnx.export(pytorch_model,dummy_input,onnx_path,verbose=False,
export_params=True,do_constant_folding=True,
input_names=input_names, output_names=output_names,
opset_version = 13,
)

#————-
#onnx2pb
keras_path_dir=”./keras_model” #temporary keras model output
onnx_model = onnx.load(onnx_path)

from onnx_tf.backend import prepare

tf_rep = prepare(onnx_model)
tf_rep.export_graph(keras_path_dir)

“””
K.set_learning_phase(0)
with K.get_session() as sess:
print(“start onnx2keras”)
k_model = onnx_to_keras(onnx_model=onnx_model, input_names=input_names,
change_ordering=True, verbose=False)
print(“onnx to keras convert finish”)
weights = k_model.get_weights()

# To avoid FailedPreconditionError
#init = tf.compat.v1.global_variables_initializer()
#sess.run(init)

# If this is left as it is, the weight information has disappeared, so reload the weight.
k_model.set_weights(weights)
if os.path.isdir(keras_path_dir):
shutil.rmtree(keras_path_dir)
#tf.compat.v1.saved_model.simple_save(
tf.saved_model.save(
#sess,
k_model,
keras_path_dir,
#inputs={‘image_array’: k_model.input},
#outputs={‘sr_array’: k_model.output}
)
“””
if output_check:

keras_out=K.run(torch.Tensor(img.trancepose([0,3,2,1])).trancepose())[0]
save_img(keras_out,”keras_out”)
#—-
#check model output between onnx and torch
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession(onnx_path)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# Inference with ONNX Runtime
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

#compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print(“Exported model has been tested with ONNXRuntime, and the result looks good!”)

print(“———-“)
print(to_numpy(torch_out))

save_img(to_numpy(torch_out),”torch_out”)

print(“———-“)
print(ort_outs[0])
save_img(ort_outs[0],”onnx_out”)
#—-

#————
#pb2tflite
converter = tf.lite.TFLiteConverter.from_saved_model(keras_path_dir)

converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]

tf_lite_model = converter.convert()
# Save the model.
with open(tflite_path, ‘wb’) as f:
f.write(tf_lite_model)
#—————

def save_img(arr,name):
out_img=(arr.transpose([0,3,2,1])).reshape([128,128,3])
out_img=out_img*256
out_img=cv2.resize(out_img,(256*2,256*2))
cv2.imwrite(“./imgs/”+name+”.jpg”, out_img)#-input_img)

if __name__==”__main__”:
converter(“path/to/weight.pth”, “path/to/output.tflite”, False)

torch.onnx.exportのopset_versionという引数ですが、ここは何かエラーが出た際にまず調整してみるといいかと思います。

デフォルトは9となっていますが、私の場合PixelShuffleを変換する際に13にして変換ができました(結果としてPixelShuffleは動かなかったので意味はありませんでしたが。)

13の設定のままConvTranspose2dの実装を変換する際にも使用したため、もしかするとデフォルトの9でも動くかもしれませんが、動作確認済みがopset_version = 13だったのでこちらで書いています。

変換したもモデルでちゃんと画像が変換できるかの確認のためにtest_images/test_image.jpgを用意しています。converter()の第三引数をTrueにすると実際に出力の確認が行えます。

以上、PytorchをTFLiteに変換する際の備忘録として今回は雑ではありますがまとめさせていただきました。

何か参考になれば幸いです。

コメント

タイトルとURLをコピーしました