Skip to content

Commit

Permalink
Merge branch 'main' of github.com:openpsi-project/realhf into v0.3.0-…
Browse files Browse the repository at this point in the history
…docs
  • Loading branch information
garrett4wade committed Sep 3, 2024
2 parents db071d1 + 4c68d90 commit ec1158d
Show file tree
Hide file tree
Showing 22 changed files with 351 additions and 167 deletions.
3 changes: 1 addition & 2 deletions examples/profiling/allocations.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
{"data_parallel_size": 2, "model_parallel_size": 2, "pipeline_parallel_size": 2, "use_sequence_parallel": true}
{"data_parallel_size": 1, "model_parallel_size": 2, "pipeline_parallel_size": 4, "use_sequence_parallel": false}
{"data_parallel_size": 2, "model_parallel_size": 4, "pipeline_parallel_size": 1, "use_sequence_parallel": true}
2 changes: 1 addition & 1 deletion examples/profiling/datasets.jsonl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"type_": "prompt", "args": {"max_length": 2048, "pad_to_max_length": true, "dataset_path": "/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl"}}
{"type_": "prompt", "args": {"max_length": 1024, "pad_to_max_length": true, "dataset_path": "/lustre/fw/datasets/imdb/rl/ppo_prompt.jsonl"}}
3 changes: 1 addition & 2 deletions examples/profiling/interfaces.jsonl
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
{"type_": "ppo_actor", "args": {"generation_config": {"max_new_tokens": 2048,"min_new_tokens": 2048,"use_cuda_graph": true,"force_no_logits_mask": true,"force_cudagraph_recapture": true,"top_p": 1.0,"top_k": 1000000},"enable_save": false,"n_minibatches": 4}}
{"type_": "ppo_critic", "args": {"enable_save": false,"n_minibatches": 4}}
{"type_": "ppo_actor", "args": {"generation_config": {"max_new_tokens": 1024,"min_new_tokens": 1024,"use_cuda_graph": true,"force_no_logits_mask": true,"force_cudagraph_recapture": true,"top_p": 1.0,"top_k": 1000000},"enable_save": false,"n_minibatches": 8}}
8 changes: 4 additions & 4 deletions examples/profiling/profile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MODEL_FAMILY=llama
SFT_MODEL_PATH=/lustre/public/pretrained_model_weights/Llama-2-7b-hf

EXP_NAME=profile-example
TRIAL_NAME=ppo-gen-n2-p2048g2048b1024
TRIAL_NAME=test

export CLUSTER_SPEC_PATH="/lustre/aigc/llm/cluster/qh.json"

Expand All @@ -30,7 +30,7 @@ export CLUSTER_SPEC_PATH="/lustre/aigc/llm/cluster/qh.json"
# all within the same experiment_name and trial_name. Instead of re-launching the whole experiment, workers will
# be paused and reconfigured to run the next experiment setup.

REAL_DUMP_TRACE=0 REAL_DUMP_MEMORY=0 \
REAL_DUMP_TRACE=1 REAL_DUMP_MEMORY=1 \
python3 -m realhf.apps.quickstart profile \
mode=local \
experiment_name=$EXP_NAME \
Expand All @@ -39,10 +39,10 @@ REAL_DUMP_TRACE=0 REAL_DUMP_MEMORY=0 \
exp_ctrl.save_freq_steps=null \
exp_ctrl.eval_freq_steps=null \
n_nodes=1 \
'handle_names=[generate]' \
'handle_names=[train_step]' \
interfaces_jsonl=./examples/profiling/interfaces.jsonl \
models_jsonl=./examples/profiling/models.jsonl \
datasets_jsonl=./examples/profiling/datasets.jsonl \
allocations_jsonl=./examples/profiling/allocations.jsonl \
'n_mbs=[1, 2, 4]' \
'batch_sizes=[1024]'
'batch_sizes=[512]'
16 changes: 8 additions & 8 deletions realhf/api/core/data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class DataBatchMeta:
@dataclasses.dataclass
class DatasetUtility:
seed: int
ddp_rank: int
dp_rank: int
world_size: int
tokenizer: transformers.PreTrainedTokenizerFast

Expand Down Expand Up @@ -654,7 +654,7 @@ def load_shuffle_split_dataset(
util.seed, datasize_per_rank * util.world_size
)
subset_indices = shuffle_indices[
util.ddp_rank * datasize_per_rank : (util.ddp_rank + 1) * datasize_per_rank
util.dp_rank * datasize_per_rank : (util.dp_rank + 1) * datasize_per_rank
]
data: List[Dict[str, str]] = [data[i] for i in subset_indices]

