diff --git a/.gitignore b/.gitignore index 4d685bf..a0d5335 100644 --- a/.gitignore +++ b/.gitignore @@ -103,8 +103,6 @@ ENV/ .DS_Store # logs and output -**/output/ -**/log/ - -# dataset -**/dataset/ \ No newline at end of file +output/ +log/ +dataset/ diff --git a/plsc/core/recompute.py b/plsc/core/recompute.py index 8b7bb73..4f08a67 100644 --- a/plsc/core/recompute.py +++ b/plsc/core/recompute.py @@ -33,7 +33,7 @@ def recompute_forward(func, *args, **kwargs): def recompute_warp(model, layerlist_interval=1, names=[]): - for name, layer in model._sub_layers.items(): + for name, layer in model.named_sublayers(): if isinstance(layer, nn.LayerList): for idx, sub_layer in enumerate(layer): if layerlist_interval >= 1 and idx % layerlist_interval == 0: diff --git a/plsc/data/dataset/__init__.py b/plsc/data/dataset/__init__.py index 3238751..8b96482 100644 --- a/plsc/data/dataset/__init__.py +++ b/plsc/data/dataset/__init__.py @@ -64,3 +64,4 @@ def default_loader(path: str): from .imagenet_dataset import ImageNetDataset from .face_recognition_dataset import FaceIdentificationDataset, FaceVerificationDataset, FaceRandomDataset from .imagefolder_dataset import ImageFolder +from .mtl_dataset import SingleTaskDataset, MultiTaskDataset, ConcatDataset diff --git a/plsc/data/dataset/mtl_dataset.py b/plsc/data/dataset/mtl_dataset.py new file mode 100644 index 0000000..72c49df --- /dev/null +++ b/plsc/data/dataset/mtl_dataset.py @@ -0,0 +1,189 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Single-task Dataset and ConcatDataset are realized. +Multi-task dataset(ConcatDataset) can be composed by multiple single-task datasets. +""" +from collections import Iterable +import warnings +import bisect +import cv2 +from os.path import join +import numpy as np +import random + +import paddle +from paddle.io import Dataset +from plsc.data.utils import create_preprocess_operators + + +class SingleTaskDataset(Dataset): + """ + Single-task Dataset. + The input file includes single task dataset. + """ + + def __init__(self, task_id, data_root, label_path, transform_ops): + self.task_id = task_id + self.data_root = data_root + self.transform_ops = None + if transform_ops is not None: + self.transform_ops = create_preprocess_operators(transform_ops) + self.data_list = [] + with open(join(data_root, label_path), "r") as f: + for line in f: + img_path, label = line.strip().split(" ") + self.data_list.append( + (join(data_root, "images", img_path), int(label))) + + def __getitem__(self, idx): + img_path, label = self.data_list[idx] + with open(img_path, 'rb') as f: + img = f.read() + if self.transform_ops: + img = self.transform_ops(img) + if label == -1: + label = 0 + label = paddle.to_tensor(np.array([label]), dtype=paddle.int32) + target = {"label": label, "task": self.task_id} + return img, target + + def __len__(self): + return len(self.data_list) + + +class ConcatDataset(Dataset): + """ + + Dataset that are composed by multiple datasets. + Multi-task Dataset can be the concatenation of single-task datasets. + """ + + @staticmethod + def cumsum(sequence, ratio_list): + r, s = [], 0 + for i, e in enumerate(sequence): + l = int(len(e) * ratio_list[i]) + r.append(l + s) + s += l + return r + + def __init__(self, datasets, dataset_ratio=None): + super(ConcatDataset, self).__init__() + assert isinstance(datasets, + Iterable), "datasets should not be iterable." + assert len(datasets) > 0, " datasets length should be greater than 0." + self.instance_datasets(datasets) + + if dataset_ratio is not None: + assert len(dataset_ratio) == len(self.datasets) + self.dataset_ratio = { + i: dataset_ratio[i] + for i in range(len(dataset_ratio)) + } + else: + self.dataset_ratio = {i: 1. for i in range(len(self.datasets))} + + self.cumulative_sizes = self.cumsum(self.datasets, self.dataset_ratio) + self.idx_ds_map = { + idx: bisect.bisect_right(self.cumulative_sizes, idx) + for idx in range(self.__len__()) + } + + def instance_datasets(self, datasets): + # get class instance from config dict + dataset_list = [] + for ds in datasets: + if isinstance(ds, SingleTaskDataset): + continue + if isinstance(ds, dict): + name = list(ds.keys())[0] + params = ds[name] + task_ids = params.pop("task_ids", [0]) + if not isinstance(task_ids, list): + task_ids = [task_ids] + label_path = params.pop("label_path") + if not isinstance(label_path, list): + label_path = [label_path] + assert len(label_path) == len( + task_ids), "Length of label_path should equal to task_ids." + for task_id, label_path in zip(task_ids, label_path): + dataset = eval(name)(task_id=task_id, + label_path=label_path, + **params) + dataset_list.append(dataset) + if len(dataset_list) > 0: + self.datasets = dataset_list + else: + self.datasets = list(datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = self.idx_ds_map[idx] + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + if sample_idx >= len(self.datasets[dataset_idx]): + sample_idx = random.choice(range(len(self.datasets[dataset_idx]))) + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + return self.cumulative_sizes + + +class MultiTaskDataset(Dataset): + """ + Multi-Task Dataset. + The input file includes multi-task datasets. + """ + + def __init__(self, task_id, data_root, label_path, transform_ops): + """ + + Args: + task_ids: task id list + data_root: + label_path: + transform_ops: + """ + self.task_id = task_id + self.data_root = data_root + self.transform_ops = None + if transform_ops is not None: + self.transform_ops = create_preprocess_operators(transform_ops) + self.data_list = [] + with open(join(data_root, label_path), "r") as f: + for line in f: + img_path, labels = line.strip().split(" ", 1) + labels = [int(label) for label in labels.strip().split(" ")] + self.data_list.append( + (join(data_root, "images", img_path), labels)) + + def __getitem__(self, idx): + img_path, labels = self.data_list[idx] + with open(img_path, 'rb') as f: + img = f.read() + if self.transform_ops: + img = self.transform_ops(img) + labels = [0 if label == -1 else label for label in labels] + labels = paddle.to_tensor(np.array(labels), dtype=paddle.int32) + target = {"label": labels, "task": self.task_id} + return img, target + + def __len__(self): + return len(self.data_list) diff --git a/plsc/data/sampler/mtl_sampler.py b/plsc/data/sampler/mtl_sampler.py new file mode 100644 index 0000000..703487f --- /dev/null +++ b/plsc/data/sampler/mtl_sampler.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import numpy as np + +from paddle.io import DistributedBatchSampler + + +class MTLSampler(DistributedBatchSampler): + def __init__(self, + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=False, + drop_last=False, + idx_sample_p: dict=None): + super(MTLSampler, self).__init__( + dataset, + batch_size, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + drop_last=drop_last) + self.idx_sample_p = idx_sample_p + + def resample(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + + return indices + + def __iter__(self): + num_samples = len(self.dataset) + indices = np.arange(num_samples).tolist() + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + if self.shuffle: + np.random.RandomState(self.epoch).shuffle(indices) + self.epoch += 1 + + # subsample + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend(indices[ + self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return np.array(subsampled_indices) + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + + assert len(indices) == self.num_samples + _sample_iter = iter(indices) + if self.idx_sample_p is not None: + assert len(self.idx_sample_p) == len(self.dataset), \ + "length of idx_sample_p must be equal to dataset" + batch_indices = [] + sample_p = [self.idx_sample_p[ind] for ind in indices] + for _ in range(len(indices)): + idx = np.random.choice(indices, replace=True, p=sample_p) + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices diff --git a/plsc/engine/__init__.py b/plsc/engine/__init__.py index 97043fd..93b77c3 100644 --- a/plsc/engine/__init__.py +++ b/plsc/engine/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from plsc.engine.engine import Engine +from plsc.engine.multi_task_classfication import MTLEngine diff --git a/plsc/engine/multi_task_classfication/__init__.py b/plsc/engine/multi_task_classfication/__init__.py new file mode 100644 index 0000000..41a2aa4 --- /dev/null +++ b/plsc/engine/multi_task_classfication/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from plsc.engine.multi_task_classfication.trainer import MTLEngine diff --git a/plsc/engine/multi_task_classfication/trainer.py b/plsc/engine/multi_task_classfication/trainer.py new file mode 100644 index 0000000..0dd763d --- /dev/null +++ b/plsc/engine/multi_task_classfication/trainer.py @@ -0,0 +1,316 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-task learning trainer +""" +import copy +import math +import numpy as np +import os +import random +import time + +import paddle +import paddle.distributed as dist + +from plsc.utils import logger, io +from plsc.utils.config import print_config +from plsc.models import build_model +from plsc.loss import build_mtl_loss +from plsc.core import GradScaler, param_sync +from plsc.core import grad_sync, recompute_warp +from plsc.optimizer import build_optimizer +from plsc.metric import build_metrics +from plsc.data import build_dataloader +from plsc.scheduler import build_lr_scheduler + + +class MTLEngine(object): + def __init__(self, config, mode="Train"): + self.mode = mode + self.finetune = False + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.config = config + self.parse_config() + self.build_modules() + + @staticmethod + def params_counts(model): + n_parameters = sum(p.numel() for p in model.parameters() + if not p.stop_gradient).item() + i = int(math.log(n_parameters, 10) // 3) + size_unit = ['', 'K', 'M', 'B', 'T', 'Q'] + param_size = n_parameters / math.pow(1000, i) + return param_size, size_unit[i] + + def _init_worker(self, worker_id): + """ set seed in subproces for dataloader when num_workers > 0""" + if self.seed: + np.random.seed(self.seed + worker_id) + random.seed(self.seed + worker_id) + + def parse_config(self): + + # parse global params + for key in self.config["Global"]: + setattr(self, key, self.config["Global"][key]) + + self.model_name = self.config["Model"].get("name", None) + assert self.model_name, "model must be defined!" + + # init logger + self.output_dir = self.config['Global']['output_dir'] + log_file = os.path.join(self.output_dir, self.model_name, + f"{self.mode}.log") + logger.init_logger(log_file=log_file) + + # record + print_config(self.config) + + # set device + assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"] + self.device = paddle.set_device(self.config["Global"]["device"]) + logger.info('train with paddle {}, commit id {} and device {}'.format( + paddle.__version__, paddle.__git_commit__[:8], self.device)) + + # set seed + self.seed = self.config["Global"].get("seed", False) + if self.seed: + assert isinstance(self.seed, int), "The 'seed' must be a integer!" + self.seed += self.rank + paddle.seed(self.seed) + np.random.seed(self.seed) + random.seed(self.seed) + self.worker_init_fn = self._init_worker if self.seed else None + + # distributed strategy + cfg_dist = self.config.get("DistributedStrategy", None) + if cfg_dist.get("data_parallel", None): + self.dp = True + dist.init_parallel_env() + + self.recompute = False + self.recompute_params = {} + if cfg_dist.get("recompute", None): + self.recompute = True + self.recompute_params = cfg_dist["recompute"] + + # amp + cfg_fp16 = self.config.get("FP16", False) + self.fp16_params = {"enable": False} + if cfg_fp16: + self.fp16_params["level"] = cfg_fp16.get("level", "O1") + if self.fp16_params["level"] != 'O0': + self.fp16_params["enable"] = True + cfg_scaler = cfg_fp16.get("GradScaler", {}) + self.scaler = GradScaler(self.fp16_params["enable"], **cfg_scaler) + self.fp16_params["custom_white_list"] = cfg_fp16.get( + "fp16_custom_white_list", None) + self.fp16_params["custom_black_list"] = cfg_fp16.get( + "fp16_custom_black_list", None) + + def build_modules(self): + # dataset + if self.mode == "Train": + for mode in ["Train", "Eval"]: + data_loader = build_dataloader( + self.config["DataLoader"], + mode, + self.device, + worker_init_fn=self.worker_init_fn) + setattr(self, f"{mode.lower()}_dataloader", data_loader) + self.eval_metrics = build_metrics(self.config["Metric"]["Eval"]) + else: + data_loader = build_dataloader( + self.config["DataLoader"], + self.mode, + self.device, + worker_init_fn=self.worker_init_fn) + setattr(self, f"{self.mode.lower()}_dataloader", data_loader) + + metrics = build_metrics(self.config["Metric"][self.mode]) + setattr(self, f"{self.mode.lower()}_metrics", metrics) + + # build model + self.model = build_model( + self.config["Model"], task_names=self.task_names) + if self.recompute: + recompute_warp(self.model, **self.recompute_params) + param_size, size_unit = self.params_counts(self.model) + logger.info( + f"The number of parameters is: {param_size:.3f}{size_unit}.") + if self.dp: + param_sync(self.model) + logger.info("DDP model: sync parameters finished.") + + # build lr, opt, loss + if self.mode == 'Train': + # lr scheduler + lr_config = copy.deepcopy(self.config.get("LRScheduler", None)) + self.lr_decay_unit = lr_config.get("decay_unit", "step") + self.lr_scheduler = None + if lr_config is not None: + self.lr_scheduler = build_lr_scheduler( + lr_config, self.epochs, len(self.train_dataloader)) + # optimizer + self.optimizer = build_optimizer(self.config["Optimizer"], + self.lr_scheduler, self.model) + + self.loss_func = build_mtl_loss(self.task_names, + self.config["Loss"][self.mode]) + + def load_model(self): + if self.checkpoint: + io.load_checkpoint(self.checkpoint, self.model, self.optimizer, + self.scaler) + elif self.pretrained_model: + self.model.load_pretrained(self.pretrained_model, self.rank) + + def train(self): + self.load_model() + # train loop + for epoch in range(self.epochs): + self.train_one_epoch(epoch) + # eval + metric_results = {} + if self.eval_during_train and self.eval_unit == "epoch" \ + and (epoch + 1) % self.eval_interval == 0: + metric_results = self.eval() + # save model + if (epoch + 1) % self.save_interval == 0 or (epoch + 1 + ) == self.epochs: + model_prefix = "final" if ( + epoch + 1) == self.epochs else f"model_epoch{epoch}" + self.save_model(model_prefix, metric_results) + # update lr + if self.lr_decay_unit == "epoch": + self.optimizer.lr_step() + + def train_one_epoch(self, epoch): + step = 0 + avg_loss = 0 # average loss in the lasted `self.print_batch_step` steps + for images, targets in self.train_dataloader: + start = time.time() + step += 1 + # compute loss + with paddle.amp.auto_cast(self.fp16_params): + logits = self.model(images) + _, total_loss = self.loss_func(logits, targets) + scaled = self.scaler.scale(total_loss) + scaled.backward() + grad_sync(self.optimizer.param_groups) + # update params + if (step + 1) % self.accum_steps == 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.clear_grad() + if self.lr_decay_unit == "step": + self.optimizer.lr_step() + # show loss + avg_loss += total_loss.cpu().numpy()[0] + if (step + 1) % self.print_batch_step == 0: + logger.info(f"epoch: {epoch}, step: {step}, " + f"total loss: {avg_loss / self.print_batch_step}") + avg_loss = 0 + end = time.time() + logger.debug(f"one step time = {(end - start): .3f}s") + + def save_model(self, model_prefix, metric_results=None): + io.save_checkpoint( + self.model, + self.optimizer, + self.scaler, + metric_results, + self.output_dir, + model_name=self.model_name, + prefix=model_prefix, + max_num_checkpoint=self.max_num_latest_checkpoint) + + @paddle.no_grad() + def eval(self): + step = 0 + results = {} + bs = {} + self.model.eval() + for images, targets in self.eval_dataloader: + step += 1 + labels = targets["label"] + tasks = targets["task"] + logits = self.model(images) + for idx, task_name in enumerate(self.task_names): + cond = tasks == idx + if not paddle.any(cond): + continue + preds = logits[task_name][cond] + labels = labels[cond] + results[idx] = results.get(idx, {}) + task_metric = self.eval_metrics(preds, labels) + for metric_name in task_metric: + results[idx][metric_name] = results[idx].get(metric_name, + {}) + for key in task_metric[metric_name]: + results[idx][metric_name][key] = \ + results[idx][metric_name].get(key, 0) + task_metric[metric_name][key] + bs[idx] = bs.get(idx, 0) + 1 + self.model.train() + for idx in results: + for metric in results[idx]: + for key in results[idx][metric]: + results[idx][metric][key] /= bs[idx] + for task_id in results: + logger.info(f"metrics - task{task_id}: {results[task_id]}") + return results + + @paddle.no_grad() + def test(self): + step = 0 + results = {} + bs = {} + self.model.eval() + for images, targets in self.test_dataloader: + step += 1 + labels = targets["label"] + tasks = targets["task"] + logits = self.model(images) + for idx in range(len(tasks)): + task_id = tasks[idx][0] + preds_i = logits[self.task_names[task_id]] + labels_i = labels[:, idx] + results[idx] = results.get(idx, {}) + task_metric = self.test_metrics(preds_i, labels_i) + for metric_name in task_metric: + results[idx][metric_name] = results[idx].get(metric_name, + {}) + for key in task_metric[metric_name]: + results[idx][metric_name][key] = \ + results[idx][metric_name].get(key, 0) + task_metric[metric_name][key] + bs[idx] = bs.get(idx, 0) + 1 + self.model.train() + for idx in results: + for metric in results[idx]: + for key in results[idx][metric]: + results[idx][metric][key] /= bs[idx] + for task_id in results: + logger.info(f"metrics - task{task_id}: {results[task_id]}") + return results + + @paddle.no_grad() + def export(self): + assert self.mode in ["Export", "export"] + assert self.config.get("Export", None) is not None + assert self.pretrained_model is not None + self.model.eval() + path = os.path.join(self.output_dir, self.model_name) + io.export(self.config["Export"], self.model, path) diff --git a/plsc/loss/MTLoss.py b/plsc/loss/MTLoss.py new file mode 100644 index 0000000..41657fd --- /dev/null +++ b/plsc/loss/MTLoss.py @@ -0,0 +1,88 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import Iterable +import paddle +import paddle.nn as nn + +from plsc.loss.celoss import ViTCELoss, CELoss +from plsc.loss.distill_loss import MSELoss + + +class MTLoss(nn.Layer): + """ + multi-task loss framework + """ + + def __init__(self, task_names, losses, weights=1.0): + super().__init__() + self.loss_func = {} + self.task_names = task_names + self.instance_losses(losses) + self.loss_weight = {} + if isinstance(weights, float): + weights = len(self.loss_func) * [weights] + assert len(self.loss_func) == len( + weights), "Length of loss_func should be equal to weights" + weight_sum = sum(weights) + for task_id in self.loss_func: + self.loss_weight[task_id] = weights[task_id] / weight_sum + + def instance_losses(self, losses): + assert isinstance(losses, Iterable) and len(losses) > 0, \ + "losses should be iterable and length greater than 0" + self.loss_func = {} + self.loss_names = {} + + for loss_item in losses: + assert isinstance(loss_item, dict) and len(loss_item.keys()) == 1, \ + "item in losses should be config dict whose length is one(loss class config)" + name = list(loss_item.keys())[0] + params = loss_item[name] + task_ids = params.pop("task_ids", [0]) + if not isinstance(task_ids, list): + task_ids = [task_ids] + for task_id in task_ids: + self.loss_func[task_id] = eval(name)(**params) + self.loss_names[task_id] = name + + @staticmethod + def cast_fp32(input): + if input.dtype != paddle.float32: + input = paddle.cast(input, 'float32') + return input + + def __call__(self, input, target): + # target: [label, task] + loss_dict = {} + total_loss = 0.0 + assert isinstance( + target, dict), "target shold be a dict including keys(label, task)" + label = target["label"] + task = target["task"] + for idx in self.loss_func: + cond = task == idx + logits = input[self.task_names[idx]][cond] + if isinstance(label, dict): + if self.task_names[idx] in label: + labels = label[self.task_names[idx]][cond] + else: + print("label should be a tensor, not dict") + else: + labels = label[cond] + loss = self.loss_func[idx](logits, labels) + loss_dict[idx] = loss[self.loss_names[idx]] + total_loss += loss_dict[idx] * self.loss_weight[idx] + return loss_dict, total_loss diff --git a/plsc/loss/__init__.py b/plsc/loss/__init__.py index 9b50253..2844e35 100644 --- a/plsc/loss/__init__.py +++ b/plsc/loss/__init__.py @@ -19,6 +19,7 @@ from plsc.utils import logger from .celoss import CELoss, ViTCELoss +from .MTLoss import MTLoss from .marginloss import MarginLoss @@ -59,3 +60,10 @@ def build_loss(config): module_class = CombinedLoss(copy.deepcopy(config)) logger.debug("build loss {} success.".format(module_class)) return module_class + + +def build_mtl_loss(task_names, cfg_loss): + loss_name = cfg_loss.pop("name") + module_class = eval(loss_name)(task_names, **cfg_loss) + logger.debug("build loss {} success.".format(module_class)) + return module_class diff --git a/plsc/loss/distill_loss.py b/plsc/loss/distill_loss.py new file mode 100644 index 0000000..18874ef --- /dev/null +++ b/plsc/loss/distill_loss.py @@ -0,0 +1,34 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# multiple distillation loss functions +# + +import paddle.nn as nn +import paddle.nn.functional as F + + +class MSELoss(nn.Layer): + def __init__(self, temperature): + super().__init__() + self.temperature = temperature + + def forward(self, student_logits, teacher_logits): + + student_sfm = F.log_softmax(student_logits / self.temperature) + teacher_sfm = F.log_softmax(teacher_logits / self.temperature) + loss = nn.functional.mse_loss( + student_sfm, teacher_sfm, reduction="mean") + return {"MSELoss": loss} diff --git a/plsc/metric/__init__.py b/plsc/metric/__init__.py index a7d8e77..f739c22 100644 --- a/plsc/metric/__init__.py +++ b/plsc/metric/__init__.py @@ -39,7 +39,9 @@ def __init__(self, config_list): def __call__(self, *args, **kwargs): metric_dict = OrderedDict() for idx, metric_func in enumerate(self.metric_func_list): - metric_dict.update(metric_func(*args, **kwargs)) + metric_dict[str(metric_func).replace("()", "")] = metric_func( + *args, **kwargs) + # metric_dict.update(metric_func(*args, **kwargs)) return metric_dict diff --git a/plsc/models/__init__.py b/plsc/models/__init__.py index fc7ba50..27a5a2a 100644 --- a/plsc/models/__init__.py +++ b/plsc/models/__init__.py @@ -22,14 +22,16 @@ from .face_vit import * from .mobilefacenet import * from .cait import * +from .multi_task.MTLModel import * __all__ = ["build_model"] -def build_model(config): +def build_model(config, **kwargs): config = copy.deepcopy(config) model_type = config.pop("name") mod = importlib.import_module(__name__) + config.update(kwargs) model = getattr(mod, model_type)(**config) assert isinstance( model, Model), 'model must inherit from plsc.models.layers.Model' diff --git a/plsc/models/multi_task/MTLModel.py b/plsc/models/multi_task/MTLModel.py new file mode 100644 index 0000000..22d2321 --- /dev/null +++ b/plsc/models/multi_task/MTLModel.py @@ -0,0 +1,126 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-task Model Framework is realized. +Combine backbone and multiple encoder layers in this framework. +""" + +import os +from typing import List, Dict +from collections import OrderedDict + +import paddle +from paddle.nn import Layer, LayerDict, LayerList +from plsc.models.base_model import Model +from plsc.models.multi_task.ResNet_backbone import * +from plsc.models.multi_task.head import * + + +class MTLModel(Model): + """ + Multi-task Model Framework. + Recomputing can be turned on. + """ + + def __init__(self, task_names, backbone, heads): + """ + + Args: + task_names: task name list + backbone: backbone for feature extraction (or config dict) + encoder_heads: Dict (or config list) + recompute_on: if recompute is used + recompute_params: recompute layers + """ + super(MTLModel, self).__init__() + self.task_names = task_names + + if isinstance(backbone, Model): + self.backbone = backbone + else: + self.backbone = self.instances_from_cfg(backbone) + # {task_names: Layer} + heads = self.instances_from_cfg(heads) + self.heads = LayerDict(sublayers=heads) + + def instances_from_cfg(self, cfg): + # instantiate layer from config dict + if isinstance(cfg, dict): + name = cfg.pop("name", None) + if name is not None: + try: + module = eval(name)(**cfg) + except Exception as e: + print("instance cfg error: ", e) + else: + return module + if isinstance(cfg, list): + module_dic = {} + for item in cfg: + if isinstance(item, dict) and len(item) == 1: + name = list(item.keys())[0] + params = item[name] + task_ids = params.pop("task_ids", None) + class_nums = params.pop("class_nums", None) + if task_ids and class_nums: + for task_id, class_num in zip(task_ids, class_nums): + module = eval(name)(class_num=class_num, **params) + module_dic[self.task_names[task_id]] = module + if len(module_dic) > 0: + return module_dic + return None + + def forward(self, inputs, output_task_names=None): + output = {} + features = self.backbone(inputs) + if output_task_names is not None: + for task_name in output_task_names: + output[task_name] = self.heads[task_name](features) + else: + for task_name in self.heads: + output[task_name] = self.heads[task_name](features) + return output + + def save(self, path, local_rank=0, rank=0): + # save model + dist_state_dict = OrderedDict() + state_dict = self.state_dict() + for name, param in list(state_dict.items()): + if param.is_distributed: + dist_state_dict[name] = state_dict.pop(name) + + if local_rank == 0: + paddle.save(state_dict, path + ".pdparams") + + if len(dist_state_dict) > 0: + paddle.save(dist_state_dict, + path + "_rank{}.pdparams".format(rank)) + + def load_pretrained(self, path, rank=0, finetune=False): + # load pretrained model + if not os.path.exists(path + '.pdparams'): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + state_dict = paddle.load(path + ".pdparams") + + dist_param_path = path + "_rank{}.pdparams".format(rank) + if os.path.exists(dist_param_path): + dist_state_dict = paddle.load(dist_param_path) + state_dict.update(dist_state_dict) + # clear + dist_state_dict.clear() + + if not finetune: + self.set_dict(state_dict) diff --git a/plsc/models/multi_task/ResNet_backbone.py b/plsc/models/multi_task/ResNet_backbone.py new file mode 100644 index 0000000..183f043 --- /dev/null +++ b/plsc/models/multi_task/ResNet_backbone.py @@ -0,0 +1,339 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import collections +import numpy as np +import paddle +import paddle.nn as nn + +from plsc.nn import init +from plsc.models.base_model import Model + +import math + +__all__ = ["IResNet18", "IResNet34", "IResNet50", "IResNet100", "IResNet200"] + + +def conv3x3(in_planes, + out_planes, + stride=1, + groups=1, + dilation=1, + data_format="NCHW"): + """3x3 convolution with padding""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias_attr=False, + data_format=data_format) + + +def conv1x1(in_planes, out_planes, stride=1, data_format="NCHW"): + """1x1 convolution""" + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias_attr=False, + data_format=data_format) + + +class IBasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=\ + 1, base_width=64, dilation=1, data_format="NCHW"): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + 'Dilation > 1 not supported in BasicBlock') + self.bn1 = nn.BatchNorm2D( + inplanes, epsilon=1e-05, data_format=data_format) + self.conv1 = conv3x3(inplanes, planes, data_format=data_format) + self.bn2 = nn.BatchNorm2D( + planes, epsilon=1e-05, data_format=data_format) + self.prelu = nn.PReLU(planes, data_format=data_format) + self.conv2 = conv3x3(planes, planes, stride, data_format=data_format) + self.bn3 = nn.BatchNorm2D( + planes, epsilon=1e-05, data_format=data_format) + self.downsample = downsample + self.stride = stride + + def forward_impl(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + def forward(self, x): + return self.forward_impl(x) + + +class IResNet(Model): + def __init__(self, + block, + layers, + dropout=0, + num_features=512, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + class_num=93431, + pfc_config={"model_parallel": True, + "sample_ratio": 1.0}, + input_image_channel=3, + input_image_width=112, + input_image_height=112, + data_format="NCHW"): + super(IResNet, self).__init__() + + self.layers = layers + self.data_format = data_format + self.input_image_channel = input_image_channel + + assert input_image_width % 16 == 0 + assert input_image_height % 16 == 0 + feat_w = input_image_width // 16 + feat_h = input_image_height // 16 + self.fc_scale = feat_w * feat_h + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + 'replace_stride_with_dilation should be None or a 3-element tuple, got {}' + .format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2D( + self.input_image_channel, + self.inplanes, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False, + data_format=data_format) + self.bn1 = nn.BatchNorm2D( + self.inplanes, epsilon=1e-05, data_format=data_format) + self.prelu = nn.PReLU(self.inplanes, data_format=data_format) + self.layer1 = self._make_layer( + block, 64, layers[0], stride=2, data_format=data_format) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + data_format=data_format) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + data_format=data_format) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + data_format=data_format) + self.bn2 = nn.BatchNorm2D( + 512 * block.expansion, epsilon=1e-05, data_format=data_format) + self.dropout = nn.Dropout(p=dropout) + + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, + num_features) + self.features = nn.BatchNorm1D(num_features, epsilon=1e-05) + # self.features = nn.BatchNorm1D(num_features, epsilon=1e-05, weight_attr=False) + + for m in self.sublayers(): + if isinstance(m, paddle.nn.Conv2D): + init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + if zero_init_residual: + for m in self.sublayers(): + if isinstance(m, IBasicBlock): + init.constant_(m.bn2.weight, 0) + + def _make_layer(self, + block, + planes, + blocks, + stride=1, + dilate=False, + data_format="NCHW"): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1( + self.inplanes, + planes * block.expansion, + stride, + data_format=data_format), + nn.BatchNorm2D( + planes * block.expansion, + epsilon=1e-05, + data_format=data_format)) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + data_format=data_format)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + data_format=data_format)) + return nn.Sequential(*layers) + + def forward(self, inputs): + + if self.training: + with paddle.no_grad(): + # Note(GuoxiaWang) + # self.features = nn.BatchNorm1D(num_features, epsilon=1e-05, weight_attr=False) + self.features.weight.fill_(1.0) + + if isinstance(inputs, dict): + x = inputs['data'] + else: + x = inputs + + x.stop_gradient = True + if self.data_format == "NHWC": + x = paddle.tensor.transpose(x, [0, 2, 3, 1]) + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + + if self.data_format == "NHWC": + x = paddle.tensor.transpose(x, [0, 3, 1, 2]) + + # return embedding feature + if isinstance(inputs, dict): + res = {'logits': x} + if 'targets' in inputs: + res['targets'] = inputs['targets'] + else: + res = x + return res + + def load_pretrained(self, path, rank=0, finetune=False): + if not os.path.exists(path + '.pdparams'): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + + state_dict = paddle.load(path + ".pdparams") + + dist_param_path = path + "_rank{}.pdparams".format(rank) + if os.path.exists(dist_param_path): + dist_state_dict = paddle.load(dist_param_path) + state_dict.update(dist_state_dict) + + # clear + dist_state_dict.clear() + + if not finetune: + self.set_dict(state_dict) + return + + return + + def save(self, path, local_rank=0, rank=0): + dist_state_dict = collections.OrderedDict() + state_dict = self.state_dict() + for name, param in list(state_dict.items()): + if param.is_distributed: + dist_state_dict[name] = state_dict.pop(name) + + if local_rank == 0: + paddle.save(state_dict, path + ".pdparams") + + if len(dist_state_dict) > 0: + paddle.save(dist_state_dict, + path + "_rank{}.pdparams".format(rank)) + + +def IResNet18(**kwargs): + model = IResNet(IBasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def IResNet34(**kwargs): + model = IResNet(IBasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def IResNet50(**kwargs): + model = IResNet(IBasicBlock, [3, 4, 14, 3], **kwargs) + return model + + +def IResNet100(**kwargs): + model = IResNet(IBasicBlock, [3, 13, 30, 3], **kwargs) + return model + + +def IResNet200(**kwargs): + model = IResNet(IBasicBlock, [6, 26, 60, 6], **kwargs) + return model diff --git a/plsc/models/multi_task/__init__.py b/plsc/models/multi_task/__init__.py new file mode 100644 index 0000000..4d4247d --- /dev/null +++ b/plsc/models/multi_task/__init__.py @@ -0,0 +1,13 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/plsc/models/multi_task/head.py b/plsc/models/multi_task/head.py new file mode 100644 index 0000000..dc7b436 --- /dev/null +++ b/plsc/models/multi_task/head.py @@ -0,0 +1,213 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-task heads. +Only defined convolution blocks. +""" +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import Conv2D, BatchNorm, Linear, \ + MaxPool2D, Dropout, PReLU + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None, + data_format="NCHW"): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + data_format=data_format) + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self._batch_norm = BatchNorm( + num_filters, + act=act, + epsilon=1e-05, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(bn_name + "_offset"), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance", + data_layout=data_format) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class ConvBNLayerAttr(nn.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + padding=0, + act=None, + name=None): + super(ConvBNLayerAttr, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + weight_attr=ParamAttr(), + bias_attr=False) + self._batch_norm = BatchNorm(num_filters, act=act) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + shortcut=True, + name=None, + data_format="NCHW"): + super(BasicBlock, self).__init__() + self.stride = stride + bn_name = "bn_" + name[3:] + "_before" + self._batch_norm = BatchNorm( + num_channels, + act=None, + epsilon=1e-05, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(bn_name + "_offset"), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance", + data_layout=data_format) + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=1, + act=None, + name=name + "_branch2a", + data_format=data_format) + self.prelu = PReLU( + num_parameters=num_filters, + data_format=data_format, + name=name + "_branch2a_prelu") + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act=None, + name=name + "_branch2b", + data_format=data_format) + + if shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + stride=stride, + act=None, + name=name + "_branch1", + data_format=data_format) + + self.shortcut = shortcut + + def forward(self, inputs): + y = self._batch_norm(inputs) + y = self.conv0(y) + y = self.prelu(y) + conv1 = self.conv1(y) + + if self.shortcut: + short = self.short(inputs) + else: + short = inputs + y = paddle.add(x=short, y=conv1) + return y + + +class TaskBlock(nn.Layer): + def __init__(self, + num_channels, + num_filters, + stride=1, + shortcut=True, + padding=0, + name=None, + class_num=10, + task_occlu=False, + data_format="NCHW"): + super(TaskBlock, self).__init__() + self.conv_for_fc = ConvBNLayerAttr( + num_channels=num_channels, + num_filters=64, + filter_size=3, + stride=1, + padding=1, + act=None, + name="conv_for_fc") + self.prelu_bottom = PReLU( + num_parameters=64, data_format=data_format, name="prelu_bottom") + + self.conv0 = ConvBNLayerAttr( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=padding, + act=None) + self.pool = MaxPool2D(kernel_size=2, stride=2, padding=0) + self.fc0 = Linear(256, 128) + self.prelu0 = PReLU(num_parameters=64) + self.prelu1 = PReLU(num_parameters=128) + self.task_occlu = task_occlu + self.fc1 = Linear(128, class_num) + + def forward(self, inputs): + y = self.conv_for_fc(inputs) + # y = self.prelu_bottom(y) + y = self.conv0(y) + y = self.prelu0(y) + N = y.shape[0] + y = self.pool(y) + y = paddle.reshape(y, [N, -1]) # 128, 64 * 3 * 3 + y = self.fc0(y) + y = self.prelu1(y) + out = self.fc1(y) + return out diff --git a/plsc/utils/config.py b/plsc/utils/config.py index 09d8529..397ee56 100644 --- a/plsc/utils/config.py +++ b/plsc/utils/config.py @@ -188,7 +188,12 @@ def parse_args(): '--config', type=str, default='configs/config.yaml', - help='config file path') + help='config file path.') + parser.add_argument( + '-t', + '--mtl', + action='store_true', + help='The option of multi-task learning.') parser.add_argument( '-o', '--override', diff --git a/task/multi_task_classification/configs/multi_task_resnet18_dp_fp16o2_demo.yaml b/task/multi_task_classification/configs/multi_task_resnet18_dp_fp16o2_demo.yaml new file mode 100644 index 0000000..4f0f8f3 --- /dev/null +++ b/task/multi_task_classification/configs/multi_task_resnet18_dp_fp16o2_demo.yaml @@ -0,0 +1,199 @@ +# global configs +Global: + checkpoint: null # current main model + teacher_checkpoint: output/teacher/MTLModel/model_epoch0 + pretrained_model: null + output_dir: ./output/student + device: gpu + save_interval: 1 + max_num_latest_checkpoint: 3 + eval_during_train: True + eval_interval: 1 + eval_unit: "epoch" + accum_steps: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + seed: 2023 + task_names: [Arched_Eyebrows, Attractive, Bags_Under_Eyes, Bald, Clock_Shadow] + +# FP16 setting +FP16: + level: O0 + GradScaler: + init_loss_scaling: 65536.0 + + +DistributedStrategy: + data_parallel: True + recompute: + layerlist_interval: 4 + names: [] + + +# model architecture +Model: + name: MTLModel + backbone: + name: IResNet50 + num_features: 512 + data_format: "NHWC" + heads: + - TaskBlock: # head 类型,一个类型的head可以支持多个任务,但是每个任务有一个head实例 + task_ids: [0, 1, 2, 3, 4] + class_nums: [2, 10, 6, 5, 3] + num_channels: 512 + num_filters: 64 + data_format : "NHWC" + + +# loss function config for traing/eval process +Loss: + Train: + name: MTLoss + weights: [2, 2, 2, 2, 1] + losses: + - ViTCELoss: + task_ids: [0, 1, 2, 3, 4] + epsilon: 0.0001 + Eval: + name: MTLoss + weights: [2, 2, 2, 2, 1] + losses: + - CELoss: + task_ids: [0, 1, 2, 3, 4] + epsilon: 0.0001 + +LRScheduler: + name: ViTLRScheduler + learning_rate: 3e-3 + decay_type: cosine + warmup_steps: 10000 + +Optimizer: + name: AdamW + betas: (0.9, 0.999) + epsilon: 1e-8 + weight_decay: 0.3 + grad_clip: + name: ClipGradByGlobalNorm + clip_norm: 1.0 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ConcatDataset + dataset_ratio: [2, 1, 2, 20, 4] + datasets: + - SingleTaskDataset: # dataset类型,多个任务可以共用一种类型的dataset,但每个任务有自己的dataset实例,最终会concat成为一个整体的dataset + data_root: ./datasets/ + task_ids: [0, 1, 2, 3, 4] + label_path: [Arched_Eyebrows_label.txt, Attractive_label.txt, Bags_Under_Eyes_label.txt, Bald_label.txt, Clock_Shadow_label.txt] + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 112 + scale: [0.05, 1.0] + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + Eval: + dataset: + name: ConcatDataset + datasets: + - SingleTaskDataset: + data_root: ./datasets/ + task_ids: [0] + label_path: [Arched_Eyebrows_label.txt] + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 112 + scale: [0.05, 1.0] + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + order: '' + - ToCHWImage: + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + Test: + dataset: + name: ConcatDataset + datasets: + - MultiTaskDataset: + data_root: ./datasets/ + label_path: [test.txt] + task_ids: [[0, 1, 2, 3, 4]] + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 112 + scale: [ 0.05, 1.0 ] + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] + order: '' + - ToCHWImage: + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] + Test: + - TopkAcc: + topk: [1, 5] + +Export: + export_type: paddle + input_shape: [None, 3, 112, 112] diff --git a/tools/train.py b/tools/train.py index c7fcb43..486084c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -20,7 +20,7 @@ paddle.disable_static() from plsc.utils import config as cfg_util -from plsc.engine.engine import Engine +from plsc.engine import Engine, MTLEngine def main(): @@ -28,7 +28,10 @@ def main(): config = cfg_util.get_config( args.config, overrides=args.override, show=False) config.profiler_options = args.profiler_options - engine = Engine(config, mode="train") + if args.mtl: + engine = MTLEngine(config, mode="Train") + else: + engine = Engine(config, mode="train") engine.train()