Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multi-task classification task. #191

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ ENV/
.DS_Store

# logs and output
**/output/
**/log/

# dataset
**/dataset/
output/
log/
dataset/
2 changes: 1 addition & 1 deletion plsc/core/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def recompute_forward(func, *args, **kwargs):

def recompute_warp(model, layerlist_interval=1, names=[]):

for name, layer in model._sub_layers.items():
for name, layer in model.named_sublayers():
if isinstance(layer, nn.LayerList):
for idx, sub_layer in enumerate(layer):
if layerlist_interval >= 1 and idx % layerlist_interval == 0:
Expand Down
1 change: 1 addition & 0 deletions plsc/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ def default_loader(path: str):
from .imagenet_dataset import ImageNetDataset
from .face_recognition_dataset import FaceIdentificationDataset, FaceVerificationDataset, FaceRandomDataset
from .imagefolder_dataset import ImageFolder
from .mtl_dataset import SingleTaskDataset, MultiTaskDataset, ConcatDataset
189 changes: 189 additions & 0 deletions plsc/data/dataset/mtl_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Single-task Dataset and ConcatDataset are realized.
Multi-task dataset(ConcatDataset) can be composed by multiple single-task datasets.
"""
from collections import Iterable
import warnings
import bisect
import cv2
from os.path import join
import numpy as np
import random

import paddle
from paddle.io import Dataset
from plsc.data.utils import create_preprocess_operators


class SingleTaskDataset(Dataset):
"""
Single-task Dataset.
The input file includes single task dataset.
"""

def __init__(self, task_id, data_root, label_path, transform_ops):
self.task_id = task_id
self.data_root = data_root
self.transform_ops = None
if transform_ops is not None:
self.transform_ops = create_preprocess_operators(transform_ops)
self.data_list = []
with open(join(data_root, label_path), "r") as f:
for line in f:
img_path, label = line.strip().split(" ")
self.data_list.append(
(join(data_root, "images", img_path), int(label)))

def __getitem__(self, idx):
img_path, label = self.data_list[idx]
with open(img_path, 'rb') as f:
img = f.read()
if self.transform_ops:
img = self.transform_ops(img)
if label == -1:
label = 0
label = paddle.to_tensor(np.array([label]), dtype=paddle.int32)
target = {"label": label, "task": self.task_id}
return img, target

def __len__(self):
return len(self.data_list)


class ConcatDataset(Dataset):
"""

Dataset that are composed by multiple datasets.
Multi-task Dataset can be the concatenation of single-task datasets.
"""

@staticmethod
def cumsum(sequence, ratio_list):
r, s = [], 0
for i, e in enumerate(sequence):
l = int(len(e) * ratio_list[i])
r.append(l + s)
s += l
return r

def __init__(self, datasets, dataset_ratio=None):
super(ConcatDataset, self).__init__()
assert isinstance(datasets,
Iterable), "datasets should not be iterable."
assert len(datasets) > 0, " datasets length should be greater than 0."
self.instance_datasets(datasets)

if dataset_ratio is not None:
assert len(dataset_ratio) == len(self.datasets)
self.dataset_ratio = {
i: dataset_ratio[i]
for i in range(len(dataset_ratio))
}
else:
self.dataset_ratio = {i: 1. for i in range(len(self.datasets))}

self.cumulative_sizes = self.cumsum(self.datasets, self.dataset_ratio)
self.idx_ds_map = {
idx: bisect.bisect_right(self.cumulative_sizes, idx)
for idx in range(self.__len__())
}

def instance_datasets(self, datasets):
# get class instance from config dict
dataset_list = []
for ds in datasets:
if isinstance(ds, SingleTaskDataset):
continue
if isinstance(ds, dict):
name = list(ds.keys())[0]
params = ds[name]
task_ids = params.pop("task_ids", [0])
if not isinstance(task_ids, list):
task_ids = [task_ids]
label_path = params.pop("label_path")
if not isinstance(label_path, list):
label_path = [label_path]
assert len(label_path) == len(
task_ids), "Length of label_path should equal to task_ids."
for task_id, label_path in zip(task_ids, label_path):
dataset = eval(name)(task_id=task_id,
label_path=label_path,
**params)
dataset_list.append(dataset)
if len(dataset_list) > 0:
self.datasets = dataset_list
else:
self.datasets = list(datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
dataset_idx = self.idx_ds_map[idx]
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
if sample_idx >= len(self.datasets[dataset_idx]):
sample_idx = random.choice(range(len(self.datasets[dataset_idx])))
return self.datasets[dataset_idx][sample_idx]

@property
def cummulative_sizes(self):
return self.cumulative_sizes


class MultiTaskDataset(Dataset):
"""
Multi-Task Dataset.
The input file includes multi-task datasets.
"""

def __init__(self, task_id, data_root, label_path, transform_ops):
"""

Args:
task_ids: task id list
data_root:
label_path:
transform_ops:
"""
self.task_id = task_id
self.data_root = data_root
self.transform_ops = None
if transform_ops is not None:
self.transform_ops = create_preprocess_operators(transform_ops)
self.data_list = []
with open(join(data_root, label_path), "r") as f:
for line in f:
img_path, labels = line.strip().split(" ", 1)
labels = [int(label) for label in labels.strip().split(" ")]
self.data_list.append(
(join(data_root, "images", img_path), labels))

def __getitem__(self, idx):
img_path, labels = self.data_list[idx]
with open(img_path, 'rb') as f:
img = f.read()
if self.transform_ops:
img = self.transform_ops(img)
labels = [0 if label == -1 else label for label in labels]
labels = paddle.to_tensor(np.array(labels), dtype=paddle.int32)
target = {"label": labels, "task": self.task_id}
return img, target

def __len__(self):
return len(self.data_list)
89 changes: 89 additions & 0 deletions plsc/data/sampler/mtl_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import numpy as np

from paddle.io import DistributedBatchSampler


class MTLSampler(DistributedBatchSampler):
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False,
idx_sample_p: dict=None):
super(MTLSampler, self).__init__(
dataset,
batch_size,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
drop_last=drop_last)
self.idx_sample_p = idx_sample_p

def resample(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()

return indices

def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1

# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks

for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])

indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return np.array(subsampled_indices)

if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)

assert len(indices) == self.num_samples
_sample_iter = iter(indices)
if self.idx_sample_p is not None:
assert len(self.idx_sample_p) == len(self.dataset), \
"length of idx_sample_p must be equal to dataset"
batch_indices = []
sample_p = [self.idx_sample_p[ind] for ind in indices]
for _ in range(len(indices)):
idx = np.random.choice(indices, replace=True, p=sample_p)
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
2 changes: 2 additions & 0 deletions plsc/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from plsc.engine.engine import Engine
from plsc.engine.multi_task_classfication import MTLEngine
15 changes: 15 additions & 0 deletions plsc/engine/multi_task_classfication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from plsc.engine.multi_task_classfication.trainer import MTLEngine
Loading