@@ -103,6 +103,40 @@ def add_evidence(self, evidences: Union[Evidence, List[Evidence]]):
103
103
"""
104
104
self ._evidences += wrap_list (evidences )
105
105
106
+ def apply_model_changes (self ):
107
+ """
108
+ Update Platform model's tags and version to CredoAI Governance if changed
109
+
110
+ This function will update the platform model associated with the assessment plan with
111
+ the tags and version associated with the local model associated with Governance. If no
112
+ model has been registered on the platform, nothing will be updated.
113
+ """
114
+ # association between keys and api calls:
115
+ api_calls = {
116
+ "tags" : self ._api .update_use_case_model_link_tags ,
117
+ "model_version" : self ._api .update_use_case_model_link_version ,
118
+ }
119
+
120
+ # find model_link with model name from assessment plan
121
+ plan_model = self ._find_plan_model ()
122
+ if plan_model is None :
123
+ return
124
+
125
+ model_info = self .get_model_info ()
126
+ plan_model_info = self ._get_model_info (plan_model )
127
+ for key in model_info .keys ():
128
+ model_value = model_info [key ]
129
+ plan_model_value = plan_model [key ]
130
+ if model_value != plan_model_value :
131
+ global_logger .info (
132
+ "%s\n %s" ,
133
+ f"Platform model and local model { key } do not match. Platform { key } : { plan_model_value } , Local { key } : { model_value } \n " ,
134
+ f"Updated platform model { key } ..." ,
135
+ )
136
+ api_call = api_calls [key ]
137
+ api_call (self ._use_case_id , plan_model ["id" ], model_value )
138
+ plan_model [key ] = model_value
139
+
106
140
def clear_evidence (self ):
107
141
self .set_evidence ([])
108
142
@@ -174,7 +208,7 @@ def get_evidence_requirements(self, tags: dict = None, verbose=False):
174
208
List[EvidenceRequirement]
175
209
"""
176
210
if tags is None :
177
- tags = self .get_model_tags ()
211
+ tags = self .get_model_info ()[ "tags" ]
178
212
179
213
reqs = [e for e in self ._evidence_requirements if check_subset (e .tags , tags )]
180
214
if verbose :
@@ -185,12 +219,9 @@ def get_requirement_tags(self):
185
219
"""Return the unique tags used for all evidence requirements"""
186
220
return self ._unique_tags
187
221
188
- def get_model_tags (self ):
189
- """Get the tags for the associated model"""
190
- if self ._model :
191
- return self ._model ["tags" ]
192
- else :
193
- return {}
222
+ def get_model_info (self ):
223
+ """Get the tags and version for the associated model"""
224
+ return self ._get_model_info (self ._model )
194
225
195
226
def register (
196
227
self ,
@@ -290,6 +321,7 @@ def set_artifacts(
290
321
self ,
291
322
model : str ,
292
323
model_tags : dict ,
324
+ model_version : str = None ,
293
325
training_dataset : str = None ,
294
326
assessment_dataset : str = None ,
295
327
):
@@ -311,15 +343,21 @@ def set_artifacts(
311
343
"""
312
344
313
345
global_logger .info (
314
- f"Adding model ({ model } ) to governance. Model has tags: { model_tags } "
346
+ f"Adding model ({ model } ) to governance. Model has tags: { model_tags } and version: { model_version } "
315
347
)
316
- prepared_model = {"name" : model , "tags" : model_tags }
348
+ prepared_model = {
349
+ "name" : model ,
350
+ "tags" : model_tags ,
351
+ "model_version" : model_version ,
352
+ }
317
353
if training_dataset :
318
354
prepared_model ["training_dataset_name" ] = training_dataset
319
355
if assessment_dataset :
320
356
prepared_model ["assessment_dataset_name" ] = assessment_dataset
321
357
self ._model = prepared_model
322
358
359
+ self ._print_model_changes_log ()
360
+
323
361
def set_evidence (self , evidences : List [Evidence ]):
324
362
"""
325
363
Update evidences
@@ -348,6 +386,9 @@ def _api_export(self):
348
386
f"Uploading { len (self ._evidences )} evidences.. for use_case_id={ self ._use_case_id } policy_pack_id={ self ._policy_pack_id } "
349
387
)
350
388
389
+ # update when model tags are changed
390
+ self .apply_model_changes ()
391
+
351
392
assessment = self ._api .create_assessment (
352
393
self ._use_case_id , self ._prepare_export_data ()
353
394
)
@@ -379,6 +420,32 @@ def _api_export(self):
379
420
error = assessment ["error" ]
380
421
global_logger .error (f"Error in uploading evidences : { error } " )
381
422
423
+ def _print_model_changes_log (self ):
424
+ # find model_link with model name from assessment plan
425
+ plan_model = self ._find_plan_model ()
426
+ if plan_model is None :
427
+ return
428
+
429
+ model_info = self .get_model_info ()
430
+ plan_model_info = self ._get_model_info (plan_model )
431
+ match = True
432
+ for key in model_info .keys ():
433
+ model_value = model_info [key ]
434
+ plan_model_value = plan_model [key ]
435
+ if model_value != plan_model_value :
436
+ match = False
437
+ global_logger .info (
438
+ f"Platform model and local model { key } do not match. Platform { key } : { plan_model_value } , Local { key } : { model_value } "
439
+ )
440
+ if not match :
441
+ global_logger .info (
442
+ """
443
+ You can apply changes to governance by calling the following method:
444
+ gov.apply_model_changes()
445
+ Alternatively, calling gov.export() method will automatically apply changes to governance.
446
+ """
447
+ )
448
+
382
449
def _check_inclusion (self , label , evidence ):
383
450
matching_evidence = []
384
451
for e in evidence :
@@ -408,6 +475,31 @@ def _file_export(self, filename):
408
475
with open (filename , "w" ) as f :
409
476
f .write (data )
410
477
478
+ def _find_plan_model (self ):
479
+ """Return model from assessment plan who matches name of associated model"""
480
+ if self .model is None or self ._plan is None :
481
+ return None
482
+
483
+ model_name = self .model .get ("name" , None )
484
+ if model_name is None :
485
+ return None
486
+
487
+ for link in self ._plan .get ("model_links" , []):
488
+ if link ["model_name" ] == model_name :
489
+ return link
490
+
491
+ return None
492
+
493
+ def _get_model_info (self , model ):
494
+ """Get the tags and version for a model"""
495
+ if model :
496
+ return {
497
+ "tags" : model .get ("tags" , {}),
498
+ "model_version" : model .get ("model_version" , None ),
499
+ }
500
+ else :
501
+ return {"tags" : {}, "model_version" : None }
502
+
411
503
def _match_requirements (self ):
412
504
missing = []
413
505
required_labels = [e .label for e in self .get_evidence_requirements ()]
0 commit comments