LIVE / thrust /cub /agent /agent_select_if.cuh
Xu Ma
update
1c3c0d9
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select.
*/
#pragma once
#include <iterator>
#include "single_pass_scan_operators.cuh"
#include "../block/block_load.cuh"
#include "../block/block_store.cuh"
#include "../block/block_scan.cuh"
#include "../block/block_exchange.cuh"
#include "../block/block_discontinuity.cuh"
#include "../config.cuh"
#include "../grid/grid_queue.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/******************************************************************************
* Tuning policy types
******************************************************************************/
/**
* Parameterizable tuning policy type for AgentSelectIf
*/
template <
int _BLOCK_THREADS, ///< Threads per thread block
int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use
CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements
BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use
struct AgentSelectIfPolicy
{
enum
{
BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block
ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input)
};
static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use
static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements
static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use
};
/******************************************************************************
* Thread block abstractions
******************************************************************************/
/**
* \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection
*
* Performs functor-based selection if SelectOpT functor type != NullType
* Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType
* Otherwise performs discontinuity selection (keep unique)
*/
template <
typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type
typename InputIteratorT, ///< Random-access input iterator type for selection items
typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection)
typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items
typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection)
typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection)
typename OffsetT, ///< Signed integer type for global offsets
bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output
struct AgentSelectIf
{
//---------------------------------------------------------------------
// Types and constants
//---------------------------------------------------------------------
// The input value type
typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
// The output value type
typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
// The flag value type
typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT;
// Tile status descriptor interface type
typedef ScanTileState<OffsetT> ScanTileStateT;
// Constants
enum
{
USE_SELECT_OP,
USE_SELECT_FLAGS,
USE_DISCONTINUITY,
BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1),
SELECT_METHOD = (!Equals<SelectOpT, NullType>::VALUE) ?
USE_SELECT_OP :
(!Equals<FlagT, NullType>::VALUE) ?
USE_SELECT_FLAGS :
USE_DISCONTINUITY
};
// Cache-modified Input iterator wrapper type (for applying cache modifier) for items
typedef typename If<IsPointer<InputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
InputIteratorT>::Type // Directly use the supplied input iterator type
WrappedInputIteratorT;
// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
FlagsInputIteratorT>::Type // Directly use the supplied input iterator type
WrappedFlagsInputIteratorT;
// Parameterized BlockLoad type for input data
typedef BlockLoad<
OutputT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>
BlockLoadT;
// Parameterized BlockLoad type for flags
typedef BlockLoad<
FlagT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
AgentSelectIfPolicyT::LOAD_ALGORITHM>
BlockLoadFlags;
// Parameterized BlockDiscontinuity type for items
typedef BlockDiscontinuity<
OutputT,
BLOCK_THREADS>
BlockDiscontinuityT;
// Parameterized BlockScan type
typedef BlockScan<
OffsetT,
BLOCK_THREADS,
AgentSelectIfPolicyT::SCAN_ALGORITHM>
BlockScanT;
// Callback type for obtaining tile prefix during block scan
typedef TilePrefixCallbackOp<
OffsetT,
cub::Sum,
ScanTileStateT>
TilePrefixCallbackOpT;
// Item exchange type
typedef OutputT ItemExchangeT[TILE_ITEMS];
// Shared memory type for this thread block
union _TempStorage
{
struct
{
typename BlockScanT::TempStorage scan; // Smem needed for tile scanning
typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback
typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection
};
// Smem needed for loading items
typename BlockLoadT::TempStorage load_items;
// Smem needed for loading values
typename BlockLoadFlags::TempStorage load_flags;
// Smem needed for compacting items (allows non POD items in this union)
Uninitialized<ItemExchangeT> raw_exchange;
};
// Alias wrapper allowing storage to be unioned
struct TempStorage : Uninitialized<_TempStorage> {};
//---------------------------------------------------------------------
// Per-thread fields
//---------------------------------------------------------------------
_TempStorage& temp_storage; ///< Reference to temp_storage
WrappedInputIteratorT d_in; ///< Input items
SelectedOutputIteratorT d_selected_out; ///< Unique output items
WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable)
InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator
SelectOpT select_op; ///< Selection operator
OffsetT num_items; ///< Total number of input items
//---------------------------------------------------------------------
// Constructor
//---------------------------------------------------------------------
// Constructor
__device__ __forceinline__
AgentSelectIf(
TempStorage &temp_storage, ///< Reference to temp_storage
InputIteratorT d_in, ///< Input data
FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable)
SelectedOutputIteratorT d_selected_out, ///< Output data
SelectOpT select_op, ///< Selection operator
EqualityOpT equality_op, ///< Equality operator
OffsetT num_items) ///< Total number of input items
:
temp_storage(temp_storage.Alias()),
d_in(d_in),
d_flags_in(d_flags_in),
d_selected_out(d_selected_out),
select_op(select_op),
inequality_op(equality_op),
num_items(num_items)
{}
//---------------------------------------------------------------------
// Utility methods for initializing the selections
//---------------------------------------------------------------------
/**
* Initialize selections (specialized for selection operator)
*/
template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
__device__ __forceinline__ void InitializeSelections(
OffsetT /*tile_offset*/,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_OP> /*select_method*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
// Out-of-bounds items are selection_flags
selection_flags[ITEM] = 1;
if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items))
selection_flags[ITEM] = select_op(items[ITEM]);
}
}
/**
* Initialize selections (specialized for valid flags)
*/
template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&/*items*/)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_SELECT_FLAGS> /*select_method*/)
{
CTA_SYNC();
FlagT flags[ITEMS_PER_THREAD];
if (IS_LAST_TILE)
{
// Out-of-bounds items are selection_flags
BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1);
}
else
{
BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags);
}
// Convert flag type to selection_flags type
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
selection_flags[ITEM] = flags[ITEM];
}
}
/**
* Initialize selections (specialized for discontinuity detection)
*/
template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
__device__ __forceinline__ void InitializeSelections(
OffsetT tile_offset,
OffsetT num_tile_items,
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
Int2Type<USE_DISCONTINUITY> /*select_method*/)
{
if (IS_FIRST_TILE)
{
CTA_SYNC();
// Set head selection_flags. First tile sets the first flag for the first item
BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op);
}
else
{
OutputT tile_predecessor;
if (threadIdx.x == 0)
tile_predecessor = d_in[tile_offset - 1];
CTA_SYNC();
BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor);
}
// Set selection flags for out-of-bounds items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
// Set selection_flags for out-of-bounds items
if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items))
selection_flags[ITEM] = 1;
}
}
//---------------------------------------------------------------------
// Scatter utility methods
//---------------------------------------------------------------------
/**
* Scatter flagged items to output offsets (specialized for direct scattering)
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterDirect(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
OffsetT num_selections)
{
// Scatter flagged items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
if (selection_flags[ITEM])
{
if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections)
{
d_selected_out[selection_indices[ITEM]] = items[ITEM];
}
}
}
}
/**
* Scatter flagged items to output offsets (specialized for two-phase scattering)
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int /*num_tile_items*/, ///< Number of valid items in this tile
int num_tile_selections, ///< Number of selections in this tile
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile
Int2Type<false> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition
{
CTA_SYNC();
// Compact and scatter items
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix;
if (selection_flags[ITEM])
{
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}
}
CTA_SYNC();
for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS)
{
d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item];
}
}
/**
* Scatter flagged items to output offsets (specialized for two-phase scattering)
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void ScatterTwoPhase(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
int num_tile_selections, ///< Number of selections in this tile
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile
Int2Type<true> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition
{
CTA_SYNC();
int tile_num_rejections = num_tile_items - num_tile_selections;
// Scatter items to shared memory (rejections first)
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
int local_selection_idx = selection_indices[ITEM] - num_selections_prefix;
int local_rejection_idx = item_idx - local_selection_idx;
int local_scatter_offset = (selection_flags[ITEM]) ?
tile_num_rejections + local_selection_idx :
local_rejection_idx;
temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
}
CTA_SYNC();
// Gather items from shared memory and scatter to global
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x;
int rejection_idx = item_idx;
int selection_idx = item_idx - tile_num_rejections;
OffsetT scatter_offset = (item_idx < tile_num_rejections) ?
num_items - num_rejected_prefix - rejection_idx - 1 :
num_selections_prefix + selection_idx;
OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
if (!IS_LAST_TILE || (item_idx < num_tile_items))
{
d_selected_out[scatter_offset] = item;
}
}
}
/**
* Scatter flagged items
*/
template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
__device__ __forceinline__ void Scatter(
OutputT (&items)[ITEMS_PER_THREAD],
OffsetT (&selection_flags)[ITEMS_PER_THREAD],
OffsetT (&selection_indices)[ITEMS_PER_THREAD],
int num_tile_items, ///< Number of valid items in this tile
int num_tile_selections, ///< Number of selections in this tile
OffsetT num_selections_prefix, ///< Total number of selections prior to this tile
OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile
OffsetT num_selections) ///< Total number of selections including this tile
{
// Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one
if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)))
{
ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
Int2Type<KEEP_REJECTS>());
}
else
{
ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>(
items,
selection_flags,
selection_indices,
num_selections);
}
}
//---------------------------------------------------------------------
// Cooperatively scan a device-wide sequence of tiles with other CTAs
//---------------------------------------------------------------------
/**
* Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile)
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ OffsetT ConsumeFirstTile(
int num_tile_items, ///< Number of input items comprising this tile
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];
// Load items
if (IS_LAST_TILE)
BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items);
else
BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
// Initialize selection_flags
InitializeSelections<true, IS_LAST_TILE>(
tile_offset,
num_tile_items,
items,
selection_flags,
Int2Type<SELECT_METHOD>());
CTA_SYNC();
// Exclusive scan of selection_flags
OffsetT num_tile_selections;
BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections);
if (threadIdx.x == 0)
{
// Update tile status if this is not the last tile
if (!IS_LAST_TILE)
tile_state.SetInclusive(0, num_tile_selections);
}
// Discount any out-of-bounds selections
if (IS_LAST_TILE)
num_tile_selections -= (TILE_ITEMS - num_tile_items);
// Scatter flagged items
Scatter<IS_LAST_TILE, true>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
0,
0,
num_tile_selections);
return num_tile_selections;
}
/**
* Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile)
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ OffsetT ConsumeSubsequentTile(
int num_tile_items, ///< Number of input items comprising this tile
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OutputT items[ITEMS_PER_THREAD];
OffsetT selection_flags[ITEMS_PER_THREAD];
OffsetT selection_indices[ITEMS_PER_THREAD];
// Load items
if (IS_LAST_TILE)
BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items);
else
BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
// Initialize selection_flags
InitializeSelections<false, IS_LAST_TILE>(
tile_offset,
num_tile_items,
items,
selection_flags,
Int2Type<SELECT_METHOD>());
CTA_SYNC();
// Exclusive scan of values and selection_flags
TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx);
BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op);
OffsetT num_tile_selections = prefix_op.GetBlockAggregate();
OffsetT num_selections = prefix_op.GetInclusivePrefix();
OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix();
OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix;
// Discount any out-of-bounds selections
if (IS_LAST_TILE)
{
int num_discount = TILE_ITEMS - num_tile_items;
num_selections -= num_discount;
num_tile_selections -= num_discount;
}
// Scatter flagged items
Scatter<IS_LAST_TILE, false>(
items,
selection_flags,
selection_indices,
num_tile_items,
num_tile_selections,
num_selections_prefix,
num_rejected_prefix,
num_selections);
return num_selections;
}
/**
* Process a tile of input
*/
template <bool IS_LAST_TILE>
__device__ __forceinline__ OffsetT ConsumeTile(
int num_tile_items, ///< Number of input items comprising this tile
int tile_idx, ///< Tile index
OffsetT tile_offset, ///< Tile offset
ScanTileStateT& tile_state) ///< Global tile state descriptor
{
OffsetT num_selections;
if (tile_idx == 0)
{
num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state);
}
else
{
num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state);
}
return num_selections;
}
/**
* Scan tiles of items as part of a dynamic chained scan
*/
template <typename NumSelectedIteratorT> ///< Output iterator type for recording number of items selection_flags
__device__ __forceinline__ void ConsumeRange(
int num_tiles, ///< Total number of input tiles
ScanTileStateT& tile_state, ///< Global tile state descriptor
NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags
{
// Blocks are launched in increasing order, so just assign one tile per block
int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index
OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile
if (tile_idx < num_tiles - 1)
{
// Not the last tile (full)
ConsumeTile<false>(TILE_ITEMS, tile_idx, tile_offset, tile_state);
}
else
{
// The last tile (possibly partially-full)
OffsetT num_remaining = num_items - tile_offset;
OffsetT num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state);
if (threadIdx.x == 0)
{
// Output the total number of items selection_flags
*d_num_selected_out = num_selections;
}
}
}
};
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)