Skip to content

Commit

Permalink
refactor(setup.py): refactor build system (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored May 10, 2024
1 parent 190ca72 commit 330ef57
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214).

### Fixed

Expand Down
81 changes: 46 additions & 35 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
import contextlib
import os
import pathlib
import platform
import re
import shutil
import sys
import sysconfig
from importlib.util import module_from_spec, spec_from_file_location

from setuptools import setup
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext


try:
from pybind11.setup_helpers import Pybind11Extension as Extension
from pybind11.setup_helpers import build_ext
except ImportError:
from setuptools import Extension
from setuptools.command.build_ext import build_ext

HERE = pathlib.Path(__file__).absolute().parent
VERSION_FILE = HERE / 'torchopt' / 'version.py'

sys.path.insert(0, str(VERSION_FILE.parent))
import version # noqa


class CMakeExtension(Extension):
Expand All @@ -47,7 +39,6 @@ def build_extension(self, ext):
build_temp.mkdir(parents=True, exist_ok=True)

config = 'Debug' if self.debug else 'Release'

cmake_args = [
f'-DCMAKE_BUILD_TYPE={config}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}',
Expand Down Expand Up @@ -83,13 +74,53 @@ def build_extension(self, ext):

build_args.extend(['--target', ext.target, '--'])

cwd = os.getcwd()
try:
os.chdir(build_temp)
self.spawn([cmake, ext.source_dir, *cmake_args])
if not self.dry_run:
self.spawn([cmake, '--build', '.', *build_args])
finally:
os.chdir(HERE)
os.chdir(cwd)


@contextlib.contextmanager
def vcs_version(name, path):
path = pathlib.Path(path).absolute()
assert path.is_file()
module_spec = spec_from_file_location(name=name, location=path)
assert module_spec is not None
assert module_spec.loader is not None
module = sys.modules.get(name)
if module is None:
module = module_from_spec(module_spec)
sys.modules[name] = module
module_spec.loader.exec_module(module)

if module.__release__:
yield module
return

content = None
try:
try:
content = path.read_text(encoding='utf-8')
path.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
f'__version__ = {module.__version__!r}',
string=content,
),
encoding='utf-8',
)
except OSError:
content = None

yield module
finally:
if content is not None:
with path.open(mode='wt', encoding='utf-8', newline='') as file:
file.write(content)


CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1'
Expand All @@ -112,29 +143,9 @@ def build_extension(self, ext):
ext_kwargs.clear()


VERSION_CONTENT = None

try:
if not version.__release__:
try:
VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8')
VERSION_FILE.write_text(
data=re.sub(
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
f'__version__ = {version.__version__!r}',
string=VERSION_CONTENT,
),
encoding='utf-8',
)
except OSError:
VERSION_CONTENT = None

with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version:
setup(
name='torchopt',
version=version.__version__,
**ext_kwargs,
)
finally:
if VERSION_CONTENT is not None:
with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file:
file.write(VERSION_CONTENT)

0 comments on commit 330ef57

Please sign in to comment.