Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Oneshot] Oneshot Refactor #1041

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
276b779
init
horheynm Jan 7, 2025
c690043
decouple main and successful fp8 run
horheynm Jan 7, 2025
166e4df
remove stage runner
horheynm Jan 7, 2025
40c73eb
run calib
horheynm Jan 7, 2025
7747bd6
Merge branch 'main' into oneshot-refac-1
horheynm Jan 7, 2025
3b7fd6a
potential non use of session
horheynm Jan 7, 2025
b3031c0
Merge branch 'oneshot-refac-1' of github.com:vllm-project/llm-compres…
horheynm Jan 7, 2025
1cd3d90
get rid of session, use oneshotclass
horheynm Jan 7, 2025
a5d0fd7
pass existing tests
horheynm Jan 8, 2025
33e1b16
Merge branch 'main' into oneshot-refac-1
horheynm Jan 8, 2025
e7407b9
pass finetune tests not dep on HF release
horheynm Jan 8, 2025
d352e4c
Merge branch 'oneshot-refac-1' of github.com:vllm-project/llm-compres…
horheynm Jan 8, 2025
bc532e7
remove unnecessary changes 1
horheynm Jan 8, 2025
137c02e
remove duplicate code
horheynm Jan 8, 2025
6d5cdbc
remove duplicate code, set output_dir and save_tensors as training_ar…
horheynm Jan 9, 2025
2c7c5f0
pass tests pre HFQuantizer check
horheynm Jan 9, 2025
324fc99
lint
horheynm Jan 10, 2025
0e34ad3
oneshot
horheynm Jan 10, 2025
9a6a87f
add __all__
horheynm Jan 10, 2025
54e8fd0
add init
horheynm Jan 10, 2025
01eff29
Merge branch 'main' into oneshot-refac-1
horheynm Jan 14, 2025
b20d6b8
move private below non-prov
horheynm Jan 15, 2025
7e84319
Merge branch 'oneshot-refac-1' of github.com:vllm-project/llm-compres…
horheynm Jan 15, 2025
3547baf
pass tests/llmcompressor/transformers/finetune/test_oneshot_and_fine…
horheynm Jan 15, 2025
976814f
remove redundant code
horheynm Jan 15, 2025
59d5d63
remove training_args, use session not local lifecycle
horheynm Jan 15, 2025
b5f75d5
move args
horheynm Jan 15, 2025
bd1385e
simplify inputargs to oneshot
horheynm Jan 16, 2025
d52dbf3
clean up **kwargs of Oneshot
horheynm Jan 16, 2025
0060b63
better doc strings
horheynm Jan 16, 2025
9eaf4c2
add docstrings, retire apply
horheynm Jan 22, 2025
77d15a4
revert exampels script
horheynm Jan 22, 2025
d5d34f6
remove apply from sessionmixin:
horheynm Jan 22, 2025
73e4d7b
remove comments
horheynm Jan 22, 2025
e1bdffd
Merge branch 'main' into oneshot-refac-1
horheynm Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import RecipeContainer

__all__ = ["CompressionLifecycle"]
__all__ = [
"CompressionLifecycle",
]


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/transformers/calibration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .oneshot import Oneshot
164 changes: 164 additions & 0 deletions src/llmcompressor/transformers/calibration/oneshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from pathlib import PosixPath
from typing import Optional

from loguru import logger
from torch.utils.data import DataLoader

from llmcompressor.core.lifecycle import CompressionLifecycle
from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
get_calibration_dataloader,
)
from llmcompressor.transformers.finetune.model_args import ModelArguments
from llmcompressor.transformers.finetune.text_generation import (
initialize_model_from_path,
initialize_processor_from_path,
parse_args,
)
from llmcompressor.transformers.finetune.training_args import (
DEFAULT_OUTPUT_DIR,
TrainingArguments,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.transformers.utils.recipe_args import RecipeArguments

__all__ = ["Oneshot"]


class Oneshot:
"""
Class responsible for carrying out oneshot calibration.

Usage:

```python
oneshot = Oneshot(model=model, recipe=recipe, dataset=dataset)
oneshot.run()

model = oneshot.model
tokenizer_or_processor = oneshot.tokenizer_or_processor
recipe = oneshot.recipe

```
"""

MODIFIER_LIFECYCLE_ACTIONS = (
"initialize",
"finalize",
)

def __init__(
self,
lifecycle: Optional[CompressionLifecycle] = None,
model_args: Optional["ModelArguments"] = None,
data_args: Optional["DataTrainingArguments"] = None,
recipe_args: Optional["RecipeArguments"] = None,
training_args: Optional["TrainingArguments"] = None,
**kwargs,
):
if any(
arg is not None
for arg in [model_args, data_args, recipe_args, training_args]
):
self.model_args, self.data_args, self.recipe_args, training_args = (
model_args,
data_args,
recipe_args,
training_args,
)
else:
self.model_args, self.data_args, self.recipe_args, training_args = (
parse_args(**kwargs)
)

self.lifecycle = (
lifecycle or CompressionLifecycle() # lifecycle from stage runner
)
self.output_dir = training_args.output_dir

# Preprocess the model and tokenizer/processor
self._pre_process()

# Set instance attributes
self.model = self.model_args.model
self.tokenizer_or_processor = self.model_args.processor
self.recipe = self.recipe_args.recipe
self.modifiers = self.lifecycle.modifiers

def run(self, **kwargs):
"""Perform oneshot calibration"""
calibration_dataloader = get_calibration_dataloader(
self.data_args, self.tokenizer_or_processor
)
self._apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader, **kwargs
)
self._post_process()

