Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low GPU Utilization during training #217

Open
ayushtues opened this issue Mar 15, 2024 · 10 comments
Open

Low GPU Utilization during training #217

ayushtues opened this issue Mar 15, 2024 · 10 comments

Comments

@ayushtues
Copy link

Hi, I have been trying to train a StyleTTS2 model from scratch on the LibriTTS 460 dataset, currently going through the first stage via train_first.py

The GPU utilisation of the training is very low ~30%. I am using a single H100 with batch_size = 8 and max_len = 300 to fit it on a single GPU.

Such low util means that the script is not using the GPU effeciently and there are potential bottlenecks to be addressed which can make the training faster.

Has anyone observed similar issues while training the model from scratch or has any ideas for improving the GPU util.

cc @yl4579

@lucasgris
Copy link

Yes, the same here, it seems there is a bottleneck, but using accelerate seems to help a little. Are you using accelerate? Try to set the num_processes.

image

@ayushtues
Copy link
Author

ayushtues commented Mar 18, 2024

Yes @lucasgris I am using accelerate and have played around with num_workers. Even in the graph you shared, the util hits very low points (<25% GPU util) consistently, any luck with improving that?

@lucasgris
Copy link

Not yet, but I think it is worth trying to identify where the code is slow, if I have any updates I will share here.

@Selectorrr
Copy link

Confirming the problem of low GPU utilization:
Снимок экрана 2024-03-27 в 17 30 06
It seems that some sort of computing on a single CPU core is a bottle neck:
Снимок экрана 2024-03-27 в 17 30 17

@borrero-c
Copy link

Also having this problem with train_finetune_accelerate.py. I haven't dug too deep but the accelerator.backward() calls seemed to be taking a very long time, specifically this code block

loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
g_loss = loss_params.lambda_mel * loss_mel + \
loss_params.lambda_F0 * loss_F0_rec + \
loss_params.lambda_ce * loss_ce + \
loss_params.lambda_norm * loss_norm_rec + \
loss_params.lambda_dur * loss_dur + \
loss_params.lambda_gen * loss_gen_all + \
loss_params.lambda_slm * loss_lm + \
loss_params.lambda_sty * loss_sty + \
loss_params.lambda_diff * loss_diff + \
loss_params.lambda_mono * loss_mono + \
loss_params.lambda_s2s * loss_s2s
running_loss += loss_mel.item()
accelerator.backward(g_loss)

@Selectorrr
Copy link

I tried the following options one by one:

  1. Without accelerator and with accelerator
  2. Increase the number of num_processes from 1 to 2
  3. Decrease max_len from 600 to 290
  4. Switch decoder from hifigan to istftnet
    Unsuccessfully.

@borrero-c
Copy link

borrero-c commented Mar 27, 2024

Also showing low GPU utilization and high single core CPU utilization

Screenshot from 2024-03-27 16-27-23

It also seems like the issue goes away after the first epoch is finished, my GPU will start being utilized and the CPU load becomes more distributed

@ayushtues
Copy link
Author

@borrero-c thanks for looking into this, I didn't seem to observe anything changing after 1 epoch, it stays low for me. Also accelerate.backward() call might be taking time since its doing the backward pass, that might be expected

@Selectorrr
Copy link

I did a little research and launched the profiler. Pay attention to the % of time

