@@ -1074,6 +1074,12 @@ def _patch_attributes_model(cls):
10741074 cls .AttributesModel = AttributesModel # type: ignore[misc]
10751075 cls .AttributesModel .model_rebuild (force = True )
10761076
1077+ @classmethod
1078+ def _get_patched_node_type_field (cls ):
1079+ """Return a copy of the `node_type` field cast as the literal type for this class."""
1080+ node_type_field = deepcopy (cls .BaseNodeModel .model_fields ['node_type' ])
1081+ return (Literal [cls .class_node_type ], node_type_field )
1082+
10771083 @classmethod
10781084 def _patch_read_model (cls ):
10791085 """Patch `ReadModel` by wiring the subclass-specific `attributes` model.
@@ -1107,9 +1113,8 @@ def _patch_read_model(cls):
11071113 }
11081114
11091115 attributes_field = deepcopy (cls .ReadModel .model_fields ['attributes' ])
1110- node_type_field = deepcopy (cls .BaseNodeModel .model_fields ['node_type' ])
1111- model_fields ['node_type' ] = (Literal [cls .class_node_type ], node_type_field )
11121116 model_fields ['attributes' ] = (cls .AttributesModel , attributes_field )
1117+ model_fields ['node_type' ] = cls ._get_patched_node_type_field ()
11131118
11141119 ReadModel = cast ( # noqa: N806
11151120 type [Node .ReadModel ],
@@ -1130,23 +1135,28 @@ def _patch_constructor_model(cls):
11301135 """Patch `ConstructorModel` by synthesizing it from `BaseNodeModel` and `ConstructorArgsModel`."""
11311136 if not cls .supports_constructor_model :
11321137 return
1133- node_type_field = deepcopy ( cls . BaseNodeModel . model_fields [ 'node_type' ])
1138+
11341139 args_field = OrmMetadataField (
11351140 description = 'The constructor arguments.' ,
11361141 write_only = True ,
11371142 )
1143+ model_fields : dict [str , Any ] = {
1144+ 'args' : (cls .ConstructorArgsModel , args_field ),
1145+ 'node_type' : cls ._get_patched_node_type_field (),
1146+ }
1147+
11381148 ConstructorModel = cast ( # noqa: N806
11391149 type [Node .BaseNodeModel ],
11401150 pdt .create_model (
11411151 'ConstructorModel' ,
11421152 __base__ = cls .BaseNodeModel ,
11431153 __module__ = cls .__module__ ,
1144- node_type = (Literal [cls .class_node_type ], node_type_field ),
1145- args = (cls .ConstructorArgsModel , args_field ),
1154+ ** model_fields ,
11461155 ),
11471156 )
11481157 ConstructorModel .__qualname__ = f'{ cls .__name__ } .ConstructorModel'
11491158 ConstructorModel .model_rebuild (force = True )
1159+
11501160 cls ._ConstructorModel = ConstructorModel
11511161
11521162 @classmethod
0 commit comments