Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use DTensor-based tensor parallel #180

Open
wants to merge 1 commit into
base: gh/kwen2501/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,22 @@ def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q, k, v = self.wq(x), self.wk(x), self.wv(x)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand Down
11 changes: 3 additions & 8 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def convert_hf_checkpoint(
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}


def permute(w, n_head):
dim = config.dim
Expand Down Expand Up @@ -116,13 +116,8 @@ def permute(w, n_head):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
final_result[key] = permute(q, config.n_head)
final_result[key.replace("wq", "wk")] = permute(k, config.n_local_heads)
else:
final_result = merged_result
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
Expand Down
104 changes: 63 additions & 41 deletions tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from enum import Enum
from typing import List, Optional

import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch import nn
if os.uname().sysname != "Darwin":
from torch.distributed import _functional_collectives as funcol
Expand All @@ -16,7 +19,7 @@
funcol = None

from model import Attention, FeedForward, Transformer
from quantize import WeightOnlyInt4Linear
from quantize import WeightOnlyInt4Linear, WeightOnlyInt8Linear


def _get_rank() -> int:
Expand All @@ -33,6 +36,12 @@ def local_break():
def _get_world_size() -> int:
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))

global device_mesh

def _get_tp_mesh():
# device_mesh has only TP dimension for now
return device_mesh

def maybe_init_dist() -> Optional[int]:
try:
# provided by torchrun
Expand All @@ -48,86 +57,97 @@ def maybe_init_dist() -> Optional[int]:

torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

global device_mesh
device_mesh = dist.init_device_mesh(
"cuda",
(world_size,), # Only TP dimension for now
)
return rank

class TPMode(Enum):
MANUAL = 0
DTENSOR = 1

def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
def _apply_tp_linear(linear: nn.Linear, style: str) -> None:
rank = _get_rank()
world_size = _get_world_size()
tp_mesh = _get_tp_mesh()

# Linear's weight matrix is transposed, and is of shape
# (linear.out_features, linear.in_features)
dim_lookup = {
"colwise": (0, "out_features"),
"rowwise": (1, "in_features")
"colwise": (0, "out_features", ColwiseParallel()),
"rowwise": (1, "in_features", RowwiseParallel()),
}
assert style in dim_lookup
shard_dim, size_attr = dim_lookup[style]
shard_dim, size_attr, tp_plan = dim_lookup[style]

# ensure we can shard evenly
assert getattr(linear, size_attr) % world_size == 0
def shard(x, dim):
assert x.size(dim=dim) % world_size == 0
return torch.tensor_split(x, world_size, dim=dim)[rank]

def shard_qkv(qkv, dim, weight_splits):
q, k, v = qkv.split(weight_splits, dim=dim)
q = shard(q, dim)
k = shard(k, dim)
v = shard(v, dim)
return torch.cat((q,k,v), dim=dim)

# shard
if weight_splits:
# attention
assert len(weight_splits) == 3

if isinstance(linear, WeightOnlyInt4Linear):
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
else:
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard_qkv(linear.scales, 0, weight_splits)
else:
sharded_weight = shard(linear.weight, shard_dim)
if isinstance(linear, WeightOnlyInt4Linear):
def shard_scale(linear, shard_dim):
if hasattr(linear, "scales_and_zeros"):
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
if style == "rowwise":
assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3]
assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
if hasattr(linear, "scales") and style == "colwise":
linear.scales = shard(linear.scales, 0)
elif hasattr(linear, "scale"):
if style == "colwise":
linear.scales = shard(linear.scales, 0)

# shard
tp_mode: TPMode
if isinstance(linear, (WeightOnlyInt4Linear, WeightOnlyInt8Linear)):
# TODO: DTensor doesn't have a way to distribute quantized tensor yet.
# Should revisit when that capability is added.
sharded_weight = shard(linear.weight, shard_dim)
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
shard_scale(linear, shard_dim)
tp_mode = TPMode.MANUAL
else:
# Use DTensor based TP
parallelize_module(linear, tp_mesh, tp_plan)
tp_mode = TPMode.DTENSOR

# local_break()
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
setattr(linear, size_attr, getattr(linear, size_attr) // world_size)

# shape info should still be synced
# assert linear.weight.shape == (linear.out_features, linear.in_features)
return tp_mode


def _apply_tp_ffn(mlp: FeedForward) -> None:
assert hasattr(mlp, "w1")
assert hasattr(mlp, "w3")
assert hasattr(mlp, "w2")

_apply_tp_linear(mlp.w1, "colwise")
_apply_tp_linear(mlp.w3, "colwise")
_apply_tp_linear(mlp.w2, "rowwise")
tp_mode = _apply_tp_linear(mlp.w1, "colwise")
tp_mode = _apply_tp_linear(mlp.w3, "colwise")
tp_mode = _apply_tp_linear(mlp.w2, "rowwise")

world_size = _get_world_size()
mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
output, "sum", list(range(world_size))))
if tp_mode == TPMode.MANUAL:
# In manual mode, we need to manually add an all-reduce at the end
world_size = _get_world_size()
mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
output, "sum", list(range(world_size))))


def _apply_tp_attn(attn: Attention) -> None:
assert hasattr(attn, "wqkv")
assert hasattr(attn, "wq")
assert hasattr(attn, "wk")
assert hasattr(attn, "wv")
assert hasattr(attn, "wo")

kv_size = attn.n_local_heads * attn.head_dim
_apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
_apply_tp_linear(attn.wo, "rowwise")
tp_mode = _apply_tp_linear(attn.wq, "colwise")
tp_mode = _apply_tp_linear(attn.wk, "colwise")
tp_mode = _apply_tp_linear(attn.wv, "colwise")
tp_mode = _apply_tp_linear(attn.wo, "rowwise")

# overwrite
world_size = _get_world_size()
Expand All @@ -136,8 +156,10 @@ def _apply_tp_attn(attn: Attention) -> None:
attn.head_dim = attn.dim // attn.n_head
attn.n_local_heads = attn.n_local_heads // world_size

attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
output[0], "sum", list(range(world_size))))
if tp_mode == TPMode.MANUAL:
# In manual mode, we need to manually add an all-reduce at the end
attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
output[0], "sum", list(range(world_size))))


def _apply_tp_Transformer(Transformer: Transformer) -> None:
Expand Down