Skip to content

Commit

Permalink
Merge branch 'main' of github.com:openpsi-project/realhf into profile
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett4wade committed Sep 2, 2024
2 parents 937ce88 + 6c18266 commit db4bc78
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 81 deletions.
80 changes: 80 additions & 0 deletions examples/scripts/local/ppo_minibatched.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# MODEL_FAMILY specifies how the pretrained checkpoint is loaded, e.g., as a LLaMA model or a GPT model.
# You can specify different model families for the SFT and the RW model, but you need to
# re-tokenize the sequences if necessary.
MODEL_FAMILY=llama

# SFT_MODEL_PATH and RW_MODEL_PATH are the saved SFT and RW checkpoints.
# ReaL saves checkpoints with the same format as HuggingFace,
# so you don't need to convert or split checkpoints explicitly.
# You can also directly use the pre-trained HuggingFace checkpoint, but this
# will not ensure the optimal algorithm performance.
SFT_MODEL_PATH=/lustre/aigc/llm/checkpoints/fw/quickstart-sft/$MODEL_FAMILY-local-manual/default/epoch7epochstep5globalstep50/
RW_MODEL_PATH=/lustre/aigc/llm/checkpoints/fw/quickstart-rw/$MODEL_FAMILY-ray-manual/default/epoch1epochstep10globalstep10/

# Option 1: The experiment runs locally with subprocesses.
MODE=local
# Option 2: The experiment runs in a Ray cluster
# MODE=ray
# Option 3: The experiment runs in a SLURM + pyxis cluster
# Using the slurm mode requires a cluster spec file
# and setting CLUSTER_SPEC_PATH to the path of it.
# MODE=slurm

# `experiment_name` and `trial_name` can be arbitrary.
# Logs and saved checkpoints will be indexed by them.
EXP_NAME=quickstart-ppo
TRIAL_NAME=$MODEL_FAMILY-$MODE-heuristic

# We use the "heuristic" allocation mode here to automatically determine the parallelism strategy
# for each model function call, i.e., actor generation, critic inference, actor train, etc.
# The number of GPUs is `n_nodes` * `n_gpus_per_node` (not set explictly here, defaults to 8).
# ReaL will make full use of these available GPUs to design allocations.
# This does not ensure the optimal throughput, but it is a good starting point.

# The `heuristic` allocation mode is not ensured to run with every model configurations.
# For example, if the vocabulary size is an odd number, the model parallelism may not work.
# In these cases, you can use the `ppo_manual.sh` to specify the parallelism strategy manually.