Expand All @@ -673,7 +673,7 @@ def register_dataset(name, dataset_cls):
def make_dataset(
cfg: Union[str, config_api.DatasetAbstraction],
seed: int,
ddp_rank: int,
dp_rank: int,
world_size: int,
tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str],
experiment_name: str,
Expand All @@ -691,7 +691,7 @@ def make_dataset(
tokenizer = tokenizer_or_tokenizer_name
util = DatasetUtility(
seed,
ddp_rank,
dp_rank,
world_size,
tokenizer,
)
Expand All @@ -717,7 +717,7 @@ def make_dataset(
cfg.type_,
f"seed{seed}",
f"world_size{world_size}",
f"rank{ddp_rank}",
f"rank{dp_rank}",
)
os.makedirs(output_path, exist_ok=True)

Expand All @@ -726,11 +726,11 @@ def make_dataset(

tik = time.perf_counter()
if not cache_found:
logger.info(f"No data cache found for rank {ddp_rank}. Create it from scratch.")
dataset = ALL_DATASET_CLASSES[cfg.type_](seed, ddp_rank, world_size, **cfg.args)
logger.info(f"No data cache found for rank {dp_rank}. Create it from scratch.")
dataset = ALL_DATASET_CLASSES[cfg.type_](seed, dp_rank, world_size, **cfg.args)
torch.save(dataset, os.path.join(output_path, fname))
else:
logger.info(f"Rank {ddp_rank} find existing data cache, load it.")
logger.info(f"Rank {dp_rank} find existing data cache, load it.")
dataset = torch.load(os.path.join(output_path, fname))
logger.info(f"Dataset creation/loading time: {time.perf_counter() - tik:.3f}s")

Expand Down
14 changes: 5 additions & 9 deletions realhf/apps/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
logger = logging.getLogger("main", "system")

CONTROLLER_TIME_LIMIT = None
TRACE_TIMEOUT = 360 # Should be larger than TRACER_SAVE_INTERVAL_SECONDS defined in system/worker_base.py


def scheduler_mode(mode: str) -> str:
Expand Down Expand Up @@ -193,9 +192,6 @@ def main_start(args, recover_count: int = 0):
args.image_name,
)

timeout = (
None if os.getenv("REAL_TRACE", "0") == "0" else TRACE_TIMEOUT
) # run 5 mins to collect trace
try:
sched.wait(
check_status=(
Expand All @@ -205,13 +201,8 @@ def main_start(args, recover_count: int = 0):
JobState.COMPLETED,
),
remove_status=(),
timeout=timeout,
)
except (KeyboardInterrupt, JobException, TimeoutError) as e:
if os.getenv("REAL_TRACE", "0") != "0" and isinstance(e, TimeoutError):
s = "#" * 30 + " Trace complete. Killing all processes... " + "#" * 30
logger.info("\n" + "#" * len(s) + "\n" + s + "\n" + "#" * len(s))

recover_states = [
JobState.CANCELLED,
JobState.FAILED,
Expand Down Expand Up @@ -284,6 +275,11 @@ def _main_profile_layers(model_family, model_path):
)

if check_slurm_availability():
if not os.environ.get("CLUSTER_SPEC_PATH", ""):
raise ValueError(
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
"See example/cluster_config.json for a template."
)
BASE_ENVIRONS = constants.get_env_vars(
WANDB_MODE="disabled",
REAL_MODE="slurm",
Expand Down
5 changes: 3 additions & 2 deletions realhf/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_tensor(self, tensor_shape, dtype, name, force_zero: bool = False):
QUICKSTART_EXPR_CACHE_PATH = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/"
BASE_ENVIRONS = {
"PYTHONPATH": "/realhf",
"REAL_TRACE": os.getenv("REAL_TRACE", "0"),
"REAL_IS_REMOTE": "1",
# "NCCL_P2P_DISABLE": "1",
# "NCCL_IB_DISABLE": "1",
Expand Down Expand Up @@ -107,7 +106,9 @@ def get_tensor(self, tensor_shape, dtype, name, force_zero: bool = False):
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
# Whether to enable time mark to plot timelines.
# "REAL_CUDA_TMARK": "1",
"REAL_CUDA_TMARK": os.getenv("REAL_CUDA_TMARK", "0"),
"REAL_DUMP_TRACE": os.getenv("REAL_DUMP_TRACE", "0"),
"REAL_DUMP_MEMORY": os.getenv("REAL_DUMP_MEMORY", "0"),
}


Expand Down
4 changes: 1 addition & 3 deletions realhf/base/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def set_cuda_device(device):
torch.cuda.set_device(device)


def reveal_ddp_identity(expr_name, trial_name, worker_index):
def reveal_pg_identity(expr_name, trial_name, worker_index):
master_group_name = names.distributed_peer(
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
)
Expand All @@ -76,8 +76,6 @@ def isolate_cuda_device(
then CUDA_VISIBLE_DEVICES of these jobsteps will be 0,1, instead of 0 and 1.
We use this function in `apps.remote` to isolate CUDA_VISIBLE_DEVICES for each jobstep.
Note that this function is completely independent of `setup_ddp`.
Args:
worker_type (str): .
rank (int): Rank of the **jobstep**.
Expand Down
Loading

0 comments on commit ec1158d

Please sign in to comment.