Skip to content

Commit

Permalink
fix(chat): model unload memory release (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent 21c8ecc commit e0a9e7e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
18 changes: 11 additions & 7 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _load(
"cpu" if "mps" in str(device) else device
).eval()
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
vocos.load_state_dict(torch.load(vocos_ckpt_path))
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
self.vocos = vocos
if "mps" in str(self.device):
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
Expand All @@ -146,15 +146,15 @@ def _load(
dvae = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(dvae)
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
self.dvae = dvae
self.logger.log(logging.INFO, 'dvae loaded.')

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
if compile and 'cuda' in str(device):
try:
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
Expand All @@ -163,20 +163,20 @@ def _load(
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path, map_location=device).to(device)
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path, weights_only=True, mmap=True).to(device)
self.logger.log(logging.INFO, 'gpt loaded.')

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
decoder.load_state_dict(torch.load(decoder_ckpt_path, weights_only=True, mmap=True))
self.decoder = decoder
self.logger.log(logging.INFO, 'decoder loaded.')

if tokenizer_path:
tokenizer = torch.load(tokenizer_path, map_location=device)
tokenizer = torch.load(tokenizer_path, map_location=device, mmap=True)
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')
Expand All @@ -187,7 +187,11 @@ def _load(

def unload(self):
logger = self.logger
del_all(self)
del_all(self.pretrain_models)
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
for module in del_list:
if hasattr(self, module):
delattr(self, module)
self.__init__(logger)

def _infer(
Expand Down
1 change: 1 addition & 0 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import random
from typing import Optional

Expand Down

0 comments on commit e0a9e7e

Please sign in to comment.