1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
|
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef COMMON_H
#define COMMON_H
#define GEMM_CODE_SIZE (4096L * 32)
#define AVX512_UNROLL_M 48
#define AVX512_UNROLL_N 8
#define AVX512_UNROLL_K 1
#define AVX512_BM 9984
#define AVX512_BN 384
#define AVX512_BK 768
#define AVX512_BK_VNNI 1536
#define AVX512_BK_TRADITIONAL 384
#define AVX512_BLOCKING_SMALL_K 48
#define AVX512_BN_SMALL_K 24
#define PAGESIZE 4096
#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
#include "jit_generator.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
enum {
PARTITION_1D_ROW,
PARTITION_1D_COL,
PARTITION_2D_COL_MAJOR,
PARTITION_2D = PARTITION_2D_COL_MAJOR,
};
enum {
COPY_NONE,
COPY_A,
};
enum {
NO_OFFSET,
FIX_OFFSET,
COL_OFFSET,
ROW_OFFSET,
};
// Alias for any dimension related variable.
typedef long long int dim_t;
typedef struct {
// Interface arguments.
int transa, transb, offsetc;
dim_t m, n, k;
dim_t lda, ldb, ldc;
const int8_t *a;
const uint8_t *b;
int32_t *c;
const float *alpha, *beta;
int8_t ao, bo;
const int32_t *co;
// Kernel parameters.
dim_t um, un, uk, bm, bn, bk;
dim_t bn_small_k, bk_traditional, blocking_small_k;
int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
const dim_t *lda, const int8_t *alpha, int8_t *b,
const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
const dim_t *lda, const uint8_t *alpha, uint8_t *b,
const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
const dim_t ldc, const int32_t *col_offset,
const int32_t *row_offset);
// Gemv kernels
void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
const int8_t*, const dim_t, const uint8_t*,
const float, int32_t*);
void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
const uint8_t*, const dim_t, const int8_t*,
const float, int32_t*);
// Gemv parameters
int swap;
} blas_t;
class jit_avx512_core_u8_copy_an_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
public:
jit_avx512_core_u8_copy_an_kern();
};
class jit_avx512_core_u8_copy_at_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
public:
jit_avx512_core_u8_copy_at_kern();
};
class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
public:
jit_avx512_core_u8_copy_bn_kern();
};
class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
public:
jit_avx512_core_u8_copy_bt_kern();
};
class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
public:
jit_avx512_core_u8_copy_sum_an_kern();
};
class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
public:
jit_avx512_core_u8_copy_sum_at_kern();
};
class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
public:
jit_avx512_core_u8_copy_sum_bn_kern();
};
class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
public:
jit_avx512_core_u8_copy_sum_bt_kern();
};
}
}
}
#endif
|