ThreadGroupTensorSliceTransfer_v4r1_dequant< ThreadGroup, SrcElementwiseOperation, ScaleElementwiseOperation, DstElementwiseOperation, DstInMemOp, BlockSliceLengths, BlockScaleSliceLengths, ThreadClusterLengths, ThreadClusterArrangeOrder, SrcData, ScaleData, DstData, SrcDesc, ScaleDesc, DstDesc, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, ScaleScalarPerVector, DstScalarPerVector, SrcScalarStrideInVector, ScaleScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, NumThreadScratch > Struct Template Reference#
ck::ThreadGroupTensorSliceTransfer_v4r1_dequant< ThreadGroup, SrcElementwiseOperation, ScaleElementwiseOperation, DstElementwiseOperation, DstInMemOp, BlockSliceLengths, BlockScaleSliceLengths, ThreadClusterLengths, ThreadClusterArrangeOrder, SrcData, ScaleData, DstData, SrcDesc, ScaleDesc, DstDesc, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, ScaleScalarPerVector, DstScalarPerVector, SrcScalarStrideInVector, ScaleScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, NumThreadScratch > Struct Template Reference
Blockwise data transfer with dequantization. More...
#include <thread_group_tensor_slice_transfer_v4r1_dequant.hpp>
Public Types | |
| using | Index = MultiIndex<nDim> |
Public Member Functions | |
| __device__ constexpr | ThreadGroupTensorSliceTransfer_v4r1_dequant (const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const ScaleDesc &scale_desc, const Index &scale_block_slice_origin, const ScaleElementwiseOperation &scale_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op) |
| template<typename SrcBuffer, index_t ThreadScratchId = 0> | |
| __device__ void | RunRead (const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{}) |
| template<typename ScaleBuffer> | |
| __device__ void | RunScaleRead (const ScaleDesc &scale_desc, const ScaleBuffer &scale_buf) |
| template<typename DstBuffer, index_t ThreadScratchId = 0> | |
| __device__ void | RunWrite (const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{}) |
| __device__ void | MoveSrcSliceWindow (const SrcDesc &src_desc, const Index &step) |
| __device__ void | MoveDstSliceWindow (const DstDesc &dst_desc, const Index &step) |
Static Public Attributes | |
| static constexpr index_t | nDim = remove_reference_t<SrcDesc>::GetNumOfDimension() |
| static constexpr auto | thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{} |
| static constexpr auto | scale_thread_slice_lengths |
Detailed Description
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
struct ck::ThreadGroupTensorSliceTransfer_v4r1_dequant< ThreadGroup, SrcElementwiseOperation, ScaleElementwiseOperation, DstElementwiseOperation, DstInMemOp, BlockSliceLengths, BlockScaleSliceLengths, ThreadClusterLengths, ThreadClusterArrangeOrder, SrcData, ScaleData, DstData, SrcDesc, ScaleDesc, DstDesc, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, ScaleScalarPerVector, DstScalarPerVector, SrcScalarStrideInVector, ScaleScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, NumThreadScratch >
struct ck::ThreadGroupTensorSliceTransfer_v4r1_dequant< ThreadGroup, SrcElementwiseOperation, ScaleElementwiseOperation, DstElementwiseOperation, DstInMemOp, BlockSliceLengths, BlockScaleSliceLengths, ThreadClusterLengths, ThreadClusterArrangeOrder, SrcData, ScaleData, DstData, SrcDesc, ScaleDesc, DstDesc, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, ScaleScalarPerVector, DstScalarPerVector, SrcScalarStrideInVector, ScaleScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, NumThreadScratch >
Blockwise data transfer with dequantization.
RunRead would load low-precision data and scale data. RunWrite would process dequantization process. Assume Scale is identical along K-dimension
This version does following things to avoid scratch memory issue
- Use StaticallyIndexedArray instead of C array for thread buffer
- ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
- ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
Member Typedef Documentation
◆ Index
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
| using ck::ThreadGroupTensorSliceTransfer_v4r1_dequant< ThreadGroup, SrcElementwiseOperation, ScaleElementwiseOperation, DstElementwiseOperation, DstInMemOp, BlockSliceLengths, BlockScaleSliceLengths, ThreadClusterLengths, ThreadClusterArrangeOrder, SrcData, ScaleData, DstData, SrcDesc, ScaleDesc, DstDesc, SrcDimAccessOrder, DstDimAccessOrder, SrcVectorDim, DstVectorDim, SrcScalarPerVector, ScaleScalarPerVector, DstScalarPerVector, SrcScalarStrideInVector, ScaleScalarStrideInVector, DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, NumThreadScratch >::Index = MultiIndex<nDim> |
Constructor & Destructor Documentation
◆ ThreadGroupTensorSliceTransfer_v4r1_dequant()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
inlineconstexpr |
Member Function Documentation
◆ MoveDstSliceWindow()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
inline |
◆ MoveSrcSliceWindow()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
inline |
◆ RunRead()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
template<typename SrcBuffer, index_t ThreadScratchId = 0>
|
inline |
◆ RunScaleRead()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
template<typename ScaleBuffer>
|
inline |
◆ RunWrite()
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
template<typename DstBuffer, index_t ThreadScratchId = 0>
|
inline |
Member Data Documentation
◆ nDim
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
staticconstexpr |
◆ scale_thread_slice_lengths
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
staticconstexpr |
Initial value:
=
BlockScaleSliceLengths{} / ThreadClusterLengths{}
◆ thread_slice_lengths
template<typename ThreadGroup, typename SrcElementwiseOperation, typename ScaleElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename BlockScaleSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename ScaleData, typename DstData, typename SrcDesc, typename ScaleDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t ScaleScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t ScaleScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
|
staticconstexpr |
The documentation for this struct was generated from the following file: