From 89e1c2328ee5e37ea763e6b1a68b770966829b03 Mon Sep 17 00:00:00 2001 From: Roy Wedge Date: Thu, 15 Aug 2024 10:08:26 -0400 Subject: [PATCH] Support null foreign keys in HMA Synthesizer (#2124) Co-authored-by: Frances Hartwell Co-authored-by: Gaurav Sheni --- sdv/multi_table/hma.py | 86 ++++++++++++++----- sdv/multi_table/utils.py | 7 +- sdv/sampling/hierarchical_sampler.py | 33 +++++-- tests/integration/multi_table/test_hma.py | 52 ++++++++--- tests/integration/utils/test_poc.py | 2 +- tests/unit/multi_table/test_hma.py | 4 + tests/unit/multi_table/test_utils.py | 2 +- .../sampling/test_hierarchical_sampler.py | 10 ++- 8 files changed, 149 insertions(+), 47 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 5df4dc358..5de322ac4 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -158,6 +158,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._table_sizes = {} self._max_child_rows = {} self._min_child_rows = {} + self._null_child_synthesizers = {} self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} @@ -310,10 +311,17 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc pbar_args = self._get_pbar_args(desc=progress_bar_desc) for foreign_key_value in tqdm(foreign_key_values, **pbar_args): - child_rows = child_table.loc[[foreign_key_value]] + try: + child_rows = child_table.loc[[foreign_key_value]] + except KeyError: + # pre pandas 2.1 df.loc[[np.nan]] causes error + if pd.isna(foreign_key_value): + child_rows = child_table[child_table.index.isna()] + else: + raise child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)] try: - if child_rows.empty: + if child_rows.empty and not pd.isna(foreign_key_value): row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: @@ -324,19 +332,26 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc self._set_extended_columns_distributions( synthesizer, child_name, child_rows.columns ) - synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) - row = synthesizer._get_parameters() - row = pd.Series(row) - row.index = f'__{child_name}__{foreign_key}__' + row.index - - if scale_columns is None: - scale_columns = [column for column in row.index if column.endswith('scale')] - - if len(child_rows) == 1: - row.loc[scale_columns] = None - - extension_rows.append(row) - index.append(foreign_key_value) + if not child_rows.empty: + synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) + row = synthesizer._get_parameters() + row = pd.Series(row) + row.index = f'__{child_name}__{foreign_key}__' + row.index + + if not pd.isna(foreign_key_value): + if scale_columns is None: + scale_columns = [ + column for column in row.index if column.endswith('scale') + ] + + if len(child_rows) == 1: + row.loc[scale_columns] = None + + if pd.isna(foreign_key_value): + self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] = synthesizer + else: + extension_rows.append(row) + index.append(foreign_key_value) except Exception: # Skip children rows subsets that fail pass @@ -344,8 +359,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc return pd.DataFrame(extension_rows, index=index) @staticmethod - def _clear_nans(table_data): - for column in table_data.columns: + def _clear_nans(table_data, ignore_cols=None): + columns = set(table_data.columns) + if ignore_cols is not None: + columns = columns - set(ignore_cols) + for column in columns: column_data = table_data[column] if column_data.dtype in (int, float): fill_value = 0 if column_data.isna().all() else column_data.mean() @@ -405,6 +423,9 @@ def _augment_table(self, table, tables, table_name): table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() self._min_child_rows[num_rows_key] = table[num_rows_key].min() + self._null_foreign_key_percentages[f'__{child_name}__{foreign_key}'] = 1 - ( + table[num_rows_key].sum() / child_table.shape[0] + ) if len(extension.columns) > 0: self._parent_extended_columns[table_name].extend(list(extension.columns)) @@ -412,7 +433,9 @@ def _augment_table(self, table, tables, table_name): tables[table_name] = table self._learned_relationships += 1 self._augmented_tables.append(table_name) - self._clear_nans(table) + + foreign_keys = self.metadata._get_all_foreign_keys(table_name) + self._clear_nans(table, ignore_cols=foreign_keys) return table @@ -525,12 +548,17 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): # A child table is created based on only one foreign key. foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] - parameters = self._extract_parameters(parent_row, child_name, foreign_key) - default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) - table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) - synthesizer._set_parameters(parameters, default_parameters) + if parent_row is not None: + parameters = self._extract_parameters(parent_row, child_name, foreign_key) + default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) + + table_meta = self.metadata.tables[child_name] + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer._set_parameters(parameters, default_parameters) + else: + synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] + synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor return synthesizer @@ -580,6 +608,9 @@ def _find_parent_id(likelihoods, num_rows): candidates.append(parent) candidate_weights.append(weight) + # cast candidates to series to ensure np.random.choice uses desired dtype + candidates = pd.Series(candidates, dtype=likelihoods.index.dtype) + # All available candidates were assigned 0 likelihood of being the parent id if sum(candidate_weights) == 0: chosen_parent = np.random.choice(candidates) @@ -629,6 +660,14 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): except (AttributeError, np.linalg.LinAlgError): likelihoods[parent_id] = None + null_child_synths = getattr(self, '_null_child_synthesizers', {}) + if f'__{table_name}__{foreign_key}' in null_child_synths: + try: + likelihoods[np.nan] = synthesizer._get_likelihood(table_rows) + + except (AttributeError, np.linalg.LinAlgError): + likelihoods[np.nan] = None + return pd.DataFrame(likelihoods, index=table_rows.index) def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key): @@ -657,6 +696,7 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy() + num_rows.loc[np.nan] = child_table.shape[0] - num_rows.sum() likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 60fd71053..069d9d629 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -461,9 +461,10 @@ def _subsample_disconnected_roots(data, metadata, table, ratio_to_keep, drop_mis def _subsample_table_and_descendants(data, metadata, table, num_rows, drop_missing_values): """Subsample the table and its descendants. - The logic is to first subsample all the NaN foreign keys of the table when ``drop_missing_values`` - is True. We raise an error if we cannot reach referential integrity while keeping - the number of rows. Then, we drop rows of the descendants to ensure referential integrity. + The logic is to first subsample all the NaN foreign keys of the table when + ``drop_missing_values`` is True. We raise an error if we cannot reach referential integrity + while keeping the number of rows. Then, we drop rows of the descendants to ensure referential + integrity. Args: data (dict): diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 641b6ccbf..0d5bf855d 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -3,6 +3,7 @@ import logging import warnings +import numpy as np import pandas as pd LOGGER = logging.getLogger(__name__) @@ -24,6 +25,7 @@ class BaseHierarchicalSampler: def __init__(self, metadata, table_synthesizers, table_sizes): self.metadata = metadata + self._null_foreign_key_percentages = {} self._table_synthesizers = table_synthesizers self._table_sizes = table_sizes @@ -103,7 +105,9 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num row_indices = sampled_rows.index sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key] else: - sampled_rows[foreign_key] = parent_row[parent_key] + sampled_rows[foreign_key] = ( + parent_row[parent_key] if parent_row is not None else np.nan + ) previous = sampled_data.get(child_name) if previous is None: @@ -143,16 +147,19 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): """ total_num_rows = round(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): + null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {}) + null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0) + total_parent_rows = round(total_num_rows * (1 - null_fk_pctg)) num_rows_key = f'__{child_name}__{foreign_key}__num_rows' min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] max_rows = self._max_child_rows[num_rows_key] key_data = sampled_data[table_name][num_rows_key].fillna(0).round() sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int) - while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: + while sum(sampled_data[table_name][num_rows_key]) != total_parent_rows: num_rows_column = sampled_data[table_name][num_rows_key].argsort() - if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) < total_parent_rows: for i in num_rows_column: # If the number of rows is already at the maximum, skip # The exception is when the smallest value is already at the maximum, @@ -164,7 +171,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] += 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break else: @@ -179,7 +186,7 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data): break sampled_data[table_name].loc[i, num_rows_key] -= 1 - if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: + if sum(sampled_data[table_name][num_rows_key]) == total_parent_rows: break def _sample_children(self, table_name, sampled_data, scale=1.0): @@ -207,8 +214,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): sampled_data=sampled_data, ) + foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] + if child_name not in sampled_data: # No child rows sampled, force row creation - foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] num_rows_key = f'__{child_name}__{foreign_key}__num_rows' max_num_child_index = sampled_data[table_name][num_rows_key].idxmax() parent_row = sampled_data[table_name].iloc[max_num_child_index] @@ -221,6 +229,19 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): num_rows=1, ) + total_num_rows = round(self._table_sizes[child_name] * scale) + null_fk_pctgs = getattr(self, '_null_foreign_key_percentages', {}) + null_fk_pctg = null_fk_pctgs.get(f'__{child_name}__{foreign_key}', 0) + num_null_rows = round(total_num_rows * null_fk_pctg) + if num_null_rows > 0: + self._add_child_rows( + child_name=child_name, + parent_name=table_name, + parent_row=None, + sampled_data=sampled_data, + num_rows=num_null_rows, + ) + self._sample_children(table_name=child_name, sampled_data=sampled_data, scale=scale) def _finalize(self, sampled_data): diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 8482c5ed4..133a53681 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1319,49 +1319,68 @@ def test_null_foreign_keys(self): """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup metadata = MultiTableMetadata() - metadata.add_table('parent_table') - metadata.add_column('parent_table', 'id', sdtype='id') - metadata.set_primary_key('parent_table', 'id') + metadata.add_table('parent_table1') + metadata.add_column('parent_table1', 'id', sdtype='id') + metadata.set_primary_key('parent_table1', 'id') + + metadata.add_table('parent_table2') + metadata.add_column('parent_table2', 'id', sdtype='id') + metadata.set_primary_key('parent_table2', 'id') metadata.add_table('child_table1') metadata.add_column('child_table1', 'id', sdtype='id') metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk', sdtype='id') + metadata.add_column('child_table1', 'fk1', sdtype='id') + metadata.add_column('child_table1', 'fk2', sdtype='id') metadata.add_table('child_table2') metadata.add_column('child_table2', 'id', sdtype='id') metadata.set_primary_key('child_table2', 'id') metadata.add_column('child_table2', 'fk1', sdtype='id') metadata.add_column('child_table2', 'fk2', sdtype='id') + metadata.add_column('child_table2', 'cat_type', sdtype='categorical') metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table1', child_table_name='child_table1', parent_primary_key='id', - child_foreign_key='fk', + child_foreign_key='fk1', ) metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table2', + child_table_name='child_table1', + parent_primary_key='id', + child_foreign_key='fk2', + ) + + metadata.add_relationship( + parent_table_name='parent_table1', child_table_name='child_table2', parent_primary_key='id', child_foreign_key='fk1', ) metadata.add_relationship( - parent_table_name='parent_table', + parent_table_name='parent_table1', child_table_name='child_table2', parent_primary_key='id', child_foreign_key='fk2', ) data = { - 'parent_table': pd.DataFrame({'id': [1, 2, 3]}), - 'child_table1': pd.DataFrame({'id': [1, 2, 3], 'fk': [1, 2, np.nan]}), + 'parent_table1': pd.DataFrame({'id': [1, 2, 3]}), + 'parent_table2': pd.DataFrame({'id': ['alpha', 'beta', 'gamma']}), + 'child_table1': pd.DataFrame({ + 'id': [1, 2, 3], + 'fk1': pd.Series([np.nan, 2, np.nan], dtype='float64'), + 'fk2': pd.Series(['alpha', 'beta', np.nan], dtype='object'), + }), 'child_table2': pd.DataFrame({ 'id': [1, 2, 3], 'fk1': [1, 2, np.nan], - 'fk2': [1, 2, np.nan], + 'fk2': pd.Series([1, np.nan, np.nan], dtype='float64'), + 'cat_type': pd.Series(['siamese', 'persian', 'american shorthair'], dtype='object'), }), } @@ -1371,8 +1390,17 @@ def test_null_foreign_keys(self): metadata.validate() metadata.validate_data(data) - # Run and Assert + # Run synthesizer.fit(data) + sampled = synthesizer.sample() + + # Assert + assert len(sampled['parent_table1']) == 3 + assert len(sampled['parent_table2']) == 3 + assert sum(pd.isna(sampled['child_table1']['fk1'])) == 2 + assert sum(pd.isna(sampled['child_table1']['fk2'])) == 1 + assert sum(pd.isna(sampled['child_table2']['fk1'])) == 1 + assert sum(pd.isna(sampled['child_table2']['fk2'])) == 2 def test_sampling_with_unknown_sdtype_numerical_column(self): """Test that if a numerical column is detected as unknown in the metadata, diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 0a3e02135..b3dfcffdc 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -242,4 +242,4 @@ def test_get_random_subset_with_missing_values(metadata, data): # Assert assert len(result['child']) == 3 - assert result['child']['parent_id'].isnull().sum() > 0 + assert result['child']['parent_id'].isna().sum() > 0 diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 0db06ea82..ef0085c19 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -511,6 +511,7 @@ def test__get_likelihoods(self): instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() instance._synthesizer.return_value._get_likelihood.return_value = likelihoods + instance._null_child_synthesizers = {} # Run result = HMASynthesizer._get_likelihoods( @@ -550,6 +551,7 @@ def test__get_likelihoods_attribute_error(self): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} instance._synthesizer.return_value._get_likelihood.side_effect = [ likelihoods, AttributeError(), @@ -594,6 +596,7 @@ def test__get_likelihoods_linalg_error(self): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} instance._synthesizer.return_value._get_likelihood.side_effect = [ likelihoods, np.linalg.LinAlgError(), @@ -639,6 +642,7 @@ def test_get_likelihoods_filters_over_existing_columns(self, mock_concat): instance._table_synthesizers = {'child_table': child_synthesizer} instance._table_parameters = {'child_table': {}} instance._extract_parameters = Mock() + instance._null_child_synthesizers = {} likelihoods = np.array([0.1, 0.2, 0.3, 0.4]) instance._synthesizer.return_value._get_likelihood.return_value = likelihoods diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index cbea576b4..e01195f89 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1990,7 +1990,7 @@ def test__subsample_data_with_null_foreing_keys(): # Assert assert len(result_with_nan['child']) == 4 - assert result_with_nan['child']['parent_id'].isnull().sum() > 0 + assert result_with_nan['child']['parent_id'].isna().sum() > 0 assert len(result_without_nan['child']) == 2 assert set(result_without_nan['child'].index) == {0, 1} diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index 006b14ffd..de07a2634 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -177,7 +177,7 @@ def sample_children(table_name, sampled_data, scale): 'session_id': ['a', 'a', 'b'], }) - def _add_child_rows(child_name, parent_name, parent_row, sampled_data): + def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=None): if parent_name == 'users': if parent_row['user_id'] == 1: sampled_data[child_name] = pd.DataFrame({ @@ -202,10 +202,13 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data): instance = Mock() instance.metadata._get_child_map.return_value = {'users': ['sessions', 'transactions']} instance.metadata._get_parent_map.return_value = {'users': []} + instance.metadata._get_foreign_keys.return_value = ['user_id'] instance._table_sizes = {'users': 10, 'sessions': 5, 'transactions': 3} instance._table_synthesizers = {'users': Mock()} instance._sample_children = sample_children instance._add_child_rows.side_effect = _add_child_rows + instance._null_child_synthesizers = {} + instance._null_foreign_key_percentages = {'__sessions__user_id': 0} # Run result = {'users': pd.DataFrame({'user_id': [1, 3]})} @@ -271,6 +274,7 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows= instance._table_synthesizers = {'users': Mock()} instance._sample_children = sample_children instance._add_child_rows.side_effect = _add_child_rows + instance._null_foreign_key_percentages = {'__sessions__user_id': 0} # Run result = {'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]})} @@ -561,6 +565,7 @@ def test___enforce_table_size_too_many_rows(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -580,6 +585,7 @@ def test___enforce_table_size_not_enough_rows(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -599,6 +605,7 @@ def test___enforce_table_size_clipping(self): instance._min_child_rows = {'__child__fk__num_rows': 2} instance._max_child_rows = {'__child__fk__num_rows': 4} instance._table_sizes = {'child': 8} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 1.0, data) @@ -618,6 +625,7 @@ def test___enforce_table_size_too_small_sample(self): instance._min_child_rows = {'__child__fk__num_rows': 1} instance._max_child_rows = {'__child__fk__num_rows': 3} instance._table_sizes = {'child': 4} + instance._null_foreign_key_percentages = {'__child__fk': 0} # Run BaseHierarchicalSampler._enforce_table_size(instance, 'child', 'parent', 0.001, data)