Skip to content

Commit

Permalink
Merge pull request #146 from Pennycook/compiler-argparse
Browse files Browse the repository at this point in the history
Implement common argument parsing with argparse
  • Loading branch information
Pennycook authored Feb 4, 2025
2 parents c9efb70 + 8eefddc commit 4dfb0c6
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 77 deletions.
135 changes: 58 additions & 77 deletions codebasin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
defining a specific code base configuration.
"""

import argparse
import collections
import logging
import os
Expand All @@ -15,62 +16,41 @@
log = logging.getLogger(__name__)


def extract_defines(argv):
"""
Extract definitions from command-line arguments.
Recognizes two argument "-D MACRO" and one argument "-DMACRO".
"""
defines = []
prefix = ""
for a in argv:
if a == "-D":
prefix = "-D"
elif prefix:
defines += [a]
prefix = ""
elif a[0:2] == "-D":
defines += [a[2:]]
prefix = ""
return defines


def extract_include_paths(argv):
"""
Extract include paths from command-line arguments.
Recognizes two argument "-I path" and one argument "-Ipath".
"""
prefixes = ["-I", "-isystem"]

include_paths = []
prefix = ""
for a in argv:
if a in prefixes:
prefix = a
elif prefix in prefixes:
include_paths += [a]
prefix = ""
elif a[0:2] == "-I":
include_paths += [a[2:]]
return include_paths


def extract_include_files(argv):
"""
Extract include files from command-line arguments.
Recognizes two argument "-include file".
"""
includes = []
prefix = ""
for a in argv:
if a == "-include":
prefix = "-include"
elif prefix:
includes += [a]
prefix = ""
return includes
_importcfg = None


_importcfg = None
def _parse_compiler_args(argv: list[str]):
"""
Parameters
----------
argv: list[str]
A list of arguments passed to a compiler.
Returns
-------
argparse.Namespace
The result of parsing `argv[1:]`.
- defines: -D arguments
- include_paths: -I/-isystem arguments
- include_files: -include arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("-D", dest="defines", action="append", default=[])
parser.add_argument(
"-I",
"-isystem",
dest="include_paths",
action="append",
default=[],
)
parser.add_argument(
"-include",
dest="include_files",
action="append",
default=[],
)
args, _ = parser.parse_known_args(argv)
return args


def load_importcfg():
Expand Down Expand Up @@ -99,16 +79,17 @@ class Compiler:
- Implicitly defined macros, which may be flag-dependent.
"""

def __init__(self, args):
self.name = os.path.basename(args[0])
self.args = args
def __init__(self, argv: list[str]):
self.name = os.path.basename(argv[0])
self.argv = argv
self.passes = {"default"}

# Check for any user-defined compiler behavior.
# Currently, users can only override default defines.
if _importcfg is None:
load_importcfg()
self.defines = extract_defines(_importcfg[self.name])
args = _parse_compiler_args(_importcfg[self.name])
self.defines = args.defines

def get_passes(self):
return self.passes.copy()
Expand Down Expand Up @@ -153,14 +134,14 @@ class ClangCompiler(Compiler):
"spir64_fpga",
]

def __init__(self, args):
super().__init__(args)
def __init__(self, argv: list[str]):
super().__init__(argv)

self.sycl = False
self.omp = False
sycl_targets = []

for arg in args:
for arg in argv:
if arg == "-fsycl":
self.sycl = True
continue
Expand Down Expand Up @@ -203,10 +184,10 @@ class GnuCompiler(Compiler):
Represents the behavior of GNU-based compilers.
"""

def __init__(self, args):
super().__init__(args)
def __init__(self, argv: list[str]):
super().__init__(argv)

for arg in args:
for arg in argv:
if arg in ["-fopenmp"]:
self.defines.append("_OPENMP")
break
Expand All @@ -217,29 +198,29 @@ class HipCompiler(Compiler):
Represents the behavior of the HIP compiler.
"""

def __init__(self, args):
super().__init__(args)
def __init__(self, argv: list[str]):
super().__init__(argv)


class IntelCompiler(ClangCompiler):
"""
Represents the behavior of Intel compilers.
"""

def __init__(self, args):
super().__init__(args)
def __init__(self, argv: list[str]):
super().__init__(argv)


class NvccCompiler(Compiler):
"""
Represents the behavior of the NVCC compiler.
"""

def __init__(self, args):
super().__init__(args)
def __init__(self, argv: list[str]):
super().__init__(argv)
self.omp = False

for arg in args:
for arg in argv:
archs = re.findall("sm_(\\d+)", arg)
archs += re.findall("compute_(\\d+)", arg)
self.passes |= set(archs)
Expand Down Expand Up @@ -267,7 +248,7 @@ def get_defines(self, pass_):
_seen_compiler = collections.defaultdict(lambda: False)


def recognize_compiler(argv):
def recognize_compiler(argv: list[str]) -> Compiler:
"""
Attempt to recognize the compiler, given an argument list.
Return a Compiler object.
Expand Down Expand Up @@ -313,11 +294,11 @@ def load_database(dbpath, rootdir):
continue
argv = command.arguments

# Extract defines, include paths and include files
# from command-line arguments
defines = extract_defines(argv)
include_paths = extract_include_paths(argv)
include_files = extract_include_files(argv)
# Parse common command-line arguments.
args = _parse_compiler_args(argv[1:])
defines = args.defines
include_paths = args.include_paths
include_files = args.include_files

# Certain tools may have additional, implicit, behaviors
# (e.g., additional defines, multiple passes for multiple targets)
Expand Down
59 changes: 59 additions & 0 deletions tests/compilers/test_compilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# SPDX-License-Identifier: BSD-3-Clause

import logging
import os
import tempfile
import unittest
from pathlib import Path

from codebasin import config

Expand All @@ -15,6 +18,40 @@ class TestCompilers(unittest.TestCase):

def setUp(self):
logging.disable()
self.cwd = os.getcwd()

def tearDown(self):
os.chdir(self.cwd)

def test_common(self):
"""compilers/common"""
argv = [
"c++",
"-I/path",
"-I",
"/path/after/space",
"-isystem",
"/system/path",
"-include",
"foo.inc",
"-include",
"bar.inc",
"-DMACRO",
"-DFUNCTION_MACRO=1",
"-D",
"MACRO_AFTER_SPACE",
"test.cpp",
]
args = config._parse_compiler_args(argv)
self.assertEqual(
args.defines,
["MACRO", "FUNCTION_MACRO=1", "MACRO_AFTER_SPACE"],
)
self.assertEqual(
args.include_paths,
["/path", "/path/after/space", "/system/path"],
)
self.assertEqual(args.include_files, ["foo.inc", "bar.inc"])

def test_clang(self):
"""compilers/clang"""
Expand Down Expand Up @@ -101,6 +138,28 @@ def test_nvcc(self):
defines = compiler.get_defines("52")
self.assertEqual(defines, defaults + ["__CUDA_ARCH__=520"])

def test_user_options(self):
"""Check that we import user-defined options"""
tmp = tempfile.TemporaryDirectory()
path = Path(tmp.name)
os.chdir(tmp.name)
os.mkdir(".cbi")
with open(path / ".cbi" / "config", mode="w") as f:
f.write('[compiler."c++"]\n')
f.write('options = ["-D", "ASDF"]\n')
config.load_importcfg()

args = [
"c++",
"test.cpp",
]

compiler = config.recognize_compiler(args)
defines = compiler.get_defines("default")
self.assertCountEqual(defines, ["ASDF"])

tmp.cleanup()


if __name__ == "__main__":
unittest.main()

0 comments on commit 4dfb0c6

Please sign in to comment.