105
105
get_huggingface_model_metadata ,
106
106
download_huggingface_model_metadata ,
107
107
)
108
+ from sagemaker .serve .validations .optimization import _validate_optimization_configuration
108
109
109
110
logger = logging .getLogger (__name__ )
110
111
@@ -1120,6 +1121,7 @@ def optimize(
1120
1121
quantization_config : Optional [Dict ] = None ,
1121
1122
compilation_config : Optional [Dict ] = None ,
1122
1123
speculative_decoding_config : Optional [Dict ] = None ,
1124
+ sharding_config : Optional [Dict ] = None ,
1123
1125
env_vars : Optional [Dict ] = None ,
1124
1126
vpc_config : Optional [Dict ] = None ,
1125
1127
kms_key : Optional [str ] = None ,
@@ -1143,6 +1145,8 @@ def optimize(
1143
1145
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1144
1146
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1145
1147
Defaults to ``None``
1148
+ sharding_config (Optional[Dict]): Model sharding configuration.
1149
+ Defaults to ``None``
1146
1150
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1147
1151
container. Defaults to ``None``.
1148
1152
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
1171
1175
quantization_config = quantization_config ,
1172
1176
compilation_config = compilation_config ,
1173
1177
speculative_decoding_config = speculative_decoding_config ,
1178
+ sharding_config = sharding_config ,
1174
1179
env_vars = env_vars ,
1175
1180
vpc_config = vpc_config ,
1176
1181
kms_key = kms_key ,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
1190
1195
quantization_config : Optional [Dict ] = None ,
1191
1196
compilation_config : Optional [Dict ] = None ,
1192
1197
speculative_decoding_config : Optional [Dict ] = None ,
1198
+ sharding_config : Optional [Dict ] = None ,
1193
1199
env_vars : Optional [Dict ] = None ,
1194
1200
vpc_config : Optional [Dict ] = None ,
1195
1201
kms_key : Optional [str ] = None ,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
1213
1219
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1214
1220
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1215
1221
Defaults to ``None``
1222
+ sharding_config (Optional[Dict]): Model sharding configuration.
1223
+ Defaults to ``None``
1216
1224
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1217
1225
container. Defaults to ``None``.
1218
1226
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper(
1227
1235
Returns:
1228
1236
Model: A deployable ``Model`` object.
1229
1237
"""
1238
+ if (
1239
+ hasattr (self , "enable_network_isolation" )
1240
+ and self .enable_network_isolation
1241
+ and sharding_config
1242
+ ):
1243
+ raise ValueError (
1244
+ "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245
+ "Loading of model requires network access."
1246
+ )
1247
+
1248
+ # TODO: ideally these dictionaries need to be sagemaker_core shapes
1249
+ # TODO: for organization, abstract all validation behind this fn
1250
+ _validate_optimization_configuration (
1251
+ is_jumpstart = self ._is_jumpstart_model_id (),
1252
+ instance_type = instance_type ,
1253
+ quantization_config = quantization_config ,
1254
+ compilation_config = compilation_config ,
1255
+ sharding_config = sharding_config ,
1256
+ speculative_decoding_config = speculative_decoding_config ,
1257
+ )
1258
+
1230
1259
self .is_compiled = compilation_config is not None
1231
1260
self .is_quantized = quantization_config is not None
1232
1261
self .speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider (
@@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper(
1236
1265
if self .mode != Mode .SAGEMAKER_ENDPOINT :
1237
1266
raise ValueError ("Model optimization is only supported in Sagemaker Endpoint Mode." )
1238
1267
1268
+ if sharding_config and (
1269
+ quantization_config or compilation_config or speculative_decoding_config
1270
+ ):
1271
+ raise ValueError (
1272
+ (
1273
+ "Sharding config is mutually exclusive "
1274
+ "and cannot be combined with any other optimization."
1275
+ )
1276
+ )
1277
+
1278
+ if sharding_config :
1279
+ has_tensor_parallel_degree_in_env_vars = (
1280
+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1281
+ )
1282
+ has_tensor_parallel_degree_in_overrides = (
1283
+ sharding_config
1284
+ and sharding_config .get ("OverrideEnvironment" )
1285
+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
1286
+ )
1287
+ if (
1288
+ not has_tensor_parallel_degree_in_env_vars
1289
+ and not has_tensor_parallel_degree_in_overrides
1290
+ ):
1291
+ raise ValueError (
1292
+ (
1293
+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1294
+ "environment variable with sharding config."
1295
+ )
1296
+ )
1297
+
1239
1298
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1240
1299
self .instance_type = instance_type or self .instance_type
1241
1300
self .role_arn = role_arn or self .role_arn
@@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper(
1252
1311
quantization_config = quantization_config ,
1253
1312
compilation_config = compilation_config ,
1254
1313
speculative_decoding_config = speculative_decoding_config ,
1314
+ sharding_config = sharding_config ,
1255
1315
env_vars = env_vars ,
1256
1316
vpc_config = vpc_config ,
1257
1317
kms_key = kms_key ,
@@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper(
1270
1330
quantization_config = quantization_config ,
1271
1331
compilation_config = compilation_config ,
1272
1332
speculative_decoding_config = speculative_decoding_config ,
1333
+ sharding_config = sharding_config ,
1273
1334
env_vars = env_vars ,
1274
1335
vpc_config = vpc_config ,
1275
1336
kms_key = kms_key ,
1276
1337
max_runtime_in_sec = max_runtime_in_sec ,
1277
1338
)
1278
1339
1340
+ if sharding_config :
1341
+ self .pysdk_model ._is_sharded_model = True
1342
+
1279
1343
if input_args :
1280
1344
optimization_instance_type = input_args ["DeploymentInstanceType" ]
1281
1345
@@ -1325,6 +1389,7 @@ def _optimize_for_hf(
1325
1389
quantization_config : Optional [Dict ] = None ,
1326
1390
compilation_config : Optional [Dict ] = None ,
1327
1391
speculative_decoding_config : Optional [Dict ] = None ,
1392
+ sharding_config : Optional [Dict ] = None ,
1328
1393
env_vars : Optional [Dict ] = None ,
1329
1394
vpc_config : Optional [Dict ] = None ,
1330
1395
kms_key : Optional [str ] = None ,
@@ -1340,6 +1405,8 @@ def _optimize_for_hf(
1340
1405
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1341
1406
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1342
1407
Defaults to ``None``
1408
+ sharding_config (Optional[Dict]): Model sharding configuration.
1409
+ Defaults to ``None``
1343
1410
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1344
1411
container. Defaults to ``None``.
1345
1412
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1430,7 @@ def _optimize_for_hf(
1363
1430
self .pysdk_model , speculative_decoding_config , False
1364
1431
)
1365
1432
1366
- if quantization_config or compilation_config :
1433
+ if quantization_config or compilation_config or sharding_config :
1367
1434
create_optimization_job_args = {
1368
1435
"OptimizationJobName" : job_name ,
1369
1436
"DeploymentInstanceType" : self .instance_type ,
@@ -1378,8 +1445,13 @@ def _optimize_for_hf(
1378
1445
model_source = _generate_model_source (self .pysdk_model .model_data , False )
1379
1446
create_optimization_job_args ["ModelSource" ] = model_source
1380
1447
1381
- optimization_config , quantization_override_env , compilation_override_env = (
1382
- _extract_optimization_config_and_env (quantization_config , compilation_config )
1448
+ (
1449
+ optimization_config ,
1450
+ quantization_override_env ,
1451
+ compilation_override_env ,
1452
+ sharding_override_env ,
1453
+ ) = _extract_optimization_config_and_env (
1454
+ quantization_config , compilation_config , sharding_config
1383
1455
)
1384
1456
create_optimization_job_args ["OptimizationConfigs" ] = [
1385
1457
{k : v } for k , v in optimization_config .items ()
@@ -1388,6 +1460,7 @@ def _optimize_for_hf(
1388
1460
{
1389
1461
** (quantization_override_env or {}),
1390
1462
** (compilation_override_env or {}),
1463
+ ** (sharding_override_env or {}),
1391
1464
}
1392
1465
)
1393
1466
0 commit comments