#include "b3LauncherCL.h"

bool gDebugLauncherCL = false;

b3LauncherCL::b3LauncherCL(cl_command_queue queue, cl_kernel kernel, const char* name)
	: m_commandQueue(queue),
	  m_kernel(kernel),
	  m_idx(0),
	  m_enableSerialization(false),
	  m_name(name)
{
	if (gDebugLauncherCL)
	{
		static int counter = 0;
		printf("[%d] Prepare to launch OpenCL kernel %s\n", counter++, name);
	}

	m_serializationSizeInBytes = sizeof(int);
}

b3LauncherCL::~b3LauncherCL()
{
	for (int i = 0; i < m_arrays.size(); i++)
	{
		delete (m_arrays[i]);
	}

	m_arrays.clear();
	if (gDebugLauncherCL)
	{
		static int counter = 0;
		printf("[%d] Finished launching OpenCL kernel %s\n", counter++, m_name);
	}
}

void b3LauncherCL::setBuffer(cl_mem clBuffer)
{
	if (m_enableSerialization)
	{
		b3KernelArgData kernelArg;
		kernelArg.m_argIndex = m_idx;
		kernelArg.m_isBuffer = 1;
		kernelArg.m_clBuffer = clBuffer;

		cl_mem_info param_name = CL_MEM_SIZE;
		size_t param_value;
		size_t sizeInBytes = sizeof(size_t);
		size_t actualSizeInBytes;
		cl_int err;
		err = clGetMemObjectInfo(kernelArg.m_clBuffer,
								 param_name,
								 sizeInBytes,
								 &param_value,
								 &actualSizeInBytes);

		b3Assert(err == CL_SUCCESS);
		kernelArg.m_argSizeInBytes = param_value;

		m_kernelArguments.push_back(kernelArg);
		m_serializationSizeInBytes += sizeof(b3KernelArgData);
		m_serializationSizeInBytes += param_value;
	}
	cl_int status = clSetKernelArg(m_kernel, m_idx++, sizeof(cl_mem), &clBuffer);
	b3Assert(status == CL_SUCCESS);
}

void b3LauncherCL::setBuffers(b3BufferInfoCL* buffInfo, int n)
{
	for (int i = 0; i < n; i++)
	{
		if (m_enableSerialization)
		{
			b3KernelArgData kernelArg;
			kernelArg.m_argIndex = m_idx;
			kernelArg.m_isBuffer = 1;
			kernelArg.m_clBuffer = buffInfo[i].m_clBuffer;

			cl_mem_info param_name = CL_MEM_SIZE;
			size_t param_value;
			size_t sizeInBytes = sizeof(size_t);
			size_t actualSizeInBytes;
			cl_int err;
			err = clGetMemObjectInfo(kernelArg.m_clBuffer,
									 param_name,
									 sizeInBytes,
									 &param_value,
									 &actualSizeInBytes);

			b3Assert(err == CL_SUCCESS);
			kernelArg.m_argSizeInBytes = param_value;

			m_kernelArguments.push_back(kernelArg);
			m_serializationSizeInBytes += sizeof(b3KernelArgData);
			m_serializationSizeInBytes += param_value;
		}
		cl_int status = clSetKernelArg(m_kernel, m_idx++, sizeof(cl_mem), &buffInfo[i].m_clBuffer);
		b3Assert(status == CL_SUCCESS);
	}
}

struct b3KernelArgDataUnaligned
{
	int m_isBuffer;
	int m_argIndex;
	int m_argSizeInBytes;
	int m_unusedPadding;
	union {
		cl_mem m_clBuffer;
		unsigned char m_argData[B3_CL_MAX_ARG_SIZE];
	};
};
#include <string.h>

int b3LauncherCL::deserializeArgs(unsigned char* buf, int bufSize, cl_context ctx)
{
	int index = 0;

	int numArguments = *(int*)&buf[index];
	index += sizeof(int);

	for (int i = 0; i < numArguments; i++)
	{
		b3KernelArgDataUnaligned* arg = (b3KernelArgDataUnaligned*)&buf[index];

		index += sizeof(b3KernelArgData);
		if (arg->m_isBuffer)
		{
			b3OpenCLArray<unsigned char>* clData = new b3OpenCLArray<unsigned char>(ctx, m_commandQueue, arg->m_argSizeInBytes);
			clData->resize(arg->m_argSizeInBytes);

			clData->copyFromHostPointer(&buf[index], arg->m_argSizeInBytes);

			arg->m_clBuffer = clData->getBufferCL();

			m_arrays.push_back(clData);

			cl_int status = clSetKernelArg(m_kernel, m_idx++, sizeof(cl_mem), &arg->m_clBuffer);
			b3Assert(status == CL_SUCCESS);
			index += arg->m_argSizeInBytes;
		}
		else
		{
			cl_int status = clSetKernelArg(m_kernel, m_idx++, arg->m_argSizeInBytes, &arg->m_argData);
			b3Assert(status == CL_SUCCESS);
		}
		b3KernelArgData b;
		memcpy(&b, arg, sizeof(b3KernelArgDataUnaligned));
		m_kernelArguments.push_back(b);
	}
	m_serializationSizeInBytes = index;
	return index;
}

