Skip to content

Commit efd6c80

Browse files
gwang111Ashish Gupta
andauthored
feat: Optimize() validations across TRT, VLLM, Neuron container optimizations (#4927)
* changes for blackbird - model sharding changes for blackbird - model sharding add more tests fix sharded model flag add optimization validations fix formatting and msging fixing validation bugs add UTs simplify logic update messaging formatting fix UTs add more UTs fix validations update ruleset update formatting update validation logic update bug fixes Disable network isolation if using sharded models. check sharding + network iso pre optimization add more UTs for sharding add more UTs * fix rebase issues --------- Co-authored-by: Ashish Gupta <[email protected]>
1 parent 663bbb6 commit efd6c80

File tree

8 files changed

+1166
-21
lines changed

8 files changed

+1166
-21
lines changed

src/sagemaker/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def __init__(
372372
self.endpoint_name = None
373373
self.inference_component_name = None
374374
self._is_compiled_model = False
375+
self._is_sharded_model = False
375376
self._compilation_job_name = None
376377
self._is_edge_packaged_model = False
377378
self.inference_recommender_job_results = None
@@ -1599,6 +1600,19 @@ def deploy(
15991600
if self._base_name is not None:
16001601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16011602

1603+
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604+
logging.warning(
1605+
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1606+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1607+
)
1608+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
1609+
1610+
if self._is_sharded_model and self._enable_network_isolation:
1611+
raise ValueError(
1612+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613+
"Loading of model requires network access."
1614+
)
1615+
16021616
# Support multiple models on same endpoint
16031617
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16041618
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,7 @@ def _optimize_for_jumpstart(
684684
quantization_config: Optional[Dict] = None,
685685
compilation_config: Optional[Dict] = None,
686686
speculative_decoding_config: Optional[Dict] = None,
687+
sharding_config: Optional[Dict] = None,
687688
env_vars: Optional[Dict] = None,
688689
vpc_config: Optional[Dict] = None,
689690
kms_key: Optional[str] = None,
@@ -705,6 +706,8 @@ def _optimize_for_jumpstart(
705706
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
706707
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
707708
Defaults to ``None``
709+
sharding_config (Optional[Dict]): Model sharding configuration.
710+
Defaults to ``None``
708711
env_vars (Optional[Dict]): Additional environment variables to run the optimization
709712
container. Defaults to ``None``.
710713
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -730,8 +733,13 @@ def _optimize_for_jumpstart(
730733
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
731734

732735
# optimization_config can contain configs for both quantization and compilation
733-
optimization_config, quantization_override_env, compilation_override_env = (
734-
_extract_optimization_config_and_env(quantization_config, compilation_config)
736+
(
737+
optimization_config,
738+
quantization_override_env,
739+
compilation_override_env,
740+
sharding_override_env,
741+
) = _extract_optimization_config_and_env(
742+
quantization_config, compilation_config, sharding_config
735743
)
736744

737745
if not optimization_config:
@@ -807,11 +815,20 @@ def _optimize_for_jumpstart(
807815
{
808816
**(quantization_override_env or {}),
809817
**(compilation_override_env or {}),
818+
**(sharding_override_env or {}),
810819
},
811820
)
812821
if optimization_env_vars:
813822
self.pysdk_model.env.update(optimization_env_vars)
814-
if quantization_config or is_compilation:
823+
824+
if sharding_config and self.pysdk_model._enable_network_isolation:
825+
logger.warning(
826+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
827+
"Loading of model requires network access. Setting it to False."
828+
)
829+
self.pysdk_model._enable_network_isolation = False
830+
831+
if quantization_config or sharding_config or is_compilation:
815832
return create_optimization_job_args
816833
return None
817834

src/sagemaker/serve/builder/model_builder.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
get_huggingface_model_metadata,
106106
download_huggingface_model_metadata,
107107
)
108+
from sagemaker.serve.validations.optimization import _validate_optimization_configuration
108109

109110
logger = logging.getLogger(__name__)
110111

@@ -1120,6 +1121,7 @@ def optimize(
11201121
quantization_config: Optional[Dict] = None,
11211122
compilation_config: Optional[Dict] = None,
11221123
speculative_decoding_config: Optional[Dict] = None,
1124+
sharding_config: Optional[Dict] = None,
11231125
env_vars: Optional[Dict] = None,
11241126
vpc_config: Optional[Dict] = None,
11251127
kms_key: Optional[str] = None,
@@ -1143,6 +1145,8 @@ def optimize(
11431145
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11441146
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11451147
Defaults to ``None``
1148+
sharding_config (Optional[Dict]): Model sharding configuration.
1149+
Defaults to ``None``
11461150
env_vars (Optional[Dict]): Additional environment variables to run the optimization
11471151
container. Defaults to ``None``.
11481152
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
11711175
quantization_config=quantization_config,
11721176
compilation_config=compilation_config,
11731177
speculative_decoding_config=speculative_decoding_config,
1178+
sharding_config=sharding_config,
11741179
env_vars=env_vars,
11751180
vpc_config=vpc_config,
11761181
kms_key=kms_key,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
11901195
quantization_config: Optional[Dict] = None,
11911196
compilation_config: Optional[Dict] = None,
11921197
speculative_decoding_config: Optional[Dict] = None,
1198+
sharding_config: Optional[Dict] = None,
11931199
env_vars: Optional[Dict] = None,
11941200
vpc_config: Optional[Dict] = None,
11951201
kms_key: Optional[str] = None,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
12131219
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12141220
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12151221
Defaults to ``None``
1222+
sharding_config (Optional[Dict]): Model sharding configuration.
1223+
Defaults to ``None``
12161224
env_vars (Optional[Dict]): Additional environment variables to run the optimization
12171225
container. Defaults to ``None``.
12181226
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper(
12271235
Returns:
12281236
Model: A deployable ``Model`` object.
12291237
"""
1238+
if (
1239+
hasattr(self, "enable_network_isolation")
1240+
and self.enable_network_isolation
1241+
and sharding_config
1242+
):
1243+
raise ValueError(
1244+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245+
"Loading of model requires network access."
1246+
)
1247+
1248+
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1249+
# TODO: for organization, abstract all validation behind this fn
1250+
_validate_optimization_configuration(
1251+
is_jumpstart=self._is_jumpstart_model_id(),
1252+
instance_type=instance_type,
1253+
quantization_config=quantization_config,
1254+
compilation_config=compilation_config,
1255+
sharding_config=sharding_config,
1256+
speculative_decoding_config=speculative_decoding_config,
1257+
)
1258+
12301259
self.is_compiled = compilation_config is not None
12311260
self.is_quantized = quantization_config is not None
12321261
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
@@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper(
12361265
if self.mode != Mode.SAGEMAKER_ENDPOINT:
12371266
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
12381267

1268+
if sharding_config and (
1269+
quantization_config or compilation_config or speculative_decoding_config
1270+
):
1271+
raise ValueError(
1272+
(
1273+
"Sharding config is mutually exclusive "
1274+
"and cannot be combined with any other optimization."
1275+
)
1276+
)
1277+
1278+
if sharding_config:
1279+
has_tensor_parallel_degree_in_env_vars = (
1280+
env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1281+
)
1282+
has_tensor_parallel_degree_in_overrides = (
1283+
sharding_config
1284+
and sharding_config.get("OverrideEnvironment")
1285+
and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment")
1286+
)
1287+
if (
1288+
not has_tensor_parallel_degree_in_env_vars
1289+
and not has_tensor_parallel_degree_in_overrides
1290+
):
1291+
raise ValueError(
1292+
(
1293+
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1294+
"environment variable with sharding config."
1295+
)
1296+
)
1297+
12391298
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12401299
self.instance_type = instance_type or self.instance_type
12411300
self.role_arn = role_arn or self.role_arn
@@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper(
12521311
quantization_config=quantization_config,
12531312
compilation_config=compilation_config,
12541313
speculative_decoding_config=speculative_decoding_config,
1314+
sharding_config=sharding_config,
12551315
env_vars=env_vars,
12561316
vpc_config=vpc_config,
12571317
kms_key=kms_key,
@@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper(
12701330
quantization_config=quantization_config,
12711331
compilation_config=compilation_config,
12721332
speculative_decoding_config=speculative_decoding_config,
1333+
sharding_config=sharding_config,
12731334
env_vars=env_vars,
12741335
vpc_config=vpc_config,
12751336
kms_key=kms_key,
12761337
max_runtime_in_sec=max_runtime_in_sec,
12771338
)
12781339

1340+
if sharding_config:
1341+
self.pysdk_model._is_sharded_model = True
1342+
12791343
if input_args:
12801344
optimization_instance_type = input_args["DeploymentInstanceType"]
12811345

@@ -1325,6 +1389,7 @@ def _optimize_for_hf(
13251389
quantization_config: Optional[Dict] = None,
13261390
compilation_config: Optional[Dict] = None,
13271391
speculative_decoding_config: Optional[Dict] = None,
1392+
sharding_config: Optional[Dict] = None,
13281393
env_vars: Optional[Dict] = None,
13291394
vpc_config: Optional[Dict] = None,
13301395
kms_key: Optional[str] = None,
@@ -1340,6 +1405,8 @@ def _optimize_for_hf(
13401405
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13411406
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13421407
Defaults to ``None``
1408+
sharding_config (Optional[Dict]): Model sharding configuration.
1409+
Defaults to ``None``
13431410
env_vars (Optional[Dict]): Additional environment variables to run the optimization
13441411
container. Defaults to ``None``.
13451412
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1430,7 @@ def _optimize_for_hf(
13631430
self.pysdk_model, speculative_decoding_config, False
13641431
)
13651432

1366-
if quantization_config or compilation_config:
1433+
if quantization_config or compilation_config or sharding_config:
13671434
create_optimization_job_args = {
13681435
"OptimizationJobName": job_name,
13691436
"DeploymentInstanceType": self.instance_type,
@@ -1378,8 +1445,13 @@ def _optimize_for_hf(
13781445
model_source = _generate_model_source(self.pysdk_model.model_data, False)
13791446
create_optimization_job_args["ModelSource"] = model_source
13801447

1381-
optimization_config, quantization_override_env, compilation_override_env = (
1382-
_extract_optimization_config_and_env(quantization_config, compilation_config)
1448+
(
1449+
optimization_config,
1450+
quantization_override_env,
1451+
compilation_override_env,
1452+
sharding_override_env,
1453+
) = _extract_optimization_config_and_env(
1454+
quantization_config, compilation_config, sharding_config
13831455
)
13841456
create_optimization_job_args["OptimizationConfigs"] = [
13851457
{k: v} for k, v in optimization_config.items()
@@ -1388,6 +1460,7 @@ def _optimize_for_hf(
13881460
{
13891461
**(quantization_override_env or {}),
13901462
**(compilation_override_env or {}),
1463+
**(sharding_override_env or {}),
13911464
}
13921465
)
13931466

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
361361

362362

363363
def _extract_optimization_config_and_env(
364-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
365-
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
364+
quantization_config: Optional[Dict] = None,
365+
compilation_config: Optional[Dict] = None,
366+
sharding_config: Optional[Dict] = None,
367+
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
366368
"""Extracts optimization config and environment variables.
367369
368370
Args:
369371
quantization_config (Optional[Dict]): The quantization config.
370372
compilation_config (Optional[Dict]): The compilation config.
373+
sharding_config (Optional[Dict]): The sharding config.
371374
372375
Returns:
373-
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
376+
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
374377
The optimization config and environment variables.
375378
"""
376379
optimization_config = {}
@@ -380,18 +383,27 @@ def _extract_optimization_config_and_env(
380383
compilation_override_env = (
381384
compilation_config.get("OverrideEnvironment") if compilation_config else None
382385
)
386+
sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None
383387

384388
if quantization_config is not None:
385389
optimization_config["ModelQuantizationConfig"] = quantization_config
386390

387391
if compilation_config is not None:
388392
optimization_config["ModelCompilationConfig"] = compilation_config
389393

394+
if sharding_config is not None:
395+
optimization_config["ModelShardingConfig"] = sharding_config
396+
390397
# Return optimization config dict and environment variables if either is present
391398
if optimization_config:
392-
return optimization_config, quantization_override_env, compilation_override_env
399+
return (
400+
optimization_config,
401+
quantization_override_env,
402+
compilation_override_env,
403+
sharding_override_env,
404+
)
393405

394-
return None, None, None
406+
return None, None, None, None
395407

396408

397409
def _custom_speculative_decoding(

0 commit comments

Comments
 (0)