diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index d2267790..c858ea51 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -175,8 +175,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._transformer = None self._data_sampler = None self._generator = None - - self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) + self.loss_values = None @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -423,8 +422,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None): loss_g.backward() optimizerG.step() - generator_loss = loss_g.detach().cpu() - discriminator_loss = loss_d.detach().cpu() + generator_loss = loss_g.detach().cpu().item() + discriminator_loss = loss_d.detach().cpu().item() epoch_loss_df = pd.DataFrame({ 'Epoch': [i],