Skip to content

Commit 23b4f69

Browse files
authored
Add new get_save_dir() function (ultralytics#4602)
1 parent 1121ef2 commit 23b4f69

File tree

6 files changed

+35
-39
lines changed

6 files changed

+35
-39
lines changed

ultralytics/cfg/__init__.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from types import SimpleNamespace
99
from typing import Dict, List, Union
1010

11-
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
12-
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
13-
yaml_print)
11+
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
12+
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
13+
yaml_load, yaml_print)
1414

1515
# Define valid tasks and modes
1616
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
@@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
146146
return IterableSimpleNamespace(**cfg)
147147

148148

149+
def get_save_dir(args):
150+
"""Return save_dir as created from train/val/predict arguments."""
151+
152+
if getattr(args, 'save_dir', None):
153+
save_dir = args.save_dir
154+
else:
155+
from ultralytics.utils.files import increment_path
156+
157+
project = args.project or Path(SETTINGS['runs_dir']) / args.task
158+
name = args.name or f'{args.mode}'
159+
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
160+
161+
return Path(save_dir)
162+
163+
149164
def _handle_deprecation(custom):
150-
"""Hardcoded function to handle deprecated config keys"""
165+
"""Hardcoded function to handle deprecated config keys."""
151166

152167
for key in custom.copy().keys():
153168
if key == 'hide_labels':
@@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
171186
Args:
172187
custom (dict): a dictionary of custom configuration options
173188
base (dict): a dictionary of base configuration options
189+
e (Error, optional): An optional error that is passed by the calling function.
174190
"""
175191
custom = _handle_deprecation(custom)
176192
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))

ultralytics/engine/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66
from typing import Union
77

8-
from ultralytics.cfg import get_cfg
8+
from ultralytics.cfg import get_cfg, get_save_dir
99
from ultralytics.engine.exporter import Exporter
1010
from ultralytics.hub.utils import HUB_WEB_ROOT
1111
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
@@ -239,7 +239,7 @@ def predict(self, source=None, stream=False, predictor=None, **kwargs):
239239
else: # only update args if predictor is already setup
240240
self.predictor.args = get_cfg(self.predictor.args, overrides)
241241
if 'project' in overrides or 'name' in overrides:
242-
self.predictor.save_dir = self.predictor.get_save_dir()
242+
self.predictor.save_dir = get_save_dir(self.predictor.args)
243243
# Set prompts for SAM/FastSAM
244244
if len and hasattr(self.predictor, 'set_prompts'):
245245
self.predictor.set_prompts(prompts)

ultralytics/engine/predictor.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
import numpy as np
3535
import torch
3636

37-
from ultralytics.cfg import get_cfg
37+
from ultralytics.cfg import get_cfg, get_save_dir
3838
from ultralytics.data import load_inference_source
3939
from ultralytics.data.augment import LetterBox, classify_transforms
4040
from ultralytics.nn.autobackend import AutoBackend
41-
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
41+
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
4242
from ultralytics.utils.checks import check_imgsz, check_imshow
4343
from ultralytics.utils.files import increment_path
4444
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
@@ -84,7 +84,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
8484
overrides (dict, optional): Configuration overrides. Defaults to None.
8585
"""
8686
self.args = get_cfg(cfg, overrides)
87-
self.save_dir = self.get_save_dir()
87+
self.save_dir = get_save_dir(self.args)
8888
if self.args.conf is None:
8989
self.args.conf = 0.25 # default conf=0.25
9090
self.done_warmup = False
@@ -108,11 +108,6 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
108108
self.txt_path = None
109109
callbacks.add_integration_callbacks(self)
110110

111-
def get_save_dir(self):
112-
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
113-
name = self.args.name or f'{self.args.mode}'
114-
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
115-
116111
def preprocess(self, im):
117112
"""Prepares input image before inference.
118113

ultralytics/engine/results.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,14 +323,10 @@ def save_crop(self, save_dir, file_name=Path('im.jpg')):
323323
if self.probs is not None:
324324
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
325325
return
326-
if isinstance(save_dir, str):
327-
save_dir = Path(save_dir)
328-
if isinstance(file_name, str):
329-
file_name = Path(file_name)
330326
for d in self.boxes:
331327
save_one_box(d.xyxy,
332328
self.orig_img.copy(),
333-
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
329+
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
334330
BGR=True)
335331

336332
def tojson(self, normalize=False):

ultralytics/engine/trainer.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
from torch.nn.parallel import DistributedDataParallel as DDP
2424
from tqdm import tqdm
2525

26-
from ultralytics.cfg import get_cfg
26+
from ultralytics.cfg import get_cfg, get_save_dir
2727
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
2828
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
29-
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
30-
colorstr, emojis, yaml_save)
29+
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr,
30+
emojis, yaml_save)
3131
from ultralytics.utils.autobatch import check_train_batch_size
3232
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
3333
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
34-
from ultralytics.utils.files import get_latest_run, increment_path
34+
from ultralytics.utils.files import get_latest_run
3535
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
3636
strip_optimizer)
3737

@@ -91,13 +91,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
9191
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
9292

9393
# Dirs
94-
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
95-
name = self.args.name or f'{self.args.mode}'
96-
if hasattr(self.args, 'save_dir'):
97-
self.save_dir = Path(self.args.save_dir)
98-
else:
99-
self.save_dir = Path(
100-
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
94+
self.save_dir = get_save_dir(self.args)
10195
self.wdir = self.save_dir / 'weights' # weights dir
10296
if RANK in (-1, 0):
10397
self.wdir.mkdir(parents=True, exist_ok=True) # make dir

ultralytics/engine/validator.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
import torch
2727
from tqdm import tqdm
2828

29-
from ultralytics.cfg import get_cfg
29+
from ultralytics.cfg import get_cfg, get_save_dir
3030
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
3131
from ultralytics.nn.autobackend import AutoBackend
32-
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
32+
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
3333
from ultralytics.utils.checks import check_imgsz
34-
from ultralytics.utils.files import increment_path
3534
from ultralytics.utils.ops import Profile
3635
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
3736

@@ -71,7 +70,7 @@ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callba
7170
7271
Args:
7372
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
74-
save_dir (Path): Directory to save results.
73+
save_dir (Path, optional): Directory to save results.
7574
pbar (tqdm.tqdm): Progress bar for displaying progress.
7675
args (SimpleNamespace): Configuration for the validator.
7776
_callbacks (dict): Dictionary to store various callback functions.
@@ -93,12 +92,8 @@ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callba
9392
self.jdict = None
9493
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
9594

96-
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
97-
name = self.args.name or f'{self.args.mode}'
98-
self.save_dir = save_dir or increment_path(Path(project) / name,
99-
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
95+
self.save_dir = save_dir or get_save_dir(self.args)
10096
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
101-
10297
if self.args.conf is None:
10398
self.args.conf = 0.001 # default conf=0.001
10499

0 commit comments

Comments
 (0)