-
Notifications
You must be signed in to change notification settings - Fork 2
/
setup.py
121 lines (101 loc) · 2.78 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import io
import os
import sys
import glob
from setuptools import find_packages, setup, Extension
import pkg_resources
NAME = 'hanser'
IMPORT_NAME = 'hanser'
DESCRIPTION = "A library to help with training for different tasks in TensorFlow."
URL = 'https://github.com/sbl1996/hanser'
EMAIL = '[email protected]'
AUTHOR = 'HrvvI'
REQUIRES_PYTHON = '>=3.6.0'
VERSION = None
REQUIRED = [
"Pillow",
"numpy",
"toolz",
"pybind11",
"cerberus",
"tensorflow_datasets>=4.3.0",
"hhutil",
"lark",
"pendulum",
"loguru",
"typer",
'typeguard',
"pandas",
"packaging",
]
tfp_version_compat_table = {
"2.8": "0.15.0",
"2.7": "0.15.0",
"2.6": "0.14.1",
"2.5": "0.13.0",
"2.4": "0.12.2",
"2.3": "0.11.1",
}
def get_tf_version():
try:
version = pkg_resources.get_distribution("tensorflow").version
except pkg_resources.DistributionNotFound:
version = pkg_resources.get_distribution("tf_nightly").version
return version.rsplit('.', 1)[0]
tfp_version = tfp_version_compat_table[get_tf_version()]
REQUIRED.append(f"tensorflow_probability=={tfp_version}")
here = os.path.dirname(os.path.abspath(__file__))
try:
with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION
about = {}
if not VERSION:
with open(os.path.join(here, IMPORT_NAME, '_version.py')) as f:
exec(f.read(), about)
else:
about['__version__'] = VERSION
def get_pybind_include(user=False):
import pybind11
return pybind11.get_include(user)
def get_numpy_extensions():
extensions_dir = os.path.join(here, IMPORT_NAME, 'csrc', 'numpy')
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
extra_compile_args = []
if sys.platform == 'darwin':
extra_compile_args += ['-stdlib=libc++',
'-mmacosx-version-min=10.9']
include_dirs = [
extensions_dir,
get_pybind_include(),
get_pybind_include(user=True),
]
ext_modules = [
Extension(
IMPORT_NAME + '._numpy',
main_file,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name=NAME,
version=about['__version__'],
description=DESCRIPTION,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
author_email=EMAIL,
python_requires=REQUIRES_PYTHON,
url=URL,
packages=find_packages(exclude=('tests',)),
install_requires=REQUIRED,
dependency_links=[],
# include_package_data=True,
license='MIT',
# ext_modules=get_numpy_extensions(),
)