diff --git a/codebasin/config.py b/codebasin/config.py index 0923587..4441543 100644 --- a/codebasin/config.py +++ b/codebasin/config.py @@ -5,6 +5,7 @@ defining a specific code base configuration. """ +import argparse import collections import logging import os @@ -15,62 +16,41 @@ log = logging.getLogger(__name__) -def extract_defines(argv): - """ - Extract definitions from command-line arguments. - Recognizes two argument "-D MACRO" and one argument "-DMACRO". - """ - defines = [] - prefix = "" - for a in argv: - if a == "-D": - prefix = "-D" - elif prefix: - defines += [a] - prefix = "" - elif a[0:2] == "-D": - defines += [a[2:]] - prefix = "" - return defines - - -def extract_include_paths(argv): - """ - Extract include paths from command-line arguments. - Recognizes two argument "-I path" and one argument "-Ipath". - """ - prefixes = ["-I", "-isystem"] - - include_paths = [] - prefix = "" - for a in argv: - if a in prefixes: - prefix = a - elif prefix in prefixes: - include_paths += [a] - prefix = "" - elif a[0:2] == "-I": - include_paths += [a[2:]] - return include_paths - - -def extract_include_files(argv): - """ - Extract include files from command-line arguments. - Recognizes two argument "-include file". - """ - includes = [] - prefix = "" - for a in argv: - if a == "-include": - prefix = "-include" - elif prefix: - includes += [a] - prefix = "" - return includes +_importcfg = None -_importcfg = None +def _parse_compiler_args(argv: list[str]): + """ + Parameters + ---------- + argv: list[str] + A list of arguments passed to a compiler. + + Returns + ------- + argparse.Namespace + The result of parsing `argv[1:]`. + - defines: -D arguments + - include_paths: -I/-isystem arguments + - include_files: -include arguments + """ + parser = argparse.ArgumentParser() + parser.add_argument("-D", dest="defines", action="append", default=[]) + parser.add_argument( + "-I", + "-isystem", + dest="include_paths", + action="append", + default=[], + ) + parser.add_argument( + "-include", + dest="include_files", + action="append", + default=[], + ) + args, _ = parser.parse_known_args(argv) + return args def load_importcfg(): @@ -99,16 +79,17 @@ class Compiler: - Implicitly defined macros, which may be flag-dependent. """ - def __init__(self, args): - self.name = os.path.basename(args[0]) - self.args = args + def __init__(self, argv: list[str]): + self.name = os.path.basename(argv[0]) + self.argv = argv self.passes = {"default"} # Check for any user-defined compiler behavior. # Currently, users can only override default defines. if _importcfg is None: load_importcfg() - self.defines = extract_defines(_importcfg[self.name]) + args = _parse_compiler_args(_importcfg[self.name]) + self.defines = args.defines def get_passes(self): return self.passes.copy() @@ -153,14 +134,14 @@ class ClangCompiler(Compiler): "spir64_fpga", ] - def __init__(self, args): - super().__init__(args) + def __init__(self, argv: list[str]): + super().__init__(argv) self.sycl = False self.omp = False sycl_targets = [] - for arg in args: + for arg in argv: if arg == "-fsycl": self.sycl = True continue @@ -203,10 +184,10 @@ class GnuCompiler(Compiler): Represents the behavior of GNU-based compilers. """ - def __init__(self, args): - super().__init__(args) + def __init__(self, argv: list[str]): + super().__init__(argv) - for arg in args: + for arg in argv: if arg in ["-fopenmp"]: self.defines.append("_OPENMP") break @@ -217,8 +198,8 @@ class HipCompiler(Compiler): Represents the behavior of the HIP compiler. """ - def __init__(self, args): - super().__init__(args) + def __init__(self, argv: list[str]): + super().__init__(argv) class IntelCompiler(ClangCompiler): @@ -226,8 +207,8 @@ class IntelCompiler(ClangCompiler): Represents the behavior of Intel compilers. """ - def __init__(self, args): - super().__init__(args) + def __init__(self, argv: list[str]): + super().__init__(argv) class NvccCompiler(Compiler): @@ -235,11 +216,11 @@ class NvccCompiler(Compiler): Represents the behavior of the NVCC compiler. """ - def __init__(self, args): - super().__init__(args) + def __init__(self, argv: list[str]): + super().__init__(argv) self.omp = False - for arg in args: + for arg in argv: archs = re.findall("sm_(\\d+)", arg) archs += re.findall("compute_(\\d+)", arg) self.passes |= set(archs) @@ -267,7 +248,7 @@ def get_defines(self, pass_): _seen_compiler = collections.defaultdict(lambda: False) -def recognize_compiler(argv): +def recognize_compiler(argv: list[str]) -> Compiler: """ Attempt to recognize the compiler, given an argument list. Return a Compiler object. @@ -313,11 +294,11 @@ def load_database(dbpath, rootdir): continue argv = command.arguments - # Extract defines, include paths and include files - # from command-line arguments - defines = extract_defines(argv) - include_paths = extract_include_paths(argv) - include_files = extract_include_files(argv) + # Parse common command-line arguments. + args = _parse_compiler_args(argv[1:]) + defines = args.defines + include_paths = args.include_paths + include_files = args.include_files # Certain tools may have additional, implicit, behaviors # (e.g., additional defines, multiple passes for multiple targets) diff --git a/tests/compilers/test_compilers.py b/tests/compilers/test_compilers.py index 02a3bec..bfcf90f 100644 --- a/tests/compilers/test_compilers.py +++ b/tests/compilers/test_compilers.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause import logging +import os +import tempfile import unittest +from pathlib import Path from codebasin import config @@ -15,6 +18,40 @@ class TestCompilers(unittest.TestCase): def setUp(self): logging.disable() + self.cwd = os.getcwd() + + def tearDown(self): + os.chdir(self.cwd) + + def test_common(self): + """compilers/common""" + argv = [ + "c++", + "-I/path", + "-I", + "/path/after/space", + "-isystem", + "/system/path", + "-include", + "foo.inc", + "-include", + "bar.inc", + "-DMACRO", + "-DFUNCTION_MACRO=1", + "-D", + "MACRO_AFTER_SPACE", + "test.cpp", + ] + args = config._parse_compiler_args(argv) + self.assertEqual( + args.defines, + ["MACRO", "FUNCTION_MACRO=1", "MACRO_AFTER_SPACE"], + ) + self.assertEqual( + args.include_paths, + ["/path", "/path/after/space", "/system/path"], + ) + self.assertEqual(args.include_files, ["foo.inc", "bar.inc"]) def test_clang(self): """compilers/clang""" @@ -101,6 +138,28 @@ def test_nvcc(self): defines = compiler.get_defines("52") self.assertEqual(defines, defaults + ["__CUDA_ARCH__=520"]) + def test_user_options(self): + """Check that we import user-defined options""" + tmp = tempfile.TemporaryDirectory() + path = Path(tmp.name) + os.chdir(tmp.name) + os.mkdir(".cbi") + with open(path / ".cbi" / "config", mode="w") as f: + f.write('[compiler."c++"]\n') + f.write('options = ["-D", "ASDF"]\n') + config.load_importcfg() + + args = [ + "c++", + "test.cpp", + ] + + compiler = config.recognize_compiler(args) + defines = compiler.get_defines("default") + self.assertCountEqual(defines, ["ASDF"]) + + tmp.cleanup() + if __name__ == "__main__": unittest.main()