Skip to content

Commit

Permalink
Store loss values as float values not PyTorch objects (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Mar 1, 2024
1 parent d6ce31b commit 58686c5
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
self._transformer = None
self._data_sampler = None
self._generator = None

self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])
self.loss_values = None

@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
Expand Down Expand Up @@ -423,8 +422,8 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
loss_g.backward()
optimizerG.step()

generator_loss = loss_g.detach().cpu()
discriminator_loss = loss_d.detach().cpu()
generator_loss = loss_g.detach().cpu().item()
discriminator_loss = loss_d.detach().cpu().item()

epoch_loss_df = pd.DataFrame({
'Epoch': [i],
Expand Down

0 comments on commit 58686c5

Please sign in to comment.