LIVE / thrust /cub /agent /agent_rle.cuh
Xu Ma
update
1c3c0d9
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode.
*/
#pragma once
#include <iterator>
#include "single_pass_scan_operators.cuh"
#include "../block/block_load.cuh"
#include "../block/block_store.cuh"
#include "../block/block_scan.cuh"
#include "../block/block_exchange.cuh"
#include "../block/block_discontinuity.cuh"
#include "../config.cuh"
#include "../grid/grid_queue.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
#include "../iterator/constant_input_iterator.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/******************************************************************************
* Tuning policy types
******************************************************************************/
/**
* Parameterizable tuning policy type for AgentRle
*/
template <
int _BLOCK_THREADS, ///< Threads per thread block
int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use
CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements
bool _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use
struct AgentRlePolicy
{
enum
{
BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block
ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
STORE_WARP_TIME_SLICING = _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
};
static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use
static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements
static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use
};
/******************************************************************************
* Thread block abstractions
******************************************************************************/
/**
* \brief AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode
*/
template <
typename AgentRlePolicyT, ///< Parameterized AgentRlePolicyT tuning policy type
typename InputIteratorT, ///< Random-access input iterator type for data
typename OffsetsOutputIteratorT, ///< Random-access output iterator type for offset values
typename LengthsOutputIteratorT, ///< Random-access output iterator type for length values
typename EqualityOpT, ///< T equality operator type
typename OffsetT> ///< Signed integer type for global offsets
struct AgentRle
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
/// The input value type
typedef typename std::iterator_traits<InputIteratorT>::value_type T;
/// The lengths output value type
typedef typename If<(Equals<typename std::iterator_traits<LengthsOutputIteratorT>::value_type, void>::VALUE), // LengthT = (if output iterator's value type is void) ?
OffsetT, // ... then the OffsetT type,
typename std::iterator_traits<LengthsOutputIteratorT>::value_type>::Type LengthT; // ... else the output iterator's value type
/// Tuple type for scanning (pairs run-length and run-index)
typedef KeyValuePair<OffsetT, LengthT> LengthOffsetPair;
/// Tile status descriptor interface type
typedef ReduceByKeyScanTileState<LengthT, OffsetT> ScanTileStateT;
// Constants
enum
{
WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH),
BLOCK_THREADS = AgentRlePolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentRlePolicyT::ITEMS_PER_THREAD,
WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
/// Whether or not to sync after loading data
SYNC_AFTER_LOAD = (AgentRlePolicyT::LOAD_ALGORITHM != BLOCK_LOAD_DIRECT),
/// Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
STORE_WARP_TIME_SLICING = AgentRlePolicyT::STORE_WARP_TIME_SLICING,
ACTIVE_EXCHANGE_WARPS = (STORE_WARP_TIME_SLICING) ? 1 : WARPS,
};
/**
* Special operator that signals all out-of-bounds items are not equal to everything else,
* forcing both (1) the last item to be tail-flagged and (2) all oob items to be marked
* trivial.
*/
template <bool LAST_TILE>
struct OobInequalityOp
{
OffsetT num_remaining;
EqualityOpT equality_op;
__device__ __forceinline__ OobInequalityOp(
OffsetT num_remaining,
EqualityOpT equality_op)
:
num_remaining(num_remaining),
equality_op(equality_op)
{}
template <typename Index>
__host__ __device__ __forceinline__ bool operator()(T first, T second, Index idx)
{
if (!LAST_TILE || (idx < num_remaining))
return !equality_op(first, second);
else
return true;
}
};
// Cache-modified Input iterator wrapper type (for applying cache modifier) for data
typedef typename If<IsPointer<InputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>, // Wrap the native input pointer with CacheModifiedVLengthnputIterator
InputIteratorT>::Type // Directly use the supplied input iterator type
WrappedInputIteratorT;
// Parameterized BlockLoad type for data
typedef BlockLoad<
T,
AgentRlePolicyT::BLOCK_THREADS,
AgentRlePolicyT::ITEMS_PER_THREAD,
AgentRlePolicyT::LOAD_ALGORITHM>
BlockLoadT;
// Parameterized BlockDiscontinuity type for data
typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT;
// Parameterized WarpScan type
typedef WarpScan<LengthOffsetPair> WarpScanPairs;
// Reduce-length-by-run scan operator
typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
// Callback type for obtaining tile prefix during block scan
typedef TilePrefixCallbackOp<
LengthOffsetPair,
ReduceBySegmentOpT,
ScanTileStateT>
TilePrefixCallbackOpT;
// Warp exchange types
typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD> WarpExchangePairs;
typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage;
typedef WarpExchange<OffsetT, ITEMS_PER_THREAD> WarpExchangeOffsets;
typedef WarpExchange<LengthT, ITEMS_PER_THREAD> WarpExchangeLengths;
typedef LengthOffsetPair WarpAggregates[WARPS];
// Shared memory type for this thread block
struct _TempStorage
{
// Aliasable storage layout
union Aliasable
{
struct
{
typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection
typename WarpScanPairs::TempStorage warp_scan[WARPS]; // Smem needed for warp-synchronous scans
Uninitialized<LengthOffsetPair[WARPS]> warp_aggregates; // Smem needed for sharing warp-wide aggregates
typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback
};
// Smem needed for input loading
typename BlockLoadT::TempStorage load;
// Aliasable layout needed for two-phase scatter
union ScatterAliasable
{
unsigned long long align;
WarpExchangePairsStorage exchange_pairs[ACTIVE_EXCHANGE_WARPS];
typename WarpExchangeOffsets::TempStorage exchange_offsets[ACTIVE_EXCHANGE_WARPS];
typename WarpExchangeLengths::TempStorage exchange_lengths[ACTIVE_EXCHANGE_WARPS];
} scatter_aliasable;
} aliasable;
OffsetT tile_idx; // Shared tile index
LengthOffsetPair tile_inclusive; // Inclusive tile prefix
LengthOffsetPair tile_exclusive; // Exclusive tile prefix
};
// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Pointer to input sequence of data items
OffsetsOutputIteratorT d_offsets_out; ///< Input run offsets
LengthsOutputIteratorT d_lengths_out; ///< Output run lengths
EqualityOpT equality_op; ///< T equality operator
ReduceBySegmentOpT scan_op; ///< Reduce-length-by-flag scan operator
OffsetT num_items; ///< Total number of input items
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Constructor
__device__ __forceinline__
AgentRle(
TempStorage &temp_storage, ///< [in] Reference to temp_storage
InputIteratorT d_in, ///< [in] Pointer to input sequence of data items
OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run offsets
LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run lengths
EqualityOpT equality_op, ///< [in] T equality operator
OffsetT num_items) ///< [in] Total number of input items
:
temp_storage(temp_storage.Alias()),
d_in(d_in),
d_offsets_out(d_offsets_out),
d_lengths_out(d_lengths_out),
equality_op(equality_op),
scan_op(cub::Sum()),
num_items(num_items)
{}
//---------------------------------------------------------------------
// Utility methods for initializing the selections
//---------------------------------------------------------------------
template <bool FIRST_TILE, bool LAST_TILE>
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_remaining,
T (&items)[ITEMS_PER_THREAD],
LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD])
{
bool head_flags[ITEMS_PER_THREAD];
bool tail_flags[ITEMS_PER_THREAD];
OobInequalityOp<LAST_TILE> inequality_op(num_remaining, equality_op);
if (FIRST_TILE && LAST_TILE)
{
// First-and-last-tile always head-flags the first item and tail-flags the last item
BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
head_flags, tail_flags, items, inequality_op);
}
else if (FIRST_TILE)
{
// First-tile always head-flags the first item
// Get the first item from the next tile
T tile_successor_item;
if (threadIdx.x == BLOCK_THREADS - 1)
tile_successor_item = d_in[tile_offset + TILE_ITEMS];
BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
head_flags, tail_flags, tile_successor_item, items, inequality_op);
}
else if (LAST_TILE)
{
// Last-tile always flags the last item
// Get the last item from the previous tile
T tile_predecessor_item;
if (threadIdx.x == 0)
tile_predecessor_item = d_in[tile_offset - 1];
BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
head_flags, tile_predecessor_item, tail_flags, items, inequality_op);
}
else
{
// Get the first item from the next tile
T tile_successor_item;
if (threadIdx.x == BLOCK_THREADS - 1)
tile_successor_item = d_in[tile_offset + TILE_ITEMS];
// Get the last item from the previous tile
T tile_predecessor_item;
if (threadIdx.x == 0)
tile_predecessor_item = d_in[tile_offset - 1];
BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op);
}
// Zip counts and runs
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
lengths_and_num_runs[ITEM].key = head_flags[ITEM] && (!tail_flags[ITEM]);
lengths_and_num_runs[ITEM].value = ((!head_flags[ITEM]) || (!tail_flags[ITEM]));
}
}
//---------------------------------------------------------------------
// Scan utility methods
//---------------------------------------------------------------------
/**
* Scan of allocations
*/
__device__ __forceinline__ void WarpScanAllocations(
LengthOffsetPair &tile_aggregate,
LengthOffsetPair &warp_aggregate,
LengthOffsetPair &warp_exclusive_in_tile,
LengthOffsetPair &thread_exclusive_in_warp,
LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD])
{
// Perform warpscans
unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
int lane_id = LaneId();
LengthOffsetPair identity;
identity.key = 0;
identity.value = 0;
LengthOffsetPair thread_inclusive;
LengthOffsetPair thread_aggregate = internal::ThreadReduce(lengths_and_num_runs, scan_op);
WarpScanPairs(temp_storage.aliasable.warp_scan[warp_id]).Scan(
thread_aggregate,
thread_inclusive,
thread_exclusive_in_warp,
identity,
scan_op);
// Last lane in each warp shares its warp-aggregate
if (lane_id == WARP_THREADS - 1)
temp_storage.aliasable.warp_aggregates.Alias()[warp_id] = thread_inclusive;
CTA_SYNC();
// Accumulate total selected and the warp-wide prefix
warp_exclusive_in_tile = identity;
warp_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[warp_id];
tile_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[0];
#pragma unroll
for (int WARP = 1; WARP < WARPS; ++WARP)
{
if (warp_id == WARP)
warp_exclusive_in_tile = tile_aggregate;
tile_aggregate = scan_op(tile_aggregate, temp_storage.aliasable.warp_aggregates.Alias()[WARP]);
}
}
//---------------------------------------------------------------------
// Utility methods for scattering selections
//---------------------------------------------------------------------
/**
* Two-phase scatter, specialized for warp time-slicing
*/
template <bool FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OffsetT tile_num_runs_exclusive_in_global,
OffsetT warp_num_runs_aggregate,
OffsetT warp_num_runs_exclusive_in_tile,
OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD],
Int2Type<true> is_warp_time_slice)
{
unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
int lane_id = LaneId();
// Locally compact items within the warp (first warp)
if (warp_id == 0)
{
WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped(
lengths_and_offsets, thread_num_runs_exclusive_in_warp);
}
// Locally compact items within the warp (remaining warps)
#pragma unroll
for (int SLICE = 1; SLICE < WARPS; ++SLICE)
{
CTA_SYNC();
if (warp_id == SLICE)
{
WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped(
lengths_and_offsets, thread_num_runs_exclusive_in_warp);
}
}
// Global scatter
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
if ((ITEM * WARP_THREADS) < warp_num_runs_aggregate - lane_id)
{
OffsetT item_offset =
tile_num_runs_exclusive_in_global +
warp_num_runs_exclusive_in_tile +
(ITEM * WARP_THREADS) + lane_id;
// Scatter offset
d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key;
// Scatter length if not the first (global) length
if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0))
{
d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value;
}
}
}
}
/**
* Two-phase scatter
*/
template <bool FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OffsetT tile_num_runs_exclusive_in_global,
OffsetT warp_num_runs_aggregate,
OffsetT warp_num_runs_exclusive_in_tile,
OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD],
Int2Type<false> is_warp_time_slice)
{
unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
int lane_id = LaneId();
// Unzip
OffsetT run_offsets[ITEMS_PER_THREAD];
LengthT run_lengths[ITEMS_PER_THREAD];
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
run_offsets[ITEM] = lengths_and_offsets[ITEM].key;
run_lengths[ITEM] = lengths_and_offsets[ITEM].value;
}
WarpExchangeOffsets(temp_storage.aliasable.scatter_aliasable.exchange_offsets[warp_id]).ScatterToStriped(
run_offsets, thread_num_runs_exclusive_in_warp);
WARP_SYNC(0xffffffff);
WarpExchangeLengths(temp_storage.aliasable.scatter_aliasable.exchange_lengths[warp_id]).ScatterToStriped(
run_lengths, thread_num_runs_exclusive_in_warp);
// Global scatter
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
if ((ITEM * WARP_THREADS) + lane_id < warp_num_runs_aggregate)
{
OffsetT item_offset =
tile_num_runs_exclusive_in_global +
warp_num_runs_exclusive_in_tile +
(ITEM * WARP_THREADS) + lane_id;
// Scatter offset
d_offsets_out[item_offset] = run_offsets[ITEM];
// Scatter length if not the first (global) length
if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0))
{
d_lengths_out[item_offset - 1] = run_lengths[ITEM];
}
}
}
}
/**
* Direct scatter
*/
template <bool FIRST_TILE>
__device__ __forceinline__ void ScatterDirect(
OffsetT tile_num_runs_exclusive_in_global,
OffsetT warp_num_runs_aggregate,
OffsetT warp_num_runs_exclusive_in_tile,
OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD])
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (thread_num_runs_exclusive_in_warp[ITEM] < warp_num_runs_aggregate)
{
OffsetT item_offset =
tile_num_runs_exclusive_in_global +
warp_num_runs_exclusive_in_tile +
thread_num_runs_exclusive_in_warp[ITEM];
// Scatter offset
d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key;
// Scatter length if not the first (global) length
if (item_offset >= 1)
{
d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value;
}
}
}
}
/**
* Scatter
*/
template <bool FIRST_TILE>
__device__ __forceinline__ void Scatter(
OffsetT tile_num_runs_aggregate,
OffsetT tile_num_runs_exclusive_in_global,
OffsetT warp_num_runs_aggregate,
OffsetT warp_num_runs_exclusive_in_tile,
OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD])
{
if ((ITEMS_PER_THREAD == 1) || (tile_num_runs_aggregate < BLOCK_THREADS))
{
// Direct scatter if the warp has any items
if (warp_num_runs_aggregate)
{
ScatterDirect<FIRST_TILE>(
tile_num_runs_exclusive_in_global,
warp_num_runs_aggregate,
warp_num_runs_exclusive_in_tile,
thread_num_runs_exclusive_in_warp,
lengths_and_offsets);
}
}
else
{
// Scatter two phase
ScatterTwoPhase<FIRST_TILE>(
tile_num_runs_exclusive_in_global,
warp_num_runs_aggregate,
warp_num_runs_exclusive_in_tile,
thread_num_runs_exclusive_in_warp,
lengths_and_offsets,
Int2Type<STORE_WARP_TIME_SLICING>());
}
}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
/**
* Process a tile of input (dynamic chained scan)
*/
template <
bool LAST_TILE>
__device__ __forceinline__ LengthOffsetPair ConsumeTile(
OffsetT num_items, ///< Total number of global input items
OffsetT num_remaining, ///< Number of global input items remaining (including this tile)
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT &tile_status) ///< Global list of tile status
{
if (tile_idx == 0)
{
// First tile
// Load items
T items[ITEMS_PER_THREAD];
if (LAST_TILE)
BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T());
else
BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items);
if (SYNC_AFTER_LOAD)
CTA_SYNC();
// Set flags
LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD];
InitializeSelections<true, LAST_TILE>(
tile_offset,
num_remaining,
items,
lengths_and_num_runs);
// Exclusive scan of lengths and runs
LengthOffsetPair tile_aggregate;
LengthOffsetPair warp_aggregate;
LengthOffsetPair warp_exclusive_in_tile;
LengthOffsetPair thread_exclusive_in_warp;
WarpScanAllocations(
tile_aggregate,
warp_aggregate,
warp_exclusive_in_tile,
thread_exclusive_in_warp,
lengths_and_num_runs);
// Update tile status if this is not the last tile
if (!LAST_TILE && (threadIdx.x == 0))
tile_status.SetInclusive(0, tile_aggregate);
// Update thread_exclusive_in_warp to fold in warp run-length
if (thread_exclusive_in_warp.key == 0)
thread_exclusive_in_warp.value += warp_exclusive_in_tile.value;
LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD];
OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];
LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD];
// Downsweep scan through lengths_and_num_runs
internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);
// Zip
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value;
lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
lengths_and_num_runs2[ITEM].key : // keep
WARP_THREADS * ITEMS_PER_THREAD; // discard
}
OffsetT tile_num_runs_aggregate = tile_aggregate.key;
OffsetT tile_num_runs_exclusive_in_global = 0;
OffsetT warp_num_runs_aggregate = warp_aggregate.key;
OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key;
// Scatter
Scatter<true>(
tile_num_runs_aggregate,
tile_num_runs_exclusive_in_global,
warp_num_runs_aggregate,
warp_num_runs_exclusive_in_tile,
thread_num_runs_exclusive_in_warp,
lengths_and_offsets);
// Return running total (inclusive of this tile)
return tile_aggregate;
}
else
{
// Not first tile
// Load items
T items[ITEMS_PER_THREAD];
if (LAST_TILE)
BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T());
else
BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items);
if (SYNC_AFTER_LOAD)
CTA_SYNC();
// Set flags
LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD];
InitializeSelections<false, LAST_TILE>(
tile_offset,
num_remaining,
items,
lengths_and_num_runs);
// Exclusive scan of lengths and runs
LengthOffsetPair tile_aggregate;
LengthOffsetPair warp_aggregate;
LengthOffsetPair warp_exclusive_in_tile;
LengthOffsetPair thread_exclusive_in_warp;
WarpScanAllocations(
tile_aggregate,
warp_aggregate,
warp_exclusive_in_tile,
thread_exclusive_in_warp,
lengths_and_num_runs);
// First warp computes tile prefix in lane 0
TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.aliasable.prefix, Sum(), tile_idx);
unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
if (warp_id == 0)
{
prefix_op(tile_aggregate);
if (threadIdx.x == 0)
temp_storage.tile_exclusive = prefix_op.exclusive_prefix;
}
CTA_SYNC();
LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive;
// Update thread_exclusive_in_warp to fold in warp and tile run-lengths
LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile);
if (thread_exclusive_in_warp.key == 0)
thread_exclusive_in_warp.value += thread_exclusive.value;
// Downsweep scan through lengths_and_num_runs
LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD];
LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD];
OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];
internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);
// Zip
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value;
lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
lengths_and_num_runs2[ITEM].key : // keep
WARP_THREADS * ITEMS_PER_THREAD; // discard
}
OffsetT tile_num_runs_aggregate = tile_aggregate.key;
OffsetT tile_num_runs_exclusive_in_global = tile_exclusive_in_global.key;
OffsetT warp_num_runs_aggregate = warp_aggregate.key;
OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key;
// Scatter
Scatter<false>(
tile_num_runs_aggregate,
tile_num_runs_exclusive_in_global,
warp_num_runs_aggregate,
warp_num_runs_exclusive_in_tile,
thread_num_runs_exclusive_in_warp,
lengths_and_offsets);
// Return running total (inclusive of this tile)
return prefix_op.inclusive_prefix;
}
}
/**
* Scan tiles of items as part of a dynamic chained scan
*/
template <typename NumRunsIteratorT> ///< Output iterator type for recording number of items selected
__device__ __forceinline__ void ConsumeRange(
int num_tiles, ///< Total number of input tiles
ScanTileStateT& tile_status, ///< Global list of tile status
NumRunsIteratorT d_num_runs_out) ///< Output pointer for total number of runs identified
{
// Blocks are launched in increasing order, so just assign one tile per block
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index
OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile
OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile)
if (tile_idx < num_tiles - 1)
{
// Not the last tile (full)
ConsumeTile<false>(num_items, num_remaining, tile_idx, tile_offset, tile_status);
}
else if (num_remaining > 0)
{
// The last tile (possibly partially-full)
LengthOffsetPair running_total = ConsumeTile<true>(num_items, num_remaining, tile_idx, tile_offset, tile_status);
if (threadIdx.x == 0)
{
// Output the total number of items selected
*d_num_runs_out = running_total.key;
// The inclusive prefix contains accumulated length reduction for the last run
if (running_total.key > 0)
d_lengths_out[running_total.key - 1] = running_total.value;
}
}
}
};
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)