1111 ColumnPairTrends as SingleTableColumnPairTrends ,
1212)
1313from sdmetrics .reports .utils import PlotConfig
14+ from sdmetrics .utils import _cast_to_iterable
1415
1516
1617class InterTableTrends (BaseMultiTableProperty ):
@@ -50,16 +51,16 @@ def _denormalize_tables(self, real_data, synthetic_data, relationship):
5051 """
5152 parent = relationship ['parent_table_name' ]
5253 child = relationship ['child_table_name' ]
53- foreign_key = relationship ['child_foreign_key' ]
54- primary_key = relationship ['parent_primary_key' ]
54+ foreign_key = _cast_to_iterable ( relationship ['child_foreign_key' ])
55+ primary_key = _cast_to_iterable ( relationship ['parent_primary_key' ])
5556
5657 real_parent = real_data [parent ].add_prefix (f'{ parent } .' )
5758 real_child = real_data [child ].add_prefix (f'{ child } .' )
5859 synthetic_parent = synthetic_data [parent ].add_prefix (f'{ parent } .' )
5960 synthetic_child = synthetic_data [child ].add_prefix (f'{ child } .' )
6061
61- child_index = f'{ child } .{ foreign_key } '
62- parent_index = f'{ parent } .{ primary_key } '
62+ child_index = [ f'{ child } .{ key_col } ' for key_col in foreign_key ]
63+ parent_index = [ f'{ parent } .{ key_col } ' for key_col in primary_key ]
6364
6465 denormalized_real = real_child .merge (
6566 real_parent , left_on = child_index , right_on = parent_index
@@ -101,7 +102,12 @@ def _merge_metadata(self, metadata, parent_table, child_table):
101102 merged_metadata ['columns' ] = {** child_cols , ** parent_cols }
102103 if 'primary_key' in merged_metadata :
103104 primary_key = merged_metadata ['primary_key' ]
104- merged_metadata ['primary_key' ] = f'{ child_table } .{ primary_key } '
105+ if isinstance (primary_key , list ):
106+ merged_metadata ['primary_key' ] = [
107+ f'{ child_table } .{ pk_col } ' for pk_col in primary_key
108+ ]
109+ else :
110+ merged_metadata ['primary_key' ] = f'{ child_table } .{ primary_key } '
105111
106112 return merged_metadata , list (parent_cols .keys ()), list (child_cols .keys ())
107113
@@ -123,6 +129,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
123129 parent = relationship ['parent_table_name' ]
124130 child = relationship ['child_table_name' ]
125131 foreign_key = relationship ['child_foreign_key' ]
132+ fk_tuple = tuple (foreign_key ) if isinstance (foreign_key , list ) else foreign_key
126133
127134 denormalized_real , denormalized_synthetic = self ._denormalize_tables (
128135 real_data , synthetic_data , relationship
@@ -132,14 +139,14 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
132139
133140 parent_child_pairs = itertools .product (parent_cols , child_cols )
134141
135- self ._properties [(parent , child , foreign_key )] = SingleTableColumnPairTrends ()
142+ self ._properties [(parent , child , fk_tuple )] = SingleTableColumnPairTrends ()
136143 self ._properties [
137- (parent , child , foreign_key )
144+ (parent , child , fk_tuple )
138145 ].real_correlation_threshold = self .real_correlation_threshold
139146 self ._properties [
140- (parent , child , foreign_key )
147+ (parent , child , fk_tuple )
141148 ].real_association_threshold = self .real_association_threshold
142- details = self ._properties [(parent , child , foreign_key )]._generate_details (
149+ details = self ._properties [(parent , child , fk_tuple )]._generate_details (
143150 denormalized_real ,
144151 denormalized_synthetic ,
145152 merged_metadata ,
@@ -149,7 +156,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
149156
150157 details ['Parent Table' ] = parent
151158 details ['Child Table' ] = child
152- details ['Foreign Key' ] = foreign_key
159+ details ['Foreign Key' ] = str ( foreign_key )
153160 if not details .empty :
154161 details ['Column 1' ] = details ['Column 1' ].str .replace (
155162 f'{ parent } .' , '' , n = 1 , regex = False
@@ -233,18 +240,15 @@ def _compute_average_score(self, to_plot):
233240 def get_visualization (self , table_name = None ):
234241 """Create a plot to show the inter table trends data.
235242
236- Returns:
237- plotly.graph_objects._figure.Figure
238-
239243 Args:
240244 table_name (str, optional):
241245 Table to plot. Defaults to None.
242246
243- Raises:
244- - ``ValueError`` if property has not been computed.
245-
246247 Returns:
247248 plotly.graph_objects._figure.Figure
249+
250+ Raises:
251+ - ``ValueError`` if property has not been computed.
248252 """
249253 if not self .is_computed :
250254 raise ValueError (
0 commit comments