Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryAdded Critical Issues:
Implementation:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Recipe
participant Linear
participant BasicLinear
participant Quantize
User->>Recipe: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
Recipe->>Recipe: quantize_backward = False
Note over Recipe: DelayedScaling: CRASHES HERE<br/>(assertion at line 220)
User->>Linear: forward(input)
Linear->>Linear: keep_backward_unquantized = True
Linear->>Linear: save_original_input = True
Linear->>Quantize: quantize(input)
Quantize->>Quantize: Check recipe.quantize_forward
Note over Quantize: Potential crash if recipe is None
Quantize-->>Linear: quantized_input (FP8)
Linear->>BasicLinear: forward(quantized_input, weight)
BasicLinear->>BasicLinear: Save high-precision input for backward
BasicLinear-->>Linear: output
User->>Linear: backward(grad_output)
Linear->>BasicLinear: backward(grad_output)
Note over BasicLinear: Uses high-precision saved tensors<br/>Skip quantization in backward
BasicLinear->>BasicLinear: wgrad = grad_output @ input_hp
BasicLinear->>BasicLinear: dgrad = grad_output @ weight_hp
BasicLinear-->>Linear: grad_input (high precision)
Linear-->>User: gradients (BF16/FP32)
|
|
I'll work on potential unit test breakage. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
… is used Signed-off-by: Ziang Li <ziangli@umich.edu>
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
| # Check if FP8 is enabled | ||
| fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() | ||
| quantize_forward = fp8_enabled and self._quantize_forward | ||
| quantize_backward = fp8_enabled and self._quantize_backward | ||
| quantize_backward = ( |
There was a problem hiding this comment.
Recipe None crash
FP8GlobalStateManager.get_fp8_recipe() can be None (e.g., if FP8 is enabled but no recipe was set), so ...get_fp8_recipe().quantize_backward will raise an AttributeError. This makes Quantize.forward() crash in that configuration; please guard for None (e.g., default to quantized backward when recipe is unset, or explicitly error with a clearer message).
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
assertion prevents using NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe - when env var is set, quantize_backward becomes False, making this assert fail and blocking the entire feature for this recipe type
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
assertion makes LayerNormMLP crash immediately when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 is set - if this module doesn't support the feature, either implement it or handle more gracefully with a clear error message before reaching this point
Additional Comments (1)
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
|
||
| # Save state for backward pass | ||
| if ctx.requires_grad: | ||
| saved_input = input_ if keep_backward_unquantized else x_local | ||
| saved_weight = self.weight if keep_backward_unquantized else w | ||
| if is_cpu_offload_enabled(): | ||
| mark_activation_offload(x_local) | ||
| ctx.save_for_backward(x_local, w) | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| mark_activation_offload(saved_input) | ||
| ctx.save_for_backward(saved_input, saved_weight) |
There was a problem hiding this comment.
Unnecessary saved tensors
The forward path now saves saved_input/saved_weight whenever ctx.requires_grad is true, even when weight_requires_grad or input_requires_grad is false. In cases like frozen weights (common for finetune) or when only one side needs grads, this will save extra tensors and can materially increase activation memory. The prior if not weight_requires_grad: saved_input = None / if not input_requires_grad: saved_weight = None logic avoided that.
Signed-off-by: Ziang Li <ziangli@umich.edu>
Additional Comments (1)
|
|
Currently, without |
| quantize_forward : bool, default = True | ||
| Whether to quantize tensors in the forward pass. | ||
| quantize_backward : bool, default = True | ||
| Whether to quantize tensors in the backward pass. |
There was a problem hiding this comment.
Not sure we need that for the custom recipe, since there we can just specify the quantizers we want, but sure, we can have it to keep the API consistent.
| ) | ||
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
| ) | ||
| if keep_backward_unquantized: | ||
| # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used | ||
| save_original_input = True |
There was a problem hiding this comment.
We should also make sure that we don't create the columnwise version of the input.
There was a problem hiding this comment.
The input_quantizer columnwise is disabled here: https://github.com/NVIDIA/TransformerEngine/pull/2644/changes/BASE..253873a4560b2c2a2c909918cc3ee26500e5b43d#diff-864ad36a21c571fb178499535cfada611df4a82223c9ffbfea872dda39972eaeR335-R342
|
@zianglih Thank you for your contribution!
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
blocks NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling - when env var is set, quantize_backward becomes False, triggering this assertion and preventing the feature from working with this recipe type
| assert ( | ||
| not keep_backward_unquantized |
There was a problem hiding this comment.
hard crash when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 - LayerNormMLP becomes completely unusable with this env var
|
Hi @zhongbozhu @timmoon10 @ptrendx , thank you so much for reviewing! I have implemented and added the unit test. All new tests passed: |
|
Hi @timmoon10 , @ptrendx,
This design was from @timmoon10 's comment here: #2644 (comment)
Which way do we prefer? Thanks! |
Signed-off-by: Ziang Li <ziangli@umich.edu>
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
blocks NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe
when env var is set, quantize_backward becomes False, but this assertion requires it to be True - the feature cannot work with this recipe type at all
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." | |
| # Note: DelayedScaling does not support quantize_backward=False yet |
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
hard crash when NVTE_KEEP_BACKWARD_UNQUANTIZED=1
setting the env var makes LayerNormMLP completely unusable - crashes immediately on first use
| assert ( | |
| not keep_backward_unquantized | |
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" | |
| if keep_backward_unquantized: | |
| raise NotImplementedError( | |
| "NVTE_KEEP_BACKWARD_UNQUANTIZED is not yet implemented in LayerNormMLP" | |
| ) |
| # Recipe quantize overrides | ||
| if FP8GlobalStateManager.get_fp8_recipe() is not None: | ||
| quantize_forward = ( | ||
| quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward | ||
| ) | ||
| quantize_backward = ( | ||
| quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward |
There was a problem hiding this comment.
get_fp8_recipe() returns None when FP8 is enabled but no recipe set
calling .quantize_backward on None will crash with AttributeError
| # Recipe quantize overrides | |
| if FP8GlobalStateManager.get_fp8_recipe() is not None: | |
| quantize_forward = ( | |
| quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward | |
| ) | |
| quantize_backward = ( | |
| quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward | |
| # Recipe quantize overrides | |
| recipe = FP8GlobalStateManager.get_fp8_recipe() | |
| if recipe is not None: | |
| quantize_forward = quantize_forward and recipe.quantize_forward | |
| quantize_backward = quantize_backward and recipe.quantize_backward |
|
Full unit tests results, with the newly added |
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: