-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathgemm_operation_3x.hpp
701 lines (607 loc) · 30.7 KB
/
gemm_operation_3x.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines operations for all GEMM operation kinds in CUTLASS Library.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/collective.hpp"
#include "cutlass/array.h"
#include "cutlass/array_subbyte.h"
#include "cutlass/library/library.h"
#include "library_internal.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cute/tensor.hpp"
#include <unordered_map>
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::library {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmOperation3xBase : public Operation {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
// assuming all tensors use same type for StrideIndex
using StrideIndex = typename Operator::LayoutA::Index;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
protected:
GemmDescription description_;
public:
/// Constructor
GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) {
description_.name = name;
description_.provider = Provider::kCUTLASS;
description_.kind = OperationKind::kGemm;
description_.gemm_kind = gemm_kind_;
description_.tile_description.threadblock_shape = make_Coord(
Operator::ThreadblockShape::kM,
Operator::ThreadblockShape::kN,
Operator::ThreadblockShape::kK);
if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) {
description_.tile_description.cluster_shape = make_Coord(
Operator::ClusterShape::kM,
Operator::ClusterShape::kN,
Operator::ClusterShape::kK);
}
description_.tile_description.threadblock_stages = Operator::kStages;
description_.tile_description.warp_count = make_Coord(
Operator::WarpCount::kM,
Operator::WarpCount::kN,
Operator::WarpCount::kK);
description_.tile_description.math_instruction.instruction_shape = make_Coord(
Operator::InstructionShape::kM,
Operator::InstructionShape::kN,
Operator::InstructionShape::kK);
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
description_.tile_description.math_instruction.opcode_class =
OpcodeClassMap<typename Operator::OperatorClass>::kId;
description_.tile_description.math_instruction.math_operation =
MathOperationMap<typename Operator::MathOperator>::kId;
description_.tile_description.minimum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMin;
description_.tile_description.maximum_compute_capability =
ArchMap<typename Operator::ArchTag, typename Operator::OperatorClass>::kMax;
description_.A = make_TensorDescription<ElementA, LayoutA>(Operator::kAlignmentA);
description_.B = make_TensorDescription<ElementB, LayoutB>(Operator::kAlignmentB);
description_.C = make_TensorDescription<ElementC, LayoutC>(Operator::kAlignmentC);
description_.D = make_TensorDescription<ElementD, LayoutD>(Operator::kAlignmentD);
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.split_k_mode = SplitKMode::kNone;
description_.transform_A = ComplexTransformMap<Operator::kTransformA>::kId;
description_.transform_B = ComplexTransformMap<Operator::kTransformB>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
/// Returns the description of the GEMM operation
GemmDescription const& get_gemm_description() const {
return description_;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Operator_>
class GemmUniversal3xOperation : public GemmOperation3xBase<Operator_> {
public:
using Operator = Operator_;
using OperatorArguments = typename Operator::Arguments;
using ElementA = typename Operator::ElementA;
using LayoutA = typename Operator::LayoutA;
using ElementB = typename Operator::ElementB;
using LayoutB = typename Operator::LayoutB;
using ElementC = typename Operator::ElementC;
using LayoutC = typename Operator::LayoutC;
using ElementD = typename Operator::ElementD;
using LayoutD = typename Operator::LayoutD;
using ElementAccumulator = typename Operator::ElementAccumulator;
using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute;
using CollectiveMainloop = typename Operator::CollectiveMainloop;
using CollectiveEpilogue = typename Operator::CollectiveEpilogue;
using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp;
static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementA>();
static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4<ElementB>();
static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) ||
(!IsRuntimeDataTypeA && !IsRuntimeDataTypeB),
"ElementA and ElementB in a GEMM kernel should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
public:
/// Constructor
GemmUniversal3xOperation(char const *name = "unknown_gemm"):
GemmOperation3xBase<Operator_>(name, GemmKind::kUniversal) {
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
dim3 cluster_dims(
cute::size<0>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<1>(typename Operator::GemmKernel::ClusterShape{}),
cute::size<2>(typename Operator::GemmKernel::ClusterShape{}));
uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock;
void const* kernel_ptr = (void*)(device_kernel<typename Operator::GemmKernel>);
max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters(
cluster_dims,
threads_per_block,
kernel_ptr);
}
}
private:
int max_active_clusters{};
protected:
/// Constructs the arguments structure given the configuration and arguments
static Status construct_arguments_(
OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) {
// NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides
// Do nothing here and construct kernel arguments in update_arguments_ instead
// We also cannot construct TMA descriptors without all the arguments available
operator_args.mode = configuration->mode;
return Status::kSuccess;
}
template<class FusionArgs, class = void>
struct UpdateFusionArgs {
static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) {
// If a custom EVT is instantiated then it is the users's responsibility
// to ensure alpha and beta are updated appropriately
return Status::kSuccess;
}
};
template<class FusionArgs>
struct UpdateFusionArgs<FusionArgs, cute::void_t<decltype(FusionArgs{}.alpha)>> {
static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) {
if (arguments.pointer_mode == ScalarPointerMode::kHost) {
fusion_args.alpha = *static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta = *static_cast<ElementCompute const *>(arguments.beta);
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
return Status::kSuccess;
}
else if (arguments.pointer_mode == ScalarPointerMode::kDevice) {
fusion_args.alpha = 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = static_cast<ElementCompute const *>(arguments.alpha);
fusion_args.beta_ptr = static_cast<ElementCompute const *>(arguments.beta);
return Status::kSuccess;
}
else {
return Status::kErrorInvalidProblem;
}
}
};
template<template<int, class, class> class Policy, int Stages, class ClusterShape, class KernelSchedule>
static constexpr bool is_mixed_dtype_mainloop_(Policy<Stages, ClusterShape, KernelSchedule> policy) {
return (cute::is_same_v<Policy<Stages, ClusterShape, KernelSchedule>,
cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<Stages, ClusterShape, KernelSchedule>>);
}
template <class DispatchPolicy>
static constexpr bool is_mixed_dtype_mainloop_(DispatchPolicy) {
return false;
}
template <
typename ElementWide,
typename ElementNarrow,
typename ElementScaleMainloop,
class ActualStrideAB,
Sm90MixedInputWiderOperand wider_operand,
bool is_n4w8,
typename ElementScale,
typename ElementZero,
class Layout_SZ>
static void dequantize_encode_(
OperatorArguments &operator_args,
GemmUniversalArguments const *arguments,
cudaStream_t stream,
const int &problem_mn,
const int &problem_k,
const int &options_l,
const int &options_g,
ElementScale *ptr_S,
ElementZero *ptr_Z,
const size_t &SZ_size,
Layout_SZ layout_SZ
) {
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
auto *ptr_dequantized_AB = static_cast<ElementWide *>(arguments->dequantized_AB);
const ElementNarrow *ptr_AB = nullptr;
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
ptr_AB = static_cast<const ElementNarrow *>(arguments->B);
}
else {
ptr_AB = static_cast<const ElementNarrow *>(arguments->A);
}
dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream);
if constexpr(is_n4w8) {
size_t AB_size = cute::size(layout_AB);
cutlass::int4b_t *encoded_AB = static_cast<cutlass::int4b_t *>(arguments->encoded_AB);
unified_encode_int4b(ptr_AB, encoded_AB, AB_size);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.ptr_B = static_cast<ElementNarrow const *>(encoded_AB);
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementNarrow const *>(encoded_AB);
}
ElementScaleMainloop *ptr_packed_Scale = static_cast<ElementScaleMainloop *>(arguments->packed_Scale);
pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size);
}
}
template <
typename ElementAB,
class ActualStrideAB,
class LayoutAB_Reordered,
class LayoutAtomQuant,
Sm90MixedInputWiderOperand wider_operand>
static void handle_shuffle_tensor_(
OperatorArguments &operator_args,
GemmUniversalArguments const *arguments,
const int &problem_mn,
const int &problem_k,
const int &options_l) {
auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l);
auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB);
auto layout_AB = cute::make_layout(shape_AB, stride_AB);
LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.dB = layout_AB_reordered;
}
else {
operator_args.mainloop.dA = layout_AB_reordered;
}
if (arguments->generate_dequantized_AB) {
size_t AB_size = cute::size(layout_AB);
ElementAB *AB_reordered = cutlass::device_memory::allocate<ElementAB>(AB_size);
const ElementAB *AB_src = nullptr;
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_B);
}
else {
AB_src = static_cast<const ElementAB *>(operator_args.mainloop.ptr_A);
}
reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered);
ElementAB *AB_dst = static_cast<ElementAB *>(arguments->encoded_AB);
cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size);
cutlass::device_memory::free(AB_reordered);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
operator_args.mainloop.ptr_B = AB_dst;
}
else {
operator_args.mainloop.ptr_A = AB_dst;
}
}
}
/// Constructs the arguments structure given the configuration and arguments
Status update_arguments_(
OperatorArguments& operator_args,
GemmUniversalArguments const* arguments,
cudaStream_t stream = nullptr) const {
Status status = Status::kSuccess;
status = UpdateFusionArgs<decltype(operator_args.epilogue.thread)>::update_(
operator_args.epilogue.thread, *arguments);
if (status != Status::kSuccess) {
return status;
}
// TODO: type erase Arguments structure in 3.0 GEMM
operator_args.problem_shape = cute::make_shape(
arguments->problem_size.m(),
arguments->problem_size.n(),
arguments->problem_size.k(),
arguments->batch_count);
// update arguments
if constexpr (IsRuntimeDataType) {
using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA;
using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB;
operator_args.mainloop.ptr_A = static_cast<ArrayElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ArrayElementB const *>(arguments->B);
std::unordered_map<RuntimeDatatype, cute::UMMA::MXF8F6F4Format> mapping = {
{RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3},
{RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2},
{RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2},
{RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1}
};
auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a);
auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b);
if (iter_runtime_a != mapping.end()) {
operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second;
} else {
assert("invalid runtime argument for datatype A!");
}
if (iter_runtime_b != mapping.end()) {
operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second;
} else {
assert("invalid runtime argument for datatype B!");
}
}
else {
operator_args.mainloop.ptr_A = static_cast<ElementA const *>(arguments->A);
operator_args.mainloop.ptr_B = static_cast<ElementB const *>(arguments->B);
}
operator_args.epilogue.ptr_C = static_cast<ElementC const *>(arguments->C);
operator_args.epilogue.ptr_D = static_cast<ElementD *>(arguments->D);
// Stride{A,B} is a Layout if and only if:
// (1) This is a mixed dtype kernel, and
// (2) This mixed dtype kernel is using shuffling, and
// (3) sizeof(narrow_type) == 4 or 8 bits, and
// (4) sizeof(wide_type) == 16 bits.
// If A/B has the narrow data type, Stride{A/B} will be a Layout
constexpr bool is_StrideA_Layout = cute::is_layout<typename CollectiveMainloop::StrideA>::value;
constexpr bool is_StrideB_Layout = cute::is_layout<typename CollectiveMainloop::StrideB>::value;
static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout");
if constexpr(!is_StrideA_Layout) {
operator_args.mainloop.dA = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideA>(
arguments->lda, arguments->batch_stride_A);
}
if constexpr(!is_StrideB_Layout) {
operator_args.mainloop.dB = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideB>(
arguments->ldb, arguments->batch_stride_B);
}
operator_args.epilogue.dC = cute::make_int_tuple_from<typename Operator::GemmKernel::StrideC>(
arguments->ldc, arguments->batch_stride_C);
operator_args.epilogue.dD = operator_args.epilogue.dC;
using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy;
if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{})) {
int problem_m = arguments->problem_size.m();
int problem_n = arguments->problem_size.n();
int problem_k = arguments->problem_size.k();
int options_l = arguments->batch_count;
constexpr Sm90MixedInputWiderOperand wider_operand =
(cutlass::sizeof_bits<ElementA>::value > cutlass::sizeof_bits<ElementB>::value) ?
Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B;
using ElementWide = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementA, ElementB>;
using ElementNarrow = std::conditional_t<wider_operand == Sm90MixedInputWiderOperand::A, ElementB, ElementA>;
constexpr bool has_scale = !std::is_same_v<typename CollectiveMainloop::ElementScale, void>;
constexpr bool has_zero = !std::is_same_v<typename CollectiveMainloop::ElementZero, void>;
if constexpr(has_scale) {
int options_g = problem_k;
int scale_k = (problem_k + options_g - 1) / options_g;
constexpr bool is_A4B8 = (
cutlass::is_same_v<ElementA, cutlass::int4b_t> &&
(cutlass::is_same_v<ElementB, cutlass::float_e4m3_t> ||
cutlass::is_same_v<ElementB, cutlass::float_e5m2_t>));
constexpr bool is_A8B4 = (
cutlass::is_same_v<ElementB, cutlass::int4b_t> &&
(cutlass::is_same_v<ElementA, cutlass::float_e4m3_t> ||
cutlass::is_same_v<ElementA, cutlass::float_e5m2_t>));
constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4;
// In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element
using ElementScaleMainloop = typename CollectiveMainloop::ElementScale;
using ElementScale = typename UnderlyingElement<typename CollectiveMainloop::ElementScale>::type;
using StrideS = typename CollectiveMainloop::StrideScale;
// In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S
using ElementZero = std::conditional_t<
has_zero,
typename CollectiveMainloop::ElementZero,
ElementScale
>;
const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m;
const size_t SZ_size = static_cast<size_t>(SZ_1st_dim * scale_k * options_l);
auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l);
ElementScale *ptr_S = static_cast<ElementScale *>(arguments->Scale);
ElementZero *ptr_Z = static_cast<ElementZero *>(arguments->Zero);
// 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled
if (arguments->generate_scale_and_zero) {
// Need to fix max_dequant_val and min_dequant_val?
const float elt_max_f = float(cutlass::platform::numeric_limits<ElementScale>::max());
const float max_dequant_val = elt_max_f * 0.25f;
const float min_dequant_val = 0.5f;
const float scale_max = max_dequant_val / elt_max_f;
const float scale_min = min_dequant_val / elt_max_f;
uint64_t seed = 2023;
cutlass::reference::device::BlockFillRandomUniform(
ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min));
// In ScaleOnly mode, set Z as zero for generating dequantized A or B
const float zero_max = has_zero ? 2.0f : 0.0f;
const float zero_min = has_zero ? -2.0f : 0.0f;
cutlass::reference::device::BlockFillRandomUniform(
ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min));
} // End of "if (arguments->generate_scale_and_zero)"
// 2. Generate the dequantized A or B for verification
if (arguments->generate_dequantized_AB) {
StrideS stride_SZ = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ);
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) {
if constexpr(is_StrideB_Layout) {
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout of B later
using ActualLayoutB = cutlass::layout::ColumnMajor;
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
else {
using ActualStrideB = typename CollectiveMainloop::StrideB;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideB, wider_operand, is_A8B4>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
}
else {
if constexpr(is_StrideA_Layout) {
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout of A later
using ActualLayoutA = cutlass::layout::RowMajor;
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
else {
using ActualStrideA = typename CollectiveMainloop::StrideA;
dequantize_encode_<ElementWide, ElementNarrow, ElementScaleMainloop, ActualStrideA, wider_operand, is_A4B8>(
operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ);
}
} // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)"
arguments->dequantized_AB_ready[0] = true;
} // End of "if (arguments->generate_dequantized_AB)"
// 3. Put arguments in mainloop
if constexpr(is_int4_x_fp8) {
operator_args.mainloop.ptr_S = static_cast<ElementScaleMainloop const*>(arguments->packed_Scale);
}
else {
operator_args.mainloop.ptr_S = static_cast<ElementScale const*>(arguments->Scale);
}
operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ);
operator_args.mainloop.group_size = options_g;
if constexpr(has_zero) {
operator_args.mainloop.ptr_Z = static_cast<ElementZero const*>(arguments->Zero);
}
} // End of "if constexpr(has_scale)"
// Handle the shuffling
using ValueShuffle = std::conditional_t<
cutlass::sizeof_bits<ElementNarrow>::value == 4,
cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>,
cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>
>;
constexpr int NumShuffleAtoms = 1;
using MmaAtomShape = cute::Layout<cute::Shape<cute::_1,cute::Int<NumShuffleAtoms>>>;
using LayoutAtomQuant = decltype(compute_memory_reordering_atom<ElementWide, MmaAtomShape, ValueShuffle>());
// The generator only generates row-major A and col-major B at the moment
// Need a way to read out the actual layout and stride of A/B later
if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) {
using ActualLayoutB = cutlass::layout::ColumnMajor;
using ActualStrideB = cutlass::detail::TagToStrideB_t<ActualLayoutB>;
using LayoutB_Reordered = typename CollectiveMainloop::StrideB;
handle_shuffle_tensor_<ElementB, ActualStrideB, LayoutB_Reordered, LayoutAtomQuant, wider_operand>(
operator_args, arguments, problem_n, problem_k, options_l);
}
if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) {
using ActualLayoutA = cutlass::layout::RowMajor;
using ActualStrideA = cutlass::detail::TagToStrideA_t<ActualLayoutA>;
using LayoutA_Reordered = typename CollectiveMainloop::StrideA;
handle_shuffle_tensor_<ElementA, ActualStrideA, LayoutA_Reordered, LayoutAtomQuant, wider_operand>(
operator_args, arguments, problem_m, problem_k, options_l);
}
} // End of "if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{}))"
/* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */
operator_args.hw_info.sm_count = arguments->sm_count;
if constexpr (Operator::ArchTag::kMinComputeCapability == 90) {
operator_args.hw_info.max_active_clusters = max_active_clusters;
}
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.max_swizzle_size)>) {
operator_args.scheduler.max_swizzle_size = arguments->swizzle_size;
}
if constexpr (!std::is_const_v<decltype(operator_args.scheduler.raster_order)>) {
using Enum_t = decltype(operator_args.scheduler.raster_order);
switch (arguments->raster_order) {
case RasterOrder::kAlongN:
operator_args.scheduler.raster_order = Enum_t::AlongN;
break;
case RasterOrder::kAlongM:
operator_args.scheduler.raster_order = Enum_t::AlongM;
break;
default:
operator_args.scheduler.raster_order = Enum_t::Heuristic;
}
}
if constexpr (std::is_same_v<typename Operator::GemmKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
operator_args.scheduler.splits = arguments->split_k_slices;
}
if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) {
operator_args.hw_info.cluster_shape = dim3(
arguments->cluster_shape.m(),
arguments->cluster_shape.n(),
arguments->cluster_shape.k());
operator_args.hw_info.cluster_shape_fallback = dim3(
arguments->cluster_shape_fallback.m(),
arguments->cluster_shape_fallback.n(),
arguments->cluster_shape_fallback.k());
}
return status;
}
public:
/// Returns success if the operation can proceed
Status can_implement(
[[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override {
GemmUniversalArguments const *arguments =
static_cast<GemmUniversalArguments const *>(arguments_ptr);
OperatorArguments args;
auto status = update_arguments_(args, arguments);
if (status != Status::kSuccess) {
return status;
}
Status can_impl = Operator::can_implement(args);
//return Operator::can_implement(args);
return can_impl;
}
/// Gets the host-side workspace
uint64_t get_host_workspace_size(void const *configuration) const override {
return sizeof(Operator);
}
/// Gets the device-side workspace
uint64_t get_device_workspace_size(
void const *configuration_ptr,void const *arguments_ptr) const override {
OperatorArguments args;
auto status = update_arguments_(
args, static_cast<GemmUniversalArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
}
/// Initializes the workspace
Status initialize(
void const *configuration_ptr,
void *host_workspace,
void *device_workspace,
cudaStream_t stream = nullptr) const override {
Operator *op = new (host_workspace) Operator;
return Status::kSuccess;
}
/// Runs the kernel
Status run(
void const *arguments_ptr,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const override {
OperatorArguments args;
Status status = update_arguments_(args, static_cast<GemmUniversalArguments const *>(arguments_ptr), stream);
if (status != Status::kSuccess) {
return status;
}
Operator *op = static_cast<Operator *>(host_workspace);
// We need to call initialize() since we have to rebuild TMA desc for every new set of args
status = op->run(args, device_workspace, stream, nullptr,
static_cast<GemmUniversalArguments const *>(arguments_ptr)->use_pdl);
return status;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::library
///////////////////////////////////////////////////////////////////////////////////////////////////