Skip to content

Commit

Permalink
support HF processor
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett4wade committed Dec 24, 2024
1 parent 37b2a70 commit 3239867
Show file tree
Hide file tree
Showing 12 changed files with 60 additions and 35 deletions.
33 changes: 22 additions & 11 deletions realhf/api/core/data_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,30 @@

logger = logging.getLogger("api.data")

TokenizerLike = Union[transformers.PreTrainedTokenizerFast, transformers.ProcessorMixin]


def load_hf_tokenizer(
model_name_or_path: str,
fast_tokenizer=True,
padding_side: Optional[str] = None,
) -> transformers.PreTrainedTokenizerFast:
kwargs = {}
if padding_side is not None:
kwargs["padding_side"] = padding_side
tokenizer = transformers.AutoTokenizer.from_pretrained(
**kwargs,
) -> TokenizerLike:
"""Load a HuggingFace processor.
It could be a text processor, which is
a duck type of tokenizer, or a multi-modal
processor, which also has image/audio/video
processing capabilities.
This function also sets the pad_token_id.
:param model_name_or_path: The HF model name or path.
:type model_name_or_path: str
:param kwargs: Additional keyword arguments passed to HF.
:type kwargs: Dict[str, Any]
"""
kwargs.update(trust_remote_code=True, fast_tokenizer=True)
tokenizer = transformers.AutoProcessor.from_pretrained(
model_name_or_path,
fast_tokenizer=fast_tokenizer,
trust_remote_code=True,
**kwargs,
)
if tokenizer.pad_token_id is None:
Expand Down Expand Up @@ -608,7 +619,7 @@ class DatasetUtility:
seed: int
dp_rank: int
world_size: int
tokenizer: transformers.PreTrainedTokenizerFast
tokenizer: TokenizerLike

def __post_init__(self):
if self.tokenizer.pad_token_id is None:
Expand Down Expand Up @@ -679,7 +690,7 @@ def make_dataset(
seed: int,
dp_rank: int,
world_size: int,
tokenizer_or_tokenizer_name: Union[transformers.PreTrainedTokenizerFast, str],
tokenizer_or_tokenizer_name: Union[TokenizerLike, str],
experiment_name: str,
trial_name: str,
cache_root: Optional[str] = None,
Expand Down
10 changes: 5 additions & 5 deletions realhf/api/core/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ModelName,
ModelWrapperAbstraction,
)
from realhf.api.core.data_api import SequenceSample, load_hf_tokenizer
from realhf.api.core.data_api import SequenceSample, TokenizerLike, load_hf_tokenizer

logger = logging.getLogger("model_api")