int b3LauncherCL::validateResults(unsigned char* goldBuffer, int goldBufferCapacity, cl_context ctx)
{
	int index = 0;

	int numArguments = *(int*)&goldBuffer[index];
	index += sizeof(int);

	if (numArguments != m_kernelArguments.size())
	{
		printf("failed validation: expected %d arguments, found %d\n", numArguments, m_kernelArguments.size());
		return -1;
	}

	for (int ii = 0; ii < numArguments; ii++)
	{
		b3KernelArgData* argGold = (b3KernelArgData*)&goldBuffer[index];

		if (m_kernelArguments[ii].m_argSizeInBytes != argGold->m_argSizeInBytes)
		{
			printf("failed validation: argument %d sizeInBytes expected: %d, found %d\n", ii, argGold->m_argSizeInBytes, m_kernelArguments[ii].m_argSizeInBytes);
			return -2;
		}

		{
			int expected = argGold->m_isBuffer;
			int found = m_kernelArguments[ii].m_isBuffer;

			if (expected != found)
			{
				printf("failed validation: argument %d isBuffer expected: %d, found %d\n", ii, expected, found);
				return -3;
			}
		}
		index += sizeof(b3KernelArgData);

		if (argGold->m_isBuffer)
		{
			unsigned char* memBuf = (unsigned char*)malloc(m_kernelArguments[ii].m_argSizeInBytes);
			unsigned char* goldBuf = &goldBuffer[index];
			for (int j = 0; j < m_kernelArguments[j].m_argSizeInBytes; j++)
			{
				memBuf[j] = 0xaa;
			}

			cl_int status = 0;
			status = clEnqueueReadBuffer(m_commandQueue, m_kernelArguments[ii].m_clBuffer, CL_TRUE, 0, m_kernelArguments[ii].m_argSizeInBytes,
										 memBuf, 0, 0, 0);
			b3Assert(status == CL_SUCCESS);
			clFinish(m_commandQueue);

			for (int b = 0; b < m_kernelArguments[ii].m_argSizeInBytes; b++)
			{
				int expected = goldBuf[b];
				int found = memBuf[b];
				if (expected != found)
				{
					printf("failed validation: argument %d OpenCL data at byte position %d expected: %d, found %d\n",
						   ii, b, expected, found);
					return -4;
				}
			}

			index += argGold->m_argSizeInBytes;
		}
		else
		{
			//compare content
			for (int b = 0; b < m_kernelArguments[ii].m_argSizeInBytes; b++)
			{
				int expected = argGold->m_argData[b];
				int found = m_kernelArguments[ii].m_argData[b];
				if (expected != found)
				{
					printf("failed validation: argument %d const data at byte position %d expected: %d, found %d\n",
						   ii, b, expected, found);
					return -5;
				}
			}
		}
	}
	return index;
}

int b3LauncherCL::serializeArguments(unsigned char* destBuffer, int destBufferCapacity)
{
	//initialize to known values
	for (int i = 0; i < destBufferCapacity; i++)
		destBuffer[i] = 0xec;

	assert(destBufferCapacity >= m_serializationSizeInBytes);

	//todo: use the b3Serializer for this to allow for 32/64bit, endianness etc
	int numArguments = m_kernelArguments.size();
	int curBufferSize = 0;
	int* dest = (int*)&destBuffer[curBufferSize];
	*dest = numArguments;
	curBufferSize += sizeof(int);

	for (int i = 0; i < this->m_kernelArguments.size(); i++)
	{
		b3KernelArgData* arg = (b3KernelArgData*)&destBuffer[curBufferSize];
		*arg = m_kernelArguments[i];
		curBufferSize += sizeof(b3KernelArgData);
		if (arg->m_isBuffer == 1)
		{
			//copy the OpenCL buffer content
			cl_int status = 0;
			status = clEnqueueReadBuffer(m_commandQueue, arg->m_clBuffer, 0, 0, arg->m_argSizeInBytes,
										 &destBuffer[curBufferSize], 0, 0, 0);
			b3Assert(status == CL_SUCCESS);
			clFinish(m_commandQueue);
			curBufferSize += arg->m_argSizeInBytes;
		}
	}
	return curBufferSize;
}

void b3LauncherCL::serializeToFile(const char* fileName, int numWorkItems)
{
	int num = numWorkItems;
	int buffSize = getSerializationBufferSize();
	unsigned char* buf = new unsigned char[buffSize + sizeof(int)];
	for (int i = 0; i < buffSize + 1; i++)
	{
		unsigned char* ptr = (unsigned char*)&buf[i];
		*ptr = 0xff;
	}
	//	int actualWrite = serializeArguments(buf,buffSize);

	//	unsigned char* cptr = (unsigned char*)&buf[buffSize];
	//            printf("buf[buffSize] = %d\n",*cptr);

	assert(buf[buffSize] == 0xff);  //check for buffer overrun
	int* ptr = (int*)&buf[buffSize];

	*ptr = num;

	FILE* f = fopen(fileName, "wb");
	fwrite(buf, buffSize + sizeof(int), 1, f);
	fclose(f);

	delete[] buf;
}