diff options
Diffstat (limited to 'thirdparty/oidn/core/common.h')
-rw-r--r-- | thirdparty/oidn/core/common.h | 133 |
1 files changed, 133 insertions, 0 deletions
diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h new file mode 100644 index 0000000000..6c87f377bc --- /dev/null +++ b/thirdparty/oidn/core/common.h @@ -0,0 +1,133 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +#include "mkl-dnn/include/mkldnn.hpp" +#include "mkl-dnn/include/mkldnn_debug.h" +#include "mkl-dnn/src/common/mkldnn_thread.hpp" +#include "mkl-dnn/src/common/type_helpers.hpp" +#include "mkl-dnn/src/cpu/jit_generator.hpp" + +#include "common/ref.h" +#include "common/exception.h" +#include "common/thread.h" +#include "math.h" + +namespace oidn { + + using namespace mkldnn; + using namespace mkldnn::impl::cpu; + using mkldnn::impl::parallel_nd; + using mkldnn::impl::memory_desc_matches_tag; + + + inline size_t getFormatBytes(Format format) + { + switch (format) + { + case Format::Undefined: return 1; + case Format::Float: return sizeof(float); + case Format::Float2: return sizeof(float)*2; + case Format::Float3: return sizeof(float)*3; + case Format::Float4: return sizeof(float)*4; + } + assert(0); + return 0; + } + + + inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); + } + + inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::data_type(desc.data_type); + } + + // Returns the number of values in a tensor + inline size_t getTensorSize(const memory::dims& dims) + { + size_t res = 1; + for (int i = 0; i < (int)dims.size(); ++i) + res *= dims[i]; + return res; + } + + inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims) + { + memory::dims result; + size_t maxSize = 0; + + for (const auto& d : dims) + { + const size_t size = getTensorSize(d); + if (size > maxSize) + { + result = d; + maxSize = size; + } + } + + return result; + } + + inline size_t getTensorSize(const std::shared_ptr<memory>& mem) + { + return getTensorSize(getTensorDims(mem)); + } + + + template<int K> + inline int getPadded(int dim) + { + return (dim + (K-1)) & ~(K-1); + } + + template<int K> + inline memory::dims getPadded_nchw(const memory::dims& dims) + { + assert(dims.size() == 4); + memory::dims padDims = dims; + padDims[1] = getPadded<K>(dims[1]); // pad C + return padDims; + } + + + template<int K> + struct BlockedFormat; + + template<> + struct BlockedFormat<8> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; + }; + + template<> + struct BlockedFormat<16> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; + }; + +} // namespace oidn |