[PyTorch] Remove unnecessary save of weights #2549
Draft
+142
−151
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
MCore's fused wgrad accumulation feature requires setting the
grad_added_to_main_gradattribute on the weight's Python object. This means the original Python object must be accessible and modifiable during the backward pass.Currently, weights are saved via
save_for_backward, with the assumption that no hooks substitute them with different tensors (e.g., during CPU offload/reload). For CPU offloading, we work around this by saving weights directly onctx. However, this approach is incompatible with non-TE CPU offloading scenarios and potentially conflicts with FSDP, which also manages weight tensors.This PR addresses these issues by saving weak references to weights for the backward pass instead. When modifications to the original Python object are needed (e.g., setting
grad_added_to_main_grad), the weakref is dereferenced and the modification is applied. This is done conditionally, only when MCore FSDP or MCore fused wgrad accumulation is enabled.Changes:
weakrefin forward passfuse_wgrad_accumulationis enabledctxlinear.py,layernorm_linear.py,grouped_linear.py, andlayernorm_mlp.pyType of change
Checklist: