summaryrefslogtreecommitdiff
path: root/servers/rendering/renderer_rd/shaders/cluster_render.glsl
blob: 932312de825cce5a92f55885ab1794a777f11f0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#[vertex]

#version 450

#VERSION_DEFINES

layout(location = 0) in vec3 vertex_attrib;

layout(location = 0) out float depth_interp;
layout(location = 1) out flat uint element_index;

layout(push_constant, std430) uniform Params {
	uint base_index;
	uint pad0;
	uint pad1;
	uint pad2;
}
params;

layout(set = 0, binding = 1, std140) uniform State {
	mat4 projection;

	float inv_z_far;
	uint screen_to_clusters_shift; // shift to obtain coordinates in block indices
	uint cluster_screen_width; //
	uint cluster_data_size; // how much data for a single cluster takes

	uint cluster_depth_offset;
	uint pad0;
	uint pad1;
	uint pad2;
}
state;

struct RenderElement {
	uint type; //0-4
	bool touches_near;
	bool touches_far;
	uint original_index;
	mat3x4 transform_inv;
	vec3 scale;
	uint pad;
};

layout(set = 0, binding = 2, std430) buffer restrict readonly RenderElements {
	RenderElement data[];
}
render_elements;

void main() {
	element_index = params.base_index + gl_InstanceIndex;

	vec3 vertex = vertex_attrib;
	vertex *= render_elements.data[element_index].scale;

	vertex = vec4(vertex, 1.0) * render_elements.data[element_index].transform_inv;
	depth_interp = -vertex.z;

	gl_Position = state.projection * vec4(vertex, 1.0);
}

#[fragment]

#version 450

#VERSION_DEFINES
#ifndef MOLTENVK_USED // Metal will corrupt GPU state otherwise
#if defined(has_GL_KHR_shader_subgroup_ballot) && defined(has_GL_KHR_shader_subgroup_arithmetic) && defined(has_GL_KHR_shader_subgroup_vote)

#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_vote : enable

#define USE_SUBGROUPS
#endif
#endif

layout(location = 0) in float depth_interp;
layout(location = 1) in flat uint element_index;

layout(set = 0, binding = 1, std140) uniform State {
	mat4 projection;
	float inv_z_far;
	uint screen_to_clusters_shift; // shift to obtain coordinates in block indices
	uint cluster_screen_width; //
	uint cluster_data_size; // how much data for a single cluster takes
	uint cluster_depth_offset;
	uint pad0;
	uint pad1;
	uint pad2;
}
state;

//cluster data is layout linearly, each cell contains the follow information:
// - list of bits for every element to mark as used, so (max_elem_count/32)*4 uints
// - a uint for each element to mark the depth bits used when rendering (0-31)

layout(set = 0, binding = 3, std430) buffer restrict ClusterRender {
	uint data[];
}
cluster_render;

void main() {
	//convert from screen to cluster
	uvec2 cluster = uvec2(gl_FragCoord.xy) >> state.screen_to_clusters_shift;

	//get linear cluster offset from screen poss
	uint cluster_offset = cluster.x + state.cluster_screen_width * cluster.y;
	//multiply by data size to position at the beginning of the element list for this cluster
	cluster_offset *= state.cluster_data_size;

	//find the current element in the list and plot the bit to mark it as used
	uint usage_write_offset = cluster_offset + (element_index >> 5);
	uint usage_write_bit = 1 << (element_index & 0x1F);

#ifdef USE_SUBGROUPS

	uint cluster_thread_group_index;

	if (!gl_HelperInvocation) {
		//https://advances.realtimerendering.com/s2017/2017_Sig_Improved_Culling_final.pdf

		uvec4 mask;

		while (true) {
			// find the cluster offset of the first active thread
			// threads that did break; go inactive and no longer count
			uint first = subgroupBroadcastFirst(cluster_offset);
			// update the mask for thread that match this cluster
			mask = subgroupBallot(first == cluster_offset);
			if (first == cluster_offset) {
				// This thread belongs to the group of threads that match this offset,
				// so exit the loop.
				break;
			}
		}

		cluster_thread_group_index = subgroupBallotExclusiveBitCount(mask);

		if (cluster_thread_group_index == 0) {
			atomicOr(cluster_render.data[usage_write_offset], usage_write_bit);
		}
	}
#else
	if (!gl_HelperInvocation) {
		atomicOr(cluster_render.data[usage_write_offset], usage_write_bit);
	}
#endif
	//find the current element in the depth usage list and mark the current depth as used
	float unit_depth = depth_interp * state.inv_z_far;

	uint z_bit = clamp(uint(floor(unit_depth * 32.0)), 0, 31);

	uint z_write_offset = cluster_offset + state.cluster_depth_offset + element_index;
	uint z_write_bit = 1 << z_bit;

#ifdef USE_SUBGROUPS
	if (!gl_HelperInvocation) {
		z_write_bit = subgroupOr(z_write_bit); //merge all Zs
		if (cluster_thread_group_index == 0) {
			atomicOr(cluster_render.data[z_write_offset], z_write_bit);
		}
	}
#else
	if (!gl_HelperInvocation) {
		atomicOr(cluster_render.data[z_write_offset], z_write_bit);
	}
#endif
}