#ifndef B3_NEW_CONTACT_REDUCTION_H
#define B3_NEW_CONTACT_REDUCTION_H

#include "Bullet3Common/shared/b3Float4.h"
#include "Bullet3Collision/NarrowPhaseCollision/shared/b3RigidBodyData.h"
#include "Bullet3Collision/NarrowPhaseCollision/shared/b3Contact4Data.h"

#define GET_NPOINTS(x) (x).m_worldNormalOnB.w


int b3ExtractManifoldSequentialGlobal(__global const b3Float4* p, int nPoints, b3Float4ConstArg nearNormal, b3Int4* contactIdx)
{
	if( nPoints == 0 )
        return 0;
    
    if (nPoints <=4)
        return nPoints;
    
    
    if (nPoints >64)
        nPoints = 64;
    
	b3Float4 center = b3MakeFloat4(0,0,0,0);
	{
		
		for (int i=0;i<nPoints;i++)
			center += p[i];
		center /= (float)nPoints;
	}
    
	
    
	//	sample 4 directions
    
    b3Float4 aVector = p[0] - center;
    b3Float4 u = b3Cross( nearNormal, aVector );
    b3Float4 v = b3Cross( nearNormal, u );
    u = b3Normalized( u );
    v = b3Normalized( v );
    
    
    //keep point with deepest penetration
    float minW= FLT_MAX;
    
    int minIndex=-1;
    
    b3Float4 maxDots;
    maxDots.x = FLT_MIN;
    maxDots.y = FLT_MIN;
    maxDots.z = FLT_MIN;
    maxDots.w = FLT_MIN;
    
    //	idx, distance
    for(int ie = 0; ie<nPoints; ie++ )
    {
        if (p[ie].w<minW)
        {
            minW = p[ie].w;
            minIndex=ie;
        }
        float f;
        b3Float4 r = p[ie]-center;
        f = b3Dot( u, r );
        if (f<maxDots.x)
        {
            maxDots.x = f;
            contactIdx[0].x = ie;
        }
        
        f = b3Dot( -u, r );
        if (f<maxDots.y)
        {
            maxDots.y = f;
            contactIdx[0].y = ie;
        }
        
        
        f = b3Dot( v, r );
        if (f<maxDots.z)
        {
            maxDots.z = f;
            contactIdx[0].z = ie;
        }
        
        f = b3Dot( -v, r );
        if (f<maxDots.w)
        {
            maxDots.w = f;
            contactIdx[0].w = ie;
        }
        
    }
    
    if (contactIdx[0].x != minIndex && contactIdx[0].y != minIndex && contactIdx[0].z != minIndex && contactIdx[0].w != minIndex)
    {
        //replace the first contact with minimum (todo: replace contact with least penetration)
        contactIdx[0].x = minIndex;
    }
    
    return 4;
    
}

__kernel void   b3NewContactReductionKernel( __global b3Int4* pairs,
                                                   __global const b3RigidBodyData_t* rigidBodies,
                                                   __global const b3Float4* separatingNormals,
                                                   __global const int* hasSeparatingAxis,
                                                   __global struct b3Contact4Data* globalContactsOut,
                                                   __global b3Int4* clippingFaces,
                                                   __global b3Float4* worldVertsB2,
                                                   volatile __global int* nGlobalContactsOut,
                                                   int vertexFaceCapacity,
												   int contactCapacity,
                                                   int numPairs,
												   int pairIndex
                                                   )
{
//    int i = get_global_id(0);
	//int pairIndex = i;
	int i = pairIndex;

    b3Int4 contactIdx;
    contactIdx=b3MakeInt4(0,1,2,3);
    
	if (i<numPairs)
	{
        
		if (hasSeparatingAxis[i])
		{
            
			
            
            
			int nPoints = clippingFaces[pairIndex].w;
           
            if (nPoints>0)
            {

                 __global b3Float4* pointsIn = &worldVertsB2[pairIndex*vertexFaceCapacity];
                b3Float4 normal = -separatingNormals[i];
                
                int nReducedContacts = b3ExtractManifoldSequentialGlobal(pointsIn, nPoints, normal, &contactIdx);
            
                int dstIdx;
                dstIdx = b3AtomicInc( nGlobalContactsOut);
				
//#if 0
                b3Assert(dstIdx < contactCapacity);
				if (dstIdx < contactCapacity)
				{

					__global struct b3Contact4Data* c = &globalContactsOut[dstIdx];
					c->m_worldNormalOnB = -normal;
					c->m_restituitionCoeffCmp = (0.f*0xffff);c->m_frictionCoeffCmp = (0.7f*0xffff);
					c->m_batchIdx = pairIndex;
					int bodyA = pairs[pairIndex].x;
					int bodyB = pairs[pairIndex].y;

					pairs[pairIndex].w = dstIdx;

					c->m_bodyAPtrAndSignBit = rigidBodies[bodyA].m_invMass==0?-bodyA:bodyA;
					c->m_bodyBPtrAndSignBit = rigidBodies[bodyB].m_invMass==0?-bodyB:bodyB;
                    c->m_childIndexA =-1;
					c->m_childIndexB =-1;

                    switch (nReducedContacts)
                    {
                        case 4:
                            c->m_worldPosB[3] = pointsIn[contactIdx.w];
                        case 3:
                            c->m_worldPosB[2] = pointsIn[contactIdx.z];
                        case 2:
                            c->m_worldPosB[1] = pointsIn[contactIdx.y];
                        case 1:
                            c->m_worldPosB[0] = pointsIn[contactIdx.x];
                        default:
                        {
                        }
                    };
                    
					GET_NPOINTS(*c) = nReducedContacts;
                    
                 }
                 
                
//#endif
				
			}//		if (numContactsOut>0)
		}//		if (hasSeparatingAxis[i])
	}//	if (i<numPairs)

    
    
}
#endif