MAIN LOOP
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
 162         2          8.8      4.4      0.0      for epoch in range(start_epoch, 5):
 163         1          0.3      0.3      0.0          running_loss = 0
 164         1          3.8      3.8      0.0          start_time = time.time()
 165
 166         1       7624.1   7624.1      0.0          _ = [model[key].train() for key in model]
 167
 168         2       2430.4   1215.2      0.0          pgbar = tqdm(desc=f"Epoch {epoch + 1}/{epochs}", unit='Step', total=len(train_list) // batch_size, smoothing=0,
 169         1          0.1      0.1      0.0                       initial=1)
 170       102     525418.3   5151.2      0.3          for i, batch in enumerate(train_dataloader):
 171       102         73.2      0.7      0.0              if i > 100:
 172         1     265667.4 265667.4      0.1                  break
 173       101      36354.2    359.9      0.0              pgbar.update(1)
 174       101        917.4      9.1      0.0              waves = batch[0]
 175       101       5605.0     55.5      0.0              batch = [b.to(device) for b in batch[1:]]
 176       101       2789.0     27.6      0.0              texts, input_lengths, _, _, mels, mel_input_length, _ = batch
 177
 178       202       1903.5      9.4      0.0              with torch.no_grad():
 179       101      77350.0    765.8      0.0                  mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
 180       101      12146.6    120.3      0.0                  text_mask = length_to_mask(input_lengths).to(texts.device)
 181
 182       101   20453215.4 202507.1     10.2              ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
 183
 184       101        911.2      9.0      0.0              s2s_attn = s2s_attn.transpose(-1, -2)
 185       101       1402.6     13.9      0.0              s2s_attn = s2s_attn[..., 1:]
 186       101        334.4      3.3      0.0              s2s_attn = s2s_attn.transpose(-1, -2)
 187
 188       202       2396.6     11.9      0.0              with torch.no_grad():
 189       101      33850.8    335.2      0.0                  attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
 190       101      16570.0    164.1      0.0                  attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
 191       101       5279.7     52.3      0.0                  attn_mask = (attn_mask < 1)
 192
 193       101       3047.2     30.2      0.0              s2s_attn.masked_fill_(attn_mask, 0.0)
 194
 195       202       1703.6      8.4      0.0              with torch.no_grad():
 196       101      48330.6    478.5      0.0                  mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
 197       101     416141.6   4120.2      0.2                  s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
 198
 199                                                       # encode
 200       101    1624539.8  16084.6      0.8              t_en = model.text_encoder(texts, input_lengths, text_mask)
 201
 202                                                       # 50% of chance of using monotonic version
 203       101        416.9      4.1      0.0              if bool(random.getrandbits(1)):
 204        43       4864.5    113.1      0.0                  asr = (t_en @ s2s_attn)
 205                                                       else:
 206        58      12170.7    209.8      0.0                  asr = (t_en @ s2s_attn_mono)
 207
 208                                                       # get clips
 209       101       5637.2     55.8      0.0              mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
 210       101      14759.9    146.1      0.0              mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
 211       101       2895.5     28.7      0.0              mel_len_st = int(mel_input_length.min().item() / 2 - 1)
 212
 213       101        499.9      4.9      0.0              en = []
 214       101        521.1      5.2      0.0              gt = []
 215       101        432.4      4.3      0.0              wav = []
 216       101        427.5      4.2      0.0              st = []
 217
 218       909       1352.1      1.5      0.0              for bib in range(len(mel_input_length)):
 219       808      17282.6     21.4      0.0                  mel_length = int(mel_input_length[bib].item() / 2)
 220
 221       808       6093.9      7.5      0.0                  random_start = np.random.randint(0, mel_length - mel_len)
 222       808      12116.3     15.0      0.0                  en.append(asr[bib, :, random_start:random_start+mel_len])
 223       808       6309.5      7.8      0.0                  gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
 224
 225       808       1675.8      2.1      0.0                  y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
 226       808      69490.9     86.0      0.0                  wav.append(torch.from_numpy(y).to(device))
 227
 228                                                           # style reference (better to be different from the GT)
 229       808       5077.0      6.3      0.0                  random_start = np.random.randint(0, mel_length - mel_len_st)
 230       808       8738.8     10.8      0.0                  st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
 231
 232       101       5238.2     51.9      0.0              en = torch.stack(en)
 233       101       2928.9     29.0      0.0              gt = torch.stack(gt).detach()
 234       101       2246.8     22.2      0.0              st = torch.stack(st).detach()
 235
 236       101       7146.6     70.8      0.0              wav = torch.stack(wav).float().detach()
 237
 238                                                       # clip too short to be used by the style encoder
 239       101        202.3      2.0      0.0              if gt.shape[-1] < 80:
 240                                                           continue
 241
 242       202       2124.9     10.5      0.0              with torch.no_grad():
 243       101      44210.4    437.7      0.0                  real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
 244       101    2671261.3  26448.1      1.3                  F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
 245
 246       101    2978410.2  29489.2      1.5              s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
 247
 248       101   17613113.7 174387.3      8.8              y_rec = model.decoder(en, F0_real, real_norm, s)
 249
 250                                                       # discriminator loss
 251
 252       101         70.8      0.7      0.0              if epoch >= TMA_epoch:
 253       101     565364.9   5597.7      0.3                  optimizer.zero_grad()
 254       101   11707820.2 115919.0      5.8                  d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
 255       101   18847437.9 186608.3      9.4                  accelerator.backward(d_loss)
 256       101     313779.6   3106.7      0.2                  optimizer.step('msd')
 257       101     294492.0   2915.8      0.1                  optimizer.step('mpd')
 258                                                       else:
 259                                                           d_loss = 0
 260
 261                                                       # generator loss
 262       101     237334.6   2349.8      0.1              optimizer.zero_grad()
 263       101     282369.5   2795.7      0.1              loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
 264
 265       101         51.6      0.5      0.0              if epoch >= TMA_epoch: # start TMA training
 266       101        419.4      4.2      0.0                  loss_s2s = 0
 267       909      10903.2     12.0      0.0                  for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
 268       808      89627.5    110.9      0.0                      loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
 269       101       2396.6     23.7      0.0                  loss_s2s /= texts.size(0)
 270
 271       101      11985.0    118.7      0.0                  loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
 272
 273       101    6983523.9  69143.8      3.5                  loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
 274       101    4046595.1  40065.3      2.0                  loss_slm = wl(wav.detach(), y_rec).mean()
 275
 276       505     812033.9   1608.0      0.4                  g_loss = loss_params.lambda_mel * loss_mel + \
 277       101       1428.5     14.1      0.0                  loss_params.lambda_mono * loss_mono + \
 278       101       1285.5     12.7      0.0                  loss_params.lambda_s2s * loss_s2s + \
 279       101       1268.2     12.6      0.0                  loss_params.lambda_gen * loss_gen_all + \
 280       101       1230.7     12.2      0.0                  loss_params.lambda_slm * loss_slm
 281
 282                                                       else:
 283                                                           loss_s2s = 0
 284                                                           loss_mono = 0
 285                                                           loss_gen_all = 0
 286                                                           loss_slm = 0
 287                                                           g_loss = loss_mel
 288
 289       101      14339.2    142.0      0.0              running_loss += accelerator.gather(loss_mel).mean().item()
 290
 291       101   99737870.0 987503.7     49.6              accelerator.backward(g_loss)
 292
 293       101     199636.4   1976.6      0.1              optimizer.step('text_encoder')
 294       101     290944.4   2880.6      0.1              optimizer.step('style_encoder')
 295       101    2382230.7  23586.4      1.2              optimizer.step('decoder')
 296
 297       101         72.7      0.7      0.0              if epoch >= TMA_epoch:
 298       101     430973.2   4267.1      0.2                  optimizer.step('text_aligner')
 299                                                           # optimizer.step('pitch_extractor')
 300
 301       101         82.0      0.8      0.0              iters = iters + 1
 302
 303       101        386.2      3.8      0.0              if (i+1)%log_interval == 0 and accelerator.is_main_process:
 304        20       1296.7     64.8      0.0                  status = 'Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f' % (
 305        10         17.5      1.7      0.0                  epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, loss_gen_all,
 306        10          2.6      0.3      0.0                  d_loss, loss_mono, loss_s2s, loss_slm)
 307                                                           # log_print (status, logger)
 308        10       2629.4    262.9      0.0                  pgbar.set_postfix_str(status)
 309        10       1553.5    155.4      0.0                  writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
 310        10       2915.1    291.5      0.0                  writer.add_scalar('train/gen_loss', loss_gen_all, iters)
 311        10       1903.1    190.3      0.0                  writer.add_scalar('train/d_loss', d_loss, iters)
 312        10       2026.3    202.6      0.0                  writer.add_scalar('train/mono_loss', loss_mono, iters)
 313        10       1550.2    155.0      0.0                  writer.add_scalar('train/s2s_loss', loss_s2s, iters)
 314        10       1451.7    145.2      0.0                  writer.add_scalar('train/slm_loss', loss_slm, iters)
 315
 316        10         11.9      1.2      0.0                  running_loss = 0
 317
 318                                                           # print('Time elasped:', time.time()-start_time)
 319
 320         1          0.4      0.4      0.0          loss_test = 0
 321
 322         1       6907.6   6907.6      0.0          _ = [model[key].eval() for key in model]

