Skip to content

Commit

Permalink
Generate Hugging Face config.json
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jan 25, 2025
1 parent b29bd89 commit bfee716
Show file tree
Hide file tree
Showing 18 changed files with 396 additions and 215 deletions.
3 changes: 3 additions & 0 deletions src/fairseq2/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@
from fairseq2.checkpoint._metadata_provider import (
FileCheckpointMetadataProvider as FileCheckpointMetadataProvider,
)
from fairseq2.checkpoint._metadata_provider import (
FileCheckpointMetadataSaver as FileCheckpointMetadataSaver,
)
93 changes: 34 additions & 59 deletions src/fairseq2/checkpoint/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Iterable, Iterator, Mapping, Set
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from shutil import Error
from typing import final
from warnings import catch_warnings

Expand All @@ -32,19 +33,11 @@
TensorLoader,
TensorLoadError,
)
from fairseq2.utils.structured import unstructure
from fairseq2.utils.yaml import YamlDumper


class CheckpointManager(ABC):
"""Saves and loads training checkpoints."""

@abstractmethod
def save_checkpoint_card(
self, family: str, config: object, tokenizer_name: str | None = None
) -> None:
"""Save the checkpoint model metadata."""

@abstractmethod
def begin_checkpoint(self, step_nr: int) -> None:
"""Begin a transactional checkpoint operation.
Expand Down Expand Up @@ -166,8 +159,6 @@ class FileCheckpointManager(CheckpointManager):
_file_system: FileSystem
_tensor_loader: TensorLoader
_tensor_dumper: TensorDumper
_yaml_dumper: YamlDumper
_num_shards: int
_shard_suffix: str
_checkpoint_step_nr: int | None

Expand All @@ -178,7 +169,6 @@ def __init__(
file_system: FileSystem,
tensor_loader: TensorLoader,
tensor_dumper: TensorDumper,
yaml_dumper: YamlDumper,
) -> None:
self._checkpoint_dir = checkpoint_dir.expanduser().resolve()

Expand All @@ -189,53 +179,13 @@ def __init__(
self._tensor_loader = tensor_loader
self._tensor_dumper = tensor_dumper

self._yaml_dumper = yaml_dumper

self._num_shards = gangs.tp.size

if self._num_shards > 1:
if gangs.tp.rank > 1:
self._shard_suffix = f".{gangs.tp.rank}"
else:
self._shard_suffix = ""

self._checkpoint_step_nr = None

@override
def save_checkpoint_card(
self, family: str, config: object, tokenizer_name: str | None = None
) -> None:
if self._gangs.root.rank == 0:
metadata: dict[str, object] = {
"name": "checkpoint",
"model_family": family,
"model_config": unstructure(config),
}

if tokenizer_name is not None:
metadata["tokenizer_ref"] = tokenizer_name

if self._num_shards != 1:
metadata["num_shards"] = self._num_shards

metadata_file = self._checkpoint_dir.joinpath("model.yaml")

def save_error() -> CheckpointError:
return CheckpointError(
f"The model metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
)

try:
self._file_system.make_directory(self._checkpoint_dir)
except OSError as ex:
raise save_error() from ex

try:
self._yaml_dumper.dump(metadata, metadata_file)
except OSError as ex:
raise save_error() from ex

self._gangs.root.barrier()

@override
def begin_checkpoint(self, step_nr: int) -> None:
if self._checkpoint_step_nr is not None:
Expand Down Expand Up @@ -270,6 +220,8 @@ def save_state(
model_key: str = "model",
replicated_keys: Set[str] | None = None,
) -> None:
gangs = self._gangs

step_nr = self._get_checkpoint_step_nr()

tmp_step_dir = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp")
Expand All @@ -280,7 +232,7 @@ def save_state(
rank_part["model_key"] = model_key

def model_replicated() -> bool:
if self._gangs.dp.size == 1:
if gangs.dp.size == 1:
return True

if not replicated_keys:
Expand All @@ -294,7 +246,7 @@ def model_replicated() -> bool:
if state_dict is not None:
del rank_part["model_key"]

if self._gangs.dp.rank == 0:
if gangs.dp.rank == 0:
model_file = tmp_step_dir.joinpath(f"model{self._shard_suffix}.pt")

try:
Expand All @@ -306,11 +258,11 @@ def model_replicated() -> bool:
step_nr, f"The replicated model state of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." # fmt: skip
) from ex

self._gangs.root.barrier()
gangs.root.barrier()

# Save the replicated state.
if replicated_keys:
if self._gangs.dp.rank == 0:
if gangs.dp.rank == 0:
replicated_part = {}

if "*" in replicated_keys:
Expand Down Expand Up @@ -343,7 +295,7 @@ def model_replicated() -> bool:
except KeyError:
pass

self._gangs.root.barrier()
gangs.root.barrier()

# Check if anything is left to save for the rank.
skip_rank = len(rank_part) == 0
Expand All @@ -353,7 +305,7 @@ def model_replicated() -> bool:
# Save the per-rank state.
if not skip_rank:
rank_file = tmp_step_dir.joinpath(
f"rank_{self._gangs.dp.rank}{self._shard_suffix}.pt"
f"rank_{gangs.dp.rank}{self._shard_suffix}.pt"
)

try:
Expand All @@ -363,7 +315,30 @@ def model_replicated() -> bool:
step_nr, f"The checkpoint state of training step {step_nr} cannot be saved to the '{rank_file}' file. See the nested exception for details." # fmt: skip
) from ex

self._gangs.root.barrier()
gangs.root.barrier()

# Copy carbon-copy files to the checkpoint directory.
if gangs.root.rank == 0:
cc_dir = self._checkpoint_dir.joinpath("cc")

try:
cc_exists = self._file_system.exists(cc_dir)
except OSError as ex:
raise CheckpointSaveError(
step_nr,
"The checkpoint carbon copy directory cannot be accessed. See the nested exception for details.",
) from ex

if cc_exists:
try:
self._file_system.copy_directory(cc_dir, tmp_step_dir)
except (OSError, Error) as ex:
raise CheckpointSaveError(
step_nr,
f"The checkpoint carbon copy directory cannot be copied to the '{tmp_step_dir}' directory. See the nested exception for details.",
) from ex

gangs.root.barrier()

@override
def save_metadata(self, metadata: Mapping[str, object]) -> None:
Expand Down
127 changes: 104 additions & 23 deletions src/fairseq2/checkpoint/_metadata_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import json
from pathlib import Path
from typing import Iterable, final

Expand All @@ -16,7 +17,106 @@
AssetMetadataError,
MetadataFileLoader,
)
from fairseq2.utils.file import FileSystem
from fairseq2.gang import Gangs
from fairseq2.models.llama import LLAMA_MODEL_FAMILY, LLaMAConfig
from fairseq2.models.llama.integ import convert_to_hg_llama_config
from fairseq2.utils.file import FileMode, FileSystem
from fairseq2.utils.structured import unstructure
from fairseq2.utils.yaml import YamlDumper


@final
class FileCheckpointMetadataSaver:
_checkpoint_dir: Path
_gangs: Gangs
_file_system: FileSystem
_yaml_dumper: YamlDumper

def __init__(
self,
checkpoint_dir: Path,
gangs: Gangs,
file_system: FileSystem,
yaml_dumper: YamlDumper,
) -> None:
self._checkpoint_dir = checkpoint_dir
self._gangs = gangs
self._file_system = file_system
self._yaml_dumper = yaml_dumper

def save(
self, family: str, config: object, tokenizer_name: str | None = None
) -> None:
if self._gangs.root.rank == 0:
unstructured_config = unstructure(config)

metadata: dict[str, object] = {
"name": "checkpoint",
"model_family": family,
"model_config": unstructured_config,
}

if tokenizer_name is not None:
metadata["tokenizer_ref"] = tokenizer_name

if self._gangs.tp.size != 1:
metadata["num_shards"] = self._gangs.tp.size

metadata_file = self._checkpoint_dir.joinpath("model.yaml")

def save_error() -> AssetMetadataError:
return AssetMetadataError(
f"The model metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
)

try:
self._file_system.make_directory(metadata_file.parent)
except OSError as ex:
raise save_error() from ex

try:
self._yaml_dumper.dump(metadata, metadata_file)
except OSError as ex:
raise save_error() from ex

self._save_huggingface_config(family, config)

self._gangs.root.barrier()

def _save_huggingface_config(self, family: str, config: object) -> None:
if family != LLAMA_MODEL_FAMILY:
return

if not isinstance(config, LLaMAConfig):
raise TypeError(
f"`config` must be of type `{LLaMAConfig}`, but is of type `{type(config)}` instead."
)

hg_config = convert_to_hg_llama_config(config)

hg_config_file = self._checkpoint_dir.joinpath("cc/config.json")

def save_error() -> AssetMetadataError:
return AssetMetadataError(
f"The Hugging Face model configuration cannot be saved to the '{hg_config_file}' file. See the nested exception for details."
)

try:
self._file_system.make_directory(hg_config_file.parent)
except OSError as ex:
raise save_error() from ex

try:
fp = self._file_system.open_text(hg_config_file, mode=FileMode.WRITE)
except OSError as ex:
raise save_error() from ex

try:
json.dump(hg_config, fp, indent=2, sort_keys=True)
except OSError as ex:
raise save_error() from ex
finally:
fp.close()


@final
Expand Down Expand Up @@ -46,13 +146,6 @@ def __init__(
def _load_cache(self) -> dict[str, dict[str, object]]:
cache: dict[str, dict[str, object]] = {}

self._load_model(cache)

self._load_tokenizer(cache)

return cache

def _load_model(self, cache: dict[str, dict[str, object]]) -> None:
metadata_file = self._checkpoint_dir.joinpath("model.yaml")

for name, metadata in self._metadata_file_loader.load(metadata_file):
Expand Down Expand Up @@ -144,12 +237,12 @@ def load_error() -> AssetMetadataError:
scores.append((score, step_nr))

if max_step_nr == -1:
return
return cache

add_checkpoint_metadata("last_checkpoint@", max_step_nr)

if not scores:
return
return cache

scores.sort()

Expand All @@ -160,16 +253,4 @@ def load_error() -> AssetMetadataError:
for idx, (_, step_nr) in enumerate(reversed(scores)):
add_checkpoint_metadata(f"best_checkpoint_{idx}@", step_nr)

def _load_tokenizer(self, cache: dict[str, dict[str, object]]) -> None:
metadata_file = self._checkpoint_dir.joinpath("tokenizer.yaml")

try:
tokenizer_exists = self._file_system.exists(metadata_file)
except OSError as ex:
raise AssetMetadataError(
f"The '{metadata_file}' path cannot be accessed. See the nested exception for details."
) from ex

if tokenizer_exists:
for name, metadata in self._metadata_file_loader.load(metadata_file):
cache[name] = metadata
return cache
2 changes: 1 addition & 1 deletion src/fairseq2/cli/commands/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
ConvertLLaMACheckpointHandler as ConvertLLaMACheckpointHandler,
)
from fairseq2.cli.commands.llama._write_hf_config import (
WriteLLaMAHFConfigHandler as WriteLLaMAHFConfigHandler,
WriteHFLLaMAConfigHandler as WriteHFLLaMAConfigHandler,
)
Loading

0 comments on commit bfee716

Please sign in to comment.