Skip to content

Commit ebe87bb

Browse files
committed
add post
1 parent 910283d commit ebe87bb

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed

.gitignore

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Logs
2+
logs
3+
*.log
4+
npm-debug.log*
5+
yarn-debug.log*
6+
yarn-error.log*
7+
8+
# Runtime data
9+
pids
10+
*.pid
11+
*.seed
12+
*.pid.lock
13+
14+
# Directory for instrumented libs generated by jscoverage/JSCover
15+
lib-cov
16+
17+
# Coverage directory used by tools like istanbul
18+
coverage
19+
20+
# nyc test coverage
21+
.nyc_output
22+
23+
# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
24+
.grunt
25+
26+
# Bower dependency directory (https://bower.io/)
27+
bower_components
28+
29+
# node-waf configuration
30+
.lock-wscript
31+
32+
# Compiled binary addons (http://nodejs.org/api/addons.html)
33+
build/Release
34+
35+
# Dependency directories
36+
node_modules/
37+
jspm_packages/
38+
39+
# Typescript v1 declaration files
40+
typings/
41+
42+
# Optional npm cache directory
43+
.npm
44+
45+
# Optional eslint cache
46+
.eslintcache
47+
48+
# Optional REPL history
49+
.node_repl_history
50+
51+
# Output of 'npm pack'
52+
*.tgz
53+
54+
# dotenv environment variables file
55+
.env
56+
57+
# gatsby files
58+
.cache/
59+
public
60+
61+
# Mac files
62+
.DS_Store
63+
64+
# Yarn
65+
yarn-error.log
66+
.pnp/
67+
.pnp.js
68+
# Yarn Integrity file
69+
.yarn-integrity

