summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/core/node.h
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/core/node.h')
-rw-r--r--thirdparty/oidn/core/node.h142
1 files changed, 142 insertions, 0 deletions
diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h
new file mode 100644
index 0000000000..b9ffe906df
--- /dev/null
+++ b/thirdparty/oidn/core/node.h
@@ -0,0 +1,142 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation //
+// //
+// Licensed under the Apache License, Version 2.0 (the "License"); //
+// you may not use this file except in compliance with the License. //
+// You may obtain a copy of the License at //
+// //
+// http://www.apache.org/licenses/LICENSE-2.0 //
+// //
+// Unless required by applicable law or agreed to in writing, software //
+// distributed under the License is distributed on an "AS IS" BASIS, //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and //
+// limitations under the License. //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+#include <vector>
+
+namespace oidn {
+
+ class Node
+ {
+ public:
+ virtual ~Node() = default;
+
+ virtual void execute(stream& sm) = 0;
+
+ virtual std::shared_ptr<memory> getDst() const { return nullptr; }
+
+ virtual size_t getScratchpadSize() const { return 0; }
+ virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
+
+ virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
+ {
+ assert(0); // not supported
+ }
+ };
+
+ // Node wrapping an MKL-DNN primitive
+ class MklNode : public Node
+ {
+ private:
+ primitive prim;
+ std::unordered_map<int, memory> args;
+ std::shared_ptr<memory> scratchpad;
+
+ public:
+ MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
+ : prim(prim),
+ args(args)
+ {}
+
+ size_t getScratchpadSize() const override
+ {
+ const auto primDesc = prim.get_primitive_desc();
+ const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
+ if (scratchpadDesc == nullptr)
+ return 0;
+ return mkldnn_memory_desc_get_size(scratchpadDesc);
+ }
+
+ void setScratchpad(const std::shared_ptr<memory>& mem) override
+ {
+ scratchpad = mem;
+ args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
+ }
+
+ void execute(stream& sm) override
+ {
+ prim.execute(sm, args);
+ }
+ };
+
+ // Convolution node
+ class ConvNode : public MklNode
+ {
+ private:
+ std::shared_ptr<memory> src;
+ std::shared_ptr<memory> weights;
+ std::shared_ptr<memory> bias;
+ std::shared_ptr<memory> dst;
+
+ public:
+ ConvNode(const convolution_forward::primitive_desc& desc,
+ const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& weights,
+ const std::shared_ptr<memory>& bias,
+ const std::shared_ptr<memory>& dst)
+ : MklNode(convolution_forward(desc),
+ { { MKLDNN_ARG_SRC, *src },
+ { MKLDNN_ARG_WEIGHTS, *weights },
+ { MKLDNN_ARG_BIAS, *bias },
+ { MKLDNN_ARG_DST, *dst } }),
+ src(src), weights(weights), bias(bias), dst(dst)
+ {}
+
+ std::shared_ptr<memory> getDst() const override { return dst; }
+ };
+
+ // Pooling node
+ class PoolNode : public MklNode
+ {
+ private:
+ std::shared_ptr<memory> src;
+ std::shared_ptr<memory> dst;
+
+ public:
+ PoolNode(const pooling_forward::primitive_desc& desc,
+ const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& dst)
+ : MklNode(pooling_forward(desc),
+ { { MKLDNN_ARG_SRC, *src },
+ { MKLDNN_ARG_DST, *dst } }),
+ src(src), dst(dst)
+ {}
+
+ std::shared_ptr<memory> getDst() const override { return dst; }
+ };
+
+ // Reorder node
+ class ReorderNode : public MklNode
+ {
+ private:
+ std::shared_ptr<memory> src;
+ std::shared_ptr<memory> dst;
+
+ public:
+ ReorderNode(const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& dst)
+ : MklNode(reorder(reorder::primitive_desc(*src, *dst)),
+ { { MKLDNN_ARG_SRC, *src },
+ { MKLDNN_ARG_DST, *dst } }),
+ src(src), dst(dst)
+ {}
+
+ std::shared_ptr<memory> getDst() const override { return dst; }
+ };
+
+} // namespace oidn