Skip to content

Commit 330ef57

Browse files
authored
refactor(setup.py): refactor build system (#214)
1 parent 190ca72 commit 330ef57

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Changed
1919

20-
-
20+
- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](https://github.com/XuehaiPan) in [#214](https://github.com/metaopt/torchopt/pull/214).
2121

2222
### Fixed
2323

setup.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
1+
import contextlib
12
import os
23
import pathlib
34
import platform
45
import re
56
import shutil
67
import sys
78
import sysconfig
9+
from importlib.util import module_from_spec, spec_from_file_location
810

9-
from setuptools import setup
11+
from setuptools import Extension, setup
12+
from setuptools.command.build_ext import build_ext
1013

1114

12-
try:
13-
from pybind11.setup_helpers import Pybind11Extension as Extension
14-
from pybind11.setup_helpers import build_ext
15-
except ImportError:
16-
from setuptools import Extension
17-
from setuptools.command.build_ext import build_ext
18-
1915
HERE = pathlib.Path(__file__).absolute().parent
20-
VERSION_FILE = HERE / 'torchopt' / 'version.py'
21-
22-
sys.path.insert(0, str(VERSION_FILE.parent))
23-
import version # noqa
2416

2517

2618
class CMakeExtension(Extension):
@@ -47,7 +39,6 @@ def build_extension(self, ext):
4739
build_temp.mkdir(parents=True, exist_ok=True)
4840

4941
config = 'Debug' if self.debug else 'Release'
50-
5142
cmake_args = [
5243
f'-DCMAKE_BUILD_TYPE={config}',
5344
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}',
@@ -83,13 +74,53 @@ def build_extension(self, ext):
8374

8475
build_args.extend(['--target', ext.target, '--'])
8576

77+
cwd = os.getcwd()
8678
try:
8779
os.chdir(build_temp)
8880
self.spawn([cmake, ext.source_dir, *cmake_args])
8981
if not self.dry_run:
9082
self.spawn([cmake, '--build', '.', *build_args])
9183
finally:
92-
os.chdir(HERE)
84+
os.chdir(cwd)
85+
86+
87+
@contextlib.contextmanager
88+
def vcs_version(name, path):
89+
path = pathlib.Path(path).absolute()
90+
assert path.is_file()
91+
module_spec = spec_from_file_location(name=name, location=path)
92+
assert module_spec is not None
93+
assert module_spec.loader is not None
94+
module = sys.modules.get(name)
95+
if module is None:
96+
module = module_from_spec(module_spec)
97+
sys.modules[name] = module
98+
module_spec.loader.exec_module(module)
99+
100+
if module.__release__:
101+
yield module
102+
return
103+
104+
content = None
105+
try:
106+
try:
107+
content = path.read_text(encoding='utf-8')
108+
path.write_text(
109+
data=re.sub(
110+
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
111+
f'__version__ = {module.__version__!r}',
112+
string=content,
113+
),
114+
encoding='utf-8',
115+
)
116+
except OSError:
117+
content = None
118+
119+
yield module
120+
finally:
121+
if content is not None:
122+
with path.open(mode='wt', encoding='utf-8', newline='') as file:
123+
file.write(content)
93124

94125

95126
CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1'
@@ -112,29 +143,9 @@ def build_extension(self, ext):
112143
ext_kwargs.clear()
113144

114145

115-
VERSION_CONTENT = None
116-
117-
try:
118-
if not version.__release__:
119-
try:
120-
VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8')
121-
VERSION_FILE.write_text(
122-
data=re.sub(
123-
r"""__version__\s*=\s*('[^']+'|"[^"]+")""",
124-
f'__version__ = {version.__version__!r}',
125-
string=VERSION_CONTENT,
126-
),
127-
encoding='utf-8',
128-
)
129-
except OSError:
130-
VERSION_CONTENT = None
131-
146+
with vcs_version(name='torchopt.version', path=(HERE / 'torchopt' / 'version.py')) as version:
132147
setup(
133148
name='torchopt',
134149
version=version.__version__,
135150
**ext_kwargs,
136151
)
137-
finally:
138-
if VERSION_CONTENT is not None:
139-
with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file:
140-
file.write(VERSION_CONTENT)

0 commit comments

Comments
 (0)