1
+ import contextlib
1
2
import os
2
3
import pathlib
3
4
import platform
4
5
import re
5
6
import shutil
6
7
import sys
7
8
import sysconfig
9
+ from importlib .util import module_from_spec , spec_from_file_location
8
10
9
- from setuptools import setup
11
+ from setuptools import Extension , setup
12
+ from setuptools .command .build_ext import build_ext
10
13
11
14
12
- try :
13
- from pybind11 .setup_helpers import Pybind11Extension as Extension
14
- from pybind11 .setup_helpers import build_ext
15
- except ImportError :
16
- from setuptools import Extension
17
- from setuptools .command .build_ext import build_ext
18
-
19
15
HERE = pathlib .Path (__file__ ).absolute ().parent
20
- VERSION_FILE = HERE / 'torchopt' / 'version.py'
21
-
22
- sys .path .insert (0 , str (VERSION_FILE .parent ))
23
- import version # noqa
24
16
25
17
26
18
class CMakeExtension (Extension ):
@@ -47,7 +39,6 @@ def build_extension(self, ext):
47
39
build_temp .mkdir (parents = True , exist_ok = True )
48
40
49
41
config = 'Debug' if self .debug else 'Release'
50
-
51
42
cmake_args = [
52
43
f'-DCMAKE_BUILD_TYPE={ config } ' ,
53
44
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{ config .upper ()} ={ ext_path .parent } ' ,
@@ -83,13 +74,53 @@ def build_extension(self, ext):
83
74
84
75
build_args .extend (['--target' , ext .target , '--' ])
85
76
77
+ cwd = os .getcwd ()
86
78
try :
87
79
os .chdir (build_temp )
88
80
self .spawn ([cmake , ext .source_dir , * cmake_args ])
89
81
if not self .dry_run :
90
82
self .spawn ([cmake , '--build' , '.' , * build_args ])
91
83
finally :
92
- os .chdir (HERE )
84
+ os .chdir (cwd )
85
+
86
+
87
+ @contextlib .contextmanager
88
+ def vcs_version (name , path ):
89
+ path = pathlib .Path (path ).absolute ()
90
+ assert path .is_file ()
91
+ module_spec = spec_from_file_location (name = name , location = path )
92
+ assert module_spec is not None
93
+ assert module_spec .loader is not None
94
+ module = sys .modules .get (name )
95
+ if module is None :
96
+ module = module_from_spec (module_spec )
97
+ sys .modules [name ] = module
98
+ module_spec .loader .exec_module (module )
99
+
100
+ if module .__release__ :
101
+ yield module
102
+ return
103
+
104
+ content = None
105
+ try :
106
+ try :
107
+ content = path .read_text (encoding = 'utf-8' )
108
+ path .write_text (
109
+ data = re .sub (
110
+ r"""__version__\s*=\s*('[^']+'|"[^"]+")""" ,
111
+ f'__version__ = { module .__version__ !r} ' ,
112
+ string = content ,
113
+ ),
114
+ encoding = 'utf-8' ,
115
+ )
116
+ except OSError :
117
+ content = None
118
+
119
+ yield module
120
+ finally :
121
+ if content is not None :
122
+ with path .open (mode = 'wt' , encoding = 'utf-8' , newline = '' ) as file :
123
+ file .write (content )
93
124
94
125
95
126
CIBUILDWHEEL = os .getenv ('CIBUILDWHEEL' , '0' ) == '1'
@@ -112,29 +143,9 @@ def build_extension(self, ext):
112
143
ext_kwargs .clear ()
113
144
114
145
115
- VERSION_CONTENT = None
116
-
117
- try :
118
- if not version .__release__ :
119
- try :
120
- VERSION_CONTENT = VERSION_FILE .read_text (encoding = 'utf-8' )
121
- VERSION_FILE .write_text (
122
- data = re .sub (
123
- r"""__version__\s*=\s*('[^']+'|"[^"]+")""" ,
124
- f'__version__ = { version .__version__ !r} ' ,
125
- string = VERSION_CONTENT ,
126
- ),
127
- encoding = 'utf-8' ,
128
- )
129
- except OSError :
130
- VERSION_CONTENT = None
131
-
146
+ with vcs_version (name = 'torchopt.version' , path = (HERE / 'torchopt' / 'version.py' )) as version :
132
147
setup (
133
148
name = 'torchopt' ,
134
149
version = version .__version__ ,
135
150
** ext_kwargs ,
136
151
)
137
- finally :
138
- if VERSION_CONTENT is not None :
139
- with VERSION_FILE .open (mode = 'wt' , encoding = 'utf-8' , newline = '' ) as file :
140
- file .write (VERSION_CONTENT )
0 commit comments