Skip to content

Commit

Permalink
install nvidia apex via cli
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Aug 31, 2021
1 parent 3718e83 commit a4e374d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 116 deletions.
11 changes: 0 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,7 @@ venv.bak/
checkpoints/
\.wandb
\wandb
core.login18*
IGNORE/
output_*
log.txt
pad_crop_info.json

output/

apex
results
*.nrrd
Untitled.ipynb
core.*

projects/maastro_lung_proton_cbct_to_ct/experiments/2d_vnet_local.yaml
97 changes: 0 additions & 97 deletions environment.yml

This file was deleted.

11 changes: 8 additions & 3 deletions ganslate/nn/gans/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from loguru import logger
import os
import sys
from abc import ABC, abstractmethod
from pathlib import Path

Expand Down Expand Up @@ -115,9 +116,13 @@ def setup(self):
"The (main) generator has to be named `G` or `G_AB`."

if self.conf[self.conf.mode].mixed_precision:
from apex import amp

# Allow the methods to access AMP that was imported here
try:
from apex import amp
except ModuleNotFoundError:
sys.exit("\nMixed precision not installed! "
"Install Nvidia Apex mixed precision support "
"by running `ganslate install-nvidia-apex`'\n")
# Allow the methods to access AMP that's imported here
globals()["amp"] = amp

# Initialize Generators and Discriminators
Expand Down
16 changes: 16 additions & 0 deletions ganslate/utils/cli/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,21 @@ def download_dataset(name, path):
download_script_path = "ganslate/utils/scripts/download_cyclegan_datasets.sh"
subprocess.call(["bash", download_script_path, name, path])

# Install Nvidia Apex
@interface.command(help="Install Nvidia Apex for mixed precision support.")
@click.option(
"--cpp/--python",
default=True,
help=("C++ support is faster and preferred, use Python fallback "
"only when CUDA is not installed natively.")
)
def install_nvidia_apex(cpp):
# TODO: (Ibro) I need to verify this in a few days when I have access to the GPU
cmd = 'pip install -v --disable-pip-version-check --no-cache-dir'
if cpp:
cmd += ' --global-option="--cpp_ext" --global-option="--cuda_ext"'
cmd += ' git+https://github.com/NVIDIA/apex.git'
subprocess.run(cmd.split(' '))

if __name__ == "__main__":
interface()
3 changes: 0 additions & 3 deletions install_apex.sh

This file was deleted.

2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

0 comments on commit a4e374d

Please sign in to comment.