-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Automicrobatching for Non-Powers-of-2 + Fixes to FSDP deadlocks using Adaptive Sync Hooks #3503
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rerequest once test passes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass, design looks right but code needs some cleanup
@no_type_check | ||
def unshard(self): | ||
""" | ||
Run the unshard logic. | ||
This is an unpatched method from pytorch, meant to be reverted to | ||
whenever automicrobatching turns off its hooks for increased throughput. | ||
This includes all-gathering the flat parameter | ||
and switching to using the unsharded flat parameter. If the handle does | ||
not need unsharding, then this only switches to using the unsharded | ||
flat parameter. For ``NO_SHARD``, this is a no-op. | ||
If FSDP is in :meth:`summon_full_params` and the handle uses parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be in the if torch 2.3.1 section
|
||
if auto_microbatching: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment on what this is doing?
def _double_device_train_microbatch_size(state: State): | ||
"""Double device_train_microbatch_size when automicrobatching searches upward for a higher non-OOM microbatch size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this go into automcirobatching utils folder?
num_consecutive_thrashes = 0 | ||
return num_consecutive_thrashes | ||
|
||
def _handle_downward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, lower_bound_microbatch_size: int, num_search_steps: int, max_search_steps: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment on moving to utils?
@@ -1251,6 +1421,7 @@ def __init__( | |||
if parallelism_config is not None: | |||
# Patch PyTorch to fix distributed bugs | |||
patch_pytorch() | |||
patch_unshard_for_automicrobatching(self.auto_microbatch_size_found) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be just part of patch_pytorch to simplify interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to pass in a boolean variable telling it how to patch this one specific method though - i feel like it would be less readable if we passed self.auto_microbatch_size_found
directly into patch_pytorch
# Sync for OOMs | ||
found_cuda_oom = _found_ooms_across_ranks(self.state, found_cuda_oom) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block is really complicated. lets move to a helper fn
with torch.no_grad(), model_eval_mode(self.state.model): | ||
if self.state.fsdp_enabled and self.first_batch_complete: | ||
print("readd hooks for eval") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
@@ -8,6 +8,18 @@ | |||
convert_nested_dict_to_flat_dict, | |||
extract_hparams, | |||
) | |||
from composer.utils.automicrobatching import ( | |||
# _create_sync_hook, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
@@ -164,4 +176,14 @@ | |||
'validate_credentials', | |||
'build_remote_backend', | |||
'RemoteFilesExistingCheckStatus', | |||
# '_create_sync_hook', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
No description provided.