summaryrefslogtreecommitdiff
path: root/thirdparty/bullet/Bullet3OpenCL/ParallelPrimitives/b3PrefixScanCL.cpp
blob: 822b5116334699154e50dd4c94d14f667318e22e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include "b3PrefixScanCL.h"
#include "b3FillCL.h"
#define B3_PREFIXSCAN_PROG_PATH "src/Bullet3OpenCL/ParallelPrimitives/kernels/PrefixScanKernels.cl"

#include "b3LauncherCL.h"
#include "Bullet3OpenCL/Initialize/b3OpenCLUtils.h"
#include "kernels/PrefixScanKernelsCL.h"

b3PrefixScanCL::b3PrefixScanCL(cl_context ctx, cl_device_id device, cl_command_queue queue, int size)
	: m_commandQueue(queue)
{
	const char* scanKernelSource = prefixScanKernelsCL;
	cl_int pErrNum;
	char* additionalMacros = 0;

	m_workBuffer = new b3OpenCLArray<unsigned int>(ctx, queue, size);
	cl_program scanProg = b3OpenCLUtils::compileCLProgramFromString(ctx, device, scanKernelSource, &pErrNum, additionalMacros, B3_PREFIXSCAN_PROG_PATH);
	b3Assert(scanProg);

	m_localScanKernel = b3OpenCLUtils::compileCLKernelFromString(ctx, device, scanKernelSource, "LocalScanKernel", &pErrNum, scanProg, additionalMacros);
	b3Assert(m_localScanKernel);
	m_blockSumKernel = b3OpenCLUtils::compileCLKernelFromString(ctx, device, scanKernelSource, "TopLevelScanKernel", &pErrNum, scanProg, additionalMacros);
	b3Assert(m_blockSumKernel);
	m_propagationKernel = b3OpenCLUtils::compileCLKernelFromString(ctx, device, scanKernelSource, "AddOffsetKernel", &pErrNum, scanProg, additionalMacros);
	b3Assert(m_propagationKernel);
}

b3PrefixScanCL::~b3PrefixScanCL()
{
	delete m_workBuffer;
	clReleaseKernel(m_localScanKernel);
	clReleaseKernel(m_blockSumKernel);
	clReleaseKernel(m_propagationKernel);
}

template <class T>
T b3NextPowerOf2(T n)
{
	n -= 1;
	for (int i = 0; i < sizeof(T) * 8; i++)
		n = n | (n >> i);
	return n + 1;
}

void b3PrefixScanCL::execute(b3OpenCLArray<unsigned int>& src, b3OpenCLArray<unsigned int>& dst, int n, unsigned int* sum)
{
	//	b3Assert( data->m_option == EXCLUSIVE );
	const unsigned int numBlocks = (const unsigned int)((n + BLOCK_SIZE * 2 - 1) / (BLOCK_SIZE * 2));

	dst.resize(src.size());
	m_workBuffer->resize(src.size());

	b3Int4 constBuffer;
	constBuffer.x = n;
	constBuffer.y = numBlocks;
	constBuffer.z = (int)b3NextPowerOf2(numBlocks);

	b3OpenCLArray<unsigned int>* srcNative = &src;
	b3OpenCLArray<unsigned int>* dstNative = &dst;

	{
		b3BufferInfoCL bInfo[] = {b3BufferInfoCL(dstNative->getBufferCL()), b3BufferInfoCL(srcNative->getBufferCL()), b3BufferInfoCL(m_workBuffer->getBufferCL())};

		b3LauncherCL launcher(m_commandQueue, m_localScanKernel, "m_localScanKernel");
		launcher.setBuffers(bInfo, sizeof(bInfo) / sizeof(b3BufferInfoCL));
		launcher.setConst(constBuffer);
		launcher.launch1D(numBlocks * BLOCK_SIZE, BLOCK_SIZE);
	}

	{
		b3BufferInfoCL bInfo[] = {b3BufferInfoCL(m_workBuffer->getBufferCL())};

		b3LauncherCL launcher(m_commandQueue, m_blockSumKernel, "m_blockSumKernel");
		launcher.setBuffers(bInfo, sizeof(bInfo) / sizeof(b3BufferInfoCL));
		launcher.setConst(constBuffer);
		launcher.launch1D(BLOCK_SIZE, BLOCK_SIZE);
	}

	if (numBlocks > 1)
	{
		b3BufferInfoCL bInfo[] = {b3BufferInfoCL(dstNative->getBufferCL()), b3BufferInfoCL(m_workBuffer->getBufferCL())};
		b3LauncherCL launcher(m_commandQueue, m_propagationKernel, "m_propagationKernel");
		launcher.setBuffers(bInfo, sizeof(bInfo) / sizeof(b3BufferInfoCL));
		launcher.setConst(constBuffer);
		launcher.launch1D((numBlocks - 1) * BLOCK_SIZE, BLOCK_SIZE);
	}

	if (sum)
	{
		clFinish(m_commandQueue);
		dstNative->copyToHostPointer(sum, 1, n - 1, true);
	}
}

void b3PrefixScanCL::executeHost(b3AlignedObjectArray<unsigned int>& src, b3AlignedObjectArray<unsigned int>& dst, int n, unsigned int* sum)
{
	unsigned int s = 0;
	//if( data->m_option == EXCLUSIVE )
	{
		for (int i = 0; i < n; i++)
		{
			dst[i] = s;
			s += src[i];
		}
	}
	/*else
	{
		for(int i=0; i<n; i++)
		{
			s += hSrc[i];
			hDst[i] = s;
		}
	}
	*/

	if (sum)
	{
		*sum = dst[n - 1];
	}
}