def save(self):
"""Save the model and tokenizer/processor to the output directory"""
self.model.save_pretrained(
self.output_dir,
save_compressed=self.model_args.save_compressed,
stage_modifiers=self.lifecycle.modifiers,
)
if self.tokenizer_or_processor:
self.tokenizer_or_processor.save_pretrained(self.output_dir)

def _apply_recipe_modifiers(
self, calibration_dataloader: Optional[DataLoader], **kwargs
):
"""Apply recipe modifiers to the model"""
for action in self.MODIFIER_LIFECYCLE_ACTIONS:
lifecycle = getattr(self.lifecycle, action)
lifecycle(
model=self.model,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
**kwargs,
)

def _pre_process(self):
"""Preprocess model and tokenizer/processor"""
self._warn_tied_embeddings()

# Initialize model
if isinstance(self.model_args.model, (str, PosixPath)):
self.model_args.model, _ = initialize_model_from_path(self.model_args)

patch_tied_tensors_bug(self.model_args.model)
modify_save_pretrained(self.model_args.model)

# Initialize processor
if isinstance(self.model_args.processor, (str, type(None))):
self.model_args.processor = initialize_processor_from_path(
self.model_args, self.model_args.model
)

# Set minimum tokens per module if data arguments are provided
if self.data_args:
self.min_tokens_per_module = self.data_args.min_tokens_per_module

def _warn_tied_embeddings(self):
if self.model_args.tie_word_embeddings:
logger.debug(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save"
)

def _post_process(self):
"""Save model and reset the lifecycle if requested"""
if (
isinstance(self.model_args.model, str)
or self.output_dir != DEFAULT_OUTPUT_DIR
):
self.save()
14 changes: 10 additions & 4 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch import Tensor
from torch.nn import Module

from llmcompressor.core import active_session
from llmcompressor.core import CompressionLifecycle, active_session
from llmcompressor.modifiers.stage import StageModifiers
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.transformers.compression.helpers import (
infer_sparse_targets_and_ignores,
Expand Down Expand Up @@ -40,7 +41,10 @@ def infer_global_sparsity(
return global_sparsity

@staticmethod
def infer_sparsity_structure(model: Optional[Module] = None) -> str:
def infer_sparsity_structure(
model: Optional[Module] = None,
stage_modifiers: Optional[CompressionLifecycle] = None,
) -> str:
"""
Determines what sparsity structure, if any, was applied.

Expand All @@ -58,7 +62,7 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
sparsity_structure = None

current_session = active_session()
stage_modifiers = current_session.lifecycle.modifiers
stage_modifiers = stage_modifiers or current_session.lifecycle.modifiers
if stage_modifiers:
sparsity_structure = infer_sparsity_structure_from_stage_modifiers(
stage_modifiers
Expand All @@ -74,6 +78,7 @@ def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
stage_modifiers: Optional[StageModifiers] = None,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -93,7 +98,8 @@ def from_pretrained(
return None

sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
model=model,
stage_modifiers=stage_modifiers,
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
Expand Down
76 changes: 76 additions & 0 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os
import re
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator

Expand All @@ -15,6 +17,7 @@
"get_raw_dataset",
"make_dataset_splits",
"get_custom_datasets_from_path",
"get_calibration_dataloader",
]


Expand Down Expand Up @@ -243,3 +246,76 @@ def do_transform(candidate: str) -> bool:
transform_dataset_key(dataset_key)

return data_files


def get_calibration_dataloader(
data_args,
processor,
add_labels: bool = False, # for oneshot
do_oneshot=True,
):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
enabled flow in self.datasets

:param processor: processor or tokenizer to use for dataset tokenization
:param add_labels: if True, add labels column to dataset splits
"""
if data_args.dataset is None:
logger.info(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return

splits = data_args.splits
tokenized_datasets = {}

def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
match = re.match(r"(\w*)\[.*\]", inp_str)
if match is not None:
return match.group(1)
return inp_str

if splits is None:
splits = {"all": None}
elif isinstance(splits, str):
splits = {_get_split_name(splits): splits}
elif isinstance(splits, List):
splits = {_get_split_name(s): s for s in splits}

# default to custom dataset if dataset provided isn't a string
registry_id = data_args.dataset if isinstance(data_args.dataset, str) else "custom"
for split_name, split_str in splits.items():
dataset = data_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
from llmcompressor.transformers.finetune.data.base import (
TextGenerationDataset,
)

dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
data_args=data_args,
split=split_str,
processor=processor,
)
tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels)

datasets = make_dataset_splits(
tokenized_datasets,
do_oneshot=do_oneshot,
)

calibration_dataset = datasets.get("calibration")

return format_calibration_data(
tokenized_dataset=calibration_dataset,
num_calibration_samples=data_args.num_calibration_samples,
do_shuffle=data_args.shuffle_calibration_samples,
collate_fn=data_args.data_collator,
)
31 changes: 19 additions & 12 deletions src/llmcompressor/transformers/finetune/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from
Model variables used for oneshot calibration, training or finetuning and
stage runners (combination of oneshot and finetune going back and forth)

"""

model: str = field(
Expand Down Expand Up @@ -44,17 +46,7 @@ class ModelArguments:
default=None,
metadata={"help": "Where to store the pretrained data from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizers. Default True"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)

use_auth_token: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -83,3 +75,18 @@ class ModelArguments:
"repositories you trust and in which you have read the code"
},
)
save_compressed: Optional[bool] = field(
default=True,
metadata={"help": "Whether to compress sparse models during save"},
)
oneshot_device: Optional[str] = field(
default="cuda:0",
metadata={"help": "Device to run oneshot calibration on"},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use "
"(can be a branch name, tag name or commit id)"
},
)
Loading
Loading