fix: bugs and code quality improvements across prototype modules#4332
fix: bugs and code quality improvements across prototype modules#4332tatavishnurao wants to merge 3 commits intopytorch:mainfrom
Conversation
- Replace print() with logger.info() in fusion_utils.py (unconsoleable output during torch.compile)
- Add kernel availability guard to Int4OpaqueTensor.from_hp_da8w4() with clear error message
- Add kernel availability guard to QuantizedLinear._forward_2d() to prevent AttributeError
- Replace mutable default kwargs={} in _replace_embedding_with_quantized_embedding()
- Fix fragile stdout capture in test_rope_fusion_detection.py to use logger capture
- Add public API exports to embedding/__init__.py (EmbeddingQuantizer, QuantizedLinear, etc.)
- Remove unused _is_blackwell() from attention/utils.py
- Remove misconfigured @triton.autotune decorators (empty configs, constexpr key)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4332
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @tatavishnurao! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Xia-Weiwen
left a comment
There was a problem hiding this comment.
The change to torchao/prototype/quantization/int4/int4_opaque_tensor.py looks good to me if with the fix of build flag.
Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
|
Thanks for the review! I have applied with the suggested |
Diagnostics: Bug fixes and code quality improvements across prototype modules
Summary
Fixes several bugs and code quality issues found across
torchao/prototype/attention/,torchao/prototype/quantization/int4/, andtorchao/prototype/quantization/embedding/.Changes
Bug fixes
Replace
print()withloggerinfusion_utils.py— Production fusion pass usedprint()to report RoPE fusion results, which cannot be suppressed and floods console output duringtorch.compile. Switched tologger.info().torchao/prototype/attention/shared_utils/fusion_utils.py:965-1108Add kernel availability guard to
from_hp_da8w4()— CallingInt4OpaqueTensor.from_hp_da8w4()directly bypassed the_dispatch_dumpassertion that protects the config handler path, resulting in a confusingRuntimeErrorfrom the C++ kernel. Added an early check with a clear error message.torchao/prototype/quantization/int4/int4_opaque_tensor.pyReplace mutable default argument
kwargs={}—_replace_embedding_with_quantized_embedding()used a mutable dict as a default argument (kwargs={}). While harmless in this case (the dict is only read, never mutated), it is a well-known Python antipattern and source of subtle bugs.torchao/prototype/quantization/embedding/api.py:145Add kernel availability guard to
QuantizedLinear—QuantizedLinear.forward()dynamically resolvestorch.ops.torchao._linear_8bit_act_{N}bit_weightwhich only exists when C++ kernels are built. UnlikeQuantizedEmbedding, there was no guard to prevent anAttributeErrorcrash. Added_is_kernel_library_loaded()assertion.torchao/prototype/quantization/embedding/api.py:253Test robustness
stdoutcapture intest_rope_fusion_detection.py— The test asserted"1 fused with RoPE"appeared instdout, which couples the test to the logging mechanism. Updated to captureloggeroutput instead, keeping the test working after theprint()→loggermigration.test/prototype/attention/test_rope_fusion_detection.py:134-141Code quality
Add public API exports to
embedding/__init__.py— The module was empty, forcing users to import fromapi.pydirectly. Added re-exports forEmbeddingQuantizer,TiedEmbeddingQuantizer,QuantizedEmbedding,QuantizedEmbeddingFallback, andQuantizedLinearto match the pattern used byint4/__init__.py.torchao/prototype/quantization/embedding/__init__.pyRemove unused
_is_blackwell()—_is_blackwell()inattention/utils.pywas defined but never called anywhere. Removed dead code.torchao/prototype/attention/utils.pyFix misconfigured triton
@autotunedecorators —hadamard_single_phase1_kernel,hadamard_v_phase1_kernel, andhadamard_rope_single_phase1_kernelhad@triton.autotuneconfigs with empty{}(no tunable parameters) andkey=["D"]whereDis passed as atl.constexpr. These decorators add launch overhead for zero tuning benefit. Removed the@triton.autotunedecorators and inlined thenum_warpsvalue from the first config.torchao/prototype/attention/quantization/triton_hadamard_qkv_quantization.pytorchao/prototype/attention/quantization/triton_hadamard_rope_qkv_quantization.pyIssues not fixed (noted for future work)
torch._C._dispatch_dumpis undocumented PyTorch internal API — Used ininference_workflow.py:135andtest_ops.py:118-119,360,383,405-407to check if a C++ kernel is registered. Works reliably but is not part of PyTorch's public API. No better alternative exists currently.Test plan
pytest test/prototype/attention/test_rope_fusion_detection.py— verify fusion detection tests still pass afterprint()→loggermigrationpytest test/prototype/quantization/test_int4_opaque_tensor.py— verify A16W4 path still workspytest test/prototype/quantization/test_embedding.py— verify embedding quantization pathspytest test/quantization/test_quant_api.py— verifyPrototypeInt4WeightOnlyConfigandInt8DynamicActivationInt4WeightConfigstill register correctly