52   typename ElementOutput_,                             
    54   typename ElementAccumulator_ = ElementOutput_,       
    55   typename ElementCompute_ = ElementOutput_,           
    99     ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
   108     ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
   160     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
   171     intermediate = mul_add_source(beta_, converted_source);                             
   172     intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);    
   174     intermediate = max_accumulator(intermediate, threshold_);
   179     return destination_converter(intermediate);
   192   typename ElementOutput_,                             
   203   static int const kCount = Count;
   237     ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
   246     ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
   267     alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
   268     beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
   269     threshold_ = params.threshold;
   298     ComputeFragment converted_accumulator = accumulator_converter(accumulator);
   309     intermediate = mul_add_source(beta_, converted_source);                             
   310     intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);    
   313     intermediate = max_accumulator(intermediate, threshold_);
   319     for (
int i = 0; i < 
kCount; ++i) {
   320       scaled_accumulator[i] = 
static_cast<int>(intermediate[i]);
   326     return destination_converter(scaled_accumulator);
 Fused multiply-add. 
Definition: functional.h:92
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const 
Computes linear scaling: D = alpha * accumulator + beta * source. 
Definition: linear_combination_relu.h:150
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_relu.h:87
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:104
ElementCompute beta
scales source tensor 
Definition: linear_combination_relu.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:233
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:67
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:68
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory 
Definition: linear_combination_relu.h:218
Definition: linear_combination_relu.h:58
Definition: functional.h:235
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory 
Definition: linear_combination_relu.h:80
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory. 
Definition: linear_combination_relu.h:265
ElementCompute_ ElementCompute
Definition: linear_combination_relu.h:63
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue. 
Definition: linear_combination_relu.h:280
Boost-like numeric conversion operator for CUTLASS numeric types. 
ElementCompute alpha
scales accumulators 
Definition: linear_combination_relu.h:214
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory. 
Definition: linear_combination_relu.h:127
ElementCompute beta
scales source tensor 
Definition: linear_combination_relu.h:215
ElementCompute threshold
Relu threshold. 
Definition: linear_combination_relu.h:78
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:205
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:95
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static FloatRoundStyle const kRound
Definition: linear_combination_relu.h:71
Top-level include for all CUTLASS numeric types. 
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:69
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue. 
Definition: linear_combination_relu.h:142
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_relu.h:225
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:207
ElementOutput_ ElementOutput
Definition: linear_combination_relu.h:199
CUTLASS_HOST_DEVICE bool is_source_needed() const 
Returns true if source is needed. 
Definition: linear_combination_relu.h:274
FloatRoundStyle
Definition: numeric_conversion.h:43
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory 
Definition: linear_combination_relu.h:217
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:242
ElementCompute threshold
Relu threshold. 
Definition: linear_combination_relu.h:216
Conversion operator for Array. 
Definition: numeric_conversion.h:294
ElementCompute alpha
scales accumulators 
Definition: linear_combination_relu.h:76
int ElementAccumulator
Definition: linear_combination_relu.h:200
float ElementCompute
Definition: linear_combination_relu.h:201
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_relu.h:62
CUTLASS_HOST_DEVICE bool is_source_needed() const 
Returns true if source is needed. 
Definition: linear_combination_relu.h:136
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const 
Computes linear scaling: D = alpha * accumulator + beta * source. 
Definition: linear_combination_relu.h:288
static int const kCount
Definition: linear_combination_relu.h:65
Basic include for CUTLASS. 
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory 
Definition: linear_combination_relu.h:79
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:206
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Host-constructable parameters structure. 
Definition: linear_combination_relu.h:74
ElementOutput_ ElementOutput
Definition: linear_combination_relu.h:61