Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Patch Fix] Support arbitrary symmetric allocations and fix MFC time log in workers #60

Merged
merged 13 commits into from
Sep 3, 2024
85 changes: 85 additions & 0 deletions examples/scripts/local/ppo_symm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# MODEL_FAMILY specifies how the pretrained checkpoint is loaded, e.g., as a LLaMA model or a GPT model.
# You can specify different model families for the SFT and the RW model, but you need to
# re-tokenize the sequences if necessary.
MODEL_FAMILY=llama

# SFT_MODEL_PATH and RW_MODEL_PATH are the saved SFT and RW checkpoints.
# ReaL saves checkpoints with the same format as HuggingFace,
# so you don't need to convert or split checkpoints explicitly.
# You can also directly use the pre-trained HuggingFace checkpoint, but this
# will not ensure the optimal algorithm performance.
SFT_MODEL_PATH=/lustre/aigc/llm/checkpoints/fw/quickstart-sft/$MODEL_FAMILY-local-manual/default/epoch7epochstep5globalstep50/
RW_MODEL_PATH=/lustre/aigc/llm/checkpoints/fw/quickstart-rw/$MODEL_FAMILY-ray-manual/default/epoch1epochstep10globalstep10/

# Option 1: The experiment runs locally with subprocesses.
MODE=local
# Option 2: The experiment runs in a Ray cluster
# MODE=ray
# Option 3: The experiment runs in a SLURM + pyxis cluster
# Using the slurm mode requires a cluster spec file
# and setting CLUSTER_SPEC_PATH to the path of it.
# MODE=slurm

# `experiment_name` and `trial_name` can be arbitrary.
# Logs and saved checkpoints will be indexed by them.
EXP_NAME=quickstart-ppo
TRIAL_NAME=$MODEL_FAMILY-$MODE-manual

# When using the "manual" allocation mode, the user should specify the device allocation
# and parallel strategies for each model function calls.
# The number of GPUs is `n_nodes` * `n_gpus_per_node` (not set explictly here, defaults to 8).
# We provide a template in the following command and the user can modify it according to
# the specific model and the available GPUs.

# The `ppo` subcommand specifies that this is a PPO experiment.
# The `save_freq_steps` is set to `null` to disable saving checkpoints.
# Enable it if you want to save checkpoints.
# The `ppo` option is used to control the generation and PPO algorithm hyperparameters.
# Note that the performance of PPO is sensitive to the the pre-trained model and hyperparameters.
# It's the user's responsibility to tune them appropriately.
# The allocation of model function calls is specified by a pattern `hostname:gpu_id1,gpu_id2,...`
# for slicing GPUS of a single node, and `hostname1,hostname2` for multiple nodes.
# Only 1, 2, 4, 8 GPUs on a single node or multiple complete nodes (e.g., 16, 24) are supported.
# If the CLUSTER_SPEC_PATH is not set, `hostname`s are NODE01, NODE02, etc, otherwise it's the
# hostname specified in this file. The `gpu_id`s are the GPU indices on the host,
# from 0 to `n_gpus_per_node` (defaults to 8, can be changed) - 1.
# Once allocations are all set, parallel strategies can be specified as long as the world size
# equals to the number of GPUs in the allocation.

# The following command shows an example of manual allocation on two nodes,
# but it can be modified according to the specific model and the available GPUs.
unset CLUSTER_SPEC_PATH
python3 -m realhf.apps.quickstart ppo \
mode=$MODE \
experiment_name=$EXP_NAME \
trial_name=$TRIAL_NAME \
exp_ctrl.total_train_epochs=1 \
exp_ctrl.save_freq_steps=null \
actor.type._class=$MODEL_FAMILY \
actor.path=$SFT_MODEL_PATH \
actor.optimizer.lr_scheduler_type=constant \
actor.optimizer.lr=1e-4 \
actor.optimizer.warmup_steps_proportion=0.0 \
critic.type._class=$MODEL_FAMILY \
critic.type.is_critic=True \
critic.path=$RW_MODEL_PATH \
ref.type._class=$MODEL_FAMILY \
ref.path=$SFT_MODEL_PATH \
rew.type._class=$MODEL_FAMILY \
rew.type.is_critic=True \
rew.path=$RW_MODEL_PATH \
dataset.path=/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl \
dataset.max_prompt_len=128 \
dataset.train_bs_n_seqs=128 \
ppo.gen.max_new_tokens=512 \
ppo.gen.min_new_tokens=512 \
ppo.gen.top_p=0.9 ppo.gen.top_k=1000 \
ppo.ppo_n_minibatches=4 \
ppo.kl_ctl=0.1 \
ppo.value_eps_clip=0.2 \
ppo.reward_output_scaling=10.0 \
ppo.adv_norm=True ppo.value_norm=True \
allocation_mode=m2d2p2 \
actor_gen.n_mbs=2 \
actor_train.n_mbs=4 \
ref_inf.n_mbs=2
6 changes: 6 additions & 0 deletions realhf/base/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# log format constants
import contextlib
import copy
import datetime
import getpass
import os
import pathlib
Expand Down Expand Up @@ -53,6 +54,11 @@ def get_tensor(self, tensor_shape, dtype, name, force_zero: bool = False):
return res


# 30 minutes. Transferring super-large batches via NCCL bcast
# for the first time may consumer over 600 secs, which is the
# pytorch's default. Increase this value to 30 minutes.
NCCL_DEFAULT_TIMEOUT = datetime.timedelta(seconds=1800)

# constants in experiment instance scope
MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
Expand Down
5 changes: 4 additions & 1 deletion realhf/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist

import realhf.base.logging as logging
from realhf.base.constants import NCCL_DEFAULT_TIMEOUT

logger = logging.getLogger("Topology")

Expand All @@ -19,7 +20,9 @@ def new_or_get_group(ranks: List[int], backend=None):
global GLOBAL_PROCESS_GROUP_REGISTRY
key = (ranks, backend)
if key not in GLOBAL_PROCESS_GROUP_REGISTRY:
GLOBAL_PROCESS_GROUP_REGISTRY[key] = dist.new_group(ranks, backend=backend)
GLOBAL_PROCESS_GROUP_REGISTRY[key] = dist.new_group(
ranks, backend=backend, timeout=NCCL_DEFAULT_TIMEOUT
)
return GLOBAL_PROCESS_GROUP_REGISTRY[key]


Expand Down
45 changes: 29 additions & 16 deletions realhf/experiments/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from realhf.experiments.common.check import check_is_realhf_native_model_interface
from realhf.experiments.common.utils import (
extract_symmetric_allocation,
get_topo,
make_inf_backend_config,
make_train_backend_config,
Expand Down Expand Up @@ -87,9 +88,11 @@ class CommonExperimentConfig(Experiment):

- ``heuristic``\: allocate resources and configure parallel strategies with heuristic strategies previously obtained by search.

- ``pipe_data``\: identical parallelization (like DSChat) with pipe+data parallelism. A world size under 8 will use data parallelism only.
- ``pipe_data``\: identical parallelization (like DSChat) with pipe+data parallelism for all MFCs. A world size under 8 will use data parallelism only.

- ``pipe_model``\: identical parallelization (like DSChat) with pipe+model parallelism. A world size under 8 will use tensor-model parallelism only.
- ``pipe_model``\: identical parallelization (like DSChat) with pipe+model parallelism for all MFCs. A world size under 8 will use tensor-model parallelism only.

- A regex pattern like ``d${DP}p${PP}m${TP}``\: identical parallelization for all MFCs with ${DP}-way data parallelism, ${PP}-way pipeline parallelism, and ${TP}-way model parallelism.

:param experiment_name: Name of the experiment.
Arbitrary string without "_" and "/", e.g., ``ultra-chat-llama``.
Expand Down Expand Up @@ -335,27 +338,35 @@ def _get_rpc_allocations(self) -> List[RPCAllocation]:
else:
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
elif (
self.allocation_mode == "pipe_data" or self.allocation_mode == "pipe_model"
self.allocation_mode == "pipe_data"
or self.allocation_mode == "pipe_model"
or extract_symmetric_allocation(self.allocation_mode)
):
if self.allocation_mode == "pipe_data":
dp, pp, mp = self.n_gpus_per_node, self.n_nodes, 1
elif self.allocation_mode == "pipe_model":
dp, pp, mp = 1, self.n_nodes, self.n_gpus_per_node
else:
para = extract_symmetric_allocation(self.allocation_mode)
dp, pp, mp = para["d"], para["p"], para["m"]
if dp * pp * mp != self.n_nodes * self.n_gpus_per_node:
raise ValueError(
"The multiplication of 3D parallel degrees "
"does not equal to the number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
rpc_allocs: List[RPCAllocation] = [
RPCAllocation(
rpc=rpc,
device_mesh=self.global_device_mesh,
parallel=ParallelismConfig(
data_parallel_size=(
self.n_gpus_per_node
if self.allocation_mode == "pipe_data"
else 1
),
pipeline_parallel_size=self.n_nodes,
model_parallel_size=(
self.n_gpus_per_node
if self.allocation_mode == "pipe_model"
else 1
),
data_parallel_size=dp,
pipeline_parallel_size=pp,
model_parallel_size=mp,
use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and self.allocation_mode == "pipe_model"
and mp > 1
),
),
)
Expand All @@ -380,7 +391,9 @@ def _get_rpc_allocations(self) -> List[RPCAllocation]:
elif self.allocation_mode == "heuristic":
rpc_allocs: List[RPCAllocation] = self._heuristic_rpc_allocation()
else:
raise NotImplementedError()
raise NotImplementedError(
f'Unknown allocation mode "{self.allocation_mode}".'
)
return rpc_allocs

def _get_model_worker_configs(
Expand Down
16 changes: 16 additions & 0 deletions realhf/experiments/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import collections
import itertools
import re
from typing import *

import numpy as np
Expand Down Expand Up @@ -194,3 +196,17 @@ def resolve_rpc_hooks(
):
rpc.add_post_hook(OffloadHook())
logger.info(f"Add offload hook for rpc {rpc.name} for role {rpc.role}")


def extract_symmetric_allocation(allocation_mode: str) -> Dict | None:
for x, y, z in itertools.permutations(["d", "m", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
if not m:
continue
a, b, c = map(int, m.groups())
return {
"d": a,
"m": b,
"p": c,
}
1 change: 1 addition & 0 deletions realhf/impl/model/comm/global_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def setup_global_comm(
rank=global_rank,
init_method=ddp_init_address,
backend=backend,
timeout=constants.NCCL_DEFAULT_TIMEOUT,
)
if torch.cuda.is_available():
torch.cuda.set_device(
Expand Down
Loading
Loading