Skip to content

Commit

Permalink
DROP: RHOAIENG-3771 - Reduce execution time of E2E tests
Browse files Browse the repository at this point in the history
By reducing number of epochs and number of training
samples in each epoch it was possible to reduce
test execution time from more than 10 minutes to
less than 2 minutes.
  • Loading branch information
jiripetrlik authored and sutaakar committed Mar 13, 2024
1 parent d757a96 commit 399c7fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions test/e2e/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, random_split, RandomSampler
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -158,7 +158,7 @@ def setup(self, stage=None):
)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
Expand All @@ -178,10 +178,11 @@ def test_dataloader(self):
trainer = Trainer(
accelerator="auto",
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs=5,
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
replace_sampler_ddp=False,
strategy="ddp",
)

Expand Down
7 changes: 4 additions & 3 deletions test/odh/resources/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.data import DataLoader, random_split, RandomSampler
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
Expand Down Expand Up @@ -158,7 +158,7 @@ def setup(self, stage=None):
)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
Expand All @@ -178,10 +178,11 @@ def test_dataloader(self):
trainer = Trainer(
accelerator="auto",
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
max_epochs=2,
max_epochs=3,
callbacks=[TQDMProgressBar(refresh_rate=20)],
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
replace_sampler_ddp=False,
strategy="ddp",
)

Expand Down

0 comments on commit 399c7fc

Please sign in to comment.