// ======================================================================== // // Copyright 2009-2019 Intel Corporation // // // // Licensed under the Apache License, Version 2.0 (the "License"); // // you may not use this file except in compliance with the License. // // You may obtain a copy of the License at // // // // http://www.apache.org/licenses/LICENSE-2.0 // // // // Unless required by applicable law or agreed to in writing, software // // distributed under the License is distributed on an "AS IS" BASIS, // // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // // See the License for the specific language governing permissions and // // limitations under the License. // // ======================================================================== // #include "common/tensor.h" #include "image.h" #include "node.h" #include "input_reorder.h" #include "output_reorder.h" #include "transfer_function.h" #pragma once namespace oidn { // Progress state struct Progress { ProgressMonitorFunction func; void* userPtr; int taskCount; }; class Executable { public: virtual ~Executable() {} virtual void execute(const Progress& progress, int taskIndex) = 0; }; template class Network : public Executable { public: Network(const Ref& device, const std::map& weightMap); void execute(const Progress& progress, int taskIndex) override; std::shared_ptr allocTensor(const memory::dims& dims, memory::format_tag format = memory::format_tag::any, void* data = nullptr); std::shared_ptr castTensor(const memory::dims& dims, const std::shared_ptr& src, size_t srcOffset = 0, memory::format_tag format = memory::format_tag::any); std::shared_ptr castTensor(const memory::dims& dims, const std::shared_ptr& src, const memory::dims& srcOffset); void zeroTensor(const std::shared_ptr& dst); memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); std::shared_ptr addInputReorder(const Image& color, const Image& albedo, const Image& normal, const std::shared_ptr& transferFunc, int alignment, const std::shared_ptr& userDst = nullptr); std::shared_ptr addOutputReorder(const std::shared_ptr& src, const std::shared_ptr& transferFunc, const Image& output); memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); std::shared_ptr addConv(const std::string& name, const std::shared_ptr& src, const std::shared_ptr& userDst = nullptr, bool relu = true); memory::dims getPoolDims(const memory::dims& srcDims); std::shared_ptr addPool(const std::shared_ptr& src, const std::shared_ptr& userDst = nullptr); memory::dims getUpsampleDims(const memory::dims& srcDims); std::shared_ptr addUpsample(const std::shared_ptr& src, const std::shared_ptr& userDst = nullptr); memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); std::shared_ptr addAutoexposure(const Image& color, const std::shared_ptr& transferFunc); void finalize(); private: Ref device; engine eng; stream sm; std::vector> nodes; std::map weightMap; // Memory allocation statistics size_t activationAllocBytes = 0; // number of allocated activation bytes size_t totalAllocBytes = 0; // total number of allocated bytes }; } // namespace oidn