From fe741f489d2ee07837485e6373f08c3b6175748c Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Tue, 8 Oct 2024 01:19:19 -0700 Subject: [PATCH] refactor --- generate.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/generate.py b/generate.py index 18ed3e1..980fd2c 100644 --- a/generate.py +++ b/generate.py @@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) +def _convert_weight(model): + from quantize import WeightOnlyInt4Linear + for fqn, mod in model.named_modules(): + if isinstance(mod, WeightOnlyInt4Linear): + weight = mod.weight.data + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles) + mod.weight = weight_int4pack + def _load_model(checkpoint_path, device, precision, use_tp): use_cuda = 'cuda' in device with torch.device('meta'): @@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp): checkpoint = checkpoint["model"] model.load_state_dict(checkpoint, assign=True) + model = model.to(device=device, dtype=precision) + # int4 packed weight needs to be converted after model loading to the specific device + if "int4" in str(checkpoint_path): + _convert_weight(model) + if use_tp: from tp import apply_tp print("Applying tensor parallel to model ...") apply_tp(model) - - model = model.to(device=device, dtype=precision) - if "int4" in str(checkpoint_path): - from quantize import WeightOnlyInt4Linear - for fqn, mod in model.named_modules(): - if isinstance(mod, WeightOnlyInt4Linear): - weight = mod.weight.data - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles) - mod.weight = weight_int4pack return model.eval() def _get_model_size(model):