content/blog/wavenet-detail/index.md

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
---
2+
title: Wavenet実装の詳細
3+
date: "2019-06-08T23:00:00.000Z"
4+
tags: ["Wavenet", "音声合成", "機械学習", "TensorFlow"]
5+
---
6+
7+
Wavenetの実装を公開しました.
8+
9+
10+
`card:https://github.com/kokeshing/WaveNet-Estimator`
11+
12+
13+
この記事では前の記事で説明していなかった実装した上で取り入れたこと,つまづいたことについて書こうと思います.
14+
15+
## 1.ソースの構造について
16+
17+
リポジトリの各ファイル,ディレクトリについて説明します.
18+
19+
### 1. /hparams.py
20+
21+
wavのサンプリングレートやモデルのinput_type,モデルの保存ディレクトリなど
22+
必要なhparamsを設定しています.
23+
24+
### 2. /audio.py
25+
26+
wavファイルからデータを抽出する,wavファイルに書き出す,メルスペクトログラムを抽出する,など
27+
音声ファイルや音声データに関わる関数を実装しています.
28+
このファイルもhparams.pyの設定値を直接参照しないようにしています.
29+
30+
31+
### 3. /preprocess.py
32+
33+
その名の通り前処理を行います.
34+
audio.py内で実装した関数を用いてwavを前処理を施してからtfrecordに保存するようにしています.
35+
hparams.pyの設定値を参照し,適切にaudio.pyの関数へ引数として渡しています.
36+
tfrecordについてはまた後述します.
37+
38+
### 4. /dataset.py
39+
40+
preprocess.pyで作成したtfrecordを読み込んでパース,wavの入力長になるようランダムに切り出し,
41+
マスクの取得などをしてモデルへと流し込むinput_fn, eval_fnを実装しています.
42+
hparams.pyの設定値を直接参照します.
43+
44+
### 5. /wavenet/model.py /wavenet/module.py
45+
46+
wavenetディレクトリはwavenetのモデルとモデルを組む際に使うパーツ群を定義しています.
47+
model.pyでは
48+
49+
- tf.keras.modelを継承したWavenetモデル
50+
- tf.estimatorでmodel\_fnとするwavenet\_fn
51+
52+
を実装しています.
53+
model.py, module.pyともにhparams.pyを受け取るよう実装しています.(設定値を直接参照しない)
54+
Wavenetモデルでは基本的にモデルの構築しか行わず,lossやoptimaizerはwavenet\_fn内で学習,検証,推論
55+
のフェーズに応じて呼び出すメソッドを出し分けるようにしています.
56+
入力データの整形などもwavenet\_fnで行っています.
57+
58+
### 6. /train.py
59+
60+
hparams.pyの設定値を参照しcustom estimatorを設定値を元に作成して学習をするファイルです.
61+
62+
### 7. /synthesize.py
63+
64+
hparams.pyの設定値を参照しcustom estimatorを設定値を元に作成して合成を行うファイルです.
65+
メルスペクトログラムと保存先のパスを受け取ってwavファイルを保存する関数をhparamsで設定した
66+
テストデータのディレクトリのwavファイルすべてに実行しています.
67+
68+
## 2. 用いた機能について
69+
70+
ここまで,ファイルの構造について書いてきました.
71+
次は意識して実装した部分ついて書きます.
72+
73+
### 1. tf.estimatorを使う
74+
75+
今回,tf.estimatorを使ってモデルを組みました.tf.estimatorはTensorFlowの高レベルAPIで
76+
学習や推論,モデルのexportなどが容易にできるように設計されています.
77+
マルチGPUについてもtf.contrib.distribute.MirroredStrategyを用いることによって
78+
モデル側のコードは一切変更せずにそのままマルチGPUで学習を行えます.
79+
80+
81+
### 2. tf.dataを用いて学習・検証データの読み込みを行う
82+
83+
これはtf.estimatorの制約でもあるのですが
84+
preprocess.pyとdataset.pyではtf.dataを用いることを念頭に置いた設計になっています.
85+
公式が
86+
87+
`card:https://www.tensorflow.org/guide/datasets`
88+
89+
で説明されていますが,tf.dataはGPUでの計算中にCPUやメモリを読み込むなど
90+
いい感じに入力前の値の加工,シャッフル,バッチサイズに切り出しなどを行ってくれます.
91+
tf.estimatorでない通常のSessionを使うようなコードでも利用できるので是非利用してみてください.
92+
93+
### 3. 学習時もBTC(NHWC)の形式で実装する
94+
95+
これは趣味です. CUDAカーネルはBCT(BCHW)の形に最適化されているらしいので下手したら非効率です.
96+
TensorFlowの標準がBTC(BHWC)だからとかそういう理由です.
97+
98+
99+
以上の点を意識して実装しました.
100+
101+
## 3. はまったところ
102+
103+
いくつかはまったところがありましたので書いておきたいと思います.
104+
105+
### 1. tf.estimatorが勝手に入力テンソルのランクを1にする
106+
107+
tf.estimatorではinput_fnを(features, label)を返すと設計されており,たとえばfeaturesを
108+
109+
```python
110+
features = {x: [1, 128, 128], y: [16, 24]}
111+
return features, label
112+
```
113+
114+
とした場合wavenet\_fnに入ってくるテンソルはflattenされた[batch\_size, 17223(128\*128+16\*24)]になってしまいます.
115+
TensorFlowのソースを覗いたところ,features.keys()でforを回して
116+
valueをすべて取得し一括にしているのでこういう動きになってしまうのかと思っています.
117+
そこでかなり強引ですがdataset.pyでは
118+
119+
```python
120+
return {"x": {"x": inputs}, "c": {"c": mel_sp}, "mask": {"mask": mask}}, targets
121+
```
122+
123+
とすることによってwavenet_fn内で
124+
125+
```python
126+
inputs = tf.feature_column.input_layer(features["x"], feature_columns[0])
127+
max_time_len = feature_columns[0].shape[0]
128+
mask = tf.feature_column.input_layer(features["mask"], feature_columns[1])
129+
c = tf.feature_column.input_layer(features["c"], feature_columns[2])
130+
num_mels = feature_columns[2].shape[0]
131+
max_time_frames = feature_columns[2].shape[1]
132+
```
133+
134+
という感じで受け取ることによってwavとmel\_spが一緒にflatten()されて渡されてこない
135+
ようにしています.(ただし各wavやmel\_spでflattenされているのでreshapeで元の形に戻す必要あり)
136+
137+
これについてはもっといい書き方をご存知でしたらぜひともご指摘お願いします.
138+
139+
### 2. Fast WaveNetで提案された一度畳み込んだ結果を保持するQueueをLayerで持っていると変数の関係上怒られる
140+
141+
はじめ, module.pyでのCasualConvでは
142+
143+
```python
144+
class CasualConv():
145+
def __init__(hoge):
146+
hoge
147+
148+
def call(self, inputs, is_incremental=False):
149+
if is_incremental:
150+
enqueue = tf.concat([self.queue[:, 1:, :],
151+
tf.expand_dims(inputs[:, -1, :], axis=1)],
152+
axis=1)
153+
assign = self.queue.assign(enqueue)
154+
155+
with tf.control_dependencies([assign]):
156+
if self.dilation_rate > 1:
157+
inputs_ = self.queue[:, 0::self.dilation_rate, :]
158+
else:
159+
inputs_ = self.queue
160+
161+
def initialize(self, batch_size, residual_channels):
162+
with tf.variable_scope(self.scope) as scope:
163+
self.queue = tf.Variable(tf.zeros((batch_size,
164+
self.kw + (self.kw - 1) * (self.dilation_rate - 1),
165+
residual_channels),
166+
dtype=tf.float32),
167+
name="queue_{}".format(self.scope))
168+
```
169+
170+
のようにしてqueueはCasualConvの各レイヤーのインスタンス側で初期化,保持するようにしていたのですが,
171+
checkpointからrestoreする際,queueはcheckpointにない値なためエラーが起きてしまいます.
172+
よって現在の実装のように外部にqueueを持ち,引数として受け取るように実装しました.
173+
174+
### 4. 学習,推論について
175+
176+
学習率とlossのグラフを貼っておきます.lossのグラフではオレンジがtrainで青がevalです.
177+
178+
![学習率](./lr.png)
179+
![loss](./loss.png)
180+
181+
P100では1GPUで1日200k程度,2GPUで1日300k程度でした.
182+
lossは400kの時点で十分下がりきっているように見えるので1.2Mも回す必要はあまりないと思います.
183+
184+
batch_sizeや学習率もいろいろ試しましたがGANのようにハイパーパラメータに
185+
そこまで敏感ではないように思えたのでWaveNetの学習はそんな難しくないと思います.
186+
187+
推論速度についてはP100で150step/secほどでした.22.5kHzの3秒程度の音声の合成に10分程度かかります.
188+
189+
最後に,実装や実装以外でも参考にしたリンクをまとめておきます.
190+
大変参考になりました.ありがとうございます.
191+
192+
- [Wavenet](https://arxiv.org/abs/1609.03499)
193+
- [Tactron-2](https://github.com/Rayhane-mamah/Tacotron-2)
194+
- [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
195+
- [WN-based TTSやりました](https://r9y9.github.io/blog/2018/05/20/tacotron2/)
196+
- [Synthesize Human Speech with WaveNet](https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/official_example/wavenet.html)
197+
- [VQ-VAEの追試で得たWaveNetのノウハウをまとめてみた。](https://www.monthly-hack.com/entry/2018/02/23/203208)
198+
- [複数話者WaveNetボコーダに関する調査](https://www.slideshare.net/t_koshikawa/wavenet-87105461)
199+
200+
201+
この記事がこれからの実装の参考になれば幸いです.

content/blog/wavenet-detail/loss.png

38.1 KB
Loading

content/blog/wavenet-detail/lr.png

16.8 KB
Loading

0 commit comments

Comments
 (0)