Skip to content

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Dec 30, 2025

Adds tests to ensure that the parameter values after reset_parameters() match their initial distributions.

Fixes #2528, #2529

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 30, 2025

Greptile Summary

This PR adds a test to validate that reset_parameters() doesn't change the statistical properties (mean and standard deviation) of module parameters. The test addresses issues #2528 (LayerNormLinear's layer_norm_weight incorrectly reset from 1.0 to random near 0) and #2529 (Linear's bias incorrectly reset from 0.0 to random near 0).

The test is well-structured and parameterized across all core modules (LayerNorm, RMSNorm, Linear, LayerNormLinear, LayerNormMLP), ensuring comprehensive coverage.

Key observations:

  • Test correctly captures parameter statistics before and after reset_parameters() call
  • Uses appropriate tolerances (1e-3) for floating-point comparisons
  • Minor issue with error message formatting using lambda instead of string

Confidence Score: 4/5

Important Files Changed

Filename Overview
tests/pytorch/test_deferred_init.py Adds test to verify reset_parameters() preserves parameter mean and std, catching initialization bugs #2528 and #2529

Sequence Diagram

sequenceDiagram
    participant Test as test_reset_parameters_doesnt_change_parameter_stats
    participant Module as TE Module (Linear/LayerNormLinear/etc)
    participant Params as Parameters (weight, bias, layer_norm_weight)
    
    Test->>Module: Initialize module on cuda device
    Module->>Params: Create parameters with initial values
    Note over Params: layer_norm_weight=1.0<br/>bias=0.0<br/>weight=random
    Test->>Params: Capture initial stats (mean, std)
    Note over Test: Store param_stats dict
    Test->>Module: Call reset_parameters()
    Module->>Params: Re-initialize parameters
    Note over Params: Should preserve distributions
    Test->>Params: Capture post-reset stats (mean, std)
    Note over Test: Store param_stats_after dict
    Test->>Test: Assert mean unchanged (atol=1e-3, rtol=1e-3)
    Test->>Test: Assert std unchanged (atol=1e-3, rtol=1e-3)
    Note over Test: Catches bugs #2528 & #2529<br/>if distributions change
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. tests/pytorch/test_deferred_init.py, line 110-123 (link)

    syntax: The msg parameter uses a lambda function, but torch.testing.assert_close expects a string, not a callable. This will cause the error message to display <lambda> instead of the actual error details.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LayerNormLinear reset_parameters() leads to the wrong initialization.

1 participant