diff --git a/op_builder/builder.py b/op_builder/builder.py index cdd11f00cabf..ac850775cb11 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -612,8 +612,8 @@ def compute_capability_args(self, cross_compile_archs=None): - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: - TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... - TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... - `cross_compile_archs` uses ; separator. @@ -651,9 +651,9 @@ def compute_capability_args(self, cross_compile_archs=None): args = [] self.enable_bf16 = True for cc in ccs: - num = cc[0] + cc[2] + num = cc[0] + cc[1].split('+')[0] args.append(f'-gencode=arch=compute_{num},code=sm_{num}') - if cc.endswith('+PTX'): + if cc[1].endswith('+PTX'): args.append(f'-gencode=arch=compute_{num},code=compute_{num}') if int(cc[0]) <= 7: @@ -666,7 +666,7 @@ def filter_ccs(self, ccs: List[str]): Prune any compute capabilities that are not compatible with the builder. Should log which CCs have been pruned. """ - return ccs + return [cc.split('.') for cc in ccs] def version_dependent_macros(self): # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index daa41a8148f5..e42927bd065d 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -78,7 +78,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 8: ccs_retained.append(cc) else: diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index f7c0b47f92c6..b6665ebb7618 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 6: ccs_retained.append(cc) else: diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py index aa5294b1cbda..a4a607288ca8 100644 --- a/op_builder/inference_cutlass_builder.py +++ b/op_builder/inference_cutlass_builder.py @@ -45,7 +45,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 8: # Only support Ampere and newer ccs_retained.append(cc) diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index 10afb193c738..0df28cc2282a 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 8: # Blocked flash has a dependency on Ampere + newer ccs_retained.append(cc) diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py index 654ba07e0879..208c9f833ebe 100755 --- a/op_builder/ragged_utils.py +++ b/op_builder/ragged_utils.py @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 6: ccs_retained.append(cc) else: diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 1b056ecef3ed..642aed56a192 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -44,7 +44,7 @@ def is_compatible(self, verbose=False): def filter_ccs(self, ccs): ccs_retained = [] ccs_pruned = [] - for cc in ccs: + for cc in [cc.split('.') for cc in ccs]: if int(cc[0]) >= 6: ccs_retained.append(cc) else: