summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/common/tensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/common/tensor.h')
-rw-r--r--thirdparty/oidn/common/tensor.h66
1 files changed, 66 insertions, 0 deletions
diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h
new file mode 100644
index 0000000000..48e7d1123d
--- /dev/null
+++ b/thirdparty/oidn/common/tensor.h
@@ -0,0 +1,66 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation //
+// //
+// Licensed under the Apache License, Version 2.0 (the "License"); //
+// you may not use this file except in compliance with the License. //
+// You may obtain a copy of the License at //
+// //
+// http://www.apache.org/licenses/LICENSE-2.0 //
+// //
+// Unless required by applicable law or agreed to in writing, software //
+// distributed under the License is distributed on an "AS IS" BASIS, //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and //
+// limitations under the License. //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+#include <vector>
+#include <map>
+
+namespace oidn {
+
+ template<typename T>
+ using shared_vector = std::shared_ptr<std::vector<T>>;
+
+ // Generic tensor
+ struct Tensor
+ {
+ float* data;
+ std::vector<int64_t> dims;
+ std::string format;
+ shared_vector<char> buffer; // optional, only for reference counting
+
+ __forceinline Tensor() : data(nullptr) {}
+
+ __forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
+ : dims(dims),
+ format(format)
+ {
+ buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
+ data = (float*)buffer->data();
+ }
+
+ __forceinline operator bool() const { return data != nullptr; }
+
+ __forceinline int ndims() const { return (int)dims.size(); }
+
+ // Returns the number of values
+ __forceinline size_t size() const
+ {
+ size_t size = 1;
+ for (int i = 0; i < ndims(); ++i)
+ size *= dims[i];
+ return size;
+ }
+
+ __forceinline float& operator [](size_t i) { return data[i]; }
+ __forceinline const float& operator [](size_t i) const { return data[i]; }
+ };
+
+ // Parses tensors from a buffer
+ std::map<std::string, Tensor> parseTensors(void* buffer);
+
+} // namespace oidn