Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Funasr1.0 #1362

Merged
merged 12 commits into from
Feb 6, 2024
3 changes: 2 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def build_model(self, **kwargs):
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()

model.to(device)

# init_param
Expand Down Expand Up @@ -209,6 +209,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **
kwargs = self.kwargs if kwargs is None else kwargs
kwargs.update(cfg)
model = self.model if model is None else model
model.eval()

batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
Expand Down
54 changes: 52 additions & 2 deletions funasr/datasets/audio_datasets/index_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from funasr.register import tables


@tables.register("index_ds_classes", "IndexDSJsonl")
class IndexDSJsonl(torch.utils.data.Dataset):
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):

def __init__(self, path):
super().__init__()
Expand Down Expand Up @@ -66,3 +66,53 @@ def get_source_len(self, data_dict):
def get_target_len(self, data_dict):

return data_dict["target_len"] if "target_len" in data_dict else 0

@tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
class IndexDSJsonlRankFull(torch.utils.data.Dataset):

def __init__(self, path):
super().__init__()

contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
data = json.loads(line.strip())
if "text" in data: # for sft
self.contents.append(data['text'])
if "source" in data: # for speech lab pretrain
prompt = data.get("prompt", "<ASR>")
source = data["source"]
target = data["target"]
source_len = data.get("source_len", 1)
target_len = data.get("target_len", 0)

contents.append({"source": source,
"prompt": prompt,
"target": target,
"source_len": source_len,
"target_len": target_len,
}
)

self.contents = contents

logging.info(
"total_num of samplers across ranks: {}".format(len(self.contents)))

def __len__(self):
return len(self.contents)

def __getitem__(self, index):
try:
data = self.contents[index]
except:
print(index)
return data

def get_source_len(self, data_dict):
return data_dict.get("source_len", 1)

def get_target_len(self, data_dict):

return data_dict.get("target_len", 0)
193 changes: 193 additions & 0 deletions funasr/datasets/audio_datasets/samplers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import numpy as np
import logging
import torch.distributed as dist

from funasr.register import tables

Expand Down Expand Up @@ -82,3 +84,194 @@ def __iter__(self):
max_token = sample_len_cur_raw
num_sample = 1


@tables.register("batch_sampler_classes", "BatchSampler")
@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):

def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = True,
shuffle: bool = True,
is_training: bool = True,
**kwargs):

self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 1500)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)

try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.rank = rank
self.world_size = world_size

def __len__(self):
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1

def set_epoch(self, epoch):
np.random.seed(epoch)

def __iter__(self):

batch_size_total = self.batch_size * self.world_size

if self.shuffle:
np.random.shuffle(self.shuffle_idx)

batch = []
max_token = 0
num_sample = 0

iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
# if iter == iter_num -1 and self.drop_last:
# continue
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue

idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]

source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
sample_len_cur = source_len + target_len

datalen_with_index.append([idx, sample_len_cur])

datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue

max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
# if self.batch_type != 'example':
# max_token_padding *= max_token_cur
if max_token_padding <= batch_size_total:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
yield batch_rank
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1


@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):

def __init__(self, dataset,
batch_type: str = "example",
batch_size: int = 100,
buffer_size: int = 30,
drop_last: bool = True,
shuffle: bool = True,
is_training: bool = True,
**kwargs):

self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 1500)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)

try:
rank = dist.get_rank()
world_size = dist.get_world_size()
except:
rank = 0
world_size = 1
self.rank = rank
self.world_size = world_size

def __len__(self):
return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1

def set_epoch(self, epoch):
np.random.seed(epoch)

def __iter__(self):

batch_size_total = self.batch_size * self.world_size
if self.shuffle:
np.random.shuffle(self.shuffle_idx)

batch_list_all_rank = []
batch_list_cur = []
max_token = 0
num_sample = 0

iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
# if iter == iter_num - 1 and self.drop_last:
# continue
datalen_with_index = []
for i in range(self.buffer_size):
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue

idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]

source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
sample_len_cur = source_len + target_len

datalen_with_index.append([idx, sample_len_cur])

datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for ii, item in enumerate(datalen_with_index_sort):
is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue

max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample

if self.batch_type != 'example':
max_token_padding *= max_token_cur
if len(batch_list_all_rank) < self.world_size:

if max_token_padding <= self.batch_size:
batch_list_cur.append(idx)
max_token = max_token_cur
num_sample += 1
else:
batch_list_all_rank.append(batch_list_cur)
batch_list_cur = []
else:
batch_rank = batch_list_all_rank[self.rank]
yield batch_rank
batch_list_all_rank = [idx]
max_token = sample_len_cur_raw
num_sample = 1
3 changes: 2 additions & 1 deletion funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def inference(self,

time1 = time.perf_counter()
is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input}
is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
audio_sample_list = load_audio_text_image_video(data_in,
fs=frontend.fs,
audio_fs=kwargs.get("fs", 16000),
Expand Down
2 changes: 1 addition & 1 deletion funasr/models/paraformer/cif_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk
alphas = alphas.squeeze(-1)
mask = mask.squeeze(-1)
if target_label_length is not None:
target_length = target_label_length
target_length = target_label_length.squeeze(-1)
elif target_label is not None:
target_length = (target_label != ignore_id).float().sum(-1)
else:
Expand Down
2 changes: 2 additions & 0 deletions funasr/models/paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def inference(self,
b, n, d = decoder_out.size()
if isinstance(key[0], (list, tuple)):
key = key[0]
if len(key) < b:
key = key*b
for i in range(b):
x = encoder_out[i, :encoder_out_lens[i], :]
am_scores = decoder_out[i, :pre_token_length[i], :]
Expand Down
18 changes: 18 additions & 0 deletions funasr/train_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,25 @@ def _train_epoch(self, epoch):
my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
with my_context():
time2 = time.perf_counter()
print("before, GPU, memory: {:.1} MB, "
"{:.1} MB, "
"{:.1} MB, "
"{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
torch.cuda.max_memory_allocated()/1024/1024/1024,
torch.cuda.memory_reserved()/1024/1024/1024,
torch.cuda.max_memory_reserved()/1024/1024/1024,
))

retval = self.model(**batch)
torch.cuda.empty_cache()
print("after, GPU, memory: {:.1} MB, "
"{:.1} MB, "
"{:.1} MB, "
"{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024,
torch.cuda.max_memory_allocated()/1024/1024/1024,
torch.cuda.memory_reserved()/1024/1024/1024,
torch.cuda.max_memory_reserved()/1024/1024/1024,
))
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
Expand Down
Loading