summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/core/network.h
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/core/network.h')
-rw-r--r--thirdparty/oidn/core/network.h112
1 files changed, 112 insertions, 0 deletions
diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h
new file mode 100644
index 0000000000..7a696fd355
--- /dev/null
+++ b/thirdparty/oidn/core/network.h
@@ -0,0 +1,112 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation //
+// //
+// Licensed under the Apache License, Version 2.0 (the "License"); //
+// you may not use this file except in compliance with the License. //
+// You may obtain a copy of the License at //
+// //
+// http://www.apache.org/licenses/LICENSE-2.0 //
+// //
+// Unless required by applicable law or agreed to in writing, software //
+// distributed under the License is distributed on an "AS IS" BASIS, //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and //
+// limitations under the License. //
+// ======================================================================== //
+
+#include "common/tensor.h"
+#include "image.h"
+#include "node.h"
+#include "input_reorder.h"
+#include "output_reorder.h"
+#include "transfer_function.h"
+
+#pragma once
+
+namespace oidn {
+
+ // Progress state
+ struct Progress
+ {
+ ProgressMonitorFunction func;
+ void* userPtr;
+ int taskCount;
+ };
+
+ class Executable
+ {
+ public:
+ virtual ~Executable() {}
+ virtual void execute(const Progress& progress, int taskIndex) = 0;
+ };
+
+ template<int K>
+ class Network : public Executable
+ {
+ public:
+ Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
+
+ void execute(const Progress& progress, int taskIndex) override;
+
+ std::shared_ptr<memory> allocTensor(const memory::dims& dims,
+ memory::format_tag format = memory::format_tag::any,
+ void* data = nullptr);
+
+ std::shared_ptr<memory> castTensor(const memory::dims& dims,
+ const std::shared_ptr<memory>& src,
+ size_t srcOffset = 0,
+ memory::format_tag format = memory::format_tag::any);
+
+ std::shared_ptr<memory> castTensor(const memory::dims& dims,
+ const std::shared_ptr<memory>& src,
+ const memory::dims& srcOffset);
+
+ void zeroTensor(const std::shared_ptr<memory>& dst);
+
+ memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
+
+ std::shared_ptr<Node> addInputReorder(const Image& color,
+ const Image& albedo,
+ const Image& normal,
+ const std::shared_ptr<TransferFunction>& transferFunc,
+ int alignment,
+ const std::shared_ptr<memory>& userDst = nullptr);
+
+ std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
+ const std::shared_ptr<TransferFunction>& transferFunc,
+ const Image& output);
+
+ memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
+ std::shared_ptr<Node> addConv(const std::string& name,
+ const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& userDst = nullptr,
+ bool relu = true);
+
+ memory::dims getPoolDims(const memory::dims& srcDims);
+ std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& userDst = nullptr);
+
+ memory::dims getUpsampleDims(const memory::dims& srcDims);
+ std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
+ const std::shared_ptr<memory>& userDst = nullptr);
+
+ memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
+
+ std::shared_ptr<Node> addAutoexposure(const Image& color,
+ const std::shared_ptr<HDRTransferFunction>& transferFunc);
+
+ void finalize();
+
+ private:
+ Ref<Device> device;
+ engine eng;
+ stream sm;
+ std::vector<std::shared_ptr<Node>> nodes;
+ std::map<std::string, Tensor> weightMap;
+
+ // Memory allocation statistics
+ size_t activationAllocBytes = 0; // number of allocated activation bytes
+ size_t totalAllocBytes = 0; // total number of allocated bytes
+ };
+
+} // namespace oidn