|
| 1 | +--- |
| 2 | +title: Wavenet実装の詳細 |
| 3 | +date: "2019-06-08T23:00:00.000Z" |
| 4 | +tags: ["Wavenet", "音声合成", "機械学習", "TensorFlow"] |
| 5 | +--- |
| 6 | + |
| 7 | +Wavenetの実装を公開しました. |
| 8 | + |
| 9 | + |
| 10 | +`card:https://github.com/kokeshing/WaveNet-Estimator` |
| 11 | + |
| 12 | + |
| 13 | +この記事では前の記事で説明していなかった実装した上で取り入れたこと,つまづいたことについて書こうと思います. |
| 14 | + |
| 15 | +## 1.ソースの構造について |
| 16 | + |
| 17 | +リポジトリの各ファイル,ディレクトリについて説明します. |
| 18 | + |
| 19 | +### 1. /hparams.py |
| 20 | + |
| 21 | +wavのサンプリングレートやモデルのinput_type,モデルの保存ディレクトリなど |
| 22 | +必要なhparamsを設定しています. |
| 23 | + |
| 24 | +### 2. /audio.py |
| 25 | + |
| 26 | +wavファイルからデータを抽出する,wavファイルに書き出す,メルスペクトログラムを抽出する,など |
| 27 | +音声ファイルや音声データに関わる関数を実装しています. |
| 28 | +このファイルもhparams.pyの設定値を直接参照しないようにしています. |
| 29 | + |
| 30 | + |
| 31 | +### 3. /preprocess.py |
| 32 | + |
| 33 | +その名の通り前処理を行います. |
| 34 | +audio.py内で実装した関数を用いてwavを前処理を施してからtfrecordに保存するようにしています. |
| 35 | +hparams.pyの設定値を参照し,適切にaudio.pyの関数へ引数として渡しています. |
| 36 | +tfrecordについてはまた後述します. |
| 37 | + |
| 38 | +### 4. /dataset.py |
| 39 | + |
| 40 | +preprocess.pyで作成したtfrecordを読み込んでパース,wavの入力長になるようランダムに切り出し, |
| 41 | +マスクの取得などをしてモデルへと流し込むinput_fn, eval_fnを実装しています. |
| 42 | +hparams.pyの設定値を直接参照します. |
| 43 | + |
| 44 | +### 5. /wavenet/model.py /wavenet/module.py |
| 45 | + |
| 46 | +wavenetディレクトリはwavenetのモデルとモデルを組む際に使うパーツ群を定義しています. |
| 47 | +model.pyでは |
| 48 | + |
| 49 | +- tf.keras.modelを継承したWavenetモデル |
| 50 | +- tf.estimatorでmodel\_fnとするwavenet\_fn |
| 51 | + |
| 52 | +を実装しています. |
| 53 | +model.py, module.pyともにhparams.pyを受け取るよう実装しています.(設定値を直接参照しない) |
| 54 | +Wavenetモデルでは基本的にモデルの構築しか行わず,lossやoptimaizerはwavenet\_fn内で学習,検証,推論 |
| 55 | +のフェーズに応じて呼び出すメソッドを出し分けるようにしています. |
| 56 | +入力データの整形などもwavenet\_fnで行っています. |
| 57 | + |
| 58 | +### 6. /train.py |
| 59 | + |
| 60 | +hparams.pyの設定値を参照しcustom estimatorを設定値を元に作成して学習をするファイルです. |
| 61 | + |
| 62 | +### 7. /synthesize.py |
| 63 | + |
| 64 | +hparams.pyの設定値を参照しcustom estimatorを設定値を元に作成して合成を行うファイルです. |
| 65 | +メルスペクトログラムと保存先のパスを受け取ってwavファイルを保存する関数をhparamsで設定した |
| 66 | +テストデータのディレクトリのwavファイルすべてに実行しています. |
| 67 | + |
| 68 | +## 2. 用いた機能について |
| 69 | + |
| 70 | +ここまで,ファイルの構造について書いてきました. |
| 71 | +次は意識して実装した部分ついて書きます. |
| 72 | + |
| 73 | +### 1. tf.estimatorを使う |
| 74 | + |
| 75 | +今回,tf.estimatorを使ってモデルを組みました.tf.estimatorはTensorFlowの高レベルAPIで |
| 76 | +学習や推論,モデルのexportなどが容易にできるように設計されています. |
| 77 | +マルチGPUについてもtf.contrib.distribute.MirroredStrategyを用いることによって |
| 78 | +モデル側のコードは一切変更せずにそのままマルチGPUで学習を行えます. |
| 79 | + |
| 80 | + |
| 81 | +### 2. tf.dataを用いて学習・検証データの読み込みを行う |
| 82 | + |
| 83 | +これはtf.estimatorの制約でもあるのですが |
| 84 | +preprocess.pyとdataset.pyではtf.dataを用いることを念頭に置いた設計になっています. |
| 85 | +公式が |
| 86 | + |
| 87 | +`card:https://www.tensorflow.org/guide/datasets` |
| 88 | + |
| 89 | +で説明されていますが,tf.dataはGPUでの計算中にCPUやメモリを読み込むなど |
| 90 | +いい感じに入力前の値の加工,シャッフル,バッチサイズに切り出しなどを行ってくれます. |
| 91 | +tf.estimatorでない通常のSessionを使うようなコードでも利用できるので是非利用してみてください. |
| 92 | + |
| 93 | +### 3. 学習時もBTC(NHWC)の形式で実装する |
| 94 | + |
| 95 | +これは趣味です. CUDAカーネルはBCT(BCHW)の形に最適化されているらしいので下手したら非効率です. |
| 96 | +TensorFlowの標準がBTC(BHWC)だからとかそういう理由です. |
| 97 | + |
| 98 | + |
| 99 | +以上の点を意識して実装しました. |
| 100 | + |
| 101 | +## 3. はまったところ |
| 102 | + |
| 103 | +いくつかはまったところがありましたので書いておきたいと思います. |
| 104 | + |
| 105 | +### 1. tf.estimatorが勝手に入力テンソルのランクを1にする |
| 106 | + |
| 107 | +tf.estimatorではinput_fnを(features, label)を返すと設計されており,たとえばfeaturesを |
| 108 | + |
| 109 | +```python |
| 110 | + features = {x: [1, 128, 128], y: [16, 24]} |
| 111 | + return features, label |
| 112 | +``` |
| 113 | + |
| 114 | +とした場合wavenet\_fnに入ってくるテンソルはflattenされた[batch\_size, 17223(128\*128+16\*24)]になってしまいます. |
| 115 | +TensorFlowのソースを覗いたところ,features.keys()でforを回して |
| 116 | +valueをすべて取得し一括にしているのでこういう動きになってしまうのかと思っています. |
| 117 | +そこでかなり強引ですがdataset.pyでは |
| 118 | + |
| 119 | +```python |
| 120 | + return {"x": {"x": inputs}, "c": {"c": mel_sp}, "mask": {"mask": mask}}, targets |
| 121 | +``` |
| 122 | + |
| 123 | +とすることによってwavenet_fn内で |
| 124 | + |
| 125 | +```python |
| 126 | + inputs = tf.feature_column.input_layer(features["x"], feature_columns[0]) |
| 127 | + max_time_len = feature_columns[0].shape[0] |
| 128 | + mask = tf.feature_column.input_layer(features["mask"], feature_columns[1]) |
| 129 | + c = tf.feature_column.input_layer(features["c"], feature_columns[2]) |
| 130 | + num_mels = feature_columns[2].shape[0] |
| 131 | + max_time_frames = feature_columns[2].shape[1] |
| 132 | +``` |
| 133 | + |
| 134 | +という感じで受け取ることによってwavとmel\_spが一緒にflatten()されて渡されてこない |
| 135 | +ようにしています.(ただし各wavやmel\_spでflattenされているのでreshapeで元の形に戻す必要あり) |
| 136 | + |
| 137 | +これについてはもっといい書き方をご存知でしたらぜひともご指摘お願いします. |
| 138 | + |
| 139 | +### 2. Fast WaveNetで提案された一度畳み込んだ結果を保持するQueueをLayerで持っていると変数の関係上怒られる |
| 140 | + |
| 141 | +はじめ, module.pyでのCasualConvでは |
| 142 | + |
| 143 | +```python |
| 144 | + class CasualConv(): |
| 145 | + def __init__(hoge): |
| 146 | + hoge |
| 147 | + |
| 148 | + def call(self, inputs, is_incremental=False): |
| 149 | + if is_incremental: |
| 150 | + enqueue = tf.concat([self.queue[:, 1:, :], |
| 151 | + tf.expand_dims(inputs[:, -1, :], axis=1)], |
| 152 | + axis=1) |
| 153 | + assign = self.queue.assign(enqueue) |
| 154 | + |
| 155 | + with tf.control_dependencies([assign]): |
| 156 | + if self.dilation_rate > 1: |
| 157 | + inputs_ = self.queue[:, 0::self.dilation_rate, :] |
| 158 | + else: |
| 159 | + inputs_ = self.queue |
| 160 | + |
| 161 | + def initialize(self, batch_size, residual_channels): |
| 162 | + with tf.variable_scope(self.scope) as scope: |
| 163 | + self.queue = tf.Variable(tf.zeros((batch_size, |
| 164 | + self.kw + (self.kw - 1) * (self.dilation_rate - 1), |
| 165 | + residual_channels), |
| 166 | + dtype=tf.float32), |
| 167 | + name="queue_{}".format(self.scope)) |
| 168 | +``` |
| 169 | + |
| 170 | +のようにしてqueueはCasualConvの各レイヤーのインスタンス側で初期化,保持するようにしていたのですが, |
| 171 | +checkpointからrestoreする際,queueはcheckpointにない値なためエラーが起きてしまいます. |
| 172 | +よって現在の実装のように外部にqueueを持ち,引数として受け取るように実装しました. |
| 173 | + |
| 174 | +### 4. 学習,推論について |
| 175 | + |
| 176 | +学習率とlossのグラフを貼っておきます.lossのグラフではオレンジがtrainで青がevalです. |
| 177 | + |
| 178 | + |
| 179 | + |
| 180 | + |
| 181 | +P100では1GPUで1日200k程度,2GPUで1日300k程度でした. |
| 182 | +lossは400kの時点で十分下がりきっているように見えるので1.2Mも回す必要はあまりないと思います. |
| 183 | + |
| 184 | +batch_sizeや学習率もいろいろ試しましたがGANのようにハイパーパラメータに |
| 185 | +そこまで敏感ではないように思えたのでWaveNetの学習はそんな難しくないと思います. |
| 186 | + |
| 187 | +推論速度についてはP100で150step/secほどでした.22.5kHzの3秒程度の音声の合成に10分程度かかります. |
| 188 | + |
| 189 | +最後に,実装や実装以外でも参考にしたリンクをまとめておきます. |
| 190 | +大変参考になりました.ありがとうございます. |
| 191 | + |
| 192 | +- [Wavenet](https://arxiv.org/abs/1609.03499) |
| 193 | +- [Tactron-2](https://github.com/Rayhane-mamah/Tacotron-2) |
| 194 | +- [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder) |
| 195 | +- [WN-based TTSやりました](https://r9y9.github.io/blog/2018/05/20/tacotron2/) |
| 196 | +- [Synthesize Human Speech with WaveNet](https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/official_example/wavenet.html) |
| 197 | +- [VQ-VAEの追試で得たWaveNetのノウハウをまとめてみた。](https://www.monthly-hack.com/entry/2018/02/23/203208) |
| 198 | +- [複数話者WaveNetボコーダに関する調査](https://www.slideshare.net/t_koshikawa/wavenet-87105461) |
| 199 | + |
| 200 | + |
| 201 | +この記事がこれからの実装の参考になれば幸いです. |
0 commit comments