Skip to content

Commit

Permalink
CTGAN error during fit if continuous training data contains null valu…
Browse files Browse the repository at this point in the history
…es (#428)
  • Loading branch information
rwedge authored Jan 15, 2025
1 parent 6ed1f19 commit 42ca6f3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ctgan/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Custom errors for CTGAN."""


class InvalidDataError(Exception):
"""Error to raise when data is not valid."""
27 changes: 27 additions & 0 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ctgan.data_sampler import DataSampler
from ctgan.data_transformer import DataTransformer
from ctgan.errors import InvalidDataError
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -289,6 +290,31 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
if invalid_columns:
raise ValueError(f'Invalid columns found: {invalid_columns}')

def _validate_null_data(self, train_data, discrete_columns):
"""Check whether null values exist in continuous ``train_data``.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
if isinstance(train_data, pd.DataFrame):
continuous_cols = list(set(train_data.columns) - set(discrete_columns))
any_nulls = train_data[continuous_cols].isna().any().any()
else:
continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns]
any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any()

if any_nulls:
raise InvalidDataError(
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

@random_state
def fit(self, train_data, discrete_columns=(), epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Expand All @@ -303,6 +329,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
a ``pandas.DataFrame``, this list should contain the column names.
"""
self._validate_discrete_columns(train_data, discrete_columns)
self._validate_null_data(train_data, discrete_columns)

if epochs is None:
epochs = self._epochs
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
import pytest

from ctgan.errors import InvalidDataError
from ctgan.synthesizers.ctgan import CTGAN


Expand Down Expand Up @@ -132,6 +133,25 @@ def test_categorical_nan():
assert {'b', 'c'}.issubset(values)


def test_continuous_nan():
"""Test the CTGAN with missing numerical values."""
# Setup
data = pd.DataFrame({
'continuous': [np.nan, 1.0, 2.0] * 10,
'discrete': ['a', 'b', 'c'] * 10,
})
discrete_columns = ['discrete']
error_message = (
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

# Run and Assert
ctgan = CTGAN(epochs=1)
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(data, discrete_columns)


def test_synthesizer_sample():
"""Test the CTGAN samples the correct datatype."""
data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)})
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from unittest import TestCase
from unittest.mock import Mock

import numpy as np
import pandas as pd
import pytest
import torch

from ctgan.data_transformer import SpanInfo
from ctgan.errors import InvalidDataError
from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual


Expand Down Expand Up @@ -289,3 +291,42 @@ def test__validate_discrete_columns(self):
ctgan = CTGAN(epochs=1)
with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'):
ctgan.fit(data, discrete_columns)

def test__validate_null_data(self):
"""Test `_validate_null_data` with pandas and numpy data.
Check the appropriate error is raised if null values are present in
continuous columns, both for numpy arrays and dataframes.
"""
# Setup
discrete_df = pd.DataFrame({'discrete': ['a', 'b']})
discrete_array = np.array([['a'], ['b']])
continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]})
continuous_no_nulls_array = np.array([[0], [1]])
continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]})
continuous_with_null_array = np.array([[1], [np.nan]])
ctgan = CTGAN(epochs=1)
error_message = (
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

# Test discrete DataFrame fits without error
ctgan.fit(discrete_df, ['discrete'])

# Test discrete array fits without error
ctgan.fit(discrete_array, [0])

# Test continuous DataFrame without nulls fits without error
ctgan.fit(continuous_no_nulls_df)

# Test continuous array without nulls fits without error
ctgan.fit(continuous_no_nulls_array)

# Test nulls in continuous columns DataFrame errors on fit
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(continuous_with_null_df)

# Test nulls in continuous columns array errors on fit
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(continuous_with_null_array)

0 comments on commit 42ca6f3

Please sign in to comment.