Skip to content

Commit

Permalink
Merge branch 'master' into disable_logger_for_PT2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Feb 4, 2025
2 parents a9dc0f9 + f4caed6 commit d3ad00a
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ def compute_capability_args(self, cross_compile_archs=None):
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
Expand Down Expand Up @@ -651,9 +651,9 @@ def compute_capability_args(self, cross_compile_archs=None):
args = []
self.enable_bf16 = True
for cc in ccs:
num = cc[0] + cc[2]
num = cc[0] + cc[1].split('+')[0]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc.endswith('+PTX'):
if cc[1].endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

if int(cc[0]) <= 7:
Expand All @@ -666,7 +666,7 @@ def filter_ccs(self, ccs: List[str]):
Prune any compute capabilities that are not compatible with the builder. Should log
which CCs have been pruned.
"""
return ccs
return [cc.split('.') for cc in ccs]

def version_dependent_macros(self):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
Expand Down
2 changes: 1 addition & 1 deletion op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
ccs_retained.append(cc)
else:
Expand Down
2 changes: 1 addition & 1 deletion op_builder/inference_core_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:
Expand Down
2 changes: 1 addition & 1 deletion op_builder/inference_cutlass_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
# Only support Ampere and newer
ccs_retained.append(cc)
Expand Down
2 changes: 1 addition & 1 deletion op_builder/ragged_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
# Blocked flash has a dependency on Ampere + newer
ccs_retained.append(cc)
Expand Down
2 changes: 1 addition & 1 deletion op_builder/ragged_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:
Expand Down
2 changes: 1 addition & 1 deletion op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_compatible(self, verbose=False):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:
Expand Down

0 comments on commit d3ad00a

Please sign in to comment.