-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Weight Normalization, addressing issue #1888 #1921
base: main
Are you sure you want to change the base?
Conversation
Improve the weight normalization implementation by: Use the optimized C++ mx.weight_norm() in WeightNormWrapper.call Add comprehensive tests for WeightNormConv2d Verify direct API usage matches module wrapper results Test normalization over multiple axes and edge cases Add specific test for GitHub issue ml-explore#1888 This change ensures maximum performance by leveraging the C++ implementation with its optimized handling of >2 axes normalization.
Thanks a lot @cavit99, this is great work! One tiny nit:
|
agreed from my side, so I pushed that change to the PR, thank you |
Perfect! 🤩 Now we wait for @awni :) |
he's gonna look and say meh, maybe if you stick it in normalization.py isn't he |
I'm not certain about including this as Either way we should not make free functions in C++ and Python for this. It should just be a layer in |
I gather it's making a comeback with realtime audio use cases, vocoders, tts because of being lighter than batch norm and where stability and convergence are critical. Other than @Blaizzy's mlx-audio which implemented weight norm manually also, I see torch's weight_norm being used in for example spark-tts, Nvidia NeMo, coqui tts, and maintained in regarding layer of course you're right, if you agree regarding usefulness I will happy refactor it fully into mlx.nn as it should have been from the start |
That would be great! |
Proposed changes
This PR implements weight normalization for MLX, addressing issue #1888. Weight normalization is a reparameterization technique that decouples the magnitude of a weight tensor from its direction, making optimization more efficient by improving the conditioning of the optimization problem. Is particularly important for audio processing, among other applications.
Key Features
mx.weight_norm
with optimized paths for different dimensionsweight_norm.py
with user-friendly API and layer wrapperslinalg::norm
2-axes limitationImplementation Details
Core C++ Implementation
The core
weight_norm
operation is implemented with three different paths based on the number of axes to normalize over:linalg::norm
kernelsPython Layer
The Python implementation:
weight_norm
function that wraps MLX modulesTesting and Verification
Testing follows a comprehensive two-pronged approach:
1. Mathematical Property Tests
2. Cross-Framework Verification
3. Performance Benchmarks
Benchmarked on Apple M3 Max shows MLX outperforms PyTorch MPS:
Usage Examples
Core API
Module API
Resolves #1888.
Checklist
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes