gnss-sdr/src/algorithms/libs/opencl/fft_kernelstring.cc

1457 lines
62 KiB
C++

/*!
* \file fft_kernelstring.cc
*
* Version: <1.0>
*
* Copyright ( C ) 2008 Apple Inc. All Rights Reserved.
* SPDX-License-Identifier: LicenseRef-Apple-Permissive
*
*/
#include "clFFT.h"
#include "fft_internal.h"
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
using namespace std;
#define max(A, B) ((A) > (B) ? (A) : (B))
#define min(A, B) ((A) < (B) ? (A) : (B))
static string
num2str(int num)
{
char temp[200];
snprintf(temp, sizeof(temp), "%d", num);
return {temp};
}
// For any n, this function decomposes n into factors for loacal memory tranpose
// based fft. Factors (radices) are sorted such that the first one (radixArray[0])
// is the largest. This base radix determines the number of registers used by each
// work item and product of remaining radices determine the size of work group needed.
// To make things concrete with and example, suppose n = 1024. It is decomposed into
// 1024 = 16 x 16 x 4. Hence kernel uses float2 a[16], for local in-register fft and
// needs 16 x 4 = 64 work items per work group. So kernel first performance 64 length
// 16 ffts (64 work items working in parallel) following by transpose using local
// memory followed by again 64 length 16 ffts followed by transpose using local memory
// followed by 256 length 4 ffts. For the last step since with size of work group is
// 64 and each work item can array for 16 values, 64 work items can compute 256 length
// 4 ffts by each work item computing 4 length 4 ffts.
// Similarly for n = 2048 = 8 x 8 x 8 x 4, each work group has 8 x 8 x 4 = 256 work
// iterms which each computes 256 (in-parallel) length 8 ffts in-register, followed
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
// by transpose using local memory, followed by 256 length 8 in-register ffts, followed
// by transpose using local memory, followed by 512 length 4 in-register ffts. Again,
// for the last step, each work item computes two length 4 in-register ffts and thus
// 256 work items are needed to compute all 512 ffts.
// For n = 32 = 8 x 4, 4 work items first compute 4 in-register
// lenth 8 ffts, followed by transpose using local memory followed by 8 in-register
// length 4 ffts, where each work item computes two length 4 ffts thus 4 work items
// can compute 8 length 4 ffts. However if work group size of say 64 is choosen,
// each work group can compute 64/ 4 = 16 size 32 ffts (batched transform).
// Users can play with these parameters to figure what gives best performance on
// their particular device i.e. some device have less register space thus using
// smaller base radix can avoid spilling ... some has small local memory thus
// using smaller work group size may be required etc
static void
getRadixArray(unsigned int n, unsigned int *radixArray, unsigned int *numRadices, unsigned int maxRadix)
{
if (maxRadix > 1)
{
maxRadix = min(n, maxRadix);
unsigned int cnt = 0;
while (n > maxRadix)
{
radixArray[cnt++] = maxRadix;
n /= maxRadix;
}
radixArray[cnt++] = n;
*numRadices = cnt;
return;
}
switch (n)
{
case 2:
*numRadices = 1;
radixArray[0] = 2;
break;
case 4:
*numRadices = 1;
radixArray[0] = 4;
break;
case 8:
*numRadices = 1;
radixArray[0] = 8;
break;
case 16:
*numRadices = 2;
radixArray[0] = 8;
radixArray[1] = 2;
break;
case 32:
*numRadices = 2;
radixArray[0] = 8;
radixArray[1] = 4;
break;
case 64:
*numRadices = 2;
radixArray[0] = 8;
radixArray[1] = 8;
break;
case 128:
*numRadices = 3;
radixArray[0] = 8;
radixArray[1] = 4;
radixArray[2] = 4;
break;
case 256:
*numRadices = 4;
radixArray[0] = 4;
radixArray[1] = 4;
radixArray[2] = 4;
radixArray[3] = 4;
break;
case 512:
*numRadices = 3;
radixArray[0] = 8;
radixArray[1] = 8;
radixArray[2] = 8;
break;
case 1024:
*numRadices = 3;
radixArray[0] = 16;
radixArray[1] = 16;
radixArray[2] = 4;
break;
case 2048:
*numRadices = 4;
radixArray[0] = 8;
radixArray[1] = 8;
radixArray[2] = 8;
radixArray[3] = 4;
break;
default:
*numRadices = 0;
return;
}
}
static void
insertHeader(string &kernelString, string &kernelName, clFFT_DataFormat dataFormat)
{
if (dataFormat == clFFT_SplitComplexFormat)
{
kernelString += string("__kernel void ") + kernelName + string("(__global float *in_real, __global float *in_imag, __global float *out_real, __global float *out_imag, int dir, int S)\n");
}
else
{
kernelString += string("__kernel void ") + kernelName + string("(__global float2 *in, __global float2 *out, int dir, int S)\n");
}
}
static void
insertVariables(string &kStream, int maxRadix)
{
kStream += string(" int i, j, r, indexIn, indexOut, index, tid, bNum, xNum, k, l;\n");
kStream += string(" int s, ii, jj, offset;\n");
kStream += string(" float2 w;\n");
kStream += string(" float ang, angf, ang1;\n");
kStream += string(" __local float *lMemStore, *lMemLoad;\n");
kStream += string(" float2 a[") + num2str(maxRadix) + string("];\n");
kStream += string(" int lId = get_local_id( 0 );\n");
kStream += string(" int groupId = get_group_id( 0 );\n");
}
static void
formattedLoad(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
{
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" a[") + num2str(aIndex) + string("] = in[") + num2str(gIndex) + string("];\n");
}
else
{
kernelString += string(" a[") + num2str(aIndex) + string("].x = in_real[") + num2str(gIndex) + string("];\n");
kernelString += string(" a[") + num2str(aIndex) + string("].y = in_imag[") + num2str(gIndex) + string("];\n");
}
}
static void
formattedStore(string &kernelString, int aIndex, int gIndex, clFFT_DataFormat dataFormat)
{
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" out[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("];\n");
}
else
{
kernelString += string(" out_real[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].x;\n");
kernelString += string(" out_imag[") + num2str(gIndex) + string("] = a[") + num2str(aIndex) + string("].y;\n");
}
}
static int
insertGlobalLoadsAndTranspose(string &kernelString, int N, int numWorkItemsPerXForm, int numXFormsPerWG, int R0, int mem_coalesce_width, clFFT_DataFormat dataFormat)
{
int log2NumWorkItemsPerXForm = (int)log2(numWorkItemsPerXForm);
int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
int i;
int j;
int lMemSize = 0;
if (numXFormsPerWG > 1)
{
kernelString += string(" s = S & ") + num2str(numXFormsPerWG - 1) + string(";\n");
}
if (numWorkItemsPerXForm >= mem_coalesce_width)
{
if (numXFormsPerWG > 1)
{
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
kernelString += string(" offset = mad24( mad24(groupId, ") + num2str(numXFormsPerWG) + string(", jj), ") + num2str(N) + string(", ii );\n");
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" in += offset;\n");
kernelString += string(" out += offset;\n");
}
else
{
kernelString += string(" in_real += offset;\n");
kernelString += string(" in_imag += offset;\n");
kernelString += string(" out_real += offset;\n");
kernelString += string(" out_imag += offset;\n");
}
for (i = 0; i < R0; i++)
{
formattedLoad(kernelString, i, i * numWorkItemsPerXForm, dataFormat);
}
kernelString += string(" }\n");
}
else
{
kernelString += string(" ii = lId;\n");
kernelString += string(" jj = 0;\n");
kernelString += string(" offset = mad24(groupId, ") + num2str(N) + string(", ii);\n");
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" in += offset;\n");
kernelString += string(" out += offset;\n");
}
else
{
kernelString += string(" in_real += offset;\n");
kernelString += string(" in_imag += offset;\n");
kernelString += string(" out_real += offset;\n");
kernelString += string(" out_imag += offset;\n");
}
for (i = 0; i < R0; i++)
{
formattedLoad(kernelString, i, i * numWorkItemsPerXForm, dataFormat);
}
}
}
else if (N >= mem_coalesce_width)
{
int numInnerIter = N / mem_coalesce_width;
int numOuterIter = numXFormsPerWG / (groupSize / mem_coalesce_width);
kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
kernelString += string(" offset = mad24( groupId, ") + num2str(numXFormsPerWG) + string(", jj);\n");
kernelString += string(" offset = mad24( offset, ") + num2str(N) + string(", ii );\n");
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" in += offset;\n");
kernelString += string(" out += offset;\n");
}
else
{
kernelString += string(" in_real += offset;\n");
kernelString += string(" in_imag += offset;\n");
kernelString += string(" out_real += offset;\n");
kernelString += string(" out_imag += offset;\n");
}
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
for (i = 0; i < numOuterIter; i++)
{
kernelString += string(" if( jj < s ) {\n");
for (j = 0; j < numInnerIter; j++)
{
formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * N, dataFormat);
}
kernelString += string(" }\n");
if (i != numOuterIter - 1)
{
kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
}
}
kernelString += string("}\n ");
kernelString += string("else {\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
formattedLoad(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * N, dataFormat);
}
}
kernelString += string("}\n");
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii);\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * (N + numWorkItemsPerXForm)) + string("] = a[") +
num2str(i * numInnerIter + j) + string("].x;\n");
}
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
kernelString += string(" lMemStore[") + num2str(j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * (N + numWorkItemsPerXForm)) + string("] = a[") +
num2str(i * numInnerIter + j) + string("].y;\n");
}
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
}
else
{
kernelString += string(" offset = mad24( groupId, ") + num2str(N * numXFormsPerWG) + string(", lId );\n");
if (dataFormat == clFFT_InterleavedComplexFormat)
{
kernelString += string(" in += offset;\n");
kernelString += string(" out += offset;\n");
}
else
{
kernelString += string(" in_real += offset;\n");
kernelString += string(" in_imag += offset;\n");
kernelString += string(" out_real += offset;\n");
kernelString += string(" out_imag += offset;\n");
}
kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
kernelString += string(" lMemStore = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" if(jj < s )\n");
formattedLoad(kernelString, i, i * groupSize, dataFormat);
if (i != R0 - 1)
{
kernelString += string(" jj += ") + num2str(groupSize / N) + string(";\n");
}
}
kernelString += string("}\n");
kernelString += string("else {\n");
for (i = 0; i < R0; i++)
{
formattedLoad(kernelString, i, i * groupSize, dataFormat);
}
kernelString += string("}\n");
if (numWorkItemsPerXForm > 1)
{
kernelString += string(" ii = lId & ") + num2str(numWorkItemsPerXForm - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str(log2NumWorkItemsPerXForm) + string(";\n");
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
}
else
{
kernelString += string(" ii = 0;\n");
kernelString += string(" jj = lId;\n");
kernelString += string(" lMemLoad = sMem + mul24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(");\n");
}
for (i = 0; i < R0; i++)
{
kernelString += string(" lMemStore[") + num2str(i * (groupSize / N) * (N + numWorkItemsPerXForm)) + string("] = a[") + num2str(i) + string("].x;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" lMemStore[") + num2str(i * (groupSize / N) * (N + numWorkItemsPerXForm)) + string("] = a[") + num2str(i) + string("].y;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < R0; i++)
{
kernelString += string(" a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
}
return lMemSize;
}
static int
insertGlobalStoresAndTranspose(string &kernelString, int N, int maxRadix, int Nr, int numWorkItemsPerXForm, int numXFormsPerWG, int mem_coalesce_width, clFFT_DataFormat dataFormat)
{
int groupSize = numWorkItemsPerXForm * numXFormsPerWG;
int i;
int j;
int k;
int ind;
int lMemSize = 0;
int numIter = maxRadix / Nr;
string indent = string("");
if (numWorkItemsPerXForm >= mem_coalesce_width)
{
if (numXFormsPerWG > 1)
{
kernelString += string(" if( !s || (groupId < get_num_groups(0)-1) || (jj < s) ) {\n");
indent = string(" ");
}
for (i = 0; i < maxRadix; i++)
{
j = i % numIter;
k = i / numIter;
ind = j * Nr + k;
formattedStore(kernelString, ind, i * numWorkItemsPerXForm, dataFormat);
}
if (numXFormsPerWG > 1)
{
kernelString += string(" }\n");
}
}
else if (N >= mem_coalesce_width)
{
int numInnerIter = N / mem_coalesce_width;
int numOuterIter = numXFormsPerWG / (groupSize / mem_coalesce_width);
kernelString += string(" lMemLoad = sMem + mad24( jj, ") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
kernelString += string(" ii = lId & ") + num2str(mem_coalesce_width - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str((int)log2(mem_coalesce_width)) + string(";\n");
kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
for (i = 0; i < maxRadix; i++)
{
j = i % numIter;
k = i / numIter;
ind = j * Nr + k;
kernelString += string(" lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
kernelString += string(" a[") + num2str(i * numInnerIter + j) + string("].x = lMemStore[") + num2str(j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * (N + numWorkItemsPerXForm)) + string("];\n");
}
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < maxRadix; i++)
{
j = i % numIter;
k = i / numIter;
ind = j * Nr + k;
kernelString += string(" lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
kernelString += string(" a[") + num2str(i * numInnerIter + j) + string("].y = lMemStore[") + num2str(j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * (N + numWorkItemsPerXForm)) + string("];\n");
}
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
for (i = 0; i < numOuterIter; i++)
{
kernelString += string(" if( jj < s ) {\n");
for (j = 0; j < numInnerIter; j++)
{
formattedStore(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * N, dataFormat);
}
kernelString += string(" }\n");
if (i != numOuterIter - 1)
{
kernelString += string(" jj += ") + num2str(groupSize / mem_coalesce_width) + string(";\n");
}
}
kernelString += string("}\n");
kernelString += string("else {\n");
for (i = 0; i < numOuterIter; i++)
{
for (j = 0; j < numInnerIter; j++)
{
formattedStore(kernelString, i * numInnerIter + j, j * mem_coalesce_width + i * (groupSize / mem_coalesce_width) * N, dataFormat);
}
}
kernelString += string("}\n");
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
}
else
{
kernelString += string(" lMemLoad = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
kernelString += string(" ii = lId & ") + num2str(N - 1) + string(";\n");
kernelString += string(" jj = lId >> ") + num2str((int)log2(N)) + string(";\n");
kernelString += string(" lMemStore = sMem + mad24( jj,") + num2str(N + numWorkItemsPerXForm) + string(", ii );\n");
for (i = 0; i < maxRadix; i++)
{
j = i % numIter;
k = i / numIter;
ind = j * Nr + k;
kernelString += string(" lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].x;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < maxRadix; i++)
{
kernelString += string(" a[") + num2str(i) + string("].x = lMemStore[") + num2str(i * (groupSize / N) * (N + numWorkItemsPerXForm)) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < maxRadix; i++)
{
j = i % numIter;
k = i / numIter;
ind = j * Nr + k;
kernelString += string(" lMemLoad[") + num2str(i * numWorkItemsPerXForm) + string("] = a[") + num2str(ind) + string("].y;\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
for (i = 0; i < maxRadix; i++)
{
kernelString += string(" a[") + num2str(i) + string("].y = lMemStore[") + num2str(i * (groupSize / N) * (N + numWorkItemsPerXForm)) + string("];\n");
}
kernelString += string(" barrier( CLK_LOCAL_MEM_FENCE );\n");
kernelString += string("if((groupId == get_num_groups(0)-1) && s) {\n");
for (i = 0; i < maxRadix; i++)
{
kernelString += string(" if(jj < s ) {\n");
formattedStore(kernelString, i, i * groupSize, dataFormat);
kernelString += string(" }\n");
if (i != maxRadix - 1)
{
kernelString += string(" jj +=") + num2str(groupSize / N) + string(";\n");
}
}
kernelString += string("}\n");
kernelString += string("else {\n");
for (i = 0; i < maxRadix; i++)
{
formattedStore(kernelString, i, i * groupSize, dataFormat);
}
kernelString += string("}\n");
lMemSize = (N + numWorkItemsPerXForm) * numXFormsPerWG;
}
return lMemSize;
}
static void
insertfftKernel(string &kernelString, int Nr, int numIter)
{
int i;
for (i = 0; i < numIter; i++)
{
kernelString += string(" fftKernel") + num2str(Nr) + string("(a+") + num2str(i * Nr) + string(", dir);\n");
}
}
static void
insertTwiddleKernel(string &kernelString, int Nr, int numIter, int Nprev, int len, int numWorkItemsPerXForm)
{
int z;
int k;
int logNPrev = (int)log2(Nprev);
for (z = 0; z < numIter; z++)
{
if (z == 0)
{
if (Nprev > 1)
{
kernelString += string(" angf = (float) (ii >> ") + num2str(logNPrev) + string(");\n");
}
else
{
kernelString += string(" angf = (float) ii;\n");
}
}
else
{
if (Nprev > 1)
{
kernelString += string(" angf = (float) ((") + num2str(z * numWorkItemsPerXForm) + string(" + ii) >>") + num2str(logNPrev) + string(");\n");
}
else
{
kernelString += string(" angf = (float) (") + num2str(z * numWorkItemsPerXForm) + string(" + ii);\n");
}
}
for (k = 1; k < Nr; k++)
{
int ind = z * Nr + k;
// float fac = (float) (2.0 * M_PI * (double) k / (double) len);
kernelString += string(" ang = dir * ( 2.0f * M_PI * ") + num2str(k) + string(".0f / ") + num2str(len) + string(".0f )") + string(" * angf;\n");
kernelString += string(" w = (float2)(native_cos(ang), native_sin(ang));\n");
kernelString += string(" a[") + num2str(ind) + string("] = complexMul(a[") + num2str(ind) + string("], w);\n");
}
}
}
static int
getPadding(int numWorkItemsPerXForm, int Nprev, int numWorkItemsReq, int numXFormsPerWG, int Nr, int numBanks, int *offset, int *midPad)
{
if ((numWorkItemsPerXForm <= Nprev) || (Nprev >= numBanks))
{
*offset = 0;
}
else
{
int numRowsReq = ((numWorkItemsPerXForm < numBanks) ? numWorkItemsPerXForm : numBanks) / Nprev;
int numColsReq = 1;
if (numRowsReq > Nr)
{
numColsReq = numRowsReq / Nr;
}
numColsReq = Nprev * numColsReq;
*offset = numColsReq;
}
if (numWorkItemsPerXForm >= numBanks || numXFormsPerWG == 1)
{
*midPad = 0;
}
else
{
int bankNum = ((numWorkItemsReq + *offset) * Nr) & (numBanks - 1);
if (bankNum >= numWorkItemsPerXForm)
{
*midPad = 0;
}
else
{
*midPad = numWorkItemsPerXForm - bankNum;
}
}
int lMemSize = (numWorkItemsReq + *offset) * Nr * numXFormsPerWG + *midPad * (numXFormsPerWG - 1);
return lMemSize;
}
static void
insertLocalStores(string &kernelString, int numIter, int Nr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
{
int z;
int k;
for (z = 0; z < numIter; z++)
{
for (k = 0; k < Nr; k++)
{
int index = k * (numWorkItemsReq + offset) + z * numWorkItemsPerXForm;
kernelString += string(" lMemStore[") + num2str(index) + string("] = a[") + num2str(z * Nr + k) + string("].") + comp + string(";\n");
}
}
kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
}
static void
insertLocalLoads(string &kernelString, int n, int Nr, int Nrn, int Nprev, int Ncurr, int numWorkItemsPerXForm, int numWorkItemsReq, int offset, string &comp)
{
int numWorkItemsReqN = n / Nrn;
int interBlockHNum = max(Nprev / numWorkItemsPerXForm, 1);
int interBlockHStride = numWorkItemsPerXForm;
int vertWidth = max(numWorkItemsPerXForm / Nprev, 1);
vertWidth = min(vertWidth, Nr);
int vertNum = Nr / vertWidth;
int vertStride = (n / Nr + offset) * vertWidth;
int iter = max(numWorkItemsReqN / numWorkItemsPerXForm, 1);
int intraBlockHStride = (numWorkItemsPerXForm / (Nprev * Nr)) > 1 ? (numWorkItemsPerXForm / (Nprev * Nr)) : 1;
intraBlockHStride *= Nprev;
int stride = numWorkItemsReq / Nrn;
int i;
for (i = 0; i < iter; i++)
{
int ii = i / (interBlockHNum * vertNum);
int zz = i % (interBlockHNum * vertNum);
int jj = zz % interBlockHNum;
int kk = zz / interBlockHNum;
int z;
for (z = 0; z < Nrn; z++)
{
int st = kk * vertStride + jj * interBlockHStride + ii * intraBlockHStride + z * stride;
kernelString += string(" a[") + num2str(i * Nrn + z) + string("].") + comp + string(" = lMemLoad[") + num2str(st) + string("];\n");
}
}
kernelString += string(" barrier(CLK_LOCAL_MEM_FENCE);\n");
}
static void
insertLocalLoadIndexArithmatic(string &kernelString, int Nprev, int Nr, int numWorkItemsReq, int numWorkItemsPerXForm, int numXFormsPerWG, int offset, int midPad)
{
int Ncurr = Nprev * Nr;
int logNcurr = (int)log2(Ncurr);
int logNprev = (int)log2(Nprev);
int incr = (numWorkItemsReq + offset) * Nr + midPad;
if (Ncurr < numWorkItemsPerXForm)
{
if (Nprev == 1)
{
kernelString += string(" j = ii & ") + num2str(Ncurr - 1) + string(";\n");
}
else
{
kernelString += string(" j = (ii & ") + num2str(Ncurr - 1) + string(") >> ") + num2str(logNprev) + string(";\n");
}
if (Nprev == 1)
{
kernelString += string(" i = ii >> ") + num2str(logNcurr) + string(";\n");
}
else
{
kernelString += string(" i = mad24(ii >> ") + num2str(logNcurr) + string(", ") + num2str(Nprev) + string(", ii & ") + num2str(Nprev - 1) + string(");\n");
}
}
else
{
if (Nprev == 1)
{
kernelString += string(" j = ii;\n");
}
else
{
kernelString += string(" j = ii >> ") + num2str(logNprev) + string(";\n");
}
if (Nprev == 1)
{
kernelString += string(" i = 0;\n");
}
else
{
kernelString += string(" i = ii & ") + num2str(Nprev - 1) + string(";\n");
}
}
if (numXFormsPerWG > 1)
{
kernelString += string(" i = mad24(jj, ") + num2str(incr) + string(", i);\n");
}
kernelString += string(" lMemLoad = sMem + mad24(j, ") + num2str(numWorkItemsReq + offset) + string(", i);\n");
}
static void
insertLocalStoreIndexArithmatic(string &kernelString, int numWorkItemsReq, int numXFormsPerWG, int Nr, int offset, int midPad)
{
if (numXFormsPerWG == 1)
{
kernelString += string(" lMemStore = sMem + ii;\n");
}
else
{
kernelString += string(" lMemStore = sMem + mad24(jj, ") + num2str((numWorkItemsReq + offset) * Nr + midPad) + string(", ii);\n");
}
}
static void
createLocalMemfftKernelString(cl_fft_plan *plan)
{
unsigned int radixArray[10];
unsigned int numRadix;
unsigned int n = plan->n.x;
assert(n <= plan->max_work_item_per_workgroup * plan->max_radix && "signal lenght too big for local mem fft\n");
getRadixArray(n, radixArray, &numRadix, 0);
assert(numRadix > 0 && "no radix array supplied\n");
if (n / radixArray[0] > plan->max_work_item_per_workgroup)
{
getRadixArray(n, radixArray, &numRadix, plan->max_radix);
}
assert(radixArray[0] <= plan->max_radix && "max radix choosen is greater than allowed\n");
assert(n / radixArray[0] <= plan->max_work_item_per_workgroup && "required work items per xform greater than maximum work items allowed per work group for local mem fft\n");
unsigned int tmpLen = 1;
unsigned int i;
for (i = 0; i < numRadix; i++)
{
assert(radixArray[i] && !((radixArray[i] - 1) & radixArray[i]));
tmpLen *= radixArray[i];
}
assert(tmpLen == n && "product of radices choosen doesnt match the length of signal\n");
int offset;
int midPad;
string localString("");
string kernelName("");
clFFT_DataFormat dataFormat = plan->format;
string *kernelString = plan->kernel_string;
cl_fft_kernel_info **kInfo = &plan->kernel_info;
int kCount = 0;
while (*kInfo)
{
kInfo = &(*kInfo)->next;
kCount++;
}
kernelName = string("fft") + num2str(kCount);
*kInfo = (cl_fft_kernel_info *)malloc(sizeof(cl_fft_kernel_info));
(*kInfo)->kernel = nullptr;
(*kInfo)->lmem_size = 0;
(*kInfo)->num_workgroups = 0;
(*kInfo)->num_workitems_per_workgroup = 0;
(*kInfo)->dir = cl_fft_kernel_x;
(*kInfo)->in_place_possible = 1;
(*kInfo)->next = nullptr;
(*kInfo)->kernel_name = (char *)malloc(sizeof(char) * (kernelName.size() + 1));
snprintf((*kInfo)->kernel_name, sizeof((*kInfo)->kernel_name), kernelName.c_str());
unsigned int numWorkItemsPerXForm = n / radixArray[0];
unsigned int numWorkItemsPerWG = numWorkItemsPerXForm <= 64 ? 64 : numWorkItemsPerXForm;
assert(numWorkItemsPerWG <= plan->max_work_item_per_workgroup);
int numXFormsPerWG = numWorkItemsPerWG / numWorkItemsPerXForm;
(*kInfo)->num_workgroups = 1;
(*kInfo)->num_xforms_per_workgroup = numXFormsPerWG;
(*kInfo)->num_workitems_per_workgroup = numWorkItemsPerWG;
unsigned int *N = radixArray;
unsigned int maxRadix = N[0];
unsigned int lMemSize = 0;
insertVariables(localString, maxRadix);
lMemSize = insertGlobalLoadsAndTranspose(localString, n, numWorkItemsPerXForm, numXFormsPerWG, maxRadix, plan->min_mem_coalesce_width, dataFormat);
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
string xcomp = string("x");
string ycomp = string("y");
unsigned int Nprev = 1;
unsigned int len = n;
unsigned int r;
for (r = 0; r < numRadix; r++)
{
int numIter = N[0] / N[r];
int numWorkItemsReq = n / N[r];
int Ncurr = Nprev * N[r];
insertfftKernel(localString, N[r], numIter);
if (r < (numRadix - 1))
{
insertTwiddleKernel(localString, N[r], numIter, Nprev, len, numWorkItemsPerXForm);
lMemSize = getPadding(numWorkItemsPerXForm, Nprev, numWorkItemsReq, numXFormsPerWG, N[r], plan->num_local_mem_banks, &offset, &midPad);
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
insertLocalStoreIndexArithmatic(localString, numWorkItemsReq, numXFormsPerWG, N[r], offset, midPad);
insertLocalLoadIndexArithmatic(localString, Nprev, N[r], numWorkItemsReq, numWorkItemsPerXForm, numXFormsPerWG, offset, midPad);
insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
insertLocalLoads(localString, n, N[r], N[r + 1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, xcomp);
insertLocalStores(localString, numIter, N[r], numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
insertLocalLoads(localString, n, N[r], N[r + 1], Nprev, Ncurr, numWorkItemsPerXForm, numWorkItemsReq, offset, ycomp);
Nprev = Ncurr;
len = len / N[r];
}
}
lMemSize = insertGlobalStoresAndTranspose(localString, n, maxRadix, N[numRadix - 1], numWorkItemsPerXForm, numXFormsPerWG, plan->min_mem_coalesce_width, dataFormat);
(*kInfo)->lmem_size = (lMemSize > (*kInfo)->lmem_size) ? lMemSize : (*kInfo)->lmem_size;
insertHeader(*kernelString, kernelName, dataFormat);
*kernelString += string("{\n");
if ((*kInfo)->lmem_size)
{
*kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
}
*kernelString += localString;
*kernelString += string("}\n");
}
// For n larger than what can be computed using local memory fft, global transposes
// multiple kernel launces is needed. For these sizes, n can be decomposed using
// much larger base radices i.e. say n = 262144 = 128 x 64 x 32. Thus three kernel
// launches will be needed, first computing 64 x 32, length 128 ffts, second computing
// 128 x 32 length 64 ffts, and finally a kernel computing 128 x 64 length 32 ffts.
// Each of these base radices can futher be divided into factors so that each of these
// base ffts can be computed within one kernel launch using in-register ffts and local
// memory transposes i.e for the first kernel above which computes 64 x 32 ffts on length
// 128, 128 can be decomposed into 128 = 16 x 8 i.e. 8 work items can compute 8 length
// 16 ffts followed by transpose using local memory followed by each of these eight
// work items computing 2 length 8 ffts thus computing 16 length 8 ffts in total. This
// means only 8 work items are needed for computing one length 128 fft. If we choose
// work group size of say 64, we can compute 64/8 = 8 length 128 ffts within one
// work group. Since we need to compute 64 x 32 length 128 ffts in first kernel, this
// means we need to launch 64 x 32 / 8 = 256 work groups with 64 work items in each
// work group where each work group is computing 8 length 128 ffts where each length
// 128 fft is computed by 8 work items. Same logic can be applied to other two kernels
// in this example. Users can play with difference base radices and difference
// decompositions of base radices to generates different kernels and see which gives
// best performance. Following function is just fixed to use 128 as base radix
void getGlobalRadixInfo(int n, int *radix, int *R1, int *R2, int *numRadices)
{
int baseRadix = min(n, 128);
int numR = 0;
int N = n;
while (N > baseRadix)
{
N /= baseRadix;
numR++;
}
for (int i = 0; i < numR; i++)
{
radix[i] = baseRadix;
}
radix[numR] = N;
numR++;
*numRadices = numR;
for (int i = 0; i < numR; i++)
{
int B = radix[i];
if (B <= 8)
{
R1[i] = B;
R2[i] = 1;
continue;
}
int r1 = 2;
int r2 = B / r1;
while (r2 > r1)
{
r1 *= 2;
r2 = B / r1;
}
R1[i] = r1;
R2[i] = r2;
}
}
static void
createGlobalFFTKernelString(cl_fft_plan *plan, int n, int BS, cl_fft_kernel_dir dir, int vertBS)
{
int i;
int j;
int k;
int t;
int radixArr[10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int R1Arr[10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int R2Arr[10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int radix;
int R1;
int R2;
int numRadices;
int maxThreadsPerBlock = plan->max_work_item_per_workgroup;
int maxArrayLen = plan->max_radix;
int batchSize = plan->min_mem_coalesce_width;
clFFT_DataFormat dataFormat = plan->format;
int vertical = (dir == cl_fft_kernel_x) ? 0 : 1;
getGlobalRadixInfo(n, radixArr, R1Arr, R2Arr, &numRadices);
int numPasses = numRadices;
string localString("");
string kernelName("");
string *kernelString = plan->kernel_string;
cl_fft_kernel_info **kInfo = &plan->kernel_info;
int kCount = 0;
while (*kInfo)
{
kInfo = &(*kInfo)->next;
kCount++;
}
int N = n;
int m = (int)log2(n);
int Rinit = vertical ? BS : 1;
batchSize = vertical ? min(BS, batchSize) : batchSize;
int passNum;
for (passNum = 0; passNum < numPasses; passNum++)
{
localString.clear();
kernelName.clear();
radix = radixArr[passNum];
R1 = R1Arr[passNum];
R2 = R2Arr[passNum];
int strideI = Rinit;
for (i = 0; i < numPasses; i++)
{
if (i != passNum)
{
strideI *= radixArr[i];
}
}
int strideO = Rinit;
for (i = 0; i < passNum; i++)
{
strideO *= radixArr[i];
}
int threadsPerXForm = R2;
batchSize = R2 == 1 ? plan->max_work_item_per_workgroup : batchSize;
batchSize = min(batchSize, strideI);
int threadsPerBlock = batchSize * threadsPerXForm;
threadsPerBlock = min(threadsPerBlock, maxThreadsPerBlock);
batchSize = threadsPerBlock / threadsPerXForm;
assert(R2 <= R1);
assert(R1 * R2 == radix);
assert(R1 <= maxArrayLen);
assert(threadsPerBlock <= maxThreadsPerBlock);
int numIter = R1 / R2;
int gInInc = threadsPerBlock / batchSize;
int lgStrideO = (int)log2(strideO);
int numBlocksPerXForm = strideI / batchSize;
int numBlocks = numBlocksPerXForm;
if (!vertical)
{
numBlocks *= BS;
}
else
{
numBlocks *= vertBS;
}
kernelName = string("fft") + num2str(kCount);
*kInfo = (cl_fft_kernel_info *)malloc(sizeof(cl_fft_kernel_info));
(*kInfo)->kernel = nullptr;
if (R2 == 1)
{
(*kInfo)->lmem_size = 0;
}
else
{
if (strideO == 1)
{
(*kInfo)->lmem_size = (radix + 1) * batchSize;
}
else
{
(*kInfo)->lmem_size = threadsPerBlock * R1;
}
}
(*kInfo)->num_workgroups = numBlocks;
(*kInfo)->num_xforms_per_workgroup = 1;
(*kInfo)->num_workitems_per_workgroup = threadsPerBlock;
(*kInfo)->dir = dir;
if ((passNum == (numPasses - 1)) && (numPasses & 1))
{
(*kInfo)->in_place_possible = 1;
}
else
{
(*kInfo)->in_place_possible = 0;
}
(*kInfo)->next = nullptr;
(*kInfo)->kernel_name = (char *)malloc(sizeof(char) * (kernelName.size() + 1));
snprintf((*kInfo)->kernel_name, sizeof((*kInfo)->kernel_name), kernelName.c_str());
insertVariables(localString, R1);
if (vertical)
{
localString += string("xNum = groupId >> ") + num2str((int)log2(numBlocksPerXForm)) + string(";\n");
localString += string("groupId = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
localString += string("indexIn = mad24(groupId, ") + num2str(batchSize) + string(", xNum << ") + num2str((int)log2(n * BS)) + string(");\n");
localString += string("tid = mul24(groupId, ") + num2str(batchSize) + string(");\n");
localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
int stride = radix * Rinit;
for (i = 0; i < passNum; i++)
{
stride *= radixArr[i];
}
localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j + ") + string("(xNum << ") + num2str((int)log2(n * BS)) + string("));\n");
localString += string("bNum = groupId;\n");
}
else
{
int lgNumBlocksPerXForm = (int)log2(numBlocksPerXForm);
localString += string("bNum = groupId & ") + num2str(numBlocksPerXForm - 1) + string(";\n");
localString += string("xNum = groupId >> ") + num2str(lgNumBlocksPerXForm) + string(";\n");
localString += string("indexIn = mul24(bNum, ") + num2str(batchSize) + string(");\n");
localString += string("tid = indexIn;\n");
localString += string("i = tid >> ") + num2str(lgStrideO) + string(";\n");
localString += string("j = tid & ") + num2str(strideO - 1) + string(";\n");
int stride = radix * Rinit;
for (i = 0; i < passNum; i++)
{
stride *= radixArr[i];
}
localString += string("indexOut = mad24(i, ") + num2str(stride) + string(", j);\n");
localString += string("indexIn += (xNum << ") + num2str(m) + string(");\n");
localString += string("indexOut += (xNum << ") + num2str(m) + string(");\n");
}
// Load Data
int lgBatchSize = (int)log2(batchSize);
localString += string("tid = lId;\n");
localString += string("i = tid & ") + num2str(batchSize - 1) + string(";\n");
localString += string("j = tid >> ") + num2str(lgBatchSize) + string(";\n");
localString += string("indexIn += mad24(j, ") + num2str(strideI) + string(", i);\n");
if (dataFormat == clFFT_SplitComplexFormat)
{
localString += string("in_real += indexIn;\n");
localString += string("in_imag += indexIn;\n");
for (j = 0; j < R1; j++)
{
localString += string("a[") + num2str(j) + string("].x = in_real[") + num2str(j * gInInc * strideI) + string("];\n");
}
for (j = 0; j < R1; j++)
{
localString += string("a[") + num2str(j) + string("].y = in_imag[") + num2str(j * gInInc * strideI) + string("];\n");
}
}
else
{
localString += string("in += indexIn;\n");
for (j = 0; j < R1; j++)
{
localString += string("a[") + num2str(j) + string("] = in[") + num2str(j * gInInc * strideI) + string("];\n");
}
}
localString += string("fftKernel") + num2str(R1) + string("(a, dir);\n");
if (R2 > 1)
{
// twiddle
for (k = 1; k < R1; k++)
{
localString += string("ang = dir*(2.0f*M_PI*") + num2str(k) + string("/") + num2str(radix) + string(")*j;\n");
localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
localString += string("a[") + num2str(k) + string("] = complexMul(a[") + num2str(k) + string("], w);\n");
}
// shuffle
numIter = R1 / R2;
localString += string("indexIn = mad24(j, ") + num2str(threadsPerBlock * numIter) + string(", i);\n");
localString += string("lMemStore = sMem + tid;\n");
localString += string("lMemLoad = sMem + indexIn;\n");
for (k = 0; k < R1; k++)
{
localString += string("lMemStore[") + num2str(k * threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
for (k = 0; k < numIter; k++)
{
for (t = 0; t < R2; t++)
{
localString += string("a[") + num2str(k * R2 + t) + string("].x = lMemLoad[") + num2str(t * batchSize + k * threadsPerBlock) + string("];\n");
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
for (k = 0; k < R1; k++)
{
localString += string("lMemStore[") + num2str(k * threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
for (k = 0; k < numIter; k++)
{
for (t = 0; t < R2; t++)
{
localString += string("a[") + num2str(k * R2 + t) + string("].y = lMemLoad[") + num2str(t * batchSize + k * threadsPerBlock) + string("];\n");
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
for (j = 0; j < numIter; j++)
{
localString += string("fftKernel") + num2str(R2) + string("(a + ") + num2str(j * R2) + string(", dir);\n");
}
}
// twiddle
if (passNum < (numPasses - 1))
{
localString += string("l = ((bNum << ") + num2str(lgBatchSize) + string(") + i) >> ") + num2str(lgStrideO) + string(";\n");
localString += string("k = j << ") + num2str((int)log2(R1 / R2)) + string(";\n");
localString += string("ang1 = dir*(2.0f*M_PI/") + num2str(N) + string(")*l;\n");
for (t = 0; t < R1; t++)
{
localString += string("ang = ang1*(k + ") + num2str((t % R2) * R1 + (t / R2)) + string(");\n");
localString += string("w = (float2)(native_cos(ang), native_sin(ang));\n");
localString += string("a[") + num2str(t) + string("] = complexMul(a[") + num2str(t) + string("], w);\n");
}
}
// Store Data
if (strideO == 1)
{
localString += string("lMemStore = sMem + mad24(i, ") + num2str(radix + 1) + string(", j << ") + num2str((int)log2(R1 / R2)) + string(");\n");
localString += string("lMemLoad = sMem + mad24(tid >> ") + num2str((int)log2(radix)) + string(", ") + num2str(radix + 1) + string(", tid & ") + num2str(radix - 1) + string(");\n");
for (i = 0; i < R1 / R2; i++)
{
for (j = 0; j < R2; j++)
{
localString += string("lMemStore[ ") + num2str(i + j * R1) + string("] = a[") + num2str(i * R2 + j) + string("].x;\n");
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
if (threadsPerBlock >= radix)
{
for (i = 0; i < R1; i++)
{
localString += string("a[") + num2str(i) + string("].x = lMemLoad[") + num2str(i * (radix + 1) * (threadsPerBlock / radix)) + string("];\n");
}
}
else
{
int innerIter = radix / threadsPerBlock;
int outerIter = R1 / innerIter;
for (i = 0; i < outerIter; i++)
{
for (j = 0; j < innerIter; j++)
{
localString += string("a[") + num2str(i * innerIter + j) + string("].x = lMemLoad[") + num2str(j * threadsPerBlock + i * (radix + 1)) + string("];\n");
}
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
for (i = 0; i < R1 / R2; i++)
{
for (j = 0; j < R2; j++)
{
localString += string("lMemStore[ ") + num2str(i + j * R1) + string("] = a[") + num2str(i * R2 + j) + string("].y;\n");
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
if (threadsPerBlock >= radix)
{
for (i = 0; i < R1; i++)
{
localString += string("a[") + num2str(i) + string("].y = lMemLoad[") + num2str(i * (radix + 1) * (threadsPerBlock / radix)) + string("];\n");
}
}
else
{
int innerIter = radix / threadsPerBlock;
int outerIter = R1 / innerIter;
for (i = 0; i < outerIter; i++)
{
for (j = 0; j < innerIter; j++)
{
localString += string("a[") + num2str(i * innerIter + j) + string("].y = lMemLoad[") + num2str(j * threadsPerBlock + i * (radix + 1)) + string("];\n");
}
}
}
localString += string("barrier(CLK_LOCAL_MEM_FENCE);\n");
localString += string("indexOut += tid;\n");
if (dataFormat == clFFT_SplitComplexFormat)
{
localString += string("out_real += indexOut;\n");
localString += string("out_imag += indexOut;\n");
for (k = 0; k < R1; k++)
{
localString += string("out_real[") + num2str(k * threadsPerBlock) + string("] = a[") + num2str(k) + string("].x;\n");
}
for (k = 0; k < R1; k++)
{
localString += string("out_imag[") + num2str(k * threadsPerBlock) + string("] = a[") + num2str(k) + string("].y;\n");
}
}
else
{
localString += string("out += indexOut;\n");
for (k = 0; k < R1; k++)
{
localString += string("out[") + num2str(k * threadsPerBlock) + string("] = a[") + num2str(k) + string("];\n");
}
}
}
else
{
localString += string("indexOut += mad24(j, ") + num2str(numIter * strideO) + string(", i);\n");
if (dataFormat == clFFT_SplitComplexFormat)
{
localString += string("out_real += indexOut;\n");
localString += string("out_imag += indexOut;\n");
for (k = 0; k < R1; k++)
{
localString += string("out_real[") + num2str(((k % R2) * R1 + (k / R2)) * strideO) + string("] = a[") + num2str(k) + string("].x;\n");
}
for (k = 0; k < R1; k++)
{
localString += string("out_imag[") + num2str(((k % R2) * R1 + (k / R2)) * strideO) + string("] = a[") + num2str(k) + string("].y;\n");
}
}
else
{
localString += string("out += indexOut;\n");
for (k = 0; k < R1; k++)
{
localString += string("out[") + num2str(((k % R2) * R1 + (k / R2)) * strideO) + string("] = a[") + num2str(k) + string("];\n");
}
}
}
insertHeader(*kernelString, kernelName, dataFormat);
*kernelString += string("{\n");
if ((*kInfo)->lmem_size)
{
*kernelString += string(" __local float sMem[") + num2str((*kInfo)->lmem_size) + string("];\n");
}
*kernelString += localString;
*kernelString += string("}\n");
N /= radix;
kInfo = &(*kInfo)->next;
kCount++;
}
}
void FFT1D(cl_fft_plan *plan, cl_fft_kernel_dir dir)
{
unsigned int radixArray[10];
unsigned int numRadix;
switch (dir)
{
case cl_fft_kernel_x:
if (plan->n.x > plan->max_localmem_fft_size)
{
createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
}
else if (plan->n.x > 1)
{
getRadixArray(plan->n.x, radixArray, &numRadix, 0);
if (plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
{
createLocalMemfftKernelString(plan);
}
else
{
getRadixArray(plan->n.x, radixArray, &numRadix, plan->max_radix);
if (plan->n.x / radixArray[0] <= plan->max_work_item_per_workgroup)
{
createLocalMemfftKernelString(plan);
}
else
{
createGlobalFFTKernelString(plan, plan->n.x, 1, cl_fft_kernel_x, 1);
}
}
}
break;
case cl_fft_kernel_y:
if (plan->n.y > 1)
{
createGlobalFFTKernelString(plan, plan->n.y, plan->n.x, cl_fft_kernel_y, 1);
}
break;
case cl_fft_kernel_z:
if (plan->n.z > 1)
{
createGlobalFFTKernelString(plan, plan->n.z, plan->n.x * plan->n.y, cl_fft_kernel_z, 1);
}
default:
return;
}
}