diff options
Diffstat (limited to 'thirdparty/oidn/core/node.h')
-rw-r--r-- | thirdparty/oidn/core/node.h | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h new file mode 100644 index 0000000000..b9ffe906df --- /dev/null +++ b/thirdparty/oidn/core/node.h @@ -0,0 +1,142 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include <vector> + +namespace oidn { + + class Node + { + public: + virtual ~Node() = default; + + virtual void execute(stream& sm) = 0; + + virtual std::shared_ptr<memory> getDst() const { return nullptr; } + + virtual size_t getScratchpadSize() const { return 0; } + virtual void setScratchpad(const std::shared_ptr<memory>& mem) {} + + virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) + { + assert(0); // not supported + } + }; + + // Node wrapping an MKL-DNN primitive + class MklNode : public Node + { + private: + primitive prim; + std::unordered_map<int, memory> args; + std::shared_ptr<memory> scratchpad; + + public: + MklNode(const primitive& prim, const std::unordered_map<int, memory>& args) + : prim(prim), + args(args) + {} + + size_t getScratchpadSize() const override + { + const auto primDesc = prim.get_primitive_desc(); + const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); + if (scratchpadDesc == nullptr) + return 0; + return mkldnn_memory_desc_get_size(scratchpadDesc); + } + + void setScratchpad(const std::shared_ptr<memory>& mem) override + { + scratchpad = mem; + args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); + } + + void execute(stream& sm) override + { + prim.execute(sm, args); + } + }; + + // Convolution node + class ConvNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> weights; + std::shared_ptr<memory> bias; + std::shared_ptr<memory> dst; + + public: + ConvNode(const convolution_forward::primitive_desc& desc, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& weights, + const std::shared_ptr<memory>& bias, + const std::shared_ptr<memory>& dst) + : MklNode(convolution_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_WEIGHTS, *weights }, + { MKLDNN_ARG_BIAS, *bias }, + { MKLDNN_ARG_DST, *dst } }), + src(src), weights(weights), bias(bias), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + + // Pooling node + class PoolNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + PoolNode(const pooling_forward::primitive_desc& desc, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : MklNode(pooling_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + + // Reorder node + class ReorderNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + ReorderNode(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : MklNode(reorder(reorder::primitive_desc(*src, *dst)), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + +} // namespace oidn |