# The `ppo` subcommand specifies that this is a PPO experiment.
# The `save_freq_steps` is set to `null` to disable saving checkpoints.
# Enable it if you want to save checkpoints.
# The `ppo` option is used to control the generation and PPO algorithm hyperparameters.
# Note that the performance of PPO is sensitive to the the pre-trained model and hyperparameters.
# It's the user's responsibility to tune them appropriately.
python3 -m realhf.apps.quickstart ppo \
mode=$MODE \
experiment_name=$EXP_NAME \
trial_name=$TRIAL_NAME \
exp_ctrl.total_train_epochs=1 \
exp_ctrl.save_freq_steps=null \
n_nodes=1 \
allocation_mode=heuristic \
actor.type._class=$MODEL_FAMILY \
actor.path=$SFT_MODEL_PATH \
critic.type._class=$MODEL_FAMILY \
critic.type.is_critic=True \
critic.path=$RW_MODEL_PATH \
ref.type._class=$MODEL_FAMILY \
ref.path=$SFT_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.path=$RW_MODEL_PATH \
dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \
dataset.max_prompt_len=128 \
dataset.train_bs_n_seqs=1024 \
ppo.gen.max_new_tokens=512 \
ppo.gen.min_new_tokens=512 \
ppo.gen.use_cuda_graph=True \
ppo.gen.force_no_logits_mask=True \
ppo.gen.top_p=0.9 ppo.gen.top_k=1000 \
ppo.ppo_n_minibatches=4 \
ppo.kl_ctl=0.1 \
ppo.value_eps_clip=0.2 \
ppo.reward_output_scaling=1.0 \
ppo.adv_norm=True ppo.value_norm=True \
actor_gen.n_mbs=2 \
actor_train.n_mbs=4 \
critic_inf.n_mbs=4 \
critic_train.n_mbs=4 \
rew_inf.n_mbs=2 \
ref_inf.n_mbs=8
21 changes: 6 additions & 15 deletions realhf/impl/model/backend/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,30 +201,21 @@ def train_batch(
version_steps: int,
num_micro_batches: Optional[int] = None,
):
if num_micro_batches is None:
num_micro_batches = 1
if constants.pipe_parallel_world_size() > 1:
if num_micro_batches is not None:
if num_micro_batches < self.pipe_runner.default_train_mbs:
logger.warning(
"When training with pipeline parallel, num micro batches should be "
"larger than 2 x num_pipeline_stages to avoid idle time. "
f"Setting num_micro_batches to {self.pipe_runner.default_train_mbs}"
)
num_micro_batches = max(
num_micro_batches, self.pipe_runner.default_train_mbs
)
else:
num_micro_batches = self.pipe_runner.default_train_mbs
# Fusing the minibatched forward-backward in a pipeline training schedule.
instr_set = PipeTrainSetForDeepSpeed(self.ds_engine)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
instr_set=instr_set,
input_=input_,
loss_fn=loss_fn,
version_steps=version_steps,
num_micro_batches=num_micro_batches,
n_pp_mbs=self.pipe_runner.default_train_mbs * num_micro_batches,
)
else:
if num_micro_batches is None:
num_micro_batches = 1
self.ds_engine._config.gradient_accumulation_steps = num_micro_batches
self.ds_engine.set_gradient_accumulation_boundary(False)
if isinstance(
Expand Down
4 changes: 2 additions & 2 deletions realhf/impl/model/backend/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(
for mb_input in input_.split(num_micro_batches):
if constants.pipe_parallel_world_size() > 1:
model_output = self.pipe_runner.forward(
input_=input_,
input_=mb_input,
post_hook=post_hook,
aggregate_fn=aggregate_fn,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ def generate(
for mb_input in input_.split(num_micro_batches):
if constants.pipe_parallel_world_size() > 1:
res = self.pipe_runner.generate(
input_=input_,
input_=mb_input,
tokenizer=tokenizer,
gconfig=gconfig,
)
Expand Down
27 changes: 8 additions & 19 deletions realhf/impl/model/backend/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,34 +731,23 @@ def train_batch(
num_micro_batches: Optional[int] = None,
):
with megatron_ctx():
if num_micro_batches is None:
num_micro_batches = 1
self.engine.zero_grad()
if constants.pipe_parallel_world_size() > 1:
if num_micro_batches is not None:
if (
num_micro_batches < self.pipe_runner.default_train_mbs
and constants.parallelism_rank() == 0
):
logger.warning(
"When training with pipeline parallel, num micro batches should be "
"larger than 2 x num_pipeline_stages to avoid idle time. "
f"Setting num_micro_batches to {self.pipe_runner.default_train_mbs}"
)
num_micro_batches = max(
num_micro_batches, self.pipe_runner.default_train_mbs
)
else:
num_micro_batches = self.pipe_runner.default_train_mbs
instr_set = PipeTrainInstrSetForMegatron(self.engine, num_micro_batches)
# Fusing the minibatched forward-backward in a pipeline training schedule.
n_pp_mbs = self.pipe_runner.default_train_mbs * num_micro_batches
instr_set = PipeTrainInstrSetForMegatron(self.engine, n_pp_mbs)
# NOTE: When training with pipeline parallel, num micro batches should be
# larger than 2 x num_pipeline_stages to avoid idle time.
return self.pipe_runner.train_batch(
instr_set=instr_set,
input_=input_,
loss_fn=loss_fn,
version_steps=version_steps,
num_micro_batches=num_micro_batches,
n_pp_mbs=n_pp_mbs,
)
else:
if num_micro_batches is None:
num_micro_batches = 1
no_sync_ctx = self.engine.ddp.no_sync()
no_sync_ctx.__enter__()
stat = collections.defaultdict(int)
Expand Down
57 changes: 27 additions & 30 deletions realhf/impl/model/backend/pipe_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _split_and_prefill_pipe_input(
module: ReaLModel,
input_: SequenceSample,
tensor_buffer: TensorBuffer,
num_micro_batches: int,
n_pp_mbs: int,
store_kv_cache: bool,
store_input_cache: bool = False,
):
Expand All @@ -49,7 +49,7 @@ def _split_and_prefill_pipe_input(
Basically, splitting all input tensors into micro batches for
pipeline parallel.
"""
n_mbs = num_micro_batches
n_mbs = n_pp_mbs

# Split sequence into several mini-batches.
partition_min_size = input_.bs // n_mbs
Expand Down Expand Up @@ -672,7 +672,7 @@ def _exec_forward_pass(
"input_cache", micro_batch_id, remove=True
)
loss, stats = loss_fn(model_output, input_cache)
loss = loss / tensor_buffer.get("num_micro_batches", micro_batch_id)
loss = loss / tensor_buffer.get("n_pp_mbs", micro_batch_id)
tensor_buffer.put("losses", micro_batch_id, loss)
tensor_buffer.put("stats", micro_batch_id, stats)

Expand Down Expand Up @@ -797,32 +797,32 @@ def train(self, *args, **kwargs):
def forward(
self,
input_: SequenceSample,
num_micro_batches: Optional[int] = None,
n_pp_mbs: Optional[int] = None,
post_hook: Callable[[torch.Tensor, SequenceSample], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
):
"""Run one forward step over a batch of tokens and return the
logits."""

if num_micro_batches is None:
num_micro_batches = self.default_inf_mbs
if n_pp_mbs is None:
n_pp_mbs = self.default_inf_mbs

tensor_buffer = TensorBuffer()
if post_hook is not None:
for i in range(num_micro_batches):
for i in range(n_pp_mbs):
tensor_buffer.put("post_hook", i, post_hook)

_split_and_prefill_pipe_input(
module=self.module,
tensor_buffer=tensor_buffer,
num_micro_batches=num_micro_batches,
n_pp_mbs=n_pp_mbs,
input_=input_,
store_kv_cache=False,
store_input_cache=post_hook is not None,
)

sched = schedule.InferenceSchedule(
micro_batches=num_micro_batches,
micro_batches=n_pp_mbs,
stages=constants.pipe_parallel_world_size(),
stage_id=constants.pipe_parallel_rank(),
)
Expand All @@ -836,7 +836,7 @@ def forward(
agg_output = None
if constants.is_last_pipe_stage():
output_list = []
for i in range(num_micro_batches):
for i in range(n_pp_mbs):
output = tensor_buffer.get("output", i, remove=True)
output_list.append(output)
agg_output = aggregate_fn(output_list)
Expand All @@ -851,28 +851,28 @@ def generate(
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
num_micro_batches: Optional[int] = None,
n_pp_mbs: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[PipeCacheData]]:
if constants.sequence_parallel():
raise NotImplementedError(
"Sequence parallel is not supported for generation"
)

if num_micro_batches is None:
num_micro_batches = self.default_inf_mbs
if n_pp_mbs is None:
n_pp_mbs = self.default_inf_mbs

tensor_buffer = TensorBuffer()

_split_and_prefill_pipe_input(
module=self.module,
tensor_buffer=tensor_buffer,
num_micro_batches=num_micro_batches,
n_pp_mbs=n_pp_mbs,
input_=input_,
store_kv_cache=True,
)

# for elegant generation termination
for mbid in range(num_micro_batches):
for mbid in range(n_pp_mbs):
tensor_buffer.put("kv_cache_reserved", mbid, False)
tensor_buffer.put(
"_terminate",
Expand All @@ -895,7 +895,7 @@ def generate(

num_stages = constants.pipe_parallel_world_size()
sched = schedule.GenerateSchedule(
micro_batches=num_micro_batches,
micro_batches=n_pp_mbs,
stages=constants.pipe_parallel_world_size(),
stage_id=constants.pipe_parallel_rank(),
max_new_tokens=gconfig.max_new_tokens + num_stages // 2 + 10,
Expand All @@ -904,10 +904,7 @@ def generate(

def terminate_condition():
term = all(
[
tensor_buffer.get("_terminate", mbid)
for mbid in range(num_micro_batches)
]
[tensor_buffer.get("_terminate", mbid) for mbid in range(n_pp_mbs)]
)
return term

Expand All @@ -920,15 +917,15 @@ def terminate_condition():
)

if gconfig.use_cuda_graph and gconfig.force_cudagraph_recapture:
for micro_batch_id in range(num_micro_batches):
for micro_batch_id in range(n_pp_mbs):
cuda_graph.destroy(f"decoding_{micro_batch_id}")

if not constants.is_last_pipe_stage():
return None

# Gather generation outputs, including generated tokens, logprobs, and logits_mask.
generate_output = []
for mbid in range(num_micro_batches):
for mbid in range(n_pp_mbs):
generate_output += [
_gather_gen_output_from_list(
gen_token_ph=tensor_buffer.get("gen_token_ph", mbid, remove=True),
Expand All @@ -954,34 +951,34 @@ def train_batch(
input_: SequenceSample,
loss_fn: Callable,
version_steps: int,
num_micro_batches: Optional[int] = None,
n_pp_mbs: Optional[int] = None,
):
# TODO: return whether update success
if not torch._C.is_grad_enabled():
raise RuntimeError(
f"train_batch() requires gradients enabled. Use eval_batch() instead."
)

if num_micro_batches is None:
num_micro_batches = self.default_train_mbs
if n_pp_mbs is None:
n_pp_mbs = self.default_train_mbs

tensor_buffer = TensorBuffer()
for i in range(num_micro_batches):
tensor_buffer.put("num_micro_batches", i, num_micro_batches)
for i in range(n_pp_mbs):
tensor_buffer.put("n_pp_mbs", i, n_pp_mbs)
tensor_buffer.put("version_steps", i, version_steps)
tensor_buffer.put("loss_fn", i, loss_fn)

_split_and_prefill_pipe_input(
module=self.module,
tensor_buffer=tensor_buffer,
num_micro_batches=num_micro_batches,
n_pp_mbs=n_pp_mbs,
input_=input_,
store_kv_cache=False,
store_input_cache=True,
)

sched = schedule.TrainSchedule(
micro_batches=num_micro_batches,
micro_batches=n_pp_mbs,
stages=constants.pipe_parallel_world_size(),
stage_id=constants.pipe_parallel_rank(),
)
Expand All @@ -995,7 +992,7 @@ def train_batch(
agg_stats = None
if constants.is_last_pipe_stage():
stats = []
for mbid in range(num_micro_batches):
for mbid in range(n_pp_mbs):
stats.append(tensor_buffer.get("stats", mbid))
agg_stats = dict()
for key in stats[0].keys():
Expand Down
Loading

0 comments on commit db4bc78

Please sign in to comment.