Skip to content

Commit

Permalink
remove test + update other test
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Apr 8, 2024
1 parent 0d4708a commit 29f6afa
Showing 1 changed file with 18 additions and 37 deletions.
55 changes: 18 additions & 37 deletions tests/integration/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,10 @@

import numpy as np
import pandas as pd
from sklearn import datasets

from ctgan.synthesizers.tvae import TVAE


def test_tvae(tmpdir, capsys):
"""Test the TVAE load/save methods."""
# Setup
iris = datasets.load_iris()
data = pd.DataFrame(iris.data, columns=iris.feature_names)
data['class'] = pd.Series(iris.target).map(iris.target_names.__getitem__)
tvae = TVAE(epochs=10, verbose=True)

# Run
tvae.fit(data, ['class'])
captured_out = capsys.readouterr().err

path = str(tmpdir / 'test_tvae.pkl')
tvae.save(path)
tvae = TVAE.load(path)

sampled = tvae.sample(100)

# Assert
assert sampled.shape == (100, 5)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == set(data.columns)
assert set(sampled.dtypes) == set(data.dtypes)
loss_values = tvae.loss_values
assert len(loss_values) == 10
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
assert all(loss_values['Batch'] == 0)
last_loss_val = loss_values['Loss'].iloc[-1]
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out


def test_drop_last_false():
"""Test the TVAE predicts the correct values."""
data = pd.DataFrame({
Expand Down Expand Up @@ -122,17 +90,18 @@ def test_fixed_random_seed():
np.testing.assert_array_equal(sampled_0_1, sampled_1_1)


def test_tvae_save(tmpdir):
def test_tvae_save(tmpdir, capsys):
"""Test that the ``TVAE`` model can be saved and loaded."""
# Setup
data = pd.DataFrame({
'continuous': np.random.random(100),
'discrete': np.random.choice(['a', 'b', 'c'], 100)
})
discrete_columns = [1]
discrete_columns = ['discrete']

tvae = TVAE(epochs=1)
tvae.fit(data.to_numpy(), discrete_columns)
tvae = TVAE(epochs=10, verbose=True)
tvae.fit(data, discrete_columns)
captured_out = capsys.readouterr().err
tvae.set_random_state(0)

tvae.sample(100)
Expand All @@ -143,4 +112,16 @@ def test_tvae_save(tmpdir):

# Load
loaded_instance = TVAE.load(str(model_path))
loaded_instance.sample(100)
sampled = loaded_instance.sample(100)

# Assert
assert sampled.shape == (100, 2)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == set(data.columns)
assert set(sampled.dtypes) == set(data.dtypes)
loss_values = tvae.loss_values
assert len(loss_values) == 10
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
assert all(loss_values['Batch'] == 0)
last_loss_val = loss_values['Loss'].iloc[-1]
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out

0 comments on commit 29f6afa

Please sign in to comment.