-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Blockwise Scaling for FP8 #1932
Conversation
57896c9
to
4d45e57
Compare
Super cool! Thanks for upstream :) Will do a full review soon. One comment to make to start would to please not extend the existing primary types in CUTLASS. Eg. The new collective builder should just be another specialized dispatch of the existing one. If you need to pass in extra arguments, you can include them in the builder dispatch policy itself. We want to ensure that there is always a single point of entry at each conceptual level. |
examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu
Show resolved
Hide resolved
...ed_gemm_with_blockwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
Outdated
Show resolved
Hide resolved
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBlockScalingBuilder< | ||
ArchTag, OperatorClass, | ||
ElementA, LayoutA, AlignmentA, | ||
ElementB, LayoutB, AlignmentB, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it implicit that Element Block Scale will always be same as ElementAccum
? if so, would be good to add some static assert in the collective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now it is the case that ElementBlockScale
= ElementAccumulator
. I still prefer to write have an alias type in the code to increase readability and find the scale tensors quickly while reading and searching. Maybe in the future we will need different datatypes for ElementBlockScale
and ElementAccumulator
. Although, I don't see it happening as long as accumulation is in F32. I have added the static_asserts, in case user tries and set them differently, as we don't have NumericConvertor
in the GmmaFP8Accumulation::scale_core
.
warpgroup_wait<0>(); | ||
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < size(accum_); ++i) { | ||
accum_(i) += accum_temp_(i) * scale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two suggestions :
-
One can do way more than just scaling accums with this approach - so you might optionally want to consider may be generically adding an interface for allowing user to apply a point-wise operation on an accumulator per block.
-
This impl places assumptions / restrictions on types of scale and accum. So this can be enhanced to call an appropriate sub-function / impl to ensure optimal scaling - that way in the future if one decides to pass a different scale type - it can just be extended with a new overload / specialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great suggestion and we can consider it in the future. For now, we are looking for scaling the accumulators blockwise and possibly group-wise. Let us discuss more on this soon.
f746e3c
to
2cdd89a
Compare
Pushed a few commits to address some of the comments. Please take a look. I kept the commits after initial commit separate so different comments can be reviewed easily. I will squash all commits before the merge. |
Thanks @thakkarV for the comment. Please the diff here. This diff makes the CollectiveBuilder entry point another specialization by adding a new dispatch policy. I had something similar in the TODO which also removed now. Please ignore the accidentally CMakeList change, you won't see that change in the full diff. |
Hi @manishucsd could you help resolve the conflicts with the main branch? |
Conflicts can be resolved trivially and I have pushed it to https://github.com/soundOfDestiny/cutlass/tree/f8_blockwise_scaling_pr_branch_resolve_conflict_blockwise In addition, I have implemented groupwise scaling granularity along M in A tensor with slight modification on this PR (https://github.com/manishucsd/cutlass/pull/1/files). Is it possible for reviewers to take a look at this PR manishucsd#1 for PR #1932? Thanks! I also resolved conflicts on my groupwise branch and pushed it to |
e4025bb
to
6834abc
Compare
6834abc
to
5ddebb9
Compare
@zhyncs and @soundOfDestiny, I have rebased it on top of the latest commit (CUTLASS 3.6). Apologies for the delay, I was AFK. Looking forward to your use of this feature. We believe this PR will be a good starting point for groupwise, as your commit shows. Do you only plan to have groups in M-dim? Are there uses case for groups in N-dim? Do you already have a non-CUTLASS groupwise kernel to show how much groupwise help numerically? Thanks @IonThruster on review and discussion on handling synchronization with multiple producers. Thanks in advance @hwu36 to help merge this. |
Thank you very much! I have also updated my branch in manishucsd#1
Having a smaller granularity in K-dim than this PR would hurt performance a lot due to the internal structure of WGMMA. We can have a larger granularity in K-dim by repeating the scaling tensor as input.
I think it's very useful to have groups in N-dim and we can implement groups in N-dim in the future by loading group scale of B into A special use case for groupwise is when group granularity along M-dim equals to 1, i.e. per-row (in M) x per-block (in K) scale for matrix A. This is used in https://github.com/deepseek-ai/DeepSeek-V3, in which there is also a triton kernel demo for GEMM with the per-row granularity scale (https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py):
among
|
Great to see it merged!! |
Hi all @manishucsd @hwu36 @soundOfDestiny I think we also need this https://github.com/manishucsd/cutlass/pull/1/files May we consider raising another PR to support that? |
Yes, please file a pr to us directly. |
I got it. I've just moved https://github.com/manishucsd/cutlass/pull/1/files to #2037 |
* Handle MNK Sm90{Row, Col}Reduction problem shapes (NVIDIA#1803) * add is_last_tile * Improve sm90 mixed dtype kernel (NVIDIA#1883) * Add GMMA shape m64n40k16 (NVIDIA#1864) * Add all supported GMMA shapes (NVIDIA#1890) * add maximum support (NVIDIA#1833) * fix typo (NVIDIA#1853) * fix by adding public (NVIDIA#1753) * added mapping for bf16 to torch::kBFloat16 (NVIDIA#1843) Co-authored-by: Haicheng Wu <[email protected]> * Fix README (NVIDIA#1658) * Fix README * Improve README --------- Co-authored-by: Haicheng Wu <[email protected]> * Adjusting code indentation (NVIDIA#1639) * Include of regular_tile_iterator.h fixed for NVRTC (NVIDIA#1765) * Include of regular_tile_iterator.h fixed for NVRTC * More include fixed for NVRTC * Update gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu with include "cutlass/gemm/device/gemm_universal.h" (NVIDIA#1569) fix compile with `cmake .. -DCUTLASS_ENABLE_TESTS=ON -DCUTLASS_TEST_LEVEL=2` * remove redundant hardcoded packing configs in mixed dtype gemm (NVIDIA#1894) Co-authored-by: Siyuan Fu <[email protected]> * fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support (NVIDIA#1856) * fix wrong A/BLayout in MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for m8n8k128, m16n8k128 mma.and.popc in MMA_Traits instantiation * add "print" template for subbyte_reference<T> * Add a print for the uint{x}b_t type. (NVIDIA#1871) * Refactor some GroupedGEMM logic (NVIDIA#1899) * feat: support kFactor 8 used in mma tensor op tile iterator (NVIDIA#1512) * Update publications (NVIDIA#1912) * remove restriction of stride == kernel in nhwc_pooling (NVIDIA#1896) * fix undefined in device code error (NVIDIA#1880) * Fix the racing condition of mixed-input gemm when writing the registers (NVIDIA#1931) * move two warpgroup_wait * merge main --------- Co-authored-by: Siyuan Fu <[email protected]> * Fix `cutlass` python library with cuda `12.6.2.post1` (NVIDIA#1942) * Fix `cutlass` python library with cuda `12.6.2.post1` Previously we had this error: ``` File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in <listcomp> _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] ^^^^^^ ValueError: invalid literal for int() with base 10: 'post1' ``` * Update sm90_utils.py * Update generator.py * Update python/cutlass_library/generator.py Co-authored-by: Jack Kosaian <[email protected]> * Update python/cutlass_library/sm90_utils.py Co-authored-by: Jack Kosaian <[email protected]> --------- Co-authored-by: Jack Kosaian <[email protected]> * add {uint4, uint2, int2} => {fp16, bf16} conversion (NVIDIA#1966) * Improve mixed dtype GEMM (NVIDIA#1972) * update * fix a typo * fix a typo that fails the compiling when ElementScale is not the same as MmaType (NVIDIA#1977) * Fix CuTe README Typo (NVIDIA#1951) * Fix Typo (NVIDIA#1962) * 3.6.0 update (NVIDIA#2005) * 3.6.0 update * doc and swap stuff --------- Co-authored-by: yuzhai <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * Update CHANGELOG.md * Update 0x_gemm_tutorial.md (NVIDIA#1982) Shouldn't this be BLK_M, BLK_**K**, k * fix bug: arch/mma_sm60.h Mma<2,2,1> calculate wrong (NVIDIA#1989) * fix mem fence (NVIDIA#2030) Co-authored-by: yuzhai <[email protected]> * Add half->int8 saturate conversion to promise valid range (NVIDIA#1983) * Add half->int8 saturate conversion to promise valid range * add gpu only macro --------- Co-authored-by: Haicheng Wu <[email protected]> * Add vector-types back to platform.h (NVIDIA#2026) * Fix typo in library_defaults.py (NVIDIA#2024) * Fix Typos (NVIDIA#2021) * Fix Typo * Fix Typo * Add Line Break (NVIDIA#2020) * Blockwise Scaling for FP8 (NVIDIA#1932) * F8 Blockwise Scaling * two more NumProducerThreadEvents --------- Co-authored-by: Haicheng Wu <[email protected]> * fix assertion in integer_subbytes.h (NVIDIA#1961) * CUTLASS 3.7 (NVIDIA#2045) * CUTLASS 3.7 * clean up changelog --------- Co-authored-by: yuzhai <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * update 3.7 docs (NVIDIA#2051) * update docs * update docs * update docs * update docs * update docs --------- Co-authored-by: yuzhai <[email protected]> * CUTLASS 3.8 Release (NVIDIA#2059) * CUTLASS 3.8 Release * update * Update README.md * Revert "Update README.md" This reverts commit b353e36. * update * update --------- Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * fix cuda 12.6 issues (NVIDIA#2066) * fix a readme broken link (NVIDIA#2069) * Update README.md * Groupwise scaling along M for FP8 gemm (NVIDIA#2037) * FP8 groupwise scaling along M * small updates --------- Co-authored-by: zl <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> * bugfix generic-k code in top-k with softmax (NVIDIA#1993) * bugfix generic-k code in top-k with softmax * Update include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp Co-authored-by: Ali Hassani <[email protected]> * Update examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu Co-authored-by: Ali Hassani <[email protected]> --------- Co-authored-by: Ali Hassani <[email protected]> * [EVT] Add support for Row/Col broadcast PtrArray (NVIDIA#2033) * Add group support to EVT row/col broadcast. * small modifications --------- Co-authored-by: Haicheng Wu <[email protected]> * v3.8.0 update (NVIDIA#2082) * 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <[email protected]> * [WA] Fix compiling errors --------- Co-authored-by: Saagar Jha <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Sergey Klevtsov <[email protected]> Co-authored-by: Tri Dao <[email protected]> Co-authored-by: Xinyu Yang <[email protected]> Co-authored-by: sijialou <[email protected]> Co-authored-by: Bogumil Sapinski Mobica <[email protected]> Co-authored-by: Haicheng Wu <[email protected]> Co-authored-by: Lei Mao <[email protected]> Co-authored-by: 103yiran <[email protected]> Co-authored-by: MaxAkaAltmer <[email protected]> Co-authored-by: 侯奇 <[email protected]> Co-authored-by: Lain <[email protected]> Co-authored-by: Siyuan Fu <[email protected]> Co-authored-by: Caleb_Du <[email protected]> Co-authored-by: LiYu Lu <[email protected]> Co-authored-by: azhurkevich <[email protected]> Co-authored-by: chenwei <[email protected]> Co-authored-by: Wenlei Bao <[email protected]> Co-authored-by: LiuQiang <[email protected]> Co-authored-by: dan_the_3rd <[email protected]> Co-authored-by: Jack Kosaian <[email protected]> Co-authored-by: Yujia Zhai <[email protected]> Co-authored-by: yuzhai <[email protected]> Co-authored-by: Andrew O'Neill <[email protected]> Co-authored-by: Dongxu.Wang <[email protected]> Co-authored-by: ZZK <[email protected]> Co-authored-by: Driss Guessous <[email protected]> Co-authored-by: ZincCat <[email protected]> Co-authored-by: Manish Gupta <[email protected]> Co-authored-by: bobliao <[email protected]> Co-authored-by: mihir-awatramani <[email protected]> Co-authored-by: Liang <[email protected]> Co-authored-by: zl <[email protected]> Co-authored-by: Tadej Ciglarič <[email protected]> Co-authored-by: Ali Hassani <[email protected]> Co-authored-by: Josh Fromm <[email protected]>
* F8 Blockwise Scaling * two more NumProducerThreadEvents --------- Co-authored-by: Haicheng Wu <[email protected]>
Summary
As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g.,
e5m2_t
,e4m3_t
). The typical GEMM operation uses tensorwise scaling withD = alpha * (A @ B) + beta * C
, but narrower datatypes necessitate more finer-grained scaling techniques. This PR adds blockwise scaling strategy to improve accuracy while making an effort to not loose performance. Before we dive deep into blockwise scaling below is a glossary of various scaling methods:EpilogueVisitorTree
.This enhancement focuses on improving GEMM accuracy for narrow datatypes, balancing the trade-off between performance and precision with the addition of blockwise scaling support.
Blockwise Scaling
The figure below illustrates a blockwise scaled GEMM, with operand tensors A and B shown in grey, block scaling tensors in blue, and output in green. In this implementation, we load operand tensors using
UTMALDG
and block scaling tensors usingLDGSTS
, transferring them from global memory to shared memory. Block scaling tensor loads are issued for the same stage as the operand tensor loads. To ensure proper synchronization forLDGSTS
, we usecutlass::arch::cpasync_barrier_arrive
withnoinc
modifier. We have modified thePipelineTmaAsync
class to accommodate a variable number of producer thread arrival events to support this functionality effectively.Performance
For the graph below, I used CUDA Toolkit 12.3.2. Please note that with the latest toolkit 12.6.2, I observe LDLs and STLs in the SASS and the performance of block scaling is terrible. Thus, I stick with 12.3.2 for further performance optimizations and look for source of improvements. Please find the SASS attached for the example
54_hopper_fp8_warp_specialized_gemm
(F8 with Slow Accumulation, FADDs after QGMMAs inside the mainloop) and the new64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
(F8 with Slow Accum, FFMA after QGMMAs inside the mainloop).NCU Profile (Kernel compiled with CUDA Toolkit 12.3.2)
I see one major stall for both FADD version and FFMA version soon after the QGMMA, waiting for for the accumulator to be ready to apply the promotions and scaling, respectively. I don't see any other difference other that that this stall is larger for FFMA.
FADD Version (Example 54 with slow accumulation and tensorwise scaling. Modified to have same tiling and kernel schedule as 64)
FFMA Version (Example 64 with slow accumulation and blockwise scaling)
Technically, for a large GEMM shape with cooperative schedule, we would expect both version to be running the same performance. Let me know if you have more input on what we are missing here to match the performance. We will eventually need a good implementation of blockwise scaling kernel in CUTLASS for plain F8 GEMMs and also take these learnings to FlashAttention-3 F8 Scaling.
Attn: @IonThruster , @hwu36 , @thakkarV