If we exclude it as expected:
accelerator.backward()
This increases GPU utilization by about 20% but utilization remains uneven.
Снимок экрана 2024-03-29 в 10 20 12

If I additionally exclude line 182:
ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
This makes GPU utilization more uniform
Снимок экрана 2024-03-29 в 10 16 51

Additional performance details about text_aligner

ASRCNN
File: /app/Utils/ASR/models.py
Function: forward at line 37

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    37                                               @profile
    38                                               def forward(self, x, src_key_padding_mask=None, text_input=None):
    39       101     116166.6   1150.2      0.5          x = self.to_mfcc(x)
    40       101     742848.0   7354.9      3.2          x = self.init_cnn(x)
    41       101    4113638.1  40729.1     17.6          x = self.cnns(x)
    42       101     576767.4   5710.6      2.5          x = self.projection(x)
    43       101       1441.8     14.3      0.0          x = x.transpose(1, 2)
    44       101     102451.0   1014.4      0.4          ctc_logit = self.ctc_linear(x)
    45       101         51.5      0.5      0.0          if text_input is not None:
    46       101   17664905.7 174900.1     75.8              _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
    47       101         66.9      0.7      0.0              return ctc_logit, s2s_logit, s2s_attn
    48                                                   else:
    49                                                       return ctc_logit
ASRS2S
File: /app/Utils/ASR/models.py
Function: forward at line 118

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   118                                               @profile
   119                                               def forward(self, memory, memory_mask, text_input):
   120                                                   """
   121                                                   moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
   122                                                   moemory_mask.shape = (B, L, )
   123                                                   texts_input.shape = (B, T)
   124                                                   """
   125       101      73718.9    729.9      0.5          self.initialize_decoder_states(memory, memory_mask)
   126                                                   # text random mask
   127       101       8880.2     87.9      0.1          random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
   128       101       2321.3     23.0      0.0          _text_input = text_input.clone()
   129       101     273189.4   2704.8      1.7          _text_input.masked_fill_(random_mask, self.unk_index)
   130       101       8951.4     88.6      0.1          decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
   131       202       3652.6     18.1      0.0          start_embedding = self.embedding(
   132       101       4901.0     48.5      0.0              torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
   133       101      29957.5    296.6      0.2          decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
   134
   135       101         50.4      0.5      0.0          hidden_outputs, logit_outputs, alignments = [], [], []
   136     12503      27916.0      2.2      0.2          while len(hidden_outputs) < decoder_inputs.size(0):
   137
   138     12402     124859.6     10.1      0.8              decoder_input = decoder_inputs[len(hidden_outputs)]
   139     12402   15663074.1   1262.9     95.9              hidden, logit, attention_weights = self.decode(decoder_input)
   140     12402      12834.5      1.0      0.1              hidden_outputs += [hidden]
   141     12402       4052.0      0.3      0.0              logit_outputs += [logit]
   142     12402       4427.6      0.4      0.0              alignments += [attention_weights]
   143
   144       101      57422.0    568.5      0.4          hidden_outputs, logit_outputs, alignments = \
   145       202      37085.3    183.6      0.2              self.parse_decoder_outputs(
   146       101         17.3      0.2      0.0                  hidden_outputs, logit_outputs, alignments)
   147
   148       101         40.2      0.4      0.0          return hidden_outputs, logit_outputs, alignments
File: /app/Utils/ASR/models.py
Function: decode at line 149

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   149                                               @profile
   150                                               def decode(self, decoder_input):
   151
   152     12077     451589.1     37.4      2.9          cell_input = torch.cat((decoder_input, self.attention_context), -1)
   153     24154    1601496.2     66.3     10.2          self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
   154     12077       1534.5      0.1      0.0              cell_input,
   155     12077       4154.4      0.3      0.0              (self.decoder_hidden, self.decoder_cell))
   156
   157     24154     395431.2     16.4      2.5          attention_weights_cat = torch.cat(
   158     24154      94829.2      3.9      0.6              (self.attention_weights.unsqueeze(1),
   159     24154      67172.2      2.8      0.4              self.attention_weights_cum.unsqueeze(1)),dim=1)
   160
   161     24154   10655347.3    441.1     68.1          self.attention_context, self.attention_weights = self.attention_layer(
   162     12077       2301.5      0.2      0.0              self.decoder_hidden,
   163     12077       2838.1      0.2      0.0              self.memory,
   164     12077       2822.3      0.2      0.0              self.processed_memory,
   165     12077       1527.4      0.1      0.0              attention_weights_cat,
   166     12077       3143.5      0.3      0.0              self.mask)
   167
   168     12077     231777.8     19.2      1.5          self.attention_weights_cum += self.attention_weights
   169
   170     12077     264389.2     21.9      1.7          hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
   171     12077    1005236.5     83.2      6.4          hidden = self.project_to_hidden(hidden_and_context)
   172
   173                                                   # dropout to increasing g
   174     12077     860518.1     71.3      5.5          logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
   175
   176     12077       5210.3      0.4      0.0          return hidden, logit, self.attention_weights

@borrero-c
Copy link

Looked into it some more, my steps are taking 40-20 seconds long and the .backwards() call is taking 20-10 seconds respectively.

When the training starts to pick up after that first epoch (and GPU is being more consistently utilized) the steps are ~4 seconds each and the backwards call takes ~2 seconds.

Also interesting to see that this code block is taking a good amount of time to complete too:

for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
mel = mels[bib, :, :mel_input_length[bib]]
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
ss.append(s)
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
gs.append(s)

It seems for each step ~25% of time is spent in the loop above and ~50% is spent in the .backwards() call in line 464. Not sure how/if those could be improved, this isnt really my area of expertise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants