diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py index 40e6abbc..e28ef343 100644 --- a/opensora/models/layers/blocks.py +++ b/opensora/models/layers/blocks.py @@ -467,6 +467,11 @@ def forward(self, x, cond, mask=None): # query/value: img tokens; key: condition; mask: if padding tokens B, N, C = x.shape + if mask is None: + Bc, Nc, _ = cond.shape + assert Bc == B + mask = [Nc] * B + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) k, v = kv.unbind(2) @@ -504,6 +509,11 @@ def forward(self, x, cond, mask=None): B, SUB_N, C = x.shape # [B, TS/p, C] N = SUB_N * sp_size + if mask is None: + Bc, Nc, _ = cond.shape + assert Bc == B + mask = [Nc] * B + # shape: # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index d730981c..f8b0e3f3 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -197,8 +197,8 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri from safetensors.torch import load_file state_dict = load_file(ckpt_path) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - print(f"Missing keys: {missing_keys}") - print(f"Unexpected keys: {unexpected_keys}") + get_logger().info("Missing keys: %s", missing_keys) + get_logger().info("Unexpected keys: %s", unexpected_keys) elif os.path.isdir(ckpt_path): load_from_sharded_state_dict(model, ckpt_path, model_name, strict=strict) get_logger().info("Model checkpoint loaded from %s", ckpt_path) diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index 95d0011b..21391d24 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -153,9 +153,9 @@ def get_mask(self, x): elif mask_name == "random": mask_ratio = random.uniform(0.1, 0.9) mask = torch.rand(num_frames, device=x.device) > mask_ratio - # if mask is all False, set the last frame to True - if not mask.any(): - mask[-1] = 1 + # if mask is all False, set the last frame to True + if not mask.any(): + mask[-1] = 1 return mask