Skip to content

Commit

Permalink
Fix destination feature for download models/tests #885 (#889)
Browse files Browse the repository at this point in the history
* Fix destination feature for download models/tests

* Force actions

* Switch cli calls to functions

* Add overrite flag

* Add overrite flag

* Update .github/workflows/run_tests.yaml

Co-authored-by: Armand Collin <[email protected]>

* Update pyproject.toml

---------

Co-authored-by: Armand Collin <[email protected]>
  • Loading branch information
mathieuboudreau and hermancollin authored Feb 19, 2025
1 parent 59ec3d2 commit 75ab39a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 20 deletions.
33 changes: 25 additions & 8 deletions AxonDeepSeg/download_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import AxonDeepSeg
from AxonDeepSeg.ads_utils import convert_path, download_data
from AxonDeepSeg.model_cards import MODELS
from pathlib import Path
Expand All @@ -10,17 +11,31 @@
# exit codes
SUCCESS, MODEL_NOT_FOUND, DOWNLOAD_ERROR = 0, 1, 2

def download_model(model='generalist', model_type='light', destination=None):

def download_model(model='generalist', model_type='light', destination=None, overwrite=True):
'''
Download a model for AxonDeepSeg.
Parameters
----------
model : str, optional
Name of the model, by default 'generalist'.
model_type : Literal['light', 'ensemble'], optional
Type of model, by default 'light'.
destination : str, optional
Directory to download the model to. Default: None.
'''
model_suffix = 'light' if model_type == 'light' else 'ensemble'
full_model_name = f'{MODELS[model]["name"]}_{model_suffix}'

if destination is None:
model_destination = Path(".") / "models" / full_model_name
package_dir = Path(AxonDeepSeg.__file__).parent # Get AxonDeepSeg installation path
model_destination = package_dir / "models" / full_model_name
else:
destination = Path(destination)
model_destination = destination / full_model_name

if model_destination.exists() and overwrite == False:
logger.info("Overwrite set to False - not deleting old model.")
return model_destination

url_model_destination = MODELS[model]['weights'][model_type]
if url_model_destination is None:
logger.error('Model not found.')
Expand All @@ -36,13 +51,16 @@ def download_model(model='generalist', model_type='light', destination=None):

# retrieving unknown model folder name
folder_name = list(set(files_after) - set(files_before))[0]
output_dir = model_destination.resolve()

if model_destination.exists():
logger.info("Model folder already existed - deleting old one")
shutil.rmtree(str(model_destination))

shutil.move(folder_name, str(model_destination))

return output_dir