Expand Down Expand Up @@ -428,7 +428,7 @@ def forward(
def generate(
self,
input_: SequenceSample,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: TokenizerLike,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
Expand All @@ -440,7 +440,7 @@ def generate(
which includes the concatenated prompts.
:type input_: SequenceSample
:param tokenizer: The tokenizer for the model.
:type tokenizer: transformers.PreTrainedTokenizerFast
:type tokenizer: TokenizerLike
:param gconfig: The generation hyperparameters.
:type gconfig: GenerationHyperparameters
:param num_micro_batches: The number of micro-batches to split the batch into.
Expand Down Expand Up @@ -472,7 +472,7 @@ class Model:
sharded by tensor or pipeline parallelism.
:type module: PipelinableEngine | torch.nn.Module
:param tokenizer: The tokenizer associated with the model.
:type tokenizer: transformers.PreTrainedTokenizerFast
:type tokenizer: TokenizerLike
:param device: The device on which to run the model.
:type device: Union[str, torch.device]
:param dtype: The data type of the model. Defaults to torch.float16
Expand All @@ -487,7 +487,7 @@ class Model:

name: ModelName
module: PipelinableEngine | torch.nn.Module
tokenizer: transformers.PreTrainedTokenizerFast
tokenizer: TokenizerLike
device: Union[str, torch.device]
dtype: Optional[torch.dtype] = None
version: ModelVersion = dataclasses.field(default_factory=ModelVersion)
Expand Down
7 changes: 7 additions & 0 deletions realhf/api/quickstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ class ModelTrainEvalConfig:
:type optimizer: Optional[OptimizerConfig]
:param init_critic_from_actor: Whether to initialize a critic/reward model from a saved LM checkpoint.
:type init_critic_from_actor: bool
:param init_from_scratch: Whether to initialize the model from scratch.
:type init_from_scratch: bool
:param tokenizer_kwargs: Additional kwargs for the HuggingFace tokenizer/processor.
:type tokenizer_kwargs: Optional[Dict[str, Any]]
"""

type: ModelFamily = dataclasses.field(default=ModelFamily("llama", 7, False))
Expand All @@ -173,6 +177,7 @@ class ModelTrainEvalConfig:
)
init_from_scratch: bool = False
init_critic_from_actor: bool = False
tokenizer_kwargs: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)


def get_real_model_config(
Expand All @@ -182,6 +187,7 @@ def get_real_model_config(
init_from_scratch: bool,
init_critic_from_actor: bool,
dtype: Optional[str] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
# LoRA config
lora: Optional[LoRAConfig] = None,
is_sft_lora: bool = False,
Expand All @@ -199,6 +205,7 @@ def get_real_model_config(
dtype=dtype,
hf_model_family=hf_model_family,
init_from_scratch=init_from_scratch,
tokenizer_kwargs=tokenizer_kwargs,
),
)
if is_sft_lora:
Expand Down
1 change: 1 addition & 0 deletions realhf/experiments/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def _get_model_worker_configs(
init_critic_from_actor=model_cfg.init_critic_from_actor,
dtype="bf16" if model_cfg.enable_bf16 else "fp16",
lora=model_cfg.lora,
tokenizer_kwargs=model_cfg.tokenizer_kwargs,
)
mapping = rpc_alloc.device_mesh.mapping
gradient_checkpointing = model_cfg.gradient_checkpointing and any(
Expand Down
4 changes: 2 additions & 2 deletions realhf/impl/model/backend/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer

import realhf.api.core.model_api as model_api
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core import data_api, model_api
from realhf.api.core.data_api import SequenceSample
from realhf.base.datapack import flat2d
from realhf.base.monitor import CUDATimeMarkType, cuda_tmark, cuda_tmarked
Expand Down Expand Up @@ -264,7 +264,7 @@ def forward(
def generate(
self,
input_: SequenceSample,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
gconfig: model_api.GenerationHyperparameters = dataclasses.field(
default_factory=model_api.GenerationHyperparameters
),
Expand Down
4 changes: 2 additions & 2 deletions realhf/impl/model/backend/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch.distributed as dist
import transformers

import realhf.api.core.model_api as model_api
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core import data_api, model_api
from realhf.api.core.data_api import SequenceSample
from realhf.base.datapack import flat2d
from realhf.impl.model.backend.pipe_runner import PipelineRunner
Expand Down Expand Up @@ -127,7 +127,7 @@ def forward(
def generate(
self,
input_: SequenceSample,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
gconfig: model_api.GenerationHyperparameters = dataclasses.field(
default_factory=model_api.GenerationHyperparameters
),
Expand Down
4 changes: 2 additions & 2 deletions realhf/impl/model/backend/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DistributedOptimizer:
pass


from realhf.api.core import model_api
from realhf.api.core import data_api, model_api
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, logging
from realhf.base.datapack import flat2d
Expand Down Expand Up @@ -810,7 +810,7 @@ def forward(
def generate(
self,
input_: SequenceSample,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
gconfig: model_api.GenerationHyperparameters = dataclasses.field(
default_factory=model_api.GenerationHyperparameters
),
Expand Down
4 changes: 2 additions & 2 deletions realhf/impl/model/backend/pipe_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import realhf.impl.model.parallelism.pipeline_parallel.p2p as p2p
import realhf.impl.model.parallelism.pipeline_parallel.static_schedule as schedule
import realhf.impl.model.utils.cuda_graph as cuda_graph
from realhf.api.core.data_api import SequenceSample
from realhf.api.core.data_api import SequenceSample, TokenizerLike
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.base.datapack import flat2d
from realhf.impl.model.nn.real_llm_api import ReaLModel
Expand Down Expand Up @@ -847,7 +847,7 @@ def forward(
def generate(
self,
input_: SequenceSample,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: TokenizerLike,
gconfig: GenerationHyperparameters = dataclasses.field(
default_factory=GenerationHyperparameters
),
Expand Down
9 changes: 6 additions & 3 deletions realhf/impl/model/nn/real_llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.utils.checkpoint
import transformers

from realhf.api.core import model_api
from realhf.api.core import data_api, model_api
from realhf.api.core.config import ModelName
from realhf.base import constants, logging, topology
from realhf.base.monitor import CUDATimeMarkType, cuda_tmark, cuda_tmarked
Expand Down Expand Up @@ -788,7 +788,7 @@ def patch_reparallelization(self, x, eta):
# a helper function to make real_model look like huggingface model
def generate_helper(
self: ReaLModel,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
packed_input_ids: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -863,6 +863,7 @@ def make_real_model(
init_critic_from_actor: bool,
dtype: Optional[str] = None,
hf_model_family: Optional[str] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
) -> model_api.Model:
if dtype == "fp16" or dtype == None:
dtype = torch.float16
Expand All @@ -873,7 +874,9 @@ def make_real_model(
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

tokenizer = model_api.load_hf_tokenizer(model_path)
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
tokenizer = model_api.load_hf_tokenizer(model_path, **tokenizer_kwargs)
mconfig = getattr(ReaLModel, f"config_from_{hf_model_family}")(
model_path=model_path,
is_critic=is_critic or init_critic_from_actor,
Expand Down
13 changes: 7 additions & 6 deletions realhf/impl/model/nn/real_llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformers

import realhf.impl.model.utils.cuda_graph as cuda_graph
from realhf.api.core import data_api
from realhf.api.core.model_api import GenerationHyperparameters, ReaLModelConfig
from realhf.base import constants, logging
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
Expand All @@ -25,7 +26,7 @@

def genstep(
next_token_logits: torch.Tensor,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
unfinished_sequences: torch.Tensor,
generated_idx: Union[torch.IntTensor, int],
gconfig: GenerationHyperparameters,
Expand All @@ -34,7 +35,7 @@ def genstep(
Args:
next_token_logits (torch.Tensor): Shape [bs, vocab_size].
tokenizer (transformers.PreTrainedTokenizerFast): .
tokenizer (data_api.TokenizerLike): .
unfinished_sequences (torch.Tensor): Bool tensor indicator of whether a sequence is finished.
Shape [bs].
generated_idx (int): The token index to be generated.
Expand Down Expand Up @@ -251,7 +252,7 @@ def maybe_capture_cudagraph(
@torch.no_grad()
def generate(
model: "ReaLModel",
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
packed_input_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
Expand Down Expand Up @@ -533,7 +534,7 @@ def concat_prompt_to_generation_output(
@torch.no_grad()
def vanilla_packed_generate(
model: "ReaLModel",
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
gconfig: GenerationHyperparameters = dataclasses.field(
Expand Down Expand Up @@ -600,7 +601,7 @@ def vanilla_packed_generate(
@torch.no_grad()
def vanilla_cpu_generate(
model: "ReaLModel",
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
gconfig: GenerationHyperparameters = dataclasses.field(
Expand Down Expand Up @@ -668,7 +669,7 @@ def __init__(
inqueue: queue.Queue,
outqueue: queue.Queue,
model: "ReaLModel",
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
gconfig: GenerationHyperparameters,
batch_size: int,
max_prompt_len: int,
Expand Down
3 changes: 2 additions & 1 deletion realhf/impl/model/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed as dist
import transformers

from realhf.api.core import data_api
from realhf.base import constants, logging
from realhf.impl.model.utils.padding import pad_input, unpad_input

Expand Down Expand Up @@ -295,7 +296,7 @@ def masked_normalization(

def get_eos_indices(
input_ids: torch.LongTensor,
tokenizer: transformers.PreTrainedTokenizerFast,
tokenizer: data_api.TokenizerLike,
) -> Tuple[torch.LongTensor, torch.FloatTensor]:
if torch.any(input_ids[:, 0] == tokenizer.eos_token_id):
indices = (input_ids[:, 0] == tokenizer.eos_token_id).nonzero().flatten()
Expand Down
3 changes: 2 additions & 1 deletion realhf/search_engine/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import realhf.api.core.system_api as config_package
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core import data_api
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.utils.padding import unpad_input

Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
self,
model_name: str,
config: ReaLModelConfig,
tokenizer: transformers.PreTrainedTokenizerFast = None,
tokenizer: data_api.TokenizerLike = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
Expand Down

0 comments on commit 3239867

Please sign in to comment.