Skip to content

Commit

Permalink
fix num_micro_batches in train_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
nuzant committed Sep 2, 2024
1 parent 24a0f7d commit 54e87b1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 36 deletions.
23 changes: 5 additions & 18 deletions realhf/impl/model/backend/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,34 +201,21 @@ def train_batch(
version_steps: int,
num_micro_batches: Optional[int] = None,
):
if num_micro_batches is None:
num_micro_batches = 1
if constants.pipe_parallel_world_size() > 1:
# Fusing the minibatched forward-backward in a pipeline training schedule.
if num_micro_batches is not None:
if (
num_micro_batches < self.pipe_runner.default_train_mbs
and constants.parallelism_rank() == 0
):
logger.warning(
"When training with pipeline parallel, num micro batches should be "
"larger than 2 x num_pipeline_stages to avoid idle time. "
f"Setting num_micro_batches to {self.pipe_runner.default_train_mbs}"
)
num_micro_batches = max(
num_micro_batches, self.pipe_runner.default_train_mbs
)
else:
num_micro_batches = self.pipe_runner.default_train_mbs
instr_set = PipeTrainSetForDeepSpeed(self.ds_engine)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
instr_set=instr_set,
input_=input_,
loss_fn=loss_fn,
version_steps=version_steps,
n_pp_mbs=num_micro_batches,
n_pp_mbs=self.pipe_runner.default_train_mbs * num_micro_batches,
)
else:
if num_micro_batches is None:
num_micro_batches = 1
self.ds_engine._config.gradient_accumulation_steps = num_micro_batches
self.ds_engine.set_gradient_accumulation_boundary(False)
if isinstance(
Expand Down
23 changes: 5 additions & 18 deletions realhf/impl/model/backend/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,35 +731,22 @@ def train_batch(
num_micro_batches: Optional[int] = None,
):
with megatron_ctx():
if num_micro_batches is None:
num_micro_batches = 1
self.engine.zero_grad()
if constants.pipe_parallel_world_size() > 1:
# Fusing the minibatched forward-backward in a pipeline training schedule.
if num_micro_batches is not None:
if (
num_micro_batches < self.pipe_runner.default_train_mbs
and constants.parallelism_rank() == 0
):
logger.warning(
"When training with pipeline parallel, num micro batches should be "
"larger than 2 x num_pipeline_stages to avoid idle time. "
f"Setting num_micro_batches to {self.pipe_runner.default_train_mbs}"
)
num_micro_batches = max(
num_micro_batches, self.pipe_runner.default_train_mbs
)
else:
num_micro_batches = self.pipe_runner.default_train_mbs
instr_set = PipeTrainInstrSetForMegatron(self.engine, num_micro_batches)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
instr_set=instr_set,
input_=input_,
loss_fn=loss_fn,
version_steps=version_steps,
n_pp_mbs=num_micro_batches,
n_pp_mbs=self.pipe_runner.default_train_mbs * num_micro_batches,
)
else:
if num_micro_batches is None:
num_micro_batches = 1
no_sync_ctx = self.engine.ddp.no_sync()
no_sync_ctx.__enter__()
stat = collections.defaultdict(int)
Expand Down

0 comments on commit 54e87b1

Please sign in to comment.