def main(argv=None):
ap = argparse.ArgumentParser()
ap.add_argument(
Expand All @@ -69,9 +87,8 @@ def main(argv=None):
ap.add_argument(
"-d", "--dir",
required=False,
help="Directory to download the model to. Default: current directory",
default=str(Path('.') / 'models'),
type=str,
help="Directory to download the model to. Default: AxonDeepSeg/models",
default = None,
)
args = vars(ap.parse_args(argv))

Expand All @@ -88,7 +105,7 @@ def main(argv=None):
pprint.pprint(model_details)
sys.exit(SUCCESS)
else:
download_model(args["model_name"], args["model_type"], args["dir"])
download_model(args["model_name"], args["model_type"], args["dir"], overwrite=True)

if __name__ == "__main__":
with logger.catch():
Expand Down
38 changes: 36 additions & 2 deletions AxonDeepSeg/download_tests.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,32 @@
import AxonDeepSeg
from AxonDeepSeg.ads_utils import download_data, convert_path
from pathlib import Path
from loguru import logger
import shutil
import argparse


def download_tests(destination=None):
def download_tests(destination=None, overwrite=True):
'''
Download test data for AxonDeepSeg.
Parameters
----------
destination : str
Directory to download the tests to. Default: test/
'''
# Get AxonDeepSeg installation path
package_dir = Path(AxonDeepSeg.__file__).parent
if destination is None:
test_files_destination = Path("test/__test_files__")
test_files_destination = package_dir.parent / "test" / "__test_files__"
else:
destination = convert_path(destination)
test_files_destination = destination / "__test_files__"

if test_files_destination.exists() and overwrite == False:
logger.info("Overwrite set to False - not deleting old test files.")
return test_files_destination

url_tests = "https://github.com/axondeepseg/data-testing/archive/refs/tags/r20250110.zip"
files_before = list(Path.cwd().iterdir())

Expand All @@ -27,6 +44,8 @@ def download_tests(destination=None):
# retrieving unknown test files names
test_folder = list(set(files_after)-set(files_before))
folder_name_test_files = ''.join([str(x) for x in test_folder if 'data-testing' in str(x)])
output_dir=test_files_destination.resolve()


if test_files_destination.exists():
print('Test files folder already existed - deleting old one.')
Expand All @@ -37,5 +56,20 @@ def download_tests(destination=None):
# remove temporary folder
shutil.rmtree(folder_name_test_files)

return output_dir

def main(argv=None):
download_tests()
ap = argparse.ArgumentParser()
ap.add_argument(
"-d", "--dir",
required=False,
help="Directory to download the tests to. Default: test/",
default = None,
)
args = vars(ap.parse_args(argv))
download_tests(args["dir"], overwrite=True)

if __name__ == "__main__":
with logger.catch():
main()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
dependencies =[
"numpy<2",
"scipy",
"scikit-image<0.25",
"scikit-image!=0.25.0,!=0.25.1",
"tabulate",
"pandas",
"matplotlib",
Expand Down
14 changes: 6 additions & 8 deletions test/test_download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def test_download_valid_model_works(self):
download_model(self.valid_model, 'light', self.tmpPath)
assert self.valid_model_path.exists()

@pytest.mark.single
def test_main_cli_runs_succesfully_no_destination(self):
cli_test_model_path = Path(AxonDeepSeg.__file__).parent / 'models' / 'model_seg_generalist_light'
output_dir = download_model(destination=None, overwrite=False)
assert output_dir == cli_test_model_path

@pytest.mark.unit
def test_download_model_cli_throws_error_for_unavailable_model(self):
with pytest.raises(SystemExit) as pytest_wrapped_e:
Expand Down Expand Up @@ -80,14 +86,6 @@ def test_main_cli_runs_succesfully_for_list_models(self):

assert (pytest_wrapped_e.type == SystemExit) and (pytest_wrapped_e.value.code == 0)

@pytest.mark.integration
def test_main_cli_runs_succesfully_no_destination(self):
cli_test_model_path = Path('.') / 'models' / 'model_seg_generalist_light'

AxonDeepSeg.download_model.main(["-t","light"])

assert cli_test_model_path.exists()

@pytest.mark.integration
def test_main_cli_downloads_to_path(self):
cli_test_path = self.tmpPath / 'cli_test'
Expand Down
19 changes: 18 additions & 1 deletion test/test_download_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8

import AxonDeepSeg
from pathlib import Path
import shutil
import imageio
Expand Down Expand Up @@ -53,6 +53,23 @@ def test_download_tests_works(self):

assert self.test_files_path.exists()

@pytest.mark.integration
def test_download_tests_runs_succesfully_with_destination(self):
assert not self.test_files_path.exists()

download_tests(self.tmpPath)

assert self.test_files_path.exists()

@pytest.mark.single
def test_main_cli_runs_succesfully_no_destination(self):
cli_test_model_path = Path(AxonDeepSeg.__file__).parent.parent / 'test' / '__test_files__'

output_dir = download_tests(destination=None, overwrite=False)

assert output_dir == cli_test_model_path


@pytest.mark.unit
def test_redownload_test_files_multiple_times_works(self):

Expand Down

0 comments on commit 75ab39a

Please sign in to comment.