diff options
1130 files changed, 140393 insertions, 13007 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 3f6e932050..c39a485d35 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -29,7 +29,9 @@ cache: install: - SET "PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - - pip install scons==3.1.2 + - pip install -U wheel # needed for pip install scons to work, otherwise a flag is missing + - pip install scons # use stable scons + - if defined VS call "%VS%" %ARCH% # if defined - so we can also use mingw before_build: - echo %GD_PLATFORM% @@ -39,3 +41,15 @@ before_build: build_script: - scons platform=%GD_PLATFORM% target=%TARGET% tools=%TOOLS% %OPTIONS% %EXTRA_ARGS% + +after_build: + - git rev-parse --short=9 HEAD > VERSION_HASH.txt + - set /P VERSION_HASH= < VERSION_HASH.txt + - cd bin + - mv godot.windows.opt.tools.64.exe godot_%APPVEYOR_REPO_BRANCH%-%VERSION_HASH%_win64.exe + - 7z a -mx9 godot_%APPVEYOR_REPO_BRANCH%-%VERSION_HASH%_win64.zip *.exe + +artifacts: + - path: bin/godot_${APPVEYOR_REPO_BRANCH}-${VERSION_HASH}_win64.zip + name: Win64 release_debug editor build + type: zip
\ No newline at end of file diff --git a/.clang-format b/.clang-format index b3f1954545..75715a3173 100644 --- a/.clang-format +++ b/.clang-format @@ -13,9 +13,9 @@ AlignAfterOpenBracket: DontAlign AlignTrailingComments: false AllowAllParametersOfDeclarationOnNextLine: false # AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: true +# AllowShortCaseLabelsOnASingleLine: false AllowShortFunctionsOnASingleLine: Inline -AllowShortIfStatementsOnASingleLine: true +# AllowShortIfStatementsOnASingleLine: false # AllowShortLoopsOnASingleLine: false # AlwaysBreakAfterDefinitionReturnType: None # AlwaysBreakAfterReturnType: None diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000..e6ee412ead --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,44 @@ +--- +Checks: 'clang-diagnostic-*,clang-analyzer-*,-*,modernize-redundant-void-arg,modernize-use-bool-literals,modernize-use-default-member-init,modernize-use-nullptr' +WarningsAsErrors: '' +HeaderFilterRegex: '.*' +AnalyzeTemporaryDtors: false +FormatStyle: none +CheckOptions: +CheckOptions: + - key: cert-dcl16-c.NewSuffixes + value: 'L;LL;LU;LLU' + - key: cert-oop54-cpp.WarnOnlyIfThisHasSuspiciousField + value: '0' + - key: cppcoreguidelines-explicit-virtual-functions.IgnoreDestructors + value: '1' + - key: cppcoreguidelines-non-private-member-variables-in-classes.IgnoreClassesWithAllMemberVariablesBeingPublic + value: '1' + - key: google-readability-braces-around-statements.ShortStatementLines + value: '1' + - key: google-readability-function-size.StatementThreshold + value: '800' + - key: google-readability-namespace-comments.ShortNamespaceLines + value: '10' + - key: google-readability-namespace-comments.SpacesBeforeComments + value: '2' + - key: modernize-loop-convert.MaxCopySize + value: '16' + - key: modernize-loop-convert.MinConfidence + value: reasonable + - key: modernize-loop-convert.NamingStyle + value: CamelCase + - key: modernize-pass-by-value.IncludeStyle + value: llvm + - key: modernize-replace-auto-ptr.IncludeStyle + value: llvm + - key: modernize-use-bool-literals.IgnoreMacros + value: '0' + - key: modernize-use-default-member-init.IgnoreMacros + value: '0' + - key: modernize-use-default-member-init.UseAssignment + value: '1' + - key: modernize-use-nullptr.NullMacros + value: 'NULL' +... + diff --git a/.gitattributes b/.gitattributes index 40a5e6183f..88c00855d8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,3 +13,4 @@ thirdparty/* linguist-vendored *.jar binary *.png binary *.ttf binary +*.tza binary diff --git a/.travis.yml b/.travis.yml index 8cfd7a1a7f..8878935e52 100644 --- a/.travis.yml +++ b/.travis.yml @@ -100,15 +100,15 @@ matrix: packages: - *linux_deps -# - name: Javascript export template (release, emscripten latest) -# stage: build -# env: PLATFORM=javascript TOOLS=no TARGET=release CACHE_NAME=${PLATFORM}-emcc-latest EXTRA_ARGS="use_closure_compiler=yes" -# os: linux -# compiler: clang -# addons: -# apt: -# packages: -# - *linux_deps + - name: Javascript export template (release, emscripten latest) + stage: build + env: PLATFORM=javascript TOOLS=no TARGET=release CACHE_NAME=${PLATFORM}-emcc-latest EXTRA_ARGS="use_closure_compiler=yes" + os: linux + compiler: clang + addons: + apt: + packages: + - *linux_deps before_install: - eval "${MATRIX_EVAL}" diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt index 81da02ab39..443e6fee97 100644 --- a/COPYRIGHT.txt +++ b/COPYRIGHT.txt @@ -303,6 +303,11 @@ Comment: NanoSVG Copyright: 2013-2014, Mikko Mononen License: Zlib +Files: ./thirdparty/oidn/ +Comment: Intel Open Image Denoise +Copyright: 2009-2020, Intel Corporation +License: Apache-2.0 + Files: ./thirdparty/opus/ Comment: Opus Copyright: 2001-2011, Xiph.Org, Skype Limited, Octasic, diff --git a/SConstruct b/SConstruct index e866a4c998..86014b8160 100644 --- a/SConstruct +++ b/SConstruct @@ -272,6 +272,15 @@ if selected_platform in platform_list: else: env = env_base.Clone() + # Custom tools are loaded automatically by SCons from site_scons/site_tools, + # but we want to use a different folder, so we register it manually. + from SCons.Script.Main import _load_site_scons_dir + + _load_site_scons_dir(".", "misc/scons") + + env.Tool("compilation_db") + env.Alias("compiledb", env.CompilationDatabase("compile_commands.json")) + if env["dev"]: env["verbose"] = True env["warnings"] = "extra" @@ -598,6 +607,13 @@ if selected_platform in platform_list: ) } ) + env.Append( + BUILDERS={ + "GLSL_HEADER": env.Builder( + action=run_in_subprocess(gles_builders.build_raw_headers), suffix="glsl.gen.h", src_suffix=".glsl" + ) + } + ) scons_cache_path = os.environ.get("SCONS_CACHE") if scons_cache_path != None: diff --git a/core/SCsub b/core/SCsub index 0ab4f11d87..d53988fae7 100644 --- a/core/SCsub +++ b/core/SCsub @@ -49,6 +49,7 @@ thirdparty_misc_dir = "#thirdparty/misc/" thirdparty_misc_sources = [ # C sources "fastlz.c", + "r128.c", "smaz.c", # C++ sources "hq2x.cpp", diff --git a/core/bind/core_bind.cpp b/core/bind/core_bind.cpp index e2774deb3c..ed0e7b1018 100644 --- a/core/bind/core_bind.cpp +++ b/core/bind/core_bind.cpp @@ -62,6 +62,8 @@ static const unsigned int MONTH_DAYS_TABLE[2][12] = { { 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31 } }; +////// _ResourceLoader ////// + _ResourceLoader *_ResourceLoader::singleton = nullptr; Error _ResourceLoader::load_threaded_request(const String &p_path, const String &p_type_hint, bool p_use_sub_threads) { @@ -150,10 +152,7 @@ void _ResourceLoader::_bind_methods() { BIND_ENUM_CONSTANT(THREAD_LOAD_LOADED); } -_ResourceLoader::_ResourceLoader() { - - singleton = this; -} +////// _ResourceSaver ////// Error _ResourceSaver::save(const String &p_path, const RES &p_resource, SaverFlags p_flags) { ERR_FAIL_COND_V_MSG(p_resource.is_null(), ERR_INVALID_PARAMETER, "Can't save empty resource to path '" + String(p_path) + "'."); @@ -189,10 +188,7 @@ void _ResourceSaver::_bind_methods() { BIND_ENUM_CONSTANT(FLAG_REPLACE_SUBRESOURCE_PATHS); } -_ResourceSaver::_ResourceSaver() { - - singleton = this; -} +////// _OS ////// PackedStringArray _OS::get_connected_midi_inputs() { return OS::get_singleton()->get_connected_midi_inputs(); @@ -319,50 +315,6 @@ bool _OS::has_feature(const String &p_feature) const { return OS::get_singleton()->has_feature(p_feature); } -/* -enum Weekday { - DAY_SUNDAY, - DAY_MONDAY, - DAY_TUESDAY, - DAY_WEDNESDAY, - DAY_THURSDAY, - DAY_FRIDAY, - DAY_SATURDAY -}; - -enum Month { - MONTH_JANUARY, - MONTH_FEBRUARY, - MONTH_MARCH, - MONTH_APRIL, - MONTH_MAY, - MONTH_JUNE, - MONTH_JULY, - MONTH_AUGUST, - MONTH_SEPTEMBER, - MONTH_OCTOBER, - MONTH_NOVEMBER, - MONTH_DECEMBER -}; -*/ -/* -struct Date { - - int year; - Month month; - int day; - Weekday weekday; - bool dst; -}; - -struct Time { - - int hour; - int min; - int sec; -}; -*/ - uint64_t _OS::get_static_memory_usage() const { return OS::get_singleton()->get_static_memory_usage(); @@ -783,6 +735,7 @@ Vector<String> _OS::get_granted_permissions() const { String _OS::get_unique_id() const { return OS::get_singleton()->get_unique_id(); } + _OS *_OS::singleton = nullptr; void _OS::_bind_methods() { @@ -839,8 +792,6 @@ void _OS::_bind_methods() { ClassDB::bind_method(D_METHOD("is_debug_build"), &_OS::is_debug_build); - //ClassDB::bind_method(D_METHOD("get_mouse_button_state"),&_OS::get_mouse_button_state); - ClassDB::bind_method(D_METHOD("dump_memory_to_file", "file"), &_OS::dump_memory_to_file); ClassDB::bind_method(D_METHOD("dump_resources_to_file", "file"), &_OS::dump_resources_to_file); ClassDB::bind_method(D_METHOD("print_resources_in_use", "short"), &_OS::print_resources_in_use, DEFVAL(false)); @@ -914,12 +865,7 @@ void _OS::_bind_methods() { BIND_ENUM_CONSTANT(SYSTEM_DIR_RINGTONES); } -_OS::_OS() { - - singleton = this; -} - -///////////////////// GEOMETRY +////// _Geometry ////// _Geometry *_Geometry::singleton = nullptr; @@ -1296,11 +1242,7 @@ void _Geometry::_bind_methods() { BIND_ENUM_CONSTANT(END_ROUND); } -_Geometry::_Geometry() { - singleton = this; -} - -///////////////////////// FILE +////// _File ////// Error _File::open_encrypted(const String &p_path, ModeFlags p_mode_flags, const Vector<uint8_t> &p_key) { @@ -1736,19 +1678,12 @@ void _File::_bind_methods() { BIND_ENUM_CONSTANT(COMPRESSION_GZIP); } -_File::_File() { - - f = nullptr; - eswap = false; -} - _File::~_File() { - if (f) memdelete(f); } -/////////////////////////////////////////////////////// +////// _Directory ////// Error _Directory::open(const String &p_path) { Error err; @@ -1929,16 +1864,16 @@ void _Directory::_bind_methods() { } _Directory::_Directory() { - d = DirAccess::create(DirAccess::ACCESS_RESOURCES); } _Directory::~_Directory() { - if (d) memdelete(d); } +////// _Marshalls ////// + _Marshalls *_Marshalls::singleton = nullptr; _Marshalls *_Marshalls::get_singleton() { @@ -2046,7 +1981,7 @@ void _Marshalls::_bind_methods() { ClassDB::bind_method(D_METHOD("base64_to_utf8", "base64_str"), &_Marshalls::base64_to_utf8); }; -//////////////// +////// _Semaphore ////// void _Semaphore::wait() { @@ -2070,7 +2005,7 @@ void _Semaphore::_bind_methods() { ClassDB::bind_method(D_METHOD("post"), &_Semaphore::post); } -/////////////// +////// _Mutex ////// void _Mutex::lock() { @@ -2094,7 +2029,7 @@ void _Mutex::_bind_methods() { ClassDB::bind_method(D_METHOD("unlock"), &_Mutex::unlock); } -/////////////// +////// _Thread ////// void _Thread::_start_func(void *ud) { @@ -2204,19 +2139,12 @@ void _Thread::_bind_methods() { BIND_ENUM_CONSTANT(PRIORITY_NORMAL); BIND_ENUM_CONSTANT(PRIORITY_HIGH); } -_Thread::_Thread() { - - active = false; - thread = nullptr; - target_instance = nullptr; -} _Thread::~_Thread() { - ERR_FAIL_COND_MSG(active, "Reference to a Thread object was lost while the thread is still running..."); } -///////////////////////////////////// +////// _ClassDB ////// PackedStringArray _ClassDB::get_class_list() const { @@ -2425,11 +2353,7 @@ void _ClassDB::_bind_methods() { ClassDB::bind_method(D_METHOD("is_class_enabled", "class"), &_ClassDB::is_class_enabled); } -_ClassDB::_ClassDB() { -} -_ClassDB::~_ClassDB() { -} -/////////////////////////////// +////// _Engine ////// void _Engine::set_iterations_per_second(int p_ips) { @@ -2588,9 +2512,7 @@ void _Engine::_bind_methods() { _Engine *_Engine::singleton = nullptr; -_Engine::_Engine() { - singleton = this; -} +////// _JSON ////// void JSONParseResult::_bind_methods() { ClassDB::bind_method(D_METHOD("get_error"), &JSONParseResult::get_error); @@ -2663,7 +2585,3 @@ Ref<JSONParseResult> _JSON::parse(const String &p_json) { } _JSON *_JSON::singleton = nullptr; - -_JSON::_JSON() { - singleton = this; -} diff --git a/core/bind/core_bind.h b/core/bind/core_bind.h index d5f44cdc44..44e573ccbe 100644 --- a/core/bind/core_bind.h +++ b/core/bind/core_bind.h @@ -69,7 +69,7 @@ public: bool has_cached(const String &p_path); bool exists(const String &p_path, const String &p_type_hint = ""); - _ResourceLoader(); + _ResourceLoader() { singleton = this; } }; VARIANT_ENUM_CAST(_ResourceLoader::ThreadLoadStatus); @@ -98,7 +98,7 @@ public: Error save(const String &p_path, const RES &p_resource, SaverFlags p_flags); Vector<String> get_recognized_extensions(const RES &p_resource); - _ResourceSaver(); + _ResourceSaver() { singleton = this; } }; VARIANT_ENUM_CAST(_ResourceSaver::SaverFlags); @@ -245,7 +245,7 @@ public: static _OS *get_singleton() { return singleton; } - _OS(); + _OS() { singleton = this; } }; VARIANT_ENUM_CAST(_OS::VideoDriver); @@ -327,7 +327,7 @@ public: Dictionary make_atlas(const Vector<Size2> &p_rects); - _Geometry(); + _Geometry() { singleton = this; } }; VARIANT_ENUM_CAST(_Geometry::PolyBooleanOperation); @@ -335,10 +335,10 @@ VARIANT_ENUM_CAST(_Geometry::PolyJoinType); VARIANT_ENUM_CAST(_Geometry::PolyEndType); class _File : public Reference { - GDCLASS(_File, Reference); - FileAccess *f; - bool eswap; + + FileAccess *f = nullptr; + bool eswap = false; protected: static void _bind_methods(); @@ -429,7 +429,7 @@ public: uint64_t get_modified_time(const String &p_file) const; - _File(); + _File() {} virtual ~_File(); }; @@ -538,10 +538,10 @@ class _Thread : public Reference { protected: Variant ret; Variant userdata; - volatile bool active; - Object *target_instance; + volatile bool active = false; + Object *target_instance = nullptr; StringName target_method; - Thread *thread; + Thread *thread = nullptr; static void _bind_methods(); static void _start_func(void *ud); @@ -559,7 +559,7 @@ public: bool is_active() const; Variant wait_to_finish(); - _Thread(); + _Thread() {} ~_Thread(); }; @@ -600,8 +600,8 @@ public: bool is_class_enabled(StringName p_class) const; - _ClassDB(); - ~_ClassDB(); + _ClassDB() {} + ~_ClassDB() {} }; class _Engine : public Object { @@ -649,7 +649,7 @@ public: void set_editor_hint(bool p_enabled); bool is_editor_hint() const; - _Engine(); + _Engine() { singleton = this; } }; class _JSON; @@ -661,7 +661,7 @@ class JSONParseResult : public Reference { Error error; String error_string; - int error_line; + int error_line = -1; Variant result; @@ -681,8 +681,7 @@ public: void set_result(const Variant &p_result); Variant get_result() const; - JSONParseResult() : - error_line(-1) {} + JSONParseResult() {} }; class _JSON : public Object { @@ -698,7 +697,7 @@ public: String print(const Variant &p_value, const String &p_indent = "", bool p_sort_keys = false); Ref<JSONParseResult> parse(const String &p_json); - _JSON(); + _JSON() { singleton = this; } }; #endif // CORE_BIND_H diff --git a/core/callable.cpp b/core/callable.cpp index 6a5dc151e5..447cf78bea 100644 --- a/core/callable.cpp +++ b/core/callable.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "callable.h" + #include "core/script_language.h" #include "message_queue.h" #include "object.h" @@ -255,12 +256,7 @@ Callable::~Callable() { } } -Callable::Callable() { - object = 0; -} - CallableCustom::CallableCustom() { - referenced = false; ref_count.init(); } @@ -349,6 +345,7 @@ Array Signal::get_connections() const { } return arr; } + Signal::Signal(const Object *p_object, const StringName &p_name) { ERR_FAIL_COND_MSG(p_object == nullptr, "Object argument to Signal constructor must be non-null"); @@ -356,10 +353,9 @@ Signal::Signal(const Object *p_object, const StringName &p_name) { object = p_object->get_instance_id(); name = p_name; } + Signal::Signal(ObjectID p_object, const StringName &p_name) { object = p_object; name = p_name; } -Signal::Signal() { -} diff --git a/core/callable.h b/core/callable.h index 7fa024dccd..1f6ff48d4f 100644 --- a/core/callable.h +++ b/core/callable.h @@ -49,7 +49,7 @@ class Callable { //needs to be max 16 bytes in 64 bits StringName method; union { - uint64_t object; + uint64_t object = 0; CallableCustom *custom; }; @@ -75,7 +75,7 @@ public: return method == StringName() && object == 0; } _FORCE_INLINE_ bool is_custom() const { - return method == StringName() && custom != 0; + return method == StringName() && custom != nullptr; } _FORCE_INLINE_ bool is_standard() const { return method != StringName(); @@ -100,14 +100,14 @@ public: Callable(ObjectID p_object, const StringName &p_method); Callable(CallableCustom *p_custom); Callable(const Callable &p_callable); - Callable(); + Callable() {} ~Callable(); }; class CallableCustom { friend class Callable; SafeRefCount ref_count; - bool referenced; + bool referenced = false; public: typedef bool (*CompareEqualFunc)(const CallableCustom *p_a, const CallableCustom *p_b); @@ -156,7 +156,7 @@ public: Array get_connections() const; Signal(const Object *p_object, const StringName &p_name); Signal(ObjectID p_object, const StringName &p_name); - Signal(); + Signal() {} }; #endif // CALLABLE_H diff --git a/core/class_db.cpp b/core/class_db.cpp index 5e49688e9b..dd9fba16d3 100644 --- a/core/class_db.cpp +++ b/core/class_db.cpp @@ -258,19 +258,6 @@ HashMap<StringName, ClassDB::ClassInfo> ClassDB::classes; HashMap<StringName, StringName> ClassDB::resource_base_extensions; HashMap<StringName, StringName> ClassDB::compat_classes; -ClassDB::ClassInfo::ClassInfo() { - - api = API_NONE; - class_ptr = nullptr; - creation_func = nullptr; - inherits_ptr = nullptr; - disabled = false; - exposed = false; -} - -ClassDB::ClassInfo::~ClassInfo() { -} - bool ClassDB::is_parent_class(const StringName &p_class, const StringName &p_inherits) { OBJTYPE_RLOCK; @@ -1456,16 +1443,19 @@ Variant ClassDB::class_get_default_property_value(const StringName &p_class, con } if (!default_values.has(p_class)) { - if (r_valid != nullptr) *r_valid = false; + if (r_valid != nullptr) + *r_valid = false; return Variant(); } if (!default_values[p_class].has(p_property)) { - if (r_valid != nullptr) *r_valid = false; + if (r_valid != nullptr) + *r_valid = false; return Variant(); } - if (r_valid != nullptr) *r_valid = true; + if (r_valid != nullptr) + *r_valid = true; return default_values[p_class][p_property]; } diff --git a/core/class_db.h b/core/class_db.h index f760aa1738..32d2148048 100644 --- a/core/class_db.h +++ b/core/class_db.h @@ -114,9 +114,10 @@ public: struct ClassInfo { - APIType api; - ClassInfo *inherits_ptr; - void *class_ptr; + APIType api = API_NONE; + ClassInfo *inherits_ptr = nullptr; + void *class_ptr = nullptr; + HashMap<StringName, MethodBind *> method_map; HashMap<StringName, int> constant_map; HashMap<StringName, List<StringName>> enum_map; @@ -133,11 +134,12 @@ public: StringName inherits; StringName name; - bool disabled; - bool exposed; - Object *(*creation_func)(); - ClassInfo(); - ~ClassInfo(); + bool disabled = false; + bool exposed = false; + Object *(*creation_func)() = nullptr; + + ClassInfo() {} + ~ClassInfo() {} }; template <class T> diff --git a/core/color.cpp b/core/color.cpp index 03aeb2085b..79b9f70a99 100644 --- a/core/color.cpp +++ b/core/color.cpp @@ -406,7 +406,8 @@ bool Color::html_is_valid(const String &p_color) { } Color Color::named(const String &p_name) { - if (_named_colors.empty()) _populate_named_colors(); // from color_names.inc + if (_named_colors.empty()) + _populate_named_colors(); // from color_names.inc String name = p_name; // Normalize name name = name.replace(" ", ""); diff --git a/core/color.h b/core/color.h index 8b689fdde1..066a3f6696 100644 --- a/core/color.h +++ b/core/color.h @@ -44,7 +44,7 @@ struct Color { float b; float a; }; - float components[4]; + float components[4] = { 0, 0, 0, 1.0 }; }; bool operator==(const Color &p_color) const { return (r == p_color.r && g == p_color.g && b == p_color.b && a == p_color.a); } @@ -204,15 +204,7 @@ struct Color { _FORCE_INLINE_ bool operator<(const Color &p_color) const; //used in set keys operator String() const; - /** - * No construct parameters, r=0, g=0, b=0. a=1 - */ - _FORCE_INLINE_ Color() { - r = 0; - g = 0; - b = 0; - a = 1.0; - } + _FORCE_INLINE_ Color() {} /** * RGB / RGBA construct parameters. Alpha is optional, but defaults to 1.0 diff --git a/core/color_names.inc b/core/color_names.inc index 428a8473fe..f249ee9868 100644 --- a/core/color_names.inc +++ b/core/color_names.inc @@ -3,7 +3,8 @@ static Map<String, Color> _named_colors; static void _populate_named_colors() { - if (!_named_colors.empty()) return; + if (!_named_colors.empty()) + return; _named_colors.insert("aliceblue", Color(0.94, 0.97, 1.00)); _named_colors.insert("antiquewhite", Color(0.98, 0.92, 0.84)); _named_colors.insert("aqua", Color(0.00, 1.00, 1.00)); diff --git a/core/command_queue_mt.cpp b/core/command_queue_mt.cpp index 3ce769c72c..60ab5d133b 100644 --- a/core/command_queue_mt.cpp +++ b/core/command_queue_mt.cpp @@ -100,24 +100,11 @@ tryagain: } CommandQueueMT::CommandQueueMT(bool p_sync) { - - read_ptr = 0; - write_ptr = 0; - dealloc_ptr = 0; - command_mem = (uint8_t *)memalloc(COMMAND_MEM_SIZE); - - for (int i = 0; i < SYNC_SEMAPHORES; i++) { - - sync_sems[i].in_use = false; - } if (p_sync) sync = memnew(Semaphore); - else - sync = nullptr; } CommandQueueMT::~CommandQueueMT() { - if (sync) memdelete(sync); memfree(command_mem); diff --git a/core/command_queue_mt.h b/core/command_queue_mt.h index 558453bdf5..af8bbb24c6 100644 --- a/core/command_queue_mt.h +++ b/core/command_queue_mt.h @@ -253,7 +253,8 @@ cmd->method = p_method; \ SEMIC_SEP_LIST(CMD_ASSIGN_PARAM, N); \ unlock(); \ - if (sync) sync->post(); \ + if (sync) \ + sync->post(); \ } #define CMD_RET_TYPE(N) CommandRet##N<T, M, COMMA_SEP_LIST(TYPE_ARG, N) COMMA(N) R> @@ -269,7 +270,8 @@ cmd->ret = r_ret; \ cmd->sync_sem = ss; \ unlock(); \ - if (sync) sync->post(); \ + if (sync) \ + sync->post(); \ ss->sem.wait(); \ ss->in_use = false; \ } @@ -286,7 +288,8 @@ SEMIC_SEP_LIST(CMD_ASSIGN_PARAM, N); \ cmd->sync_sem = ss; \ unlock(); \ - if (sync) sync->post(); \ + if (sync) \ + sync->post(); \ ss->sem.wait(); \ ss->in_use = false; \ } @@ -298,14 +301,14 @@ class CommandQueueMT { struct SyncSemaphore { Semaphore sem; - bool in_use; + bool in_use = false; }; struct CommandBase { virtual void call() = 0; - virtual void post(){}; - virtual ~CommandBase(){}; + virtual void post() {} + virtual ~CommandBase() {} }; struct SyncCommand : public CommandBase { @@ -336,13 +339,13 @@ class CommandQueueMT { SYNC_SEMAPHORES = 8 }; - uint8_t *command_mem; - uint32_t read_ptr; - uint32_t write_ptr; - uint32_t dealloc_ptr; + uint8_t *command_mem = (uint8_t *)memalloc(COMMAND_MEM_SIZE); + uint32_t read_ptr = 0; + uint32_t write_ptr = 0; + uint32_t dealloc_ptr = 0; SyncSemaphore sync_sems[SYNC_SEMAPHORES]; Mutex mutex; - Semaphore *sync; + Semaphore *sync = nullptr; template <class T> T *allocate() { @@ -418,12 +421,14 @@ class CommandQueueMT { } bool flush_one(bool p_lock = true) { - if (p_lock) lock(); + if (p_lock) + lock(); tryagain: // tried to read an empty queue if (read_ptr == write_ptr) { - if (p_lock) unlock(); + if (p_lock) + unlock(); return false; } @@ -442,15 +447,18 @@ class CommandQueueMT { read_ptr += size; - if (p_lock) unlock(); + if (p_lock) + unlock(); cmd->call(); - if (p_lock) lock(); + if (p_lock) + lock(); cmd->post(); cmd->~CommandBase(); *(uint32_t *)&command_mem[size_ptr] &= ~1; - if (p_lock) unlock(); + if (p_lock) + unlock(); return true; } diff --git a/core/compressed_translation.cpp b/core/compressed_translation.cpp index 0225524bc8..9e6ba6cde2 100644 --- a/core/compressed_translation.cpp +++ b/core/compressed_translation.cpp @@ -288,6 +288,3 @@ void PHashTranslation::_bind_methods() { ClassDB::bind_method(D_METHOD("generate", "from"), &PHashTranslation::generate); } - -PHashTranslation::PHashTranslation() { -} diff --git a/core/compressed_translation.h b/core/compressed_translation.h index d599240dfe..fff4350caa 100644 --- a/core/compressed_translation.h +++ b/core/compressed_translation.h @@ -86,7 +86,7 @@ public: virtual StringName get_message(const StringName &p_src_text) const; //overridable for other implementations void generate(const Ref<Translation> &p_from); - PHashTranslation(); + PHashTranslation() {} }; #endif // COMPRESSED_TRANSLATION_H diff --git a/core/container_type_validate.h b/core/container_type_validate.h index 3721668033..7809e5f385 100644 --- a/core/container_type_validate.h +++ b/core/container_type_validate.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* container_type_validate.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef CONTAINER_TYPE_VALIDATE_H #define CONTAINER_TYPE_VALIDATE_H diff --git a/core/cowdata.h b/core/cowdata.h index 975a572906..b63a407511 100644 --- a/core/cowdata.h +++ b/core/cowdata.h @@ -54,7 +54,7 @@ class CowData { friend class VMap; private: - mutable T *_ptr; + mutable T *_ptr = nullptr; // internal helpers @@ -132,7 +132,7 @@ public: } _FORCE_INLINE_ void clear() { resize(0); } - _FORCE_INLINE_ bool empty() const { return _ptr == 0; } + _FORCE_INLINE_ bool empty() const { return _ptr == nullptr; } _FORCE_INLINE_ void set(int p_index, const T &p_elem) { @@ -183,7 +183,7 @@ public: int find(const T &p_val, int p_from = 0) const; - _FORCE_INLINE_ CowData(); + _FORCE_INLINE_ CowData() {} _FORCE_INLINE_ ~CowData(); _FORCE_INLINE_ CowData(CowData<T> &p_from) { _ref(p_from); }; }; @@ -367,12 +367,6 @@ void CowData<T>::_ref(const CowData &p_from) { } template <class T> -CowData<T>::CowData() { - - _ptr = nullptr; -} - -template <class T> CowData<T>::~CowData() { _unref(_ptr); diff --git a/core/crypto/crypto.cpp b/core/crypto/crypto.cpp index 233f62bd15..585731ac9f 100644 --- a/core/crypto/crypto.cpp +++ b/core/crypto/crypto.cpp @@ -94,9 +94,6 @@ Ref<X509Certificate> Crypto::generate_self_signed_certificate(Ref<CryptoKey> p_k ERR_FAIL_V_MSG(nullptr, "generate_self_signed_certificate is not available when mbedtls module is disabled."); } -Crypto::Crypto() { -} - /// Resource loader/saver RES ResourceFormatLoaderCrypto::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { diff --git a/core/crypto/crypto.h b/core/crypto/crypto.h index d9becab958..cf21648a4a 100644 --- a/core/crypto/crypto.h +++ b/core/crypto/crypto.h @@ -31,11 +31,10 @@ #ifndef CRYPTO_H #define CRYPTO_H -#include "core/reference.h" -#include "core/resource.h" - #include "core/io/resource_loader.h" #include "core/io/resource_saver.h" +#include "core/reference.h" +#include "core/resource.h" class CryptoKey : public Resource { GDCLASS(CryptoKey, Resource); @@ -80,7 +79,7 @@ public: virtual Ref<CryptoKey> generate_rsa(int p_bytes); virtual Ref<X509Certificate> generate_self_signed_certificate(Ref<CryptoKey> p_key, String p_issuer_name, String p_not_before, String p_not_after); - Crypto(); + Crypto() {} }; class ResourceFormatLoaderCrypto : public ResourceFormatLoader { diff --git a/core/crypto/hashing_context.cpp b/core/crypto/hashing_context.cpp index af43bc9bad..0b21dead74 100644 --- a/core/crypto/hashing_context.cpp +++ b/core/crypto/hashing_context.cpp @@ -128,10 +128,6 @@ void HashingContext::_bind_methods() { BIND_ENUM_CONSTANT(HASH_SHA256); } -HashingContext::HashingContext() { - ctx = nullptr; -} - HashingContext::~HashingContext() { if (ctx != nullptr) _delete_ctx(); diff --git a/core/crypto/hashing_context.h b/core/crypto/hashing_context.h index 230ba7ee85..f9454fa891 100644 --- a/core/crypto/hashing_context.h +++ b/core/crypto/hashing_context.h @@ -44,7 +44,7 @@ public: }; private: - void *ctx; + void *ctx = nullptr; HashType type; protected: @@ -57,7 +57,7 @@ public: Error update(PackedByteArray p_chunk); PackedByteArray finish(); - HashingContext(); + HashingContext() {} ~HashingContext(); }; diff --git a/core/debugger/debugger_marshalls.h b/core/debugger/debugger_marshalls.h index 04229c0afc..9ba316d997 100644 --- a/core/debugger/debugger_marshalls.h +++ b/core/debugger/debugger_marshalls.h @@ -42,11 +42,8 @@ struct DebuggerMarshalls { String format; String type; RID id; - int vram; + int vram = 0; bool operator<(const ResourceInfo &p_img) const { return vram == p_img.vram ? id < p_img.id : vram > p_img.vram; } - ResourceInfo() { - vram = 0; - } }; struct ResourceUsage { @@ -119,10 +116,7 @@ struct DebuggerMarshalls { struct ScriptStackVariable { String name; Variant value; - int type; - ScriptStackVariable() { - type = -1; - } + int type = -1; Array serialize(int max_size = 1 << 20); // 1 MiB default. bool deserialize(const Array &p_arr); @@ -137,27 +131,18 @@ struct DebuggerMarshalls { }; struct OutputError { - int hr; - int min; - int sec; - int msec; + int hr = -1; + int min = -1; + int sec = -1; + int msec = -1; String source_file; String source_func; - int source_line; + int source_line = -1; String error; String error_descr; - bool warning; + bool warning = false; Vector<ScriptLanguage::StackInfo> callstack; - OutputError() { - hr = -1; - min = -1; - sec = -1; - msec = -1; - source_line = -1; - warning = false; - } - Array serialize(); bool deserialize(const Array &p_arr); }; diff --git a/core/debugger/engine_debugger.cpp b/core/debugger/engine_debugger.cpp index bfe38d0f4a..04eba84b30 100644 --- a/core/debugger/engine_debugger.cpp +++ b/core/debugger/engine_debugger.cpp @@ -32,6 +32,7 @@ #include "core/debugger/local_debugger.h" #include "core/debugger/remote_debugger.h" +#include "core/debugger/remote_debugger_peer.h" #include "core/debugger/script_debugger.h" #include "core/os/os.h" @@ -40,6 +41,7 @@ ScriptDebugger *EngineDebugger::script_debugger = nullptr; Map<StringName, EngineDebugger::Profiler> EngineDebugger::profilers; Map<StringName, EngineDebugger::Capture> EngineDebugger::captures; +Map<String, EngineDebugger::CreatePeerFunc> EngineDebugger::protocols; void EngineDebugger::register_profiler(const StringName &p_name, const Profiler &p_func) { ERR_FAIL_COND_MSG(profilers.has(p_name), "Profiler already registered: " + p_name); @@ -66,6 +68,11 @@ void EngineDebugger::unregister_message_capture(const StringName &p_name) { captures.erase(p_name); } +void EngineDebugger::register_uri_handler(const String &p_protocol, CreatePeerFunc p_func) { + ERR_FAIL_COND_MSG(protocols.has(p_protocol), "Protocol handler already registered: " + p_protocol); + protocols.insert(p_protocol, p_func); +} + void EngineDebugger::profiler_enable(const StringName &p_name, bool p_enabled, const Array &p_opts) { ERR_FAIL_COND_MSG(!profilers.has(p_name), "Can't change profiler state, no profiler: " + p_name); Profiler &p = profilers[p_name]; @@ -125,6 +132,7 @@ void EngineDebugger::iteration(uint64_t p_frame_ticks, uint64_t p_idle_ticks, ui } void EngineDebugger::initialize(const String &p_uri, bool p_skip_breakpoints, Vector<String> p_breakpoints) { + register_uri_handler("tcp://", RemoteDebuggerPeerTCP::create); // TCP is the default protocol. Platforms/modules can add more. if (p_uri.empty()) return; if (p_uri == "local://") { @@ -132,10 +140,14 @@ void EngineDebugger::initialize(const String &p_uri, bool p_skip_breakpoints, Ve script_debugger = memnew(ScriptDebugger); // Tell the OS that we want to handle termination signals. OS::get_singleton()->initialize_debugging(); - } else { - singleton = RemoteDebugger::create_for_uri(p_uri); - if (!singleton) + } else if (p_uri.find("://") >= 0) { + const String proto = p_uri.substr(0, p_uri.find("://") + 3); + if (!protocols.has(proto)) + return; + RemoteDebuggerPeer *peer = protocols[proto](p_uri); + if (!peer) return; + singleton = memnew(RemoteDebugger(Ref<RemoteDebuggerPeer>(peer))); script_debugger = memnew(ScriptDebugger); // Notify editor of our pid (to allow focus stealing). Array msg; @@ -160,22 +172,24 @@ void EngineDebugger::initialize(const String &p_uri, bool p_skip_breakpoints, Ve } void EngineDebugger::deinitialize() { - if (!singleton) - return; - - // Stop all profilers - for (Map<StringName, Profiler>::Element *E = profilers.front(); E; E = E->next()) { - if (E->get().active) - singleton->profiler_enable(E->key(), false); + if (singleton) { + // Stop all profilers + for (Map<StringName, Profiler>::Element *E = profilers.front(); E; E = E->next()) { + if (E->get().active) + singleton->profiler_enable(E->key(), false); + } + + // Flush any remaining message + singleton->poll_events(false); + + memdelete(singleton); + singleton = nullptr; } - // Flush any remaining message - singleton->poll_events(false); - - memdelete(singleton); - singleton = nullptr; + // Clear profilers/captuers/protocol handlers. profilers.clear(); captures.clear(); + protocols.clear(); } EngineDebugger::~EngineDebugger() { diff --git a/core/debugger/engine_debugger.h b/core/debugger/engine_debugger.h index 7b6b77ca9b..8d5ebb2394 100644 --- a/core/debugger/engine_debugger.h +++ b/core/debugger/engine_debugger.h @@ -38,6 +38,7 @@ #include "core/variant.h" #include "core/vector.h" +class RemoteDebuggerPeer; class ScriptDebugger; class EngineDebugger { @@ -45,8 +46,11 @@ public: typedef void (*ProfilingToggle)(void *p_user, bool p_enable, const Array &p_opts); typedef void (*ProfilingTick)(void *p_user, float p_frame_time, float p_idle_time, float p_physics_time, float p_physics_frame_time); typedef void (*ProfilingAdd)(void *p_user, const Array &p_arr); + typedef Error (*CaptureFunc)(void *p_user, const String &p_msg, const Array &p_args, bool &r_captured); + typedef RemoteDebuggerPeer *(*CreatePeerFunc)(const String &p_uri); + class Profiler { friend class EngineDebugger; @@ -94,6 +98,7 @@ protected: static Map<StringName, Profiler> profilers; static Map<StringName, Capture> captures; + static Map<String, CreatePeerFunc> protocols; public: _FORCE_INLINE_ static EngineDebugger *get_singleton() { return singleton; } @@ -113,6 +118,8 @@ public: static void unregister_message_capture(const StringName &p_name); static bool has_capture(const StringName &p_name); + static void register_uri_handler(const String &p_protocol, CreatePeerFunc p_func); + void iteration(uint64_t p_frame_ticks, uint64_t p_idle_ticks, uint64_t p_physics_ticks, float p_physics_frame_time); void profiler_enable(const StringName &p_name, bool p_enabled, const Array &p_opts = Array()); Error capture_parse(const StringName &p_name, const String &p_msg, const Array &p_args, bool &r_captured); diff --git a/core/debugger/remote_debugger.cpp b/core/debugger/remote_debugger.cpp index 97d5c71b6f..18c9602eb2 100644 --- a/core/debugger/remote_debugger.cpp +++ b/core/debugger/remote_debugger.cpp @@ -673,6 +673,9 @@ void RemoteDebugger::debug(bool p_can_continue, bool p_is_error_breakpoint) { ERR_FAIL_COND_MSG(!is_peer_connected(), "Script Debugger failed to connect, but being used anyway."); + if (!peer->can_block()) + return; // Peer does not support blocking IO. We could at least send the error though. + ScriptLanguage *script_lang = script_debugger->get_break_language(); const String error_str = script_lang ? script_lang->debug_get_error() : ""; Array msg; @@ -886,13 +889,6 @@ Error RemoteDebugger::_profiler_capture(const String &p_cmd, const Array &p_data return OK; } -RemoteDebugger *RemoteDebugger::create_for_uri(const String &p_uri) { - Ref<RemoteDebuggerPeer> peer = RemoteDebuggerPeer::create_from_uri(p_uri); - if (peer.is_valid()) - return memnew(RemoteDebugger(peer)); - return nullptr; -} - RemoteDebugger::RemoteDebugger(Ref<RemoteDebuggerPeer> p_peer) { peer = p_peer; max_chars_per_second = GLOBAL_GET("network/limits/debugger/max_chars_per_second"); diff --git a/core/debugger/remote_debugger.h b/core/debugger/remote_debugger.h index cac0bc3730..dc7e4436e1 100644 --- a/core/debugger/remote_debugger.h +++ b/core/debugger/remote_debugger.h @@ -108,8 +108,6 @@ private: Error _try_capture(const String &p_name, const Array &p_data, bool &r_captured); public: - static RemoteDebugger *create_for_uri(const String &p_uri); - // Overrides void poll_events(bool p_is_idle); void send_message(const String &p_message, const Array &p_args); diff --git a/core/debugger/remote_debugger_peer.cpp b/core/debugger/remote_debugger_peer.cpp index ed04431177..458c46df93 100644 --- a/core/debugger/remote_debugger_peer.cpp +++ b/core/debugger/remote_debugger_peer.cpp @@ -218,9 +218,8 @@ void RemoteDebuggerPeerTCP::_poll() { } } -Ref<RemoteDebuggerPeer> RemoteDebuggerPeer::create_from_uri(const String p_uri) { - if (!p_uri.begins_with("tcp://")) - return Ref<RemoteDebuggerPeer>(); // Only TCP supported for now, more to come. +RemoteDebuggerPeer *RemoteDebuggerPeerTCP::create(const String &p_uri) { + ERR_FAIL_COND_V(!p_uri.begins_with("tcp://"), nullptr); String debug_host = p_uri.replace("tcp://", ""); uint16_t debug_port = 6007; @@ -230,10 +229,13 @@ Ref<RemoteDebuggerPeer> RemoteDebuggerPeer::create_from_uri(const String p_uri) debug_port = debug_host.substr(sep_pos + 1).to_int(); debug_host = debug_host.substr(0, sep_pos); } - Ref<RemoteDebuggerPeerTCP> peer = Ref<RemoteDebuggerPeer>(memnew(RemoteDebuggerPeerTCP)); + + RemoteDebuggerPeerTCP *peer = memnew(RemoteDebuggerPeerTCP); Error err = peer->connect_to_host(debug_host, debug_port); - if (err != OK) - return Ref<RemoteDebuggerPeer>(); + if (err != OK) { + memdelete(peer); + return nullptr; + } return peer; } diff --git a/core/debugger/remote_debugger_peer.h b/core/debugger/remote_debugger_peer.h index e4b838f145..3a75a2a02b 100644 --- a/core/debugger/remote_debugger_peer.h +++ b/core/debugger/remote_debugger_peer.h @@ -42,7 +42,6 @@ protected: int max_queued_messages = 4096; public: - static Ref<RemoteDebuggerPeer> create_from_uri(const String p_uri); virtual bool is_peer_connected() = 0; virtual bool has_message() = 0; virtual Error put_message(const Array &p_arr) = 0; @@ -50,6 +49,7 @@ public: virtual void close() = 0; virtual void poll() = 0; virtual int get_max_message_size() const = 0; + virtual bool can_block() const { return true; } // If blocking io is allowed on main thread (debug). RemoteDebuggerPeer(); }; @@ -77,6 +77,8 @@ private: void _read_in(); public: + static RemoteDebuggerPeer *create(const String &p_uri); + Error connect_to_host(const String &p_host, uint16_t p_port); void poll(); diff --git a/core/engine.cpp b/core/engine.cpp index 5361e09a8a..86ce0395b9 100644 --- a/core/engine.cpp +++ b/core/engine.cpp @@ -217,23 +217,7 @@ Engine *Engine::get_singleton() { bool Engine::is_abort_on_gpu_errors_enabled() const { return abort_on_gpu_errors; } -Engine::Engine() { +Engine::Engine() { singleton = this; - frames_drawn = 0; - ips = 60; - physics_jitter_fix = 0.5; - _physics_interpolation_fraction = 0.0f; - _frame_delay = 0; - _fps = 1; - _target_fps = 0; - _time_scale = 1.0; - _pixel_snap = false; - _physics_frames = 0; - _idle_frames = 0; - _in_physics = false; - _frame_ticks = 0; - _frame_step = 0; - editor_hint = false; - abort_on_gpu_errors = false; } diff --git a/core/engine.h b/core/engine.h index 8512779d4c..aa28b35814 100644 --- a/core/engine.h +++ b/core/engine.h @@ -51,28 +51,28 @@ public: private: friend class Main; - uint64_t frames_drawn; - uint32_t _frame_delay; - uint64_t _frame_ticks; - float _frame_step; - - int ips; - float physics_jitter_fix; - float _fps; - int _target_fps; - float _time_scale; - bool _pixel_snap; - uint64_t _physics_frames; - float _physics_interpolation_fraction; - bool abort_on_gpu_errors; - - uint64_t _idle_frames; - bool _in_physics; + uint64_t frames_drawn = 0; + uint32_t _frame_delay = 0; + uint64_t _frame_ticks = 0; + float _frame_step = 0; + + int ips = 60; + float physics_jitter_fix = 0.5; + float _fps = 1; + int _target_fps = 0; + float _time_scale = 1.0; + bool _pixel_snap = false; + uint64_t _physics_frames = 0; + float _physics_interpolation_fraction = 0.0f; + bool abort_on_gpu_errors = false; + + uint64_t _idle_frames = 0; + bool _in_physics = false; List<Singleton> singletons; Map<StringName, Object *> singleton_ptrs; - bool editor_hint; + bool editor_hint = false; static Engine *singleton; diff --git a/core/error_macros.h b/core/error_macros.h index 18c46c9e7d..eb2cc5215d 100644 --- a/core/error_macros.h +++ b/core/error_macros.h @@ -48,16 +48,12 @@ typedef void (*ErrorHandlerFunc)(void *, const char *, const char *, int p_line, struct ErrorHandlerList { - ErrorHandlerFunc errfunc; - void *userdata; + ErrorHandlerFunc errfunc = nullptr; + void *userdata = nullptr; - ErrorHandlerList *next; + ErrorHandlerList *next = nullptr; - ErrorHandlerList() { - errfunc = 0; - next = 0; - userdata = 0; - } + ErrorHandlerList() {} }; void add_error_handler(ErrorHandlerList *p_handler); @@ -530,19 +526,19 @@ void _err_print_index_error(const char *p_function, const char *p_file, int p_li * Prints `m_msg`. */ #define ERR_PRINT(m_msg) \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg)) + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg) /** * Prints `m_msg` once during the application lifetime. */ -#define ERR_PRINT_ONCE(m_msg) \ - if (1) { \ - static bool first_print = true; \ - if (first_print) { \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg)); \ - first_print = false; \ - } \ - } else \ +#define ERR_PRINT_ONCE(m_msg) \ + if (1) { \ + static bool first_print = true; \ + if (first_print) { \ + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg); \ + first_print = false; \ + } \ + } else \ ((void)0) // Print warning message macros. @@ -553,21 +549,21 @@ void _err_print_index_error(const char *p_function, const char *p_file, int p_li * If warning about deprecated usage, use `WARN_DEPRECATED` or `WARN_DEPRECATED_MSG` instead. */ #define WARN_PRINT(m_msg) \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg), ERR_HANDLER_WARNING) + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg, ERR_HANDLER_WARNING) /** * Prints `m_msg` once during the application lifetime. * * If warning about deprecated usage, use `WARN_DEPRECATED` or `WARN_DEPRECATED_MSG` instead. */ -#define WARN_PRINT_ONCE(m_msg) \ - if (1) { \ - static bool first_print = true; \ - if (first_print) { \ - _err_print_error(FUNCTION_STR, __FILE__, __LINE__, DEBUG_STR(m_msg), ERR_HANDLER_WARNING); \ - first_print = false; \ - } \ - } else \ +#define WARN_PRINT_ONCE(m_msg) \ + if (1) { \ + static bool first_print = true; \ + if (first_print) { \ + _err_print_error(FUNCTION_STR, __FILE__, __LINE__, m_msg, ERR_HANDLER_WARNING); \ + first_print = false; \ + } \ + } else \ ((void)0) // Print deprecated warning message macros. diff --git a/core/func_ref.cpp b/core/func_ref.cpp index 338c17946b..ad29f4488d 100644 --- a/core/func_ref.cpp +++ b/core/func_ref.cpp @@ -94,6 +94,3 @@ void FuncRef::_bind_methods() { ClassDB::bind_method(D_METHOD("set_function", "name"), &FuncRef::set_function); ClassDB::bind_method(D_METHOD("is_valid"), &FuncRef::is_valid); } - -FuncRef::FuncRef() { -} diff --git a/core/func_ref.h b/core/func_ref.h index 8cb3be6e61..07b361db2d 100644 --- a/core/func_ref.h +++ b/core/func_ref.h @@ -48,7 +48,8 @@ public: void set_instance(Object *p_obj); void set_function(const StringName &p_func); bool is_valid() const; - FuncRef(); + + FuncRef() {} }; #endif // FUNC_REF_H diff --git a/core/global_constants.cpp b/core/global_constants.cpp index afcb283e05..4f415a2056 100644 --- a/core/global_constants.cpp +++ b/core/global_constants.cpp @@ -88,7 +88,8 @@ static Vector<_GlobalConstant> _global_constants; VARIANT_ENUM_CAST(KeyList); VARIANT_ENUM_CAST(KeyModifierMask); VARIANT_ENUM_CAST(ButtonList); -VARIANT_ENUM_CAST(JoystickList); +VARIANT_ENUM_CAST(JoyButtonList); +VARIANT_ENUM_CAST(JoyAxisList); VARIANT_ENUM_CAST(MidiMessageList); void register_global_constants() { @@ -388,90 +389,70 @@ void register_global_constants() { BIND_GLOBAL_ENUM_CONSTANT(BUTTON_MASK_XBUTTON1); BIND_GLOBAL_ENUM_CONSTANT(BUTTON_MASK_XBUTTON2); - //joypads - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_0); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_1); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_2); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_3); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_4); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_5); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_6); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_7); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_8); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_9); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_10); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_11); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_12); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_13); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_14); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_15); - BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_MAX); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_CIRCLE); + // Joypad buttons + BIND_GLOBAL_ENUM_CONSTANT(JOY_INVALID_BUTTON); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_A); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_B); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_BACK); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_GUIDE); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_START); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_LEFT_STICK); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_RIGHT_STICK); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_LEFT_SHOULDER); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_RIGHT_SHOULDER); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_DPAD_UP); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_DPAD_DOWN); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_DPAD_LEFT); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_DPAD_RIGHT); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SDL_BUTTONS); BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_CROSS); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_CIRCLE); BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_SQUARE); BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_TRIANGLE); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_B); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_SELECT); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_START); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_PS); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_L1); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_R1); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_L3); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SONY_R3); BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_A); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_B); BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_X); BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_BACK); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_START); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_HOME); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_LS); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_RS); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_LB); + BIND_GLOBAL_ENUM_CONSTANT(JOY_XBOX_RB); + BIND_GLOBAL_ENUM_CONSTANT(JOY_BUTTON_MAX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DS_A); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DS_B); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DS_X); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DS_Y); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_VR_GRIP); - BIND_GLOBAL_ENUM_CONSTANT(JOY_VR_PAD); - BIND_GLOBAL_ENUM_CONSTANT(JOY_VR_TRIGGER); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_OCULUS_AX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_OCULUS_BY); - BIND_GLOBAL_ENUM_CONSTANT(JOY_OCULUS_MENU); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_OPENVR_MENU); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_SELECT); - BIND_GLOBAL_ENUM_CONSTANT(JOY_START); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DPAD_UP); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DPAD_DOWN); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DPAD_LEFT); - BIND_GLOBAL_ENUM_CONSTANT(JOY_DPAD_RIGHT); - BIND_GLOBAL_ENUM_CONSTANT(JOY_L); - BIND_GLOBAL_ENUM_CONSTANT(JOY_L2); - BIND_GLOBAL_ENUM_CONSTANT(JOY_L3); - BIND_GLOBAL_ENUM_CONSTANT(JOY_R); - BIND_GLOBAL_ENUM_CONSTANT(JOY_R2); - BIND_GLOBAL_ENUM_CONSTANT(JOY_R3); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_0); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_1); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_2); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_3); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_4); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_5); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_6); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_7); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_8); - BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_9); + // Joypad axes + BIND_GLOBAL_ENUM_CONSTANT(JOY_INVALID_AXIS); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_LEFT_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_LEFT_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_RIGHT_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_RIGHT_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_TRIGGER_LEFT); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_TRIGGER_RIGHT); + BIND_GLOBAL_ENUM_CONSTANT(JOY_SDL_AXES); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_0_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_0_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_1_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_1_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_2_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_2_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_3_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_3_Y); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_4_X); + BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_4_Y); BIND_GLOBAL_ENUM_CONSTANT(JOY_AXIS_MAX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_LX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_LY); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_RX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_RY); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_L2); - BIND_GLOBAL_ENUM_CONSTANT(JOY_ANALOG_R2); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_VR_ANALOG_TRIGGER); - BIND_GLOBAL_ENUM_CONSTANT(JOY_VR_ANALOG_GRIP); - - BIND_GLOBAL_ENUM_CONSTANT(JOY_OPENVR_TOUCHPADX); - BIND_GLOBAL_ENUM_CONSTANT(JOY_OPENVR_TOUCHPADY); - // midi BIND_GLOBAL_ENUM_CONSTANT(MIDI_MESSAGE_NOTE_OFF); BIND_GLOBAL_ENUM_CONSTANT(MIDI_MESSAGE_NOTE_ON); diff --git a/core/hash_map.h b/core/hash_map.h index f27a86cc02..cfc0f87010 100644 --- a/core/hash_map.h +++ b/core/hash_map.h @@ -75,8 +75,8 @@ public: friend class HashMap; uint32_t hash; - Element *next; - Element() { next = 0; } + Element *next = nullptr; + Element() {} Pair pair; public: @@ -94,9 +94,9 @@ public: }; private: - Element **hash_table; - uint8_t hash_table_power; - uint32_t elements; + Element **hash_table = nullptr; + uint8_t hash_table_power = 0; + uint32_t elements = 0; void make_hash_table() { @@ -107,7 +107,7 @@ private: hash_table_power = MIN_HASH_TABLE_POWER; elements = 0; for (int i = 0; i < (1 << MIN_HASH_TABLE_POWER); i++) - hash_table[i] = 0; + hash_table[i] = nullptr; } void erase_hash_table() { @@ -115,7 +115,7 @@ private: ERR_FAIL_COND_MSG(elements, "Cannot erase hash table if there are still elements inside."); memdelete_arr(hash_table); - hash_table = 0; + hash_table = nullptr; hash_table_power = 0; elements = 0; } @@ -155,7 +155,7 @@ private: for (int i = 0; i < (1 << new_hash_table_power); i++) { - new_hash_table[i] = 0; + new_hash_table[i] = nullptr; } if (hash_table) { @@ -541,7 +541,7 @@ public: memdelete_arr(hash_table); } - hash_table = 0; + hash_table = nullptr; hash_table_power = 0; elements = 0; } @@ -551,12 +551,6 @@ public: copy_from(p_table); } - HashMap() { - hash_table = nullptr; - elements = 0; - hash_table_power = 0; - } - void get_key_value_ptr_array(const Pair **p_pairs) const { if (unlikely(!hash_table)) return; @@ -584,17 +578,13 @@ public: } } - HashMap(const HashMap &p_table) { - - hash_table = nullptr; - elements = 0; - hash_table_power = 0; + HashMap() {} + HashMap(const HashMap &p_table) { copy_from(p_table); } ~HashMap() { - clear(); } }; diff --git a/core/image.cpp b/core/image.cpp index 58351ae2e5..c88ed7e3cb 100644 --- a/core/image.cpp +++ b/core/image.cpp @@ -114,23 +114,36 @@ int Image::get_format_pixel_size(Format p_format) { return 1; //luminance case FORMAT_LA8: return 2; //luminance-alpha - case FORMAT_R8: return 1; - case FORMAT_RG8: return 2; - case FORMAT_RGB8: return 3; - case FORMAT_RGBA8: return 4; - case FORMAT_RGBA4444: return 2; - case FORMAT_RGB565: return 2; + case FORMAT_R8: + return 1; + case FORMAT_RG8: + return 2; + case FORMAT_RGB8: + return 3; + case FORMAT_RGBA8: + return 4; + case FORMAT_RGBA4444: + return 2; + case FORMAT_RGB565: + return 2; case FORMAT_RF: return 4; //float - case FORMAT_RGF: return 8; - case FORMAT_RGBF: return 12; - case FORMAT_RGBAF: return 16; + case FORMAT_RGF: + return 8; + case FORMAT_RGBF: + return 12; + case FORMAT_RGBAF: + return 16; case FORMAT_RH: return 2; //half float - case FORMAT_RGH: return 4; - case FORMAT_RGBH: return 6; - case FORMAT_RGBAH: return 8; - case FORMAT_RGBE9995: return 4; + case FORMAT_RGH: + return 4; + case FORMAT_RGBH: + return 6; + case FORMAT_RGBAH: + return 8; + case FORMAT_RGBE9995: + return 4; case FORMAT_DXT1: return 1; //s3tc bc1 case FORMAT_DXT3: @@ -149,22 +162,32 @@ int Image::get_format_pixel_size(Format p_format) { return 1; //unsigned float case FORMAT_PVRTC2: return 1; //pvrtc - case FORMAT_PVRTC2A: return 1; - case FORMAT_PVRTC4: return 1; - case FORMAT_PVRTC4A: return 1; + case FORMAT_PVRTC2A: + return 1; + case FORMAT_PVRTC4: + return 1; + case FORMAT_PVRTC4A: + return 1; case FORMAT_ETC: return 1; //etc1 case FORMAT_ETC2_R11: return 1; //etc2 case FORMAT_ETC2_R11S: return 1; //signed: return 1; NOT srgb. - case FORMAT_ETC2_RG11: return 1; - case FORMAT_ETC2_RG11S: return 1; - case FORMAT_ETC2_RGB8: return 1; - case FORMAT_ETC2_RGBA8: return 1; - case FORMAT_ETC2_RGB8A1: return 1; - case FORMAT_ETC2_RA_AS_RG: return 1; - case FORMAT_DXT5_RA_AS_RG: return 1; + case FORMAT_ETC2_RG11: + return 1; + case FORMAT_ETC2_RG11S: + return 1; + case FORMAT_ETC2_RGB8: + return 1; + case FORMAT_ETC2_RGBA8: + return 1; + case FORMAT_ETC2_RGB8A1: + return 1; + case FORMAT_ETC2_RA_AS_RG: + return 1; + case FORMAT_DXT5_RA_AS_RG: + return 1; case FORMAT_MAX: { } } @@ -451,7 +474,7 @@ void Image::convert(Format p_new_format) { } else if (format > FORMAT_RGBA8 || p_new_format > FORMAT_RGBA8) { //use put/set pixel which is slower but works with non byte formats - Image new_img(width, height, 0, p_new_format); + Image new_img(width, height, false, p_new_format); for (int i = 0; i < width; i++) { for (int j = 0; j < height; j++) { @@ -469,7 +492,7 @@ void Image::convert(Format p_new_format) { return; } - Image new_img(width, height, 0, p_new_format); + Image new_img(width, height, false, p_new_format); const uint8_t *rptr = data.ptr(); uint8_t *wptr = new_img.data.ptrw(); @@ -478,36 +501,96 @@ void Image::convert(Format p_new_format) { switch (conversion_type) { - case FORMAT_L8 | (FORMAT_LA8 << 8): _convert<1, false, 1, true, true, true>(width, height, rptr, wptr); break; - case FORMAT_L8 | (FORMAT_R8 << 8): _convert<1, false, 1, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_L8 | (FORMAT_RG8 << 8): _convert<1, false, 2, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_L8 | (FORMAT_RGB8 << 8): _convert<1, false, 3, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_L8 | (FORMAT_RGBA8 << 8): _convert<1, false, 3, true, true, false>(width, height, rptr, wptr); break; - case FORMAT_LA8 | (FORMAT_L8 << 8): _convert<1, true, 1, false, true, true>(width, height, rptr, wptr); break; - case FORMAT_LA8 | (FORMAT_R8 << 8): _convert<1, true, 1, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_LA8 | (FORMAT_RG8 << 8): _convert<1, true, 2, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_LA8 | (FORMAT_RGB8 << 8): _convert<1, true, 3, false, true, false>(width, height, rptr, wptr); break; - case FORMAT_LA8 | (FORMAT_RGBA8 << 8): _convert<1, true, 3, true, true, false>(width, height, rptr, wptr); break; - case FORMAT_R8 | (FORMAT_L8 << 8): _convert<1, false, 1, false, false, true>(width, height, rptr, wptr); break; - case FORMAT_R8 | (FORMAT_LA8 << 8): _convert<1, false, 1, true, false, true>(width, height, rptr, wptr); break; - case FORMAT_R8 | (FORMAT_RG8 << 8): _convert<1, false, 2, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_R8 | (FORMAT_RGB8 << 8): _convert<1, false, 3, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_R8 | (FORMAT_RGBA8 << 8): _convert<1, false, 3, true, false, false>(width, height, rptr, wptr); break; - case FORMAT_RG8 | (FORMAT_L8 << 8): _convert<2, false, 1, false, false, true>(width, height, rptr, wptr); break; - case FORMAT_RG8 | (FORMAT_LA8 << 8): _convert<2, false, 1, true, false, true>(width, height, rptr, wptr); break; - case FORMAT_RG8 | (FORMAT_R8 << 8): _convert<2, false, 1, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RG8 | (FORMAT_RGB8 << 8): _convert<2, false, 3, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RG8 | (FORMAT_RGBA8 << 8): _convert<2, false, 3, true, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGB8 | (FORMAT_L8 << 8): _convert<3, false, 1, false, false, true>(width, height, rptr, wptr); break; - case FORMAT_RGB8 | (FORMAT_LA8 << 8): _convert<3, false, 1, true, false, true>(width, height, rptr, wptr); break; - case FORMAT_RGB8 | (FORMAT_R8 << 8): _convert<3, false, 1, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGB8 | (FORMAT_RG8 << 8): _convert<3, false, 2, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGB8 | (FORMAT_RGBA8 << 8): _convert<3, false, 3, true, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGBA8 | (FORMAT_L8 << 8): _convert<3, true, 1, false, false, true>(width, height, rptr, wptr); break; - case FORMAT_RGBA8 | (FORMAT_LA8 << 8): _convert<3, true, 1, true, false, true>(width, height, rptr, wptr); break; - case FORMAT_RGBA8 | (FORMAT_R8 << 8): _convert<3, true, 1, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGBA8 | (FORMAT_RG8 << 8): _convert<3, true, 2, false, false, false>(width, height, rptr, wptr); break; - case FORMAT_RGBA8 | (FORMAT_RGB8 << 8): _convert<3, true, 3, false, false, false>(width, height, rptr, wptr); break; + case FORMAT_L8 | (FORMAT_LA8 << 8): + _convert<1, false, 1, true, true, true>(width, height, rptr, wptr); + break; + case FORMAT_L8 | (FORMAT_R8 << 8): + _convert<1, false, 1, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_L8 | (FORMAT_RG8 << 8): + _convert<1, false, 2, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_L8 | (FORMAT_RGB8 << 8): + _convert<1, false, 3, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_L8 | (FORMAT_RGBA8 << 8): + _convert<1, false, 3, true, true, false>(width, height, rptr, wptr); + break; + case FORMAT_LA8 | (FORMAT_L8 << 8): + _convert<1, true, 1, false, true, true>(width, height, rptr, wptr); + break; + case FORMAT_LA8 | (FORMAT_R8 << 8): + _convert<1, true, 1, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_LA8 | (FORMAT_RG8 << 8): + _convert<1, true, 2, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_LA8 | (FORMAT_RGB8 << 8): + _convert<1, true, 3, false, true, false>(width, height, rptr, wptr); + break; + case FORMAT_LA8 | (FORMAT_RGBA8 << 8): + _convert<1, true, 3, true, true, false>(width, height, rptr, wptr); + break; + case FORMAT_R8 | (FORMAT_L8 << 8): + _convert<1, false, 1, false, false, true>(width, height, rptr, wptr); + break; + case FORMAT_R8 | (FORMAT_LA8 << 8): + _convert<1, false, 1, true, false, true>(width, height, rptr, wptr); + break; + case FORMAT_R8 | (FORMAT_RG8 << 8): + _convert<1, false, 2, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_R8 | (FORMAT_RGB8 << 8): + _convert<1, false, 3, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_R8 | (FORMAT_RGBA8 << 8): + _convert<1, false, 3, true, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RG8 | (FORMAT_L8 << 8): + _convert<2, false, 1, false, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RG8 | (FORMAT_LA8 << 8): + _convert<2, false, 1, true, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RG8 | (FORMAT_R8 << 8): + _convert<2, false, 1, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RG8 | (FORMAT_RGB8 << 8): + _convert<2, false, 3, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RG8 | (FORMAT_RGBA8 << 8): + _convert<2, false, 3, true, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGB8 | (FORMAT_L8 << 8): + _convert<3, false, 1, false, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RGB8 | (FORMAT_LA8 << 8): + _convert<3, false, 1, true, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RGB8 | (FORMAT_R8 << 8): + _convert<3, false, 1, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGB8 | (FORMAT_RG8 << 8): + _convert<3, false, 2, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGB8 | (FORMAT_RGBA8 << 8): + _convert<3, false, 3, true, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGBA8 | (FORMAT_L8 << 8): + _convert<3, true, 1, false, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RGBA8 | (FORMAT_LA8 << 8): + _convert<3, true, 1, true, false, true>(width, height, rptr, wptr); + break; + case FORMAT_RGBA8 | (FORMAT_R8 << 8): + _convert<3, true, 1, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGBA8 | (FORMAT_RG8 << 8): + _convert<3, true, 2, false, false, false>(width, height, rptr, wptr); + break; + case FORMAT_RGBA8 | (FORMAT_RGB8 << 8): + _convert<3, true, 3, false, false, false>(width, height, rptr, wptr); + break; } bool gen_mipmaps = mipmaps; @@ -908,7 +991,7 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { if (p_width == width && p_height == height) return; - Image dst(p_width, p_height, 0, format); + Image dst(p_width, p_height, false, format); // Setup mipmap-aware scaling Image dst2; @@ -928,7 +1011,7 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { } bool interpolate_mipmaps = mipmap_aware && mip1 != mip2; if (interpolate_mipmaps) { - dst2.create(p_width, p_height, 0, format); + dst2.create(p_width, p_height, false, format); } bool had_mipmaps = mipmaps; @@ -949,25 +1032,49 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { if (format >= FORMAT_L8 && format <= FORMAT_RGBA8) { switch (get_format_pixel_size(format)) { - case 1: _scale_nearest<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 2: _scale_nearest<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 3: _scale_nearest<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_nearest<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 1: + _scale_nearest<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 2: + _scale_nearest<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 3: + _scale_nearest<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_nearest<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RF && format <= FORMAT_RGBAF) { switch (get_format_pixel_size(format)) { - case 4: _scale_nearest<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_nearest<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 12: _scale_nearest<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 16: _scale_nearest<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 4: + _scale_nearest<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_nearest<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 12: + _scale_nearest<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 16: + _scale_nearest<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RH && format <= FORMAT_RGBAH) { switch (get_format_pixel_size(format)) { - case 2: _scale_nearest<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_nearest<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 6: _scale_nearest<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_nearest<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 2: + _scale_nearest<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_nearest<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 6: + _scale_nearest<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_nearest<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } @@ -1013,24 +1120,48 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { if (format >= FORMAT_L8 && format <= FORMAT_RGBA8) { switch (get_format_pixel_size(format)) { - case 1: _scale_bilinear<1, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 2: _scale_bilinear<2, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 3: _scale_bilinear<3, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 4: _scale_bilinear<4, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; + case 1: + _scale_bilinear<1, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 2: + _scale_bilinear<2, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 3: + _scale_bilinear<3, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 4: + _scale_bilinear<4, uint8_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; } } else if (format >= FORMAT_RF && format <= FORMAT_RGBAF) { switch (get_format_pixel_size(format)) { - case 4: _scale_bilinear<1, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 8: _scale_bilinear<2, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 12: _scale_bilinear<3, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 16: _scale_bilinear<4, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; + case 4: + _scale_bilinear<1, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 8: + _scale_bilinear<2, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 12: + _scale_bilinear<3, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 16: + _scale_bilinear<4, float>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; } } else if (format >= FORMAT_RH && format <= FORMAT_RGBAH) { switch (get_format_pixel_size(format)) { - case 2: _scale_bilinear<1, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 4: _scale_bilinear<2, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 6: _scale_bilinear<3, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; - case 8: _scale_bilinear<4, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); break; + case 2: + _scale_bilinear<1, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 4: + _scale_bilinear<2, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 6: + _scale_bilinear<3, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; + case 8: + _scale_bilinear<4, uint16_t>(src_ptr, w_ptr, src_width, src_height, p_width, p_height); + break; } } } @@ -1046,24 +1177,48 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { if (format >= FORMAT_L8 && format <= FORMAT_RGBA8) { switch (get_format_pixel_size(format)) { - case 1: _scale_cubic<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 2: _scale_cubic<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 3: _scale_cubic<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_cubic<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 1: + _scale_cubic<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 2: + _scale_cubic<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 3: + _scale_cubic<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_cubic<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RF && format <= FORMAT_RGBAF) { switch (get_format_pixel_size(format)) { - case 4: _scale_cubic<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_cubic<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 12: _scale_cubic<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 16: _scale_cubic<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 4: + _scale_cubic<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_cubic<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 12: + _scale_cubic<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 16: + _scale_cubic<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RH && format <= FORMAT_RGBAH) { switch (get_format_pixel_size(format)) { - case 2: _scale_cubic<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_cubic<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 6: _scale_cubic<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_cubic<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 2: + _scale_cubic<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_cubic<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 6: + _scale_cubic<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_cubic<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } } break; @@ -1071,24 +1226,48 @@ void Image::resize(int p_width, int p_height, Interpolation p_interpolation) { if (format >= FORMAT_L8 && format <= FORMAT_RGBA8) { switch (get_format_pixel_size(format)) { - case 1: _scale_lanczos<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 2: _scale_lanczos<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 3: _scale_lanczos<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_lanczos<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 1: + _scale_lanczos<1, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 2: + _scale_lanczos<2, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 3: + _scale_lanczos<3, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_lanczos<4, uint8_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RF && format <= FORMAT_RGBAF) { switch (get_format_pixel_size(format)) { - case 4: _scale_lanczos<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_lanczos<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 12: _scale_lanczos<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 16: _scale_lanczos<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 4: + _scale_lanczos<1, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_lanczos<2, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 12: + _scale_lanczos<3, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 16: + _scale_lanczos<4, float>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } else if (format >= FORMAT_RH && format <= FORMAT_RGBAH) { switch (get_format_pixel_size(format)) { - case 2: _scale_lanczos<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 4: _scale_lanczos<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 6: _scale_lanczos<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; - case 8: _scale_lanczos<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); break; + case 2: + _scale_lanczos<1, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 4: + _scale_lanczos<2, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 6: + _scale_lanczos<3, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; + case 8: + _scale_lanczos<4, uint16_t>(r_ptr, w_ptr, width, height, p_width, p_height); + break; } } } break; @@ -1125,7 +1304,7 @@ void Image::crop_from_point(int p_x, int p_y, int p_width, int p_height) { uint8_t pdata[16]; //largest is 16 uint32_t pixel_size = get_format_pixel_size(format); - Image dst(p_width, p_height, 0, format); + Image dst(p_width, p_height, false, format); { const uint8_t *r = data.ptr(); @@ -1403,23 +1582,51 @@ void Image::shrink_x2() { switch (format) { case FORMAT_L8: - case FORMAT_R8: _generate_po2_mipmap<uint8_t, 1, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); break; - case FORMAT_LA8: _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); break; - case FORMAT_RG8: _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); break; - case FORMAT_RGB8: _generate_po2_mipmap<uint8_t, 3, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); break; - case FORMAT_RGBA8: _generate_po2_mipmap<uint8_t, 4, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); break; - - case FORMAT_RF: _generate_po2_mipmap<float, 1, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); break; - case FORMAT_RGF: _generate_po2_mipmap<float, 2, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); break; - case FORMAT_RGBF: _generate_po2_mipmap<float, 3, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); break; - case FORMAT_RGBAF: _generate_po2_mipmap<float, 4, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); break; - - case FORMAT_RH: _generate_po2_mipmap<uint16_t, 1, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); break; - case FORMAT_RGH: _generate_po2_mipmap<uint16_t, 2, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); break; - case FORMAT_RGBH: _generate_po2_mipmap<uint16_t, 3, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); break; - case FORMAT_RGBAH: _generate_po2_mipmap<uint16_t, 4, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); break; - - case FORMAT_RGBE9995: _generate_po2_mipmap<uint32_t, 1, false, Image::average_4_rgbe9995, Image::renormalize_rgbe9995>(reinterpret_cast<const uint32_t *>(r), reinterpret_cast<uint32_t *>(w), width, height); break; + case FORMAT_R8: + _generate_po2_mipmap<uint8_t, 1, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); + break; + case FORMAT_LA8: + _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); + break; + case FORMAT_RG8: + _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); + break; + case FORMAT_RGB8: + _generate_po2_mipmap<uint8_t, 3, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); + break; + case FORMAT_RGBA8: + _generate_po2_mipmap<uint8_t, 4, false, Image::average_4_uint8, Image::renormalize_uint8>(r, w, width, height); + break; + + case FORMAT_RF: + _generate_po2_mipmap<float, 1, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); + break; + case FORMAT_RGF: + _generate_po2_mipmap<float, 2, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); + break; + case FORMAT_RGBF: + _generate_po2_mipmap<float, 3, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); + break; + case FORMAT_RGBAF: + _generate_po2_mipmap<float, 4, false, Image::average_4_float, Image::renormalize_float>(reinterpret_cast<const float *>(r), reinterpret_cast<float *>(w), width, height); + break; + + case FORMAT_RH: + _generate_po2_mipmap<uint16_t, 1, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); + break; + case FORMAT_RGH: + _generate_po2_mipmap<uint16_t, 2, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); + break; + case FORMAT_RGBH: + _generate_po2_mipmap<uint16_t, 3, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); + break; + case FORMAT_RGBAH: + _generate_po2_mipmap<uint16_t, 4, false, Image::average_4_half, Image::renormalize_half>(reinterpret_cast<const uint16_t *>(r), reinterpret_cast<uint16_t *>(w), width, height); + break; + + case FORMAT_RGBE9995: + _generate_po2_mipmap<uint32_t, 1, false, Image::average_4_rgbe9995, Image::renormalize_rgbe9995>(reinterpret_cast<const uint32_t *>(r), reinterpret_cast<uint32_t *>(w), width, height); + break; default: { } } @@ -1485,9 +1692,13 @@ Error Image::generate_mipmaps(bool p_renormalize) { switch (format) { case FORMAT_L8: - case FORMAT_R8: _generate_po2_mipmap<uint8_t, 1, false, Image::average_4_uint8, Image::renormalize_uint8>(&wp[prev_ofs], &wp[ofs], prev_w, prev_h); break; + case FORMAT_R8: + _generate_po2_mipmap<uint8_t, 1, false, Image::average_4_uint8, Image::renormalize_uint8>(&wp[prev_ofs], &wp[ofs], prev_w, prev_h); + break; case FORMAT_LA8: - case FORMAT_RG8: _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(&wp[prev_ofs], &wp[ofs], prev_w, prev_h); break; + case FORMAT_RG8: + _generate_po2_mipmap<uint8_t, 2, false, Image::average_4_uint8, Image::renormalize_uint8>(&wp[prev_ofs], &wp[ofs], prev_w, prev_h); + break; case FORMAT_RGB8: if (p_renormalize) _generate_po2_mipmap<uint8_t, 3, true, Image::average_4_uint8, Image::renormalize_uint8>(&wp[prev_ofs], &wp[ofs], prev_w, prev_h); @@ -1911,12 +2122,24 @@ void Image::create(const char **p_xpm) { break; switch (i) { - case 0: col_r = v << 4; break; - case 1: col_r |= v; break; - case 2: col_g = v << 4; break; - case 3: col_g |= v; break; - case 4: col_b = v << 4; break; - case 5: col_b |= v; break; + case 0: + col_r = v << 4; + break; + case 1: + col_r |= v; + break; + case 2: + col_g = v << 4; + break; + case 3: + col_g |= v; + break; + case 4: + col_b = v << 4; + break; + case 5: + col_b |= v; + break; }; } @@ -1934,7 +2157,7 @@ void Image::create(const char **p_xpm) { if (line == colormap_size) { status = READING_PIXELS; - create(size_width, size_height, 0, has_alpha ? FORMAT_RGBA8 : FORMAT_RGB8); + create(size_width, size_height, false, has_alpha ? FORMAT_RGBA8 : FORMAT_RGB8); w = data.ptrw(); pixel_size = has_alpha ? 4 : 3; } @@ -2901,12 +3124,24 @@ Image::UsedChannels Image::detect_used_channels(CompressSource p_source) { void Image::optimize_channels() { switch (detect_used_channels()) { - case USED_CHANNELS_L: convert(FORMAT_L8); break; - case USED_CHANNELS_LA: convert(FORMAT_LA8); break; - case USED_CHANNELS_R: convert(FORMAT_R8); break; - case USED_CHANNELS_RG: convert(FORMAT_RG8); break; - case USED_CHANNELS_RGB: convert(FORMAT_RGB8); break; - case USED_CHANNELS_RGBA: convert(FORMAT_RGBA8); break; + case USED_CHANNELS_L: + convert(FORMAT_L8); + break; + case USED_CHANNELS_LA: + convert(FORMAT_LA8); + break; + case USED_CHANNELS_R: + convert(FORMAT_R8); + break; + case USED_CHANNELS_RG: + convert(FORMAT_RG8); + break; + case USED_CHANNELS_RGB: + convert(FORMAT_RGB8); + break; + case USED_CHANNELS_RGBA: + convert(FORMAT_RGBA8); + break; } } @@ -3094,7 +3329,7 @@ Ref<Image> Image::rgbe_to_srgb() { Ref<Image> new_image; new_image.instance(); - new_image->create(width, height, 0, Image::FORMAT_RGB8); + new_image->create(width, height, false, Image::FORMAT_RGB8); for (int row = 0; row < height; row++) { for (int col = 0; col < width; col++) { @@ -3152,11 +3387,13 @@ void Image::bumpmap_to_normalmap(float bump_scale) { for (int ty = 0; ty < height; ty++) { int py = ty + 1; - if (py >= height) py -= height; + if (py >= height) + py -= height; for (int tx = 0; tx < width; tx++) { int px = tx + 1; - if (px >= width) px -= width; + if (px >= width) + px -= width; float here = read_ptr[ty * width + tx]; float to_right = read_ptr[ty * width + px]; float above = read_ptr[py * width + tx]; @@ -3431,13 +3668,6 @@ Ref<Resource> Image::duplicate(bool p_subresources) const { return copy; } -Image::Image() { - - width = 0; - height = 0; - mipmaps = false; - format = FORMAT_L8; -} - -Image::~Image() { +void Image::set_as_black() { + zeromem(data.ptrw(), data.size()); } diff --git a/core/image.h b/core/image.h index 5bd73fa677..dbdfaa917b 100644 --- a/core/image.h +++ b/core/image.h @@ -33,7 +33,6 @@ #include "core/color.h" #include "core/math/rect2.h" - #include "core/resource.h" /** @@ -172,10 +171,11 @@ private: create(p_width, p_height, p_use_mipmaps, p_format, p_data); } - Format format; + Format format = FORMAT_L8; Vector<uint8_t> data; - int width, height; - bool mipmaps; + int width = 0; + int height = 0; + bool mipmaps = false; void _copy_internals_from(const Image &p_image) { format = p_image.format; @@ -286,7 +286,7 @@ public: /** * create an empty image */ - Image(); + Image() {} /** * create an empty image of a specific size and format */ @@ -296,6 +296,8 @@ public: */ Image(int p_width, int p_height, bool p_mipmaps, Format p_format, const Vector<uint8_t> &p_data); + ~Image() {} + enum AlphaMode { ALPHA_NONE, ALPHA_BIT, @@ -376,6 +378,8 @@ public: void set_pixelv(const Point2 &p_dst, const Color &p_color); void set_pixel(int p_x, int p_y, const Color &p_color); + void set_as_black(); + void copy_internals_from(const Ref<Image> &p_image) { ERR_FAIL_COND_MSG(p_image.is_null(), "It's not a reference to a valid Image object."); format = p_image->format; @@ -384,8 +388,6 @@ public: mipmaps = p_image->mipmaps; data = p_image->data; } - - ~Image(); }; VARIANT_ENUM_CAST(Image::Format) diff --git a/core/input/godotcontrollerdb.txt b/core/input/godotcontrollerdb.txt index adc3649d64..51ddda1e4e 100644 --- a/core/input/godotcontrollerdb.txt +++ b/core/input/godotcontrollerdb.txt @@ -4,6 +4,9 @@ # Windows __XINPUT_DEVICE__,XInput Gamepad,a:b12,b:b13,x:b14,y:b15,start:b4,back:b5,leftstick:b6,rightstick:b7,leftshoulder:b8,rightshoulder:b9,dpup:b0,dpdown:b1,dpleft:b2,dpright:b3,leftx:a0,lefty:a1,rightx:a2,righty:a3,lefttrigger:a4,righttrigger:a5,platform:Windows, +# Android +Default Android Gamepad,Default Controller,leftx:a0,lefty:a1,dpdown:h0.4,rightstick:b8,rightshoulder:b10,rightx:a2,start:b6,righty:a3,dpleft:h0.8,lefttrigger:a4,x:b2,dpup:h0.1,back:b4,leftstick:b7,leftshoulder:b9,y:b3,a:b0,dpright:h0.2,righttrigger:a5,b:b1,platform:Android, + # Javascript Default HTML5 Gamepad, Default Mapping,leftx:a0,lefty:a1,dpdown:b13,rightstick:b11,rightshoulder:b5,rightx:a2,start:b9,righty:a3,dpleft:b14,lefttrigger:a6,x:b2,dpup:b12,back:b8,leftstick:b10,leftshoulder:b4,y:b3,a:b0,dpright:b15,righttrigger:a7,b:b1,platform:Javascript, c2a94d6963726f736f66742058626f78,Wireless X360 Controller,leftx:a0,lefty:a1,dpdown:b14,rightstick:b10,rightshoulder:b5,rightx:a3,start:b7,righty:a4,dpleft:b11,lefttrigger:a2,x:b2,dpup:b13,back:b6,leftstick:b9,leftshoulder:b4,y:b3,a:b0,dpright:b12,righttrigger:a5,b:b1,platform:Javascript, diff --git a/core/input/input.cpp b/core/input/input.cpp index 357e4c06c1..38a71994d8 100644 --- a/core/input/input.cpp +++ b/core/input/input.cpp @@ -39,6 +39,87 @@ #include "editor/editor_settings.h" #endif +static const char *_joy_buttons[JOY_SDL_BUTTONS + 1] = { + "a", + "b", + "x", + "y", + "back", + "guide", + "start", + "leftstick", + "rightstick", + "leftshoulder", + "rightshoulder", + "dpup", + "dpdown", + "dpleft", + "dpright", + nullptr +}; + +static const char *_joy_button_names[JOY_BUTTON_MAX] = { + "Face Bottom", + "Face Right", + "Face Left", + "Face Top", + "Select", + "Guide", + "Start", + "Left Stick", + "Right Stick", + "Left Shoulder", + "Right Shoulder", + "D-Pad Up", + "D-Pad Down", + "D-Pad Left", + "D-Pad Right", + "Button 15", + "Button 16", + "Button 17", + "Button 18", + "Button 19", + "Button 20", + "Button 21", + "Button 22", + "Button 23", + "Button 24", + "Button 25", + "Button 26", + "Button 27", + "Button 28", + "Button 29", + "Button 30", + "Button 31", + "Button 32", + "Button 33", + "Button 34", + "Button 35" +}; + +static const char *_joy_axes[JOY_SDL_AXES + 1] = { + "leftx", + "lefty", + "rightx", + "righty", + "lefttrigger", + "righttrigger", + nullptr +}; + +static const char *_joy_axis_names[JOY_AXIS_MAX] = { + "Left Stick X", + "Left Stick Y", + "Right Stick X", + "Right Stick Y", + "Left Trigger", + "Right Trigger", + "Joystick 3 Stick X", + "Joystick 3 Stick Y", + "Joystick 4 Stick X", + "Joystick 4 Stick Y" +}; + Input *Input::singleton = nullptr; void (*Input::set_mouse_mode_func)(Input::MouseMode) = nullptr; @@ -336,13 +417,12 @@ void Input::joy_connection_changed(int p_idx, bool p_connected, String p_name, S } else { js.connected = false; for (int i = 0; i < JOY_BUTTON_MAX; i++) { - - if (i < JOY_AXIS_MAX) - set_joy_axis(p_idx, i, 0.0f); - int c = _combine_device(i, p_idx); joy_buttons_pressed.erase(c); - }; + } + for (int i = 0; i < JOY_AXIS_MAX; i++) { + set_joy_axis(p_idx, i, 0.0f); + } }; joy_names[p_idx] = js; @@ -831,21 +911,9 @@ void Input::joy_button(int p_device, int p_button, bool p_pressed) { return; } - const Map<int, JoyEvent>::Element *el = map_db[joy.mapping].buttons.find(p_button); - if (!el) { - //don't process un-mapped events for now, it could mess things up badly for devices with additional buttons/axis - //return _button_event(p_last_id, p_device, p_button, p_pressed); - return; - } + JoyEvent map = _get_mapped_button_event(map_db[joy.mapping], p_button); - JoyEvent map = el->get(); if (map.type == TYPE_BUTTON) { - //fake additional axis event for triggers - if (map.index == JOY_L2 || map.index == JOY_R2) { - float value = p_pressed ? 1.0f : 0.0f; - int axis = map.index == JOY_L2 ? JOY_ANALOG_L2 : JOY_ANALOG_R2; - _axis_event(p_device, axis, value); - } _button_event(p_device, map.index, p_pressed); return; } @@ -901,32 +969,20 @@ void Input::joy_axis(int p_device, int p_axis, const JoyAxis &p_value) { return; }; - const Map<int, JoyEvent>::Element *el = map_db[joy.mapping].axis.find(p_axis); - if (!el) { - //return _axis_event(p_last_id, p_device, p_axis, p_value); - return; - }; - - JoyEvent map = el->get(); + JoyEvent map = _get_mapped_axis_event(map_db[joy.mapping], p_axis, p_value); if (map.type == TYPE_BUTTON) { - //send axis event for triggers - if (map.index == JOY_L2 || map.index == JOY_R2) { - float value = p_value.min == 0 ? p_value.value : 0.5f + p_value.value / 2.0f; - int axis = map.index == JOY_L2 ? JOY_ANALOG_L2 : JOY_ANALOG_R2; - _axis_event(p_device, axis, value); - } - if (map.index == JOY_DPAD_UP || map.index == JOY_DPAD_DOWN) { + if (map.index == JOY_BUTTON_DPAD_UP || map.index == JOY_BUTTON_DPAD_DOWN) { bool pressed = p_value.value != 0.0f; - int button = p_value.value < 0 ? JOY_DPAD_UP : JOY_DPAD_DOWN; + int button = p_value.value < 0 ? JOY_BUTTON_DPAD_UP : JOY_BUTTON_DPAD_DOWN; if (!pressed) { - if (joy_buttons_pressed.has(_combine_device(JOY_DPAD_UP, p_device))) { - _button_event(p_device, JOY_DPAD_UP, false); + if (joy_buttons_pressed.has(_combine_device(JOY_BUTTON_DPAD_UP, p_device))) { + _button_event(p_device, JOY_BUTTON_DPAD_UP, false); } - if (joy_buttons_pressed.has(_combine_device(JOY_DPAD_DOWN, p_device))) { - _button_event(p_device, JOY_DPAD_DOWN, false); + if (joy_buttons_pressed.has(_combine_device(JOY_BUTTON_DPAD_DOWN, p_device))) { + _button_event(p_device, JOY_BUTTON_DPAD_DOWN, false); } } if (pressed == joy_buttons_pressed.has(_combine_device(button, p_device))) { @@ -935,16 +991,17 @@ void Input::joy_axis(int p_device, int p_axis, const JoyAxis &p_value) { _button_event(p_device, button, true); return; } - if (map.index == JOY_DPAD_LEFT || map.index == JOY_DPAD_RIGHT) { + + if (map.index == JOY_BUTTON_DPAD_LEFT || map.index == JOY_BUTTON_DPAD_RIGHT) { bool pressed = p_value.value != 0.0f; - int button = p_value.value < 0 ? JOY_DPAD_LEFT : JOY_DPAD_RIGHT; + int button = p_value.value < 0 ? JOY_BUTTON_DPAD_LEFT : JOY_BUTTON_DPAD_RIGHT; if (!pressed) { - if (joy_buttons_pressed.has(_combine_device(JOY_DPAD_LEFT, p_device))) { - _button_event(p_device, JOY_DPAD_LEFT, false); + if (joy_buttons_pressed.has(_combine_device(JOY_BUTTON_DPAD_LEFT, p_device))) { + _button_event(p_device, JOY_BUTTON_DPAD_LEFT, false); } - if (joy_buttons_pressed.has(_combine_device(JOY_DPAD_RIGHT, p_device))) { - _button_event(p_device, JOY_DPAD_RIGHT, false); + if (joy_buttons_pressed.has(_combine_device(JOY_BUTTON_DPAD_RIGHT, p_device))) { + _button_event(p_device, JOY_BUTTON_DPAD_RIGHT, false); } } if (pressed == joy_buttons_pressed.has(_combine_device(button, p_device))) { @@ -953,19 +1010,21 @@ void Input::joy_axis(int p_device, int p_axis, const JoyAxis &p_value) { _button_event(p_device, button, true); return; } + float deadzone = p_value.min == 0 ? 0.5f : 0.0f; bool pressed = p_value.value > deadzone; if (pressed == joy_buttons_pressed.has(_combine_device(map.index, p_device))) { // button already pressed or released, this is an axis bounce value return; } + _button_event(p_device, map.index, pressed); return; } if (map.type == TYPE_AXIS) { - _axis_event(p_device, map.index, val); + _axis_event(p_device, map.index, map.value); return; } //printf("invalid mapping\n"); @@ -976,12 +1035,26 @@ void Input::joy_hat(int p_device, int p_val) { _THREAD_SAFE_METHOD_; const Joypad &joy = joy_names[p_device]; - const JoyEvent *map; + JoyEvent map[HAT_MAX]; - if (joy.mapping == -1) { - map = hat_map_default; - } else { - map = map_db[joy.mapping].hat; + map[HAT_UP].type = TYPE_BUTTON; + map[HAT_UP].index = JOY_BUTTON_DPAD_UP; + map[HAT_UP].value = 0; + + map[HAT_RIGHT].type = TYPE_BUTTON; + map[HAT_RIGHT].index = JOY_BUTTON_DPAD_RIGHT; + map[HAT_RIGHT].value = 0; + + map[HAT_DOWN].type = TYPE_BUTTON; + map[HAT_DOWN].index = JOY_BUTTON_DPAD_DOWN; + map[HAT_DOWN].value = 0; + + map[HAT_LEFT].type = TYPE_BUTTON; + map[HAT_LEFT].index = JOY_BUTTON_DPAD_LEFT; + map[HAT_LEFT].value = 0; + + if (joy.mapping != -1) { + _get_mapped_hat_events(map_db[joy.mapping], 0, map); }; int cur_val = joy_names[p_device].hat_current; @@ -1025,50 +1098,149 @@ void Input::_axis_event(int p_device, int p_axis, float p_value) { parse_input_event(ievent); }; -Input::JoyEvent Input::_find_to_event(String p_to) { +Input::JoyEvent Input::_get_mapped_button_event(const JoyDeviceMapping &mapping, int p_button) { + + JoyEvent event; + event.type = TYPE_MAX; + + for (int i = 0; i < mapping.bindings.size(); i++) { + const JoyBinding binding = mapping.bindings[i]; + if (binding.inputType == TYPE_BUTTON && binding.input.button == p_button) { + event.type = binding.outputType; + switch (binding.outputType) { + case TYPE_BUTTON: + event.index = binding.output.button; + return event; + case TYPE_AXIS: + event.index = binding.output.axis.axis; + return event; + default: + ERR_PRINT_ONCE("Joypad button mapping error."); + } + } + } + return event; +} + +Input::JoyEvent Input::_get_mapped_axis_event(const JoyDeviceMapping &mapping, int p_axis, const JoyAxis &p_value) { + + JoyEvent event; + event.type = TYPE_MAX; + + for (int i = 0; i < mapping.bindings.size(); i++) { + const JoyBinding binding = mapping.bindings[i]; + if (binding.inputType == TYPE_AXIS && binding.input.axis.axis == p_axis) { + float value = p_value.value; + if (binding.input.axis.invert) + value = -value; + if (binding.input.axis.range == FULL_AXIS || + (binding.input.axis.range == POSITIVE_HALF_AXIS && value > 0) || + (binding.input.axis.range == NEGATIVE_HALF_AXIS && value < 0)) { + event.type = binding.outputType; + switch (binding.outputType) { + case TYPE_BUTTON: + event.index = binding.output.button; + return event; + case TYPE_AXIS: + event.index = binding.output.axis.axis; + event.value = value; + if (binding.output.axis.range != binding.input.axis.range) { + float shifted_positive_value = 0; + switch (binding.input.axis.range) { + case POSITIVE_HALF_AXIS: + shifted_positive_value = value; + break; + case NEGATIVE_HALF_AXIS: + shifted_positive_value = value + 1; + break; + case FULL_AXIS: + shifted_positive_value = (value + 1) / 2; + break; + } + switch (binding.output.axis.range) { + case POSITIVE_HALF_AXIS: + event.value = shifted_positive_value; + break; + case NEGATIVE_HALF_AXIS: + event.value = shifted_positive_value - 1; + break; + case FULL_AXIS: + event.value = (shifted_positive_value * 2) - 1; + break; + } + } + return event; + default: + ERR_PRINT_ONCE("Joypad axis mapping error."); + } + } + } + } + return event; +} - // string names of the SDL buttons in the same order as input_event.h godot buttons - static const char *buttons[] = { "a", "b", "x", "y", "leftshoulder", "rightshoulder", "lefttrigger", "righttrigger", "leftstick", "rightstick", "back", "start", "dpup", "dpdown", "dpleft", "dpright", "guide", nullptr }; +void Input::_get_mapped_hat_events(const JoyDeviceMapping &mapping, int p_hat, JoyEvent r_events[]) { - static const char *axis[] = { "leftx", "lefty", "rightx", "righty", nullptr }; + for (int i = 0; i < mapping.bindings.size(); i++) { + const JoyBinding binding = mapping.bindings[i]; + if (binding.inputType == TYPE_HAT && binding.input.hat.hat == p_hat) { - JoyEvent ret; - ret.type = -1; - ret.index = 0; + int index; + switch (binding.input.hat.hat_mask) { + case HAT_MASK_UP: + index = 0; + break; + case HAT_MASK_RIGHT: + index = 1; + break; + case HAT_MASK_DOWN: + index = 2; + break; + case HAT_MASK_LEFT: + index = 3; + break; + default: + ERR_PRINT_ONCE("Joypad button mapping error."); + continue; + } - int i = 0; - while (buttons[i]) { + r_events[index].type = binding.outputType; + switch (binding.outputType) { + case TYPE_BUTTON: + r_events[index].index = binding.output.button; + break; + case TYPE_AXIS: + r_events[index].index = binding.output.axis.axis; + break; + default: + ERR_PRINT_ONCE("Joypad button mapping error."); + } + } + } +} - if (p_to == buttons[i]) { - ret.type = TYPE_BUTTON; - ret.index = i; - ret.value = 0; - return ret; - }; - ++i; - }; +JoyButtonList Input::_get_output_button(String output) { - i = 0; - while (axis[i]) { + for (int i = 0; _joy_buttons[i]; i++) { + if (output == _joy_buttons[i]) + return JoyButtonList(i); + } + return JoyButtonList::JOY_INVALID_BUTTON; +} - if (p_to == axis[i]) { - ret.type = TYPE_AXIS; - ret.index = i; - ret.value = 0; - return ret; - }; - ++i; - }; +JoyAxisList Input::_get_output_axis(String output) { - return ret; -}; + for (int i = 0; _joy_axes[i]; i++) { + if (output == _joy_axes[i]) + return JoyAxisList(i); + } + return JoyAxisList::JOY_INVALID_AXIS; +} void Input::parse_mapping(String p_mapping) { _THREAD_SAFE_METHOD_; JoyDeviceMapping mapping; - for (int i = 0; i < HAT_MAX; ++i) - mapping.hat[i].index = 1024 + i; Vector<String> entry = p_mapping.split(","); if (entry.size() < 2) { @@ -1087,45 +1259,79 @@ void Input::parse_mapping(String p_mapping) { if (entry[idx] == "") continue; - String from = entry[idx].get_slice(":", 1).replace(" ", ""); - String to = entry[idx].get_slice(":", 0).replace(" ", ""); + String output = entry[idx].get_slice(":", 0).replace(" ", ""); + String input = entry[idx].get_slice(":", 1).replace(" ", ""); + ERR_CONTINUE_MSG(output.length() < 1 || input.length() < 2, + String(entry[idx] + "\nInvalid device mapping entry: " + entry[idx])); - JoyEvent to_event = _find_to_event(to); - if (to_event.type == -1) + if (output == "platform") continue; - String etype = from.substr(0, 1); - if (etype == "a") { - - int aid = from.substr(1, from.length() - 1).to_int(); - mapping.axis[aid] = to_event; - - } else if (etype == "b") { + JoyAxisRange output_range = FULL_AXIS; + if (output[0] == '+' || output[0] == '-') { + ERR_CONTINUE_MSG(output.length() < 2, String(entry[idx] + "\nInvalid output: " + entry[idx])); + output = output.right(1); + if (output[0] == '+') + output_range = POSITIVE_HALF_AXIS; + else if (output[0] == '-') + output_range = NEGATIVE_HALF_AXIS; + } - int bid = from.substr(1, from.length() - 1).to_int(); - mapping.buttons[bid] = to_event; + JoyAxisRange input_range = FULL_AXIS; + if (input[0] == '+') { + input_range = POSITIVE_HALF_AXIS; + input = input.right(1); + } else if (input[0] == '-') { + input_range = NEGATIVE_HALF_AXIS; + input = input.right(1); + } + bool invert_axis = false; + if (input[input.length() - 1] == '~') + invert_axis = true; + + JoyButtonList output_button = _get_output_button(output); + JoyAxisList output_axis = _get_output_axis(output); + ERR_CONTINUE_MSG(output_button == JOY_INVALID_BUTTON && output_axis == JOY_INVALID_AXIS, + String(entry[idx] + "\nUnrecognised output string: " + output)); + ERR_CONTINUE_MSG(output_button != JOY_INVALID_BUTTON && output_axis != JOY_INVALID_AXIS, + String("BUG: Output string matched both button and axis: " + output)); + + JoyBinding binding; + if (output_button != JOY_INVALID_BUTTON) { + binding.outputType = TYPE_BUTTON; + binding.output.button = output_button; + } else if (output_axis != JOY_INVALID_AXIS) { + binding.outputType = TYPE_AXIS; + binding.output.axis.axis = output_axis; + binding.output.axis.range = output_range; + } - } else if (etype == "h") { + switch (input[0]) { + case 'b': + binding.inputType = TYPE_BUTTON; + binding.input.button = input.right(1).to_int(); + break; + case 'a': + binding.inputType = TYPE_AXIS; + binding.input.axis.axis = input.right(1).to_int(); + binding.input.axis.range = input_range; + binding.input.axis.invert = invert_axis; + break; + case 'h': + ERR_CONTINUE_MSG(input.length() != 4 || input[2] != '.', + String(entry[idx] + "\nInvalid hat input: " + input)); + binding.inputType = TYPE_HAT; + binding.input.hat.hat = input.substr(1, 1).to_int(); + binding.input.hat.hat_mask = static_cast<HatMask>(input.right(3).to_int()); + break; + default: + ERR_CONTINUE_MSG(true, String(entry[idx] + "\nUnrecognised input string: " + input)); + } - int hat_value = from.get_slice(".", 1).to_int(); - switch (hat_value) { - case 1: - mapping.hat[HAT_UP] = to_event; - break; - case 2: - mapping.hat[HAT_RIGHT] = to_event; - break; - case 4: - mapping.hat[HAT_DOWN] = to_event; - break; - case 8: - mapping.hat[HAT_LEFT] = to_event; - break; - }; - }; + mapping.bindings.push_back(binding); }; + map_db.push_back(mapping); - //printf("added mapping with uuid %ls\n", mapping.uid.c_str()); }; void Input::add_joy_mapping(String p_mapping, bool p_update_existing) { @@ -1187,50 +1393,18 @@ Array Input::get_connected_joypads() { return ret; } -static const char *_buttons[JOY_BUTTON_MAX] = { - "Face Button Bottom", - "Face Button Right", - "Face Button Left", - "Face Button Top", - "L", - "R", - "L2", - "R2", - "L3", - "R3", - "Select", - "Start", - "DPAD Up", - "DPAD Down", - "DPAD Left", - "DPAD Right" -}; - -static const char *_axes[JOY_AXIS_MAX] = { - "Left Stick X", - "Left Stick Y", - "Right Stick X", - "Right Stick Y", - "", - "", - "L2", - "R2", - "", - "" -}; - String Input::get_joy_button_string(int p_button) { - ERR_FAIL_INDEX_V(p_button, JOY_BUTTON_MAX, ""); - return _buttons[p_button]; + ERR_FAIL_INDEX_V(p_button, JOY_BUTTON_MAX, "Invalid button"); + return _joy_button_names[p_button]; } int Input::get_joy_button_index_from_string(String p_button) { for (int i = 0; i < JOY_BUTTON_MAX; i++) { - if (p_button == _buttons[i]) { + if (p_button == _joy_button_names[i]) { return i; } } - ERR_FAIL_V(-1); + ERR_FAIL_V(JOY_INVALID_BUTTON); } int Input::get_unused_joy_id() { @@ -1243,48 +1417,22 @@ int Input::get_unused_joy_id() { } String Input::get_joy_axis_string(int p_axis) { - ERR_FAIL_INDEX_V(p_axis, JOY_AXIS_MAX, ""); - return _axes[p_axis]; + ERR_FAIL_INDEX_V(p_axis, JOY_AXIS_MAX, "Invalid axis"); + return _joy_axis_names[p_axis]; } int Input::get_joy_axis_index_from_string(String p_axis) { for (int i = 0; i < JOY_AXIS_MAX; i++) { - if (p_axis == _axes[i]) { + if (p_axis == _joy_axis_names[i]) { return i; } } - ERR_FAIL_V(-1); + ERR_FAIL_V(JOY_INVALID_AXIS); } Input::Input() { singleton = this; - use_accumulated_input = true; - mouse_button_mask = 0; - mouse_window = 0; - emulate_touch_from_mouse = false; - emulate_mouse_from_touch = false; - mouse_from_touch_index = -1; - event_dispatch_function = nullptr; - default_shape = CURSOR_ARROW; - - hat_map_default[HAT_UP].type = TYPE_BUTTON; - hat_map_default[HAT_UP].index = JOY_DPAD_UP; - hat_map_default[HAT_UP].value = 0; - - hat_map_default[HAT_RIGHT].type = TYPE_BUTTON; - hat_map_default[HAT_RIGHT].index = JOY_DPAD_RIGHT; - hat_map_default[HAT_RIGHT].value = 0; - - hat_map_default[HAT_DOWN].type = TYPE_BUTTON; - hat_map_default[HAT_DOWN].index = JOY_DPAD_DOWN; - hat_map_default[HAT_DOWN].value = 0; - - hat_map_default[HAT_LEFT].type = TYPE_BUTTON; - hat_map_default[HAT_LEFT].index = JOY_DPAD_LEFT; - hat_map_default[HAT_LEFT].value = 0; - - fallback_mapping = -1; // Parse default mappings. { diff --git a/core/input/input.h b/core/input/input.h index 2e136dbf02..f3150a8127 100644 --- a/core/input/input.h +++ b/core/input/input.h @@ -36,7 +36,6 @@ #include "core/os/thread_safe.h" class Input : public Object { - GDCLASS(Input, Object); _THREAD_SAFE_CLASS_ @@ -100,7 +99,7 @@ public: typedef void (*EventDispatchFunc)(const Ref<InputEvent> &p_event); private: - int mouse_button_mask; + int mouse_button_mask = 0; Set<int> keys_pressed; Set<int> joy_buttons_pressed; @@ -111,7 +110,7 @@ private: Vector3 magnetometer; Vector3 gyroscope; Vector2 mouse_pos; - int64_t mouse_window; + int64_t mouse_window = 0; struct Action { uint64_t physics_frame; @@ -122,10 +121,11 @@ private: Map<StringName, Action> action_state; - bool emulate_touch_from_mouse; - bool emulate_mouse_from_touch; + bool emulate_touch_from_mouse = false; + bool emulate_mouse_from_touch = false; + bool use_accumulated_input = false; - int mouse_from_touch_index; + int mouse_from_touch_index = -1; struct SpeedTrack { @@ -144,37 +144,21 @@ private: struct Joypad { StringName name; StringName uid; - bool connected; - bool last_buttons[JOY_BUTTON_MAX + 19]; //apparently SDL specifies 35 possible buttons on android - float last_axis[JOY_AXIS_MAX]; - float filter; - int last_hat; - int mapping; - int hat_current; - - Joypad() { - for (int i = 0; i < JOY_AXIS_MAX; i++) { - - last_axis[i] = 0.0f; - } - for (int i = 0; i < JOY_BUTTON_MAX + 19; i++) { - - last_buttons[i] = false; - } - connected = false; - last_hat = HAT_MASK_CENTER; - filter = 0.01f; - mapping = -1; - hat_current = 0; - } + bool connected = false; + bool last_buttons[JOY_BUTTON_MAX] = { false }; + float last_axis[JOY_AXIS_MAX] = { 0.0f }; + float filter = 0.01f; + int last_hat = HAT_MASK_CENTER; + int mapping = -1; + int hat_current = 0; }; SpeedTrack mouse_speed_track; Map<int, SpeedTrack> touch_speed_track; Map<int, Joypad> joy_names; - int fallback_mapping; + int fallback_mapping = -1; - CursorShape default_shape; + CursorShape default_shape = CURSOR_ARROW; enum JoyType { TYPE_BUTTON, @@ -183,26 +167,61 @@ private: TYPE_MAX, }; + enum JoyAxisRange { + NEGATIVE_HALF_AXIS = -1, + FULL_AXIS = 0, + POSITIVE_HALF_AXIS = 1 + }; + struct JoyEvent { int type; int index; - int value; + float value; }; - struct JoyDeviceMapping { + struct JoyBinding { + JoyType inputType; + union { + int button; + + struct { + int axis; + JoyAxisRange range; + bool invert; + } axis; + + struct { + int hat; + HatMask hat_mask; + } hat; + + } input; + JoyType outputType; + union { + JoyButtonList button; + + struct { + JoyAxisList axis; + JoyAxisRange range; + } axis; + + } output; + }; + + struct JoyDeviceMapping { String uid; String name; - Map<int, JoyEvent> buttons; - Map<int, JoyEvent> axis; - JoyEvent hat[HAT_MAX]; + Vector<JoyBinding> bindings; }; - JoyEvent hat_map_default[HAT_MAX]; - Vector<JoyDeviceMapping> map_db; - JoyEvent _find_to_event(String p_to); + JoyEvent _get_mapped_button_event(const JoyDeviceMapping &mapping, int p_button); + JoyEvent _get_mapped_axis_event(const JoyDeviceMapping &mapping, int p_axis, const JoyAxis &p_value); + void _get_mapped_hat_events(const JoyDeviceMapping &mapping, int p_hat, JoyEvent r_events[HAT_MAX]); + JoyButtonList _get_output_button(String output); + JoyAxisList _get_output_axis(String output); void _button_event(int p_device, int p_index, bool p_pressed); void _axis_event(int p_device, int p_axis, float p_value); float _handle_deadzone(int p_device, int p_axis, float p_value); @@ -210,7 +229,7 @@ private: void _parse_input_event_impl(const Ref<InputEvent> &p_event, bool p_is_emulated); List<Ref<InputEvent>> accumulated_events; - bool use_accumulated_input; + friend class DisplayServer; static void (*set_mouse_mode_func)(MouseMode); @@ -220,7 +239,7 @@ private: static CursorShape (*get_current_cursor_shape_func)(); static void (*set_custom_mouse_cursor_func)(const RES &, CursorShape, const Vector2 &); - EventDispatchFunc event_dispatch_function; + EventDispatchFunc event_dispatch_function = nullptr; protected: struct VibrationInfo { diff --git a/core/input/input_builders.py b/core/input/input_builders.py index 53b90f2073..748ec06133 100644 --- a/core/input/input_builders.py +++ b/core/input/input_builders.py @@ -42,18 +42,7 @@ def make_default_controller_mappings(target, source, env): src_path, current_platform, platform_mappings[current_platform][guid] ) ) - valid_mapping = True - for input_map in line_parts[2:]: - if "+" in input_map or "-" in input_map or "~" in input_map: - g.write( - "// WARNING - DISCARDED UNSUPPORTED MAPPING TYPE FROM DATABASE {}: {} {}\n".format( - src_path, current_platform, line - ) - ) - valid_mapping = False - break - if valid_mapping: - platform_mappings[current_platform][guid] = line + platform_mappings[current_platform][guid] = line platform_variables = { "Linux": "#if X11_ENABLED", diff --git a/core/input/input_event.cpp b/core/input/input_event.cpp index 4b8c104f39..9d3f8f9424 100644 --- a/core/input/input_event.cpp +++ b/core/input/input_event.cpp @@ -132,11 +132,7 @@ void InputEvent::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::INT, "device"), "set_device", "get_device"); } -InputEvent::InputEvent() { - - device = 0; -} -//////////////// +/////////////////////////////////// void InputEventFromWindow::_bind_methods() { @@ -152,11 +148,7 @@ int64_t InputEventFromWindow::get_window_id() const { return window_id; } -InputEventFromWindow::InputEventFromWindow() { - window_id = 0; -} - -////////////////// +/////////////////////////////////// void InputEventWithModifiers::set_shift(bool p_enabled) { @@ -236,15 +228,7 @@ void InputEventWithModifiers::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::BOOL, "command"), "set_command", "get_command"); } -InputEventWithModifiers::InputEventWithModifiers() { - - alt = false; - shift = false; - control = false; - meta = false; -} - -////////////////////////////////// +/////////////////////////////////// void InputEventKey::set_pressed(bool p_pressed) { @@ -411,16 +395,7 @@ void InputEventKey::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::BOOL, "echo"), "set_echo", "is_echo"); } -InputEventKey::InputEventKey() { - - pressed = false; - keycode = 0; - physical_keycode = 0; - unicode = 0; ///unicode - echo = false; -} - -//////////////////////////////////////// +/////////////////////////////////// void InputEventMouse::set_button_mask(int p_mask) { @@ -465,12 +440,7 @@ void InputEventMouse::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "global_position"), "set_global_position", "get_global_position"); } -InputEventMouse::InputEventMouse() { - - button_mask = 0; -} - -/////////////////////////////////////// +/////////////////////////////////// void InputEventMouseButton::set_factor(float p_factor) { @@ -608,15 +578,7 @@ void InputEventMouseButton::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::BOOL, "doubleclick"), "set_doubleclick", "is_doubleclick"); } -InputEventMouseButton::InputEventMouseButton() { - - factor = 1; - button_index = 0; - pressed = false; - doubleclick = false; -} - -//////////////////////////////////////////// +/////////////////////////////////// void InputEventMouseMotion::set_tilt(const Vector2 &p_tilt) { @@ -773,12 +735,7 @@ void InputEventMouseMotion::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "speed"), "set_speed", "get_speed"); } -InputEventMouseMotion::InputEventMouseMotion() { - - pressure = 0; -} - -//////////////////////////////////////// +/////////////////////////////////// void InputEventJoypadMotion::set_axis(int p_axis) { @@ -849,12 +806,7 @@ void InputEventJoypadMotion::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "axis_value"), "set_axis_value", "get_axis_value"); } -InputEventJoypadMotion::InputEventJoypadMotion() { - - axis = 0; - axis_value = 0; -} -///////////////////////////////// +/////////////////////////////////// void InputEventJoypadButton::set_button_index(int p_index) { @@ -931,14 +883,7 @@ void InputEventJoypadButton::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::BOOL, "pressed"), "set_pressed", "is_pressed"); } -InputEventJoypadButton::InputEventJoypadButton() { - - button_index = 0; - pressure = 0; - pressed = false; -} - -////////////////////////////////////////////// +/////////////////////////////////// void InputEventScreenTouch::set_index(int p_index) { @@ -1001,13 +946,7 @@ void InputEventScreenTouch::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::BOOL, "pressed"), "set_pressed", "is_pressed"); } -InputEventScreenTouch::InputEventScreenTouch() { - - index = 0; - pressed = false; -} - -///////////////////////////// +/////////////////////////////////// void InputEventScreenDrag::set_index(int p_index) { @@ -1088,11 +1027,7 @@ void InputEventScreenDrag::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "speed"), "set_speed", "get_speed"); } -InputEventScreenDrag::InputEventScreenDrag() { - - index = 0; -} -///////////////////////////// +/////////////////////////////////// void InputEventAction::set_action(const StringName &p_action) { @@ -1171,11 +1106,7 @@ void InputEventAction::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "strength", PROPERTY_HINT_RANGE, "0,1,0.01"), "set_strength", "get_strength"); } -InputEventAction::InputEventAction() { - pressed = false; - strength = 1.0f; -} -///////////////////////////// +/////////////////////////////////// void InputEventGesture::set_position(const Vector2 &p_pos) { @@ -1194,7 +1125,8 @@ Vector2 InputEventGesture::get_position() const { return pos; } -///////////////////////////// + +/////////////////////////////////// void InputEventMagnifyGesture::set_factor(real_t p_factor) { @@ -1235,11 +1167,7 @@ void InputEventMagnifyGesture::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "factor"), "set_factor", "get_factor"); } -InputEventMagnifyGesture::InputEventMagnifyGesture() { - - factor = 1.0; -} -///////////////////////////// +/////////////////////////////////// void InputEventPanGesture::set_delta(const Vector2 &p_delta) { @@ -1279,11 +1207,7 @@ void InputEventPanGesture::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "delta"), "set_delta", "get_delta"); } -InputEventPanGesture::InputEventPanGesture() { - - delta = Vector2(0, 0); -} -///////////////////////////// +/////////////////////////////////// void InputEventMIDI::set_channel(const int p_channel) { @@ -1390,15 +1314,3 @@ void InputEventMIDI::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::INT, "controller_number"), "set_controller_number", "get_controller_number"); ADD_PROPERTY(PropertyInfo(Variant::INT, "controller_value"), "set_controller_value", "get_controller_value"); } - -InputEventMIDI::InputEventMIDI() { - - channel = 0; - message = 0; - pitch = 0; - velocity = 0; - instrument = 0; - pressure = 0; - controller_number = 0; - controller_value = 0; -} diff --git a/core/input/input_event.h b/core/input/input_event.h index 8774b3b1db..18792076f5 100644 --- a/core/input/input_event.h +++ b/core/input/input_event.h @@ -59,98 +59,84 @@ enum ButtonList { BUTTON_MASK_XBUTTON2 = (1 << (BUTTON_XBUTTON2 - 1)) }; -enum JoystickList { - - JOY_BUTTON_0 = 0, - JOY_BUTTON_1 = 1, - JOY_BUTTON_2 = 2, - JOY_BUTTON_3 = 3, - JOY_BUTTON_4 = 4, - JOY_BUTTON_5 = 5, - JOY_BUTTON_6 = 6, - JOY_BUTTON_7 = 7, - JOY_BUTTON_8 = 8, - JOY_BUTTON_9 = 9, - JOY_BUTTON_10 = 10, - JOY_BUTTON_11 = 11, - JOY_BUTTON_12 = 12, - JOY_BUTTON_13 = 13, - JOY_BUTTON_14 = 14, - JOY_BUTTON_15 = 15, - JOY_BUTTON_MAX = 16, - - JOY_L = JOY_BUTTON_4, - JOY_R = JOY_BUTTON_5, - JOY_L2 = JOY_BUTTON_6, - JOY_R2 = JOY_BUTTON_7, - JOY_L3 = JOY_BUTTON_8, - JOY_R3 = JOY_BUTTON_9, - JOY_SELECT = JOY_BUTTON_10, - JOY_START = JOY_BUTTON_11, - JOY_DPAD_UP = JOY_BUTTON_12, - JOY_DPAD_DOWN = JOY_BUTTON_13, - JOY_DPAD_LEFT = JOY_BUTTON_14, - JOY_DPAD_RIGHT = JOY_BUTTON_15, - - JOY_SONY_CIRCLE = JOY_BUTTON_1, - JOY_SONY_X = JOY_BUTTON_0, - JOY_SONY_SQUARE = JOY_BUTTON_2, - JOY_SONY_TRIANGLE = JOY_BUTTON_3, - - JOY_XBOX_A = JOY_BUTTON_0, - JOY_XBOX_B = JOY_BUTTON_1, - JOY_XBOX_X = JOY_BUTTON_2, - JOY_XBOX_Y = JOY_BUTTON_3, - - JOY_DS_A = JOY_BUTTON_1, - JOY_DS_B = JOY_BUTTON_0, - JOY_DS_X = JOY_BUTTON_3, - JOY_DS_Y = JOY_BUTTON_2, - - JOY_WII_C = JOY_BUTTON_5, - JOY_WII_Z = JOY_BUTTON_6, - - JOY_WII_MINUS = JOY_BUTTON_10, - JOY_WII_PLUS = JOY_BUTTON_11, - - JOY_VR_GRIP = JOY_BUTTON_2, - JOY_VR_PAD = JOY_BUTTON_14, - JOY_VR_TRIGGER = JOY_BUTTON_15, - - JOY_OCULUS_AX = JOY_BUTTON_7, - JOY_OCULUS_BY = JOY_BUTTON_1, - JOY_OCULUS_MENU = JOY_BUTTON_3, - - JOY_OPENVR_MENU = JOY_BUTTON_1, - - // end of history - - JOY_AXIS_0 = 0, - JOY_AXIS_1 = 1, - JOY_AXIS_2 = 2, - JOY_AXIS_3 = 3, - JOY_AXIS_4 = 4, - JOY_AXIS_5 = 5, - JOY_AXIS_6 = 6, - JOY_AXIS_7 = 7, - JOY_AXIS_8 = 8, - JOY_AXIS_9 = 9, - JOY_AXIS_MAX = 10, - - JOY_ANALOG_LX = JOY_AXIS_0, - JOY_ANALOG_LY = JOY_AXIS_1, - - JOY_ANALOG_RX = JOY_AXIS_2, - JOY_ANALOG_RY = JOY_AXIS_3, - - JOY_ANALOG_L2 = JOY_AXIS_6, - JOY_ANALOG_R2 = JOY_AXIS_7, - - JOY_VR_ANALOG_TRIGGER = JOY_AXIS_2, - JOY_VR_ANALOG_GRIP = JOY_AXIS_4, - - JOY_OPENVR_TOUCHPADX = JOY_AXIS_0, - JOY_OPENVR_TOUCHPADY = JOY_AXIS_1, +enum JoyButtonList { + + JOY_INVALID_BUTTON = -1, + + // SDL Buttons + JOY_BUTTON_A = 0, + JOY_BUTTON_B = 1, + JOY_BUTTON_X = 2, + JOY_BUTTON_Y = 3, + JOY_BUTTON_BACK = 4, + JOY_BUTTON_GUIDE = 5, + JOY_BUTTON_START = 6, + JOY_BUTTON_LEFT_STICK = 7, + JOY_BUTTON_RIGHT_STICK = 8, + JOY_BUTTON_LEFT_SHOULDER = 9, + JOY_BUTTON_RIGHT_SHOULDER = 10, + JOY_BUTTON_DPAD_UP = 11, + JOY_BUTTON_DPAD_DOWN = 12, + JOY_BUTTON_DPAD_LEFT = 13, + JOY_BUTTON_DPAD_RIGHT = 14, + JOY_SDL_BUTTONS = 15, + + // Sony Buttons + JOY_SONY_X = JOY_BUTTON_A, + JOY_SONY_CROSS = JOY_BUTTON_A, + JOY_SONY_CIRCLE = JOY_BUTTON_B, + JOY_SONY_SQUARE = JOY_BUTTON_X, + JOY_SONY_TRIANGLE = JOY_BUTTON_Y, + JOY_SONY_SELECT = JOY_BUTTON_BACK, + JOY_SONY_START = JOY_BUTTON_START, + JOY_SONY_PS = JOY_BUTTON_GUIDE, + JOY_SONY_L1 = JOY_BUTTON_LEFT_SHOULDER, + JOY_SONY_R1 = JOY_BUTTON_RIGHT_SHOULDER, + JOY_SONY_L3 = JOY_BUTTON_LEFT_STICK, + JOY_SONY_R3 = JOY_BUTTON_RIGHT_STICK, + + // Xbox Buttons + JOY_XBOX_A = JOY_BUTTON_A, + JOY_XBOX_B = JOY_BUTTON_B, + JOY_XBOX_X = JOY_BUTTON_X, + JOY_XBOX_Y = JOY_BUTTON_Y, + JOY_XBOX_BACK = JOY_BUTTON_BACK, + JOY_XBOX_START = JOY_BUTTON_START, + JOY_XBOX_HOME = JOY_BUTTON_GUIDE, + JOY_XBOX_LS = JOY_BUTTON_LEFT_STICK, + JOY_XBOX_RS = JOY_BUTTON_RIGHT_STICK, + JOY_XBOX_LB = JOY_BUTTON_LEFT_SHOULDER, + JOY_XBOX_RB = JOY_BUTTON_RIGHT_SHOULDER, + + JOY_BUTTON_MAX = 36 // Apparently Android supports up to 36 buttons. +}; + +enum JoyAxisList { + + JOY_INVALID_AXIS = -1, + + // SDL Axes + JOY_AXIS_LEFT_X = 0, + JOY_AXIS_LEFT_Y = 1, + JOY_AXIS_RIGHT_X = 2, + JOY_AXIS_RIGHT_Y = 3, + JOY_AXIS_TRIGGER_LEFT = 4, + JOY_AXIS_TRIGGER_RIGHT = 5, + JOY_SDL_AXES = 6, + + // Joystick axes. + JOY_AXIS_0_X = 0, + JOY_AXIS_0_Y = 1, + JOY_AXIS_1_X = 2, + JOY_AXIS_1_Y = 3, + JOY_AXIS_2_X = 4, + JOY_AXIS_2_Y = 5, + JOY_AXIS_3_X = 6, + JOY_AXIS_3_Y = 7, + JOY_AXIS_4_X = 8, + JOY_AXIS_4_Y = 9, + + JOY_AXIS_MAX = 10 // OpenVR supports up to 5 Joysticks making a total of 10 axes. }; enum MidiMessageList { @@ -171,7 +157,7 @@ enum MidiMessageList { class InputEvent : public Resource { GDCLASS(InputEvent, Resource); - int device; + int device = 0; protected: static void _bind_methods(); @@ -191,7 +177,6 @@ public: // To be removed someday, since they do not make sense for all events virtual bool is_pressed() const; virtual bool is_echo() const; - // ...-. virtual String as_text() const; @@ -202,14 +187,15 @@ public: virtual bool is_action_type() const; virtual bool accumulate(const Ref<InputEvent> &p_event) { return false; } - InputEvent(); + + InputEvent() {} }; class InputEventFromWindow : public InputEvent { GDCLASS(InputEventFromWindow, InputEvent); - int64_t window_id; + int64_t window_id = 0; protected: static void _bind_methods(); @@ -218,28 +204,27 @@ public: void set_window_id(int64_t p_id); int64_t get_window_id() const; - InputEventFromWindow(); + InputEventFromWindow() {} }; class InputEventWithModifiers : public InputEventFromWindow { GDCLASS(InputEventWithModifiers, InputEventFromWindow); - bool shift; - bool alt; + bool shift = false; + bool alt = false; #ifdef APPLE_STYLE_KEYS union { bool command; - bool meta; //< windows/mac key + bool meta = false; //< windows/mac key }; - bool control; + bool control = false; #else union { bool command; //< windows/mac key - bool control; + bool control = false; }; - bool meta; //< windows/mac key - + bool meta = false; //< windows/mac key #endif protected: @@ -263,20 +248,20 @@ public: void set_modifiers_from_event(const InputEventWithModifiers *event); - InputEventWithModifiers(); + InputEventWithModifiers() {} }; class InputEventKey : public InputEventWithModifiers { GDCLASS(InputEventKey, InputEventWithModifiers); - bool pressed; /// otherwise release + bool pressed = false; /// otherwise release - uint32_t keycode; ///< check keyboard.h , KeyCode enum, without modifier masks - uint32_t physical_keycode; - uint32_t unicode; ///unicode + uint32_t keycode = 0; ///< check keyboard.h , KeyCode enum, without modifier masks + uint32_t physical_keycode = 0; + uint32_t unicode = 0; ///unicode - bool echo; /// true if this is an echo key + bool echo = false; /// true if this is an echo key protected: static void _bind_methods(); @@ -307,14 +292,14 @@ public: virtual String as_text() const; - InputEventKey(); + InputEventKey() {} }; class InputEventMouse : public InputEventWithModifiers { GDCLASS(InputEventMouse, InputEventWithModifiers); - int button_mask; + int button_mask = 0; Vector2 pos; Vector2 global_pos; @@ -332,17 +317,17 @@ public: void set_global_position(const Vector2 &p_global_pos); Vector2 get_global_position() const; - InputEventMouse(); + InputEventMouse() {} }; class InputEventMouseButton : public InputEventMouse { GDCLASS(InputEventMouseButton, InputEventMouse); - float factor; - int button_index; - bool pressed; //otherwise released - bool doubleclick; //last even less than doubleclick time + float factor = 1; + int button_index = 0; + bool pressed = false; //otherwise released + bool doubleclick = false; //last even less than doubleclick time protected: static void _bind_methods(); @@ -366,7 +351,7 @@ public: virtual bool is_action_type() const { return true; } virtual String as_text() const; - InputEventMouseButton(); + InputEventMouseButton() {} }; class InputEventMouseMotion : public InputEventMouse { @@ -374,7 +359,7 @@ class InputEventMouseMotion : public InputEventMouse { GDCLASS(InputEventMouseMotion, InputEventMouse); Vector2 tilt; - float pressure; + float pressure = 0; Vector2 relative; Vector2 speed; @@ -399,14 +384,14 @@ public: virtual bool accumulate(const Ref<InputEvent> &p_event); - InputEventMouseMotion(); + InputEventMouseMotion() {} }; class InputEventJoypadMotion : public InputEvent { GDCLASS(InputEventJoypadMotion, InputEvent); - int axis; ///< Joypad axis - float axis_value; ///< -1 to 1 + int axis = 0; ///< Joypad axis + float axis_value = 0; ///< -1 to 1 protected: static void _bind_methods(); @@ -425,15 +410,15 @@ public: virtual bool is_action_type() const { return true; } virtual String as_text() const; - InputEventJoypadMotion(); + InputEventJoypadMotion() {} }; class InputEventJoypadButton : public InputEvent { GDCLASS(InputEventJoypadButton, InputEvent); - int button_index; - bool pressed; - float pressure; //0 to 1 + int button_index = 0; + bool pressed = false; + float pressure = 0; //0 to 1 protected: static void _bind_methods(); @@ -453,14 +438,14 @@ public: virtual bool is_action_type() const { return true; } virtual String as_text() const; - InputEventJoypadButton(); + InputEventJoypadButton() {} }; class InputEventScreenTouch : public InputEventFromWindow { GDCLASS(InputEventScreenTouch, InputEventFromWindow); - int index; + int index = 0; Vector2 pos; - bool pressed; + bool pressed = false; protected: static void _bind_methods(); @@ -478,13 +463,13 @@ public: virtual Ref<InputEvent> xformed_by(const Transform2D &p_xform, const Vector2 &p_local_ofs = Vector2()) const; virtual String as_text() const; - InputEventScreenTouch(); + InputEventScreenTouch() {} }; class InputEventScreenDrag : public InputEventFromWindow { GDCLASS(InputEventScreenDrag, InputEventFromWindow); - int index; + int index = 0; Vector2 pos; Vector2 relative; Vector2 speed; @@ -508,7 +493,7 @@ public: virtual Ref<InputEvent> xformed_by(const Transform2D &p_xform, const Vector2 &p_local_ofs = Vector2()) const; virtual String as_text() const; - InputEventScreenDrag(); + InputEventScreenDrag() {} }; class InputEventAction : public InputEvent { @@ -516,8 +501,8 @@ class InputEventAction : public InputEvent { GDCLASS(InputEventAction, InputEvent); StringName action; - bool pressed; - float strength; + bool pressed = false; + float strength = 1.0f; protected: static void _bind_methods(); @@ -540,7 +525,7 @@ public: virtual bool is_action_type() const { return true; } virtual String as_text() const; - InputEventAction(); + InputEventAction() {} }; class InputEventGesture : public InputEventWithModifiers { @@ -560,7 +545,7 @@ public: class InputEventMagnifyGesture : public InputEventGesture { GDCLASS(InputEventMagnifyGesture, InputEventGesture); - real_t factor; + real_t factor = 1.0; protected: static void _bind_methods(); @@ -572,7 +557,7 @@ public: virtual Ref<InputEvent> xformed_by(const Transform2D &p_xform, const Vector2 &p_local_ofs = Vector2()) const; virtual String as_text() const; - InputEventMagnifyGesture(); + InputEventMagnifyGesture() {} }; class InputEventPanGesture : public InputEventGesture { @@ -590,20 +575,20 @@ public: virtual Ref<InputEvent> xformed_by(const Transform2D &p_xform, const Vector2 &p_local_ofs = Vector2()) const; virtual String as_text() const; - InputEventPanGesture(); + InputEventPanGesture() {} }; class InputEventMIDI : public InputEvent { GDCLASS(InputEventMIDI, InputEvent); - int channel; - int message; - int pitch; - int velocity; - int instrument; - int pressure; - int controller_number; - int controller_value; + int channel = 0; + int message = 0; + int pitch = 0; + int velocity = 0; + int instrument = 0; + int pressure = 0; + int controller_number = 0; + int controller_value = 0; protected: static void _bind_methods(); @@ -635,7 +620,7 @@ public: virtual String as_text() const; - InputEventMIDI(); + InputEventMIDI() {} }; #endif // INPUT_EVENT_H diff --git a/core/io/compression.h b/core/io/compression.h index 8354b581fa..3e7c125d8e 100644 --- a/core/io/compression.h +++ b/core/io/compression.h @@ -53,7 +53,7 @@ public: static int get_max_compressed_buffer_size(int p_src_size, Mode p_mode = MODE_ZSTD); static int decompress(uint8_t *p_dst, int p_dst_max_size, const uint8_t *p_src, int p_src_size, Mode p_mode = MODE_ZSTD); - Compression(); + Compression() {} }; #endif // COMPRESSION_H diff --git a/core/io/dtls_server.cpp b/core/io/dtls_server.cpp index 5bda06e5b9..7cbf5c618e 100644 --- a/core/io/dtls_server.cpp +++ b/core/io/dtls_server.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "dtls_server.h" + #include "core/os/file_access.h" #include "core/project_settings.h" @@ -49,6 +50,3 @@ void DTLSServer::_bind_methods() { ClassDB::bind_method(D_METHOD("setup", "key", "certificate", "chain"), &DTLSServer::setup, DEFVAL(Ref<X509Certificate>())); ClassDB::bind_method(D_METHOD("take_connection", "udp_peer"), &DTLSServer::take_connection); } - -DTLSServer::DTLSServer() { -} diff --git a/core/io/dtls_server.h b/core/io/dtls_server.h index 7b08138f7f..ae1d3bcd98 100644 --- a/core/io/dtls_server.h +++ b/core/io/dtls_server.h @@ -51,7 +51,7 @@ public: virtual void stop() = 0; virtual Ref<PacketPeerDTLS> take_connection(Ref<PacketPeerUDP> p_peer) = 0; - DTLSServer(); + DTLSServer() {} }; #endif // DTLS_SERVER_H diff --git a/core/io/file_access_buffered.cpp b/core/io/file_access_buffered.cpp index ab0fb3943c..2df91a4dd8 100644 --- a/core/io/file_access_buffered.cpp +++ b/core/io/file_access_buffered.cpp @@ -166,11 +166,3 @@ Error FileAccessBuffered::get_error() const { return last_error; } - -FileAccessBuffered::FileAccessBuffered() { - - cache_size = DEFAULT_CACHE_SIZE; -} - -FileAccessBuffered::~FileAccessBuffered() { -} diff --git a/core/io/file_access_buffered.h b/core/io/file_access_buffered.h index a6177c20be..f473886330 100644 --- a/core/io/file_access_buffered.h +++ b/core/io/file_access_buffered.h @@ -43,7 +43,7 @@ public: }; private: - int cache_size; + int cache_size = DEFAULT_CACHE_SIZE; int cache_data_left() const; mutable Error last_error; @@ -66,7 +66,7 @@ protected: int offset; } cache; - virtual int read_data_block(int p_offset, int p_size, uint8_t *p_dest = 0) const = 0; + virtual int read_data_block(int p_offset, int p_size, uint8_t *p_dest = nullptr) const = 0; void set_cache_size(int p_size); int get_cache_size(); @@ -87,8 +87,8 @@ public: virtual Error get_error() const; - FileAccessBuffered(); - virtual ~FileAccessBuffered(); + FileAccessBuffered() {} + virtual ~FileAccessBuffered() {} }; #endif diff --git a/core/io/file_access_buffered_fa.h b/core/io/file_access_buffered_fa.h index 6ec77d503b..edb4ff9a9f 100644 --- a/core/io/file_access_buffered_fa.h +++ b/core/io/file_access_buffered_fa.h @@ -132,12 +132,6 @@ public: set_error(OK); }; - /* - static void make_default() { - FileAccess::create_func = FileAccessBufferedFA<T>::create; - }; - */ - virtual uint64_t _get_modified_time(const String &p_file) { return f._get_modified_time(p_file); @@ -151,9 +145,7 @@ public: return f._set_unix_permissions(p_file, p_permissions); } - FileAccessBufferedFA(){ - - }; + FileAccessBufferedFA() {} }; #endif // FILE_ACCESS_BUFFERED_FA_H diff --git a/core/io/file_access_compressed.cpp b/core/io/file_access_compressed.cpp index c76142d22d..f2827b519e 100644 --- a/core/io/file_access_compressed.cpp +++ b/core/io/file_access_compressed.cpp @@ -389,25 +389,6 @@ Error FileAccessCompressed::_set_unix_permissions(const String &p_file, uint32_t return FAILED; } -FileAccessCompressed::FileAccessCompressed() : - cmode(Compression::MODE_ZSTD), - writing(false), - write_ptr(nullptr), - write_buffer_size(0), - write_max(0), - block_size(0), - read_eof(false), - at_end(false), - read_ptr(nullptr), - read_block(0), - read_block_count(0), - read_block_size(0), - read_pos(0), - read_total(0), - magic("GCMP"), - f(nullptr) { -} - FileAccessCompressed::~FileAccessCompressed() { if (f) diff --git a/core/io/file_access_compressed.h b/core/io/file_access_compressed.h index 0bb311faa8..f192be0883 100644 --- a/core/io/file_access_compressed.h +++ b/core/io/file_access_compressed.h @@ -36,15 +36,15 @@ class FileAccessCompressed : public FileAccess { - Compression::Mode cmode; - bool writing; - uint32_t write_pos; - uint8_t *write_ptr; - uint32_t write_buffer_size; - uint32_t write_max; - uint32_t block_size; - mutable bool read_eof; - mutable bool at_end; + Compression::Mode cmode = Compression::MODE_ZSTD; + bool writing = false; + uint32_t write_pos = 0; + uint8_t *write_ptr = nullptr; + uint32_t write_buffer_size = 0; + uint32_t write_max = 0; + uint32_t block_size = 0; + mutable bool read_eof = false; + mutable bool at_end = false; struct ReadBlock { int csize; @@ -52,17 +52,17 @@ class FileAccessCompressed : public FileAccess { }; mutable Vector<uint8_t> comp_buffer; - uint8_t *read_ptr; - mutable int read_block; - int read_block_count; - mutable int read_block_size; - mutable int read_pos; + uint8_t *read_ptr = nullptr; + mutable int read_block = 0; + int read_block_count = 0; + mutable int read_block_size = 0; + mutable int read_pos = 0; Vector<ReadBlock> read_blocks; - uint32_t read_total; + uint32_t read_total = 0; - String magic; + String magic = "GCMP"; mutable Vector<uint8_t> buffer; - FileAccess *f; + FileAccess *f = nullptr; public: void configure(const String &p_magic, Compression::Mode p_mode = Compression::MODE_ZSTD, int p_block_size = 4096); @@ -94,7 +94,7 @@ public: virtual uint32_t _get_unix_permissions(const String &p_file); virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions); - FileAccessCompressed(); + FileAccessCompressed() {} virtual ~FileAccessCompressed(); }; diff --git a/core/io/file_access_encrypted.cpp b/core/io/file_access_encrypted.cpp index a5b3807789..271c34ec4a 100644 --- a/core/io/file_access_encrypted.cpp +++ b/core/io/file_access_encrypted.cpp @@ -317,15 +317,6 @@ Error FileAccessEncrypted::_set_unix_permissions(const String &p_file, uint32_t return ERR_UNAVAILABLE; } -FileAccessEncrypted::FileAccessEncrypted() { - - file = nullptr; - pos = 0; - eofed = false; - mode = MODE_MAX; - writing = false; -} - FileAccessEncrypted::~FileAccessEncrypted() { if (file) diff --git a/core/io/file_access_encrypted.h b/core/io/file_access_encrypted.h index 7a9f4ecdd8..e269c1e30c 100644 --- a/core/io/file_access_encrypted.h +++ b/core/io/file_access_encrypted.h @@ -42,15 +42,15 @@ public: }; private: - Mode mode; + Mode mode = MODE_MAX; Vector<uint8_t> key; - bool writing; - FileAccess *file; + bool writing = false; + FileAccess *file = nullptr; size_t base; size_t length; Vector<uint8_t> data; - mutable int pos; - mutable bool eofed; + mutable int pos = 0; + mutable bool eofed = false; public: Error open_and_parse(FileAccess *p_base, const Vector<uint8_t> &p_key, Mode p_mode); @@ -85,7 +85,7 @@ public: virtual uint32_t _get_unix_permissions(const String &p_file); virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions); - FileAccessEncrypted(); + FileAccessEncrypted() {} ~FileAccessEncrypted(); }; diff --git a/core/io/file_access_memory.cpp b/core/io/file_access_memory.cpp index a2379ce88f..a3e04a4538 100644 --- a/core/io/file_access_memory.cpp +++ b/core/io/file_access_memory.cpp @@ -193,8 +193,3 @@ void FileAccessMemory::store_buffer(const uint8_t *p_src, int p_length) { copymem(&data[pos], p_src, write); pos += p_length; } - -FileAccessMemory::FileAccessMemory() { - - data = nullptr; -} diff --git a/core/io/file_access_memory.h b/core/io/file_access_memory.h index 2db14db265..d8be989b20 100644 --- a/core/io/file_access_memory.h +++ b/core/io/file_access_memory.h @@ -35,7 +35,7 @@ class FileAccessMemory : public FileAccess { - uint8_t *data; + uint8_t *data = nullptr; int length; mutable int pos; @@ -73,7 +73,7 @@ public: virtual uint32_t _get_unix_permissions(const String &p_file) { return 0; } virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions) { return FAILED; } - FileAccessMemory(); + FileAccessMemory() {} }; #endif // FILE_ACCESS_MEMORY_H diff --git a/core/io/file_access_network.cpp b/core/io/file_access_network.cpp index a3f307393f..00f504c391 100644 --- a/core/io/file_access_network.cpp +++ b/core/io/file_access_network.cpp @@ -222,17 +222,11 @@ Error FileAccessNetworkClient::connect(const String &p_host, int p_port, const S FileAccessNetworkClient *FileAccessNetworkClient::singleton = nullptr; FileAccessNetworkClient::FileAccessNetworkClient() { - - thread = nullptr; - quit = false; singleton = this; - last_id = 0; client.instance(); - lockcount = 0; } FileAccessNetworkClient::~FileAccessNetworkClient() { - if (thread) { quit = true; sem.post(); @@ -513,9 +507,6 @@ void FileAccessNetwork::configure() { FileAccessNetwork::FileAccessNetwork() { - eof_flag = false; - opened = false; - pos = 0; FileAccessNetworkClient *nc = FileAccessNetworkClient::singleton; nc->lock_mutex(); id = nc->last_id++; @@ -523,9 +514,6 @@ FileAccessNetwork::FileAccessNetwork() { nc->unlock_mutex(); page_size = GLOBAL_GET("network/remote_fs/page_size"); read_ahead = GLOBAL_GET("network/remote_fs/page_read_ahead"); - last_activity_val = 0; - waiting_on_page = -1; - last_page = -1; } FileAccessNetwork::~FileAccessNetwork() { diff --git a/core/io/file_access_network.h b/core/io/file_access_network.h index 7f664b46f7..6cdd6af0b4 100644 --- a/core/io/file_access_network.h +++ b/core/io/file_access_network.h @@ -50,13 +50,14 @@ class FileAccessNetworkClient { List<BlockRequest> block_requests; Semaphore sem; - Thread *thread; - bool quit; + Thread *thread = nullptr; + bool quit = false; Mutex mutex; Mutex blockrequest_mutex; Map<int, FileAccessNetwork *> accesses; Ref<StreamPeerTCP> client; - int last_id; + int last_id = 0; + int lockcount = 0; Vector<uint8_t> block; @@ -67,7 +68,6 @@ class FileAccessNetworkClient { void put_64(int64_t p_64); int get_32(); int64_t get_64(); - int lockcount; void lock_mutex(); void unlock_mutex(); @@ -88,27 +88,23 @@ class FileAccessNetwork : public FileAccess { Semaphore sem; Semaphore page_sem; Mutex buffer_mutex; - bool opened; + bool opened = false; size_t total_size; - mutable size_t pos; + mutable size_t pos = 0; int id; - mutable bool eof_flag; - mutable int last_page; - mutable uint8_t *last_page_buff; + mutable bool eof_flag = false; + mutable int last_page = -1; + mutable uint8_t *last_page_buff = nullptr; int page_size; int read_ahead; - mutable int waiting_on_page; - mutable int last_activity_val; + mutable int waiting_on_page = -1; + struct Page { - int activity; - bool queued; + int activity = 0; + bool queued = false; Vector<uint8_t> buffer; - Page() { - activity = 0; - queued = false; - } }; mutable Vector<Page> pages; diff --git a/core/io/file_access_pack.cpp b/core/io/file_access_pack.cpp index 0a7dee9444..d70f2ba445 100644 --- a/core/io/file_access_pack.cpp +++ b/core/io/file_access_pack.cpp @@ -109,8 +109,6 @@ PackedData::PackedData() { singleton = this; root = memnew(PackedDir); - root->parent = nullptr; - disabled = false; add_pack_source(memnew(PackedSourcePCK)); } @@ -414,7 +412,8 @@ Error DirAccessPack::change_dir(String p_dir) { nd = nd.simplify_path(); - if (nd == "") nd = "."; + if (nd == "") + nd = "."; if (nd.begins_with("/")) { nd = nd.replace_first("/", ""); @@ -505,10 +504,5 @@ String DirAccessPack::get_filesystem_type() const { } DirAccessPack::DirAccessPack() { - current = PackedData::get_singleton()->root; - cdir = false; -} - -DirAccessPack::~DirAccessPack() { } diff --git a/core/io/file_access_pack.h b/core/io/file_access_pack.h index 8df6826ac9..aa3a14272b 100644 --- a/core/io/file_access_pack.h +++ b/core/io/file_access_pack.h @@ -61,15 +61,15 @@ public: private: struct PackedDir { - PackedDir *parent; + PackedDir *parent = nullptr; String name; Map<String, PackedDir *> subdirs; Set<String> files; }; struct PathMD5 { - uint64_t a; - uint64_t b; + uint64_t a = 0; + uint64_t b = 0; bool operator<(const PathMD5 &p_md5) const { if (p_md5.a == a) { @@ -83,14 +83,12 @@ private: return a == p_md5.a && b == p_md5.b; }; - PathMD5() { - a = b = 0; - }; + PathMD5() {} PathMD5(const Vector<uint8_t> p_buf) { a = *((uint64_t *)&p_buf[0]); b = *((uint64_t *)&p_buf[8]); - }; + } }; Map<PathMD5, PackedFile> files; @@ -98,10 +96,9 @@ private: Vector<PackSource *> sources; PackedDir *root; - //Map<String,PackedDir*> dirs; static PackedData *singleton; - bool disabled; + bool disabled = false; void _free_packed_dirs(PackedDir *p_dir); @@ -203,7 +200,7 @@ class DirAccessPack : public DirAccess { List<String> list_dirs; List<String> list_files; - bool cdir; + bool cdir = false; public: virtual Error list_dir_begin(); @@ -231,7 +228,7 @@ public: virtual String get_filesystem_type() const; DirAccessPack(); - ~DirAccessPack(); + ~DirAccessPack() {} }; #endif // FILE_ACCESS_PACK_H diff --git a/core/io/file_access_zip.cpp b/core/io/file_access_zip.cpp index 57de66afaf..9d068fe809 100644 --- a/core/io/file_access_zip.cpp +++ b/core/io/file_access_zip.cpp @@ -241,9 +241,7 @@ ZipArchive *ZipArchive::get_singleton() { } ZipArchive::ZipArchive() { - instance = this; - //fa_create_func = FileAccess::get_create_func(); } ZipArchive::~ZipArchive() { @@ -369,14 +367,12 @@ bool FileAccessZip::file_exists(const String &p_name) { return false; } -FileAccessZip::FileAccessZip(const String &p_path, const PackedData::PackedFile &p_file) : - zfile(nullptr) { +FileAccessZip::FileAccessZip(const String &p_path, const PackedData::PackedFile &p_file) { _open(p_path, FileAccess::READ); } FileAccessZip::~FileAccessZip() { - close(); } -#endif +#endif // MINIZIP_ENABLED diff --git a/core/io/file_access_zip.h b/core/io/file_access_zip.h index d5ce7d7a8d..17a3d085b6 100644 --- a/core/io/file_access_zip.h +++ b/core/io/file_access_zip.h @@ -45,18 +45,15 @@ class ZipArchive : public PackSource { public: struct File { - int package; + int package = -1; unz_file_pos file_pos; - File() { - - package = -1; - }; + File() {} }; private: struct Package { String filename; - unzFile zfile; + unzFile zfile = nullptr; }; Vector<Package> packages; diff --git a/core/io/http_client.cpp b/core/io/http_client.cpp index 56f8f1ff91..940bac0009 100644 --- a/core/io/http_client.cpp +++ b/core/io/http_client.cpp @@ -283,7 +283,7 @@ void HTTPClient::close() { body_size = -1; body_left = 0; chunk_left = 0; - chunk_trailer_part = 0; + chunk_trailer_part = false; read_until_eof = false; response_num = 0; handshaking = false; @@ -729,27 +729,10 @@ int HTTPClient::get_read_chunk_size() const { } HTTPClient::HTTPClient() { - tcp_connection.instance(); - resolving = IP::RESOLVER_INVALID_ID; - status = STATUS_DISCONNECTED; - head_request = false; - conn_port = -1; - body_size = -1; - chunked = false; - body_left = 0; - read_until_eof = false; - chunk_left = 0; - chunk_trailer_part = false; - response_num = 0; - ssl = false; - blocking = false; - handshaking = false; - read_chunk_size = 4096; } -HTTPClient::~HTTPClient() { -} +HTTPClient::~HTTPClient() {} #endif // #ifndef JAVASCRIPT_ENABLED diff --git a/core/io/http_client.h b/core/io/http_client.h index 03ba20f8dd..05690534ae 100644 --- a/core/io/http_client.h +++ b/core/io/http_client.h @@ -158,32 +158,32 @@ private: }; #ifndef JAVASCRIPT_ENABLED - Status status; - IP::ResolverID resolving; - int conn_port; + Status status = STATUS_DISCONNECTED; + IP::ResolverID resolving = IP::RESOLVER_INVALID_ID; + int conn_port = -1; String conn_host; - bool ssl; - bool ssl_verify_host; - bool blocking; - bool handshaking; - bool head_request; + bool ssl = false; + bool ssl_verify_host = false; + bool blocking = false; + bool handshaking = false; + bool head_request = false; Vector<uint8_t> response_str; - bool chunked; + bool chunked = false; Vector<uint8_t> chunk; - int chunk_left; - bool chunk_trailer_part; - int body_size; - int body_left; - bool read_until_eof; + int chunk_left = 0; + bool chunk_trailer_part = false; + int body_size = -1; + int body_left = 0; + bool read_until_eof = false; Ref<StreamPeerTCP> tcp_connection; Ref<StreamPeer> connection; - int response_num; + int response_num = 0; Vector<String> response_headers; - int read_chunk_size; + int read_chunk_size = 4096; Error _get_http_data(uint8_t *p_buffer, int p_bytes, int &r_received); diff --git a/core/io/ip_address.h b/core/io/ip_address.h index 89cf37ff8f..a59178063d 100644 --- a/core/io/ip_address.h +++ b/core/io/ip_address.h @@ -52,16 +52,20 @@ protected: public: //operator Variant() const; bool operator==(const IP_Address &p_ip) const { - if (p_ip.valid != valid) return false; - if (!valid) return false; + if (p_ip.valid != valid) + return false; + if (!valid) + return false; for (int i = 0; i < 4; i++) if (field32[i] != p_ip.field32[i]) return false; return true; } bool operator!=(const IP_Address &p_ip) const { - if (p_ip.valid != valid) return true; - if (!valid) return true; + if (p_ip.valid != valid) + return true; + if (!valid) + return true; for (int i = 0; i < 4; i++) if (field32[i] != p_ip.field32[i]) return true; diff --git a/core/io/json.cpp b/core/io/json.cpp index 3a0edceb81..0186547dd2 100644 --- a/core/io/json.cpp +++ b/core/io/json.cpp @@ -67,10 +67,14 @@ String JSON::_print_var(const Variant &p_var, const String &p_indent, int p_cur_ switch (p_var.get_type()) { - case Variant::NIL: return "null"; - case Variant::BOOL: return p_var.operator bool() ? "true" : "false"; - case Variant::INT: return itos(p_var); - case Variant::FLOAT: return rtos(p_var); + case Variant::NIL: + return "null"; + case Variant::BOOL: + return p_var.operator bool() ? "true" : "false"; + case Variant::INT: + return itos(p_var); + case Variant::FLOAT: + return rtos(p_var); case Variant::PACKED_INT32_ARRAY: case Variant::PACKED_INT64_ARRAY: case Variant::PACKED_FLOAT32_ARRAY: @@ -116,7 +120,8 @@ String JSON::_print_var(const Variant &p_var, const String &p_indent, int p_cur_ s += end_statement + _make_indent(p_indent, p_cur_indent) + "}"; return s; }; - default: return "\"" + String(p_var).json_escape() + "\""; + default: + return "\"" + String(p_var).json_escape() + "\""; } } @@ -199,11 +204,21 @@ Error JSON::_get_token(const CharType *p_str, int &index, int p_len, Token &r_to switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; - case 'r': res = 13; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; + case 'r': + res = 13; + break; case 'u': { // hex number for (int j = 0; j < 4; j++) { diff --git a/core/io/logger.cpp b/core/io/logger.cpp index ad0cc81023..23165b575e 100644 --- a/core/io/logger.cpp +++ b/core/io/logger.cpp @@ -49,11 +49,21 @@ void Logger::log_error(const char *p_function, const char *p_file, int p_line, c const char *err_type = "ERROR"; switch (p_type) { - case ERR_ERROR: err_type = "ERROR"; break; - case ERR_WARNING: err_type = "WARNING"; break; - case ERR_SCRIPT: err_type = "SCRIPT ERROR"; break; - case ERR_SHADER: err_type = "SHADER ERROR"; break; - default: ERR_PRINT("Unknown error type"); break; + case ERR_ERROR: + err_type = "ERROR"; + break; + case ERR_WARNING: + err_type = "WARNING"; + break; + case ERR_SCRIPT: + err_type = "SCRIPT ERROR"; + break; + case ERR_SHADER: + err_type = "SHADER ERROR"; + break; + default: + ERR_PRINT("Unknown error type"); + break; } const char *err_details; @@ -92,8 +102,6 @@ void Logger::logf_error(const char *p_format, ...) { va_end(argp); } -Logger::~Logger() {} - void RotatedFileLogger::close_file() { if (file) { memdelete(file); @@ -170,8 +178,7 @@ void RotatedFileLogger::rotate_file() { RotatedFileLogger::RotatedFileLogger(const String &p_base_path, int p_max_files) : base_path(p_base_path.simplify_path()), - max_files(p_max_files > 0 ? p_max_files : 1), - file(nullptr) { + max_files(p_max_files > 0 ? p_max_files : 1) { rotate_file(); } @@ -226,8 +233,6 @@ void StdLogger::logv(const char *p_format, va_list p_list, bool p_err) { } } -StdLogger::~StdLogger() {} - CompositeLogger::CompositeLogger(Vector<Logger *> p_loggers) : loggers(p_loggers) { } diff --git a/core/io/logger.h b/core/io/logger.h index 7028551185..54f1a42da9 100644 --- a/core/io/logger.h +++ b/core/io/logger.h @@ -55,7 +55,7 @@ public: void logf(const char *p_format, ...) _PRINTF_FORMAT_ATTRIBUTE_2_3; void logf_error(const char *p_format, ...) _PRINTF_FORMAT_ATTRIBUTE_2_3; - virtual ~Logger(); + virtual ~Logger() {} }; /** @@ -65,7 +65,7 @@ class StdLogger : public Logger { public: virtual void logv(const char *p_format, va_list p_list, bool p_err) _PRINTF_FORMAT_ATTRIBUTE_2_0; - virtual ~StdLogger(); + virtual ~StdLogger() {} }; /** @@ -78,7 +78,7 @@ class RotatedFileLogger : public Logger { String base_path; int max_files; - FileAccess *file; + FileAccess *file = nullptr; void rotate_file_without_closing(); void close_file(); diff --git a/core/io/marshalls.cpp b/core/io/marshalls.cpp index 81bc45b2f7..abf27954b8 100644 --- a/core/io/marshalls.cpp +++ b/core/io/marshalls.cpp @@ -53,9 +53,6 @@ ObjectID EncodedObjectAsID::get_object_id() const { return id; } -EncodedObjectAsID::EncodedObjectAsID() { -} - #define _S(a) ((int32_t)a) #define ERR_FAIL_ADD_OF(a, b, err) ERR_FAIL_COND_V(_S(b) < 0 || _S(a) < 0 || _S(a) > INT_MAX - _S(b), err) #define ERR_FAIL_MUL_OF(a, b, err) ERR_FAIL_COND_V(_S(a) < 0 || _S(b) <= 0 || _S(a) > INT_MAX / _S(b), err) diff --git a/core/io/marshalls.h b/core/io/marshalls.h index d029ed238c..1ba786d5d9 100644 --- a/core/io/marshalls.h +++ b/core/io/marshalls.h @@ -121,7 +121,8 @@ static inline int encode_cstring(const char *p_string, uint8_t *p_data) { len++; }; - if (p_data) *p_data = 0; + if (p_data) + *p_data = 0; return len + 1; } @@ -196,7 +197,7 @@ public: void set_object_id(ObjectID p_id); ObjectID get_object_id() const; - EncodedObjectAsID(); + EncodedObjectAsID() {} }; Error decode_variant(Variant &r_variant, const uint8_t *p_buffer, int p_len, int *r_len = nullptr, bool p_allow_objects = false); diff --git a/core/io/multiplayer_api.cpp b/core/io/multiplayer_api.cpp index 3bec52416e..998bcfd3f3 100644 --- a/core/io/multiplayer_api.cpp +++ b/core/io/multiplayer_api.cpp @@ -33,6 +33,7 @@ #include "core/debugger/engine_debugger.h" #include "core/io/marshalls.h" #include "scene/main/node.h" + #include <stdint.h> #define NODE_ID_COMPRESSION_SHIFT 3 @@ -144,7 +145,8 @@ void MultiplayerAPI::set_root_node(Node *p_node) { void MultiplayerAPI::set_network_peer(const Ref<NetworkedMultiplayerPeer> &p_peer) { - if (p_peer == network_peer) return; // Nothing to do + if (p_peer == network_peer) + return; // Nothing to do ERR_FAIL_COND_MSG(p_peer.is_valid() && p_peer->get_connection_status() == NetworkedMultiplayerPeer::CONNECTION_DISCONNECTED, "Supplied NetworkedMultiplayerPeer must be connecting or connected."); @@ -787,8 +789,9 @@ void MultiplayerAPI::_send_rpc(Node *p_from, int p_to, bool p_unreliable, bool p int ofs = 0; -#define MAKE_ROOM(m_amount) \ - if (packet_cache.size() < m_amount) packet_cache.resize(m_amount); +#define MAKE_ROOM(m_amount) \ + if (packet_cache.size() < m_amount) \ + packet_cache.resize(m_amount); // Encode meta. // The meta is composed by a single byte that contains (starting from the least segnificant bit): @@ -1259,10 +1262,7 @@ void MultiplayerAPI::_bind_methods() { BIND_ENUM_CONSTANT(RPC_MODE_PUPPETSYNC); } -MultiplayerAPI::MultiplayerAPI() : - allow_object_decoding(false) { - rpc_sender_id = 0; - root_node = nullptr; +MultiplayerAPI::MultiplayerAPI() { clear(); } diff --git a/core/io/multiplayer_api.h b/core/io/multiplayer_api.h index 4eb4a53e99..2603fb1e27 100644 --- a/core/io/multiplayer_api.h +++ b/core/io/multiplayer_api.h @@ -56,14 +56,14 @@ private: }; Ref<NetworkedMultiplayerPeer> network_peer; - int rpc_sender_id; + int rpc_sender_id = 0; Set<int> connected_peers; HashMap<NodePath, PathSentCache> path_send_cache; Map<int, PathGetCache> path_get_cache; int last_send_cache_id; Vector<uint8_t> packet_cache; - Node *root_node; - bool allow_object_decoding; + Node *root_node = nullptr; + bool allow_object_decoding = false; protected: static void _bind_methods(); diff --git a/core/io/networked_multiplayer_peer.cpp b/core/io/networked_multiplayer_peer.cpp index b2f810d212..332beb4c8c 100644 --- a/core/io/networked_multiplayer_peer.cpp +++ b/core/io/networked_multiplayer_peer.cpp @@ -66,6 +66,3 @@ void NetworkedMultiplayerPeer::_bind_methods() { ADD_SIGNAL(MethodInfo("connection_succeeded")); ADD_SIGNAL(MethodInfo("connection_failed")); } - -NetworkedMultiplayerPeer::NetworkedMultiplayerPeer() { -} diff --git a/core/io/networked_multiplayer_peer.h b/core/io/networked_multiplayer_peer.h index c1f1924051..8792886ff3 100644 --- a/core/io/networked_multiplayer_peer.h +++ b/core/io/networked_multiplayer_peer.h @@ -74,7 +74,7 @@ public: virtual ConnectionStatus get_connection_status() const = 0; - NetworkedMultiplayerPeer(); + NetworkedMultiplayerPeer() {} }; VARIANT_ENUM_CAST(NetworkedMultiplayerPeer::TransferMode) diff --git a/core/io/packet_peer.cpp b/core/io/packet_peer.cpp index 38abb5c0d6..6d3e1341a7 100644 --- a/core/io/packet_peer.cpp +++ b/core/io/packet_peer.cpp @@ -35,11 +35,6 @@ /* helpers / binders */ -PacketPeer::PacketPeer() : - last_get_error(OK), - encode_buffer_max_size(8 * 1024 * 1024) { -} - void PacketPeer::set_encode_buffer_max_size(int p_max_size) { ERR_FAIL_COND_MSG(p_max_size < 1024, "Max encode buffer must be at least 1024 bytes"); diff --git a/core/io/packet_peer.h b/core/io/packet_peer.h index 62144259cc..b69efa531f 100644 --- a/core/io/packet_peer.h +++ b/core/io/packet_peer.h @@ -47,9 +47,9 @@ class PacketPeer : public Reference { Vector<uint8_t> _get_packet(); Error _get_packet_error() const; - mutable Error last_get_error; + mutable Error last_get_error = OK; - int encode_buffer_max_size; + int encode_buffer_max_size = 8 * 1024 * 1024; Vector<uint8_t> encode_buffer; public: @@ -70,7 +70,7 @@ public: void set_encode_buffer_max_size(int p_max_size); int get_encode_buffer_max_size() const; - PacketPeer(); + PacketPeer() {} ~PacketPeer() {} }; diff --git a/core/io/packet_peer_dtls.cpp b/core/io/packet_peer_dtls.cpp index 6da115eed2..ada3cb10a2 100644 --- a/core/io/packet_peer_dtls.cpp +++ b/core/io/packet_peer_dtls.cpp @@ -57,6 +57,3 @@ void PacketPeerDTLS::_bind_methods() { BIND_ENUM_CONSTANT(STATUS_ERROR); BIND_ENUM_CONSTANT(STATUS_ERROR_HOSTNAME_MISMATCH); } - -PacketPeerDTLS::PacketPeerDTLS() { -} diff --git a/core/io/packet_peer_dtls.h b/core/io/packet_peer_dtls.h index 4f9f4535bc..c2ff4e1a7f 100644 --- a/core/io/packet_peer_dtls.h +++ b/core/io/packet_peer_dtls.h @@ -60,7 +60,7 @@ public: static PacketPeerDTLS *create(); static bool is_available(); - PacketPeerDTLS(); + PacketPeerDTLS() {} }; VARIANT_ENUM_CAST(PacketPeerDTLS::Status); diff --git a/core/io/packet_peer_udp.cpp b/core/io/packet_peer_udp.cpp index f800ffc3db..8b6bd7ef90 100644 --- a/core/io/packet_peer_udp.cpp +++ b/core/io/packet_peer_udp.cpp @@ -343,12 +343,6 @@ void PacketPeerUDP::_bind_methods() { } PacketPeerUDP::PacketPeerUDP() : - packet_port(0), - queue_count(0), - peer_port(0), - connected(false), - blocking(true), - broadcast(false), _sock(Ref<NetSocket>(NetSocket::create())) { rb.resize(16); } diff --git a/core/io/packet_peer_udp.h b/core/io/packet_peer_udp.h index b5a9fc9ec3..23fc5460a6 100644 --- a/core/io/packet_peer_udp.h +++ b/core/io/packet_peer_udp.h @@ -47,14 +47,14 @@ protected: uint8_t recv_buffer[PACKET_BUFFER_SIZE]; uint8_t packet_buffer[PACKET_BUFFER_SIZE]; IP_Address packet_ip; - int packet_port; - int queue_count; + int packet_port = 0; + int queue_count = 0; IP_Address peer_addr; - int peer_port; - bool connected; - bool blocking; - bool broadcast; + int peer_port = 0; + bool connected = false; + bool blocking = true; + bool broadcast = false; Ref<NetSocket> _sock; static void _bind_methods(); diff --git a/core/io/pck_packer.cpp b/core/io/pck_packer.cpp index 5c4b3379ee..124ac30b88 100644 --- a/core/io/pck_packer.cpp +++ b/core/io/pck_packer.cpp @@ -180,11 +180,6 @@ Error PCKPacker::flush(bool p_verbose) { return OK; }; -PCKPacker::PCKPacker() { - - file = nullptr; -}; - PCKPacker::~PCKPacker() { if (file != nullptr) { memdelete(file); diff --git a/core/io/pck_packer.h b/core/io/pck_packer.h index 6058de8345..2848ac3a65 100644 --- a/core/io/pck_packer.h +++ b/core/io/pck_packer.h @@ -39,7 +39,7 @@ class PCKPacker : public Reference { GDCLASS(PCKPacker, Reference); - FileAccess *file; + FileAccess *file = nullptr; int alignment; static void _bind_methods(); @@ -58,7 +58,7 @@ public: Error add_file(const String &p_file, const String &p_src); Error flush(bool p_verbose = false); - PCKPacker(); + PCKPacker() {} ~PCKPacker(); }; diff --git a/core/io/resource_format_binary.cpp b/core/io/resource_format_binary.cpp index 8c7559479b..8ce17bcfbe 100644 --- a/core/io/resource_format_binary.cpp +++ b/core/io/resource_format_binary.cpp @@ -337,10 +337,14 @@ Error ResourceLoaderBinary::parse_variant(Variant &r_v) { } break; case OBJECT_INTERNAL_RESOURCE: { uint32_t index = f->get_32(); + String path = res_path + "::" + itos(index); + if (use_nocache) { - r_v = internal_resources[index].cache; + if (!internal_index_cache.has(path)) { + WARN_PRINT(String("Couldn't load resource (no cache): " + path).utf8().get_data()); + } + r_v = internal_index_cache[path]; } else { - String path = res_path + "::" + itos(index); RES res = ResourceLoader::load(path); if (res.is_null()) { WARN_PRINT(String("Couldn't load resource: " + path).utf8().get_data()); @@ -720,13 +724,15 @@ Error ResourceLoaderBinary::load() { if (!main) { + path = internal_resources[i].path; + + if (path.begins_with("local://")) { + path = path.replace_first("local://", ""); + subindex = path.to_int(); + path = res_path + "::" + path; + } + if (!use_nocache) { - path = internal_resources[i].path; - if (path.begins_with("local://")) { - path = path.replace_first("local://", ""); - subindex = path.to_int(); - path = res_path + "::" + path; - } if (ResourceCache::has(path)) { //already loaded, don't do anything @@ -769,7 +775,7 @@ Error ResourceLoaderBinary::load() { r->set_subindex(subindex); if (!main) { - internal_resources.write[i].cache = res; + internal_index_cache[path] = res; } int pc = f->get_32(); @@ -1018,18 +1024,6 @@ String ResourceLoaderBinary::recognize(FileAccess *p_f) { return type; } -ResourceLoaderBinary::ResourceLoaderBinary() : - translation_remapped(false), - ver_format(0), - f(nullptr), - importmd_ofs(0), - error(OK) { - - use_nocache = false; - progress = nullptr; - use_sub_threads = false; -} - ResourceLoaderBinary::~ResourceLoaderBinary() { if (f) @@ -2112,6 +2106,5 @@ void ResourceFormatSaverBinary::get_recognized_extensions(const RES &p_resource, ResourceFormatSaverBinary *ResourceFormatSaverBinary::singleton = nullptr; ResourceFormatSaverBinary::ResourceFormatSaverBinary() { - singleton = this; } diff --git a/core/io/resource_format_binary.h b/core/io/resource_format_binary.h index 0f8fc9445b..57aa086022 100644 --- a/core/io/resource_format_binary.h +++ b/core/io/resource_format_binary.h @@ -37,16 +37,16 @@ class ResourceLoaderBinary { - bool translation_remapped; + bool translation_remapped = false; String local_path; String res_path; String type; Ref<Resource> resource; - uint32_t ver_format; + uint32_t ver_format = 0; - FileAccess *f; + FileAccess *f = nullptr; - uint64_t importmd_ofs; + uint64_t importmd_ofs = 0; Vector<char> str_buf; List<RES> resource_cache; @@ -61,25 +61,25 @@ class ResourceLoaderBinary { RES cache; }; - bool use_sub_threads; - float *progress; + bool use_sub_threads = false; + float *progress = nullptr; Vector<ExtResource> external_resources; struct IntResource { String path; uint64_t offset; - RES cache; }; Vector<IntResource> internal_resources; + Map<String, RES> internal_index_cache; String get_unicode_string(); void _advance_padding(uint32_t p_len); Map<String, String> remaps; - Error error; + Error error = OK; - bool use_nocache; + bool use_nocache = false; friend class ResourceFormatLoaderBinary; @@ -98,7 +98,7 @@ public: String recognize(FileAccess *p_f); void get_dependencies(FileAccess *p_f, List<String> *p_dependencies, bool p_add_types); - ResourceLoaderBinary(); + ResourceLoaderBinary() {} ~ResourceLoaderBinary(); }; diff --git a/core/io/resource_importer.cpp b/core/io/resource_importer.cpp index 643df53f8c..9e22bdced7 100644 --- a/core/io/resource_importer.cpp +++ b/core/io/resource_importer.cpp @@ -91,7 +91,7 @@ Error ResourceFormatImporter::_get_path_and_type(const String &p_path, PathAndTy r_path_and_type.path = value; path_found = true; //first match must have priority } else if (assign == "type") { - r_path_and_type.type = value; + r_path_and_type.type = ClassDB::get_compatibility_remapped_class(value); } else if (assign == "importer") { r_path_and_type.importer = value; } else if (assign == "group_file") { diff --git a/core/io/resource_loader.cpp b/core/io/resource_loader.cpp index d90802d7e2..dc44be4e0b 100644 --- a/core/io/resource_loader.cpp +++ b/core/io/resource_loader.cpp @@ -33,7 +33,6 @@ #include "core/io/resource_importer.h" #include "core/os/file_access.h" #include "core/os/os.h" -#include "core/path_remap.h" #include "core/print_string.h" #include "core/project_settings.h" #include "core/translation.h" diff --git a/core/io/resource_loader.h b/core/io/resource_loader.h index 2d95d5545f..b10c08d693 100644 --- a/core/io/resource_loader.h +++ b/core/io/resource_loader.h @@ -160,7 +160,8 @@ public: static bool get_timestamp_on_load() { return timestamp_on_load; } static void notify_load_error(const String &p_err) { - if (err_notify) err_notify(err_notify_ud, p_err); + if (err_notify) + err_notify(err_notify_ud, p_err); } static void set_error_notify_func(void *p_ud, ResourceLoadErrorNotify p_err_notify) { err_notify = p_err_notify; @@ -168,7 +169,8 @@ public: } static void notify_dependency_error(const String &p_path, const String &p_dependency, const String &p_type) { - if (dep_err_notify) dep_err_notify(dep_err_notify_ud, p_path, p_dependency, p_type); + if (dep_err_notify) + dep_err_notify(dep_err_notify_ud, p_path, p_dependency, p_type); } static void set_dependency_error_notify_func(void *p_ud, DependencyErrorNotify p_err_notify) { dep_err_notify = p_err_notify; diff --git a/core/io/stream_peer.cpp b/core/io/stream_peer.cpp index b28b17aa95..9bbe92096d 100644 --- a/core/io/stream_peer.cpp +++ b/core/io/stream_peer.cpp @@ -536,8 +536,3 @@ Ref<StreamPeerBuffer> StreamPeerBuffer::duplicate() const { spb->data = data; return spb; } - -StreamPeerBuffer::StreamPeerBuffer() { - - pointer = 0; -} diff --git a/core/io/stream_peer.h b/core/io/stream_peer.h index 9358a2c07c..a390fdc325 100644 --- a/core/io/stream_peer.h +++ b/core/io/stream_peer.h @@ -47,7 +47,7 @@ protected: Array _get_data(int p_bytes); Array _get_partial_data(int p_bytes); - bool big_endian; + bool big_endian = false; public: virtual Error put_data(const uint8_t *p_data, int p_bytes) = 0; ///< put a whole chunk of data, blocking until it sent @@ -89,7 +89,7 @@ public: String get_utf8_string(int p_bytes = -1); Variant get_var(bool p_allow_objects = false); - StreamPeer() { big_endian = false; } + StreamPeer() {} }; class StreamPeerBuffer : public StreamPeer { @@ -97,7 +97,7 @@ class StreamPeerBuffer : public StreamPeer { GDCLASS(StreamPeerBuffer, StreamPeer); Vector<uint8_t> data; - int pointer; + int pointer = 0; protected: static void _bind_methods(); @@ -123,7 +123,7 @@ public: Ref<StreamPeerBuffer> duplicate() const; - StreamPeerBuffer(); + StreamPeerBuffer() {} }; #endif // STREAM_PEER_H diff --git a/core/io/stream_peer_ssl.cpp b/core/io/stream_peer_ssl.cpp index d98935f77c..1d86c35578 100644 --- a/core/io/stream_peer_ssl.cpp +++ b/core/io/stream_peer_ssl.cpp @@ -73,7 +73,3 @@ void StreamPeerSSL::_bind_methods() { BIND_ENUM_CONSTANT(STATUS_ERROR); BIND_ENUM_CONSTANT(STATUS_ERROR_HOSTNAME_MISMATCH); } - -StreamPeerSSL::StreamPeerSSL() { - blocking_handshake = true; -} diff --git a/core/io/stream_peer_ssl.h b/core/io/stream_peer_ssl.h index de3cb09c60..81b95b856d 100644 --- a/core/io/stream_peer_ssl.h +++ b/core/io/stream_peer_ssl.h @@ -43,7 +43,7 @@ protected: static bool available; - bool blocking_handshake; + bool blocking_handshake = true; public: enum Status { @@ -68,7 +68,7 @@ public: static bool is_available(); - StreamPeerSSL(); + StreamPeerSSL() {} }; VARIANT_ENUM_CAST(StreamPeerSSL::Status); diff --git a/core/io/stream_peer_tcp.cpp b/core/io/stream_peer_tcp.cpp index f0c5816d73..6218b98758 100644 --- a/core/io/stream_peer_tcp.cpp +++ b/core/io/stream_peer_tcp.cpp @@ -362,10 +362,7 @@ void StreamPeerTCP::_bind_methods() { } StreamPeerTCP::StreamPeerTCP() : - _sock(Ref<NetSocket>(NetSocket::create())), - timeout(0), - status(STATUS_NONE), - peer_port(0) { + _sock(Ref<NetSocket>(NetSocket::create())) { } StreamPeerTCP::~StreamPeerTCP() { diff --git a/core/io/stream_peer_tcp.h b/core/io/stream_peer_tcp.h index 86df9ab8cf..571f6b7c54 100644 --- a/core/io/stream_peer_tcp.h +++ b/core/io/stream_peer_tcp.h @@ -52,10 +52,10 @@ public: protected: Ref<NetSocket> _sock; - uint64_t timeout; - Status status; + uint64_t timeout = 0; + Status status = STATUS_NONE; IP_Address peer_host; - uint16_t peer_port; + uint16_t peer_port = 0; Error _connect(const String &p_address, int p_port); Error _poll_connection(); diff --git a/core/io/translation_loader_po.cpp b/core/io/translation_loader_po.cpp index bce5361c76..6f79e2554b 100644 --- a/core/io/translation_loader_po.cpp +++ b/core/io/translation_loader_po.cpp @@ -212,6 +212,3 @@ String TranslationLoaderPO::get_resource_type(const String &p_path) const { return "Translation"; return ""; } - -TranslationLoaderPO::TranslationLoaderPO() { -} diff --git a/core/io/translation_loader_po.h b/core/io/translation_loader_po.h index 137dfd1768..a196a37dc0 100644 --- a/core/io/translation_loader_po.h +++ b/core/io/translation_loader_po.h @@ -43,7 +43,7 @@ public: virtual bool handles_type(const String &p_type) const; virtual String get_resource_type(const String &p_path) const; - TranslationLoaderPO(); + TranslationLoaderPO() {} }; #endif // TRANSLATION_LOADER_PO_H diff --git a/core/io/xml_parser.cpp b/core/io/xml_parser.cpp index 9613ad3f10..a4b64bf17c 100644 --- a/core/io/xml_parser.cpp +++ b/core/io/xml_parser.cpp @@ -543,9 +543,6 @@ int XMLParser::get_current_line() const { } XMLParser::XMLParser() { - - data = nullptr; - close(); special_characters.push_back("&"); special_characters.push_back("<lt;"); special_characters.push_back(">gt;"); diff --git a/core/io/xml_parser.h b/core/io/xml_parser.h index 26c3e6802f..42b7d6e0d4 100644 --- a/core/io/xml_parser.h +++ b/core/io/xml_parser.h @@ -66,15 +66,15 @@ public: }; private: - char *data; - char *P; - uint64_t length; + char *data = nullptr; + char *P = nullptr; + uint64_t length = 0; void unescape(String &p_str); Vector<String> special_characters; String node_name; - bool node_empty; - NodeType node_type; - uint64_t node_offset; + bool node_empty = false; + NodeType node_type = NODE_NONE; + uint64_t node_offset = 0; struct Attribute { String name; diff --git a/core/list.h b/core/list.h index be2dccd876..7fbf3e50fc 100644 --- a/core/list.h +++ b/core/list.h @@ -54,9 +54,9 @@ public: friend class List<T, A>; T value; - Element *next_ptr; - Element *prev_ptr; - _Data *data; + Element *next_ptr = nullptr; + Element *prev_ptr = nullptr; + _Data *data = nullptr; public: /** @@ -139,11 +139,7 @@ public: data->erase(this); } - _FORCE_INLINE_ Element() { - next_ptr = 0; - prev_ptr = 0; - data = nullptr; - }; + _FORCE_INLINE_ Element() {} }; private: @@ -178,7 +174,7 @@ private: } }; - _Data *_data; + _Data *_data = nullptr; public: /** @@ -186,14 +182,14 @@ public: */ _FORCE_INLINE_ const Element *front() const { - return _data ? _data->first : 0; + return _data ? _data->first : nullptr; }; /** * return an iterator to the beginning of the list. */ _FORCE_INLINE_ Element *front() { - return _data ? _data->first : 0; + return _data ? _data->first : nullptr; }; /** @@ -201,7 +197,7 @@ public: */ _FORCE_INLINE_ const Element *back() const { - return _data ? _data->last : 0; + return _data ? _data->last : nullptr; }; /** @@ -209,7 +205,7 @@ public: */ _FORCE_INLINE_ Element *back() { - return _data ? _data->last : 0; + return _data ? _data->last : nullptr; }; /** @@ -229,7 +225,7 @@ public: n->value = (T &)value; n->prev_ptr = _data->last; - n->next_ptr = 0; + n->next_ptr = nullptr; n->data = _data; if (_data->last) { @@ -268,7 +264,7 @@ public: Element *n = memnew_allocator(Element, A); n->value = (T &)value; - n->prev_ptr = 0; + n->prev_ptr = nullptr; n->next_ptr = _data->first; n->data = _data; @@ -353,7 +349,8 @@ public: Element *it = front(); while (it) { - if (it->value == p_val) return it; + if (it->value == p_val) + return it; it = it->next(); }; @@ -686,7 +683,6 @@ public: */ List(const List &p_list) { - _data = nullptr; const Element *it = p_list.front(); while (it) { @@ -695,9 +691,8 @@ public: } } - List() { - _data = nullptr; - }; + List() {} + ~List() { clear(); if (_data) { diff --git a/core/local_vector.h b/core/local_vector.h new file mode 100644 index 0000000000..0b0ef6dfdc --- /dev/null +++ b/core/local_vector.h @@ -0,0 +1,246 @@ +/*************************************************************************/ +/* local_vector.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef LOCAL_VECTOR_H +#define LOCAL_VECTOR_H + +#include "core/error_macros.h" +#include "core/os/copymem.h" +#include "core/os/memory.h" +#include "core/sort_array.h" +#include "core/vector.h" + +template <class T, class U = uint32_t, bool force_trivial = false> +class LocalVector { +private: + U count = 0; + U capacity = 0; + T *data = nullptr; + +public: + _FORCE_INLINE_ void push_back(T p_elem) { + if (unlikely(count == capacity)) { + if (capacity == 0) { + capacity = 1; + } else { + capacity <<= 1; + } + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + + if (!__has_trivial_constructor(T) && !force_trivial) { + memnew_placement(&data[count++], T(p_elem)); + } else { + data[count++] = p_elem; + } + } + + void remove(U p_index) { + ERR_FAIL_UNSIGNED_INDEX(p_index, count); + for (U i = p_index; i < count; i++) { + data[i] = data[i + 1]; + } + count--; + if (!__has_trivial_destructor(T) && !force_trivial) { + data[count].~T(); + } + } + + void erase(const T &p_val) { + U idx = find(p_val); + if (idx >= 0) + remove(idx); + } + + void invert() { + + for (U i = 0; i < count / 2; i++) { + SWAP(data[i], data[count - i - 1]); + } + } + + _FORCE_INLINE_ void clear() { resize(0); } + _FORCE_INLINE_ void reset() { + clear(); + if (data) { + memfree(data); + data = nullptr; + capacity = 0; + } + } + _FORCE_INLINE_ bool empty() const { return count == 0; } + _FORCE_INLINE_ void reserve(U p_size) { + p_size = nearest_power_of_2_templated(p_size); + if (p_size > capacity) { + capacity = p_size; + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + } + + _FORCE_INLINE_ U size() const { return count; } + void resize(U p_size) { + + if (p_size < count) { + + if (!__has_trivial_destructor(T) && !force_trivial) { + for (U i = p_size; i < count; i++) { + data[i].~T(); + } + } + count = p_size; + } else if (p_size > count) { + + if (unlikely(p_size > capacity)) { + if (capacity == 0) { + capacity = 1; + } + while (capacity < p_size) { + capacity <<= 1; + } + data = (T *)memrealloc(data, capacity * sizeof(T)); + CRASH_COND_MSG(!data, "Out of memory"); + } + if (!__has_trivial_constructor(T) && !force_trivial) { + for (U i = count; i < p_size; i++) { + memnew_placement(&data[i], T); + } + } + count = p_size; + } + } + _FORCE_INLINE_ const T &operator[](U p_index) const { + CRASH_BAD_UNSIGNED_INDEX(p_index, count); + return data[p_index]; + } + _FORCE_INLINE_ T &operator[](U p_index) { + CRASH_BAD_UNSIGNED_INDEX(p_index, count); + return data[p_index]; + } + + void insert(U p_pos, T p_val) { + ERR_FAIL_UNSIGNED_INDEX(p_pos, count + 1); + if (p_pos == count) { + push_back(p_val); + } else { + resize(count + 1); + for (U i = count; i > p_pos; i--) { + data[i] = data[i - 1]; + } + data[p_pos] = p_val; + } + } + + int64_t find(const T &p_val, U p_from = 0) const { + + for (U i = 0; i < count; i++) { + if (data[i] == p_val) { + return int64_t(i); + } + } + return -1; + } + + template <class C> + void sort_custom() { + + U len = count; + if (len == 0) + return; + + SortArray<T, C> sorter; + sorter.sort(data, len); + } + + void sort() { + + sort_custom<_DefaultComparator<T>>(); + } + + void ordered_insert(T p_val) { + + U i; + for (i = 0; i < count; i++) { + + if (p_val < data[i]) { + break; + }; + }; + insert(i, p_val); + } + + operator Vector<T>() const { + Vector<T> ret; + ret.resize(size()); + T *w = ret.ptrw(); + copymem(w, data, sizeof(T) * count); + return ret; + } + + Vector<uint8_t> to_byte_array() const { //useful to pass stuff to gpu or variant + Vector<uint8_t> ret; + ret.resize(count * sizeof(T)); + uint8_t *w = ret.ptrw(); + copymem(w, data, sizeof(T) * count); + return ret; + } + + _FORCE_INLINE_ LocalVector() {} + _FORCE_INLINE_ LocalVector(const LocalVector &p_from) { + resize(p_from.size()); + for (U i = 0; i < p_from.count; i++) { + data[i] = p_from.data[i]; + } + } + inline LocalVector &operator=(const LocalVector &p_from) { + resize(p_from.size()); + for (U i = 0; i < p_from.count; i++) { + data[i] = p_from.data[i]; + } + return *this; + } + inline LocalVector &operator=(const Vector<T> &p_from) { + resize(p_from.size()); + for (U i = 0; i < count; i++) { + data[i] = p_from[i]; + } + return *this; + } + + _FORCE_INLINE_ ~LocalVector() { + + if (data) { + reset(); + } + } +}; + +#endif // LOCAL_VECTOR_H diff --git a/core/map.h b/core/map.h index 6b9dff51de..621b6c2842 100644 --- a/core/map.h +++ b/core/map.h @@ -51,12 +51,12 @@ public: private: friend class Map<K, V, C, A>; - int color; - Element *right; - Element *left; - Element *parent; - Element *_next; - Element *_prev; + int color = RED; + Element *right = nullptr; + Element *left = nullptr; + Element *parent = nullptr; + Element *_next = nullptr; + Element *_prev = nullptr; K _key; V _value; //_Data *data; @@ -93,22 +93,15 @@ public: const V &get() const { return _value; }; - Element() { - color = RED; - right = nullptr; - left = nullptr; - parent = nullptr; - _next = nullptr; - _prev = nullptr; - }; + Element() {} }; private: struct _Data { - Element *_root; + Element *_root = nullptr; Element *_nil; - int size_cache; + int size_cache = 0; _FORCE_INLINE_ _Data() { #ifdef GLOBALNIL_DISABLED @@ -118,8 +111,6 @@ private: #else _nil = (Element *)&_GlobalNilClass::_nil; #endif - _root = nullptr; - size_cache = 0; } void _create_root() { @@ -673,8 +664,7 @@ public: _copy_from(p_map); } - _FORCE_INLINE_ Map() { - } + _FORCE_INLINE_ Map() {} ~Map() { diff --git a/core/math/a_star.cpp b/core/math/a_star.cpp index 3e3e6c50a7..d6d6101402 100644 --- a/core/math/a_star.cpp +++ b/core/math/a_star.cpp @@ -164,7 +164,8 @@ void AStar::connect_points(int p_id, int p_with_id, bool bidirectional) { } Segment s(p_id, p_with_id); - if (bidirectional) s.direction = Segment::BIDIRECTIONAL; + if (bidirectional) + s.direction = Segment::BIDIRECTIONAL; Set<Segment>::Element *element = segments.find(s); if (element != nullptr) { @@ -290,7 +291,8 @@ int AStar::get_closest_point(const Vector3 &p_point, bool p_include_disabled) co for (OAHashMap<int, Point *>::Iterator it = points.iter(); it.valid; it = points.next_iter(it)) { - if (!p_include_disabled && !(*it.value)->enabled) continue; // Disabled points should not be considered. + if (!p_include_disabled && !(*it.value)->enabled) + continue; // Disabled points should not be considered. real_t d = p_point.distance_squared_to((*it.value)->pos); if (closest_id < 0 || d < closest_dist) { @@ -340,7 +342,8 @@ bool AStar::_solve(Point *begin_point, Point *end_point) { pass++; - if (!end_point->enabled) return false; + if (!end_point->enabled) + return false; bool found_route = false; @@ -451,7 +454,8 @@ Vector<Vector3> AStar::get_point_path(int p_from_id, int p_to_id) { Point *end_point = b; bool found_route = _solve(begin_point, end_point); - if (!found_route) return Vector<Vector3>(); + if (!found_route) + return Vector<Vector3>(); Point *p = end_point; int pc = 1; // Begin point @@ -499,7 +503,8 @@ Vector<int> AStar::get_id_path(int p_from_id, int p_to_id) { Point *end_point = b; bool found_route = _solve(begin_point, end_point); - if (!found_route) return Vector<int>(); + if (!found_route) + return Vector<int>(); Point *p = end_point; int pc = 1; // Begin point @@ -580,11 +585,6 @@ void AStar::_bind_methods() { BIND_VMETHOD(MethodInfo(Variant::FLOAT, "_compute_cost", PropertyInfo(Variant::INT, "from_id"), PropertyInfo(Variant::INT, "to_id"))); } -AStar::AStar() { - last_free_id = 0; - pass = 1; -} - AStar::~AStar() { clear(); } @@ -729,7 +729,8 @@ Vector<Vector2> AStar2D::get_point_path(int p_from_id, int p_to_id) { AStar::Point *end_point = b; bool found_route = _solve(begin_point, end_point); - if (!found_route) return Vector<Vector2>(); + if (!found_route) + return Vector<Vector2>(); AStar::Point *p = end_point; int pc = 1; // Begin point @@ -777,7 +778,8 @@ Vector<int> AStar2D::get_id_path(int p_from_id, int p_to_id) { AStar::Point *end_point = b; bool found_route = _solve(begin_point, end_point); - if (!found_route) return Vector<int>(); + if (!found_route) + return Vector<int>(); AStar::Point *p = end_point; int pc = 1; // Begin point @@ -809,7 +811,8 @@ bool AStar2D::_solve(AStar::Point *begin_point, AStar::Point *end_point) { astar.pass++; - if (!end_point->enabled) return false; + if (!end_point->enabled) + return false; bool found_route = false; @@ -902,9 +905,3 @@ void AStar2D::_bind_methods() { BIND_VMETHOD(MethodInfo(Variant::FLOAT, "_estimate_cost", PropertyInfo(Variant::INT, "from_id"), PropertyInfo(Variant::INT, "to_id"))); BIND_VMETHOD(MethodInfo(Variant::FLOAT, "_compute_cost", PropertyInfo(Variant::INT, "from_id"), PropertyInfo(Variant::INT, "to_id"))); } - -AStar2D::AStar2D() { -} - -AStar2D::~AStar2D() { -} diff --git a/core/math/a_star.h b/core/math/a_star.h index 8c10ace33c..ffb437ee04 100644 --- a/core/math/a_star.h +++ b/core/math/a_star.h @@ -47,17 +47,15 @@ class AStar : public Reference { struct Point { - Point() : - neighbours(4u), - unlinked_neighbours(4u) {} + Point() {} int id; Vector3 pos; real_t weight_scale; bool enabled; - OAHashMap<int, Point *> neighbours; - OAHashMap<int, Point *> unlinked_neighbours; + OAHashMap<int, Point *> neighbours = 4u; + OAHashMap<int, Point *> unlinked_neighbours = 4u; // Used for pathfinding. Point *prev_point; @@ -85,7 +83,7 @@ class AStar : public Reference { int32_t u; int32_t v; }; - uint64_t key; + uint64_t key = 0; }; enum { @@ -94,13 +92,11 @@ class AStar : public Reference { BACKWARD = 2, BIDIRECTIONAL = FORWARD | BACKWARD }; - unsigned char direction; + unsigned char direction = NONE; bool operator<(const Segment &p_s) const { return key < p_s.key; } - Segment() { - key = 0; - direction = NONE; - } + + Segment() {} Segment(int p_from, int p_to) { if (p_from < p_to) { u = p_from; @@ -114,8 +110,8 @@ class AStar : public Reference { } }; - int last_free_id; - uint64_t pass; + int last_free_id = 0; + uint64_t pass = 1; OAHashMap<int, Point *> points; Set<Segment> segments; @@ -159,7 +155,7 @@ public: Vector<Vector3> get_point_path(int p_from_id, int p_to_id); Vector<int> get_id_path(int p_from_id, int p_to_id); - AStar(); + AStar() {} ~AStar(); }; @@ -206,8 +202,8 @@ public: Vector<Vector2> get_point_path(int p_from_id, int p_to_id); Vector<int> get_id_path(int p_from_id, int p_to_id); - AStar2D(); - ~AStar2D(); + AStar2D() {} + ~AStar2D() {} }; #endif // A_STAR_H diff --git a/core/math/aabb.h b/core/math/aabb.h index 7fdad07c89..f87fced12d 100644 --- a/core/math/aabb.h +++ b/core/math/aabb.h @@ -177,14 +177,22 @@ Vector3 AABB::get_support(const Vector3 &p_normal) const { Vector3 AABB::get_endpoint(int p_point) const { switch (p_point) { - case 0: return Vector3(position.x, position.y, position.z); - case 1: return Vector3(position.x, position.y, position.z + size.z); - case 2: return Vector3(position.x, position.y + size.y, position.z); - case 3: return Vector3(position.x, position.y + size.y, position.z + size.z); - case 4: return Vector3(position.x + size.x, position.y, position.z); - case 5: return Vector3(position.x + size.x, position.y, position.z + size.z); - case 6: return Vector3(position.x + size.x, position.y + size.y, position.z); - case 7: return Vector3(position.x + size.x, position.y + size.y, position.z + size.z); + case 0: + return Vector3(position.x, position.y, position.z); + case 1: + return Vector3(position.x, position.y, position.z + size.z); + case 2: + return Vector3(position.x, position.y + size.y, position.z); + case 3: + return Vector3(position.x, position.y + size.y, position.z + size.z); + case 4: + return Vector3(position.x + size.x, position.y, position.z); + case 5: + return Vector3(position.x + size.x, position.y, position.z + size.z); + case 6: + return Vector3(position.x + size.x, position.y + size.y, position.z); + case 7: + return Vector3(position.x + size.x, position.y + size.y, position.z + size.z); }; ERR_FAIL_V(Vector3()); diff --git a/core/math/basis.cpp b/core/math/basis.cpp index 0f519a20d8..6218b7e248 100644 --- a/core/math/basis.cpp +++ b/core/math/basis.cpp @@ -783,7 +783,8 @@ void Basis::get_axis_angle(Vector3 &r_axis, real_t &r_angle) const { real_t s = Math::sqrt((elements[1][2] - elements[2][1]) * (elements[1][2] - elements[2][1]) + (elements[2][0] - elements[0][2]) * (elements[2][0] - elements[0][2]) + (elements[0][1] - elements[1][0]) * (elements[0][1] - elements[1][0])); // s=|axis||sin(angle)|, used to normalise angle = Math::acos((elements[0][0] + elements[1][1] + elements[2][2] - 1) / 2); - if (angle < 0) s = -s; + if (angle < 0) + s = -s; x = (elements[2][1] - elements[1][2]) / s; y = (elements[0][2] - elements[2][0]) / s; z = (elements[1][0] - elements[0][1]) / s; @@ -877,3 +878,114 @@ Basis Basis::slerp(const Basis &target, const real_t &t) const { return b; } + +void Basis::rotate_sh(real_t *p_values) { + + // code by John Hable + // http://filmicworlds.com/blog/simple-and-fast-spherical-harmonic-rotation/ + // this code is Public Domain + + const static real_t s_c3 = 0.94617469575; // (3*sqrt(5))/(4*sqrt(pi)) + const static real_t s_c4 = -0.31539156525; // (-sqrt(5))/(4*sqrt(pi)) + const static real_t s_c5 = 0.54627421529; // (sqrt(15))/(4*sqrt(pi)) + + const static real_t s_c_scale = 1.0 / 0.91529123286551084; + const static real_t s_c_scale_inv = 0.91529123286551084; + + const static real_t s_rc2 = 1.5853309190550713 * s_c_scale; + const static real_t s_c4_div_c3 = s_c4 / s_c3; + const static real_t s_c4_div_c3_x2 = (s_c4 / s_c3) * 2.0; + + const static real_t s_scale_dst2 = s_c3 * s_c_scale_inv; + const static real_t s_scale_dst4 = s_c5 * s_c_scale_inv; + + real_t src[9] = { p_values[0], p_values[1], p_values[2], p_values[3], p_values[4], p_values[5], p_values[6], p_values[7], p_values[8] }; + + real_t m00 = elements[0][0]; + real_t m01 = elements[0][1]; + real_t m02 = elements[0][2]; + real_t m10 = elements[1][0]; + real_t m11 = elements[1][1]; + real_t m12 = elements[1][2]; + real_t m20 = elements[2][0]; + real_t m21 = elements[2][1]; + real_t m22 = elements[2][2]; + + p_values[0] = src[0]; + p_values[1] = m11 * src[1] - m12 * src[2] + m10 * src[3]; + p_values[2] = -m21 * src[1] + m22 * src[2] - m20 * src[3]; + p_values[3] = m01 * src[1] - m02 * src[2] + m00 * src[3]; + + real_t sh0 = src[7] + src[8] + src[8] - src[5]; + real_t sh1 = src[4] + s_rc2 * src[6] + src[7] + src[8]; + real_t sh2 = src[4]; + real_t sh3 = -src[7]; + real_t sh4 = -src[5]; + + // Rotations. R0 and R1 just use the raw matrix columns + real_t r2x = m00 + m01; + real_t r2y = m10 + m11; + real_t r2z = m20 + m21; + + real_t r3x = m00 + m02; + real_t r3y = m10 + m12; + real_t r3z = m20 + m22; + + real_t r4x = m01 + m02; + real_t r4y = m11 + m12; + real_t r4z = m21 + m22; + + // dense matrix multiplication one column at a time + + // column 0 + real_t sh0_x = sh0 * m00; + real_t sh0_y = sh0 * m10; + real_t d0 = sh0_x * m10; + real_t d1 = sh0_y * m20; + real_t d2 = sh0 * (m20 * m20 + s_c4_div_c3); + real_t d3 = sh0_x * m20; + real_t d4 = sh0_x * m00 - sh0_y * m10; + + // column 1 + real_t sh1_x = sh1 * m02; + real_t sh1_y = sh1 * m12; + d0 += sh1_x * m12; + d1 += sh1_y * m22; + d2 += sh1 * (m22 * m22 + s_c4_div_c3); + d3 += sh1_x * m22; + d4 += sh1_x * m02 - sh1_y * m12; + + // column 2 + real_t sh2_x = sh2 * r2x; + real_t sh2_y = sh2 * r2y; + d0 += sh2_x * r2y; + d1 += sh2_y * r2z; + d2 += sh2 * (r2z * r2z + s_c4_div_c3_x2); + d3 += sh2_x * r2z; + d4 += sh2_x * r2x - sh2_y * r2y; + + // column 3 + real_t sh3_x = sh3 * r3x; + real_t sh3_y = sh3 * r3y; + d0 += sh3_x * r3y; + d1 += sh3_y * r3z; + d2 += sh3 * (r3z * r3z + s_c4_div_c3_x2); + d3 += sh3_x * r3z; + d4 += sh3_x * r3x - sh3_y * r3y; + + // column 4 + real_t sh4_x = sh4 * r4x; + real_t sh4_y = sh4 * r4y; + d0 += sh4_x * r4y; + d1 += sh4_y * r4z; + d2 += sh4 * (r4z * r4z + s_c4_div_c3_x2); + d3 += sh4_x * r4z; + d4 += sh4_x * r4x - sh4_y * r4y; + + // extra multipliers + p_values[4] = d0; + p_values[5] = -d1; + p_values[6] = d2 * s_scale_dst2; + p_values[7] = -d3; + p_values[8] = d4 * s_scale_dst4; +} diff --git a/core/math/basis.h b/core/math/basis.h index 0261cf67c6..2924a0ddbd 100644 --- a/core/math/basis.h +++ b/core/math/basis.h @@ -159,6 +159,7 @@ public: bool is_rotation() const; Basis slerp(const Basis &target, const real_t &t) const; + void rotate_sh(real_t *p_values); operator String() const; diff --git a/core/math/camera_matrix.cpp b/core/math/camera_matrix.cpp index c36070e47f..5d3ebc9f6d 100644 --- a/core/math/camera_matrix.cpp +++ b/core/math/camera_matrix.cpp @@ -33,6 +33,22 @@ #include "core/math/math_funcs.h" #include "core/print_string.h" +float CameraMatrix::determinant() const { + + return matrix[0][3] * matrix[1][2] * matrix[2][1] * matrix[3][0] - matrix[0][2] * matrix[1][3] * matrix[2][1] * matrix[3][0] - + matrix[0][3] * matrix[1][1] * matrix[2][2] * matrix[3][0] + matrix[0][1] * matrix[1][3] * matrix[2][2] * matrix[3][0] + + matrix[0][2] * matrix[1][1] * matrix[2][3] * matrix[3][0] - matrix[0][1] * matrix[1][2] * matrix[2][3] * matrix[3][0] - + matrix[0][3] * matrix[1][2] * matrix[2][0] * matrix[3][1] + matrix[0][2] * matrix[1][3] * matrix[2][0] * matrix[3][1] + + matrix[0][3] * matrix[1][0] * matrix[2][2] * matrix[3][1] - matrix[0][0] * matrix[1][3] * matrix[2][2] * matrix[3][1] - + matrix[0][2] * matrix[1][0] * matrix[2][3] * matrix[3][1] + matrix[0][0] * matrix[1][2] * matrix[2][3] * matrix[3][1] + + matrix[0][3] * matrix[1][1] * matrix[2][0] * matrix[3][2] - matrix[0][1] * matrix[1][3] * matrix[2][0] * matrix[3][2] - + matrix[0][3] * matrix[1][0] * matrix[2][1] * matrix[3][2] + matrix[0][0] * matrix[1][3] * matrix[2][1] * matrix[3][2] + + matrix[0][1] * matrix[1][0] * matrix[2][3] * matrix[3][2] - matrix[0][0] * matrix[1][1] * matrix[2][3] * matrix[3][2] - + matrix[0][2] * matrix[1][1] * matrix[2][0] * matrix[3][3] + matrix[0][1] * matrix[1][2] * matrix[2][0] * matrix[3][3] + + matrix[0][2] * matrix[1][0] * matrix[2][1] * matrix[3][3] - matrix[0][0] * matrix[1][2] * matrix[2][1] * matrix[3][3] - + matrix[0][1] * matrix[1][0] * matrix[2][2] * matrix[3][3] + matrix[0][0] * matrix[1][1] * matrix[2][2] * matrix[3][3]; +} + void CameraMatrix::set_identity() { for (int i = 0; i < 4; i++) { @@ -473,20 +489,23 @@ void CameraMatrix::invert() { /** Divide column by minus pivot value **/ for (i = 0; i < 4; i++) { - if (i != k) matrix[i][k] /= (-pvt_val); + if (i != k) + matrix[i][k] /= (-pvt_val); } /** Reduce the matrix **/ for (i = 0; i < 4; i++) { hold = matrix[i][k]; for (j = 0; j < 4; j++) { - if (i != k && j != k) matrix[i][j] += hold * matrix[k][j]; + if (i != k && j != k) + matrix[i][j] += hold * matrix[k][j]; } } /** Divide row by pivot **/ for (j = 0; j < 4; j++) { - if (j != k) matrix[k][j] /= pvt_val; + if (j != k) + matrix[k][j] /= pvt_val; } /** Replace pivot by reciprocal (at last we can touch it). **/ diff --git a/core/math/camera_matrix.h b/core/math/camera_matrix.h index c10193bc84..5420fa2984 100644 --- a/core/math/camera_matrix.h +++ b/core/math/camera_matrix.h @@ -47,6 +47,7 @@ struct CameraMatrix { real_t matrix[4][4]; + float determinant() const; void set_identity(); void set_zero(); void set_light_bias(); diff --git a/core/math/delaunay.h b/core/math/delaunay_2d.h index 29f84210d2..66b2f8f573 100644 --- a/core/math/delaunay.h +++ b/core/math/delaunay_2d.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* delaunay.h */ +/* delaunay_2d.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,32 +28,29 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#ifndef DELAUNAY_H -#define DELAUNAY_H +#ifndef DELAUNAY_2D_H +#define DELAUNAY_2D_H #include "core/math/rect2.h" class Delaunay2D { public: struct Triangle { - int points[3]; - bool bad; - Triangle() { bad = false; } + bool bad = false; + Triangle() {} Triangle(int p_a, int p_b, int p_c) { points[0] = p_a; points[1] = p_b; points[2] = p_c; - bad = false; } }; struct Edge { int edge[2]; - bool bad; - Edge() { bad = false; } + bool bad = false; + Edge() {} Edge(int p_a, int p_b) { - bad = false; edge[0] = p_a; edge[1] = p_b; } @@ -115,8 +112,6 @@ public: triangles.push_back(Triangle(p_points.size() + 0, p_points.size() + 1, p_points.size() + 2)); for (int i = 0; i < p_points.size(); i++) { - //std::cout << "Traitement du point " << *p << std::endl; - //std::cout << "_triangles contains " << _triangles.size() << " elements" << std::endl; Vector<Edge> polygon; @@ -172,4 +167,4 @@ public: } }; -#endif // DELAUNAY_H +#endif // DELAUNAY_2D_H diff --git a/core/math/delaunay_3d.h b/core/math/delaunay_3d.h new file mode 100644 index 0000000000..ab0993abc9 --- /dev/null +++ b/core/math/delaunay_3d.h @@ -0,0 +1,383 @@ +#ifndef DELAUNAY_3D_H +#define DELAUNAY_3D_H + +#include "core/local_vector.h" +#include "core/math/aabb.h" +#include "core/math/camera_matrix.h" +#include "core/math/vector3.h" +#include "core/oa_hash_map.h" +#include "core/os/file_access.h" +#include "core/print_string.h" +#include "core/variant.h" +#include "core/vector.h" + +#include "thirdparty/misc/r128.h" + +class Delaunay3D { + struct Simplex; + + enum { + ACCEL_GRID_SIZE = 16 + }; + struct GridPos { + Vector3i pos; + List<Simplex *>::Element *E = nullptr; + }; + + struct Simplex { + + uint32_t points[4]; + R128 circum_center_x; + R128 circum_center_y; + R128 circum_center_z; + R128 circum_r2; + LocalVector<GridPos> grid_positions; + List<Simplex *>::Element *SE = nullptr; + + _FORCE_INLINE_ Simplex() {} + _FORCE_INLINE_ Simplex(uint32_t p_a, uint32_t p_b, uint32_t p_c, uint32_t p_d) { + points[0] = p_a; + points[1] = p_b; + points[2] = p_c; + points[3] = p_d; + } + }; + + struct Triangle { + uint32_t triangle[3]; + bool bad = false; + _FORCE_INLINE_ bool operator==(const Triangle &p_triangle) const { + return triangle[0] == p_triangle.triangle[0] && triangle[1] == p_triangle.triangle[1] && triangle[2] == p_triangle.triangle[2]; + } + + _FORCE_INLINE_ Triangle() {} + _FORCE_INLINE_ Triangle(uint32_t p_a, uint32_t p_b, uint32_t p_c) { + if (p_a > p_b) + SWAP(p_a, p_b); + if (p_b > p_c) + SWAP(p_b, p_c); + if (p_a > p_b) + SWAP(p_a, p_b); + + triangle[0] = p_a; + triangle[1] = p_b; + triangle[2] = p_c; + } + }; + + struct TriangleHasher { + _FORCE_INLINE_ static uint32_t hash(const Triangle &p_triangle) { + uint32_t h = hash_djb2_one_32(p_triangle.triangle[0]); + h = hash_djb2_one_32(p_triangle.triangle[1], h); + return hash_djb2_one_32(p_triangle.triangle[2], h); + } + }; + + _FORCE_INLINE_ static void circum_sphere_compute(const Vector3 *p_points, Simplex *p_simplex) { + + // the only part in the algorithm where there may be precision errors is this one, so ensure that + // we do it as maximum precision as possible + + R128 v0_x = p_points[p_simplex->points[0]].x; + R128 v0_y = p_points[p_simplex->points[0]].y; + R128 v0_z = p_points[p_simplex->points[0]].z; + R128 v1_x = p_points[p_simplex->points[1]].x; + R128 v1_y = p_points[p_simplex->points[1]].y; + R128 v1_z = p_points[p_simplex->points[1]].z; + R128 v2_x = p_points[p_simplex->points[2]].x; + R128 v2_y = p_points[p_simplex->points[2]].y; + R128 v2_z = p_points[p_simplex->points[2]].z; + R128 v3_x = p_points[p_simplex->points[3]].x; + R128 v3_y = p_points[p_simplex->points[3]].y; + R128 v3_z = p_points[p_simplex->points[3]].z; + + //Create the rows of our "unrolled" 3x3 matrix + R128 row1_x = v1_x - v0_x; + R128 row1_y = v1_y - v0_y; + R128 row1_z = v1_z - v0_z; + + R128 row2_x = v2_x - v0_x; + R128 row2_y = v2_y - v0_y; + R128 row2_z = v2_z - v0_z; + + R128 row3_x = v3_x - v0_x; + R128 row3_y = v3_y - v0_y; + R128 row3_z = v3_z - v0_z; + + R128 sq_lenght1 = row1_x * row1_x + row1_y * row1_y + row1_z * row1_z; + R128 sq_lenght2 = row2_x * row2_x + row2_y * row2_y + row2_z * row2_z; + R128 sq_lenght3 = row3_x * row3_x + row3_y * row3_y + row3_z * row3_z; + + //Compute the determinant of said matrix + R128 determinant = row1_x * (row2_y * row3_z - row3_y * row2_z) - row2_x * (row1_y * row3_z - row3_y * row1_z) + row3_x * (row1_y * row2_z - row2_y * row1_z); + + // Compute the volume of the tetrahedron, and precompute a scalar quantity for re-use in the formula + R128 volume = determinant / R128(6.f); + R128 i12volume = R128(1.f) / (volume * R128(12.f)); + + R128 center_x = v0_x + i12volume * ((row2_y * row3_z - row3_y * row2_z) * sq_lenght1 - (row1_y * row3_z - row3_y * row1_z) * sq_lenght2 + (row1_y * row2_z - row2_y * row1_z) * sq_lenght3); + R128 center_y = v0_y + i12volume * (-(row2_x * row3_z - row3_x * row2_z) * sq_lenght1 + (row1_x * row3_z - row3_x * row1_z) * sq_lenght2 - (row1_x * row2_z - row2_x * row1_z) * sq_lenght3); + R128 center_z = v0_z + i12volume * ((row2_x * row3_y - row3_x * row2_y) * sq_lenght1 - (row1_x * row3_y - row3_x * row1_y) * sq_lenght2 + (row1_x * row2_y - row2_x * row1_y) * sq_lenght3); + + //Once we know the center, the radius is clearly the distance to any vertex + + R128 rel1_x = center_x - v0_x; + R128 rel1_y = center_y - v0_y; + R128 rel1_z = center_z - v0_z; + + R128 radius1 = rel1_x * rel1_x + rel1_y * rel1_y + rel1_z * rel1_z; + + p_simplex->circum_center_x = center_x; + p_simplex->circum_center_y = center_y; + p_simplex->circum_center_z = center_z; + p_simplex->circum_r2 = radius1; + } + + _FORCE_INLINE_ static bool simplex_contains(const Vector3 *p_points, const Simplex &p_simplex, uint32_t p_vertex) { + + R128 v_x = p_points[p_vertex].x; + R128 v_y = p_points[p_vertex].y; + R128 v_z = p_points[p_vertex].z; + + R128 rel2_x = p_simplex.circum_center_x - v_x; + R128 rel2_y = p_simplex.circum_center_y - v_y; + R128 rel2_z = p_simplex.circum_center_z - v_z; + + R128 radius2 = rel2_x * rel2_x + rel2_y * rel2_y + rel2_z * rel2_z; + + return radius2 < (p_simplex.circum_r2 - R128(0.00001)); + } + + static bool simplex_is_coplanar(const Vector3 *p_points, const Simplex &p_simplex) { + + Plane p(p_points[p_simplex.points[0]], p_points[p_simplex.points[1]], p_points[p_simplex.points[2]]); + if (ABS(p.distance_to(p_points[p_simplex.points[3]])) < CMP_EPSILON) { + return true; + } + + CameraMatrix cm; + + cm.matrix[0][0] = p_points[p_simplex.points[0]].x; + cm.matrix[0][1] = p_points[p_simplex.points[1]].x; + cm.matrix[0][2] = p_points[p_simplex.points[2]].x; + cm.matrix[0][3] = p_points[p_simplex.points[3]].x; + + cm.matrix[1][0] = p_points[p_simplex.points[0]].y; + cm.matrix[1][1] = p_points[p_simplex.points[1]].y; + cm.matrix[1][2] = p_points[p_simplex.points[2]].y; + cm.matrix[1][3] = p_points[p_simplex.points[3]].y; + + cm.matrix[2][0] = p_points[p_simplex.points[0]].z; + cm.matrix[2][1] = p_points[p_simplex.points[1]].z; + cm.matrix[2][2] = p_points[p_simplex.points[2]].z; + cm.matrix[2][3] = p_points[p_simplex.points[3]].z; + + cm.matrix[3][0] = 1.0; + cm.matrix[3][1] = 1.0; + cm.matrix[3][2] = 1.0; + cm.matrix[3][3] = 1.0; + + return ABS(cm.determinant()) <= CMP_EPSILON; + } + +public: + struct OutputSimplex { + uint32_t points[4]; + }; + + static Vector<OutputSimplex> tetrahedralize(const Vector<Vector3> &p_points) { + + uint32_t point_count = p_points.size(); + Vector3 *points = (Vector3 *)memalloc(sizeof(Vector3) * (point_count + 4)); + + { + const Vector3 *src_points = p_points.ptr(); + AABB rect; + for (uint32_t i = 0; i < point_count; i++) { + Vector3 point = src_points[i]; + if (i == 0) { + rect.position = point; + } else { + rect.expand_to(point); + } + points[i] = point; + } + + for (uint32_t i = 0; i < point_count; i++) { + points[i] = (points[i] - rect.position) / rect.size; + } + + float delta_max = Math::sqrt(2.0) * 20.0; + Vector3 center = Vector3(0.5, 0.5, 0.5); + + // any simplex that contains everything is good + points[point_count + 0] = center + Vector3(0, 1, 0) * delta_max; + points[point_count + 1] = center + Vector3(0, -1, 1) * delta_max; + points[point_count + 2] = center + Vector3(1, -1, -1) * delta_max; + points[point_count + 3] = center + Vector3(-1, -1, -1) * delta_max; + } + + List<Simplex *> acceleration_grid[ACCEL_GRID_SIZE][ACCEL_GRID_SIZE][ACCEL_GRID_SIZE]; + + List<Simplex *> simplex_list; + { + //create root simplex + Simplex *root = memnew(Simplex(point_count + 0, point_count + 1, point_count + 2, point_count + 3)); + root->SE = simplex_list.push_back(root); + + for (uint32_t i = 0; i < ACCEL_GRID_SIZE; i++) { + for (uint32_t j = 0; j < ACCEL_GRID_SIZE; j++) { + for (uint32_t k = 0; k < ACCEL_GRID_SIZE; k++) { + GridPos gp; + gp.E = acceleration_grid[i][j][k].push_back(root); + gp.pos = Vector3i(i, j, k); + root->grid_positions.push_back(gp); + } + } + } + + circum_sphere_compute(points, root); + } + + OAHashMap<Triangle, uint32_t, TriangleHasher> triangles_inserted; + LocalVector<Triangle> triangles; + + for (uint32_t i = 0; i < point_count; i++) { + + bool unique = true; + for (uint32_t j = i + 1; j < point_count; j++) { + if (points[i].is_equal_approx(points[j])) { + unique = false; + break; + } + } + if (!unique) { + continue; + } + + Vector3i grid_pos = Vector3i(points[i] * ACCEL_GRID_SIZE); + grid_pos.x = CLAMP(grid_pos.x, 0, ACCEL_GRID_SIZE - 1); + grid_pos.y = CLAMP(grid_pos.y, 0, ACCEL_GRID_SIZE - 1); + grid_pos.z = CLAMP(grid_pos.z, 0, ACCEL_GRID_SIZE - 1); + + for (List<Simplex *>::Element *E = acceleration_grid[grid_pos.x][grid_pos.y][grid_pos.z].front(); E;) { + List<Simplex *>::Element *N = E->next(); //may be deleted + + Simplex *simplex = E->get(); + + if (simplex_contains(points, *simplex, i)) { + + static const uint32_t triangle_order[4][3] = { + { 0, 1, 2 }, + { 0, 1, 3 }, + { 0, 2, 3 }, + { 1, 2, 3 }, + }; + + for (uint32_t k = 0; k < 4; k++) { + Triangle t = Triangle(simplex->points[triangle_order[k][0]], simplex->points[triangle_order[k][1]], simplex->points[triangle_order[k][2]]); + uint32_t *p = triangles_inserted.lookup_ptr(t); + if (p) { + triangles[*p].bad = true; + } else { + triangles_inserted.insert(t, triangles.size()); + triangles.push_back(t); + } + } + + //remove simplex and continue + simplex_list.erase(simplex->SE); + + for (uint32_t k = 0; k < simplex->grid_positions.size(); k++) { + Vector3i p = simplex->grid_positions[k].pos; + acceleration_grid[p.x][p.y][p.z].erase(simplex->grid_positions[k].E); + } + memdelete(simplex); + } + E = N; + } + + uint32_t good_triangles = 0; + for (uint32_t j = 0; j < triangles.size(); j++) { + + if (triangles[j].bad) { + continue; + } + Simplex *new_simplex = memnew(Simplex(triangles[j].triangle[0], triangles[j].triangle[1], triangles[j].triangle[2], i)); + circum_sphere_compute(points, new_simplex); + new_simplex->SE = simplex_list.push_back(new_simplex); + { + Vector3 center; + center.x = double(new_simplex->circum_center_x); + center.y = double(new_simplex->circum_center_y); + center.z = double(new_simplex->circum_center_z); + + float radius2 = Math::sqrt(double(new_simplex->circum_r2)); + radius2 += 0.0001; // + Vector3 extents = Vector3(radius2, radius2, radius2); + Vector3i from = Vector3i((center - extents) * ACCEL_GRID_SIZE); + Vector3i to = Vector3i((center + extents) * ACCEL_GRID_SIZE); + from.x = CLAMP(from.x, 0, ACCEL_GRID_SIZE - 1); + from.y = CLAMP(from.y, 0, ACCEL_GRID_SIZE - 1); + from.z = CLAMP(from.z, 0, ACCEL_GRID_SIZE - 1); + to.x = CLAMP(to.x, 0, ACCEL_GRID_SIZE - 1); + to.y = CLAMP(to.y, 0, ACCEL_GRID_SIZE - 1); + to.z = CLAMP(to.z, 0, ACCEL_GRID_SIZE - 1); + + for (int32_t x = from.x; x <= to.x; x++) { + for (int32_t y = from.y; y <= to.y; y++) { + for (int32_t z = from.z; z <= to.z; z++) { + GridPos gp; + gp.pos = Vector3(x, y, z); + gp.E = acceleration_grid[x][y][z].push_back(new_simplex); + new_simplex->grid_positions.push_back(gp); + } + } + } + } + + good_triangles++; + } + + //print_line("at point " + itos(i) + "/" + itos(point_count) + " simplices added " + itos(good_triangles) + "/" + itos(simplex_list.size()) + " - triangles: " + itos(triangles.size())); + triangles.clear(); + triangles_inserted.clear(); + } + + //print_line("end with simplices: " + itos(simplex_list.size())); + Vector<OutputSimplex> ret_simplices; + ret_simplices.resize(simplex_list.size()); + OutputSimplex *ret_simplicesw = ret_simplices.ptrw(); + uint32_t simplices_written = 0; + + for (List<Simplex *>::Element *E = simplex_list.front(); E; E = E->next()) { + Simplex *simplex = E->get(); + bool invalid = false; + for (int j = 0; j < 4; j++) { + if (simplex->points[j] >= point_count) { + invalid = true; + break; + } + } + if (invalid || simplex_is_coplanar(points, *simplex)) { + memdelete(simplex); + continue; + } + + ret_simplicesw[simplices_written].points[0] = simplex->points[0]; + ret_simplicesw[simplices_written].points[1] = simplex->points[1]; + ret_simplicesw[simplices_written].points[2] = simplex->points[2]; + ret_simplicesw[simplices_written].points[3] = simplex->points[3]; + simplices_written++; + memdelete(simplex); + } + + ret_simplices.resize(simplices_written); + + memfree(points); + + return ret_simplices; + } +}; + +#endif // DELAUNAY_3D_H diff --git a/core/math/expression.cpp b/core/math/expression.cpp index 859b9be8c5..c43831ddee 100644 --- a/core/math/expression.cpp +++ b/core/math/expression.cpp @@ -1023,11 +1023,21 @@ Error Expression::_get_token(Token &r_token) { switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; - case 'r': res = 13; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; + case 'r': + res = 13; + break; case 'u': { // hex number for (int j = 0; j < 4; j++) { @@ -1703,27 +1713,69 @@ Expression::ENode *Expression::_parse_expression() { Variant::Operator op = Variant::OP_MAX; switch (tk.type) { - case TK_OP_IN: op = Variant::OP_IN; break; - case TK_OP_EQUAL: op = Variant::OP_EQUAL; break; - case TK_OP_NOT_EQUAL: op = Variant::OP_NOT_EQUAL; break; - case TK_OP_LESS: op = Variant::OP_LESS; break; - case TK_OP_LESS_EQUAL: op = Variant::OP_LESS_EQUAL; break; - case TK_OP_GREATER: op = Variant::OP_GREATER; break; - case TK_OP_GREATER_EQUAL: op = Variant::OP_GREATER_EQUAL; break; - case TK_OP_AND: op = Variant::OP_AND; break; - case TK_OP_OR: op = Variant::OP_OR; break; - case TK_OP_NOT: op = Variant::OP_NOT; break; - case TK_OP_ADD: op = Variant::OP_ADD; break; - case TK_OP_SUB: op = Variant::OP_SUBTRACT; break; - case TK_OP_MUL: op = Variant::OP_MULTIPLY; break; - case TK_OP_DIV: op = Variant::OP_DIVIDE; break; - case TK_OP_MOD: op = Variant::OP_MODULE; break; - case TK_OP_SHIFT_LEFT: op = Variant::OP_SHIFT_LEFT; break; - case TK_OP_SHIFT_RIGHT: op = Variant::OP_SHIFT_RIGHT; break; - case TK_OP_BIT_AND: op = Variant::OP_BIT_AND; break; - case TK_OP_BIT_OR: op = Variant::OP_BIT_OR; break; - case TK_OP_BIT_XOR: op = Variant::OP_BIT_XOR; break; - case TK_OP_BIT_INVERT: op = Variant::OP_BIT_NEGATE; break; + case TK_OP_IN: + op = Variant::OP_IN; + break; + case TK_OP_EQUAL: + op = Variant::OP_EQUAL; + break; + case TK_OP_NOT_EQUAL: + op = Variant::OP_NOT_EQUAL; + break; + case TK_OP_LESS: + op = Variant::OP_LESS; + break; + case TK_OP_LESS_EQUAL: + op = Variant::OP_LESS_EQUAL; + break; + case TK_OP_GREATER: + op = Variant::OP_GREATER; + break; + case TK_OP_GREATER_EQUAL: + op = Variant::OP_GREATER_EQUAL; + break; + case TK_OP_AND: + op = Variant::OP_AND; + break; + case TK_OP_OR: + op = Variant::OP_OR; + break; + case TK_OP_NOT: + op = Variant::OP_NOT; + break; + case TK_OP_ADD: + op = Variant::OP_ADD; + break; + case TK_OP_SUB: + op = Variant::OP_SUBTRACT; + break; + case TK_OP_MUL: + op = Variant::OP_MULTIPLY; + break; + case TK_OP_DIV: + op = Variant::OP_DIVIDE; + break; + case TK_OP_MOD: + op = Variant::OP_MODULE; + break; + case TK_OP_SHIFT_LEFT: + op = Variant::OP_SHIFT_LEFT; + break; + case TK_OP_SHIFT_RIGHT: + op = Variant::OP_SHIFT_RIGHT; + break; + case TK_OP_BIT_AND: + op = Variant::OP_BIT_AND; + break; + case TK_OP_BIT_OR: + op = Variant::OP_BIT_OR; + break; + case TK_OP_BIT_XOR: + op = Variant::OP_BIT_XOR; + break; + case TK_OP_BIT_INVERT: + op = Variant::OP_BIT_NEGATE; + break; default: { }; } @@ -1772,36 +1824,74 @@ Expression::ENode *Expression::_parse_expression() { unary = true; break; - case Variant::OP_MULTIPLY: priority = 2; break; - case Variant::OP_DIVIDE: priority = 2; break; - case Variant::OP_MODULE: priority = 2; break; + case Variant::OP_MULTIPLY: + priority = 2; + break; + case Variant::OP_DIVIDE: + priority = 2; + break; + case Variant::OP_MODULE: + priority = 2; + break; - case Variant::OP_ADD: priority = 3; break; - case Variant::OP_SUBTRACT: priority = 3; break; + case Variant::OP_ADD: + priority = 3; + break; + case Variant::OP_SUBTRACT: + priority = 3; + break; - case Variant::OP_SHIFT_LEFT: priority = 4; break; - case Variant::OP_SHIFT_RIGHT: priority = 4; break; + case Variant::OP_SHIFT_LEFT: + priority = 4; + break; + case Variant::OP_SHIFT_RIGHT: + priority = 4; + break; - case Variant::OP_BIT_AND: priority = 5; break; - case Variant::OP_BIT_XOR: priority = 6; break; - case Variant::OP_BIT_OR: priority = 7; break; + case Variant::OP_BIT_AND: + priority = 5; + break; + case Variant::OP_BIT_XOR: + priority = 6; + break; + case Variant::OP_BIT_OR: + priority = 7; + break; - case Variant::OP_LESS: priority = 8; break; - case Variant::OP_LESS_EQUAL: priority = 8; break; - case Variant::OP_GREATER: priority = 8; break; - case Variant::OP_GREATER_EQUAL: priority = 8; break; + case Variant::OP_LESS: + priority = 8; + break; + case Variant::OP_LESS_EQUAL: + priority = 8; + break; + case Variant::OP_GREATER: + priority = 8; + break; + case Variant::OP_GREATER_EQUAL: + priority = 8; + break; - case Variant::OP_EQUAL: priority = 8; break; - case Variant::OP_NOT_EQUAL: priority = 8; break; + case Variant::OP_EQUAL: + priority = 8; + break; + case Variant::OP_NOT_EQUAL: + priority = 8; + break; - case Variant::OP_IN: priority = 10; break; + case Variant::OP_IN: + priority = 10; + break; case Variant::OP_NOT: priority = 11; unary = true; break; - case Variant::OP_AND: priority = 12; break; - case Variant::OP_OR: priority = 13; break; + case Variant::OP_AND: + priority = 12; + break; + case Variant::OP_OR: + priority = 13; + break; default: { _set_error("Parser bug, invalid operator in expression: " + itos(expression[i].op)); @@ -2208,17 +2298,6 @@ void Expression::_bind_methods() { ClassDB::bind_method(D_METHOD("get_error_text"), &Expression::get_error_text); } -Expression::Expression() : - output_type(Variant::NIL), - sequenced(false), - error_set(true), - root(nullptr), - nodes(nullptr), - execution_error(false) { - str_ofs = 0; - expression_dirty = false; -} - Expression::~Expression() { if (nodes) { diff --git a/core/math/expression.h b/core/math/expression.h index 78de225ebf..bf710ecdd5 100644 --- a/core/math/expression.h +++ b/core/math/expression.h @@ -119,22 +119,20 @@ private: struct Input { - Variant::Type type; + Variant::Type type = Variant::NIL; String name; - Input() : - type(Variant::NIL) { - } + Input() {} }; Vector<Input> inputs; - Variant::Type output_type; + Variant::Type output_type = Variant::NIL; String expression; - bool sequenced; - int str_ofs; - bool expression_dirty; + bool sequenced = false; + int str_ofs = 0; + bool expression_dirty = false; bool _compile_expression(); @@ -197,7 +195,7 @@ private: Error _get_token(Token &r_token); String error_str; - bool error_set; + bool error_set = true; struct ENode { @@ -215,11 +213,11 @@ private: TYPE_CALL }; - ENode *next; + ENode *next = nullptr; Type type; - ENode() { next = nullptr; } + ENode() {} virtual ~ENode() { if (next) { memdelete(next); @@ -339,12 +337,12 @@ private: return node; } - ENode *root; - ENode *nodes; + ENode *root = nullptr; + ENode *nodes = nullptr; Vector<String> input_names; - bool execution_error; + bool execution_error = false; bool _execute(const Array &p_inputs, Object *p_instance, Expression::ENode *p_node, Variant &r_ret, String &r_error_str); protected: @@ -356,7 +354,7 @@ public: bool has_execute_failed() const; String get_error_text() const; - Expression(); + Expression() {} ~Expression(); }; diff --git a/core/math/face3.h b/core/math/face3.h index f4b8721caa..da269028f5 100644 --- a/core/math/face3.h +++ b/core/math/face3.h @@ -69,8 +69,8 @@ public: Vector3 get_median_point() const; Vector3 get_closest_point_to(const Vector3 &p_point) const; - bool intersects_ray(const Vector3 &p_from, const Vector3 &p_dir, Vector3 *p_intersection = 0) const; - bool intersects_segment(const Vector3 &p_from, const Vector3 &p_dir, Vector3 *p_intersection = 0) const; + bool intersects_ray(const Vector3 &p_from, const Vector3 &p_dir, Vector3 *p_intersection = nullptr) const; + bool intersects_segment(const Vector3 &p_from, const Vector3 &p_dir, Vector3 *p_intersection = nullptr) const; ClockDirection get_clock_dir() const; ///< todo, test if this is returning the proper clockwisity diff --git a/core/math/geometry.cpp b/core/math/geometry.cpp index fa96fc4535..f923b62542 100644 --- a/core/math/geometry.cpp +++ b/core/math/geometry.cpp @@ -31,8 +31,11 @@ #include "geometry.h" #include "core/print_string.h" + #include "thirdparty/misc/clipper.hpp" #include "thirdparty/misc/triangulator.h" +#define STB_RECT_PACK_IMPLEMENTATION +#include "thirdparty/misc/stb_rect_pack.h" #define SCALE_FACTOR 100000.0 // Based on CMP_EPSILON. @@ -101,25 +104,19 @@ struct _FaceClassify { struct _Link { - int face; - int edge; + int face = -1; + int edge = -1; void clear() { face = -1; edge = -1; } - _Link() { - face = -1; - edge = -1; - } + _Link() {} }; - bool valid; - int group; + bool valid = false; + int group = -1; _Link links[3]; Face3 face; - _FaceClassify() { - group = -1; - valid = false; - }; + _FaceClassify() {} }; static bool _connect_faces(_FaceClassify *p_faces, int len, int p_group) { @@ -445,7 +442,8 @@ static inline void _mark_outside(uint8_t ***p_cell_status, int x, int y, int z, next_z--; prev = _CELL_PREV_Z_POS; } break; - default: ERR_FAIL(); + default: + ERR_FAIL(); } if (next_x < 0 || next_x >= len_x) @@ -1083,10 +1081,18 @@ Vector<Vector<Point2>> Geometry::_polypaths_do_operation(PolyBooleanOperation p_ ClipType op = ctUnion; switch (p_op) { - case OPERATION_UNION: op = ctUnion; break; - case OPERATION_DIFFERENCE: op = ctDifference; break; - case OPERATION_INTERSECTION: op = ctIntersection; break; - case OPERATION_XOR: op = ctXor; break; + case OPERATION_UNION: + op = ctUnion; + break; + case OPERATION_DIFFERENCE: + op = ctDifference; + break; + case OPERATION_INTERSECTION: + op = ctIntersection; + break; + case OPERATION_XOR: + op = ctXor; + break; } Path path_a, path_b; @@ -1135,19 +1141,35 @@ Vector<Vector<Point2>> Geometry::_polypath_offset(const Vector<Point2> &p_polypa JoinType jt = jtSquare; switch (p_join_type) { - case JOIN_SQUARE: jt = jtSquare; break; - case JOIN_ROUND: jt = jtRound; break; - case JOIN_MITER: jt = jtMiter; break; + case JOIN_SQUARE: + jt = jtSquare; + break; + case JOIN_ROUND: + jt = jtRound; + break; + case JOIN_MITER: + jt = jtMiter; + break; } EndType et = etClosedPolygon; switch (p_end_type) { - case END_POLYGON: et = etClosedPolygon; break; - case END_JOINED: et = etClosedLine; break; - case END_BUTT: et = etOpenButt; break; - case END_SQUARE: et = etOpenSquare; break; - case END_ROUND: et = etOpenRound; break; + case END_POLYGON: + et = etClosedPolygon; + break; + case END_JOINED: + et = etClosedLine; + break; + case END_BUTT: + et = etOpenButt; + break; + case END_SQUARE: + et = etOpenSquare; + break; + case END_ROUND: + et = etOpenRound; + break; } ClipperOffset co(2.0, 0.25 * SCALE_FACTOR); // Defaults from ClipperOffset. Path path; @@ -1217,3 +1239,195 @@ Vector<Vector3> Geometry::compute_convex_mesh_points(const Plane *p_planes, int return points; } + +Vector<Point2i> Geometry::pack_rects(const Vector<Size2i> &p_sizes, const Size2i &p_atlas_size) { + + Vector<stbrp_node> nodes; + nodes.resize(p_atlas_size.width); + + stbrp_context context; + stbrp_init_target(&context, p_atlas_size.width, p_atlas_size.height, nodes.ptrw(), p_atlas_size.width); + + Vector<stbrp_rect> rects; + rects.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + rects.write[i].id = 0; + rects.write[i].w = p_sizes[i].width; + rects.write[i].h = p_sizes[i].height; + rects.write[i].x = 0; + rects.write[i].y = 0; + rects.write[i].was_packed = 0; + } + + int res = stbrp_pack_rects(&context, rects.ptrw(), rects.size()); + if (res == 0) { //pack failed + return Vector<Point2i>(); + } + + Vector<Point2i> ret; + ret.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + Point2i r(rects[i].x, rects[i].y); + ret.write[i] = r; + } + + return ret; +} + +Vector<Vector3i> Geometry::partial_pack_rects(const Vector<Vector2i> &p_sizes, const Size2i &p_atlas_size) { + + Vector<stbrp_node> nodes; + nodes.resize(p_atlas_size.width); + zeromem(nodes.ptrw(), sizeof(stbrp_node) * nodes.size()); + + stbrp_context context; + stbrp_init_target(&context, p_atlas_size.width, p_atlas_size.height, nodes.ptrw(), p_atlas_size.width); + + Vector<stbrp_rect> rects; + rects.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + rects.write[i].id = i; + rects.write[i].w = p_sizes[i].width; + rects.write[i].h = p_sizes[i].height; + rects.write[i].x = 0; + rects.write[i].y = 0; + rects.write[i].was_packed = 0; + } + + stbrp_pack_rects(&context, rects.ptrw(), rects.size()); + + Vector<Vector3i> ret; + ret.resize(p_sizes.size()); + + for (int i = 0; i < p_sizes.size(); i++) { + ret.write[rects[i].id] = Vector3i(rects[i].x, rects[i].y, rects[i].was_packed != 0 ? 1 : 0); + } + + return ret; +} + +#define square(m_s) ((m_s) * (m_s)) +#define INF 1e20 + +/* dt of 1d function using squared distance */ +static void edt(float *f, int stride, int n) { + + float *d = (float *)alloca(sizeof(float) * n + sizeof(int) * n + sizeof(float) * (n + 1)); + int *v = (int *)&(d[n]); + float *z = (float *)&v[n]; + + int k = 0; + v[0] = 0; + z[0] = -INF; + z[1] = +INF; + for (int q = 1; q <= n - 1; q++) { + float s = ((f[q * stride] + square(q)) - (f[v[k] * stride] + square(v[k]))) / (2 * q - 2 * v[k]); + while (s <= z[k]) { + k--; + s = ((f[q * stride] + square(q)) - (f[v[k] * stride] + square(v[k]))) / (2 * q - 2 * v[k]); + } + k++; + v[k] = q; + + z[k] = s; + z[k + 1] = +INF; + } + + k = 0; + for (int q = 0; q <= n - 1; q++) { + while (z[k + 1] < q) + k++; + d[q] = square(q - v[k]) + f[v[k] * stride]; + } + + for (int i = 0; i < n; i++) { + f[i * stride] = d[i]; + } +} + +#undef square + +Vector<uint32_t> Geometry::generate_edf(const Vector<bool> &p_voxels, const Vector3i &p_size, bool p_negative) { + + uint32_t float_count = p_size.x * p_size.y * p_size.z; + + ERR_FAIL_COND_V((uint32_t)p_voxels.size() != float_count, Vector<uint32_t>()); + + float *work_memory = memnew_arr(float, float_count); + for (uint32_t i = 0; i < float_count; i++) { + work_memory[i] = INF; + } + + uint32_t y_mult = p_size.x; + uint32_t z_mult = y_mult * p_size.y; + + //plot solid cells + { + const bool *voxr = p_voxels.ptr(); + for (uint32_t i = 0; i < float_count; i++) { + + bool plot = voxr[i]; + if (p_negative) { + plot = !plot; + } + if (plot) { + work_memory[i] = 0; + } + } + } + + //process in each direction + + //xy->z + + for (int i = 0; i < p_size.x; i++) { + for (int j = 0; j < p_size.y; j++) { + edt(&work_memory[i + j * y_mult], z_mult, p_size.z); + } + } + + //xz->y + + for (int i = 0; i < p_size.x; i++) { + for (int j = 0; j < p_size.z; j++) { + edt(&work_memory[i + j * z_mult], y_mult, p_size.y); + } + } + + //yz->x + for (int i = 0; i < p_size.y; i++) { + for (int j = 0; j < p_size.z; j++) { + edt(&work_memory[i * y_mult + j * z_mult], 1, p_size.x); + } + } + + Vector<uint32_t> ret; + ret.resize(float_count); + { + uint32_t *w = ret.ptrw(); + for (uint32_t i = 0; i < float_count; i++) { + w[i] = uint32_t(Math::sqrt(work_memory[i])); + } + } + + return ret; +} + +Vector<int8_t> Geometry::generate_sdf8(const Vector<uint32_t> &p_positive, const Vector<uint32_t> &p_negative) { + ERR_FAIL_COND_V(p_positive.size() != p_negative.size(), Vector<int8_t>()); + Vector<int8_t> sdf8; + int s = p_positive.size(); + sdf8.resize(s); + + const uint32_t *rpos = p_positive.ptr(); + const uint32_t *rneg = p_negative.ptr(); + int8_t *wsdf = sdf8.ptrw(); + for (int i = 0; i < s; i++) { + int32_t diff = int32_t(rpos[i]) - int32_t(rneg[i]); + wsdf[i] = CLAMP(diff, -128, 127); + } + return sdf8; +} diff --git a/core/math/geometry.h b/core/math/geometry.h index ea063a8a59..b90cae3786 100644 --- a/core/math/geometry.h +++ b/core/math/geometry.h @@ -31,13 +31,12 @@ #ifndef GEOMETRY_H #define GEOMETRY_H -#include "core/math/delaunay.h" +#include "core/math/delaunay_2d.h" #include "core/math/face3.h" #include "core/math/rect2.h" #include "core/math/triangulate.h" #include "core/math/vector3.h" #include "core/object.h" - #include "core/print_string.h" #include "core/vector.h" @@ -113,10 +112,14 @@ public: real_t mub = (d_of(p1, q1, q2, q1) + mua * d_of(q2, q1, p2, p1)) / d_of(q2, q1, q2, q1); // Clip the value between [0..1] constraining the solution to lie on the original curves. - if (mua < 0) mua = 0; - if (mub < 0) mub = 0; - if (mua > 1) mua = 1; - if (mub > 1) mub = 1; + if (mua < 0) + mua = 0; + if (mub < 0) + mub = 0; + if (mua > 1) + mua = 1; + if (mub > 1) + mub = 1; c1 = p1.lerp(p2, mua); c2 = q1.lerp(q2, mub); } @@ -187,7 +190,7 @@ public: return dP.length(); // Return the closest distance. } - static inline bool ray_intersects_triangle(const Vector3 &p_from, const Vector3 &p_dir, const Vector3 &p_v0, const Vector3 &p_v1, const Vector3 &p_v2, Vector3 *r_res = 0) { + static inline bool ray_intersects_triangle(const Vector3 &p_from, const Vector3 &p_dir, const Vector3 &p_v0, const Vector3 &p_v1, const Vector3 &p_v2, Vector3 *r_res = nullptr) { Vector3 e1 = p_v1 - p_v0; Vector3 e2 = p_v2 - p_v0; Vector3 h = p_dir.cross(e2); @@ -222,7 +225,7 @@ public: return false; } - static inline bool segment_intersects_triangle(const Vector3 &p_from, const Vector3 &p_to, const Vector3 &p_v0, const Vector3 &p_v1, const Vector3 &p_v2, Vector3 *r_res = 0) { + static inline bool segment_intersects_triangle(const Vector3 &p_from, const Vector3 &p_to, const Vector3 &p_v0, const Vector3 &p_v1, const Vector3 &p_v2, Vector3 *r_res = nullptr) { Vector3 rel = p_to - p_from; Vector3 e1 = p_v1 - p_v0; @@ -259,7 +262,7 @@ public: return false; } - static inline bool segment_intersects_sphere(const Vector3 &p_from, const Vector3 &p_to, const Vector3 &p_sphere_pos, real_t p_sphere_radius, Vector3 *r_res = 0, Vector3 *r_norm = 0) { + static inline bool segment_intersects_sphere(const Vector3 &p_from, const Vector3 &p_to, const Vector3 &p_sphere_pos, real_t p_sphere_radius, Vector3 *r_res = nullptr, Vector3 *r_norm = nullptr) { Vector3 sphere_pos = p_sphere_pos - p_from; Vector3 rel = (p_to - p_from); @@ -295,7 +298,7 @@ public: return true; } - static inline bool segment_intersects_cylinder(const Vector3 &p_from, const Vector3 &p_to, real_t p_height, real_t p_radius, Vector3 *r_res = 0, Vector3 *r_norm = 0) { + static inline bool segment_intersects_cylinder(const Vector3 &p_from, const Vector3 &p_to, real_t p_height, real_t p_radius, Vector3 *r_res = nullptr, Vector3 *r_norm = nullptr) { Vector3 rel = (p_to - p_from); real_t rel_l = rel.length(); @@ -497,7 +500,8 @@ public: bool orientation = an.cross(bn) > 0; - if ((bn.cross(cn) > 0) != orientation) return false; + if ((bn.cross(cn) > 0) != orientation) + return false; return (cn.cross(an) > 0) == orientation; } @@ -683,7 +687,8 @@ public: // If the term we intend to square root is less than 0 then the answer won't be real, // so it definitely won't be t in the range 0 to 1. - if (sqrtterm < 0) return -1; + if (sqrtterm < 0) + return -1; // If we can assume that the line segment starts outside the circle (e.g. for continuous time collision detection) // then the following can be skipped and we can just return the equivalent of res1. @@ -691,8 +696,10 @@ public: real_t res1 = (-b - sqrtterm) / (2 * a); real_t res2 = (-b + sqrtterm) / (2 * a); - if (res1 >= 0 && res1 <= 1) return res1; - if (res2 >= 0 && res2 <= 1) return res2; + if (res1 >= 0 && res1 <= 1) + return res1; + if (res2 >= 0 && res2 <= 1) + return res2; return -1; } @@ -1016,6 +1023,249 @@ public: static Vector<Vector3> compute_convex_mesh_points(const Plane *p_planes, int p_plane_count); +#define FINDMINMAX(x0, x1, x2, min, max) \ + min = max = x0; \ + if (x1 < min) \ + min = x1; \ + if (x1 > max) \ + max = x1; \ + if (x2 < min) \ + min = x2; \ + if (x2 > max) \ + max = x2; + + _FORCE_INLINE_ static bool planeBoxOverlap(Vector3 normal, float d, Vector3 maxbox) { + int q; + Vector3 vmin, vmax; + for (q = 0; q <= 2; q++) { + if (normal[q] > 0.0f) { + vmin[q] = -maxbox[q]; + vmax[q] = maxbox[q]; + } else { + vmin[q] = maxbox[q]; + vmax[q] = -maxbox[q]; + } + } + if (normal.dot(vmin) + d > 0.0f) + return false; + if (normal.dot(vmax) + d >= 0.0f) + return true; + + return false; + } + +/*======================== X-tests ========================*/ +#define AXISTEST_X01(a, b, fa, fb) \ + p0 = a * v0.y - b * v0.z; \ + p2 = a * v2.y - b * v2.z; \ + if (p0 < p2) { \ + min = p0; \ + max = p2; \ + } else { \ + min = p2; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_X2(a, b, fa, fb) \ + p0 = a * v0.y - b * v0.z; \ + p1 = a * v1.y - b * v1.z; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +/*======================== Y-tests ========================*/ +#define AXISTEST_Y02(a, b, fa, fb) \ + p0 = -a * v0.x + b * v0.z; \ + p2 = -a * v2.x + b * v2.z; \ + if (p0 < p2) { \ + min = p0; \ + max = p2; \ + } else { \ + min = p2; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_Y1(a, b, fa, fb) \ + p0 = -a * v0.x + b * v0.z; \ + p1 = -a * v1.x + b * v1.z; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ + if (min > rad || max < -rad) \ + return false; + + /*======================== Z-tests ========================*/ + +#define AXISTEST_Z12(a, b, fa, fb) \ + p1 = a * v1.x - b * v1.y; \ + p2 = a * v2.x - b * v2.y; \ + if (p2 < p1) { \ + min = p2; \ + max = p1; \ + } else { \ + min = p1; \ + max = p2; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ + if (min > rad || max < -rad) \ + return false; + +#define AXISTEST_Z0(a, b, fa, fb) \ + p0 = a * v0.x - b * v0.y; \ + p1 = a * v1.x - b * v1.y; \ + if (p0 < p1) { \ + min = p0; \ + max = p1; \ + } else { \ + min = p1; \ + max = p0; \ + } \ + rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ + if (min > rad || max < -rad) \ + return false; + + _FORCE_INLINE_ static bool triangle_box_overlap(const Vector3 &boxcenter, const Vector3 boxhalfsize, const Vector3 *triverts) { + + /* use separating axis theorem to test overlap between triangle and box */ + /* need to test for overlap in these directions: */ + /* 1) the {x,y,z}-directions (actually, since we use the AABB of the triangle */ + /* we do not even need to test these) */ + /* 2) normal of the triangle */ + /* 3) crossproduct(edge from tri, {x,y,z}-directin) */ + /* this gives 3x3=9 more tests */ + Vector3 v0, v1, v2; + float min, max, d, p0, p1, p2, rad, fex, fey, fez; + Vector3 normal, e0, e1, e2; + + /* This is the fastest branch on Sun */ + /* move everything so that the boxcenter is in (0,0,0) */ + + v0 = triverts[0] - boxcenter; + v1 = triverts[1] - boxcenter; + v2 = triverts[2] - boxcenter; + + /* compute triangle edges */ + e0 = v1 - v0; /* tri edge 0 */ + e1 = v2 - v1; /* tri edge 1 */ + e2 = v0 - v2; /* tri edge 2 */ + + /* Bullet 3: */ + /* test the 9 tests first (this was faster) */ + fex = Math::abs(e0.x); + fey = Math::abs(e0.y); + fez = Math::abs(e0.z); + AXISTEST_X01(e0.z, e0.y, fez, fey); + AXISTEST_Y02(e0.z, e0.x, fez, fex); + AXISTEST_Z12(e0.y, e0.x, fey, fex); + + fex = Math::abs(e1.x); + fey = Math::abs(e1.y); + fez = Math::abs(e1.z); + AXISTEST_X01(e1.z, e1.y, fez, fey); + AXISTEST_Y02(e1.z, e1.x, fez, fex); + AXISTEST_Z0(e1.y, e1.x, fey, fex); + + fex = Math::abs(e2.x); + fey = Math::abs(e2.y); + fez = Math::abs(e2.z); + AXISTEST_X2(e2.z, e2.y, fez, fey); + AXISTEST_Y1(e2.z, e2.x, fez, fex); + AXISTEST_Z12(e2.y, e2.x, fey, fex); + + /* Bullet 1: */ + /* first test overlap in the {x,y,z}-directions */ + /* find min, max of the triangle each direction, and test for overlap in */ + /* that direction -- this is equivalent to testing a minimal AABB around */ + /* the triangle against the AABB */ + + /* test in X-direction */ + FINDMINMAX(v0.x, v1.x, v2.x, min, max); + if (min > boxhalfsize.x || max < -boxhalfsize.x) + return false; + + /* test in Y-direction */ + FINDMINMAX(v0.y, v1.y, v2.y, min, max); + if (min > boxhalfsize.y || max < -boxhalfsize.y) + return false; + + /* test in Z-direction */ + FINDMINMAX(v0.z, v1.z, v2.z, min, max); + if (min > boxhalfsize.z || max < -boxhalfsize.z) + return false; + + /* Bullet 2: */ + /* test if the box intersects the plane of the triangle */ + /* compute plane equation of triangle: normal*x+d=0 */ + normal = e0.cross(e1); + d = -normal.dot(v0); /* plane eq: normal.x+d=0 */ + return planeBoxOverlap(normal, d, boxhalfsize); /* if true, box and triangle overlaps */ + } + + static Vector<Point2i> pack_rects(const Vector<Size2i> &p_sizes, const Size2i &p_atlas_size); + static Vector<Vector3i> partial_pack_rects(const Vector<Vector2i> &p_sizes, const Size2i &p_atlas_size); + + static Vector<uint32_t> generate_edf(const Vector<bool> &p_voxels, const Vector3i &p_size, bool p_negative); + static Vector<int8_t> generate_sdf8(const Vector<uint32_t> &p_positive, const Vector<uint32_t> &p_negative); + + static Vector3 triangle_get_barycentric_coords(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_c, const Vector3 &p_pos) { + Vector3 v0 = p_b - p_a; + Vector3 v1 = p_c - p_a; + Vector3 v2 = p_pos - p_a; + + float d00 = v0.dot(v0); + float d01 = v0.dot(v1); + float d11 = v1.dot(v1); + float d20 = v2.dot(v0); + float d21 = v2.dot(v1); + float denom = (d00 * d11 - d01 * d01); + if (denom == 0) { + return Vector3(); //invalid triangle, return empty + } + float v = (d11 * d20 - d01 * d21) / denom; + float w = (d00 * d21 - d01 * d20) / denom; + float u = 1.0f - v - w; + return Vector3(u, v, w); + } + + static Color tetrahedron_get_barycentric_coords(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_c, const Vector3 &p_d, const Vector3 &p_pos) { + Vector3 vap = p_pos - p_a; + Vector3 vbp = p_pos - p_b; + + Vector3 vab = p_b - p_a; + Vector3 vac = p_c - p_a; + Vector3 vad = p_d - p_a; + + Vector3 vbc = p_c - p_b; + Vector3 vbd = p_d - p_b; + // ScTP computes the scalar triple product +#define STP(m_a, m_b, m_c) ((m_a).dot((m_b).cross((m_c)))) + float va6 = STP(vbp, vbd, vbc); + float vb6 = STP(vap, vac, vad); + float vc6 = STP(vap, vad, vab); + float vd6 = STP(vap, vab, vac); + float v6 = 1 / STP(vab, vac, vad); + return Color(va6 * v6, vb6 * v6, vc6 * v6, vd6 * v6); +#undef STP + } + private: static Vector<Vector<Point2>> _polypaths_do_operation(PolyBooleanOperation p_op, const Vector<Point2> &p_polypath_a, const Vector<Point2> &p_polypath_b, bool is_a_open = false); static Vector<Vector<Point2>> _polypath_offset(const Vector<Point2> &p_polypath, real_t p_delta, PolyJoinType p_join_type, PolyEndType p_end_type); diff --git a/core/math/math_funcs.h b/core/math/math_funcs.h index 3e1eb14a6a..bd13c82894 100644 --- a/core/math/math_funcs.h +++ b/core/math/math_funcs.h @@ -233,12 +233,14 @@ public: static _ALWAYS_INLINE_ float range_lerp(float p_value, float p_istart, float p_istop, float p_ostart, float p_ostop) { return Math::lerp(p_ostart, p_ostop, Math::inverse_lerp(p_istart, p_istop, p_value)); } static _ALWAYS_INLINE_ double smoothstep(double p_from, double p_to, double p_weight) { - if (is_equal_approx(p_from, p_to)) return p_from; + if (is_equal_approx(p_from, p_to)) + return p_from; double x = CLAMP((p_weight - p_from) / (p_to - p_from), 0.0, 1.0); return x * x * (3.0 - 2.0 * x); } static _ALWAYS_INLINE_ float smoothstep(float p_from, float p_to, float p_weight) { - if (is_equal_approx(p_from, p_to)) return p_from; + if (is_equal_approx(p_from, p_to)) + return p_from; float x = CLAMP((p_weight - p_from) / (p_to - p_from), 0.0f, 1.0f); return x * x * (3.0f - 2.0f * x); } diff --git a/core/math/octree.h b/core/math/octree.h index ffb405bd0f..7d89c50f69 100644 --- a/core/math/octree.h +++ b/core/math/octree.h @@ -52,7 +52,6 @@ public: private: enum { - NEG = 0, POS = 1, }; @@ -106,49 +105,35 @@ private: // cached for FAST plane check AABB aabb; - uint64_t last_pass; - Octant *parent; - Octant *children[8]; + uint64_t last_pass = 0; + Octant *parent = nullptr; + Octant *children[8] = { nullptr }; - int children_count; // cache for amount of childrens (fast check for removal) - int parent_index; // cache for parent index (fast check for removal) + int children_count = 0; // cache for amount of childrens (fast check for removal) + int parent_index = -1; // cache for parent index (fast check for removal) List<Element *, AL> pairable_elements; List<Element *, AL> elements; - Octant() { - children_count = 0; - parent_index = -1; - last_pass = 0; - parent = nullptr; - for (int i = 0; i < 8; i++) - children[i] = nullptr; - } - - ~Octant() { - - /* - for (int i=0;i<8;i++) - memdelete_notnull(children[i]); - */ - } + Octant() {} + ~Octant() {} }; struct PairData; struct Element { - Octree *octree; + Octree *octree = nullptr; - T *userdata; - int subindex; - bool pairable; - uint32_t pairable_mask; - uint32_t pairable_type; + T *userdata = nullptr; + int subindex = 0; + bool pairable = false; + uint32_t pairable_mask = 0; + uint32_t pairable_type = 0; - uint64_t last_pass; - OctreeElementID _id; - Octant *common_parent; + uint64_t last_pass = 0; + OctreeElementID _id = 0; + Octant *common_parent = nullptr; AABB aabb; AABB container_aabb; @@ -163,17 +148,7 @@ private: List<OctantOwner, AL> octant_owners; - Element() { - last_pass = 0; - _id = 0; - pairable = false; - subindex = 0; - userdata = 0; - octree = 0; - pairable_mask = 0; - pairable_type = 0; - common_parent = nullptr; - } + Element() {} }; struct PairData { diff --git a/core/math/plane.cpp b/core/math/plane.cpp index a3818698bc..26ac0aac47 100644 --- a/core/math/plane.cpp +++ b/core/math/plane.cpp @@ -153,6 +153,10 @@ bool Plane::intersects_segment(const Vector3 &p_begin, const Vector3 &p_end, Vec /* misc */ +bool Plane::is_equal_approx_any_side(const Plane &p_plane) const { + return (normal.is_equal_approx(p_plane.normal) && Math::is_equal_approx(d, p_plane.d)) || (normal.is_equal_approx(-p_plane.normal) && Math::is_equal_approx(d, -p_plane.d)); +} + bool Plane::is_equal_approx(const Plane &p_plane) const { return normal.is_equal_approx(p_plane.normal) && Math::is_equal_approx(d, p_plane.d); diff --git a/core/math/plane.h b/core/math/plane.h index 771c8fc705..f4f205465f 100644 --- a/core/math/plane.h +++ b/core/math/plane.h @@ -36,7 +36,7 @@ class Plane { public: Vector3 normal; - real_t d; + real_t d = 0; void set_normal(const Vector3 &p_normal); _FORCE_INLINE_ Vector3 get_normal() const { return normal; }; ///Point is coplanar, CMP_EPSILON for precision @@ -56,7 +56,7 @@ public: /* intersections */ - bool intersect_3(const Plane &p_plane1, const Plane &p_plane2, Vector3 *r_result = 0) const; + bool intersect_3(const Plane &p_plane1, const Plane &p_plane2, Vector3 *r_result = nullptr) const; bool intersects_ray(const Vector3 &p_from, const Vector3 &p_dir, Vector3 *p_intersection) const; bool intersects_segment(const Vector3 &p_begin, const Vector3 &p_end, Vector3 *p_intersection) const; @@ -69,13 +69,13 @@ public: Plane operator-() const { return Plane(-normal, -d); } bool is_equal_approx(const Plane &p_plane) const; + bool is_equal_approx_any_side(const Plane &p_plane) const; _FORCE_INLINE_ bool operator==(const Plane &p_plane) const; _FORCE_INLINE_ bool operator!=(const Plane &p_plane) const; operator String() const; - _FORCE_INLINE_ Plane() : - d(0) {} + _FORCE_INLINE_ Plane() {} _FORCE_INLINE_ Plane(real_t p_a, real_t p_b, real_t p_c, real_t p_d) : normal(p_a, p_b, p_c), d(p_d) {} diff --git a/core/math/quat.cpp b/core/math/quat.cpp index 61cd41b23d..6fbea70279 100644 --- a/core/math/quat.cpp +++ b/core/math/quat.cpp @@ -206,7 +206,8 @@ Quat Quat::slerpni(const Quat &q, const real_t &t) const { real_t dot = from.dot(q); - if (Math::absf(dot) > 0.9999) return from; + if (Math::absf(dot) > 0.9999) + return from; real_t theta = Math::acos(dot), sinT = 1.0 / Math::sin(theta), diff --git a/core/math/quat.h b/core/math/quat.h index b3135ad1ca..1ca6fe7ce3 100644 --- a/core/math/quat.h +++ b/core/math/quat.h @@ -40,7 +40,7 @@ class Quat { public: - real_t x, y, z, w; + real_t x = 0, y = 0, z = 0, w = 1; _FORCE_INLINE_ real_t length_squared() const; bool is_equal_approx(const Quat &p_quat) const; @@ -112,7 +112,9 @@ public: z = p_z; w = p_w; } - inline Quat(real_t p_x, real_t p_y, real_t p_z, real_t p_w) : + + _FORCE_INLINE_ Quat() {} + _FORCE_INLINE_ Quat(real_t p_x, real_t p_y, real_t p_z, real_t p_w) : x(p_x), y(p_y), z(p_z), @@ -157,13 +159,6 @@ public: w = s * 0.5; } } - - inline Quat() : - x(0), - y(0), - z(0), - w(1) { - } }; real_t Quat::dot(const Quat &q) const { diff --git a/core/math/quick_hull.h b/core/math/quick_hull.h index 173f919a73..89061ab415 100644 --- a/core/math/quick_hull.h +++ b/core/math/quick_hull.h @@ -75,18 +75,12 @@ public: private: struct FaceConnect { - List<Face>::Element *left, *right; - FaceConnect() { - left = nullptr; - right = nullptr; - } + List<Face>::Element *left, *right = nullptr; + FaceConnect() {} }; struct RetFaceConnect { - List<Geometry::MeshData::Face>::Element *left, *right; - RetFaceConnect() { - left = nullptr; - right = nullptr; - } + List<Geometry::MeshData::Face>::Element *left, *right = nullptr; + RetFaceConnect() {} }; public: diff --git a/core/math/random_number_generator.cpp b/core/math/random_number_generator.cpp index 1a1bffb562..67f4c0b14a 100644 --- a/core/math/random_number_generator.cpp +++ b/core/math/random_number_generator.cpp @@ -30,8 +30,6 @@ #include "random_number_generator.h" -RandomNumberGenerator::RandomNumberGenerator() {} - void RandomNumberGenerator::_bind_methods() { ClassDB::bind_method(D_METHOD("set_seed", "seed"), &RandomNumberGenerator::set_seed); ClassDB::bind_method(D_METHOD("get_seed"), &RandomNumberGenerator::get_seed); diff --git a/core/math/random_number_generator.h b/core/math/random_number_generator.h index e7f188bb42..2b125433b3 100644 --- a/core/math/random_number_generator.h +++ b/core/math/random_number_generator.h @@ -65,7 +65,7 @@ public: return ret % (to - from + 1) + from; } - RandomNumberGenerator(); + RandomNumberGenerator() {} }; #endif // RANDOM_NUMBER_GENERATOR_H diff --git a/core/math/rect2.h b/core/math/rect2.h index 30dbfdbbe5..a3f3634bfb 100644 --- a/core/math/rect2.h +++ b/core/math/rect2.h @@ -393,11 +393,12 @@ struct Rect2i { operator String() const { return String(position) + ", " + String(size); } operator Rect2() const { return Rect2(position, size); } + + Rect2i() {} Rect2i(const Rect2 &p_r2) : position(p_r2.position), size(p_r2.size) { } - Rect2i() {} Rect2i(int p_x, int p_y, int p_width, int p_height) : position(Point2(p_x, p_y)), size(Size2(p_width, p_height)) { diff --git a/core/math/triangle_mesh.cpp b/core/math/triangle_mesh.cpp index 5c66721b9d..0f7350a260 100644 --- a/core/math/triangle_mesh.cpp +++ b/core/math/triangle_mesh.cpp @@ -558,14 +558,16 @@ bool TriangleMesh::intersect_convex_shape(const Plane *p_planes, int p_plane_cou if (p.intersects_segment(point, next_point, &res)) { bool inisde = true; for (int k = 0; k < p_plane_count; k++) { - if (k == i) continue; + if (k == i) + continue; const Plane &pp = p_planes[k]; if (pp.is_point_over(res)) { inisde = false; break; } } - if (inisde) return true; + if (inisde) + return true; } if (p.is_point_over(point)) { @@ -573,7 +575,8 @@ bool TriangleMesh::intersect_convex_shape(const Plane *p_planes, int p_plane_cou break; } } - if (over) return true; + if (over) + return true; } stack[level] = (VISIT_DONE_BIT << VISITED_BIT_SHIFT) | node; @@ -652,7 +655,8 @@ bool TriangleMesh::inside_convex_shape(const Plane *p_planes, int p_plane_count, case TEST_AABB_BIT: { bool intersects = scale.xform(b.aabb).intersects_convex_shape(p_planes, p_plane_count, p_points, p_point_count); - if (!intersects) return false; + if (!intersects) + return false; bool inside = scale.xform(b.aabb).inside_convex_shape(p_planes, p_plane_count); if (inside) { @@ -667,7 +671,8 @@ bool TriangleMesh::inside_convex_shape(const Plane *p_planes, int p_plane_count, Vector3 point = scale.xform(vertexptr[s.indices[j]]); for (int i = 0; i < p_plane_count; i++) { const Plane &p = p_planes[i]; - if (p.is_point_over(point)) return false; + if (p.is_point_over(point)) + return false; } } diff --git a/core/math/triangulate.cpp b/core/math/triangulate.cpp index cbcb232745..ae278b034d 100644 --- a/core/math/triangulate.cpp +++ b/core/math/triangulate.cpp @@ -103,13 +103,16 @@ bool Triangulate::snip(const Vector<Vector2> &p_contour, int u, int v, int w, in // To avoid that we allow zero-area triangles if all else failed. float threshold = relaxed ? -CMP_EPSILON : CMP_EPSILON; - if (threshold > (((Bx - Ax) * (Cy - Ay)) - ((By - Ay) * (Cx - Ax)))) return false; + if (threshold > (((Bx - Ax) * (Cy - Ay)) - ((By - Ay) * (Cx - Ax)))) + return false; for (p = 0; p < n; p++) { - if ((p == u) || (p == v) || (p == w)) continue; + if ((p == u) || (p == v) || (p == w)) + continue; Px = contour[V[p]].x; Py = contour[V[p]].y; - if (is_inside_triangle(Ax, Ay, Bx, By, Cx, Cy, Px, Py, relaxed)) return false; + if (is_inside_triangle(Ax, Ay, Bx, By, Cx, Cy, Px, Py, relaxed)) + return false; } return true; @@ -119,7 +122,8 @@ bool Triangulate::triangulate(const Vector<Vector2> &contour, Vector<int> &resul /* allocate and initialize list of Vertices in polygon */ int n = contour.size(); - if (n < 3) return false; + if (n < 3) + return false; Vector<int> V; V.resize(n); @@ -161,11 +165,14 @@ bool Triangulate::triangulate(const Vector<Vector2> &contour, Vector<int> &resul /* three consecutive vertices in current polygon, <u,v,w> */ int u = v; - if (nv <= u) u = 0; /* previous */ + if (nv <= u) + u = 0; /* previous */ v = u + 1; - if (nv <= v) v = 0; /* new v */ + if (nv <= v) + v = 0; /* new v */ int w = v + 1; - if (nv <= w) w = 0; /* next */ + if (nv <= w) + w = 0; /* next */ if (snip(contour, u, v, w, nv, V, relaxed)) { int a, b, c, s, t; diff --git a/core/math/vector2.h b/core/math/vector2.h index c0057f2543..5a3e6a0660 100644 --- a/core/math/vector2.h +++ b/core/math/vector2.h @@ -44,11 +44,11 @@ struct Vector2 { }; union { - real_t x; + real_t x = 0; real_t width; }; union { - real_t y; + real_t y = 0; real_t height; }; @@ -142,11 +142,11 @@ struct Vector2 { operator String() const { return String::num(x) + ", " + String::num(y); } + _FORCE_INLINE_ Vector2() {} _FORCE_INLINE_ Vector2(real_t p_x, real_t p_y) { x = p_x; y = p_y; } - _FORCE_INLINE_ Vector2() { x = y = 0; } }; _FORCE_INLINE_ Vector2 Vector2::plane_project(real_t p_d, const Vector2 &p_vec) const { @@ -260,11 +260,11 @@ struct Vector2i { }; union { - int x; + int x = 0; int width; }; union { - int y; + int y = 0; int height; }; @@ -307,6 +307,8 @@ struct Vector2i { operator String() const { return String::num(x) + ", " + String::num(y); } operator Vector2() const { return Vector2(x, y); } + + inline Vector2i() {} inline Vector2i(const Vector2 &p_vec2) { x = (int)p_vec2.x; y = (int)p_vec2.y; @@ -315,10 +317,6 @@ struct Vector2i { x = p_x; y = p_y; } - inline Vector2i() { - x = 0; - y = 0; - } }; typedef Vector2i Size2i; diff --git a/core/math/vector3.h b/core/math/vector3.h index a5e9d09208..7131063e04 100644 --- a/core/math/vector3.h +++ b/core/math/vector3.h @@ -52,7 +52,7 @@ struct Vector3 { real_t z; }; - real_t coord[3]; + real_t coord[3] = { 0 }; }; _FORCE_INLINE_ const real_t &operator[](int p_axis) const { @@ -152,18 +152,17 @@ struct Vector3 { return Vector3i(x, y, z); } + _FORCE_INLINE_ Vector3() {} _FORCE_INLINE_ Vector3(const Vector3i &p_ivec) { x = p_ivec.x; y = p_ivec.y; z = p_ivec.z; } - _FORCE_INLINE_ Vector3(real_t p_x, real_t p_y, real_t p_z) { x = p_x; y = p_y; z = p_z; } - _FORCE_INLINE_ Vector3() { x = y = z = 0; } }; Vector3 Vector3::cross(const Vector3 &p_b) const { diff --git a/core/math/vector3i.h b/core/math/vector3i.h index 6f9754d3b9..60e5b94c12 100644 --- a/core/math/vector3i.h +++ b/core/math/vector3i.h @@ -49,7 +49,7 @@ struct Vector3i { int32_t z; }; - int32_t coord[3]; + int32_t coord[3] = { 0 }; }; _FORCE_INLINE_ const int32_t &operator[](int p_axis) const { @@ -100,12 +100,12 @@ struct Vector3i { operator String() const; + _FORCE_INLINE_ Vector3i() {} _FORCE_INLINE_ Vector3i(int32_t p_x, int32_t p_y, int32_t p_z) { x = p_x; y = p_y; z = p_z; } - _FORCE_INLINE_ Vector3i() { x = y = z = 0; } }; Vector3i Vector3i::abs() const { diff --git a/core/message_queue.cpp b/core/message_queue.cpp index 652c424492..ad4211f3da 100644 --- a/core/message_queue.cpp +++ b/core/message_queue.cpp @@ -345,10 +345,7 @@ MessageQueue::MessageQueue() { ERR_FAIL_COND_MSG(singleton != nullptr, "A MessageQueue singleton already exists."); singleton = this; - flushing = false; - buffer_end = 0; - buffer_max_used = 0; buffer_size = GLOBAL_DEF_RST("memory/limits/message_queue/max_size_kb", DEFAULT_QUEUE_SIZE_KB); ProjectSettings::get_singleton()->set_custom_property_info("memory/limits/message_queue/max_size_kb", PropertyInfo(Variant::INT, "memory/limits/message_queue/max_size_kb", PROPERTY_HINT_RANGE, "1024,4096,1,or_greater")); buffer_size *= 1024; diff --git a/core/message_queue.h b/core/message_queue.h index 9ba748bb42..180e0ce362 100644 --- a/core/message_queue.h +++ b/core/message_queue.h @@ -63,15 +63,15 @@ class MessageQueue { }; uint8_t *buffer; - uint32_t buffer_end; - uint32_t buffer_max_used; + uint32_t buffer_end = 0; + uint32_t buffer_max_used = 0; uint32_t buffer_size; void _call_function(const Callable &p_callable, const Variant *p_args, int p_argcount, bool p_show_error); static MessageQueue *singleton; - bool flushing; + bool flushing = false; public: static MessageQueue *get_singleton(); diff --git a/core/method_bind.cpp b/core/method_bind.cpp index c513de9ca0..854e19cf8a 100644 --- a/core/method_bind.cpp +++ b/core/method_bind.cpp @@ -104,14 +104,6 @@ void MethodBind::_generate_argument_types(int p_count) { MethodBind::MethodBind() { static int last_id = 0; method_id = last_id++; - hint_flags = METHOD_FLAGS_DEFAULT; - argument_count = 0; - default_argument_count = 0; -#ifdef DEBUG_METHODS_ENABLED - argument_types = nullptr; -#endif - _const = false; - _returns = false; } MethodBind::~MethodBind() { diff --git a/core/method_bind.h b/core/method_bind.h index d39f107ba6..0092527a25 100644 --- a/core/method_bind.h +++ b/core/method_bind.h @@ -165,7 +165,8 @@ struct VariantObjectClassChecker<Control *> { #define CHECK_NOARG(m_arg) \ { \ if (p_arg##m_arg.get_type() != Variant::NIL) { \ - if (r_argerror) *r_argerror = (m_arg - 1); \ + if (r_argerror) \ + *r_argerror = (m_arg - 1); \ return CALL_ERROR_EXTRA_ARGUMENT; \ } \ } @@ -207,18 +208,18 @@ struct PtrToArg<wchar_t> { class MethodBind { int method_id; - uint32_t hint_flags; + uint32_t hint_flags = METHOD_FLAGS_DEFAULT; StringName name; Vector<Variant> default_arguments; - int default_argument_count; - int argument_count; + int default_argument_count = 0; + int argument_count = 0; - bool _const; - bool _returns; + bool _const = false; + bool _returns = false; protected: #ifdef DEBUG_METHODS_ENABLED - Variant::Type *argument_types; + Variant::Type *argument_types = nullptr; Vector<StringName> arg_names; #endif void _set_const(bool p_const); @@ -303,12 +304,11 @@ public: typedef Variant (T::*NativeCall)(const Variant **, int, Callable::CallError &); protected: - NativeCall call_method; + NativeCall call_method = nullptr; #ifdef DEBUG_METHODS_ENABLED - MethodInfo arguments; - #endif + public: #ifdef DEBUG_METHODS_ENABLED @@ -383,7 +383,6 @@ public: virtual bool is_vararg() const { return true; } MethodBindVarArg() { - call_method = nullptr; _set_returns(true); } }; diff --git a/core/node_path.cpp b/core/node_path.cpp index 83233622a0..f8001a354a 100644 --- a/core/node_path.cpp +++ b/core/node_path.cpp @@ -187,16 +187,6 @@ NodePath::operator String() const { return ret; } -NodePath::NodePath(const NodePath &p_path) { - - data = nullptr; - - if (p_path.data && p_path.data->refcount.ref()) { - - data = p_path.data; - } -} - Vector<StringName> NodePath::get_names() const { if (data) @@ -285,35 +275,8 @@ NodePath NodePath::get_as_property_path() const { } } -NodePath::NodePath(const Vector<StringName> &p_path, bool p_absolute) { - - data = nullptr; - - if (p_path.size() == 0) - return; - - data = memnew(Data); - data->refcount.init(); - data->absolute = p_absolute; - data->path = p_path; - data->has_slashes = true; - data->hash_cache_valid = false; -} - -NodePath::NodePath(const Vector<StringName> &p_path, const Vector<StringName> &p_subpath, bool p_absolute) { - - data = nullptr; - - if (p_path.size() == 0 && p_subpath.size() == 0) - return; - - data = memnew(Data); - data->refcount.init(); - data->absolute = p_absolute; - data->path = p_path; - data->subpath = p_subpath; - data->has_slashes = true; - data->hash_cache_valid = false; +bool NodePath::is_empty() const { + return !data; } void NodePath::simplify() { @@ -347,10 +310,38 @@ NodePath NodePath::simplified() const { return np; } -NodePath::NodePath(const String &p_path) { +NodePath::NodePath(const Vector<StringName> &p_path, bool p_absolute) { + if (p_path.size() == 0) + return; - data = nullptr; + data = memnew(Data); + data->refcount.init(); + data->absolute = p_absolute; + data->path = p_path; + data->has_slashes = true; + data->hash_cache_valid = false; +} + +NodePath::NodePath(const Vector<StringName> &p_path, const Vector<StringName> &p_subpath, bool p_absolute) { + if (p_path.size() == 0 && p_subpath.size() == 0) + return; + data = memnew(Data); + data->refcount.init(); + data->absolute = p_absolute; + data->path = p_path; + data->subpath = p_subpath; + data->has_slashes = true; + data->hash_cache_valid = false; +} + +NodePath::NodePath(const NodePath &p_path) { + if (p_path.data && p_path.data->refcount.ref()) { + data = p_path.data; + } +} + +NodePath::NodePath(const String &p_path) { if (p_path.length() == 0) return; @@ -373,7 +364,8 @@ NodePath::NodePath(const String &p_path) { String str = path.substr(from, i - from); if (str == "") { - if (path[i] == 0) continue; // Allow end-of-path : + if (path[i] == 0) + continue; // Allow end-of-path : ERR_FAIL_MSG("Invalid NodePath '" + p_path + "'."); } @@ -436,16 +428,6 @@ NodePath::NodePath(const String &p_path) { } } -bool NodePath::is_empty() const { - - return !data; -} -NodePath::NodePath() { - - data = nullptr; -} - NodePath::~NodePath() { - unref(); } diff --git a/core/node_path.h b/core/node_path.h index 76de36cd9f..fb15d017bf 100644 --- a/core/node_path.h +++ b/core/node_path.h @@ -48,7 +48,7 @@ class NodePath { mutable uint32_t hash_cache; }; - mutable Data *data; + mutable Data *data = nullptr; void unref(); void _update_hash_cache() const; @@ -93,7 +93,7 @@ public: NodePath(const Vector<StringName> &p_path, const Vector<StringName> &p_subpath, bool p_absolute); NodePath(const NodePath &p_path); NodePath(const String &p_path); - NodePath(); + NodePath() {} ~NodePath(); }; diff --git a/core/oa_hash_map.h b/core/oa_hash_map.h index 71e3ba9068..b4d9ce4d51 100644 --- a/core/oa_hash_map.h +++ b/core/oa_hash_map.h @@ -58,7 +58,7 @@ private: uint32_t capacity; - uint32_t num_elements; + uint32_t num_elements = 0; static const uint32_t EMPTY_HASH = 0; @@ -90,7 +90,7 @@ private: uint32_t pos = hash % capacity; uint32_t distance = 0; - while (42) { + while (true) { if (hashes[pos] == EMPTY_HASH) { return false; } @@ -118,7 +118,7 @@ private: TKey key = p_key; TValue value = p_value; - while (42) { + while (true) { if (hashes[pos] == EMPTY_HASH) { _construct(pos, hash, key, value); @@ -350,7 +350,6 @@ public: OAHashMap(uint32_t p_initial_capacity = 64) { capacity = p_initial_capacity; - num_elements = 0; keys = memnew_arr(TKey, p_initial_capacity); values = memnew_arr(TValue, p_initial_capacity); diff --git a/core/object.cpp b/core/object.cpp index b0e6f2bdae..9ae2d2dcde 100644 --- a/core/object.cpp +++ b/core/object.cpp @@ -127,11 +127,6 @@ MethodInfo::operator Dictionary() const { return d; } -MethodInfo::MethodInfo() : - flags(METHOD_FLAG_NORMAL), - id(0) { -} - MethodInfo MethodInfo::from_dict(const Dictionary &p_dict) { MethodInfo mi; @@ -165,29 +160,28 @@ MethodInfo MethodInfo::from_dict(const Dictionary &p_dict) { return mi; } +MethodInfo::MethodInfo() : + flags(METHOD_FLAG_NORMAL) {} + MethodInfo::MethodInfo(const String &p_name) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { } MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); } MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); } MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -195,8 +189,7 @@ MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -205,8 +198,7 @@ MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4, const PropertyInfo &p_param5) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -215,28 +207,24 @@ MethodInfo::MethodInfo(const String &p_name, const PropertyInfo &p_param1, const } MethodInfo::MethodInfo(Variant::Type ret) : - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; } MethodInfo::MethodInfo(Variant::Type ret, const String &p_name) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; } MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyInfo &p_param1) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; arguments.push_back(p_param1); } MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; arguments.push_back(p_param1); arguments.push_back(p_param2); @@ -244,8 +232,7 @@ MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyIn MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; arguments.push_back(p_param1); arguments.push_back(p_param2); @@ -254,8 +241,7 @@ MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyIn MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; arguments.push_back(p_param1); arguments.push_back(p_param2); @@ -265,8 +251,7 @@ MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyIn MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4, const PropertyInfo &p_param5) : name(p_name), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { return_val.type = ret; arguments.push_back(p_param1); arguments.push_back(p_param2); @@ -278,23 +263,20 @@ MethodInfo::MethodInfo(Variant::Type ret, const String &p_name, const PropertyIn MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { } MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const PropertyInfo &p_param1) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); } MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); } @@ -302,8 +284,7 @@ MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const Pr MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -312,8 +293,7 @@ MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const Pr MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -323,8 +303,7 @@ MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const Pr MethodInfo::MethodInfo(const PropertyInfo &p_ret, const String &p_name, const PropertyInfo &p_param1, const PropertyInfo &p_param2, const PropertyInfo &p_param3, const PropertyInfo &p_param4, const PropertyInfo &p_param5) : name(p_name), return_val(p_ret), - flags(METHOD_FLAG_NORMAL), - id(0) { + flags(METHOD_FLAG_NORMAL) { arguments.push_back(p_param1); arguments.push_back(p_param2); arguments.push_back(p_param3); @@ -543,7 +522,8 @@ void Object::set_indexed(const Vector<StringName> &p_names, const Variant &p_val } bool valid = false; - if (!r_valid) r_valid = &valid; + if (!r_valid) + r_valid = &valid; List<Variant> value_stack; @@ -1922,32 +1902,19 @@ void Object::set_script_instance_binding(int p_script_language_index, void *p_da void Object::_construct_object(bool p_reference) { type_is_reference = p_reference; - _class_ptr = nullptr; - _block_signals = false; - _predelete_ok = 0; _instance_id = ObjectDB::add_instance(this); - _can_translate = true; - _is_queued_for_deletion = false; - _emitting = false; - instance_binding_count = 0; memset(_script_instance_bindings, 0, sizeof(void *) * MAX_SCRIPT_INSTANCE_BINDINGS); - script_instance = nullptr; -#ifdef TOOLS_ENABLED - - _edited = false; - _edited_version = 0; -#endif #ifdef DEBUG_ENABLED _lock_index.init(1); #endif } + Object::Object(bool p_reference) { _construct_object(p_reference); } Object::Object() { - _construct_object(false); } diff --git a/core/object.h b/core/object.h index b40aef2a42..20defae095 100644 --- a/core/object.h +++ b/core/object.h @@ -139,12 +139,12 @@ enum PropertyUsageFlags { struct PropertyInfo { - Variant::Type type; + Variant::Type type = Variant::NIL; String name; StringName class_name; //for classes - PropertyHint hint; + PropertyHint hint = PROPERTY_HINT_NONE; String hint_string; - uint32_t usage; + uint32_t usage = PROPERTY_USAGE_DEFAULT; _FORCE_INLINE_ PropertyInfo added_usage(int p_fl) const { PropertyInfo pi = *this; @@ -156,11 +156,7 @@ struct PropertyInfo { static PropertyInfo from_dict(const Dictionary &p_dict); - PropertyInfo() : - type(Variant::NIL), - hint(PROPERTY_HINT_NONE), - usage(PROPERTY_USAGE_DEFAULT) { - } + PropertyInfo() {} PropertyInfo(Variant::Type p_type, const String p_name, PropertyHint p_hint = PROPERTY_HINT_NONE, const String &p_hint_string = "", uint32_t p_usage = PROPERTY_USAGE_DEFAULT, const StringName &p_class_name = StringName()) : type(p_type), @@ -178,10 +174,7 @@ struct PropertyInfo { PropertyInfo(const StringName &p_class_name) : type(Variant::OBJECT), - class_name(p_class_name), - hint(PROPERTY_HINT_NONE), - usage(PROPERTY_USAGE_DEFAULT) { - } + class_name(p_class_name) {} bool operator==(const PropertyInfo &p_info) const { return ((type == p_info.type) && @@ -203,8 +196,8 @@ struct MethodInfo { String name; PropertyInfo return_val; - uint32_t flags; - int id; + uint32_t flags; // NOLINT - prevent clang-tidy to assign method_bind.h constant here, it should stay in .cpp. + int id = 0; List<PropertyInfo> arguments; Vector<Variant> default_arguments; @@ -214,6 +207,7 @@ struct MethodInfo { operator Dictionary() const; static MethodInfo from_dict(const Dictionary &p_dict); + MethodInfo(); MethodInfo(const String &p_name); MethodInfo(const String &p_name, const PropertyInfo &p_param1); @@ -346,7 +340,8 @@ protected: return (bool (Object::*)(const StringName &, const Variant &)) & m_class::_set; \ } \ virtual bool _setv(const StringName &p_name, const Variant &p_property) { \ - if (m_inherits::_setv(p_name, p_property)) return true; \ + if (m_inherits::_setv(p_name, p_property)) \ + return true; \ if (m_class::_get_set() != m_inherits::_get_set()) { \ return _set(p_name, p_property); \ } \ @@ -415,14 +410,13 @@ public: ::Signal signal; Callable callable; - uint32_t flags; + uint32_t flags = 0; Vector<Variant> binds; bool operator<(const Connection &p_conn) const; operator Variant() const; - Connection() { - flags = 0; - } + + Connection() {} Connection(const Variant &p_variant); }; @@ -440,16 +434,13 @@ private: struct SignalData { struct Slot { - - int reference_count; + int reference_count = 0; Connection conn; - List<Connection>::Element *cE; - Slot() { reference_count = 0; } + List<Connection>::Element *cE = nullptr; }; MethodInfo user; VMap<Callable, Slot> slot_map; - SignalData() {} }; HashMap<StringName, SignalData> signal_map; @@ -457,24 +448,24 @@ private: #ifdef DEBUG_ENABLED SafeRefCount _lock_index; #endif - bool _block_signals; - int _predelete_ok; + bool _block_signals = false; + int _predelete_ok = 0; Set<Object *> change_receptors; ObjectID _instance_id; bool _predelete(); void _postinitialize(); - bool _can_translate; - bool _emitting; + bool _can_translate = true; + bool _emitting = false; #ifdef TOOLS_ENABLED - bool _edited; - uint32_t _edited_version; + bool _edited = false; + uint32_t _edited_version = 0; Set<String> editor_section_folding; #endif - ScriptInstance *script_instance; + ScriptInstance *script_instance = nullptr; Variant script; //reference does not yet exist, store it in a Dictionary metadata; mutable StringName _class_name; - mutable const StringName *_class_ptr; + mutable const StringName *_class_ptr = nullptr; void _add_user_signal(const String &p_name, const Array &p_args = Array()); bool _has_user_signal(const StringName &p_name) const; @@ -493,8 +484,9 @@ private: friend class Reference; bool type_is_reference = false; - uint32_t instance_binding_count; + uint32_t instance_binding_count = 0; void *_script_instance_bindings[MAX_SCRIPT_INSTANCE_BINDINGS]; + Object(bool p_reference); protected: @@ -502,14 +494,14 @@ protected: virtual bool _setv(const StringName &p_name, const Variant &p_property) { return false; }; virtual bool _getv(const StringName &p_name, Variant &r_property) const { return false; }; virtual void _get_property_listv(List<PropertyInfo> *p_list, bool p_reversed) const {}; - virtual void _notificationv(int p_notification, bool p_reversed){}; + virtual void _notificationv(int p_notification, bool p_reversed) {} static String _get_category() { return ""; } static void _bind_methods(); bool _set(const StringName &p_name, const Variant &p_property) { return false; }; bool _get(const StringName &p_name, Variant &r_property) const { return false; }; void _get_property_list(List<PropertyInfo> *p_list) const {}; - void _notification(int p_notification){}; + void _notification(int p_notification) {} _FORCE_INLINE_ static void (*_get_bind_methods())() { return &Object::_bind_methods; @@ -558,7 +550,7 @@ protected: public: //should be protected, but bug in clang++ static void initialize_class(); - _FORCE_INLINE_ static void register_custom_data_to_otdb(){}; + _FORCE_INLINE_ static void register_custom_data_to_otdb() {} public: #ifdef TOOLS_ENABLED @@ -578,8 +570,8 @@ public: bool _is_gpl_reversed() const { return false; } _FORCE_INLINE_ ObjectID get_instance_id() const { return _instance_id; } - // this is used for editors + // this is used for editors void add_change_receptor(Object *p_receptor); void remove_change_receptor(Object *p_receptor); @@ -612,7 +604,6 @@ public: } enum { - NOTIFICATION_POSTINITIALIZE = 0, NOTIFICATION_PREDELETE = 1 }; @@ -722,7 +713,7 @@ public: StringName tr(const StringName &p_message) const; // translate message (internationalization) - bool _is_queued_for_deletion; // set to true by SceneTree::queue_delete() + bool _is_queued_for_deletion = false; // set to true by SceneTree::queue_delete() bool is_queued_for_deletion() const; _FORCE_INLINE_ void set_message_translation(bool p_enable) { _can_translate = p_enable; } @@ -744,6 +735,7 @@ public: void clear_internal_resource_paths(); _ALWAYS_INLINE_ bool is_reference() const { return type_is_reference; } + Object(); virtual ~Object(); }; diff --git a/core/ordered_hash_map.h b/core/ordered_hash_map.h index 05debd529f..1f1be71741 100644 --- a/core/ordered_hash_map.h +++ b/core/ordered_hash_map.h @@ -55,9 +55,9 @@ public: class Element { friend class OrderedHashMap<K, V, Hasher, Comparator, MIN_HASH_TABLE_POWER, RELATIONSHIP>; - typename InternalList::Element *list_element; - typename InternalList::Element *prev_element; - typename InternalList::Element *next_element; + typename InternalList::Element *list_element = nullptr; + typename InternalList::Element *prev_element = nullptr; + typename InternalList::Element *next_element = nullptr; Element(typename InternalList::Element *p_element) { list_element = p_element; @@ -69,11 +69,7 @@ public: } public: - _FORCE_INLINE_ Element() : - list_element(nullptr), - prev_element(nullptr), - next_element(nullptr) { - } + _FORCE_INLINE_ Element() {} Element next() const { return Element(next_element); @@ -136,16 +132,14 @@ public: class ConstElement { friend class OrderedHashMap<K, V, Hasher, Comparator, MIN_HASH_TABLE_POWER, RELATIONSHIP>; - const typename InternalList::Element *list_element; + const typename InternalList::Element *list_element = nullptr; ConstElement(const typename InternalList::Element *p_element) : list_element(p_element) { } public: - _FORCE_INLINE_ ConstElement() : - list_element(nullptr) { - } + _FORCE_INLINE_ ConstElement() {} ConstElement(const ConstElement &other) : list_element(other.list_element) { @@ -299,8 +293,7 @@ public: _copy_from(p_map); } - _FORCE_INLINE_ OrderedHashMap() { - } + _FORCE_INLINE_ OrderedHashMap() {} }; #endif // ORDERED_HASH_MAP_H diff --git a/core/os/dir_access.cpp b/core/os/dir_access.cpp index 94c8cd5d73..53b959a580 100644 --- a/core/os/dir_access.cpp +++ b/core/os/dir_access.cpp @@ -39,18 +39,24 @@ String DirAccess::_get_root_path() const { switch (_access_type) { - case ACCESS_RESOURCES: return ProjectSettings::get_singleton()->get_resource_path(); - case ACCESS_USERDATA: return OS::get_singleton()->get_user_data_dir(); - default: return ""; + case ACCESS_RESOURCES: + return ProjectSettings::get_singleton()->get_resource_path(); + case ACCESS_USERDATA: + return OS::get_singleton()->get_user_data_dir(); + default: + return ""; } } String DirAccess::_get_root_string() const { switch (_access_type) { - case ACCESS_RESOURCES: return "res://"; - case ACCESS_USERDATA: return "user://"; - default: return ""; + case ACCESS_RESOURCES: + return "res://"; + case ACCESS_USERDATA: + return "user://"; + default: + return ""; } } @@ -220,7 +226,8 @@ String DirAccess::fix_path(String p_path) const { return p_path; } break; - case ACCESS_MAX: break; // Can't happen, but silences warning + case ACCESS_MAX: + break; // Can't happen, but silences warning } return p_path; @@ -439,11 +446,3 @@ bool DirAccess::exists(String p_dir) { memdelete(da); return valid; } - -DirAccess::DirAccess() { - - _access_type = ACCESS_FILESYSTEM; -} - -DirAccess::~DirAccess() { -} diff --git a/core/os/dir_access.h b/core/os/dir_access.h index 60eb553968..cac0d0ec7c 100644 --- a/core/os/dir_access.h +++ b/core/os/dir_access.h @@ -47,7 +47,7 @@ public: typedef DirAccess *(*CreateFunc)(); private: - AccessType _access_type; + AccessType _access_type = ACCESS_FILESYSTEM; static CreateFunc create_func[ACCESS_MAX]; ///< set this to instance a filesystem object Error _copy_dir(DirAccess *p_target_da, String p_to, int p_chmod_flags); @@ -110,16 +110,6 @@ public: static String get_full_path(const String &p_path, AccessType p_access); static DirAccess *create_for_path(const String &p_path); - /* - enum DirType { - - FILE_TYPE_INVALID, - FILE_TYPE_FILE, - FILE_TYPE_DIR, - }; - - //virtual DirType get_file_type() const=0; -*/ static DirAccess *create(AccessType p_access); template <class T> @@ -130,8 +120,8 @@ public: static DirAccess *open(const String &p_path, Error *r_error = nullptr); - DirAccess(); - virtual ~DirAccess(); + DirAccess() {} + virtual ~DirAccess() {} }; struct DirAccessRef { @@ -142,10 +132,13 @@ struct DirAccessRef { } operator bool() const { return f != nullptr; } + DirAccess *f; + DirAccessRef(DirAccess *fa) { f = fa; } ~DirAccessRef() { - if (f) memdelete(f); + if (f) + memdelete(f); } }; diff --git a/core/os/file_access.cpp b/core/os/file_access.cpp index 3922f031b7..cb8705f706 100644 --- a/core/os/file_access.cpp +++ b/core/os/file_access.cpp @@ -163,7 +163,8 @@ String FileAccess::fix_path(const String &p_path) const { return r_path; } break; - case ACCESS_MAX: break; // Can't happen, but silences warning + case ACCESS_MAX: + break; // Can't happen, but silences warning } return r_path; @@ -277,7 +278,7 @@ class CharBuffer { char *buffer; int capacity; - int written; + int written = 0; bool grow() { @@ -304,8 +305,7 @@ class CharBuffer { public: _FORCE_INLINE_ CharBuffer() : buffer(stack_buffer), - capacity(sizeof(stack_buffer) / sizeof(char)), - written(0) { + capacity(sizeof(stack_buffer) / sizeof(char)) { } _FORCE_INLINE_ void push_back(char c) { @@ -715,10 +715,3 @@ String FileAccess::get_sha256(const String &p_file) { memdelete(f); return String::hex_encode_buffer(hash, 32); } - -FileAccess::FileAccess() { - - endian_swap = false; - real_is_double = false; - _access_type = ACCESS_FILESYSTEM; -}; diff --git a/core/os/file_access.h b/core/os/file_access.h index 010cc74a87..0ee29abbc9 100644 --- a/core/os/file_access.h +++ b/core/os/file_access.h @@ -53,8 +53,8 @@ public: typedef void (*FileCloseFailNotify)(const String &); typedef FileAccess *(*CreateFunc)(); - bool endian_swap; - bool real_is_double; + bool endian_swap = false; + bool real_is_double = false; virtual uint32_t _get_unix_permissions(const String &p_file) = 0; virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions) = 0; @@ -69,7 +69,7 @@ protected: private: static bool backup_save; - AccessType _access_type; + AccessType _access_type = ACCESS_FILESYSTEM; static CreateFunc create_func[ACCESS_MAX]; /** default file access creation function for a platform */ template <class T> static FileAccess *_create_builtin() { @@ -176,7 +176,7 @@ public: create_func[p_access] = _create_builtin<T>; } - FileAccess(); + FileAccess() {} virtual ~FileAccess() {} }; @@ -188,11 +188,15 @@ struct FileAccessRef { } operator bool() const { return f != nullptr; } + FileAccess *f; + operator FileAccess *() { return f; } + FileAccessRef(FileAccess *fa) { f = fa; } ~FileAccessRef() { - if (f) memdelete(f); + if (f) + memdelete(f); } }; diff --git a/core/os/main_loop.cpp b/core/os/main_loop.cpp index 0d1a080682..b29e3f6142 100644 --- a/core/os/main_loop.cpp +++ b/core/os/main_loop.cpp @@ -60,12 +60,6 @@ void MainLoop::set_init_script(const Ref<Script> &p_init_script) { init_script = p_init_script; } -MainLoop::MainLoop() { -} - -MainLoop::~MainLoop() { -} - void MainLoop::init() { if (init_script.is_valid()) diff --git a/core/os/main_loop.h b/core/os/main_loop.h index 8f6c8c91b1..c7cc8f01e0 100644 --- a/core/os/main_loop.h +++ b/core/os/main_loop.h @@ -64,8 +64,8 @@ public: void set_init_script(const Ref<Script> &p_init_script); - MainLoop(); - virtual ~MainLoop(); + MainLoop() {} + virtual ~MainLoop() {} }; #endif // MAIN_LOOP_H diff --git a/core/os/memory.cpp b/core/os/memory.cpp index d921c10ad4..0e48592cc1 100644 --- a/core/os/memory.cpp +++ b/core/os/memory.cpp @@ -204,8 +204,6 @@ uint64_t Memory::get_mem_max_usage() { } _GlobalNil::_GlobalNil() { - - color = 1; left = this; right = this; parent = this; diff --git a/core/os/memory.h b/core/os/memory.h index dcaedd92ba..03c6a80e89 100644 --- a/core/os/memory.h +++ b/core/os/memory.h @@ -130,9 +130,10 @@ void memdelete_allocator(T *p_class) { A::free(p_class); } -#define memdelete_notnull(m_v) \ - { \ - if (m_v) memdelete(m_v); \ +#define memdelete_notnull(m_v) \ + { \ + if (m_v) \ + memdelete(m_v); \ } #define memnew_arr(m_class, m_count) memnew_arr_template<m_class>(m_count) @@ -141,13 +142,13 @@ template <typename T> T *memnew_arr_template(size_t p_elements, const char *p_descr = "") { if (p_elements == 0) - return 0; + return nullptr; /** overloading operator new[] cannot be done , because it may not return the real allocated address (it may pad the 'element count' before the actual array). Because of that, it must be done by hand. This is the same strategy used by std::vector, and the Vector class, so it should be safe.*/ size_t len = sizeof(T) * p_elements; uint64_t *mem = (uint64_t *)Memory::alloc_static(len, true); - T *failptr = 0; //get rid of a warning + T *failptr = nullptr; //get rid of a warning ERR_FAIL_COND_V(!mem, failptr); *(mem - 1) = p_elements; @@ -193,7 +194,7 @@ void memdelete_arr(T *p_class) { struct _GlobalNil { - int color; + int color = 1; _GlobalNil *right; _GlobalNil *left; _GlobalNil *parent; diff --git a/core/os/mutex.h b/core/os/mutex.h index 526549dd93..69a15f96de 100644 --- a/core/os/mutex.h +++ b/core/os/mutex.h @@ -83,7 +83,7 @@ extern template class MutexLock<MutexImpl<std::mutex>>; class FakeMutex { - FakeMutex(){}; + FakeMutex() {} }; template <class MutexT> diff --git a/core/os/os.cpp b/core/os/os.cpp index 425132fbec..cdc9f1e0ff 100644 --- a/core/os/os.cpp +++ b/core/os/os.cpp @@ -508,25 +508,10 @@ void OS::close_midi_inputs() { OS::OS() { void *volatile stack_bottom; - restart_on_exit = false; singleton = this; - _keep_screen_on = true; // set default value to true, because this had been true before godot 2.0. - low_processor_usage_mode = false; - low_processor_usage_mode_sleep_usec = 10000; - _verbose_stdout = false; - _no_window = false; - _exit_code = 0; - _render_thread_mode = RENDER_THREAD_SAFE; - - _allow_hidpi = false; - _allow_layered = false; _stack_bottom = (void *)(&stack_bottom); - _logger = nullptr; - - has_server_feature_callback = nullptr; - Vector<Logger *> loggers; loggers.push_back(memnew(StdLogger)); _set_logger(memnew(CompositeLogger(loggers))); diff --git a/core/os/os.h b/core/os/os.h index 38114e6814..4340823cf4 100644 --- a/core/os/os.h +++ b/core/os/os.h @@ -46,17 +46,17 @@ class OS { static OS *singleton; String _execpath; List<String> _cmdline; - bool _keep_screen_on; - bool low_processor_usage_mode; - int low_processor_usage_mode_sleep_usec; - bool _verbose_stdout; + bool _keep_screen_on = true; // set default value to true, because this had been true before godot 2.0. + bool low_processor_usage_mode = false; + int low_processor_usage_mode_sleep_usec = 10000; + bool _verbose_stdout = false; String _local_clipboard; uint64_t _msec_splash; - bool _no_window; - int _exit_code; + bool _no_window = false; + int _exit_code = 0; int _orientation; - bool _allow_hidpi; - bool _allow_layered; + bool _allow_hidpi = false; + bool _allow_layered = false; bool _use_vsync; bool _vsync_via_compositor; bool _disable_wintab; @@ -65,9 +65,9 @@ class OS { void *_stack_bottom; - CompositeLogger *_logger; + CompositeLogger *_logger = nullptr; - bool restart_on_exit; + bool restart_on_exit = false; List<String> restart_commandline; protected: @@ -87,8 +87,8 @@ public: protected: friend class Main; - HasServerFeatureCallback has_server_feature_callback; - RenderThreadMode _render_thread_mode; + HasServerFeatureCallback has_server_feature_callback = nullptr; + RenderThreadMode _render_thread_mode = RENDER_THREAD_SAFE; // functions used by main to initialize/deinitialize the OS void add_logger(Logger *p_logger); diff --git a/core/os/rw_lock.cpp b/core/os/rw_lock.cpp index 1dd2c3bccb..81df7f7ea6 100644 --- a/core/os/rw_lock.cpp +++ b/core/os/rw_lock.cpp @@ -42,6 +42,3 @@ RWLock *RWLock::create() { return create_func(); } - -RWLock::~RWLock() { -} diff --git a/core/os/rw_lock.h b/core/os/rw_lock.h index 64dfbef20c..8dca8a230a 100644 --- a/core/os/rw_lock.h +++ b/core/os/rw_lock.h @@ -48,7 +48,7 @@ public: static RWLock *create(); ///< Create a rwlock - virtual ~RWLock(); + virtual ~RWLock() {} }; class RWLockRead { @@ -58,10 +58,12 @@ class RWLockRead { public: RWLockRead(const RWLock *p_lock) { lock = const_cast<RWLock *>(p_lock); - if (lock) lock->read_lock(); + if (lock) + lock->read_lock(); } ~RWLockRead() { - if (lock) lock->read_unlock(); + if (lock) + lock->read_unlock(); } }; @@ -72,10 +74,12 @@ class RWLockWrite { public: RWLockWrite(RWLock *p_lock) { lock = p_lock; - if (lock) lock->write_lock(); + if (lock) + lock->write_lock(); } ~RWLockWrite() { - if (lock) lock->write_unlock(); + if (lock) + lock->write_unlock(); } }; diff --git a/core/os/thread.cpp b/core/os/thread.cpp index 294b52f00c..a8eb0b2a9f 100644 --- a/core/os/thread.cpp +++ b/core/os/thread.cpp @@ -66,9 +66,3 @@ Error Thread::set_name(const String &p_name) { return ERR_UNAVAILABLE; }; - -Thread::Thread() { -} - -Thread::~Thread() { -} diff --git a/core/os/thread.h b/core/os/thread.h index 76d296bcf7..005217dca7 100644 --- a/core/os/thread.h +++ b/core/os/thread.h @@ -63,7 +63,7 @@ protected: static ID _main_thread_id; - Thread(); + Thread() {} public: virtual ID get_id() const = 0; @@ -74,7 +74,7 @@ public: static void wait_to_finish(Thread *p_thread); ///< waits until thread is finished, and deallocates it. static Thread *create(ThreadCreateCallback p_callback, void *p_user, const Settings &p_settings = Settings()); ///< Static function to create a thread, will call p_callback - virtual ~Thread(); + virtual ~Thread() {} }; #endif // THREAD_H diff --git a/core/packed_data_container.cpp b/core/packed_data_container.cpp index 04deba2c14..17b5905a93 100644 --- a/core/packed_data_container.cpp +++ b/core/packed_data_container.cpp @@ -381,11 +381,6 @@ void PackedDataContainer::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::PACKED_BYTE_ARRAY, "__data__"), "_set_data", "_get_data"); } -PackedDataContainer::PackedDataContainer() { - - datalen = 0; -} - ////////////////// Variant PackedDataContainerRef::_iter_init(const Array &p_iter) { @@ -429,6 +424,3 @@ int PackedDataContainerRef::size() const { return from->_size(offset); }; - -PackedDataContainerRef::PackedDataContainerRef() { -} diff --git a/core/packed_data_container.h b/core/packed_data_container.h index 0f08a1cb7b..00ec4248ee 100644 --- a/core/packed_data_container.h +++ b/core/packed_data_container.h @@ -49,7 +49,7 @@ class PackedDataContainer : public Resource { }; Vector<uint8_t> data; - int datalen; + int datalen = 0; uint32_t _pack(const Variant &p_data, Vector<uint8_t> &tmpdata, Map<String, uint32_t> &string_cache); @@ -78,7 +78,7 @@ public: int size() const; - PackedDataContainer(); + PackedDataContainer() {} }; class PackedDataContainerRef : public Reference { @@ -100,7 +100,7 @@ public: int size() const; virtual Variant getvar(const Variant &p_key, bool *r_valid = nullptr) const; - PackedDataContainerRef(); + PackedDataContainerRef() {} }; #endif // PACKED_DATA_CONTAINER_H diff --git a/core/pool_allocator.cpp b/core/pool_allocator.cpp index b74540395c..8fd67a47d2 100644 --- a/core/pool_allocator.cpp +++ b/core/pool_allocator.cpp @@ -182,7 +182,8 @@ PoolAllocator::ID PoolAllocator::alloc(int p_size) { ERR_FAIL_COND_V(p_size < 1, POOL_ALLOCATOR_INVALID_ID); #ifdef DEBUG_ENABLED - if (p_size > free_mem) OS::get_singleton()->debug_break(); + if (p_size > free_mem) + OS::get_singleton()->debug_break(); #endif ERR_FAIL_COND_V(p_size > free_mem, POOL_ALLOCATOR_INVALID_ID); diff --git a/core/pool_allocator.h b/core/pool_allocator.h index 8c1710ebe0..1cc21afb21 100644 --- a/core/pool_allocator.h +++ b/core/pool_allocator.h @@ -61,10 +61,10 @@ private: struct Entry { - unsigned int pos; - unsigned int len; - unsigned int lock; - unsigned int check; + unsigned int pos = 0; + unsigned int len = 0; + unsigned int lock = 0; + unsigned int check = 0; inline void clear() { pos = 0; @@ -72,7 +72,7 @@ private: lock = 0; check = 0; } - Entry() { clear(); } + Entry() {} }; typedef int EntryArrayPos; diff --git a/core/print_string.h b/core/print_string.h index d83cc35dd6..3e9125bddc 100644 --- a/core/print_string.h +++ b/core/print_string.h @@ -39,16 +39,12 @@ typedef void (*PrintHandlerFunc)(void *, const String &p_string, bool p_error); struct PrintHandlerList { - PrintHandlerFunc printfunc; - void *userdata; + PrintHandlerFunc printfunc = nullptr; + void *userdata = nullptr; - PrintHandlerList *next; + PrintHandlerList *next = nullptr; - PrintHandlerList() { - printfunc = 0; - next = 0; - userdata = 0; - } + PrintHandlerList() {} }; void add_print_handler(PrintHandlerList *p_handler); diff --git a/core/project_settings.cpp b/core/project_settings.cpp index 8829181489..e141e54e61 100644 --- a/core/project_settings.cpp +++ b/core/project_settings.cpp @@ -989,10 +989,6 @@ void ProjectSettings::_bind_methods() { ProjectSettings::ProjectSettings() { singleton = this; - last_order = NO_BUILTIN_ORDER_BASE; - last_builtin_order = 0; - disable_feature_overrides = false; - registering_order = true; Array events; Dictionary action; @@ -1037,7 +1033,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_SPACE); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_BUTTON_0); + joyb->set_button_index(JOY_BUTTON_A); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_accept", action); @@ -1050,7 +1046,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_SPACE); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_BUTTON_3); + joyb->set_button_index(JOY_BUTTON_Y); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_select", action); @@ -1063,7 +1059,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_ESCAPE); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_BUTTON_1); + joyb->set_button_index(JOY_BUTTON_B); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_cancel", action); @@ -1097,7 +1093,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_LEFT); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_DPAD_LEFT); + joyb->set_button_index(JOY_BUTTON_DPAD_LEFT); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_left", action); @@ -1110,7 +1106,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_RIGHT); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_DPAD_RIGHT); + joyb->set_button_index(JOY_BUTTON_DPAD_RIGHT); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_right", action); @@ -1123,7 +1119,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_UP); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_DPAD_UP); + joyb->set_button_index(JOY_BUTTON_DPAD_UP); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_up", action); @@ -1136,7 +1132,7 @@ ProjectSettings::ProjectSettings() { key->set_keycode(KEY_DOWN); events.push_back(key); joyb.instance(); - joyb->set_button_index(JOY_DPAD_DOWN); + joyb->set_button_index(JOY_BUTTON_DPAD_DOWN); events.push_back(joyb); action["events"] = events; GLOBAL_DEF("input/ui_down", action); @@ -1203,8 +1199,6 @@ ProjectSettings::ProjectSettings() { Compression::gzip_level = GLOBAL_DEF("compression/formats/gzip/compression_level", Z_DEFAULT_COMPRESSION); custom_prop_info["compression/formats/gzip/compression_level"] = PropertyInfo(Variant::INT, "compression/formats/gzip/compression_level", PROPERTY_HINT_RANGE, "-1,9,1"); - - using_datapack = false; } ProjectSettings::~ProjectSettings() { diff --git a/core/project_settings.h b/core/project_settings.h index 7b3ca18c62..87f2a8273f 100644 --- a/core/project_settings.h +++ b/core/project_settings.h @@ -50,38 +50,31 @@ public: protected: struct VariantContainer { - int order; - bool persist; + int order = 0; + bool persist = false; Variant variant; Variant initial; - bool hide_from_editor; - bool overridden; - bool restart_if_changed; - VariantContainer() : - order(0), - persist(false), - hide_from_editor(false), - overridden(false), - restart_if_changed(false) { - } + bool hide_from_editor = false; + bool overridden = false; + bool restart_if_changed = false; + + VariantContainer() {} + VariantContainer(const Variant &p_variant, int p_order, bool p_persist = false) : order(p_order), persist(p_persist), - variant(p_variant), - hide_from_editor(false), - overridden(false), - restart_if_changed(false) { + variant(p_variant) { } }; - bool registering_order; - int last_order; - int last_builtin_order; + bool registering_order = true; + int last_order = 0; + int last_builtin_order = NO_BUILTIN_ORDER_BASE; Map<StringName, VariantContainer> props; String resource_path; Map<StringName, PropertyInfo> custom_prop_info; - bool disable_feature_overrides; - bool using_datapack; + bool disable_feature_overrides = false; + bool using_datapack = false; List<String> input_presets; Set<String> custom_features; diff --git a/core/reference.cpp b/core/reference.cpp index dd65ccce69..57b72dcaad 100644 --- a/core/reference.cpp +++ b/core/reference.cpp @@ -109,9 +109,6 @@ Reference::Reference() : refcount_init.init(); } -Reference::~Reference() { -} - Variant WeakRef::get_ref() const { if (ref.is_null()) @@ -138,9 +135,6 @@ void WeakRef::set_ref(const REF &p_ref) { ref = p_ref.is_valid() ? p_ref->get_instance_id() : ObjectID(); } -WeakRef::WeakRef() { -} - void WeakRef::_bind_methods() { ClassDB::bind_method(D_METHOD("get_ref"), &WeakRef::get_ref); diff --git a/core/reference.h b/core/reference.h index 30a93d82a6..5190f6ab11 100644 --- a/core/reference.h +++ b/core/reference.h @@ -52,13 +52,13 @@ public: int reference_get_count() const; Reference(); - ~Reference(); + ~Reference() {} }; template <class T> class Ref { - T *reference; + T *reference = nullptr; void ref(const Ref &p_from) { @@ -189,15 +189,11 @@ public: } Ref(const Ref &p_from) { - - reference = nullptr; ref(p_from); } template <class T_Other> Ref(const Ref<T_Other> &p_from) { - - reference = nullptr; Reference *refb = const_cast<Reference *>(static_cast<const Reference *>(p_from.ptr())); if (!refb) { unref(); @@ -210,26 +206,20 @@ public: } Ref(T *p_reference) { - - reference = nullptr; if (p_reference) ref_pointer(p_reference); } Ref(const Variant &p_variant) { - Object *object = p_variant.get_validated_object(); if (!object) { - reference = nullptr; return; } T *r = Object::cast_to<T>(object); if (r && r->reference()) { reference = r; - } else { - reference = nullptr; } } @@ -252,13 +242,9 @@ public: ref(memnew(T)); } - Ref() { - - reference = nullptr; - } + Ref() {} ~Ref() { - unref(); } }; @@ -279,7 +265,7 @@ public: void set_obj(Object *p_object); void set_ref(const REF &p_ref); - WeakRef(); + WeakRef() {} }; #ifdef PTRCALL_ENABLED diff --git a/core/register_core_types.cpp b/core/register_core_types.cpp index 23f549be1a..f24397be5b 100644 --- a/core/register_core_types.cpp +++ b/core/register_core_types.cpp @@ -65,7 +65,6 @@ #include "core/math/triangle_mesh.h" #include "core/os/main_loop.h" #include "core/packed_data_container.h" -#include "core/path_remap.h" #include "core/project_settings.h" #include "core/translation.h" #include "core/undo_redo.h" @@ -233,8 +232,8 @@ void register_core_settings() { GLOBAL_DEF_RST("network/limits/packet_peer_stream/max_buffer_po2", (16)); ProjectSettings::get_singleton()->set_custom_property_info("network/limits/packet_peer_stream/max_buffer_po2", PropertyInfo(Variant::INT, "network/limits/packet_peer_stream/max_buffer_po2", PROPERTY_HINT_RANGE, "0,64,1,or_greater")); - GLOBAL_DEF("network/ssl/certificates", ""); - ProjectSettings::get_singleton()->set_custom_property_info("network/ssl/certificates", PropertyInfo(Variant::STRING, "network/ssl/certificates", PROPERTY_HINT_FILE, "*.crt")); + GLOBAL_DEF("network/ssl/certificate_bundle_override", ""); + ProjectSettings::get_singleton()->set_custom_property_info("network/ssl/certificate_bundle_override", PropertyInfo(Variant::STRING, "network/ssl/certificate_bundle_override", PROPERTY_HINT_FILE, "*.crt")); } void register_core_singletons() { diff --git a/core/resource.cpp b/core/resource.cpp index 8d5c441b21..f8948e9a59 100644 --- a/core/resource.cpp +++ b/core/resource.cpp @@ -429,17 +429,7 @@ void Resource::_bind_methods() { } Resource::Resource() : - remapped_list(this) { - -#ifdef TOOLS_ENABLED - last_modified_time = 0; - import_last_modified_time = 0; -#endif - - subindex = 0; - local_to_scene = false; - local_scene = nullptr; -} + remapped_list(this) {} Resource::~Resource() { diff --git a/core/resource.h b/core/resource.h index 3b7812c870..7d92b843dc 100644 --- a/core/resource.h +++ b/core/resource.h @@ -57,19 +57,19 @@ class Resource : public Reference { String name; String path_cache; - int subindex; + int subindex = 0; virtual bool _use_builtin_script() const { return true; } #ifdef TOOLS_ENABLED - uint64_t last_modified_time; - uint64_t import_last_modified_time; + uint64_t last_modified_time = 0; + uint64_t import_last_modified_time = 0; String import_path; #endif - bool local_to_scene; + bool local_to_scene = false; friend class SceneState; - Node *local_scene; + Node *local_scene = nullptr; SelfList<Resource> remapped_list; diff --git a/core/rid.h b/core/rid.h index a2f73423a3..ac07eacd08 100644 --- a/core/rid.h +++ b/core/rid.h @@ -37,7 +37,7 @@ class RID_AllocBase; class RID { friend class RID_AllocBase; - uint64_t _id; + uint64_t _id = 0; public: _FORCE_INLINE_ bool operator==(const RID &p_rid) const { @@ -65,9 +65,7 @@ public: _FORCE_INLINE_ uint64_t get_id() const { return _id; } - _FORCE_INLINE_ RID() { - _id = 0; - } + _FORCE_INLINE_ RID() {} }; #endif // RID_H diff --git a/core/rid_owner.h b/core/rid_owner.h index ad6996b9a7..77bbc3c83c 100644 --- a/core/rid_owner.h +++ b/core/rid_owner.h @@ -39,6 +39,7 @@ #include "core/safe_refcount.h" #include "core/set.h" #include "core/spin_lock.h" + #include <stdio.h> #include <typeinfo> @@ -68,15 +69,15 @@ public: template <class T, bool THREAD_SAFE = false> class RID_Alloc : public RID_AllocBase { - T **chunks; - uint32_t **free_list_chunks; - uint32_t **validator_chunks; + T **chunks = nullptr; + uint32_t **free_list_chunks = nullptr; + uint32_t **validator_chunks = nullptr; uint32_t elements_in_chunk; - uint32_t max_alloc; - uint32_t alloc_count; + uint32_t max_alloc = 0; + uint32_t alloc_count = 0; - const char *description; + const char *description = nullptr; SpinLock spin_lock; @@ -288,14 +289,7 @@ public: } RID_Alloc(uint32_t p_target_chunk_byte_size = 4096) { - chunks = nullptr; - free_list_chunks = nullptr; - validator_chunks = nullptr; - elements_in_chunk = sizeof(T) > p_target_chunk_byte_size ? 1 : (p_target_chunk_byte_size / sizeof(T)); - max_alloc = 0; - alloc_count = 0; - description = nullptr; } ~RID_Alloc() { @@ -412,4 +406,5 @@ public: RID_Owner(uint32_t p_target_chunk_byte_size = 4096) : alloc(p_target_chunk_byte_size) {} }; + #endif // RID_OWNER_H diff --git a/core/ring_buffer.h b/core/ring_buffer.h index 620a3a3846..8ef9b1a15c 100644 --- a/core/ring_buffer.h +++ b/core/ring_buffer.h @@ -37,8 +37,8 @@ template <typename T> class RingBuffer { Vector<T> data; - int read_pos; - int write_pos; + int read_pos = 0; + int write_pos = 0; int size_mask; inline int inc(int &p_var, int p_size) const { @@ -214,11 +214,9 @@ public: }; RingBuffer<T>(int p_power = 0) { - read_pos = 0; - write_pos = 0; resize(p_power); }; - ~RingBuffer<T>(){}; + ~RingBuffer<T>() {} }; #endif // RING_BUFFER_H diff --git a/core/script_language.cpp b/core/script_language.cpp index 82cac6bc9a..603b4dc13d 100644 --- a/core/script_language.cpp +++ b/core/script_language.cpp @@ -34,6 +34,7 @@ #include "core/debugger/engine_debugger.h" #include "core/debugger/script_debugger.h" #include "core/project_settings.h" + #include <stdint.h> ScriptLanguage *ScriptServer::_languages[MAX_LANGUAGES]; diff --git a/core/script_language.h b/core/script_language.h index 5cc240efcb..544de26d81 100644 --- a/core/script_language.h +++ b/core/script_language.h @@ -253,14 +253,12 @@ struct ScriptCodeCompletionOption { KIND_FILE_PATH, KIND_PLAIN_TEXT, }; - Kind kind; + Kind kind = KIND_PLAIN_TEXT; String display; String insert_text; RES icon; - ScriptCodeCompletionOption() { - kind = KIND_PLAIN_TEXT; - } + ScriptCodeCompletionOption() {} ScriptCodeCompletionOption(const String &p_text, Kind p_kind) { display = p_text; diff --git a/core/self_list.h b/core/self_list.h index 43aeb44fea..74c585e60e 100644 --- a/core/self_list.h +++ b/core/self_list.h @@ -39,8 +39,8 @@ class SelfList { public: class List { - SelfList<T> *_first; - SelfList<T> *_last; + SelfList<T> *_first = nullptr; + SelfList<T> *_last = nullptr; public: void add(SelfList<T> *p_elem) { @@ -105,23 +105,22 @@ public: _FORCE_INLINE_ SelfList<T> *first() { return _first; } _FORCE_INLINE_ const SelfList<T> *first() const { return _first; } - _FORCE_INLINE_ List() { - _first = nullptr; - _last = nullptr; - } + + _FORCE_INLINE_ List() {} _FORCE_INLINE_ ~List() { ERR_FAIL_COND(_first != nullptr); } }; private: - List *_root; + List *_root = nullptr; T *_self; - SelfList<T> *_next; - SelfList<T> *_prev; + SelfList<T> *_next = nullptr; + SelfList<T> *_prev = nullptr; public: _FORCE_INLINE_ bool in_list() const { return _root; } _FORCE_INLINE_ void remove_from_list() { - if (_root) _root->remove(this); + if (_root) + _root->remove(this); } _FORCE_INLINE_ SelfList<T> *next() { return _next; } _FORCE_INLINE_ SelfList<T> *prev() { return _prev; } @@ -130,15 +129,10 @@ public: _FORCE_INLINE_ T *self() const { return _self; } _FORCE_INLINE_ SelfList(T *p_self) { - _self = p_self; - _next = nullptr; - _prev = nullptr; - _root = nullptr; } _FORCE_INLINE_ ~SelfList() { - if (_root) _root->remove(this); } diff --git a/core/set.h b/core/set.h index c17ee15350..851a33b43a 100644 --- a/core/set.h +++ b/core/set.h @@ -51,12 +51,12 @@ public: private: friend class Set<T, C, A>; - int color; - Element *right; - Element *left; - Element *parent; - Element *_next; - Element *_prev; + int color = RED; + Element *right = nullptr; + Element *left = nullptr; + Element *parent = nullptr; + Element *_next = nullptr; + Element *_prev = nullptr; T value; //_Data *data; @@ -80,22 +80,15 @@ public: const T &get() const { return value; }; - Element() { - color = RED; - right = nullptr; - left = nullptr; - parent = nullptr; - _next = nullptr; - _prev = nullptr; - }; + Element() {} }; private: struct _Data { - Element *_root; + Element *_root = nullptr; Element *_nil; - int size_cache; + int size_cache = 0; _FORCE_INLINE_ _Data() { #ifdef GLOBALNIL_DISABLED @@ -105,8 +98,6 @@ private: #else _nil = (Element *)&_GlobalNilClass::_nil; #endif - _root = nullptr; - size_cache = 0; } void _create_root() { @@ -625,11 +616,9 @@ public: _copy_from(p_set); } - _FORCE_INLINE_ Set() { - } + _FORCE_INLINE_ Set() {} ~Set() { - clear(); } }; diff --git a/core/spin_lock.h b/core/spin_lock.h index c48631f94a..1bb810bb29 100644 --- a/core/spin_lock.h +++ b/core/spin_lock.h @@ -32,6 +32,7 @@ #define SPIN_LOCK_H #include "core/typedefs.h" + #include <atomic> class SpinLock { diff --git a/core/string_buffer.h b/core/string_buffer.h index a140f0abf7..cfe7cdabfe 100644 --- a/core/string_buffer.h +++ b/core/string_buffer.h @@ -38,7 +38,7 @@ class StringBuffer { CharType short_buffer[SHORT_BUFFER_SIZE]; String buffer; - int string_length; + int string_length = 0; _FORCE_INLINE_ CharType *current_buffer_ptr() { return static_cast<String &>(buffer).empty() ? short_buffer : buffer.ptrw(); @@ -78,10 +78,6 @@ public: _FORCE_INLINE_ operator String() { return as_string(); } - - StringBuffer() { - string_length = 0; - } }; template <int SHORT_BUFFER_SIZE> diff --git a/core/string_builder.h b/core/string_builder.h index dd8a154890..8fcd6669bd 100644 --- a/core/string_builder.h +++ b/core/string_builder.h @@ -32,12 +32,11 @@ #define STRING_BUILDER_H #include "core/ustring.h" - #include "core/vector.h" class StringBuilder { - uint32_t string_length; + uint32_t string_length = 0; Vector<String> strings; Vector<const char *> c_strings; @@ -80,9 +79,7 @@ public: return as_string(); } - StringBuilder() { - string_length = 0; - } + StringBuilder() {} }; #endif // STRING_BUILDER_H diff --git a/core/string_name.cpp b/core/string_name.cpp index 9cbac97a7c..bfc10d96e4 100644 --- a/core/string_name.cpp +++ b/core/string_name.cpp @@ -388,12 +388,6 @@ StringName StringName::search(const String &p_name) { return StringName(); //does not exist } -StringName::StringName() { - - _data = nullptr; -} - StringName::~StringName() { - unref(); } diff --git a/core/string_name.h b/core/string_name.h index aec87b8e66..762eb43610 100644 --- a/core/string_name.h +++ b/core/string_name.h @@ -52,25 +52,20 @@ class StringName { struct _Data { SafeRefCount refcount; - const char *cname; + const char *cname = nullptr; String name; String get_name() const { return cname ? String(cname) : name; } - int idx; - uint32_t hash; - _Data *prev; - _Data *next; - _Data() { - cname = nullptr; - next = prev = nullptr; - idx = 0; - hash = 0; - } + int idx = 0; + uint32_t hash = 0; + _Data *prev = nullptr; + _Data *next = nullptr; + _Data() {} }; static _Data *_table[STRING_TABLE_LEN]; - _Data *_data; + _Data *_data = nullptr; union _HashUnion { @@ -90,7 +85,7 @@ class StringName { StringName(_Data *p_data) { _data = p_data; } public: - operator const void *() const { return (_data && (_data->cname || !_data->name.empty())) ? (void *)1 : 0; } + operator const void *() const { return (_data && (_data->cname || !_data->name.empty())) ? (void *)1 : nullptr; } bool operator==(const String &p_name) const; bool operator==(const char *p_name) const; @@ -160,7 +155,7 @@ public: StringName(const StringName &p_name); StringName(const String &p_name); StringName(const StaticCString &p_static_string); - StringName(); + StringName() {} ~StringName(); }; diff --git a/core/thread_work_pool.cpp b/core/thread_work_pool.cpp index c8311f102f..28e933ac4d 100644 --- a/core/thread_work_pool.cpp +++ b/core/thread_work_pool.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "thread_work_pool.h" + #include "core/os/os.h" void ThreadWorkPool::_thread_function(ThreadData *p_thread) { diff --git a/core/thread_work_pool.h b/core/thread_work_pool.h index 214d2c4aa7..8005bf80b8 100644 --- a/core/thread_work_pool.h +++ b/core/thread_work_pool.h @@ -33,8 +33,10 @@ #include "core/os/memory.h" #include "core/os/semaphore.h" + #include <atomic> #include <thread> + class ThreadWorkPool { std::atomic<uint32_t> index; diff --git a/core/translation.cpp b/core/translation.cpp index 3f45bb17c9..191349e953 100644 --- a/core/translation.cpp +++ b/core/translation.cpp @@ -902,10 +902,6 @@ void Translation::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::STRING, "locale"), "set_locale", "get_locale"); } -Translation::Translation() : - locale("en") { -} - /////////////////////////////////////////////// bool TranslationServer::is_locale_valid(const String &p_locale) { @@ -989,7 +985,8 @@ String TranslationServer::get_locale() const { String TranslationServer::get_locale_name(const String &p_locale) const { - if (!locale_name_map.has(p_locale)) return String(); + if (!locale_name_map.has(p_locale)) + return String(); return locale_name_map[p_locale]; } @@ -1246,13 +1243,10 @@ void TranslationServer::load_translations() { } } -TranslationServer::TranslationServer() : - locale("en"), - enabled(true) { +TranslationServer::TranslationServer() { singleton = this; for (int i = 0; locale_list[i]; ++i) { - locale_name_map.insert(locale_list[i], String::utf8(locale_names[i])); } } diff --git a/core/translation.h b/core/translation.h index 29a068f450..423b3166b1 100644 --- a/core/translation.h +++ b/core/translation.h @@ -39,7 +39,7 @@ class Translation : public Resource { OBJ_SAVE_TYPE(Translation); RES_BASE_EXTENSION("translation"); - String locale; + String locale = "en"; Map<StringName, StringName> translation_map; Vector<String> _get_message_list() const; @@ -61,14 +61,14 @@ public: void get_message_list(List<StringName> *r_messages) const; int get_message_count() const; - Translation(); + Translation() {} }; class TranslationServer : public Object { GDCLASS(TranslationServer, Object); - String locale; + String locale = "en"; String fallback; Set<Ref<Translation>> translations; @@ -77,7 +77,7 @@ class TranslationServer : public Object { Map<String, String> locale_name_map; - bool enabled; + bool enabled = true; static TranslationServer *singleton; bool _load_translations(const String &p_from); diff --git a/core/typed_array.cpp b/core/typed_array.cpp deleted file mode 100644 index 55e45f0b3f..0000000000 --- a/core/typed_array.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "typed_array.h" diff --git a/core/typed_array.h b/core/typed_array.h index 5e95e81ea3..2c7b7e0384 100644 --- a/core/typed_array.h +++ b/core/typed_array.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* typed_array.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef TYPED_ARRAY_H #define TYPED_ARRAY_H diff --git a/core/undo_redo.cpp b/core/undo_redo.cpp index 62ad3e9f98..9324dfb573 100644 --- a/core/undo_redo.cpp +++ b/core/undo_redo.cpp @@ -409,23 +409,6 @@ void UndoRedo::set_property_notify_callback(PropertyNotifyCallback p_property_ca prop_callback_ud = p_ud; } -UndoRedo::UndoRedo() { - - committing = 0; - version = 1; - action_level = 0; - current_action = -1; - merge_mode = MERGE_DISABLE; - merging = false; - callback = nullptr; - callback_ud = nullptr; - - method_callbck_ud = nullptr; - prop_callback_ud = nullptr; - method_callback = nullptr; - property_callback = nullptr; -} - UndoRedo::~UndoRedo() { clear_history(); diff --git a/core/undo_redo.h b/core/undo_redo.h index 3b91e9ce36..5b74ffcbb8 100644 --- a/core/undo_redo.h +++ b/core/undo_redo.h @@ -77,25 +77,25 @@ private: }; Vector<Action> actions; - int current_action; - int action_level; - MergeMode merge_mode; - bool merging; - uint64_t version; + int current_action = -1; + int action_level = 0; + MergeMode merge_mode = MERGE_DISABLE; + bool merging = false; + uint64_t version = 1; void _pop_history_tail(); void _process_operation_list(List<Operation>::Element *E); void _discard_redo(); - CommitNotifyCallback callback; - void *callback_ud; - void *method_callbck_ud; - void *prop_callback_ud; + CommitNotifyCallback callback = nullptr; + void *callback_ud = nullptr; + void *method_callbck_ud = nullptr; + void *prop_callback_ud = nullptr; - MethodNotifyCallback method_callback; - PropertyNotifyCallback property_callback; + MethodNotifyCallback method_callback = nullptr; + PropertyNotifyCallback property_callback = nullptr; - int committing; + int committing = 0; protected: static void _bind_methods(); @@ -128,7 +128,7 @@ public: void set_method_notify_callback(MethodNotifyCallback p_method_callback, void *p_ud); void set_property_notify_callback(PropertyNotifyCallback p_property_callback, void *p_ud); - UndoRedo(); + UndoRedo() {} ~UndoRedo(); }; diff --git a/core/ustring.cpp b/core/ustring.cpp index fbe3fcb1b2..992424f057 100644 --- a/core/ustring.cpp +++ b/core/ustring.cpp @@ -28,10 +28,6 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#ifdef _MSC_VER -#define _CRT_SECURE_NO_WARNINGS // to disable build-time warning which suggested to use strcpy_s instead strcpy -#endif - #include "ustring.h" #include "core/color.h" @@ -51,6 +47,10 @@ #include <stdlib.h> #endif +#ifdef _MSC_VER +#define _CRT_SECURE_NO_WARNINGS // to disable build-time warning which suggested to use strcpy_s instead strcpy +#endif + #if defined(MINGW_ENABLED) || defined(_MSC_VER) #define snprintf _snprintf_s #endif @@ -548,8 +548,8 @@ signed char String::naturalnocasecmp_to(const String &p_str) const { return -1; /* Compare the numbers */ - this_int = to_int(this_str); - that_int = to_int(that_str); + this_int = to_int(this_str, -1, true); + that_int = to_int(that_str, -1, true); if (this_int < that_int) return -1; @@ -1882,7 +1882,8 @@ bool String::is_numeric() const { }; int s = 0; - if (operator[](0) == '-') ++s; + if (operator[](0) == '-') + ++s; bool dot = false; for (int i = s; i < length(); i++) { @@ -2137,7 +2138,7 @@ double String::to_double(const CharType *p_str, const CharType **r_end) { return built_in_strtod<CharType>(p_str, (CharType **)r_end); } -int64_t String::to_int(const CharType *p_str, int p_len) { +int64_t String::to_int(const CharType *p_str, int p_len, bool p_clamp) { if (p_len == 0 || !p_str[0]) return 0; @@ -2181,7 +2182,15 @@ int64_t String::to_int(const CharType *p_str, int p_len) { while (*str && str != limit) { number += *(str++); } - ERR_FAIL_V_MSG(sign == 1 ? INT64_MAX : INT64_MIN, "Cannot represent " + number + " as integer, provided value is " + (sign == 1 ? "too big." : "too small.")); + if (p_clamp) { + if (sign == 1) { + return INT64_MAX; + } else { + return INT64_MIN; + } + } else { + ERR_FAIL_V_MSG(sign == 1 ? INT64_MAX : INT64_MIN, "Cannot represent " + number + " as integer, provided value is " + (sign == 1 ? "too big." : "too small.")); + } } integer *= 10; integer += c - '0'; @@ -4182,9 +4191,14 @@ String String::sprintf(const Array &values, bool *error) const { int base = 16; bool capitalize = false; switch (c) { - case 'd': base = 10; break; - case 'o': base = 8; break; - case 'x': break; + case 'd': + base = 10; + break; + case 'o': + base = 8; + break; + case 'x': + break; case 'X': base = 16; capitalize = true; diff --git a/core/ustring.h b/core/ustring.h index ee7e3b1e16..15bc2b323c 100644 --- a/core/ustring.h +++ b/core/ustring.h @@ -254,7 +254,7 @@ public: static int to_int(const char *p_str, int p_len = -1); static double to_double(const char *p_str); static double to_double(const CharType *p_str, const CharType **r_end = nullptr); - static int64_t to_int(const CharType *p_str, int p_len = -1); + static int64_t to_int(const CharType *p_str, int p_len = -1, bool p_clamp = false); String capitalize() const; String camelcase_to_underscore(bool lowercase = true) const; diff --git a/core/variant.cpp b/core/variant.cpp index a91876dd98..162d409026 100644 --- a/core/variant.cpp +++ b/core/variant.cpp @@ -1406,20 +1406,47 @@ void Variant::reference(const Variant &p_variant) { void Variant::zero() { switch (type) { - case NIL: break; - case BOOL: this->_data._bool = false; break; - case INT: this->_data._int = 0; break; - case FLOAT: this->_data._float = 0; break; - case VECTOR2: *reinterpret_cast<Vector2 *>(this->_data._mem) = Vector2(); break; - case VECTOR2I: *reinterpret_cast<Vector2i *>(this->_data._mem) = Vector2i(); break; - case RECT2: *reinterpret_cast<Rect2 *>(this->_data._mem) = Rect2(); break; - case RECT2I: *reinterpret_cast<Rect2i *>(this->_data._mem) = Rect2i(); break; - case VECTOR3: *reinterpret_cast<Vector3 *>(this->_data._mem) = Vector3(); break; - case VECTOR3I: *reinterpret_cast<Vector3i *>(this->_data._mem) = Vector3i(); break; - case PLANE: *reinterpret_cast<Plane *>(this->_data._mem) = Plane(); break; - case QUAT: *reinterpret_cast<Quat *>(this->_data._mem) = Quat(); break; - case COLOR: *reinterpret_cast<Color *>(this->_data._mem) = Color(); break; - default: this->clear(); break; + case NIL: + break; + case BOOL: + this->_data._bool = false; + break; + case INT: + this->_data._int = 0; + break; + case FLOAT: + this->_data._float = 0; + break; + case VECTOR2: + *reinterpret_cast<Vector2 *>(this->_data._mem) = Vector2(); + break; + case VECTOR2I: + *reinterpret_cast<Vector2i *>(this->_data._mem) = Vector2i(); + break; + case RECT2: + *reinterpret_cast<Rect2 *>(this->_data._mem) = Rect2(); + break; + case RECT2I: + *reinterpret_cast<Rect2i *>(this->_data._mem) = Rect2i(); + break; + case VECTOR3: + *reinterpret_cast<Vector3 *>(this->_data._mem) = Vector3(); + break; + case VECTOR3I: + *reinterpret_cast<Vector3i *>(this->_data._mem) = Vector3i(); + break; + case PLANE: + *reinterpret_cast<Plane *>(this->_data._mem) = Plane(); + break; + case QUAT: + *reinterpret_cast<Quat *>(this->_data._mem) = Quat(); + break; + case COLOR: + *reinterpret_cast<Color *>(this->_data._mem) = Color(); + break; + default: + this->clear(); + break; } } @@ -1545,11 +1572,16 @@ Variant::operator signed int() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1560,11 +1592,16 @@ Variant::operator unsigned int() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1576,11 +1613,16 @@ Variant::operator int64_t() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int64(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int64(); default: { return 0; @@ -1612,11 +1654,16 @@ Variant::operator uint64_t() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1639,11 +1686,16 @@ Variant::operator signed long() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._real; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._real; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1657,11 +1709,16 @@ Variant::operator unsigned long() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._real; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._real; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1676,11 +1733,16 @@ Variant::operator signed short() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1691,11 +1753,16 @@ Variant::operator unsigned short() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1706,11 +1773,16 @@ Variant::operator signed char() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1721,11 +1793,16 @@ Variant::operator unsigned char() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1 : 0; - case INT: return _data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_int(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1 : 0; + case INT: + return _data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_int(); default: { return 0; @@ -1742,11 +1819,16 @@ Variant::operator float() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1.0 : 0.0; - case INT: return (float)_data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_double(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1.0 : 0.0; + case INT: + return (float)_data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_double(); default: { return 0; @@ -1757,11 +1839,16 @@ Variant::operator double() const { switch (type) { - case NIL: return 0; - case BOOL: return _data._bool ? 1.0 : 0.0; - case INT: return (double)_data._int; - case FLOAT: return _data._float; - case STRING: return operator String().to_double(); + case NIL: + return 0; + case BOOL: + return _data._bool ? 1.0 : 0.0; + case INT: + return (double)_data._int; + case FLOAT: + return _data._float; + case STRING: + return operator String().to_double(); default: { return 0; @@ -1800,27 +1887,40 @@ Variant::operator String() const { String Variant::stringify(List<const void *> &stack) const { switch (type) { - case NIL: return "Null"; - case BOOL: return _data._bool ? "True" : "False"; - case INT: return itos(_data._int); - case FLOAT: return rtos(_data._float); - case STRING: return *reinterpret_cast<const String *>(_data._mem); - case VECTOR2: return "(" + operator Vector2() + ")"; - case VECTOR2I: return "(" + operator Vector2i() + ")"; - case RECT2: return "(" + operator Rect2() + ")"; - case RECT2I: return "(" + operator Rect2i() + ")"; + case NIL: + return "Null"; + case BOOL: + return _data._bool ? "True" : "False"; + case INT: + return itos(_data._int); + case FLOAT: + return rtos(_data._float); + case STRING: + return *reinterpret_cast<const String *>(_data._mem); + case VECTOR2: + return "(" + operator Vector2() + ")"; + case VECTOR2I: + return "(" + operator Vector2i() + ")"; + case RECT2: + return "(" + operator Rect2() + ")"; + case RECT2I: + return "(" + operator Rect2i() + ")"; case TRANSFORM2D: { Transform2D mat32 = operator Transform2D(); return "(" + Variant(mat32.elements[0]).operator String() + ", " + Variant(mat32.elements[1]).operator String() + ", " + Variant(mat32.elements[2]).operator String() + ")"; } break; - case VECTOR3: return "(" + operator Vector3() + ")"; - case VECTOR3I: return "(" + operator Vector3i() + ")"; + case VECTOR3: + return "(" + operator Vector3() + ")"; + case VECTOR3I: + return "(" + operator Vector3i() + ")"; case PLANE: return operator Plane(); //case QUAT: - case AABB: return operator ::AABB(); - case QUAT: return "(" + operator Quat() + ")"; + case AABB: + return operator ::AABB(); + case QUAT: + return "(" + operator Quat() + ")"; case BASIS: { Basis mat3 = operator Basis(); @@ -1846,10 +1946,14 @@ String Variant::stringify(List<const void *> &stack) const { return mtx + ")"; } break; - case TRANSFORM: return operator Transform(); - case STRING_NAME: return operator StringName(); - case NODE_PATH: return operator NodePath(); - case COLOR: return String::num(operator Color().r) + "," + String::num(operator Color().g) + "," + String::num(operator Color().b) + "," + String::num(operator Color().a); + case TRANSFORM: + return operator Transform(); + case STRING_NAME: + return operator StringName(); + case NODE_PATH: + return operator NodePath(); + case COLOR: + return String::num(operator Color().r) + "," + String::num(operator Color().g) + "," + String::num(operator Color().b) + "," + String::num(operator Color().a); case DICTIONARY: { const Dictionary &d = *reinterpret_cast<const Dictionary *>(_data._mem); @@ -3085,16 +3189,9 @@ Variant::Variant(const IP_Address &p_address) { Variant::Variant(const Variant &p_variant) { - type = NIL; reference(p_variant); } -/* -Variant::~Variant() { - - clear(); -}*/ - uint32_t Variant::hash() const { switch (type) { @@ -3676,9 +3773,12 @@ bool Variant::is_shared() const { switch (type) { - case OBJECT: return true; - case ARRAY: return true; - case DICTIONARY: return true; + case OBJECT: + return true; + case ARRAY: + return true; + case DICTIONARY: + return true; default: { } } diff --git a/core/variant.h b/core/variant.h index 1ae09ad82f..0498e93825 100644 --- a/core/variant.h +++ b/core/variant.h @@ -123,7 +123,7 @@ private: // Variant takes 20 bytes when real_t is float, and 36 if double // it only allocates extra memory for aabb/matrix. - Type type; + Type type = NIL; struct ObjData { @@ -469,12 +469,12 @@ public: static void construct_from_string(const String &p_string, Variant &r_value, ObjectConstruct p_obj_construct = nullptr, void *p_construct_ud = nullptr); void operator=(const Variant &p_variant); // only this is enough for all the other types + Variant(const Variant &p_variant); - _FORCE_INLINE_ Variant() { - type = NIL; - } + _FORCE_INLINE_ Variant() {} _FORCE_INLINE_ ~Variant() { - if (type != Variant::NIL) clear(); + if (type != Variant::NIL) + clear(); } }; diff --git a/core/variant_call.cpp b/core/variant_call.cpp index 9fffb42ff6..416a1a5fb8 100644 --- a/core/variant_call.cpp +++ b/core/variant_call.cpp @@ -262,6 +262,7 @@ struct _VariantCall { VCALL_LOCALMEM3R(String, split); VCALL_LOCALMEM3R(String, rsplit); VCALL_LOCALMEM2R(String, split_floats); + VCALL_LOCALMEM1R(String, join); VCALL_LOCALMEM0R(String, to_upper); VCALL_LOCALMEM0R(String, to_lower); VCALL_LOCALMEM1R(String, left); @@ -861,9 +862,15 @@ struct _VariantCall { static void _call_Transform2D_xform(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator Vector2()); return; - case Variant::RECT2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator Rect2()); return; - case Variant::PACKED_VECTOR2_ARRAY: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator PackedVector2Array()); return; + case Variant::VECTOR2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator Vector2()); + return; + case Variant::RECT2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator Rect2()); + return; + case Variant::PACKED_VECTOR2_ARRAY: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform(p_args[0]->operator PackedVector2Array()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'xform' in base 'Transform2D'. Valid types are Vector2, Rect2, and PackedVector2Array."); @@ -872,9 +879,15 @@ struct _VariantCall { static void _call_Transform2D_xform_inv(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Vector2()); return; - case Variant::RECT2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Rect2()); return; - case Variant::PACKED_VECTOR2_ARRAY: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator PackedVector2Array()); return; + case Variant::VECTOR2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Vector2()); + return; + case Variant::RECT2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Rect2()); + return; + case Variant::PACKED_VECTOR2_ARRAY: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->xform_inv(p_args[0]->operator PackedVector2Array()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'xform_inv' in base 'Transform2D'. Valid types are Vector2, Rect2, and PackedVector2Array."); @@ -883,7 +896,9 @@ struct _VariantCall { static void _call_Transform2D_basis_xform(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->basis_xform(p_args[0]->operator Vector2()); return; + case Variant::VECTOR2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->basis_xform(p_args[0]->operator Vector2()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'basis_xform' in base 'Transform2D'. Only Vector2 is valid."); @@ -892,7 +907,9 @@ struct _VariantCall { static void _call_Transform2D_basis_xform_inv(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR2: r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->basis_xform_inv(p_args[0]->operator Vector2()); return; + case Variant::VECTOR2: + r_ret = reinterpret_cast<Transform2D *>(p_self._data._ptr)->basis_xform_inv(p_args[0]->operator Vector2()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'basis_xform_inv' in base 'Transform2D'. Only Vector2 is valid."); @@ -929,10 +946,18 @@ struct _VariantCall { static void _call_Transform_xform(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR3: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator Vector3()); return; - case Variant::PLANE: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator Plane()); return; - case Variant::AABB: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator ::AABB()); return; - case Variant::PACKED_VECTOR3_ARRAY: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator ::PackedVector3Array()); return; + case Variant::VECTOR3: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator Vector3()); + return; + case Variant::PLANE: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator Plane()); + return; + case Variant::AABB: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator ::AABB()); + return; + case Variant::PACKED_VECTOR3_ARRAY: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform(p_args[0]->operator ::PackedVector3Array()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'xform' in base 'Transform'. Valid types are Vector3, Plane, AABB, and PackedVector3Array."); @@ -941,10 +966,18 @@ struct _VariantCall { static void _call_Transform_xform_inv(Variant &r_ret, Variant &p_self, const Variant **p_args) { switch (p_args[0]->type) { - case Variant::VECTOR3: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Vector3()); return; - case Variant::PLANE: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Plane()); return; - case Variant::AABB: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator ::AABB()); return; - case Variant::PACKED_VECTOR3_ARRAY: r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator ::PackedVector3Array()); return; + case Variant::VECTOR3: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Vector3()); + return; + case Variant::PLANE: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator Plane()); + return; + case Variant::AABB: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator ::AABB()); + return; + case Variant::PACKED_VECTOR3_ARRAY: + r_ret = reinterpret_cast<Transform *>(p_self._data._ptr)->xform_inv(p_args[0]->operator ::PackedVector3Array()); + return; default: r_ret = Variant(); ERR_PRINT("Invalid type in function 'xform_inv' in base 'Transform'. Valid types are Vector3, Plane, AABB, and PackedVector3Array."); @@ -1280,50 +1313,74 @@ Variant Variant::construct(const Variant::Type p_type, const Variant **p_args, i return Variant(); // atomic types - case BOOL: return Variant(false); - case INT: return 0; - case FLOAT: return 0.0f; + case BOOL: + return Variant(false); + case INT: + return 0; + case FLOAT: + return 0.0f; case STRING: return String(); // math types case VECTOR2: return Vector2(); - case RECT2: return Rect2(); - case VECTOR3: return Vector3(); - case TRANSFORM2D: return Transform2D(); - case PLANE: return Plane(); - case QUAT: return Quat(); + case RECT2: + return Rect2(); + case VECTOR3: + return Vector3(); + case TRANSFORM2D: + return Transform2D(); + case PLANE: + return Plane(); + case QUAT: + return Quat(); case AABB: return ::AABB(); - case BASIS: return Basis(); + case BASIS: + return Basis(); case TRANSFORM: return Transform(); // misc types - case COLOR: return Color(); + case COLOR: + return Color(); case STRING_NAME: return StringName(); case NODE_PATH: return NodePath(); - case _RID: return RID(); - case OBJECT: return (Object *)nullptr; - case CALLABLE: return Callable(); - case SIGNAL: return Signal(); - case DICTIONARY: return Dictionary(); + case _RID: + return RID(); + case OBJECT: + return (Object *)nullptr; + case CALLABLE: + return Callable(); + case SIGNAL: + return Signal(); + case DICTIONARY: + return Dictionary(); case ARRAY: return Array(); - case PACKED_BYTE_ARRAY: return PackedByteArray(); - case PACKED_INT32_ARRAY: return PackedInt32Array(); - case PACKED_INT64_ARRAY: return PackedInt64Array(); - case PACKED_FLOAT32_ARRAY: return PackedFloat32Array(); - case PACKED_FLOAT64_ARRAY: return PackedFloat64Array(); - case PACKED_STRING_ARRAY: return PackedStringArray(); + case PACKED_BYTE_ARRAY: + return PackedByteArray(); + case PACKED_INT32_ARRAY: + return PackedInt32Array(); + case PACKED_INT64_ARRAY: + return PackedInt64Array(); + case PACKED_FLOAT32_ARRAY: + return PackedFloat32Array(); + case PACKED_FLOAT64_ARRAY: + return PackedFloat64Array(); + case PACKED_STRING_ARRAY: + return PackedStringArray(); case PACKED_VECTOR2_ARRAY: return PackedVector2Array(); - case PACKED_VECTOR3_ARRAY: return PackedVector3Array(); - case PACKED_COLOR_ARRAY: return PackedColorArray(); - default: return Variant(); + case PACKED_VECTOR3_ARRAY: + return PackedVector3Array(); + case PACKED_COLOR_ARRAY: + return PackedColorArray(); + default: + return Variant(); } } else if (p_argcount == 1 && p_args[0]->type == p_type) { @@ -1354,46 +1411,68 @@ Variant Variant::construct(const Variant::Type p_type, const Variant **p_args, i case VECTOR2I: { return Vector2i(*p_args[0]); } - case RECT2: return (Rect2(*p_args[0])); - case RECT2I: return (Rect2i(*p_args[0])); - case VECTOR3: return (Vector3(*p_args[0])); - case VECTOR3I: return (Vector3i(*p_args[0])); + case RECT2: + return (Rect2(*p_args[0])); + case RECT2I: + return (Rect2i(*p_args[0])); + case VECTOR3: + return (Vector3(*p_args[0])); + case VECTOR3I: + return (Vector3i(*p_args[0])); case TRANSFORM2D: return (Transform2D(p_args[0]->operator Transform2D())); - case PLANE: return (Plane(*p_args[0])); - case QUAT: return (p_args[0]->operator Quat()); + case PLANE: + return (Plane(*p_args[0])); + case QUAT: + return (p_args[0]->operator Quat()); case AABB: return (::AABB(*p_args[0])); - case BASIS: return (Basis(p_args[0]->operator Basis())); + case BASIS: + return (Basis(p_args[0]->operator Basis())); case TRANSFORM: return (Transform(p_args[0]->operator Transform())); // misc types - case COLOR: return p_args[0]->type == Variant::STRING ? Color::html(*p_args[0]) : Color::hex(*p_args[0]); + case COLOR: + return p_args[0]->type == Variant::STRING ? Color::html(*p_args[0]) : Color::hex(*p_args[0]); case STRING_NAME: return (StringName(p_args[0]->operator StringName())); case NODE_PATH: return (NodePath(p_args[0]->operator NodePath())); - case _RID: return (RID(*p_args[0])); - case OBJECT: return ((Object *)(p_args[0]->operator Object *())); - case CALLABLE: return ((Callable)(p_args[0]->operator Callable())); - case SIGNAL: return ((Signal)(p_args[0]->operator Signal())); - case DICTIONARY: return p_args[0]->operator Dictionary(); + case _RID: + return (RID(*p_args[0])); + case OBJECT: + return ((Object *)(p_args[0]->operator Object *())); + case CALLABLE: + return ((Callable)(p_args[0]->operator Callable())); + case SIGNAL: + return ((Signal)(p_args[0]->operator Signal())); + case DICTIONARY: + return p_args[0]->operator Dictionary(); case ARRAY: return p_args[0]->operator Array(); // arrays - case PACKED_BYTE_ARRAY: return (PackedByteArray(*p_args[0])); - case PACKED_INT32_ARRAY: return (PackedInt32Array(*p_args[0])); - case PACKED_INT64_ARRAY: return (PackedInt64Array(*p_args[0])); - case PACKED_FLOAT32_ARRAY: return (PackedFloat32Array(*p_args[0])); - case PACKED_FLOAT64_ARRAY: return (PackedFloat64Array(*p_args[0])); - case PACKED_STRING_ARRAY: return (PackedStringArray(*p_args[0])); + case PACKED_BYTE_ARRAY: + return (PackedByteArray(*p_args[0])); + case PACKED_INT32_ARRAY: + return (PackedInt32Array(*p_args[0])); + case PACKED_INT64_ARRAY: + return (PackedInt64Array(*p_args[0])); + case PACKED_FLOAT32_ARRAY: + return (PackedFloat32Array(*p_args[0])); + case PACKED_FLOAT64_ARRAY: + return (PackedFloat64Array(*p_args[0])); + case PACKED_STRING_ARRAY: + return (PackedStringArray(*p_args[0])); case PACKED_VECTOR2_ARRAY: return (PackedVector2Array(*p_args[0])); - case PACKED_VECTOR3_ARRAY: return (PackedVector3Array(*p_args[0])); - case PACKED_COLOR_ARRAY: return (PackedColorArray(*p_args[0])); - default: return Variant(); + case PACKED_VECTOR3_ARRAY: + return (PackedVector3Array(*p_args[0])); + case PACKED_COLOR_ARRAY: + return (PackedColorArray(*p_args[0])); + default: + return Variant(); } } else if (p_argcount >= 1) { @@ -1738,6 +1817,7 @@ void register_variant_methods() { ADDFUNC3R(STRING, PACKED_STRING_ARRAY, String, split, STRING, "delimiter", BOOL, "allow_empty", INT, "maxsplit", varray(true, 0)); ADDFUNC3R(STRING, PACKED_STRING_ARRAY, String, rsplit, STRING, "delimiter", BOOL, "allow_empty", INT, "maxsplit", varray(true, 0)); ADDFUNC2R(STRING, PACKED_FLOAT32_ARRAY, String, split_floats, STRING, "delimiter", BOOL, "allow_empty", varray(true)); + ADDFUNC1R(STRING, STRING, String, join, PACKED_STRING_ARRAY, "parts", varray()); ADDFUNC0R(STRING, STRING, String, to_upper, varray()); ADDFUNC0R(STRING, STRING, String, to_lower, varray()); diff --git a/core/variant_op.cpp b/core/variant_op.cpp index 27624b81ee..4c9848f26a 100644 --- a/core/variant_op.cpp +++ b/core/variant_op.cpp @@ -78,38 +78,38 @@ #define TYPE(PREFIX, OP, TYPE) &&PREFIX##_##OP##_##TYPE /* clang-format off */ -#define TYPES(PREFIX, OP) { \ - TYPE(PREFIX, OP, NIL), \ - TYPE(PREFIX, OP, BOOL), \ - TYPE(PREFIX, OP, INT), \ - TYPE(PREFIX, OP, FLOAT), \ - TYPE(PREFIX, OP, STRING), \ - TYPE(PREFIX, OP, VECTOR2), \ - TYPE(PREFIX, OP, VECTOR2I), \ - TYPE(PREFIX, OP, RECT2), \ - TYPE(PREFIX, OP, RECT2I), \ - TYPE(PREFIX, OP, VECTOR3), \ - TYPE(PREFIX, OP, VECTOR3I), \ - TYPE(PREFIX, OP, TRANSFORM2D), \ - TYPE(PREFIX, OP, PLANE), \ - TYPE(PREFIX, OP, QUAT), \ - TYPE(PREFIX, OP, AABB), \ - TYPE(PREFIX, OP, BASIS), \ - TYPE(PREFIX, OP, TRANSFORM), \ - TYPE(PREFIX, OP, COLOR), \ +#define TYPES(PREFIX, OP) { \ + TYPE(PREFIX, OP, NIL), \ + TYPE(PREFIX, OP, BOOL), \ + TYPE(PREFIX, OP, INT), \ + TYPE(PREFIX, OP, FLOAT), \ + TYPE(PREFIX, OP, STRING), \ + TYPE(PREFIX, OP, VECTOR2), \ + TYPE(PREFIX, OP, VECTOR2I), \ + TYPE(PREFIX, OP, RECT2), \ + TYPE(PREFIX, OP, RECT2I), \ + TYPE(PREFIX, OP, VECTOR3), \ + TYPE(PREFIX, OP, VECTOR3I), \ + TYPE(PREFIX, OP, TRANSFORM2D), \ + TYPE(PREFIX, OP, PLANE), \ + TYPE(PREFIX, OP, QUAT), \ + TYPE(PREFIX, OP, AABB), \ + TYPE(PREFIX, OP, BASIS), \ + TYPE(PREFIX, OP, TRANSFORM), \ + TYPE(PREFIX, OP, COLOR), \ TYPE(PREFIX, OP, STRING_NAME), \ - TYPE(PREFIX, OP, NODE_PATH), \ - TYPE(PREFIX, OP, _RID), \ - TYPE(PREFIX, OP, OBJECT), \ + TYPE(PREFIX, OP, NODE_PATH), \ + TYPE(PREFIX, OP, _RID), \ + TYPE(PREFIX, OP, OBJECT), \ TYPE(PREFIX, OP, CALLABLE), \ - TYPE(PREFIX, OP, SIGNAL), \ - TYPE(PREFIX, OP, DICTIONARY), \ - TYPE(PREFIX, OP, ARRAY), \ + TYPE(PREFIX, OP, SIGNAL), \ + TYPE(PREFIX, OP, DICTIONARY), \ + TYPE(PREFIX, OP, ARRAY), \ TYPE(PREFIX, OP, PACKED_BYTE_ARRAY), \ - TYPE(PREFIX, OP, PACKED_INT32_ARRAY), \ - TYPE(PREFIX, OP, PACKED_INT64_ARRAY), \ - TYPE(PREFIX, OP, PACKED_FLOAT32_ARRAY), \ - TYPE(PREFIX, OP, PACKED_FLOAT64_ARRAY), \ + TYPE(PREFIX, OP, PACKED_INT32_ARRAY), \ + TYPE(PREFIX, OP, PACKED_INT64_ARRAY), \ + TYPE(PREFIX, OP, PACKED_FLOAT32_ARRAY), \ + TYPE(PREFIX, OP, PACKED_FLOAT64_ARRAY), \ TYPE(PREFIX, OP, PACKED_STRING_ARRAY), \ TYPE(PREFIX, OP, PACKED_VECTOR2_ARRAY), \ TYPE(PREFIX, OP, PACKED_VECTOR3_ARRAY), \ @@ -181,21 +181,26 @@ bool Variant::booleanize() const { return; \ } -#define DEFAULT_OP_NUM(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == INT) _RETURN(p_a._data.m_type m_op p_b._data._int); \ - if (p_b.type == FLOAT) _RETURN(p_a._data.m_type m_op p_b._data._float); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_NUM(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == INT) \ + _RETURN(p_a._data.m_type m_op p_b._data._int); \ + if (p_b.type == FLOAT) \ + _RETURN(p_a._data.m_type m_op p_b._data._float); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_NUM_NULL(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == INT) _RETURN(p_a._data.m_type m_op p_b._data._int); \ - if (p_b.type == FLOAT) _RETURN(p_a._data.m_type m_op p_b._data._float); \ - if (p_b.type == NIL) _RETURN(!(p_b.type m_op NIL)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_NUM_NULL(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == INT) \ + _RETURN(p_a._data.m_type m_op p_b._data._int); \ + if (p_b.type == FLOAT) \ + _RETURN(p_a._data.m_type m_op p_b._data._float); \ + if (p_b.type == NIL) \ + _RETURN(!(p_b.type m_op NIL)); \ + \ + _RETURN_FAIL \ }; #ifdef DEBUG_ENABLED @@ -219,12 +224,14 @@ bool Variant::booleanize() const { _RETURN_FAIL \ }; #else -#define DEFAULT_OP_NUM_DIV(m_prefix, m_op_name, m_name, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == INT) _RETURN(p_a._data.m_type / p_b._data._int); \ - if (p_b.type == FLOAT) _RETURN(p_a._data.m_type / p_b._data._float); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_NUM_DIV(m_prefix, m_op_name, m_name, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == INT) \ + _RETURN(p_a._data.m_type / p_b._data._int); \ + if (p_b.type == FLOAT) \ + _RETURN(p_a._data.m_type / p_b._data._float); \ + \ + _RETURN_FAIL \ }; #endif @@ -238,62 +245,84 @@ bool Variant::booleanize() const { _RETURN(p_a._data.m_type); \ }; -#define DEFAULT_OP_NUM_VEC(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == INT) _RETURN(p_a._data.m_type m_op p_b._data._int); \ - if (p_b.type == FLOAT) _RETURN(p_a._data.m_type m_op p_b._data._float); \ - if (p_b.type == VECTOR2) _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector2 *>(p_b._data._mem)); \ - if (p_b.type == VECTOR3) _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector3 *>(p_b._data._mem)); \ - if (p_b.type == VECTOR2I) _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector2 *>(p_b._data._mem)); \ - if (p_b.type == VECTOR3I) _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector3 *>(p_b._data._mem)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_NUM_VEC(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == INT) \ + _RETURN(p_a._data.m_type m_op p_b._data._int); \ + if (p_b.type == FLOAT) \ + _RETURN(p_a._data.m_type m_op p_b._data._float); \ + if (p_b.type == VECTOR2) \ + _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector2 *>(p_b._data._mem)); \ + if (p_b.type == VECTOR3) \ + _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector3 *>(p_b._data._mem)); \ + if (p_b.type == VECTOR2I) \ + _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector2 *>(p_b._data._mem)); \ + if (p_b.type == VECTOR3I) \ + _RETURN(p_a._data.m_type m_op *reinterpret_cast<const Vector3 *>(p_b._data._mem)); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_STR_REV(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == STRING) _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const String *>(p_a._data._mem)); \ - if (p_b.type == STRING_NAME) _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const StringName *>(p_a._data._mem)); \ - if (p_b.type == NODE_PATH) _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const NodePath *>(p_a._data._mem)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_STR_REV(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == STRING) \ + _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const String *>(p_a._data._mem)); \ + if (p_b.type == STRING_NAME) \ + _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const StringName *>(p_a._data._mem)); \ + if (p_b.type == NODE_PATH) \ + _RETURN(*reinterpret_cast<const m_type *>(p_b._data._mem) m_op *reinterpret_cast<const NodePath *>(p_a._data._mem)); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_STR(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == STRING) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ - if (p_b.type == STRING_NAME) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ - if (p_b.type == NODE_PATH) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_STR(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == STRING) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ + if (p_b.type == STRING_NAME) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ + if (p_b.type == NODE_PATH) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_STR_NULL(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == STRING) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ - if (p_b.type == STRING_NAME) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ - if (p_b.type == NODE_PATH) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ - if (p_b.type == NIL) _RETURN(!(p_b.type m_op NIL)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_STR_NULL(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == STRING) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ + if (p_b.type == STRING_NAME) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ + if (p_b.type == NODE_PATH) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ + if (p_b.type == NIL) \ + _RETURN(!(p_b.type m_op NIL)); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_STR_NULL_NP(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == STRING) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ - if (p_b.type == NODE_PATH) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ - if (p_b.type == NIL) _RETURN(!(p_b.type m_op NIL)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_STR_NULL_NP(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == STRING) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ + if (p_b.type == NODE_PATH) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const NodePath *>(p_b._data._mem)); \ + if (p_b.type == NIL) \ + _RETURN(!(p_b.type m_op NIL)); \ + \ + _RETURN_FAIL \ }; -#define DEFAULT_OP_STR_NULL_SN(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == STRING) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ - if (p_b.type == STRING_NAME) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ - if (p_b.type == NIL) _RETURN(!(p_b.type m_op NIL)); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_STR_NULL_SN(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == STRING) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const String *>(p_b._data._mem)); \ + if (p_b.type == STRING_NAME) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const StringName *>(p_b._data._mem)); \ + if (p_b.type == NIL) \ + _RETURN(!(p_b.type m_op NIL)); \ + \ + _RETURN_FAIL \ }; #define DEFAULT_OP_LOCALMEM_REV(m_prefix, m_op_name, m_name, m_op, m_type) \ @@ -332,13 +361,16 @@ bool Variant::booleanize() const { _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem)); \ } -#define DEFAULT_OP_LOCALMEM_NUM(m_prefix, m_op_name, m_name, m_op, m_type) \ - CASE_TYPE(m_prefix, m_op_name, m_name) { \ - if (p_b.type == m_name) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const m_type *>(p_b._data._mem)); \ - if (p_b.type == INT) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op p_b._data._int); \ - if (p_b.type == FLOAT) _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op p_b._data._float); \ - \ - _RETURN_FAIL \ +#define DEFAULT_OP_LOCALMEM_NUM(m_prefix, m_op_name, m_name, m_op, m_type) \ + CASE_TYPE(m_prefix, m_op_name, m_name) { \ + if (p_b.type == m_name) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op *reinterpret_cast<const m_type *>(p_b._data._mem)); \ + if (p_b.type == INT) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op p_b._data._int); \ + if (p_b.type == FLOAT) \ + _RETURN(*reinterpret_cast<const m_type *>(p_a._data._mem) m_op p_b._data._float); \ + \ + _RETURN_FAIL \ } #define DEFAULT_OP_PTR(m_op, m_name, m_sub) \ @@ -436,7 +468,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, SWITCH(math, p_op, p_a.type) { SWITCH_OP(math, OP_EQUAL, p_a.type) { CASE_TYPE(math, OP_EQUAL, NIL) { - if (p_b.type == NIL) _RETURN(true); + if (p_b.type == NIL) + _RETURN(true); if (p_b.type == OBJECT) _RETURN(p_b._get_obj().obj == nullptr); @@ -532,7 +565,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, SWITCH_OP(math, OP_NOT_EQUAL, p_a.type) { CASE_TYPE(math, OP_NOT_EQUAL, NIL) { - if (p_b.type == NIL) _RETURN(false); + if (p_b.type == NIL) + _RETURN(false); if (p_b.type == OBJECT) _RETURN(p_b._get_obj().obj != nullptr); @@ -981,7 +1015,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, case VECTOR2: { _RETURN(p_a._data._transform2d->xform(*(const Vector2 *)p_b._data._mem)); } - default: _RETURN_FAIL; + default: + _RETURN_FAIL; } } @@ -996,7 +1031,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, case FLOAT: { _RETURN(*reinterpret_cast<const Quat *>(p_a._data._mem) * p_b._data._float); } - default: _RETURN_FAIL; + default: + _RETURN_FAIL; } } @@ -1008,7 +1044,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, case BASIS: { _RETURN(*p_a._data._basis * *p_b._data._basis); } - default: _RETURN_FAIL; + default: + _RETURN_FAIL; } } @@ -1020,7 +1057,8 @@ void Variant::evaluate(const Operator &p_op, const Variant &p_a, case TRANSFORM: { _RETURN(*p_a._data._transform * *p_b._data._transform); } - default: _RETURN_FAIL; + default: + _RETURN_FAIL; } } @@ -1983,7 +2021,8 @@ Variant Variant::get_named(const StringName &p_index, bool *r_valid) const { #define DEFAULT_OP_DVECTOR_SET(m_name, m_type, skip_cond) \ case m_name: { \ - if (skip_cond) return; \ + if (skip_cond) \ + return; \ \ if (p_index.get_type() == Variant::INT || p_index.get_type() == Variant::FLOAT) { \ int index = p_index; \ @@ -3458,17 +3497,39 @@ bool Variant::iter_init(Variant &r_iter, bool &valid) const { return _data._float > 0.0; } break; case VECTOR2: { - int64_t from = reinterpret_cast<const Vector2 *>(_data._mem)->x; - int64_t to = reinterpret_cast<const Vector2 *>(_data._mem)->y; + double from = reinterpret_cast<const Vector2 *>(_data._mem)->x; + double to = reinterpret_cast<const Vector2 *>(_data._mem)->y; + + r_iter = from; + + return from < to; + } break; + case VECTOR2I: { + int64_t from = reinterpret_cast<const Vector2i *>(_data._mem)->x; + int64_t to = reinterpret_cast<const Vector2i *>(_data._mem)->y; r_iter = from; return from < to; } break; case VECTOR3: { - int64_t from = reinterpret_cast<const Vector3 *>(_data._mem)->x; - int64_t to = reinterpret_cast<const Vector3 *>(_data._mem)->y; - int64_t step = reinterpret_cast<const Vector3 *>(_data._mem)->z; + double from = reinterpret_cast<const Vector3 *>(_data._mem)->x; + double to = reinterpret_cast<const Vector3 *>(_data._mem)->y; + double step = reinterpret_cast<const Vector3 *>(_data._mem)->z; + + r_iter = from; + + if (from == to) { + return false; + } else if (from < to) { + return step > 0; + } + return step < 0; + } break; + case VECTOR3I: { + int64_t from = reinterpret_cast<const Vector3i *>(_data._mem)->x; + int64_t to = reinterpret_cast<const Vector3i *>(_data._mem)->y; + int64_t step = reinterpret_cast<const Vector3i *>(_data._mem)->z; r_iter = from; @@ -3476,10 +3537,8 @@ bool Variant::iter_init(Variant &r_iter, bool &valid) const { return false; } else if (from < to) { return step > 0; - } else { - return step < 0; } - //return true; + return step < 0; } break; case OBJECT: { @@ -3640,7 +3699,19 @@ bool Variant::iter_next(Variant &r_iter, bool &valid) const { return true; } break; case VECTOR2: { - int64_t to = reinterpret_cast<const Vector2 *>(_data._mem)->y; + double to = reinterpret_cast<const Vector2 *>(_data._mem)->y; + + double idx = r_iter; + idx++; + + if (idx >= to) + return false; + + r_iter = idx; + return true; + } break; + case VECTOR2I: { + int64_t to = reinterpret_cast<const Vector2i *>(_data._mem)->y; int64_t idx = r_iter; idx++; @@ -3652,8 +3723,24 @@ bool Variant::iter_next(Variant &r_iter, bool &valid) const { return true; } break; case VECTOR3: { - int64_t to = reinterpret_cast<const Vector3 *>(_data._mem)->y; - int64_t step = reinterpret_cast<const Vector3 *>(_data._mem)->z; + double to = reinterpret_cast<const Vector3 *>(_data._mem)->y; + double step = reinterpret_cast<const Vector3 *>(_data._mem)->z; + + double idx = r_iter; + idx += step; + + if (step < 0 && idx <= to) + return false; + + if (step > 0 && idx >= to) + return false; + + r_iter = idx; + return true; + } break; + case VECTOR3I: { + int64_t to = reinterpret_cast<const Vector3i *>(_data._mem)->y; + int64_t step = reinterpret_cast<const Vector3i *>(_data._mem)->z; int64_t idx = r_iter; idx += step; @@ -3844,10 +3931,18 @@ Variant Variant::iter_get(const Variant &r_iter, bool &r_valid) const { return r_iter; } break; + case VECTOR2I: { + + return r_iter; + } break; case VECTOR3: { return r_iter; } break; + case VECTOR3I: { + + return r_iter; + } break; case OBJECT: { if (!_get_obj().obj) { diff --git a/core/variant_parser.cpp b/core/variant_parser.cpp index 0a578faf78..c9678c9933 100644 --- a/core/variant_parser.cpp +++ b/core/variant_parser.cpp @@ -243,11 +243,21 @@ Error VariantParser::get_token(Stream *p_stream, Token &r_token, int &line, Stri switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; - case 'r': res = 13; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; + case 'r': + res = 13; + break; case 'u': { //hex number for (int j = 0; j < 4; j++) { diff --git a/core/variant_parser.h b/core/variant_parser.h index 63ed51bcc9..af7d55d1b8 100644 --- a/core/variant_parser.h +++ b/core/variant_parser.h @@ -43,34 +43,33 @@ public: virtual bool is_utf8() const = 0; virtual bool is_eof() const = 0; - CharType saved; + CharType saved = 0; - Stream() : - saved(0) {} + Stream() {} virtual ~Stream() {} }; struct StreamFile : public Stream { - FileAccess *f; + FileAccess *f = nullptr; virtual CharType get_char(); virtual bool is_utf8() const; virtual bool is_eof() const; - StreamFile() { f = nullptr; } + StreamFile() {} }; struct StreamString : public Stream { String s; - int pos; + int pos = 0; virtual CharType get_char(); virtual bool is_utf8() const; virtual bool is_eof() const; - StreamString() { pos = 0; } + StreamString() {} }; typedef Error (*ParseResourceFunc)(void *p_self, Stream *p_stream, Ref<Resource> &r_res, int &line, String &r_err_str); diff --git a/core/vector.h b/core/vector.h index 7277179621..7ab464fe11 100644 --- a/core/vector.h +++ b/core/vector.h @@ -39,6 +39,7 @@ #include "core/cowdata.h" #include "core/error_macros.h" +#include "core/os/copymem.h" #include "core/os/memory.h" #include "core/sort_array.h" @@ -69,7 +70,8 @@ public: void remove(int p_index) { _cowdata.remove(p_index); } void erase(const T &p_val) { int idx = find(p_val); - if (idx >= 0) remove(idx); + if (idx >= 0) + remove(idx); } void invert(); @@ -117,13 +119,18 @@ public: insert(i, p_val); } - _FORCE_INLINE_ Vector() {} - _FORCE_INLINE_ Vector(const Vector &p_from) { _cowdata._ref(p_from._cowdata); } inline Vector &operator=(const Vector &p_from) { _cowdata._ref(p_from._cowdata); return *this; } + Vector<uint8_t> to_byte_array() const { + Vector<uint8_t> ret; + ret.resize(size() * sizeof(T)); + copymem(ret.ptrw(), ptr(), sizeof(T) * size()); + return ret; + } + Vector<T> subarray(int p_from, int p_to) const { if (p_from < 0) { @@ -148,6 +155,9 @@ public: return slice; } + _FORCE_INLINE_ Vector() {} + _FORCE_INLINE_ Vector(const Vector &p_from) { _cowdata._ref(p_from._cowdata); } + _FORCE_INLINE_ ~Vector() {} }; diff --git a/core/vmap.h b/core/vmap.h index 84ae1aaf66..848b5055fa 100644 --- a/core/vmap.h +++ b/core/vmap.h @@ -202,11 +202,13 @@ public: return _cowdata.get_m(pos).value; } - _FORCE_INLINE_ VMap(){}; + _FORCE_INLINE_ VMap() {} _FORCE_INLINE_ VMap(const VMap &p_from) { _cowdata._ref(p_from._cowdata); } + inline VMap &operator=(const VMap &p_from) { _cowdata._ref(p_from._cowdata); return *this; } }; + #endif // VMAP_H diff --git a/doc/classes/@GlobalScope.xml b/doc/classes/@GlobalScope.xml index 222b70c9c7..bf2ce321ac 100644 --- a/doc/classes/@GlobalScope.xml +++ b/doc/classes/@GlobalScope.xml @@ -77,9 +77,6 @@ <member name="ProjectSettings" type="ProjectSettings" setter="" getter=""> The [ProjectSettings] singleton. </member> - <member name="RenderingDevice" type="RenderingDevice" setter="" getter=""> - The [RenderingDevice] singleton. - </member> <member name="RenderingServer" type="RenderingServer" setter="" getter=""> The [RenderingServer] singleton. </member> @@ -946,212 +943,185 @@ <constant name="BUTTON_MASK_XBUTTON2" value="256" enum="ButtonList"> Extra mouse button 2 mask. </constant> - <constant name="JOY_BUTTON_0" value="0" enum="JoystickList"> - Gamepad button 0. - </constant> - <constant name="JOY_BUTTON_1" value="1" enum="JoystickList"> - Gamepad button 1. - </constant> - <constant name="JOY_BUTTON_2" value="2" enum="JoystickList"> - Gamepad button 2. - </constant> - <constant name="JOY_BUTTON_3" value="3" enum="JoystickList"> - Gamepad button 3. - </constant> - <constant name="JOY_BUTTON_4" value="4" enum="JoystickList"> - Gamepad button 4. - </constant> - <constant name="JOY_BUTTON_5" value="5" enum="JoystickList"> - Gamepad button 5. - </constant> - <constant name="JOY_BUTTON_6" value="6" enum="JoystickList"> - Gamepad button 6. - </constant> - <constant name="JOY_BUTTON_7" value="7" enum="JoystickList"> - Gamepad button 7. - </constant> - <constant name="JOY_BUTTON_8" value="8" enum="JoystickList"> - Gamepad button 8. - </constant> - <constant name="JOY_BUTTON_9" value="9" enum="JoystickList"> - Gamepad button 9. + <constant name="JOY_INVALID_BUTTON" value="-1" enum="JoyButtonList"> + An invalid game controller button. </constant> - <constant name="JOY_BUTTON_10" value="10" enum="JoystickList"> - Gamepad button 10. + <constant name="JOY_BUTTON_A" value="0" enum="JoyButtonList"> + Game controller SDL button A. </constant> - <constant name="JOY_BUTTON_11" value="11" enum="JoystickList"> - Gamepad button 11. + <constant name="JOY_BUTTON_B" value="1" enum="JoyButtonList"> + Game controller SDL button B. </constant> - <constant name="JOY_BUTTON_12" value="12" enum="JoystickList"> - Gamepad button 12. + <constant name="JOY_BUTTON_X" value="2" enum="JoyButtonList"> + Game controller SDL button X. </constant> - <constant name="JOY_BUTTON_13" value="13" enum="JoystickList"> - Gamepad button 13. + <constant name="JOY_BUTTON_Y" value="3" enum="JoyButtonList"> + Game controller SDL button Y. </constant> - <constant name="JOY_BUTTON_14" value="14" enum="JoystickList"> - Gamepad button 14. + <constant name="JOY_BUTTON_BACK" value="4" enum="JoyButtonList"> + Game controller SDL back button. </constant> - <constant name="JOY_BUTTON_15" value="15" enum="JoystickList"> - Gamepad button 15. + <constant name="JOY_BUTTON_GUIDE" value="5" enum="JoyButtonList"> + Game controller SDL guide button. </constant> - <constant name="JOY_BUTTON_MAX" value="16" enum="JoystickList"> - Represents the maximum number of joystick buttons supported. + <constant name="JOY_BUTTON_START" value="6" enum="JoyButtonList"> + Game controller SDL start button. </constant> - <constant name="JOY_SONY_CIRCLE" value="1" enum="JoystickList"> - DualShock circle button. + <constant name="JOY_BUTTON_LEFT_STICK" value="7" enum="JoyButtonList"> + Game controller SDL left stick button. </constant> - <constant name="JOY_SONY_X" value="0" enum="JoystickList"> - DualShock X button. + <constant name="JOY_BUTTON_RIGHT_STICK" value="8" enum="JoyButtonList"> + Game controller SDL right stick button. </constant> - <constant name="JOY_SONY_SQUARE" value="2" enum="JoystickList"> - DualShock square button. + <constant name="JOY_BUTTON_LEFT_SHOULDER" value="9" enum="JoyButtonList"> + Game controller SDL left shoulder button. </constant> - <constant name="JOY_SONY_TRIANGLE" value="3" enum="JoystickList"> - DualShock triangle button. + <constant name="JOY_BUTTON_RIGHT_SHOULDER" value="10" enum="JoyButtonList"> + Game controller SDL right shoulder button. </constant> - <constant name="JOY_XBOX_B" value="1" enum="JoystickList"> - Xbox controller B button. + <constant name="JOY_BUTTON_DPAD_UP" value="11" enum="JoyButtonList"> + Game controller SDL D-pad up button. </constant> - <constant name="JOY_XBOX_A" value="0" enum="JoystickList"> - Xbox controller A button. + <constant name="JOY_BUTTON_DPAD_DOWN" value="12" enum="JoyButtonList"> + Game controller SDL D-pad down button. </constant> - <constant name="JOY_XBOX_X" value="2" enum="JoystickList"> - Xbox controller X button. + <constant name="JOY_BUTTON_DPAD_LEFT" value="13" enum="JoyButtonList"> + Game controller SDL D-pad left button. </constant> - <constant name="JOY_XBOX_Y" value="3" enum="JoystickList"> - Xbox controller Y button. + <constant name="JOY_BUTTON_DPAD_RIGHT" value="14" enum="JoyButtonList"> + Game controller SDL D-pad right button. </constant> - <constant name="JOY_DS_A" value="1" enum="JoystickList"> - Nintendo controller A button. + <constant name="JOY_SDL_BUTTONS" value="15" enum="JoyButtonList"> + The number of SDL game controller buttons. </constant> - <constant name="JOY_DS_B" value="0" enum="JoystickList"> - Nintendo controller B button. + <constant name="JOY_SONY_X" value="0" enum="JoyButtonList"> + Sony DualShock controller X button maps to SDL button A. </constant> - <constant name="JOY_DS_X" value="3" enum="JoystickList"> - Nintendo controller X button. + <constant name="JOY_SONY_CROSS" value="0" enum="JoyButtonList"> + Sony DualShock controller cross button maps to SDL button A. </constant> - <constant name="JOY_DS_Y" value="2" enum="JoystickList"> - Nintendo controller Y button. + <constant name="JOY_SONY_CIRCLE" value="1" enum="JoyButtonList"> + Sony DualShock controller circle button maps to SDL button B. </constant> - <constant name="JOY_VR_GRIP" value="2" enum="JoystickList"> - Grip (side) buttons on a VR controller. + <constant name="JOY_SONY_SQUARE" value="2" enum="JoyButtonList"> + Sony DualShock controller square button maps to SDL button X. </constant> - <constant name="JOY_VR_PAD" value="14" enum="JoystickList"> - Push down on the touchpad or main joystick on a VR controller. + <constant name="JOY_SONY_TRIANGLE" value="3" enum="JoyButtonList"> + Sony DualShock controller triangle button maps to SDL button Y. </constant> - <constant name="JOY_VR_TRIGGER" value="15" enum="JoystickList"> - Trigger on a VR controller. + <constant name="JOY_SONY_SELECT" value="4" enum="JoyButtonList"> + Sony DualShock controller select button maps to SDL back button. </constant> - <constant name="JOY_OCULUS_AX" value="7" enum="JoystickList"> - A button on the right Oculus Touch controller, X button on the left controller (also when used in OpenVR). + <constant name="JOY_SONY_START" value="6" enum="JoyButtonList"> + Sony DualShock controller start button maps to SDL start button. </constant> - <constant name="JOY_OCULUS_BY" value="1" enum="JoystickList"> - B button on the right Oculus Touch controller, Y button on the left controller (also when used in OpenVR). + <constant name="JOY_SONY_PS" value="5" enum="JoyButtonList"> + Sony DualShock controller PS button maps to SDL guide button. </constant> - <constant name="JOY_OCULUS_MENU" value="3" enum="JoystickList"> - Menu button on either Oculus Touch controller. + <constant name="JOY_SONY_L1" value="9" enum="JoyButtonList"> + Sony DualShock controller L1 button maps to SDL left shoulder button. </constant> - <constant name="JOY_OPENVR_MENU" value="1" enum="JoystickList"> - Menu button in OpenVR (Except when Oculus Touch controllers are used). + <constant name="JOY_SONY_R1" value="10" enum="JoyButtonList"> + Sony DualShock controller R1 button maps to SDL right shoulder button. </constant> - <constant name="JOY_SELECT" value="10" enum="JoystickList"> - Gamepad button Select. + <constant name="JOY_SONY_L3" value="7" enum="JoyButtonList"> + Sony DualShock controller L3 button maps to SDL left stick button. </constant> - <constant name="JOY_START" value="11" enum="JoystickList"> - Gamepad button Start. + <constant name="JOY_SONY_R3" value="8" enum="JoyButtonList"> + Sony DualShock controller R3 button maps to SDL right stick button. </constant> - <constant name="JOY_DPAD_UP" value="12" enum="JoystickList"> - Gamepad DPad up. + <constant name="JOY_XBOX_A" value="0" enum="JoyButtonList"> + Xbox game controller A button maps to SDL button A. </constant> - <constant name="JOY_DPAD_DOWN" value="13" enum="JoystickList"> - Gamepad DPad down. + <constant name="JOY_XBOX_B" value="1" enum="JoyButtonList"> + Xbox game controller B button maps to SDL button B. </constant> - <constant name="JOY_DPAD_LEFT" value="14" enum="JoystickList"> - Gamepad DPad left. + <constant name="JOY_XBOX_X" value="2" enum="JoyButtonList"> + Xbox game controller X button maps to SDL button X. </constant> - <constant name="JOY_DPAD_RIGHT" value="15" enum="JoystickList"> - Gamepad DPad right. + <constant name="JOY_XBOX_Y" value="3" enum="JoyButtonList"> + Xbox game controller Y button maps to SDL button Y. </constant> - <constant name="JOY_L" value="4" enum="JoystickList"> - Gamepad left Shoulder button. + <constant name="JOY_XBOX_BACK" value="4" enum="JoyButtonList"> + Xbox game controller back button maps to SDL back button. </constant> - <constant name="JOY_L2" value="6" enum="JoystickList"> - Gamepad left trigger. + <constant name="JOY_XBOX_START" value="6" enum="JoyButtonList"> + Xbox game controller start button maps to SDL start button. </constant> - <constant name="JOY_L3" value="8" enum="JoystickList"> - Gamepad left stick click. + <constant name="JOY_XBOX_HOME" value="5" enum="JoyButtonList"> + Xbox game controller home button maps to SDL guide button. </constant> - <constant name="JOY_R" value="5" enum="JoystickList"> - Gamepad right Shoulder button. + <constant name="JOY_XBOX_LS" value="7" enum="JoyButtonList"> + Xbox game controller left stick button maps to SDL left stick button. </constant> - <constant name="JOY_R2" value="7" enum="JoystickList"> - Gamepad right trigger. + <constant name="JOY_XBOX_RS" value="8" enum="JoyButtonList"> + Xbox game controller right stick button maps to SDL right stick button. </constant> - <constant name="JOY_R3" value="9" enum="JoystickList"> - Gamepad right stick click. + <constant name="JOY_XBOX_LB" value="9" enum="JoyButtonList"> + Xbox game controller left bumper button maps to SDL left shoulder button. </constant> - <constant name="JOY_AXIS_0" value="0" enum="JoystickList"> - Gamepad left stick horizontal axis. + <constant name="JOY_XBOX_RB" value="10" enum="JoyButtonList"> + Xbox game controller right bumper button maps to SDL right shoulder button. </constant> - <constant name="JOY_AXIS_1" value="1" enum="JoystickList"> - Gamepad left stick vertical axis. + <constant name="JOY_BUTTON_MAX" value="36" enum="JoyAxisList"> + The maximum number of game controller buttons. </constant> - <constant name="JOY_AXIS_2" value="2" enum="JoystickList"> - Gamepad right stick horizontal axis. + <constant name="JOY_INVALID_BUTTON" value="-1" enum="JoyButtonList"> + An invalid game controller axis. </constant> - <constant name="JOY_AXIS_3" value="3" enum="JoystickList"> - Gamepad right stick vertical axis. + <constant name="JOY_AXIS_LEFT_X" value="0" enum="JoyAxisList"> + Game controller left joystick x-axis. </constant> - <constant name="JOY_AXIS_4" value="4" enum="JoystickList"> - Generic gamepad axis 4. + <constant name="JOY_AXIS_LEFT_Y" value="1" enum="JoyAxisList"> + Game controller left joystick y-axis. </constant> - <constant name="JOY_AXIS_5" value="5" enum="JoystickList"> - Generic gamepad axis 5. + <constant name="JOY_AXIS_RIGHT_X" value="2" enum="JoyAxisList"> + Game controller right joystick x-axis. </constant> - <constant name="JOY_AXIS_6" value="6" enum="JoystickList"> - Gamepad left trigger analog axis. + <constant name="JOY_AXIS_RIGHT_Y" value="3" enum="JoyAxisList"> + Game controller right joystick y-axis. </constant> - <constant name="JOY_AXIS_7" value="7" enum="JoystickList"> - Gamepad right trigger analog axis. + <constant name="JOY_AXIS_TRIGGER_LEFT" value="4" enum="JoyAxisList"> + Game controller left trigger axis. </constant> - <constant name="JOY_AXIS_8" value="8" enum="JoystickList"> - Generic gamepad axis 8. + <constant name="JOY_AXIS_TRIGGER_RIGHT" value="5" enum="JoyAxisList"> + Game controller right trigger axis. </constant> - <constant name="JOY_AXIS_9" value="9" enum="JoystickList"> - Generic gamepad axis 9. + <constant name="JOY_SDL_AXES" value="6" enum="JoyAxisList"> + The number of SDL game controller axes. </constant> - <constant name="JOY_AXIS_MAX" value="10" enum="JoystickList"> - Represents the maximum number of joystick axes supported. + <constant name="JOY_AXIS_0_X" value="0" enum="JoyAxisList"> + Game controller joystick 0 x-axis. </constant> - <constant name="JOY_ANALOG_LX" value="0" enum="JoystickList"> - Gamepad left stick horizontal axis. + <constant name="JOY_AXIS_0_Y" value="1" enum="JoyAxisList"> + Game controller joystick 0 y-axis. </constant> - <constant name="JOY_ANALOG_LY" value="1" enum="JoystickList"> - Gamepad left stick vertical axis. + <constant name="JOY_AXIS_1_X" value="2" enum="JoyAxisList"> + Game controller joystick 1 x-axis. </constant> - <constant name="JOY_ANALOG_RX" value="2" enum="JoystickList"> - Gamepad right stick horizontal axis. + <constant name="JOY_AXIS_1_Y" value="3" enum="JoyAxisList"> + Game controller joystick 1 y-axis. </constant> - <constant name="JOY_ANALOG_RY" value="3" enum="JoystickList"> - Gamepad right stick vertical axis. + <constant name="JOY_AXIS_2_X" value="4" enum="JoyAxisList"> + Game controller joystick 2 x-axis. </constant> - <constant name="JOY_ANALOG_L2" value="6" enum="JoystickList"> - Gamepad left analog trigger. + <constant name="JOY_AXIS_2_Y" value="5" enum="JoyAxisList"> + Game controller joystick 2 y-axis. </constant> - <constant name="JOY_ANALOG_R2" value="7" enum="JoystickList"> - Gamepad right analog trigger. + <constant name="JOY_AXIS_3_X" value="6" enum="JoyAxisList"> + Game controller joystick 3 x-axis. </constant> - <constant name="JOY_VR_ANALOG_TRIGGER" value="2" enum="JoystickList"> - VR Controller analog trigger. + <constant name="JOY_AXIS_3_Y" value="7" enum="JoyAxisList"> + Game controller joystick 3 y-axis. </constant> - <constant name="JOY_VR_ANALOG_GRIP" value="4" enum="JoystickList"> - VR Controller analog grip (side buttons). + <constant name="JOY_AXIS_4_X" value="8" enum="JoyAxisList"> + Game controller joystick 4 x-axis. </constant> - <constant name="JOY_OPENVR_TOUCHPADX" value="0" enum="JoystickList"> - OpenVR touchpad X axis (Joystick axis on Oculus Touch and Windows MR controllers). + <constant name="JOY_AXIS_4_Y" value="9" enum="JoyAxisList"> + Game controller joystick 4 y-axis. </constant> - <constant name="JOY_OPENVR_TOUCHPADY" value="1" enum="JoystickList"> - OpenVR touchpad Y axis (Joystick axis on Oculus Touch and Windows MR controllers). + <constant name="JOY_AXIS_MAX" value="10" enum="JoyAxisList"> + The maximum number of game controller axes. </constant> <constant name="MIDI_MESSAGE_NOTE_OFF" value="8" enum="MidiMessageList"> MIDI note OFF message. diff --git a/doc/classes/BakedLightmap.xml b/doc/classes/BakedLightmap.xml new file mode 100644 index 0000000000..6fd08fc4e4 --- /dev/null +++ b/doc/classes/BakedLightmap.xml @@ -0,0 +1,81 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="BakedLightmap" inherits="VisualInstance3D" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <members> + <member name="bias" type="float" setter="set_bias" getter="get_bias" default="0.0005"> + </member> + <member name="bounces" type="int" setter="set_bounces" getter="get_bounces" default="1"> + </member> + <member name="directional" type="bool" setter="set_directional" getter="is_directional" default="false"> + </member> + <member name="environment_custom_color" type="Color" setter="set_environment_custom_color" getter="get_environment_custom_color"> + </member> + <member name="environment_custom_energy" type="float" setter="set_environment_custom_energy" getter="get_environment_custom_energy"> + </member> + <member name="environment_custom_sky" type="Sky" setter="set_environment_custom_sky" getter="get_environment_custom_sky"> + </member> + <member name="environment_mode" type="int" setter="set_environment_mode" getter="get_environment_mode" enum="BakedLightmap.EnvironmentMode" default="0"> + </member> + <member name="generate_probes_subdiv" type="int" setter="set_generate_probes" getter="get_generate_probes" enum="BakedLightmap.GenerateProbes" default="0"> + </member> + <member name="interior" type="bool" setter="set_interior" getter="is_interior" default="false"> + </member> + <member name="light_data" type="BakedLightmapData" setter="set_light_data" getter="get_light_data"> + </member> + <member name="max_texture_size" type="int" setter="set_max_texture_size" getter="get_max_texture_size" default="16384"> + </member> + <member name="quality" type="int" setter="set_bake_quality" getter="get_bake_quality" enum="BakedLightmap.BakeQuality" default="1"> + </member> + <member name="use_denoiser" type="bool" setter="set_use_denoiser" getter="is_using_denoiser" default="true"> + </member> + </members> + <constants> + <constant name="BAKE_QUALITY_LOW" value="0" enum="BakeQuality"> + </constant> + <constant name="BAKE_QUALITY_MEDIUM" value="1" enum="BakeQuality"> + </constant> + <constant name="BAKE_QUALITY_HIGH" value="2" enum="BakeQuality"> + </constant> + <constant name="BAKE_QUALITY_ULTRA" value="3" enum="BakeQuality"> + </constant> + <constant name="GENERATE_PROBES_DISABLED" value="0" enum="GenerateProbes"> + </constant> + <constant name="GENERATE_PROBES_SUBDIV_4" value="1" enum="GenerateProbes"> + </constant> + <constant name="GENERATE_PROBES_SUBDIV_8" value="2" enum="GenerateProbes"> + </constant> + <constant name="GENERATE_PROBES_SUBDIV_16" value="3" enum="GenerateProbes"> + </constant> + <constant name="GENERATE_PROBES_SUBDIV_32" value="4" enum="GenerateProbes"> + </constant> + <constant name="BAKE_ERROR_OK" value="0" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_NO_LIGHTMAPPER" value="1" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_NO_SAVE_PATH" value="2" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_NO_MESHES" value="3" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_MESHES_INVALID" value="4" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_CANT_CREATE_IMAGE" value="5" enum="BakeError"> + </constant> + <constant name="BAKE_ERROR_USER_ABORTED" value="6" enum="BakeError"> + </constant> + <constant name="ENVIRONMENT_MODE_DISABLED" value="0" enum="EnvironmentMode"> + </constant> + <constant name="ENVIRONMENT_MODE_SCENE" value="1" enum="EnvironmentMode"> + </constant> + <constant name="ENVIRONMENT_MODE_CUSTOM_SKY" value="2" enum="EnvironmentMode"> + </constant> + <constant name="ENVIRONMENT_MODE_CUSTOM_COLOR" value="3" enum="EnvironmentMode"> + </constant> + </constants> +</class> diff --git a/doc/classes/BakedLightmapData.xml b/doc/classes/BakedLightmapData.xml new file mode 100644 index 0000000000..026477782a --- /dev/null +++ b/doc/classes/BakedLightmapData.xml @@ -0,0 +1,65 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="BakedLightmapData" inherits="Resource" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + <method name="add_user"> + <return type="void"> + </return> + <argument index="0" name="path" type="NodePath"> + </argument> + <argument index="1" name="lightmap" type="Rect2"> + </argument> + <argument index="2" name="offset" type="int"> + </argument> + <argument index="3" name="arg3" type="int"> + </argument> + <description> + </description> + </method> + <method name="clear_users"> + <return type="void"> + </return> + <description> + </description> + </method> + <method name="get_user_count" qualifiers="const"> + <return type="int"> + </return> + <description> + </description> + </method> + <method name="get_user_path" qualifiers="const"> + <return type="NodePath"> + </return> + <argument index="0" name="user_idx" type="int"> + </argument> + <description> + </description> + </method> + <method name="is_using_spherical_harmonics" qualifiers="const"> + <return type="bool"> + </return> + <description> + </description> + </method> + <method name="set_uses_spherical_harmonics"> + <return type="void"> + </return> + <argument index="0" name="uses_spherical_harmonics" type="bool"> + </argument> + <description> + </description> + </method> + </methods> + <members> + <member name="light_texture" type="TextureLayered" setter="set_light_texture" getter="get_light_texture"> + </member> + </members> + <constants> + </constants> +</class> diff --git a/doc/classes/Cubemap.xml b/doc/classes/Cubemap.xml index 16431c65c9..61cb1d43f0 100644 --- a/doc/classes/Cubemap.xml +++ b/doc/classes/Cubemap.xml @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8" ?> -<class name="Cubemap" inherits="TextureLayered" version="4.0"> +<class name="Cubemap" inherits="ImageTextureLayered" version="4.0"> <brief_description> </brief_description> <description> diff --git a/doc/classes/CubemapArray.xml b/doc/classes/CubemapArray.xml index 03cfd75acf..627baf79e0 100644 --- a/doc/classes/CubemapArray.xml +++ b/doc/classes/CubemapArray.xml @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8" ?> -<class name="CubemapArray" inherits="TextureLayered" version="4.0"> +<class name="CubemapArray" inherits="ImageTextureLayered" version="4.0"> <brief_description> </brief_description> <description> diff --git a/doc/classes/Engine.xml b/doc/classes/Engine.xml index 45c153b6dc..12701d8688 100644 --- a/doc/classes/Engine.xml +++ b/doc/classes/Engine.xml @@ -1,10 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="Engine" inherits="Object" version="4.0"> <brief_description> - Access to basic engine properties. + Access to engine properties. </brief_description> <description> - The [Engine] class allows you to query and modify the project's run-time parameters, such as frames per second, time scale, and others. + The [Engine] singleton allows you to query and modify the project's run-time parameters, such as frames per second, time scale, and others. </description> <tutorials> </tutorials> diff --git a/doc/classes/GIProbe.xml b/doc/classes/GIProbe.xml index df50244c77..23dd562653 100644 --- a/doc/classes/GIProbe.xml +++ b/doc/classes/GIProbe.xml @@ -19,7 +19,7 @@ <argument index="1" name="create_visual_debug" type="bool" default="false"> </argument> <description> - Bakes the effect from all [GeometryInstance3D]s marked with [member GeometryInstance3D.use_in_baked_light] and [Light3D]s marked with either [constant Light3D.BAKE_INDIRECT] or [constant Light3D.BAKE_ALL]. If [code]create_visual_debug[/code] is [code]true[/code], after baking the light, this will generate a [MultiMesh] that has a cube representing each solid cell with each cube colored to the cell's albedo color. This can be used to visualize the [GIProbe]'s data and debug any issues that may be occurring. + Bakes the effect from all [GeometryInstance3D]s marked with [constant GeometryInstance3D.GI_MODE_BAKED] and [Light3D]s marked with either [constant Light3D.BAKE_INDIRECT] or [constant Light3D.BAKE_ALL]. If [code]create_visual_debug[/code] is [code]true[/code], after baking the light, this will generate a [MultiMesh] that has a cube representing each solid cell with each cube colored to the cell's albedo color. This can be used to visualize the [GIProbe]'s data and debug any issues that may be occurring. </description> </method> <method name="debug_bake"> diff --git a/doc/classes/GeometryInstance3D.xml b/doc/classes/GeometryInstance3D.xml index 518a172973..cc85ce295b 100644 --- a/doc/classes/GeometryInstance3D.xml +++ b/doc/classes/GeometryInstance3D.xml @@ -9,15 +9,6 @@ <tutorials> </tutorials> <methods> - <method name="get_flag" qualifiers="const"> - <return type="bool"> - </return> - <argument index="0" name="flag" type="int" enum="GeometryInstance3D.Flags"> - </argument> - <description> - Returns the [enum GeometryInstance3D.Flags] that have been set for this object. - </description> - </method> <method name="get_shader_instance_uniform" qualifiers="const"> <return type="Variant"> </return> @@ -35,17 +26,6 @@ Overrides the bounding box of this node with a custom one. To remove it, set an [AABB] with all fields set to zero. </description> </method> - <method name="set_flag"> - <return type="void"> - </return> - <argument index="0" name="flag" type="int" enum="GeometryInstance3D.Flags"> - </argument> - <argument index="1" name="value" type="bool"> - </argument> - <description> - Sets the [enum GeometryInstance3D.Flags] specified. See [enum GeometryInstance3D.Flags] for options. - </description> - </method> <method name="set_shader_instance_uniform"> <return type="void"> </return> @@ -64,6 +44,10 @@ <member name="extra_cull_margin" type="float" setter="set_extra_cull_margin" getter="get_extra_cull_margin" default="0.0"> The extra distance added to the GeometryInstance3D's bounding box ([AABB]) to increase its cull box. </member> + <member name="gi_lightmap_scale" type="int" setter="set_lightmap_scale" getter="get_lightmap_scale" enum="GeometryInstance3D.LightmapScale" default="0"> + </member> + <member name="gi_mode" type="int" setter="set_gi_mode" getter="get_gi_mode" enum="GeometryInstance3D.GIMode" default="0"> + </member> <member name="lod_max_distance" type="float" setter="set_lod_max_distance" getter="get_lod_max_distance" default="0.0"> The GeometryInstance3D's max LOD distance. [b]Note:[/b] This property currently has no effect. @@ -84,11 +68,6 @@ The material override for the whole geometry. If a material is assigned to this property, it will be used instead of any material set in any material slot of the mesh. </member> - <member name="use_dynamic_gi" type="bool" setter="set_flag" getter="get_flag" default="false"> - </member> - <member name="use_in_baked_light" type="bool" setter="set_flag" getter="get_flag" default="false"> - If [code]true[/code], this GeometryInstance3D will be used when baking lights using a [GIProbe]. - </member> </members> <constants> <constant name="SHADOW_CASTING_SETTING_OFF" value="0" enum="ShadowCastingSetting"> @@ -106,16 +85,21 @@ Will only show the shadows casted from this object. In other words, the actual mesh will not be visible, only the shadows casted from the mesh will be. </constant> - <constant name="FLAG_USE_BAKED_LIGHT" value="0" enum="Flags"> - Will allow the GeometryInstance3D to be used when baking lights using a [GIProbe]. + <constant name="GI_MODE_DISABLED" value="0" enum="GIMode"> + </constant> + <constant name="GI_MODE_BAKED" value="1" enum="GIMode"> + </constant> + <constant name="GI_MODE_DYNAMIC" value="2" enum="GIMode"> + </constant> + <constant name="LIGHTMAP_SCALE_1X" value="0" enum="LightmapScale"> + </constant> + <constant name="LIGHTMAP_SCALE_2X" value="1" enum="LightmapScale"> </constant> - <constant name="FLAG_USE_DYNAMIC_GI" value="1" enum="Flags"> + <constant name="LIGHTMAP_SCALE_4X" value="2" enum="LightmapScale"> </constant> - <constant name="FLAG_DRAW_NEXT_FRAME_IF_VISIBLE" value="2" enum="Flags"> - Unused in this class, exposed for consistency with [enum RenderingServer.InstanceFlags]. + <constant name="LIGHTMAP_SCALE_8X" value="3" enum="LightmapScale"> </constant> - <constant name="FLAG_MAX" value="3" enum="Flags"> - Represents the size of the [enum Flags] enum. + <constant name="LIGHTMAP_SCALE_MAX" value="4" enum="LightmapScale"> </constant> </constants> </class> diff --git a/doc/classes/Input.xml b/doc/classes/Input.xml index 0f212e7498..fc3c3776ce 100644 --- a/doc/classes/Input.xml +++ b/doc/classes/Input.xml @@ -96,7 +96,7 @@ <argument index="1" name="axis" type="int"> </argument> <description> - Returns the current value of the joypad axis at given index (see [enum JoystickList]). + Returns the current value of the joypad axis at given index (see [enum JoyAxisList]). </description> </method> <method name="get_joy_axis_index_from_string"> @@ -114,7 +114,7 @@ <argument index="0" name="axis_index" type="int"> </argument> <description> - Receives a [enum JoystickList] axis and returns its equivalent name as a string. + Receives a [enum JoyAxisList] axis and returns its equivalent name as a string. </description> </method> <method name="get_joy_button_index_from_string"> @@ -132,7 +132,7 @@ <argument index="0" name="button_index" type="int"> </argument> <description> - Receives a gamepad button from [enum JoystickList] and returns its equivalent name as a string. + Receives a gamepad button from [enum JoyButtonList] and returns its equivalent name as a string. </description> </method> <method name="get_joy_guid" qualifiers="const"> @@ -235,7 +235,7 @@ <argument index="1" name="button" type="int"> </argument> <description> - Returns [code]true[/code] if you are pressing the joypad button (see [enum JoystickList]). + Returns [code]true[/code] if you are pressing the joypad button (see [enum JoyButtonList]). </description> </method> <method name="is_joy_known"> @@ -244,7 +244,7 @@ <argument index="0" name="device" type="int"> </argument> <description> - Returns [code]true[/code] if the system knows the specified device. This means that it sets all button and axis indices exactly as defined in [enum JoystickList]. Unknown joypads are not expected to match these constants, but you can still retrieve events from them. + Returns [code]true[/code] if the system knows the specified device. This means that it sets all button and axis indices. Unknown joypads are not expected to match these constants, but you can still retrieve events from them. </description> </method> <method name="is_key_pressed" qualifiers="const"> diff --git a/doc/classes/InputEventJoypadButton.xml b/doc/classes/InputEventJoypadButton.xml index 19aa97e1ec..7876bace75 100644 --- a/doc/classes/InputEventJoypadButton.xml +++ b/doc/classes/InputEventJoypadButton.xml @@ -13,7 +13,7 @@ </methods> <members> <member name="button_index" type="int" setter="set_button_index" getter="get_button_index" default="0"> - Button identifier. One of the [enum JoystickList] button constants. + Button identifier. One of the [enum JoyButtonList] button constants. </member> <member name="pressed" type="bool" setter="set_pressed" getter="is_pressed" default="false"> If [code]true[/code], the button's state is pressed. If [code]false[/code], the button's state is released. diff --git a/doc/classes/InputEventJoypadMotion.xml b/doc/classes/InputEventJoypadMotion.xml index 01e02b79b1..bfd961ce1f 100644 --- a/doc/classes/InputEventJoypadMotion.xml +++ b/doc/classes/InputEventJoypadMotion.xml @@ -13,7 +13,7 @@ </methods> <members> <member name="axis" type="int" setter="set_axis" getter="get_axis" default="0"> - Axis identifier. Use one of the [enum JoystickList] axis constants. + Axis identifier. Use one of the [enum JoyAxisList] axis constants. </member> <member name="axis_value" type="float" setter="set_axis_value" getter="get_axis_value" default="0.0"> Current position of the joystick on the given axis. The value ranges from [code]-1.0[/code] to [code]1.0[/code]. A value of [code]0[/code] means the axis is in its resting position. diff --git a/doc/classes/LightmapProbe.xml b/doc/classes/LightmapProbe.xml new file mode 100644 index 0000000000..3af71f3774 --- /dev/null +++ b/doc/classes/LightmapProbe.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="LightmapProbe" inherits="Node3D" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <constants> + </constants> +</class> diff --git a/doc/classes/LightmapperRD.xml b/doc/classes/LightmapperRD.xml new file mode 100644 index 0000000000..0993b28f19 --- /dev/null +++ b/doc/classes/LightmapperRD.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="LightmapperRD" inherits="Lightmapper" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <constants> + </constants> +</class> diff --git a/doc/classes/Mesh.xml b/doc/classes/Mesh.xml index 6958c815a6..367099e455 100644 --- a/doc/classes/Mesh.xml +++ b/doc/classes/Mesh.xml @@ -102,7 +102,7 @@ </method> </methods> <members> - <member name="lightmap_size_hint" type="Vector2" setter="set_lightmap_size_hint" getter="get_lightmap_size_hint" default="Vector2( 0, 0 )"> + <member name="lightmap_size_hint" type="Vector2i" setter="set_lightmap_size_hint" getter="get_lightmap_size_hint" default="Vector2i( 0, 0 )"> Sets a hint to be used for lightmap resolution. </member> </members> diff --git a/doc/classes/Node.xml b/doc/classes/Node.xml index 5ba3c6c56a..0c0d16753a 100644 --- a/doc/classes/Node.xml +++ b/doc/classes/Node.xml @@ -305,7 +305,7 @@ <return type="Node"> </return> <description> - Returns the parent node of the current node, or an empty [Node] if the node lacks a parent. + Returns the parent node of the current node, or a [code]null instance[/code] if the node lacks a parent. </description> </method> <method name="get_path" qualifiers="const"> diff --git a/doc/classes/PopupMenu.xml b/doc/classes/PopupMenu.xml index 569da5c58b..59450b7ea2 100644 --- a/doc/classes/PopupMenu.xml +++ b/doc/classes/PopupMenu.xml @@ -242,6 +242,12 @@ Removes all items from the [PopupMenu]. </description> </method> + <method name="get_current_index" qualifiers="const"> + <return type="int"> + </return> + <description> + </description> + </method> <method name="get_item_accelerator" qualifiers="const"> <return type="int"> </return> diff --git a/doc/classes/ProjectSettings.xml b/doc/classes/ProjectSettings.xml index dc07dd8837..92f116c7ec 100644 --- a/doc/classes/ProjectSettings.xml +++ b/doc/classes/ProjectSettings.xml @@ -874,8 +874,9 @@ <member name="network/remote_fs/page_size" type="int" setter="" getter="" default="65536"> Page size used by remote filesystem (in bytes). </member> - <member name="network/ssl/certificates" type="String" setter="" getter="" default=""""> - CA certificates bundle to use for SSL connections. If not defined, Godot's internal CA certificates are used. + <member name="network/ssl/certificate_bundle_override" type="String" setter="" getter="" default=""""> + The CA certificates bundle to use for SSL connections. If this is set to a non-empty value, this will [i]override[/i] Godot's default [url=https://github.com/godotengine/godot/blob/master/thirdparty/certs/ca-certificates.crt]Mozilla certificate bundle[/url]. If left empty, the default certificate bundle will be used. + If in doubt, leave this setting empty. </member> <member name="node/name_casing" type="int" setter="" getter="" default="0"> When creating node names automatically, set the type of casing in this project. This is mostly an editor setting. @@ -977,8 +978,32 @@ <member name="rendering/environment/default_environment" type="String" setter="" getter="" default=""""> [Environment] that will be used as a fallback environment in case a scene does not specify its own environment. The default environment is loaded in at scene load time regardless of whether you have set an environment or not. If you do not rely on the fallback environment, it is best to delete [code]default_env.tres[/code], or to specify a different default environment here. </member> + <member name="rendering/gpu_lightmapper/performance/max_rays_per_pass" type="int" setter="" getter="" default="32"> + </member> + <member name="rendering/gpu_lightmapper/performance/max_rays_per_probe_pass" type="int" setter="" getter="" default="64"> + </member> + <member name="rendering/gpu_lightmapper/performance/region_size" type="int" setter="" getter="" default="512"> + </member> + <member name="rendering/gpu_lightmapper/quality/high_quality_probe_ray_count" type="int" setter="" getter="" default="512"> + </member> + <member name="rendering/gpu_lightmapper/quality/high_quality_ray_count" type="int" setter="" getter="" default="256"> + </member> + <member name="rendering/gpu_lightmapper/quality/low_quality_probe_ray_count" type="int" setter="" getter="" default="64"> + </member> + <member name="rendering/gpu_lightmapper/quality/low_quality_ray_count" type="int" setter="" getter="" default="16"> + </member> + <member name="rendering/gpu_lightmapper/quality/medium_quality_probe_ray_count" type="int" setter="" getter="" default="256"> + </member> + <member name="rendering/gpu_lightmapper/quality/medium_quality_ray_count" type="int" setter="" getter="" default="64"> + </member> + <member name="rendering/gpu_lightmapper/quality/ultra_quality_probe_ray_count" type="int" setter="" getter="" default="2048"> + </member> + <member name="rendering/gpu_lightmapper/quality/ultra_quality_ray_count" type="int" setter="" getter="" default="1024"> + </member> <member name="rendering/high_end/global_shader_variables_buffer_size" type="int" setter="" getter="" default="65536"> </member> + <member name="rendering/lightmapper/probe_capture_update_speed" type="float" setter="" getter="" default="15"> + </member> <member name="rendering/limits/rendering/max_renderable_elements" type="int" setter="" getter="" default="128000"> Max amount of elements renderable in a frame. If more than this are visible per frame, they will be dropped. Keep in mind elements refer to mesh surfaces and not meshes themselves. </member> diff --git a/doc/classes/RenderingServer.xml b/doc/classes/RenderingServer.xml index 648f4d0107..d8be6d4bd7 100644 --- a/doc/classes/RenderingServer.xml +++ b/doc/classes/RenderingServer.xml @@ -1282,7 +1282,7 @@ <argument index="1" name="base" type="RID"> </argument> <description> - Sets the base of the instance. A base can be any of the 3D objects that are created in the RenderingServer that can be displayed. For example, any of the light types, mesh, multimesh, immediate geometry, particle system, reflection probe, lightmap capture, and the GI probe are all types that can be set as the base of an instance in order to be displayed in the scenario. + Sets the base of the instance. A base can be any of the 3D objects that are created in the RenderingServer that can be displayed. For example, any of the light types, mesh, multimesh, immediate geometry, particle system, reflection probe, lightmap, and the GI probe are all types that can be set as the base of an instance in order to be displayed in the scenario. </description> </method> <method name="instance_set_blend_shape_weight"> @@ -1377,19 +1377,6 @@ Sets the world space transform of the instance. Equivalent to [member Node3D.transform]. </description> </method> - <method name="instance_set_use_lightmap"> - <return type="void"> - </return> - <argument index="0" name="instance" type="RID"> - </argument> - <argument index="1" name="lightmap_instance" type="RID"> - </argument> - <argument index="2" name="lightmap" type="RID"> - </argument> - <description> - Sets the lightmap to use with this instance. - </description> - </method> <method name="instance_set_visible"> <return type="void"> </return> @@ -1584,115 +1571,6 @@ Sets whether GI probes capture light information from this light. </description> </method> - <method name="lightmap_capture_create"> - <return type="RID"> - </return> - <description> - Creates a lightmap capture and adds it to the RenderingServer. It can be accessed with the RID that is returned. This RID will be used in all [code]lightmap_capture_*[/code] RenderingServer functions. - Once finished with your RID, you will want to free the RID using the RenderingServer's [method free_rid] static method. - To place in a scene, attach this lightmap capture to an instance using [method instance_set_base] using the returned RID. - </description> - </method> - <method name="lightmap_capture_get_bounds" qualifiers="const"> - <return type="AABB"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <description> - Returns the size of the lightmap capture area. - </description> - </method> - <method name="lightmap_capture_get_energy" qualifiers="const"> - <return type="float"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <description> - Returns the energy multiplier used by the lightmap capture. - </description> - </method> - <method name="lightmap_capture_get_octree" qualifiers="const"> - <return type="PackedByteArray"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <description> - Returns the octree used by the lightmap capture. - </description> - </method> - <method name="lightmap_capture_get_octree_cell_subdiv" qualifiers="const"> - <return type="int"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <description> - Returns the cell subdivision amount used by this lightmap capture's octree. - </description> - </method> - <method name="lightmap_capture_get_octree_cell_transform" qualifiers="const"> - <return type="Transform"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <description> - Returns the cell transform for this lightmap capture's octree. - </description> - </method> - <method name="lightmap_capture_set_bounds"> - <return type="void"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <argument index="1" name="bounds" type="AABB"> - </argument> - <description> - Sets the size of the area covered by the lightmap capture. - </description> - </method> - <method name="lightmap_capture_set_energy"> - <return type="void"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <argument index="1" name="energy" type="float"> - </argument> - <description> - Sets the energy multiplier for this lightmap capture. - </description> - </method> - <method name="lightmap_capture_set_octree"> - <return type="void"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <argument index="1" name="octree" type="PackedByteArray"> - </argument> - <description> - Sets the octree to be used by this lightmap capture. - </description> - </method> - <method name="lightmap_capture_set_octree_cell_subdiv"> - <return type="void"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <argument index="1" name="subdiv" type="int"> - </argument> - <description> - Sets the subdivision level of this lightmap capture's octree. - </description> - </method> - <method name="lightmap_capture_set_octree_cell_transform"> - <return type="void"> - </return> - <argument index="0" name="capture" type="RID"> - </argument> - <argument index="1" name="xform" type="Transform"> - </argument> - <description> - Sets the octree cell transform for this lightmap capture's octree. - </description> - </method> <method name="make_sphere_mesh"> <return type="RID"> </return> @@ -3725,8 +3603,8 @@ <constant name="INSTANCE_GI_PROBE" value="8" enum="InstanceType"> The instance is a GI probe. </constant> - <constant name="INSTANCE_LIGHTMAP_CAPTURE" value="9" enum="InstanceType"> - The instance is a lightmap capture. + <constant name="INSTANCE_LIGHTMAP" value="9" enum="InstanceType"> + The instance is a lightmap. </constant> <constant name="INSTANCE_MAX" value="10" enum="InstanceType"> Represents the size of the [enum InstanceType] enum. diff --git a/doc/classes/ResourceFormatLoader.xml b/doc/classes/ResourceFormatLoader.xml index 713f2c1726..ad0c438f98 100644 --- a/doc/classes/ResourceFormatLoader.xml +++ b/doc/classes/ResourceFormatLoader.xml @@ -6,7 +6,7 @@ <description> Godot loads resources in the editor or in exported games using ResourceFormatLoaders. They are queried automatically via the [ResourceLoader] singleton, or when a resource with internal dependencies is loaded. Each file type may load as a different resource type, so multiple ResourceFormatLoaders are registered in the engine. Extending this class allows you to define your own loader. Be sure to respect the documented return types and values. You should give it a global class name with [code]class_name[/code] for it to be registered. Like built-in ResourceFormatLoaders, it will be called automatically when loading resources of its handled type(s). You may also implement a [ResourceFormatSaver]. - [b]Note:[/b] You can also extend [EditorImportPlugin] if the resource type you need exists but Godot is unable to load its format. Choosing one way over another depends if the format is suitable or not for the final exported game. For example, it's better to import [code].png[/code] textures as [code].stex[/code] ([StreamTexture]) first, so they can be loaded with better efficiency on the graphics card. + [b]Note:[/b] You can also extend [EditorImportPlugin] if the resource type you need exists but Godot is unable to load its format. Choosing one way over another depends if the format is suitable or not for the final exported game. For example, it's better to import [code].png[/code] textures as [code].stex[/code] ([StreamTexture2D]) first, so they can be loaded with better efficiency on the graphics card. </description> <tutorials> </tutorials> diff --git a/doc/classes/StreamCubemap.xml b/doc/classes/StreamCubemap.xml new file mode 100644 index 0000000000..16648266eb --- /dev/null +++ b/doc/classes/StreamCubemap.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="StreamCubemap" inherits="StreamTextureLayered" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <constants> + </constants> +</class> diff --git a/doc/classes/StreamCubemapArray.xml b/doc/classes/StreamCubemapArray.xml new file mode 100644 index 0000000000..b84973fd14 --- /dev/null +++ b/doc/classes/StreamCubemapArray.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="StreamCubemapArray" inherits="StreamTextureLayered" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <constants> + </constants> +</class> diff --git a/doc/classes/StreamTexture.xml b/doc/classes/StreamTexture2D.xml index 03afcb5b0d..214298475c 100644 --- a/doc/classes/StreamTexture.xml +++ b/doc/classes/StreamTexture2D.xml @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8" ?> -<class name="StreamTexture" inherits="Texture2D" version="4.0"> +<class name="StreamTexture2D" inherits="Texture2D" version="4.0"> <brief_description> A [code].stex[/code] texture. </brief_description> diff --git a/doc/classes/StreamTexture2DArray.xml b/doc/classes/StreamTexture2DArray.xml new file mode 100644 index 0000000000..ec545b24d0 --- /dev/null +++ b/doc/classes/StreamTexture2DArray.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="StreamTexture2DArray" inherits="StreamTextureLayered" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + </methods> + <constants> + </constants> +</class> diff --git a/doc/classes/StreamTextureLayered.xml b/doc/classes/StreamTextureLayered.xml new file mode 100644 index 0000000000..10a7aae976 --- /dev/null +++ b/doc/classes/StreamTextureLayered.xml @@ -0,0 +1,25 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<class name="StreamTextureLayered" inherits="TextureLayered" version="4.0"> + <brief_description> + </brief_description> + <description> + </description> + <tutorials> + </tutorials> + <methods> + <method name="load"> + <return type="int" enum="Error"> + </return> + <argument index="0" name="path" type="String"> + </argument> + <description> + </description> + </method> + </methods> + <members> + <member name="load_path" type="String" setter="load" getter="get_load_path" default=""""> + </member> + </members> + <constants> + </constants> +</class> diff --git a/doc/classes/String.xml b/doc/classes/String.xml index a72b8f05d8..0dd6923129 100644 --- a/doc/classes/String.xml +++ b/doc/classes/String.xml @@ -620,6 +620,19 @@ Returns [code]true[/code] if this string contains a valid IP address. </description> </method> + <method name="join"> + <return type="String"> + </return> + <argument index="0" name="parts" type="PackedStringArray"> + </argument> + <description> + Return a [String] which is the concatenation of the [code]parts[/code]. The separator between elements is the string providing this method. + Example: + [codeblock] + print(", ".join(["One", "Two", "Three", "Four"])) + [/codeblock] + </description> + </method> <method name="json_escape"> <return type="String"> </return> diff --git a/doc/classes/TextEdit.xml b/doc/classes/TextEdit.xml index bb2d355bf6..e553518b39 100644 --- a/doc/classes/TextEdit.xml +++ b/doc/classes/TextEdit.xml @@ -343,26 +343,26 @@ Select all the text. </description> </method> - <method name="set_line_as_hidden"> + <method name="set_line"> <return type="void"> </return> <argument index="0" name="line" type="int"> </argument> - <argument index="1" name="enable" type="bool"> + <argument index="1" name="new_text" type="String"> </argument> <description> - If [code]true[/code], hides the line of the specified index. + Sets the text for a specific line. </description> </method> - <method name="set_line"> + <method name="set_line_as_hidden"> <return type="void"> </return> <argument index="0" name="line" type="int"> </argument> - <argument index="1" name="new_text" type="String"> + <argument index="1" name="enable" type="bool"> </argument> <description> - Sets the text for a specific line. + If [code]true[/code], hides the line of the specified index. </description> </method> <method name="toggle_fold_line"> diff --git a/doc/classes/Texture2DArray.xml b/doc/classes/Texture2DArray.xml index 657506120e..bb9283803d 100644 --- a/doc/classes/Texture2DArray.xml +++ b/doc/classes/Texture2DArray.xml @@ -1,5 +1,5 @@ <?xml version="1.0" encoding="UTF-8" ?> -<class name="Texture2DArray" inherits="TextureLayered" version="4.0"> +<class name="Texture2DArray" inherits="ImageTextureLayered" version="4.0"> <brief_description> </brief_description> <description> diff --git a/doc/classes/TextureLayered.xml b/doc/classes/TextureLayered.xml index 66e5b69ab4..d81c3c7c5a 100644 --- a/doc/classes/TextureLayered.xml +++ b/doc/classes/TextureLayered.xml @@ -9,14 +9,6 @@ <tutorials> </tutorials> <methods> - <method name="create_from_images"> - <return type="int" enum="Error"> - </return> - <argument index="0" name="images" type="Array"> - </argument> - <description> - </description> - </method> <method name="get_format" qualifiers="const"> <return type="int" enum="Image.Format"> </return> @@ -40,6 +32,12 @@ Returns an [Image] resource with the data from specified [code]layer[/code]. </description> </method> + <method name="get_layered_type" qualifiers="const"> + <return type="int" enum="TextureLayered.LayeredType"> + </return> + <description> + </description> + </method> <method name="get_layers" qualifiers="const"> <return type="int"> </return> @@ -53,17 +51,19 @@ Returns the width of the texture. Width is typically represented by the X-axis. </description> </method> - <method name="update_layer"> - <return type="void"> + <method name="has_mipmaps" qualifiers="const"> + <return type="bool"> </return> - <argument index="0" name="image" type="Image"> - </argument> - <argument index="1" name="layer" type="int"> - </argument> <description> </description> </method> </methods> <constants> + <constant name="LAYERED_TYPE_2D_ARRAY" value="0" enum="LayeredType"> + </constant> + <constant name="LAYERED_TYPE_CUBEMAP" value="1" enum="LayeredType"> + </constant> + <constant name="LAYERED_TYPE_CUBEMAP_ARRAY" value="2" enum="LayeredType"> + </constant> </constants> </class> diff --git a/doc/classes/TileMap.xml b/doc/classes/TileMap.xml index bfe6983f06..9df2b656f4 100644 --- a/doc/classes/TileMap.xml +++ b/doc/classes/TileMap.xml @@ -80,13 +80,13 @@ Returns a [Vector2] array with the positions of all cells containing a tile from the tileset (i.e. a tile index different from [code]-1[/code]). </description> </method> - <method name="get_used_cells_by_id" qualifiers="const"> + <method name="get_used_cells_by_index" qualifiers="const"> <return type="Vector2i[]"> </return> - <argument index="0" name="id" type="int"> + <argument index="0" name="index" type="int"> </argument> <description> - Returns an array of all cells with the given tile [code]id[/code]. + Returns an array of all cells with the given tile [code]index[/code]. </description> </method> <method name="get_used_rect"> @@ -273,7 +273,7 @@ <member name="cell_tile_origin" type="int" setter="set_tile_origin" getter="get_tile_origin" enum="TileMap.TileOrigin" default="0"> Position for tile origin. See [enum TileOrigin] for possible values. </member> - <member name="cell_y_sort" type="bool" setter="set_y_sort_mode" getter="is_y_sort_mode_enabled" default="false"> + <member name="cell_y_sort" type="bool" setter="set_y_sort_enabled" getter="is_y_sort_enabled" default="false"> If [code]true[/code], the TileMap's children will be drawn in order of their Y coordinate. </member> <member name="centered_textures" type="bool" setter="set_centered_textures" getter="is_centered_textures_enabled" default="false"> diff --git a/doc/classes/TreeItem.xml b/doc/classes/TreeItem.xml index a8a17370c2..126d6b4180 100644 --- a/doc/classes/TreeItem.xml +++ b/doc/classes/TreeItem.xml @@ -249,6 +249,14 @@ <description> </description> </method> + <method name="get_suffix" qualifiers="const"> + <return type="String"> + </return> + <argument index="0" name="column" type="int"> + </argument> + <description> + </description> + </method> <method name="get_text" qualifiers="const"> <return type="String"> </return> @@ -572,6 +580,16 @@ If [code]true[/code], the given column is selectable. </description> </method> + <method name="set_suffix"> + <return type="void"> + </return> + <argument index="0" name="column" type="int"> + </argument> + <argument index="1" name="text" type="String"> + </argument> + <description> + </description> + </method> <method name="set_text"> <return type="void"> </return> diff --git a/doc/classes/Vector3.xml b/doc/classes/Vector3.xml index 8c18ca8cc9..29222bb4d1 100644 --- a/doc/classes/Vector3.xml +++ b/doc/classes/Vector3.xml @@ -309,7 +309,7 @@ <argument index="0" name="by" type="Vector3"> </argument> <description> - Returns a copy of the vector snapped to the lowest neared multiple. + Returns the vector snapped to a grid with the given size. </description> </method> <method name="to_diagonal_matrix"> diff --git a/doc/classes/VisualShaderNodeInput.xml b/doc/classes/VisualShaderNodeInput.xml index ed629508d0..9261d0088d 100644 --- a/doc/classes/VisualShaderNodeInput.xml +++ b/doc/classes/VisualShaderNodeInput.xml @@ -4,8 +4,10 @@ Represents the input shader parameter within the visual shader graph. </brief_description> <description> + Gives access to input variables (built-ins) available for the shader. See the shading reference for the list of available built-ins for each shader type (check [code]Tutorials[/code] section for link). </description> <tutorials> + <link>https://docs.godotengine.org/en/stable/tutorials/shading/shading_reference/index.html</link> </tutorials> <methods> <method name="get_input_real_name" qualifiers="const"> @@ -18,7 +20,7 @@ </methods> <members> <member name="input_name" type="String" setter="set_input_name" getter="get_input_name" default=""[None]""> - One of the several input constants in lower-case style like: "vertex"([/code]VERTEX[code]) or "point_size"([code]POINT_SIZE[/code]). + One of the several input constants in lower-case style like: "vertex"([code]VERTEX[/code]) or "point_size"([code]POINT_SIZE[/code]). </member> </members> <signals> diff --git a/doc/classes/VisualShaderNodeIs.xml b/doc/classes/VisualShaderNodeIs.xml index 184c9e099f..b767a9638e 100644 --- a/doc/classes/VisualShaderNodeIs.xml +++ b/doc/classes/VisualShaderNodeIs.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeIs" inherits="VisualShaderNode" version="4.0"> <brief_description> + A boolean comparison operator to be used within the visual shader graph. </brief_description> <description> + Returns the boolean result of the comparison between [code]INF[/code] or [code]NaN[/code] and a scalar parameter. </description> <tutorials> </tutorials> @@ -10,12 +12,15 @@ </methods> <members> <member name="function" type="int" setter="set_function" getter="get_function" enum="VisualShaderNodeIs.Function" default="0"> + The comparison function. See [enum Function] for options. </member> </members> <constants> <constant name="FUNC_IS_INF" value="0" enum="Function"> + Comparison with [code]INF[/code] (Infinity). </constant> <constant name="FUNC_IS_NAN" value="1" enum="Function"> + Comparison with [code]NaN[/code] (Not a Number; denotes invalid numeric results, e.g. division by zero). </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeOuterProduct.xml b/doc/classes/VisualShaderNodeOuterProduct.xml index b8d4fd687f..ba6822bfce 100644 --- a/doc/classes/VisualShaderNodeOuterProduct.xml +++ b/doc/classes/VisualShaderNodeOuterProduct.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeOuterProduct" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates an outer product of two vectors within the visual shader graph. </brief_description> <description> + [code]OuterProduct[/code] treats the first parameter [code]c[/code] as a column vector (matrix with one column) and the second parameter [code]r[/code] as a row vector (matrix with one row) and does a linear algebraic matrix multiply [code]c * r[/code], yielding a matrix whose number of rows is the number of components in [code]c[/code] and whose number of columns is the number of components in [code]r[/code]. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeOutput.xml b/doc/classes/VisualShaderNodeOutput.xml index c63e307bad..2b4aed9ae4 100644 --- a/doc/classes/VisualShaderNodeOutput.xml +++ b/doc/classes/VisualShaderNodeOutput.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeOutput" inherits="VisualShaderNode" version="4.0"> <brief_description> + Represents the output shader parameters within the visual shader graph. </brief_description> <description> + This visual shader node is present in all shader graphs in form of "Output" block with mutliple output value ports. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeScalarClamp.xml b/doc/classes/VisualShaderNodeScalarClamp.xml index fd963dcb5d..7432e8dfca 100644 --- a/doc/classes/VisualShaderNodeScalarClamp.xml +++ b/doc/classes/VisualShaderNodeScalarClamp.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeScalarClamp" inherits="VisualShaderNode" version="4.0"> <brief_description> + Clamps a scalar value within the visual shader graph. </brief_description> <description> + Constrains a value to lie between [code]min[/code] and [code]max[/code] values. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeScalarDerivativeFunc.xml b/doc/classes/VisualShaderNodeScalarDerivativeFunc.xml index fa9aa07761..33777c1e49 100644 --- a/doc/classes/VisualShaderNodeScalarDerivativeFunc.xml +++ b/doc/classes/VisualShaderNodeScalarDerivativeFunc.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeScalarDerivativeFunc" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a scalar derivative within the visual shader graph. </brief_description> <description> + This node is only available in [code]Fragment[/code] and [code]Light[/code] visual shaders. </description> <tutorials> </tutorials> @@ -10,14 +12,18 @@ </methods> <members> <member name="function" type="int" setter="set_function" getter="get_function" enum="VisualShaderNodeScalarDerivativeFunc.Function" default="0"> + The derivative type. See [enum Function] for options. </member> </members> <constants> <constant name="FUNC_SUM" value="0" enum="Function"> + Sum of absolute derivative in [code]x[/code] and [code]y[/code]. </constant> <constant name="FUNC_X" value="1" enum="Function"> + Derivative in [code]x[/code] using local differencing. </constant> <constant name="FUNC_Y" value="2" enum="Function"> + Derivative in [code]y[/code] using local differencing. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeScalarInterp.xml b/doc/classes/VisualShaderNodeScalarInterp.xml index a25ab750cc..393ea70e1a 100644 --- a/doc/classes/VisualShaderNodeScalarInterp.xml +++ b/doc/classes/VisualShaderNodeScalarInterp.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeScalarInterp" inherits="VisualShaderNode" version="4.0"> <brief_description> + Linearly interpolates between two scalars within the visual shader graph. </brief_description> <description> + Translates to [code]mix(a, b, weight)[/code] in the shader language. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeScalarSmoothStep.xml b/doc/classes/VisualShaderNodeScalarSmoothStep.xml index 1ac16e451f..e619cc8571 100644 --- a/doc/classes/VisualShaderNodeScalarSmoothStep.xml +++ b/doc/classes/VisualShaderNodeScalarSmoothStep.xml @@ -1,8 +1,11 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeScalarSmoothStep" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a scalar SmoothStep function within the visual shader graph. </brief_description> <description> + Translates to [code]smoothstep(edge0, edge1, x)[/code] in the shader language. + Returns [code]0.0[/code] if [code]x[/code] is smaller than [code]edge0[/code] and [code]1.0[/code] if [code]x[/code] is larger than [code]edge1[/code]. Otherwise the return value is interpolated between [code]0.0[/code] and [code]1.0[/code] using Hermite polynomials. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeScalarSwitch.xml b/doc/classes/VisualShaderNodeScalarSwitch.xml index 789c8972bb..2ad5202745 100644 --- a/doc/classes/VisualShaderNodeScalarSwitch.xml +++ b/doc/classes/VisualShaderNodeScalarSwitch.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeScalarSwitch" inherits="VisualShaderNodeSwitch" version="4.0"> <brief_description> + A boolean/scalar function for use within the visual shader graph. </brief_description> <description> + Returns an associated scalar if the provided boolean value is [code]true[/code] or [code]false[/code]. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeSwitch.xml b/doc/classes/VisualShaderNodeSwitch.xml index 5bbb9168a0..9f8a12c0fd 100644 --- a/doc/classes/VisualShaderNodeSwitch.xml +++ b/doc/classes/VisualShaderNodeSwitch.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeSwitch" inherits="VisualShaderNode" version="4.0"> <brief_description> + A boolean/vector function for use within the visual shader graph. </brief_description> <description> + Returns an associated vector if the provided boolean value is [code]true[/code] or [code]false[/code]. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeTexture.xml b/doc/classes/VisualShaderNodeTexture.xml index a28a7f5c65..8e389e0b40 100644 --- a/doc/classes/VisualShaderNodeTexture.xml +++ b/doc/classes/VisualShaderNodeTexture.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTexture" inherits="VisualShaderNode" version="4.0"> <brief_description> + Performs a texture lookup within the visual shader graph. </brief_description> <description> + Performs a lookup operation on the provided texture, with support for multiple texture sources to choose from. </description> <tutorials> </tutorials> @@ -10,30 +12,42 @@ </methods> <members> <member name="source" type="int" setter="set_source" getter="get_source" enum="VisualShaderNodeTexture.Source" default="0"> + Determines the source for the lookup. See [enum Source] for options. </member> <member name="texture" type="Texture2D" setter="set_texture" getter="get_texture"> + The source texture, if needed for the selected [member source]. </member> <member name="texture_type" type="int" setter="set_texture_type" getter="get_texture_type" enum="VisualShaderNodeTexture.TextureType" default="0"> + Specifies the type of the texture if [member source] is set to [constant SOURCE_TEXTURE]. See [enum TextureType] for options. </member> </members> <constants> <constant name="SOURCE_TEXTURE" value="0" enum="Source"> + Use the texture given as an argument for this function. </constant> <constant name="SOURCE_SCREEN" value="1" enum="Source"> + Use the current viewport's texture as the source. </constant> <constant name="SOURCE_2D_TEXTURE" value="2" enum="Source"> + Use the texture from this shader's texture built-in (e.g. a texture of a [Sprite2D]). </constant> <constant name="SOURCE_2D_NORMAL" value="3" enum="Source"> + Use the texture from this shader's normal map built-in. </constant> <constant name="SOURCE_DEPTH" value="4" enum="Source"> + Use the depth texture available for this shader. </constant> <constant name="SOURCE_PORT" value="5" enum="Source"> + Use the texture provided in the input port for this function. </constant> <constant name="TYPE_DATA" value="0" enum="TextureType"> + No hints are added to the uniform declaration. </constant> <constant name="TYPE_COLOR" value="1" enum="TextureType"> + Adds [code]hint_albedo[/code] as hint to the uniform declaration for proper sRGB to linear conversion. </constant> <constant name="TYPE_NORMALMAP" value="2" enum="TextureType"> + Adds [code]hint_normal[/code] as hint to the uniform declaration, which internally converts the texture for proper usage as normal map. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeTextureUniform.xml b/doc/classes/VisualShaderNodeTextureUniform.xml index 4e2c39a297..107f08ba28 100644 --- a/doc/classes/VisualShaderNodeTextureUniform.xml +++ b/doc/classes/VisualShaderNodeTextureUniform.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTextureUniform" inherits="VisualShaderNodeUniform" version="4.0"> <brief_description> + Performs a uniform texture lookup within the visual shader graph. </brief_description> <description> + Performs a lookup operation on the texture provided as a uniform for the shader. </description> <tutorials> </tutorials> @@ -10,22 +12,30 @@ </methods> <members> <member name="color_default" type="int" setter="set_color_default" getter="get_color_default" enum="VisualShaderNodeTextureUniform.ColorDefault" default="0"> + Sets the default color if no texture is assigned to the uniform. </member> <member name="texture_type" type="int" setter="set_texture_type" getter="get_texture_type" enum="VisualShaderNodeTextureUniform.TextureType" default="0"> + Defines the type of data provided by the source texture. See [enum TextureType] for options. </member> </members> <constants> <constant name="TYPE_DATA" value="0" enum="TextureType"> + No hints are added to the uniform declaration. </constant> <constant name="TYPE_COLOR" value="1" enum="TextureType"> + Adds [code]hint_albedo[/code] as hint to the uniform declaration for proper sRGB to linear conversion. </constant> <constant name="TYPE_NORMALMAP" value="2" enum="TextureType"> + Adds [code]hint_normal[/code] as hint to the uniform declaration, which internally converts the texture for proper usage as normal map. </constant> <constant name="TYPE_ANISO" value="3" enum="TextureType"> + Adds [code]hint_aniso[/code] as hint to the uniform declaration to use for a flowmap. </constant> <constant name="COLOR_DEFAULT_WHITE" value="0" enum="ColorDefault"> + Defaults to white color. </constant> <constant name="COLOR_DEFAULT_BLACK" value="1" enum="ColorDefault"> + Defaults to black color. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeTextureUniformTriplanar.xml b/doc/classes/VisualShaderNodeTextureUniformTriplanar.xml index 3d69575444..28504cc7ac 100644 --- a/doc/classes/VisualShaderNodeTextureUniformTriplanar.xml +++ b/doc/classes/VisualShaderNodeTextureUniformTriplanar.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTextureUniformTriplanar" inherits="VisualShaderNodeTextureUniform" version="4.0"> <brief_description> + Performs a uniform texture lookup with triplanar within the visual shader graph. </brief_description> <description> + Performs a lookup operation on the texture provided as a uniform for the shader, with support for triplanar mapping. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeTransformCompose.xml b/doc/classes/VisualShaderNodeTransformCompose.xml index 6d9cab7ab0..41762b0099 100644 --- a/doc/classes/VisualShaderNodeTransformCompose.xml +++ b/doc/classes/VisualShaderNodeTransformCompose.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformCompose" inherits="VisualShaderNode" version="4.0"> <brief_description> + Composes a [Transform] from four [Vector3]s within the visual shader graph. </brief_description> <description> + Creates a 4x4 transform matrix using four vectors of type [code]vec3[/code]. Each vector is one row in the matrix and the last column is a [code]vec4(0, 0, 0, 1)[/code]. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeTransformConstant.xml b/doc/classes/VisualShaderNodeTransformConstant.xml index 15422e1728..e5004e5bb6 100644 --- a/doc/classes/VisualShaderNodeTransformConstant.xml +++ b/doc/classes/VisualShaderNodeTransformConstant.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformConstant" inherits="VisualShaderNode" version="4.0"> <brief_description> + A [Transform] constant for use within the visual shader graph. </brief_description> <description> + A constant [Transform], which can be used as an input node. </description> <tutorials> </tutorials> @@ -10,6 +12,7 @@ </methods> <members> <member name="constant" type="Transform" setter="set_constant" getter="get_constant" default="Transform( 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0 )"> + A [Transform] constant which represents the state of this node. </member> </members> <constants> diff --git a/doc/classes/VisualShaderNodeTransformDecompose.xml b/doc/classes/VisualShaderNodeTransformDecompose.xml index 4d3c464781..c8d893db00 100644 --- a/doc/classes/VisualShaderNodeTransformDecompose.xml +++ b/doc/classes/VisualShaderNodeTransformDecompose.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformDecompose" inherits="VisualShaderNode" version="4.0"> <brief_description> + Decomposes a [Transform] into four [Vector3]s within the visual shader graph. </brief_description> <description> + Takes a 4x4 transform matrix and decomposes it into four [code]vec3[/code] values, one from each row of the matrix. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeTransformFunc.xml b/doc/classes/VisualShaderNodeTransformFunc.xml index d2b6fcef2b..d0b5c5129d 100644 --- a/doc/classes/VisualShaderNodeTransformFunc.xml +++ b/doc/classes/VisualShaderNodeTransformFunc.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformFunc" inherits="VisualShaderNode" version="4.0"> <brief_description> + Computes a [Transform] function within the visual shader graph. </brief_description> <description> + Computes an inverse or transpose function on the provided [Transform]. </description> <tutorials> </tutorials> @@ -10,12 +12,15 @@ </methods> <members> <member name="function" type="int" setter="set_function" getter="get_function" enum="VisualShaderNodeTransformFunc.Function" default="0"> + The function to be computed. See [enum Function] for options. </member> </members> <constants> <constant name="FUNC_INVERSE" value="0" enum="Function"> + Perform the inverse operation on the [Transform] matrix. </constant> <constant name="FUNC_TRANSPOSE" value="1" enum="Function"> + Perform the transpose operation on the [Transform] matrix. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeTransformMult.xml b/doc/classes/VisualShaderNodeTransformMult.xml index 5893d1413b..02b6e0cd1c 100644 --- a/doc/classes/VisualShaderNodeTransformMult.xml +++ b/doc/classes/VisualShaderNodeTransformMult.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformMult" inherits="VisualShaderNode" version="4.0"> <brief_description> + Multiplies [Transform] by [Transform] within the visual shader graph. </brief_description> <description> + A multiplication operation on two transforms (4x4 matrices), with support for different multiplication operators. </description> <tutorials> </tutorials> @@ -10,16 +12,21 @@ </methods> <members> <member name="operator" type="int" setter="set_operator" getter="get_operator" enum="VisualShaderNodeTransformMult.Operator" default="0"> + The multiplication type to be performed on the transforms. See [enum Operator] for options. </member> </members> <constants> <constant name="OP_AxB" value="0" enum="Operator"> + Multiplies transform [code]a[/code] by the transform [code]b[/code]. </constant> <constant name="OP_BxA" value="1" enum="Operator"> + Multiplies transform [code]b[/code] by the transform [code]a[/code]. </constant> <constant name="OP_AxB_COMP" value="2" enum="Operator"> + Performs a component-wise multiplication of transform [code]a[/code] by the transform [code]b[/code]. </constant> <constant name="OP_BxA_COMP" value="3" enum="Operator"> + Performs a component-wise multiplication of transform [code]b[/code] by the transform [code]a[/code]. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeTransformUniform.xml b/doc/classes/VisualShaderNodeTransformUniform.xml index 7605ef96d5..43696c8226 100644 --- a/doc/classes/VisualShaderNodeTransformUniform.xml +++ b/doc/classes/VisualShaderNodeTransformUniform.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformUniform" inherits="VisualShaderNodeUniform" version="4.0"> <brief_description> + A [Transform] uniform for use within the visual shader graph. </brief_description> <description> + Translated to [code]uniform mat4[/code] in the shader language. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeTransformVecMult.xml b/doc/classes/VisualShaderNodeTransformVecMult.xml index d53c3c5ae5..3d5f87f727 100644 --- a/doc/classes/VisualShaderNodeTransformVecMult.xml +++ b/doc/classes/VisualShaderNodeTransformVecMult.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeTransformVecMult" inherits="VisualShaderNode" version="4.0"> <brief_description> + Multiplies a [Transform] and a [Vector3] within the visual shader graph. </brief_description> <description> + A multiplication operation on a transform (4x4 matrix) and a vector, with support for different multiplication operators. </description> <tutorials> </tutorials> @@ -10,16 +12,21 @@ </methods> <members> <member name="operator" type="int" setter="set_operator" getter="get_operator" enum="VisualShaderNodeTransformVecMult.Operator" default="0"> + The multiplication type to be performed. See [enum Operator] for options. </member> </members> <constants> <constant name="OP_AxB" value="0" enum="Operator"> + Multiplies transform [code]a[/code] by the vector [code]b[/code]. </constant> <constant name="OP_BxA" value="1" enum="Operator"> + Multiplies vector [code]b[/code] by the transform [code]a[/code]. </constant> <constant name="OP_3x3_AxB" value="2" enum="Operator"> + Multiplies transform [code]a[/code] by the vector [code]b[/code], skipping the last row and column of the transform. </constant> <constant name="OP_3x3_BxA" value="3" enum="Operator"> + Multiplies vector [code]b[/code] by the transform [code]a[/code], skipping the last row and column of the transform. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeUniform.xml b/doc/classes/VisualShaderNodeUniform.xml index 1225d9c320..83261344bd 100644 --- a/doc/classes/VisualShaderNodeUniform.xml +++ b/doc/classes/VisualShaderNodeUniform.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeUniform" inherits="VisualShaderNode" version="4.0"> <brief_description> + A base type for the uniforms within the visual shader graph. </brief_description> <description> + A uniform represents a variable in the shader which is set externally, i.e. from the [ShaderMaterial]. Uniforms are exposed as properties in the [ShaderMaterial] and can be assigned from the inspector or from a script. </description> <tutorials> </tutorials> @@ -12,6 +14,7 @@ <member name="qualifier" type="int" setter="set_qualifier" getter="get_qualifier" enum="VisualShaderNodeUniform.Qualifier" default="0"> </member> <member name="uniform_name" type="String" setter="set_uniform_name" getter="get_uniform_name" default=""""> + Name of the uniform, by which it can be accessed through the [ShaderMaterial] properties. </member> </members> <constants> diff --git a/doc/classes/VisualShaderNodeVec3Constant.xml b/doc/classes/VisualShaderNodeVec3Constant.xml index 8532c5476f..4dfc9dc081 100644 --- a/doc/classes/VisualShaderNodeVec3Constant.xml +++ b/doc/classes/VisualShaderNodeVec3Constant.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVec3Constant" inherits="VisualShaderNode" version="4.0"> <brief_description> + A [Vector3] constant to be used within the visual shader graph. </brief_description> <description> + A constant [Vector3], which can be used as an input node. </description> <tutorials> </tutorials> @@ -10,6 +12,7 @@ </methods> <members> <member name="constant" type="Vector3" setter="set_constant" getter="get_constant" default="Vector3( 0, 0, 0 )"> + A [Vector3] constant which represents the state of this node. </member> </members> <constants> diff --git a/doc/classes/VisualShaderNodeVec3Uniform.xml b/doc/classes/VisualShaderNodeVec3Uniform.xml index f2b4c4778b..c8b44fdc8d 100644 --- a/doc/classes/VisualShaderNodeVec3Uniform.xml +++ b/doc/classes/VisualShaderNodeVec3Uniform.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVec3Uniform" inherits="VisualShaderNodeUniform" version="4.0"> <brief_description> + A [Vector3] uniform to be used within the visual shader graph. </brief_description> <description> + Translated to [code]uniform vec3[/code] in the shader language. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorClamp.xml b/doc/classes/VisualShaderNodeVectorClamp.xml index 85c093f84c..567fed8a41 100644 --- a/doc/classes/VisualShaderNodeVectorClamp.xml +++ b/doc/classes/VisualShaderNodeVectorClamp.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorClamp" inherits="VisualShaderNode" version="4.0"> <brief_description> + Clamps a vector value within the visual shader graph. </brief_description> <description> + Constrains a value to lie between [code]min[/code] and [code]max[/code] values. The operation is performed on each component of the vector individually. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorCompose.xml b/doc/classes/VisualShaderNodeVectorCompose.xml index e5e4924a9e..c9ff3cd38e 100644 --- a/doc/classes/VisualShaderNodeVectorCompose.xml +++ b/doc/classes/VisualShaderNodeVectorCompose.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorCompose" inherits="VisualShaderNode" version="4.0"> <brief_description> + Composes a [Vector3] from three scalars within the visual shader graph. </brief_description> <description> + Creates a [code]vec3[/code] using three scalar values that can be provided from separate inputs. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorDecompose.xml b/doc/classes/VisualShaderNodeVectorDecompose.xml index 5378f38b6d..95af323c9b 100644 --- a/doc/classes/VisualShaderNodeVectorDecompose.xml +++ b/doc/classes/VisualShaderNodeVectorDecompose.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorDecompose" inherits="VisualShaderNode" version="4.0"> <brief_description> + Decomposes a [Vector3] into three scalars within the visual shader graph. </brief_description> <description> + Takes a [code]vec3[/code] and decomposes it into three scalar values that can be used as separate inputs. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorDerivativeFunc.xml b/doc/classes/VisualShaderNodeVectorDerivativeFunc.xml index d62512d68b..859c47bc33 100644 --- a/doc/classes/VisualShaderNodeVectorDerivativeFunc.xml +++ b/doc/classes/VisualShaderNodeVectorDerivativeFunc.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorDerivativeFunc" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a vector derivative within the visual shader graph. </brief_description> <description> + This node is only available in [code]Fragment[/code] and [code]Light[/code] visual shaders. </description> <tutorials> </tutorials> @@ -10,14 +12,18 @@ </methods> <members> <member name="function" type="int" setter="set_function" getter="get_function" enum="VisualShaderNodeVectorDerivativeFunc.Function" default="0"> + A derivative type. See [enum Function] for options. </member> </members> <constants> <constant name="FUNC_SUM" value="0" enum="Function"> + Sum of absolute derivative in [code]x[/code] and [code]y[/code]. </constant> <constant name="FUNC_X" value="1" enum="Function"> + Derivative in [code]x[/code] using local differencing. </constant> <constant name="FUNC_Y" value="2" enum="Function"> + Derivative in [code]y[/code] using local differencing. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeVectorDistance.xml b/doc/classes/VisualShaderNodeVectorDistance.xml index 2e681156a5..2da04b122e 100644 --- a/doc/classes/VisualShaderNodeVectorDistance.xml +++ b/doc/classes/VisualShaderNodeVectorDistance.xml @@ -1,8 +1,11 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorDistance" inherits="VisualShaderNode" version="4.0"> <brief_description> + Returns the distance between two points. To be used within the visual shader graph. </brief_description> <description> + Calculates distance from point represented by vector [code]p0[/code] to vector [code]p1[/code]. + Translated to [code]distance(p0, p1)[/code] in the shader language. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorFunc.xml b/doc/classes/VisualShaderNodeVectorFunc.xml index 0b3f317b8b..cbda3dfb46 100644 --- a/doc/classes/VisualShaderNodeVectorFunc.xml +++ b/doc/classes/VisualShaderNodeVectorFunc.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorFunc" inherits="VisualShaderNode" version="4.0"> <brief_description> + A vector function to be used within the visual shader graph. </brief_description> <description> + A visual shader node able to perform different functions using vectors. </description> <tutorials> </tutorials> @@ -10,78 +12,114 @@ </methods> <members> <member name="function" type="int" setter="set_function" getter="get_function" enum="VisualShaderNodeVectorFunc.Function" default="0"> + The function to be performed. See [enum Function] for options. </member> </members> <constants> <constant name="FUNC_NORMALIZE" value="0" enum="Function"> + Normalizes the vector so that it has a length of [code]1[/code] but points in the same direction. </constant> <constant name="FUNC_SATURATE" value="1" enum="Function"> + Clamps the value between [code]0.0[/code] and [code]1.0[/code]. </constant> <constant name="FUNC_NEGATE" value="2" enum="Function"> + Returns the opposite value of the parameter. </constant> <constant name="FUNC_RECIPROCAL" value="3" enum="Function"> + Returns [code]1/vector[/code]. </constant> <constant name="FUNC_RGB2HSV" value="4" enum="Function"> + Converts RGB vector to HSV equivalent. </constant> <constant name="FUNC_HSV2RGB" value="5" enum="Function"> + Converts HSV vector to RGB equivalent. </constant> <constant name="FUNC_ABS" value="6" enum="Function"> + Returns the absolute value of the parameter. </constant> <constant name="FUNC_ACOS" value="7" enum="Function"> + Returns the arc-cosine of the parameter. </constant> <constant name="FUNC_ACOSH" value="8" enum="Function"> + Returns the inverse hyperbolic cosine of the parameter. </constant> <constant name="FUNC_ASIN" value="9" enum="Function"> + Returns the arc-sine of the parameter. </constant> <constant name="FUNC_ASINH" value="10" enum="Function"> + Returns the inverse hyperbolic sine of the parameter. </constant> <constant name="FUNC_ATAN" value="11" enum="Function"> + Returns the arc-tangent of the parameter. </constant> <constant name="FUNC_ATANH" value="12" enum="Function"> + Returns the inverse hyperbolic tangent of the parameter. </constant> <constant name="FUNC_CEIL" value="13" enum="Function"> + Finds the nearest integer that is greater than or equal to the parameter. </constant> <constant name="FUNC_COS" value="14" enum="Function"> + Returns the cosine of the parameter. </constant> <constant name="FUNC_COSH" value="15" enum="Function"> + Returns the hyperbolic cosine of the parameter. </constant> <constant name="FUNC_DEGREES" value="16" enum="Function"> + Converts a quantity in radians to degrees. </constant> <constant name="FUNC_EXP" value="17" enum="Function"> + Base-e Exponential. </constant> <constant name="FUNC_EXP2" value="18" enum="Function"> + Base-2 Exponential. </constant> <constant name="FUNC_FLOOR" value="19" enum="Function"> + Finds the nearest integer less than or equal to the parameter. </constant> <constant name="FUNC_FRAC" value="20" enum="Function"> + Computes the fractional part of the argument. </constant> <constant name="FUNC_INVERSE_SQRT" value="21" enum="Function"> + Returns the inverse of the square root of the parameter. </constant> <constant name="FUNC_LOG" value="22" enum="Function"> + Natural logarithm. </constant> <constant name="FUNC_LOG2" value="23" enum="Function"> + Base-2 logarithm. </constant> <constant name="FUNC_RADIANS" value="24" enum="Function"> + Converts a quantity in degrees to radians. </constant> <constant name="FUNC_ROUND" value="25" enum="Function"> + Finds the nearest integer to the parameter. </constant> <constant name="FUNC_ROUNDEVEN" value="26" enum="Function"> + Finds the nearest even integer to the parameter. </constant> <constant name="FUNC_SIGN" value="27" enum="Function"> + Extracts the sign of the parameter, i.e. returns [code]-1[/code] if the parameter is negative, [code]1[/code] if it's positive and [code]0[/code] otherwise. </constant> <constant name="FUNC_SIN" value="28" enum="Function"> + Returns the sine of the parameter. </constant> <constant name="FUNC_SINH" value="29" enum="Function"> + Returns the hyperbolic sine of the parameter. </constant> <constant name="FUNC_SQRT" value="30" enum="Function"> + Returns the square root of the parameter. </constant> <constant name="FUNC_TAN" value="31" enum="Function"> + Returns the tangent of the parameter. </constant> <constant name="FUNC_TANH" value="32" enum="Function"> + Returns the hyperbolic tangent of the parameter. </constant> <constant name="FUNC_TRUNC" value="33" enum="Function"> + Returns a value equal to the nearest integer to the parameter whose absolute value is not larger than the absolute value of the parameter. </constant> <constant name="FUNC_ONEMINUS" value="34" enum="Function"> + Returns [code]1.0 - vector[/code]. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeVectorInterp.xml b/doc/classes/VisualShaderNodeVectorInterp.xml index 4d6d3ac577..b63d34b742 100644 --- a/doc/classes/VisualShaderNodeVectorInterp.xml +++ b/doc/classes/VisualShaderNodeVectorInterp.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorInterp" inherits="VisualShaderNode" version="4.0"> <brief_description> + Linearly interpolates between two vectors within the visual shader graph. </brief_description> <description> + Translates to [code]mix(a, b, weight)[/code] in the shader language, where [code]weight[/code] is a [Vector3] with weights for each component. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorLen.xml b/doc/classes/VisualShaderNodeVectorLen.xml index ade575310c..77261d3190 100644 --- a/doc/classes/VisualShaderNodeVectorLen.xml +++ b/doc/classes/VisualShaderNodeVectorLen.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorLen" inherits="VisualShaderNode" version="4.0"> <brief_description> + Returns the length of a [Vector3] within the visual shader graph. </brief_description> <description> + Translated to [code]length(p0)[/code] in the shader language. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorOp.xml b/doc/classes/VisualShaderNodeVectorOp.xml index 8c391d09a2..d56c012f8f 100644 --- a/doc/classes/VisualShaderNodeVectorOp.xml +++ b/doc/classes/VisualShaderNodeVectorOp.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorOp" inherits="VisualShaderNode" version="4.0"> <brief_description> + A vector operator to be used within the visual shader graph. </brief_description> <description> + A visual shader node for use of vector operators. Operates on vector [code]a[/code] and vector [code]b[/code]. </description> <tutorials> </tutorials> @@ -10,32 +12,45 @@ </methods> <members> <member name="operator" type="int" setter="set_operator" getter="get_operator" enum="VisualShaderNodeVectorOp.Operator" default="0"> + The operator to be used. See [enum Operator] for options. </member> </members> <constants> <constant name="OP_ADD" value="0" enum="Operator"> + Adds two vectors. </constant> <constant name="OP_SUB" value="1" enum="Operator"> + Subtracts a vector from a vector. </constant> <constant name="OP_MUL" value="2" enum="Operator"> + Multiplies two vectors. </constant> <constant name="OP_DIV" value="3" enum="Operator"> + Divides vector by vector. </constant> <constant name="OP_MOD" value="4" enum="Operator"> + Returns the remainder of the two vectors. </constant> <constant name="OP_POW" value="5" enum="Operator"> + Returns the value of the first parameter raised to the power of the second, for each component of the vectors. </constant> <constant name="OP_MAX" value="6" enum="Operator"> + Returns the greater of two values, for each component of the vectors. </constant> <constant name="OP_MIN" value="7" enum="Operator"> + Returns the lesser of two values, for each component of the vectors. </constant> <constant name="OP_CROSS" value="8" enum="Operator"> + Calculates the cross product of two vectors. </constant> <constant name="OP_ATAN2" value="9" enum="Operator"> + Returns the arc-tangent of the parameters. </constant> <constant name="OP_REFLECT" value="10" enum="Operator"> + Returns the vector that points in the direction of reflection. [code]a[/code] is incident vector and [code]b[/code] is the normal vector. </constant> <constant name="OP_STEP" value="11" enum="Operator"> + Vector step operator. Returns [code]0.0[/code] if [code]a[/code] is smaller than [code]b[/code] and [code]1.0[/code] otherwise. </constant> </constants> </class> diff --git a/doc/classes/VisualShaderNodeVectorRefract.xml b/doc/classes/VisualShaderNodeVectorRefract.xml index c5962e7e10..0fa90a69cf 100644 --- a/doc/classes/VisualShaderNodeVectorRefract.xml +++ b/doc/classes/VisualShaderNodeVectorRefract.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorRefract" inherits="VisualShaderNode" version="4.0"> <brief_description> + Returns the [Vector3] that points in the direction of refraction. For use within the visual shader graph. </brief_description> <description> + Translated to [code]refract(I, N, eta)[/code] in the shader language, where [code]I[/code] is the incident vector, [code]N[/code] is the normal vector and [code]eta[/code] is the ratio of the indicies of the refraction. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorScalarMix.xml b/doc/classes/VisualShaderNodeVectorScalarMix.xml index c425bfe45e..791a9e6be1 100644 --- a/doc/classes/VisualShaderNodeVectorScalarMix.xml +++ b/doc/classes/VisualShaderNodeVectorScalarMix.xml @@ -1,8 +1,10 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorScalarMix" inherits="VisualShaderNode" version="4.0"> <brief_description> + Linearly interpolates between two vectors using a scalar. For use within the visual shader graph. </brief_description> <description> + Translates to [code]mix(a, b, weight)[/code] in the shader language, where [code]a[/code] and [code]b[/code] are vectors and [code]weight[/code] is a scalar. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorScalarSmoothStep.xml b/doc/classes/VisualShaderNodeVectorScalarSmoothStep.xml index 2f341d71e0..580abaf5fe 100644 --- a/doc/classes/VisualShaderNodeVectorScalarSmoothStep.xml +++ b/doc/classes/VisualShaderNodeVectorScalarSmoothStep.xml @@ -1,8 +1,11 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorScalarSmoothStep" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a vector SmoothStep function using scalar within the visual shader graph. </brief_description> <description> + Translates to [code]smoothstep(edge0, edge1, x)[/code] in the shader language, where [code]x[/code] is a scalar. + Returns [code]0.0[/code] if [code]x[/code] is smaller than [code]edge0[/code] and [code]1.0[/code] if [code]x[/code] is larger than [code]edge1[/code]. Otherwise the return value is interpolated between [code]0.0[/code] and [code]1.0[/code] using Hermite polynomials. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorScalarStep.xml b/doc/classes/VisualShaderNodeVectorScalarStep.xml index 11da106001..d61414f3a8 100644 --- a/doc/classes/VisualShaderNodeVectorScalarStep.xml +++ b/doc/classes/VisualShaderNodeVectorScalarStep.xml @@ -1,8 +1,11 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorScalarStep" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a vector Step function within the visual shader graph. </brief_description> <description> + Translates to [code]step(edge, x)[/code] in the shader language. + Returns [code]0.0[/code] if [code]x[/code] is smaller than [code]edge[/code] and [code]1.0[/code] otherwise. </description> <tutorials> </tutorials> diff --git a/doc/classes/VisualShaderNodeVectorSmoothStep.xml b/doc/classes/VisualShaderNodeVectorSmoothStep.xml index 54e9f1bd7d..1b77a3c535 100644 --- a/doc/classes/VisualShaderNodeVectorSmoothStep.xml +++ b/doc/classes/VisualShaderNodeVectorSmoothStep.xml @@ -1,8 +1,11 @@ <?xml version="1.0" encoding="UTF-8" ?> <class name="VisualShaderNodeVectorSmoothStep" inherits="VisualShaderNode" version="4.0"> <brief_description> + Calculates a vector SmoothStep function within the visual shader graph. </brief_description> <description> + Translates to [code]smoothstep(edge0, edge1, x)[/code] in the shader language, where [code]x[/code] is a vector. + Returns [code]0.0[/code] if [code]x[/code] is smaller than [code]edge0[/code] and [code]1.0[/code] if [code]x[/code] is larger than [code]edge1[/code]. Otherwise the return value is interpolated between [code]0.0[/code] and [code]1.0[/code] using Hermite polynomials. </description> <tutorials> </tutorials> diff --git a/doc/classes/World2D.xml b/doc/classes/World2D.xml index 2d8382b7e3..e66f21a0e7 100644 --- a/doc/classes/World2D.xml +++ b/doc/classes/World2D.xml @@ -16,7 +16,7 @@ The [RID] of this world's canvas resource. Used by the [RenderingServer] for 2D drawing. </member> <member name="direct_space_state" type="PhysicsDirectSpaceState2D" setter="" getter="get_direct_space_state"> - The state of this world's physics space. This allows arbitrary querying for collision. + Direct access to the world's physics 2D space state. Used for querying current and potential collisions. Must only be accessed from the main thread within [code]_physics_process(delta)[/code]. </member> <member name="space" type="RID" setter="" getter="get_space"> The [RID] of this world's physics space resource. Used by the [PhysicsServer2D] for 2D physics, treating it as both a space and an area. diff --git a/doc/classes/World3D.xml b/doc/classes/World3D.xml index 4224a2a2c3..6d3b94794e 100644 --- a/doc/classes/World3D.xml +++ b/doc/classes/World3D.xml @@ -15,7 +15,7 @@ <member name="camera_effects" type="CameraEffects" setter="set_camera_effects" getter="get_camera_effects"> </member> <member name="direct_space_state" type="PhysicsDirectSpaceState3D" setter="" getter="get_direct_space_state"> - The World3D's physics direct space state, used for making various queries. Might be used only during [code]_physics_process[/code]. + Direct access to the world's physics 3D space state. Used for querying current and potential collisions. Must only be accessed from within [code]_physics_process(delta)[/code]. </member> <member name="environment" type="Environment" setter="set_environment" getter="get_environment"> The World3D's [Environment]. diff --git a/doc/classes/XRController3D.xml b/doc/classes/XRController3D.xml index e4a06a80db..8e80eb9a32 100644 --- a/doc/classes/XRController3D.xml +++ b/doc/classes/XRController3D.xml @@ -62,7 +62,7 @@ <argument index="0" name="button" type="int"> </argument> <description> - Returns [code]true[/code] if the button at index [code]button[/code] is pressed. See [enum JoystickList], in particular the [code]JOY_VR_*[/code] constants. + Returns [code]true[/code] if the button at index [code]button[/code] is pressed. See [enum JoyButtonList]. </description> </method> </methods> diff --git a/doc/translations/classes.pot b/doc/translations/classes.pot index d58b9d9472..28db7e4e04 100644 --- a/doc/translations/classes.pot +++ b/doc/translations/classes.pot @@ -24656,7 +24656,7 @@ msgstr "" #: doc/classes/Input.xml:99 msgid "" "Returns the current value of the joypad axis at given index (see [enum " -"JoystickList])." +"JoyAxisList])." msgstr "" #: doc/classes/Input.xml:108 @@ -24665,7 +24665,7 @@ msgstr "" #: doc/classes/Input.xml:117 msgid "" -"Receives a [enum JoystickList] axis and returns its equivalent name as a " +"Receives a [enum JoyAxisList] axis and returns its equivalent name as a " "string." msgstr "" @@ -24675,7 +24675,7 @@ msgstr "" #: doc/classes/Input.xml:135 msgid "" -"Receives a gamepad button from [enum JoystickList] and returns its " +"Receives a gamepad button from [enum JoyButtonList] and returns its " "equivalent name as a string." msgstr "" @@ -24750,15 +24750,14 @@ msgstr "" #: doc/classes/Input.xml:238 msgid "" "Returns [code]true[/code] if you are pressing the joypad button (see [enum " -"JoystickList])." +"JoyButtonList])." msgstr "" #: doc/classes/Input.xml:247 msgid "" "Returns [code]true[/code] if the system knows the specified device. This " -"means that it sets all button and axis indices exactly as defined in [enum " -"JoystickList]. Unknown joypads are not expected to match these constants, " -"but you can still retrieve events from them." +"means that it sets all button and axis indices. Unknown joypads are not " +"expected to match these constants, but you can still retrieve events from them." msgstr "" #: doc/classes/Input.xml:256 @@ -25150,7 +25149,7 @@ msgid "" msgstr "" #: doc/classes/InputEventJoypadButton.xml:16 -msgid "Button identifier. One of the [enum JoystickList] button constants." +msgid "Button identifier. One of the [enum JoyButtonList] button constants." msgstr "" #: doc/classes/InputEventJoypadButton.xml:19 @@ -25178,7 +25177,7 @@ msgid "" msgstr "" #: doc/classes/InputEventJoypadMotion.xml:16 -msgid "Axis identifier. Use one of the [enum JoystickList] axis constants." +msgid "Axis identifier. Use one of the [enum JoyAxisList] axis constants." msgstr "" #: doc/classes/InputEventJoypadMotion.xml:19 @@ -57946,8 +57945,7 @@ msgstr "" #: doc/classes/XRController3D.xml:65 msgid "" "Returns [code]true[/code] if the button at index [code]button[/code] is " -"pressed. See [enum JoystickList], in particular the [code]JOY_VR_*[/code] " -"constants." +"pressed. See [enum JoyButtonList]." msgstr "" #: doc/classes/XRController3D.xml:71 diff --git a/doc/translations/fr.po b/doc/translations/fr.po index d39ab5f248..57466d0b04 100644 --- a/doc/translations/fr.po +++ b/doc/translations/fr.po @@ -24666,7 +24666,7 @@ msgstr "" #: doc/classes/Input.xml:99 msgid "" "Returns the current value of the joypad axis at given index (see [enum " -"JoystickList])." +"JoyAxisList])." msgstr "" #: doc/classes/Input.xml:108 @@ -24675,7 +24675,7 @@ msgstr "" #: doc/classes/Input.xml:117 msgid "" -"Receives a [enum JoystickList] axis and returns its equivalent name as a " +"Receives a [enum JoyAxisList] axis and returns its equivalent name as a " "string." msgstr "" @@ -24685,7 +24685,7 @@ msgstr "" #: doc/classes/Input.xml:135 msgid "" -"Receives a gamepad button from [enum JoystickList] and returns its " +"Receives a gamepad button from [enum JoyButtonList] and returns its " "equivalent name as a string." msgstr "" @@ -24760,15 +24760,14 @@ msgstr "" #: doc/classes/Input.xml:238 msgid "" "Returns [code]true[/code] if you are pressing the joypad button (see [enum " -"JoystickList])." +"JoyButtonList])." msgstr "" #: doc/classes/Input.xml:247 msgid "" "Returns [code]true[/code] if the system knows the specified device. This " -"means that it sets all button and axis indices exactly as defined in [enum " -"JoystickList]. Unknown joypads are not expected to match these constants, " -"but you can still retrieve events from them." +"means that it sets all button and axis indices. Unknown joypads are not " +"expected to match these constants, but you can still retrieve events from them." msgstr "" #: doc/classes/Input.xml:256 @@ -25160,7 +25159,7 @@ msgid "" msgstr "" #: doc/classes/InputEventJoypadButton.xml:16 -msgid "Button identifier. One of the [enum JoystickList] button constants." +msgid "Button identifier. One of the [enum JoyButtonList] button constants." msgstr "" #: doc/classes/InputEventJoypadButton.xml:19 @@ -25188,7 +25187,7 @@ msgid "" msgstr "" #: doc/classes/InputEventJoypadMotion.xml:16 -msgid "Axis identifier. Use one of the [enum JoystickList] axis constants." +msgid "Axis identifier. Use one of the [enum JoyAxisList] axis constants." msgstr "" #: doc/classes/InputEventJoypadMotion.xml:19 @@ -57956,8 +57955,7 @@ msgstr "" #: doc/classes/XRController3D.xml:65 msgid "" "Returns [code]true[/code] if the button at index [code]button[/code] is " -"pressed. See [enum JoystickList], in particular the [code]JOY_VR_*[/code] " -"constants." +"pressed. See [enum JoyButtonList]." msgstr "" #: doc/classes/XRController3D.xml:71 diff --git a/drivers/alsa/audio_driver_alsa.cpp b/drivers/alsa/audio_driver_alsa.cpp index 48e694dd3a..e394222d3a 100644 --- a/drivers/alsa/audio_driver_alsa.cpp +++ b/drivers/alsa/audio_driver_alsa.cpp @@ -331,14 +331,4 @@ void AudioDriverALSA::finish() { finish_device(); } -AudioDriverALSA::AudioDriverALSA() : - thread(nullptr), - pcm_handle(nullptr), - device_name("Default"), - new_device("Default") { -} - -AudioDriverALSA::~AudioDriverALSA() { -} - -#endif +#endif // ALSA_ENABLED diff --git a/drivers/alsa/audio_driver_alsa.h b/drivers/alsa/audio_driver_alsa.h index 50bd9e853d..d437993901 100644 --- a/drivers/alsa/audio_driver_alsa.h +++ b/drivers/alsa/audio_driver_alsa.h @@ -41,13 +41,13 @@ class AudioDriverALSA : public AudioDriver { - Thread *thread; + Thread *thread = nullptr; Mutex mutex; - snd_pcm_t *pcm_handle; + snd_pcm_t *pcm_handle = nullptr; - String device_name; - String new_device; + String device_name = "Default"; + String new_device = "Default"; Vector<int32_t> samples_in; Vector<int16_t> samples_out; @@ -85,8 +85,8 @@ public: virtual void unlock(); virtual void finish(); - AudioDriverALSA(); - ~AudioDriverALSA(); + AudioDriverALSA() {} + ~AudioDriverALSA() {} }; #endif // AUDIO_DRIVER_ALSA_H diff --git a/drivers/coreaudio/audio_driver_coreaudio.cpp b/drivers/coreaudio/audio_driver_coreaudio.cpp index 21c3649445..76d2d13dfe 100644 --- a/drivers/coreaudio/audio_driver_coreaudio.cpp +++ b/drivers/coreaudio/audio_driver_coreaudio.cpp @@ -676,19 +676,8 @@ String AudioDriverCoreAudio::capture_get_device() { #endif -AudioDriverCoreAudio::AudioDriverCoreAudio() : - audio_unit(nullptr), - input_unit(nullptr), - active(false), - device_name("Default"), - capture_device_name("Default"), - mix_rate(0), - channels(2), - capture_channels(2), - buffer_frames(0) { +AudioDriverCoreAudio::AudioDriverCoreAudio() { samples_in.clear(); } -AudioDriverCoreAudio::~AudioDriverCoreAudio(){}; - -#endif +#endif // COREAUDIO_ENABLED diff --git a/drivers/coreaudio/audio_driver_coreaudio.h b/drivers/coreaudio/audio_driver_coreaudio.h index fb9473e230..89dd52181f 100644 --- a/drivers/coreaudio/audio_driver_coreaudio.h +++ b/drivers/coreaudio/audio_driver_coreaudio.h @@ -42,19 +42,19 @@ class AudioDriverCoreAudio : public AudioDriver { - AudioComponentInstance audio_unit; - AudioComponentInstance input_unit; + AudioComponentInstance audio_unit = nullptr; + AudioComponentInstance input_unit = nullptr; - bool active; + bool active = false; Mutex mutex; - String device_name; - String capture_device_name; + String device_name = "Default"; + String capture_device_name = "Default"; - int mix_rate; - unsigned int channels; - unsigned int capture_channels; - unsigned int buffer_frames; + int mix_rate = 0; + unsigned int channels = 2; + unsigned int capture_channels = 2; + unsigned int buffer_frames = 0; Vector<int32_t> samples_in; Vector<int16_t> input_buf; @@ -118,7 +118,7 @@ public: #endif AudioDriverCoreAudio(); - ~AudioDriverCoreAudio(); + ~AudioDriverCoreAudio() {} }; #endif diff --git a/drivers/coremidi/midi_driver_coremidi.cpp b/drivers/coremidi/midi_driver_coremidi.cpp index 2cd322813b..f155b4accc 100644 --- a/drivers/coremidi/midi_driver_coremidi.cpp +++ b/drivers/coremidi/midi_driver_coremidi.cpp @@ -112,9 +112,7 @@ PackedStringArray MIDIDriverCoreMidi::get_connected_inputs() { return list; } -MIDIDriverCoreMidi::MIDIDriverCoreMidi() : - client(0) { -} +MIDIDriverCoreMidi::MIDIDriverCoreMidi() {} MIDIDriverCoreMidi::~MIDIDriverCoreMidi() { close(); diff --git a/drivers/coremidi/midi_driver_coremidi.h b/drivers/coremidi/midi_driver_coremidi.h index e8b4481c20..68de7a11fb 100644 --- a/drivers/coremidi/midi_driver_coremidi.h +++ b/drivers/coremidi/midi_driver_coremidi.h @@ -41,7 +41,7 @@ class MIDIDriverCoreMidi : public MIDIDriver { - MIDIClientRef client; + MIDIClientRef client = 0; MIDIPortRef port_in; Vector<MIDIEndpointRef> connected_sources; diff --git a/drivers/dummy/rasterizer_dummy.h b/drivers/dummy/rasterizer_dummy.h index 0bcfed2dcf..05a45315b2 100644 --- a/drivers/dummy/rasterizer_dummy.h +++ b/drivers/dummy/rasterizer_dummy.h @@ -59,6 +59,7 @@ public: void sky_set_texture(RID p_sky, RID p_panorama) {} void sky_set_texture(RID p_sky, RID p_cube_map, int p_radiance_size) {} void sky_set_material(RID p_sky, RID p_material) {} + virtual Ref<Image> sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) { return Ref<Image>(); } /* ENVIRONMENT API */ @@ -78,10 +79,11 @@ public: #endif void environment_set_glow(RID p_env, bool p_enable, int p_level_flags, float p_intensity, float p_strength, float p_mix, float p_bloom_threshold, RS::EnvironmentGlowBlendMode p_blend_mode, float p_hdr_bleed_threshold, float p_hdr_bleed_scale, float p_hdr_luminance_cap) {} - + virtual void environment_glow_set_use_bicubic_upscale(bool p_enable) {} void environment_set_fog(RID p_env, bool p_enable, float p_begin, float p_end, RID p_gradient_texture) {} - void environment_set_ssr(RID p_env, bool p_enable, int p_max_steps, float p_fade_int, float p_fade_out, float p_depth_tolerance, bool p_roughness) {} + void environment_set_ssr(RID p_env, bool p_enable, int p_max_steps, float p_fade_int, float p_fade_out, float p_depth_tolerance) {} + virtual void environment_set_ssr_roughness_quality(RS::EnvironmentSSRRoughnessQuality p_quality) {} virtual void environment_set_ssao(RID p_env, bool p_enable, float p_radius, float p_intensity, float p_bias, float p_light_affect, float p_ao_channel_affect, RS::EnvironmentSSAOBlur p_blur, float p_bilateral_sharpness) {} virtual void environment_set_ssao_quality(RS::EnvironmentSSAOQuality p_quality, bool p_half_size) {} @@ -93,6 +95,8 @@ public: void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) {} void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) {} + virtual Ref<Image> environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) { return Ref<Image>(); } + bool is_environment(RID p_env) const { return false; } RS::EnvironmentBG environment_get_background(RID p_env) const { return RS::ENV_BG_KEEP; } int environment_get_canvas_max_layer(RID p_env) const { return 0; } @@ -105,9 +109,12 @@ public: virtual void camera_effects_set_dof_blur(RID p_camera_effects, bool p_far_enable, float p_far_distance, float p_far_transition, bool p_near_enable, float p_near_distance, float p_near_transition, float p_amount) {} virtual void camera_effects_set_custom_exposure(RID p_camera_effects, bool p_enable, float p_exposure) {} + virtual void shadows_quality_set(RS::ShadowQuality p_quality) {} + virtual void directional_shadow_quality_set(RS::ShadowQuality p_quality) {} + RID light_instance_create(RID p_light) { return RID(); } void light_instance_set_transform(RID p_light_instance, const Transform &p_transform) {} - void light_instance_set_shadow_transform(RID p_light_instance, const CameraMatrix &p_projection, const Transform &p_transform, float p_far, float p_split, int p_pass, float p_bias_scale = 1.0) {} + void light_instance_set_shadow_transform(RID p_light_instance, const CameraMatrix &p_projection, const Transform &p_transform, float p_far, float p_split, int p_pass, float p_shadow_texel_size, float p_bias_scale = 1.0, float p_range_begin = 0, const Vector2 &p_uv_scale = Vector2()) {} void light_instance_mark_visible(RID p_light_instance) {} RID reflection_atlas_create() { return RID(); } @@ -121,13 +128,16 @@ public: bool reflection_probe_instance_begin_render(RID p_instance, RID p_reflection_atlas) { return false; } bool reflection_probe_instance_postprocess_step(RID p_instance) { return true; } + virtual RID decal_instance_create(RID p_decal) { return RID(); } + virtual void decal_instance_set_transform(RID p_decal, const Transform &p_transform) {} + virtual RID gi_probe_instance_create(RID p_gi_probe) { return RID(); } void gi_probe_instance_set_light_data(RID p_probe, RID p_base, RID p_data) {} void gi_probe_instance_set_transform_to_data(RID p_probe, const Transform &p_xform) {} virtual bool gi_probe_needs_update(RID p_probe) const { return false; } virtual void gi_probe_update(RID p_probe, bool p_update_light_instances, const Vector<RID> &p_light_instances, int p_dynamic_object_count, InstanceBase **p_dynamic_objects) {} - virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) {} + virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) {} void render_shadow(RID p_light, RID p_shadow_atlas, int p_pass, InstanceBase **p_cull_result, int p_cull_count) {} virtual void render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) {} @@ -136,11 +146,16 @@ public: void set_debug_draw_mode(RS::ViewportDebugDraw p_debug_draw) {} virtual RID render_buffers_create() { return RID(); } - virtual void render_buffers_configure(RID p_render_buffers, RID p_render_target, int p_width, int p_height, RS::ViewportMSAA p_msaa) {} + virtual void render_buffers_configure(RID p_render_buffers, RID p_render_target, int p_width, int p_height, RS::ViewportMSAA p_msaa, RS::ViewportScreenSpaceAA p_screen_space_aa) {} virtual void screen_space_roughness_limiter_set_active(bool p_enable, float p_curve) {} virtual bool screen_space_roughness_limiter_is_active() const { return false; } + virtual void sub_surface_scattering_set_quality(RS::SubSurfaceScatteringQuality p_quality) {} + virtual void sub_surface_scattering_set_scale(float p_scale, float p_depth_scale) {} + + virtual TypedArray<Image> bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size) { return TypedArray<Image>(); } + bool free(RID p_rid) { return true; } virtual void update() {} @@ -192,7 +207,7 @@ public: virtual void texture_proxy_update(RID p_proxy, RID p_base) {} virtual RID texture_2d_placeholder_create() { return RID(); } - virtual RID texture_2d_layered_placeholder_create() { return RID(); } + virtual RID texture_2d_layered_placeholder_create(RenderingServer::TextureLayeredType p_layered_type) { return RID(); } virtual RID texture_3d_placeholder_create() { return RID(); } virtual Ref<Image> texture_2d_get(RID p_texture) const { return Ref<Image>(); } @@ -217,6 +232,9 @@ public: virtual void texture_set_force_redraw_if_visible(RID p_texture, bool p_enable) {} virtual Size2 texture_size_with_proxy(RID p_proxy) { return Size2(); } + virtual void texture_add_to_decal_atlas(RID p_texture, bool p_panorama_to_dp = false) {} + virtual void texture_remove_from_decal_atlas(RID p_texture, bool p_panorama_to_dp = false) {} + #if 0 RID texture_create() { @@ -345,6 +363,7 @@ public: bool material_is_animated(RID p_material) { return false; } bool material_casts_shadows(RID p_material) { return false; } + virtual void material_get_instance_shader_parameters(RID p_material, List<InstanceShaderParam> *r_parameters) {} void material_update_dependency(RID p_material, RasterizerScene::InstanceBase *p_instance) {} /* MESH API */ @@ -608,6 +627,21 @@ public: virtual void base_update_dependency(RID p_base, RasterizerScene::InstanceBase *p_instance) {} virtual void skeleton_update_dependency(RID p_base, RasterizerScene::InstanceBase *p_instance) {} + /* DECAL API */ + + virtual RID decal_create() { return RID(); } + virtual void decal_set_extents(RID p_decal, const Vector3 &p_extents) {} + virtual void decal_set_texture(RID p_decal, RS::DecalTexture p_type, RID p_texture) {} + virtual void decal_set_emission_energy(RID p_decal, float p_energy) {} + virtual void decal_set_albedo_mix(RID p_decal, float p_mix) {} + virtual void decal_set_modulate(RID p_decal, const Color &p_modulate) {} + virtual void decal_set_cull_mask(RID p_decal, uint32_t p_layers) {} + virtual void decal_set_distance_fade(RID p_decal, bool p_enabled, float p_begin, float p_length) {} + virtual void decal_set_fade(RID p_decal, float p_above, float p_below) {} + virtual void decal_set_normal_fade(RID p_decal, float p_fade) {} + + virtual AABB decal_get_aabb(RID p_decal) const { return AABB(); } + /* GI PROBE API */ RID gi_probe_create() { return RID(); } @@ -656,6 +690,7 @@ public: uint32_t gi_probe_get_version(RID p_gi_probe) { return 0; } /* LIGHTMAP CAPTURE */ +#if 0 struct Instantiable { SelfList<RasterizerScene::InstanceBase>::List instance_list; @@ -722,6 +757,23 @@ public: ERR_FAIL_COND_V(!capture, nullptr); return &capture->octree; } +#endif + + virtual RID lightmap_create() { return RID(); } + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) {} + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) {} + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior) {} + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) {} + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const { return PackedVector3Array(); } + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const { return PackedColorArray(); } + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const { return PackedInt32Array(); } + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const { return PackedInt32Array(); } + virtual AABB lightmap_get_aabb(RID p_lightmap) const { return AABB(); } + virtual void lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh) {} + virtual bool lightmap_is_interior(RID p_lightmap) const { return false; } + virtual void lightmap_set_probe_capture_update_speed(float p_speed) {} + virtual float lightmap_get_probe_capture_update_speed() const { return 0; } /* PARTICLES */ @@ -757,6 +809,24 @@ public: int particles_get_draw_passes(RID p_particles) const { return 0; } RID particles_get_draw_pass_mesh(RID p_particles, int p_pass) const { return RID(); } + /* GLOBAL VARIABLES */ + + virtual void global_variable_add(const StringName &p_name, RS::GlobalVariableType p_type, const Variant &p_value) {} + virtual void global_variable_remove(const StringName &p_name) {} + virtual Vector<StringName> global_variable_get_list() const { return Vector<StringName>(); } + + virtual void global_variable_set(const StringName &p_name, const Variant &p_value) {} + virtual void global_variable_set_override(const StringName &p_name, const Variant &p_value) {} + virtual Variant global_variable_get(const StringName &p_name) const { return Variant(); } + virtual RS::GlobalVariableType global_variable_get_type(const StringName &p_name) const { return RS::GLOBAL_VAR_TYPE_MAX; } + + virtual void global_variables_load_settings(bool p_load_textures = true) {} + virtual void global_variables_clear() {} + + virtual int32_t global_variables_instance_allocate(RID p_instance) { return 0; } + virtual void global_variables_instance_free(RID p_instance) {} + virtual void global_variables_instance_update(RID p_instance, int p_index, const Variant &p_value) {} + virtual bool particles_is_inactive(RID p_particles) const { return false; } /* RENDER TARGET */ @@ -853,6 +923,10 @@ public: }; class RasterizerDummy : public Rasterizer { +private: + uint64_t frame = 1; + float delta = 0; + protected: RasterizerCanvasDummy canvas; RasterizerStorageDummy storage; @@ -866,12 +940,20 @@ public: void set_boot_image(const Ref<Image> &p_image, const Color &p_color, bool p_scale, bool p_use_filter = true) {} void initialize() {} - void begin_frame(double frame_step) {} + void begin_frame(double frame_step) { + frame++; + delta = frame_step; + } virtual void prepare_for_blitting_render_targets() {} virtual void blit_render_targets_to_screen(int p_screen, const BlitToScreen *p_render_targets, int p_amount) {} - void end_frame(bool p_swap_buffers) { OS::get_singleton()->swap_buffers(); } + void end_frame(bool p_swap_buffers) { + if (p_swap_buffers) { + DisplayServer::get_singleton()->swap_buffers(); + } + } + void finalize() {} static Error is_viable() { @@ -887,6 +969,8 @@ public: } virtual bool is_low_end() const { return true; } + virtual uint64_t get_frame_number() const { return frame; } + virtual float get_frame_delta_time() const { return delta; } RasterizerDummy() {} ~RasterizerDummy() {} diff --git a/drivers/gles2/shader_compiler_gles2.cpp b/drivers/gles2/shader_compiler_gles2.cpp index 92c1ada850..f6a764c26c 100644 --- a/drivers/gles2/shader_compiler_gles2.cpp +++ b/drivers/gles2/shader_compiler_gles2.cpp @@ -55,10 +55,14 @@ static String _typestr(SL::DataType p_type) { static String _prestr(SL::DataPrecision p_pres) { switch (p_pres) { - case SL::PRECISION_LOWP: return "lowp "; - case SL::PRECISION_MEDIUMP: return "mediump "; - case SL::PRECISION_HIGHP: return "highp "; - case SL::PRECISION_DEFAULT: return ""; + case SL::PRECISION_LOWP: + return "lowp "; + case SL::PRECISION_MEDIUMP: + return "mediump "; + case SL::PRECISION_HIGHP: + return "highp "; + case SL::PRECISION_DEFAULT: + return ""; } return ""; } @@ -66,9 +70,12 @@ static String _prestr(SL::DataPrecision p_pres) { static String _qualstr(SL::ArgumentQualifier p_qual) { switch (p_qual) { - case SL::ARGUMENT_QUALIFIER_IN: return "in "; - case SL::ARGUMENT_QUALIFIER_OUT: return "out "; - case SL::ARGUMENT_QUALIFIER_INOUT: return "inout "; + case SL::ARGUMENT_QUALIFIER_IN: + return "in "; + case SL::ARGUMENT_QUALIFIER_OUT: + return "out "; + case SL::ARGUMENT_QUALIFIER_INOUT: + return "inout "; } return ""; } @@ -96,7 +103,8 @@ static String f2sp0(float p_float) { static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNode::Value> &p_values) { switch (p_type) { - case SL::TYPE_BOOL: return p_values[0].boolean ? "true" : "false"; + case SL::TYPE_BOOL: + return p_values[0].boolean ? "true" : "false"; case SL::TYPE_BVEC2: case SL::TYPE_BVEC3: case SL::TYPE_BVEC4: { @@ -118,7 +126,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo } // GLSL ES 2 doesn't support uints, so we just use signed ints instead... - case SL::TYPE_UINT: return itos(p_values[0].uint); + case SL::TYPE_UINT: + return itos(p_values[0].uint); case SL::TYPE_UVEC2: case SL::TYPE_UVEC3: case SL::TYPE_UVEC4: { @@ -140,7 +149,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo } break; - case SL::TYPE_INT: return itos(p_values[0].sint); + case SL::TYPE_INT: + return itos(p_values[0].sint); case SL::TYPE_IVEC2: case SL::TYPE_IVEC3: case SL::TYPE_IVEC4: { @@ -161,7 +171,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo return text.as_string(); } break; - case SL::TYPE_FLOAT: return f2sp0(p_values[0].real); + case SL::TYPE_FLOAT: + return f2sp0(p_values[0].real); case SL::TYPE_VEC2: case SL::TYPE_VEC3: case SL::TYPE_VEC4: { @@ -202,7 +213,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo return text.as_string(); } break; - default: ERR_FAIL_V(String()); + default: + ERR_FAIL_V(String()); } } diff --git a/drivers/gles2/shaders/scene.glsl b/drivers/gles2/shaders/scene.glsl index 84aadcbbc3..b720c71cec 100644 --- a/drivers/gles2/shaders/scene.glsl +++ b/drivers/gles2/shaders/scene.glsl @@ -1128,7 +1128,8 @@ float SchlickFresnel(float u) { } float GTR1(float NdotH, float a) { - if (a >= 1.0) return 1.0 / M_PI; + if (a >= 1.0) + return 1.0 / M_PI; float a2 = a * a; float t = 1.0 + (a2 - 1.0) * NdotH * NdotH; return (a2 - 1.0) / (M_PI * log(a2) * t); diff --git a/drivers/png/png_driver_common.cpp b/drivers/png/png_driver_common.cpp index f17abcb54c..3f9c824e93 100644 --- a/drivers/png/png_driver_common.cpp +++ b/drivers/png/png_driver_common.cpp @@ -115,7 +115,7 @@ Error png_to_image(const uint8_t *p_source, size_t p_size, Ref<Image> p_image) { ERR_FAIL_COND_V(!success, ERR_FILE_CORRUPT); //print_line("png width: "+itos(png_img.width)+" height: "+itos(png_img.height)); - p_image->create(png_img.width, png_img.height, 0, dest_format, buffer); + p_image->create(png_img.width, png_img.height, false, dest_format, buffer); return OK; } diff --git a/drivers/pulseaudio/audio_driver_pulseaudio.cpp b/drivers/pulseaudio/audio_driver_pulseaudio.cpp index 8a47f6cf96..b16408f727 100644 --- a/drivers/pulseaudio/audio_driver_pulseaudio.cpp +++ b/drivers/pulseaudio/audio_driver_pulseaudio.cpp @@ -793,30 +793,9 @@ String AudioDriverPulseAudio::capture_get_device() { return name; } -AudioDriverPulseAudio::AudioDriverPulseAudio() : - thread(nullptr), - pa_ml(nullptr), - pa_ctx(nullptr), - pa_str(nullptr), - pa_rec_str(nullptr), - device_name("Default"), - new_device("Default"), - default_device(""), - mix_rate(0), - buffer_frames(0), - pa_buffer_size(0), - channels(0), - pa_ready(0), - pa_status(0), - active(false), - thread_exited(false), - exit_thread(false), - latency(0) { +AudioDriverPulseAudio::AudioDriverPulseAudio() { samples_in.clear(); samples_out.clear(); } -AudioDriverPulseAudio::~AudioDriverPulseAudio() { -} - -#endif +#endif // PULSEAUDIO_ENABLED diff --git a/drivers/pulseaudio/audio_driver_pulseaudio.h b/drivers/pulseaudio/audio_driver_pulseaudio.h index 1ece332a8a..ab55a15076 100644 --- a/drivers/pulseaudio/audio_driver_pulseaudio.h +++ b/drivers/pulseaudio/audio_driver_pulseaudio.h @@ -41,18 +41,18 @@ class AudioDriverPulseAudio : public AudioDriver { - Thread *thread; + Thread *thread = nullptr; Mutex mutex; - pa_mainloop *pa_ml; - pa_context *pa_ctx; - pa_stream *pa_str; - pa_stream *pa_rec_str; + pa_mainloop *pa_ml = nullptr; + pa_context *pa_ctx = nullptr; + pa_stream *pa_str = nullptr; + pa_stream *pa_rec_str = nullptr; pa_channel_map pa_map; pa_channel_map pa_rec_map; - String device_name; - String new_device; + String device_name = "Default"; + String new_device = "Default"; String default_device; String capture_device_name; @@ -62,20 +62,20 @@ class AudioDriverPulseAudio : public AudioDriver { Vector<int32_t> samples_in; Vector<int16_t> samples_out; - unsigned int mix_rate; - unsigned int buffer_frames; - unsigned int pa_buffer_size; - int channels; - int pa_ready; - int pa_status; + unsigned int mix_rate = 0; + unsigned int buffer_frames = 0; + unsigned int pa_buffer_size = 0; + int channels = 0; + int pa_ready = 0; + int pa_status = 0; Array pa_devices; Array pa_rec_devices; - bool active; - bool thread_exited; - mutable bool exit_thread; + bool active = false; + bool thread_exited = false; + mutable bool exit_thread = false; - float latency; + float latency = 0; static void pa_state_cb(pa_context *c, void *userdata); static void pa_sink_info_cb(pa_context *c, const pa_sink_info *l, int eol, void *userdata); @@ -122,7 +122,7 @@ public: virtual Error capture_stop(); AudioDriverPulseAudio(); - ~AudioDriverPulseAudio(); + ~AudioDriverPulseAudio() {} }; #endif // AUDIO_DRIVER_PULSEAUDIO_H diff --git a/drivers/unix/file_access_unix.cpp b/drivers/unix/file_access_unix.cpp index 4aa408a1f0..54a585c6fd 100644 --- a/drivers/unix/file_access_unix.cpp +++ b/drivers/unix/file_access_unix.cpp @@ -354,14 +354,7 @@ FileAccess *FileAccessUnix::create_libc() { CloseNotificationFunc FileAccessUnix::close_notification_func = nullptr; -FileAccessUnix::FileAccessUnix() : - f(nullptr), - flags(0), - last_error(OK) { -} - FileAccessUnix::~FileAccessUnix() { - close(); } diff --git a/drivers/unix/file_access_unix.h b/drivers/unix/file_access_unix.h index 8116f72345..1dd080914d 100644 --- a/drivers/unix/file_access_unix.h +++ b/drivers/unix/file_access_unix.h @@ -42,10 +42,10 @@ typedef void (*CloseNotificationFunc)(const String &p_file, int p_flags); class FileAccessUnix : public FileAccess { - FILE *f; - int flags; + FILE *f = nullptr; + int flags = 0; void check_errors() const; - mutable Error last_error; + mutable Error last_error = OK; String save_path; String path; String path_src; @@ -84,7 +84,7 @@ public: virtual uint32_t _get_unix_permissions(const String &p_file); virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions); - FileAccessUnix(); + FileAccessUnix() {} virtual ~FileAccessUnix(); }; diff --git a/drivers/unix/ip_unix.cpp b/drivers/unix/ip_unix.cpp index 5e3dedfc2f..56be9a2f74 100644 --- a/drivers/unix/ip_unix.cpp +++ b/drivers/unix/ip_unix.cpp @@ -248,7 +248,8 @@ void IP_Unix::get_local_interfaces(Map<String, Interface_Info> *r_interfaces) co info.ip_addresses.push_front(_sockaddr2ip(ifa->ifa_addr)); } - if (ifAddrStruct != nullptr) freeifaddrs(ifAddrStruct); + if (ifAddrStruct != nullptr) + freeifaddrs(ifAddrStruct); } #endif diff --git a/drivers/unix/net_socket_posix.cpp b/drivers/unix/net_socket_posix.cpp index 7c6543c3a2..81ea20e5da 100644 --- a/drivers/unix/net_socket_posix.cpp +++ b/drivers/unix/net_socket_posix.cpp @@ -173,9 +173,7 @@ void NetSocketPosix::cleanup() { } NetSocketPosix::NetSocketPosix() : - _sock(SOCK_EMPTY), - _ip_type(IP::TYPE_NONE), - _is_stream(false) { + _sock(SOCK_EMPTY) { } NetSocketPosix::~NetSocketPosix() { diff --git a/drivers/unix/net_socket_posix.h b/drivers/unix/net_socket_posix.h index 0a19967265..4e1fedfcb0 100644 --- a/drivers/unix/net_socket_posix.h +++ b/drivers/unix/net_socket_posix.h @@ -47,9 +47,9 @@ class NetSocketPosix : public NetSocket { private: - SOCKET_TYPE _sock; - IP::Type _ip_type; - bool _is_stream; + SOCKET_TYPE _sock; // NOLINT - the default value is defined in the .cpp + IP::Type _ip_type = IP::TYPE_NONE; + bool _is_stream = false; enum NetError { ERR_NET_WOULD_BLOCK, diff --git a/drivers/unix/syslog_logger.cpp b/drivers/unix/syslog_logger.cpp index dc9112bf14..8296d6ce30 100644 --- a/drivers/unix/syslog_logger.cpp +++ b/drivers/unix/syslog_logger.cpp @@ -49,11 +49,21 @@ void SyslogLogger::print_error(const char *p_function, const char *p_file, int p const char *err_type = "**ERROR**"; switch (p_type) { - case ERR_ERROR: err_type = "**ERROR**"; break; - case ERR_WARNING: err_type = "**WARNING**"; break; - case ERR_SCRIPT: err_type = "**SCRIPT ERROR**"; break; - case ERR_SHADER: err_type = "**SHADER ERROR**"; break; - default: ERR_PRINT("Unknown error type"); break; + case ERR_ERROR: + err_type = "**ERROR**"; + break; + case ERR_WARNING: + err_type = "**WARNING**"; + break; + case ERR_SCRIPT: + err_type = "**SCRIPT ERROR**"; + break; + case ERR_SHADER: + err_type = "**SHADER ERROR**"; + break; + default: + ERR_PRINT("Unknown error type"); + break; } const char *err_details; diff --git a/drivers/vulkan/rendering_device_vulkan.cpp b/drivers/vulkan/rendering_device_vulkan.cpp index 23e9227a39..71be891b1d 100644 --- a/drivers/vulkan/rendering_device_vulkan.cpp +++ b/drivers/vulkan/rendering_device_vulkan.cpp @@ -566,52 +566,66 @@ int RenderingDeviceVulkan::get_format_vertex_size(DataFormat p_format) { case DATA_FORMAT_B8G8R8A8_UNORM: case DATA_FORMAT_B8G8R8A8_SNORM: case DATA_FORMAT_B8G8R8A8_UINT: - case DATA_FORMAT_B8G8R8A8_SINT: return 4; + case DATA_FORMAT_B8G8R8A8_SINT: + return 4; case DATA_FORMAT_R16_UNORM: case DATA_FORMAT_R16_SNORM: case DATA_FORMAT_R16_UINT: case DATA_FORMAT_R16_SINT: - case DATA_FORMAT_R16_SFLOAT: return 4; + case DATA_FORMAT_R16_SFLOAT: + return 4; case DATA_FORMAT_R16G16_UNORM: case DATA_FORMAT_R16G16_SNORM: case DATA_FORMAT_R16G16_UINT: case DATA_FORMAT_R16G16_SINT: - case DATA_FORMAT_R16G16_SFLOAT: return 4; + case DATA_FORMAT_R16G16_SFLOAT: + return 4; case DATA_FORMAT_R16G16B16_UNORM: case DATA_FORMAT_R16G16B16_SNORM: case DATA_FORMAT_R16G16B16_UINT: case DATA_FORMAT_R16G16B16_SINT: - case DATA_FORMAT_R16G16B16_SFLOAT: return 8; + case DATA_FORMAT_R16G16B16_SFLOAT: + return 8; case DATA_FORMAT_R16G16B16A16_UNORM: case DATA_FORMAT_R16G16B16A16_SNORM: case DATA_FORMAT_R16G16B16A16_UINT: case DATA_FORMAT_R16G16B16A16_SINT: - case DATA_FORMAT_R16G16B16A16_SFLOAT: return 8; + case DATA_FORMAT_R16G16B16A16_SFLOAT: + return 8; case DATA_FORMAT_R32_UINT: case DATA_FORMAT_R32_SINT: - case DATA_FORMAT_R32_SFLOAT: return 4; + case DATA_FORMAT_R32_SFLOAT: + return 4; case DATA_FORMAT_R32G32_UINT: case DATA_FORMAT_R32G32_SINT: - case DATA_FORMAT_R32G32_SFLOAT: return 8; + case DATA_FORMAT_R32G32_SFLOAT: + return 8; case DATA_FORMAT_R32G32B32_UINT: case DATA_FORMAT_R32G32B32_SINT: - case DATA_FORMAT_R32G32B32_SFLOAT: return 12; + case DATA_FORMAT_R32G32B32_SFLOAT: + return 12; case DATA_FORMAT_R32G32B32A32_UINT: case DATA_FORMAT_R32G32B32A32_SINT: - case DATA_FORMAT_R32G32B32A32_SFLOAT: return 16; + case DATA_FORMAT_R32G32B32A32_SFLOAT: + return 16; case DATA_FORMAT_R64_UINT: case DATA_FORMAT_R64_SINT: - case DATA_FORMAT_R64_SFLOAT: return 8; + case DATA_FORMAT_R64_SFLOAT: + return 8; case DATA_FORMAT_R64G64_UINT: case DATA_FORMAT_R64G64_SINT: - case DATA_FORMAT_R64G64_SFLOAT: return 16; + case DATA_FORMAT_R64G64_SFLOAT: + return 16; case DATA_FORMAT_R64G64B64_UINT: case DATA_FORMAT_R64G64B64_SINT: - case DATA_FORMAT_R64G64B64_SFLOAT: return 24; + case DATA_FORMAT_R64G64B64_SFLOAT: + return 24; case DATA_FORMAT_R64G64B64A64_UINT: case DATA_FORMAT_R64G64B64A64_SINT: - case DATA_FORMAT_R64G64B64A64_SFLOAT: return 32; - default: return 0; + case DATA_FORMAT_R64G64B64A64_SFLOAT: + return 32; + default: + return 0; } } @@ -619,28 +633,32 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) switch (p_format) { - case DATA_FORMAT_R4G4_UNORM_PACK8: return 1; + case DATA_FORMAT_R4G4_UNORM_PACK8: + return 1; case DATA_FORMAT_R4G4B4A4_UNORM_PACK16: case DATA_FORMAT_B4G4R4A4_UNORM_PACK16: case DATA_FORMAT_R5G6B5_UNORM_PACK16: case DATA_FORMAT_B5G6R5_UNORM_PACK16: case DATA_FORMAT_R5G5B5A1_UNORM_PACK16: case DATA_FORMAT_B5G5R5A1_UNORM_PACK16: - case DATA_FORMAT_A1R5G5B5_UNORM_PACK16: return 2; + case DATA_FORMAT_A1R5G5B5_UNORM_PACK16: + return 2; case DATA_FORMAT_R8_UNORM: case DATA_FORMAT_R8_SNORM: case DATA_FORMAT_R8_USCALED: case DATA_FORMAT_R8_SSCALED: case DATA_FORMAT_R8_UINT: case DATA_FORMAT_R8_SINT: - case DATA_FORMAT_R8_SRGB: return 1; + case DATA_FORMAT_R8_SRGB: + return 1; case DATA_FORMAT_R8G8_UNORM: case DATA_FORMAT_R8G8_SNORM: case DATA_FORMAT_R8G8_USCALED: case DATA_FORMAT_R8G8_SSCALED: case DATA_FORMAT_R8G8_UINT: case DATA_FORMAT_R8G8_SINT: - case DATA_FORMAT_R8G8_SRGB: return 2; + case DATA_FORMAT_R8G8_SRGB: + return 2; case DATA_FORMAT_R8G8B8_UNORM: case DATA_FORMAT_R8G8B8_SNORM: case DATA_FORMAT_R8G8B8_USCALED: @@ -654,7 +672,8 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_B8G8R8_SSCALED: case DATA_FORMAT_B8G8R8_UINT: case DATA_FORMAT_B8G8R8_SINT: - case DATA_FORMAT_B8G8R8_SRGB: return 3; + case DATA_FORMAT_B8G8R8_SRGB: + return 3; case DATA_FORMAT_R8G8B8A8_UNORM: case DATA_FORMAT_R8G8B8A8_SNORM: case DATA_FORMAT_R8G8B8A8_USCALED: @@ -668,7 +687,8 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_B8G8R8A8_SSCALED: case DATA_FORMAT_B8G8R8A8_UINT: case DATA_FORMAT_B8G8R8A8_SINT: - case DATA_FORMAT_B8G8R8A8_SRGB: return 4; + case DATA_FORMAT_B8G8R8A8_SRGB: + return 4; case DATA_FORMAT_A8B8G8R8_UNORM_PACK32: case DATA_FORMAT_A8B8G8R8_SNORM_PACK32: case DATA_FORMAT_A8B8G8R8_USCALED_PACK32: @@ -687,67 +707,87 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_A2B10G10R10_USCALED_PACK32: case DATA_FORMAT_A2B10G10R10_SSCALED_PACK32: case DATA_FORMAT_A2B10G10R10_UINT_PACK32: - case DATA_FORMAT_A2B10G10R10_SINT_PACK32: return 4; + case DATA_FORMAT_A2B10G10R10_SINT_PACK32: + return 4; case DATA_FORMAT_R16_UNORM: case DATA_FORMAT_R16_SNORM: case DATA_FORMAT_R16_USCALED: case DATA_FORMAT_R16_SSCALED: case DATA_FORMAT_R16_UINT: case DATA_FORMAT_R16_SINT: - case DATA_FORMAT_R16_SFLOAT: return 2; + case DATA_FORMAT_R16_SFLOAT: + return 2; case DATA_FORMAT_R16G16_UNORM: case DATA_FORMAT_R16G16_SNORM: case DATA_FORMAT_R16G16_USCALED: case DATA_FORMAT_R16G16_SSCALED: case DATA_FORMAT_R16G16_UINT: case DATA_FORMAT_R16G16_SINT: - case DATA_FORMAT_R16G16_SFLOAT: return 4; + case DATA_FORMAT_R16G16_SFLOAT: + return 4; case DATA_FORMAT_R16G16B16_UNORM: case DATA_FORMAT_R16G16B16_SNORM: case DATA_FORMAT_R16G16B16_USCALED: case DATA_FORMAT_R16G16B16_SSCALED: case DATA_FORMAT_R16G16B16_UINT: case DATA_FORMAT_R16G16B16_SINT: - case DATA_FORMAT_R16G16B16_SFLOAT: return 6; + case DATA_FORMAT_R16G16B16_SFLOAT: + return 6; case DATA_FORMAT_R16G16B16A16_UNORM: case DATA_FORMAT_R16G16B16A16_SNORM: case DATA_FORMAT_R16G16B16A16_USCALED: case DATA_FORMAT_R16G16B16A16_SSCALED: case DATA_FORMAT_R16G16B16A16_UINT: case DATA_FORMAT_R16G16B16A16_SINT: - case DATA_FORMAT_R16G16B16A16_SFLOAT: return 8; + case DATA_FORMAT_R16G16B16A16_SFLOAT: + return 8; case DATA_FORMAT_R32_UINT: case DATA_FORMAT_R32_SINT: - case DATA_FORMAT_R32_SFLOAT: return 4; + case DATA_FORMAT_R32_SFLOAT: + return 4; case DATA_FORMAT_R32G32_UINT: case DATA_FORMAT_R32G32_SINT: - case DATA_FORMAT_R32G32_SFLOAT: return 8; + case DATA_FORMAT_R32G32_SFLOAT: + return 8; case DATA_FORMAT_R32G32B32_UINT: case DATA_FORMAT_R32G32B32_SINT: - case DATA_FORMAT_R32G32B32_SFLOAT: return 12; + case DATA_FORMAT_R32G32B32_SFLOAT: + return 12; case DATA_FORMAT_R32G32B32A32_UINT: case DATA_FORMAT_R32G32B32A32_SINT: - case DATA_FORMAT_R32G32B32A32_SFLOAT: return 16; + case DATA_FORMAT_R32G32B32A32_SFLOAT: + return 16; case DATA_FORMAT_R64_UINT: case DATA_FORMAT_R64_SINT: - case DATA_FORMAT_R64_SFLOAT: return 8; + case DATA_FORMAT_R64_SFLOAT: + return 8; case DATA_FORMAT_R64G64_UINT: case DATA_FORMAT_R64G64_SINT: - case DATA_FORMAT_R64G64_SFLOAT: return 16; + case DATA_FORMAT_R64G64_SFLOAT: + return 16; case DATA_FORMAT_R64G64B64_UINT: case DATA_FORMAT_R64G64B64_SINT: - case DATA_FORMAT_R64G64B64_SFLOAT: return 24; + case DATA_FORMAT_R64G64B64_SFLOAT: + return 24; case DATA_FORMAT_R64G64B64A64_UINT: case DATA_FORMAT_R64G64B64A64_SINT: - case DATA_FORMAT_R64G64B64A64_SFLOAT: return 32; + case DATA_FORMAT_R64G64B64A64_SFLOAT: + return 32; case DATA_FORMAT_B10G11R11_UFLOAT_PACK32: - case DATA_FORMAT_E5B9G9R9_UFLOAT_PACK32: return 4; - case DATA_FORMAT_D16_UNORM: return 2; - case DATA_FORMAT_X8_D24_UNORM_PACK32: return 4; - case DATA_FORMAT_D32_SFLOAT: return 4; - case DATA_FORMAT_S8_UINT: return 1; - case DATA_FORMAT_D16_UNORM_S8_UINT: return 4; - case DATA_FORMAT_D24_UNORM_S8_UINT: return 4; + case DATA_FORMAT_E5B9G9R9_UFLOAT_PACK32: + return 4; + case DATA_FORMAT_D16_UNORM: + return 2; + case DATA_FORMAT_X8_D24_UNORM_PACK32: + return 4; + case DATA_FORMAT_D32_SFLOAT: + return 4; + case DATA_FORMAT_S8_UINT: + return 1; + case DATA_FORMAT_D16_UNORM_S8_UINT: + return 4; + case DATA_FORMAT_D24_UNORM_S8_UINT: + return 4; case DATA_FORMAT_D32_SFLOAT_S8_UINT: return 5; //? case DATA_FORMAT_BC1_RGB_UNORM_BLOCK: @@ -765,17 +805,20 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_BC6H_UFLOAT_BLOCK: case DATA_FORMAT_BC6H_SFLOAT_BLOCK: case DATA_FORMAT_BC7_UNORM_BLOCK: - case DATA_FORMAT_BC7_SRGB_BLOCK: return 1; + case DATA_FORMAT_BC7_SRGB_BLOCK: + return 1; case DATA_FORMAT_ETC2_R8G8B8_UNORM_BLOCK: case DATA_FORMAT_ETC2_R8G8B8_SRGB_BLOCK: case DATA_FORMAT_ETC2_R8G8B8A1_UNORM_BLOCK: case DATA_FORMAT_ETC2_R8G8B8A1_SRGB_BLOCK: case DATA_FORMAT_ETC2_R8G8B8A8_UNORM_BLOCK: - case DATA_FORMAT_ETC2_R8G8B8A8_SRGB_BLOCK: return 1; + case DATA_FORMAT_ETC2_R8G8B8A8_SRGB_BLOCK: + return 1; case DATA_FORMAT_EAC_R11_UNORM_BLOCK: case DATA_FORMAT_EAC_R11_SNORM_BLOCK: case DATA_FORMAT_EAC_R11G11_UNORM_BLOCK: - case DATA_FORMAT_EAC_R11G11_SNORM_BLOCK: return 1; + case DATA_FORMAT_EAC_R11G11_SNORM_BLOCK: + return 1; case DATA_FORMAT_ASTC_4x4_UNORM_BLOCK: case DATA_FORMAT_ASTC_4x4_SRGB_BLOCK: case DATA_FORMAT_ASTC_5x4_UNORM_BLOCK: @@ -803,14 +846,17 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_ASTC_12x10_UNORM_BLOCK: case DATA_FORMAT_ASTC_12x10_SRGB_BLOCK: case DATA_FORMAT_ASTC_12x12_UNORM_BLOCK: - case DATA_FORMAT_ASTC_12x12_SRGB_BLOCK: return 1; + case DATA_FORMAT_ASTC_12x12_SRGB_BLOCK: + return 1; case DATA_FORMAT_G8B8G8R8_422_UNORM: - case DATA_FORMAT_B8G8R8G8_422_UNORM: return 4; + case DATA_FORMAT_B8G8R8G8_422_UNORM: + return 4; case DATA_FORMAT_G8_B8_R8_3PLANE_420_UNORM: case DATA_FORMAT_G8_B8R8_2PLANE_420_UNORM: case DATA_FORMAT_G8_B8_R8_3PLANE_422_UNORM: case DATA_FORMAT_G8_B8R8_2PLANE_422_UNORM: - case DATA_FORMAT_G8_B8_R8_3PLANE_444_UNORM: return 4; + case DATA_FORMAT_G8_B8_R8_3PLANE_444_UNORM: + return 4; case DATA_FORMAT_R10X6_UNORM_PACK16: case DATA_FORMAT_R10X6G10X6_UNORM_2PACK16: case DATA_FORMAT_R10X6G10X6B10X6A10X6_UNORM_4PACK16: @@ -830,14 +876,16 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_G12X4_B12X4R12X4_2PLANE_420_UNORM_3PACK16: case DATA_FORMAT_G12X4_B12X4_R12X4_3PLANE_422_UNORM_3PACK16: case DATA_FORMAT_G12X4_B12X4R12X4_2PLANE_422_UNORM_3PACK16: - case DATA_FORMAT_G12X4_B12X4_R12X4_3PLANE_444_UNORM_3PACK16: return 2; + case DATA_FORMAT_G12X4_B12X4_R12X4_3PLANE_444_UNORM_3PACK16: + return 2; case DATA_FORMAT_G16B16G16R16_422_UNORM: case DATA_FORMAT_B16G16R16G16_422_UNORM: case DATA_FORMAT_G16_B16_R16_3PLANE_420_UNORM: case DATA_FORMAT_G16_B16R16_2PLANE_420_UNORM: case DATA_FORMAT_G16_B16_R16_3PLANE_422_UNORM: case DATA_FORMAT_G16_B16R16_2PLANE_422_UNORM: - case DATA_FORMAT_G16_B16_R16_3PLANE_444_UNORM: return 8; + case DATA_FORMAT_G16_B16_R16_3PLANE_444_UNORM: + return 8; case DATA_FORMAT_PVRTC1_2BPP_UNORM_BLOCK_IMG: case DATA_FORMAT_PVRTC1_4BPP_UNORM_BLOCK_IMG: case DATA_FORMAT_PVRTC2_2BPP_UNORM_BLOCK_IMG: @@ -845,7 +893,8 @@ uint32_t RenderingDeviceVulkan::get_image_format_pixel_size(DataFormat p_format) case DATA_FORMAT_PVRTC1_2BPP_SRGB_BLOCK_IMG: case DATA_FORMAT_PVRTC1_4BPP_SRGB_BLOCK_IMG: case DATA_FORMAT_PVRTC2_2BPP_SRGB_BLOCK_IMG: - case DATA_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG: return 1; + case DATA_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG: + return 1; default: { ERR_PRINT("Format not handled, bug"); } @@ -943,29 +992,41 @@ uint32_t RenderingDeviceVulkan::get_compressed_image_format_block_byte_size(Data case DATA_FORMAT_BC1_RGB_UNORM_BLOCK: case DATA_FORMAT_BC1_RGB_SRGB_BLOCK: case DATA_FORMAT_BC1_RGBA_UNORM_BLOCK: - case DATA_FORMAT_BC1_RGBA_SRGB_BLOCK: return 8; + case DATA_FORMAT_BC1_RGBA_SRGB_BLOCK: + return 8; case DATA_FORMAT_BC2_UNORM_BLOCK: - case DATA_FORMAT_BC2_SRGB_BLOCK: return 16; + case DATA_FORMAT_BC2_SRGB_BLOCK: + return 16; case DATA_FORMAT_BC3_UNORM_BLOCK: - case DATA_FORMAT_BC3_SRGB_BLOCK: return 16; + case DATA_FORMAT_BC3_SRGB_BLOCK: + return 16; case DATA_FORMAT_BC4_UNORM_BLOCK: - case DATA_FORMAT_BC4_SNORM_BLOCK: return 8; + case DATA_FORMAT_BC4_SNORM_BLOCK: + return 8; case DATA_FORMAT_BC5_UNORM_BLOCK: - case DATA_FORMAT_BC5_SNORM_BLOCK: return 16; + case DATA_FORMAT_BC5_SNORM_BLOCK: + return 16; case DATA_FORMAT_BC6H_UFLOAT_BLOCK: - case DATA_FORMAT_BC6H_SFLOAT_BLOCK: return 16; + case DATA_FORMAT_BC6H_SFLOAT_BLOCK: + return 16; case DATA_FORMAT_BC7_UNORM_BLOCK: - case DATA_FORMAT_BC7_SRGB_BLOCK: return 16; + case DATA_FORMAT_BC7_SRGB_BLOCK: + return 16; case DATA_FORMAT_ETC2_R8G8B8_UNORM_BLOCK: - case DATA_FORMAT_ETC2_R8G8B8_SRGB_BLOCK: return 8; + case DATA_FORMAT_ETC2_R8G8B8_SRGB_BLOCK: + return 8; case DATA_FORMAT_ETC2_R8G8B8A1_UNORM_BLOCK: - case DATA_FORMAT_ETC2_R8G8B8A1_SRGB_BLOCK: return 8; + case DATA_FORMAT_ETC2_R8G8B8A1_SRGB_BLOCK: + return 8; case DATA_FORMAT_ETC2_R8G8B8A8_UNORM_BLOCK: - case DATA_FORMAT_ETC2_R8G8B8A8_SRGB_BLOCK: return 16; + case DATA_FORMAT_ETC2_R8G8B8A8_SRGB_BLOCK: + return 16; case DATA_FORMAT_EAC_R11_UNORM_BLOCK: - case DATA_FORMAT_EAC_R11_SNORM_BLOCK: return 8; + case DATA_FORMAT_EAC_R11_SNORM_BLOCK: + return 8; case DATA_FORMAT_EAC_R11G11_UNORM_BLOCK: - case DATA_FORMAT_EAC_R11G11_SNORM_BLOCK: return 16; + case DATA_FORMAT_EAC_R11G11_SNORM_BLOCK: + return 16; case DATA_FORMAT_ASTC_4x4_UNORM_BLOCK: //again, not sure about astc case DATA_FORMAT_ASTC_4x4_SRGB_BLOCK: case DATA_FORMAT_ASTC_5x4_UNORM_BLOCK: @@ -1028,11 +1089,13 @@ uint32_t RenderingDeviceVulkan::get_compressed_image_format_pixel_rshift(DataFor case DATA_FORMAT_PVRTC1_4BPP_UNORM_BLOCK_IMG: case DATA_FORMAT_PVRTC2_4BPP_UNORM_BLOCK_IMG: case DATA_FORMAT_PVRTC1_4BPP_SRGB_BLOCK_IMG: - case DATA_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG: return 1; + case DATA_FORMAT_PVRTC2_4BPP_SRGB_BLOCK_IMG: + return 1; case DATA_FORMAT_PVRTC1_2BPP_UNORM_BLOCK_IMG: //these formats are quarter byte size, so rshift is 1 case DATA_FORMAT_PVRTC2_2BPP_UNORM_BLOCK_IMG: case DATA_FORMAT_PVRTC1_2BPP_SRGB_BLOCK_IMG: - case DATA_FORMAT_PVRTC2_2BPP_SRGB_BLOCK_IMG: return 2; + case DATA_FORMAT_PVRTC2_2BPP_SRGB_BLOCK_IMG: + return 2; default: { } } @@ -2390,7 +2453,7 @@ Vector<uint8_t> RenderingDeviceVulkan::texture_get_data(RID p_texture, uint32_t uint32_t buffer_size = get_image_format_required_size(tex->format, tex->width, tex->height, tex->depth, tex->mipmaps, &width, &height, &depth); //allocate buffer - VkCommandBuffer command_buffer = frames[frame].setup_command_buffer; + VkCommandBuffer command_buffer = frames[frame].draw_command_buffer; //makes more sense to retrieve Buffer tmp_buffer; _buffer_allocate(&tmp_buffer, buffer_size, VK_BUFFER_USAGE_TRANSFER_DST_BIT, VMA_MEMORY_USAGE_CPU_ONLY); @@ -6796,6 +6859,7 @@ void RenderingDeviceVulkan::sync() { context->local_device_sync(local_device); _begin_frame(); + local_device_processing = false; } void RenderingDeviceVulkan::_free_pending_resources(int p_frame) { @@ -6912,6 +6976,12 @@ uint32_t RenderingDeviceVulkan::get_frame_delay() const { return frame_count; } +uint64_t RenderingDeviceVulkan::get_memory_usage() const { + VmaStats stats; + vmaCalculateStats(allocator, &stats); + return stats.total.usedBytes; +} + void RenderingDeviceVulkan::_flush(bool p_current_frame) { if (local_device.is_valid() && !p_current_frame) { @@ -6976,6 +7046,7 @@ void RenderingDeviceVulkan::initialize(VulkanContext *p_context, bool p_local_de if (p_local_device) { frame_count = 1; local_device = p_context->local_device_create(); + device = p_context->local_device_get_vk_device(local_device); } else { frame_count = p_context->get_swapchain_image_count() + 1; //always need one extra to ensure it's unused at any time, without having to use a fence for this. } @@ -7222,42 +7293,77 @@ String RenderingDeviceVulkan::get_captured_timestamp_name(uint32_t p_index) cons int RenderingDeviceVulkan::limit_get(Limit p_limit) { switch (p_limit) { - case LIMIT_MAX_BOUND_UNIFORM_SETS: return limits.maxBoundDescriptorSets; - case LIMIT_MAX_FRAMEBUFFER_COLOR_ATTACHMENTS: return limits.maxColorAttachments; - case LIMIT_MAX_TEXTURES_PER_UNIFORM_SET: return limits.maxDescriptorSetSampledImages; - case LIMIT_MAX_SAMPLERS_PER_UNIFORM_SET: return limits.maxDescriptorSetSamplers; - case LIMIT_MAX_STORAGE_BUFFERS_PER_UNIFORM_SET: return limits.maxDescriptorSetStorageBuffers; - case LIMIT_MAX_STORAGE_IMAGES_PER_UNIFORM_SET: return limits.maxDescriptorSetStorageImages; - case LIMIT_MAX_UNIFORM_BUFFERS_PER_UNIFORM_SET: return limits.maxDescriptorSetUniformBuffers; - case LIMIT_MAX_DRAW_INDEXED_INDEX: return limits.maxDrawIndexedIndexValue; - case LIMIT_MAX_FRAMEBUFFER_HEIGHT: return limits.maxFramebufferHeight; - case LIMIT_MAX_FRAMEBUFFER_WIDTH: return limits.maxFramebufferWidth; - case LIMIT_MAX_TEXTURE_ARRAY_LAYERS: return limits.maxImageArrayLayers; - case LIMIT_MAX_TEXTURE_SIZE_1D: return limits.maxImageDimension1D; - case LIMIT_MAX_TEXTURE_SIZE_2D: return limits.maxImageDimension2D; - case LIMIT_MAX_TEXTURE_SIZE_3D: return limits.maxImageDimension3D; - case LIMIT_MAX_TEXTURE_SIZE_CUBE: return limits.maxImageDimensionCube; - case LIMIT_MAX_TEXTURES_PER_SHADER_STAGE: return limits.maxPerStageDescriptorSampledImages; - case LIMIT_MAX_SAMPLERS_PER_SHADER_STAGE: return limits.maxPerStageDescriptorSamplers; - case LIMIT_MAX_STORAGE_BUFFERS_PER_SHADER_STAGE: return limits.maxPerStageDescriptorStorageBuffers; - case LIMIT_MAX_STORAGE_IMAGES_PER_SHADER_STAGE: return limits.maxPerStageDescriptorStorageImages; - case LIMIT_MAX_UNIFORM_BUFFERS_PER_SHADER_STAGE: return limits.maxPerStageDescriptorUniformBuffers; - case LIMIT_MAX_PUSH_CONSTANT_SIZE: return limits.maxPushConstantsSize; - case LIMIT_MAX_UNIFORM_BUFFER_SIZE: return limits.maxUniformBufferRange; - case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTE_OFFSET: return limits.maxVertexInputAttributeOffset; - case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTES: return limits.maxVertexInputAttributes; - case LIMIT_MAX_VERTEX_INPUT_BINDINGS: return limits.maxVertexInputBindings; - case LIMIT_MAX_VERTEX_INPUT_BINDING_STRIDE: return limits.maxVertexInputBindingStride; - case LIMIT_MIN_UNIFORM_BUFFER_OFFSET_ALIGNMENT: return limits.minUniformBufferOffsetAlignment; - case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_X: return limits.maxComputeWorkGroupCount[0]; - case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Y: return limits.maxComputeWorkGroupCount[1]; - case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Z: return limits.maxComputeWorkGroupCount[2]; - case LIMIT_MAX_COMPUTE_WORKGROUP_INVOCATIONS: return limits.maxComputeWorkGroupInvocations; - case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_X: return limits.maxComputeWorkGroupSize[0]; - case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Y: return limits.maxComputeWorkGroupSize[1]; - case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Z: return limits.maxComputeWorkGroupSize[2]; - - default: ERR_FAIL_V(0); + case LIMIT_MAX_BOUND_UNIFORM_SETS: + return limits.maxBoundDescriptorSets; + case LIMIT_MAX_FRAMEBUFFER_COLOR_ATTACHMENTS: + return limits.maxColorAttachments; + case LIMIT_MAX_TEXTURES_PER_UNIFORM_SET: + return limits.maxDescriptorSetSampledImages; + case LIMIT_MAX_SAMPLERS_PER_UNIFORM_SET: + return limits.maxDescriptorSetSamplers; + case LIMIT_MAX_STORAGE_BUFFERS_PER_UNIFORM_SET: + return limits.maxDescriptorSetStorageBuffers; + case LIMIT_MAX_STORAGE_IMAGES_PER_UNIFORM_SET: + return limits.maxDescriptorSetStorageImages; + case LIMIT_MAX_UNIFORM_BUFFERS_PER_UNIFORM_SET: + return limits.maxDescriptorSetUniformBuffers; + case LIMIT_MAX_DRAW_INDEXED_INDEX: + return limits.maxDrawIndexedIndexValue; + case LIMIT_MAX_FRAMEBUFFER_HEIGHT: + return limits.maxFramebufferHeight; + case LIMIT_MAX_FRAMEBUFFER_WIDTH: + return limits.maxFramebufferWidth; + case LIMIT_MAX_TEXTURE_ARRAY_LAYERS: + return limits.maxImageArrayLayers; + case LIMIT_MAX_TEXTURE_SIZE_1D: + return limits.maxImageDimension1D; + case LIMIT_MAX_TEXTURE_SIZE_2D: + return limits.maxImageDimension2D; + case LIMIT_MAX_TEXTURE_SIZE_3D: + return limits.maxImageDimension3D; + case LIMIT_MAX_TEXTURE_SIZE_CUBE: + return limits.maxImageDimensionCube; + case LIMIT_MAX_TEXTURES_PER_SHADER_STAGE: + return limits.maxPerStageDescriptorSampledImages; + case LIMIT_MAX_SAMPLERS_PER_SHADER_STAGE: + return limits.maxPerStageDescriptorSamplers; + case LIMIT_MAX_STORAGE_BUFFERS_PER_SHADER_STAGE: + return limits.maxPerStageDescriptorStorageBuffers; + case LIMIT_MAX_STORAGE_IMAGES_PER_SHADER_STAGE: + return limits.maxPerStageDescriptorStorageImages; + case LIMIT_MAX_UNIFORM_BUFFERS_PER_SHADER_STAGE: + return limits.maxPerStageDescriptorUniformBuffers; + case LIMIT_MAX_PUSH_CONSTANT_SIZE: + return limits.maxPushConstantsSize; + case LIMIT_MAX_UNIFORM_BUFFER_SIZE: + return limits.maxUniformBufferRange; + case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTE_OFFSET: + return limits.maxVertexInputAttributeOffset; + case LIMIT_MAX_VERTEX_INPUT_ATTRIBUTES: + return limits.maxVertexInputAttributes; + case LIMIT_MAX_VERTEX_INPUT_BINDINGS: + return limits.maxVertexInputBindings; + case LIMIT_MAX_VERTEX_INPUT_BINDING_STRIDE: + return limits.maxVertexInputBindingStride; + case LIMIT_MIN_UNIFORM_BUFFER_OFFSET_ALIGNMENT: + return limits.minUniformBufferOffsetAlignment; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_X: + return limits.maxComputeWorkGroupCount[0]; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Y: + return limits.maxComputeWorkGroupCount[1]; + case LIMIT_MAX_COMPUTE_WORKGROUP_COUNT_Z: + return limits.maxComputeWorkGroupCount[2]; + case LIMIT_MAX_COMPUTE_WORKGROUP_INVOCATIONS: + return limits.maxComputeWorkGroupInvocations; + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_X: + return limits.maxComputeWorkGroupSize[0]; + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Y: + return limits.maxComputeWorkGroupSize[1]; + case LIMIT_MAX_COMPUTE_WORKGROUP_SIZE_Z: + return limits.maxComputeWorkGroupSize[2]; + + default: + ERR_FAIL_V(0); } return 0; diff --git a/drivers/vulkan/rendering_device_vulkan.h b/drivers/vulkan/rendering_device_vulkan.h index 6432946fbe..87af5d03d4 100644 --- a/drivers/vulkan/rendering_device_vulkan.h +++ b/drivers/vulkan/rendering_device_vulkan.h @@ -1138,6 +1138,8 @@ public: virtual RenderingDevice *create_local_device(); + virtual uint64_t get_memory_usage() const; + RenderingDeviceVulkan(); ~RenderingDeviceVulkan(); }; diff --git a/drivers/vulkan/vulkan_context.cpp b/drivers/vulkan/vulkan_context.cpp index d293abdee3..9471b4604c 100644 --- a/drivers/vulkan/vulkan_context.cpp +++ b/drivers/vulkan/vulkan_context.cpp @@ -601,12 +601,13 @@ Error VulkanContext::_initialize_queues(VkSurfaceKHR surface) { _create_device(); static PFN_vkGetDeviceProcAddr g_gdpa = nullptr; -#define GET_DEVICE_PROC_ADDR(dev, entrypoint) \ - { \ - if (!g_gdpa) g_gdpa = (PFN_vkGetDeviceProcAddr)vkGetInstanceProcAddr(inst, "vkGetDeviceProcAddr"); \ - fp##entrypoint = (PFN_vk##entrypoint)g_gdpa(dev, "vk" #entrypoint); \ - ERR_FAIL_COND_V_MSG(fp##entrypoint == nullptr, ERR_CANT_CREATE, \ - "vkGetDeviceProcAddr failed to find vk" #entrypoint); \ +#define GET_DEVICE_PROC_ADDR(dev, entrypoint) \ + { \ + if (!g_gdpa) \ + g_gdpa = (PFN_vkGetDeviceProcAddr)vkGetInstanceProcAddr(inst, "vkGetDeviceProcAddr"); \ + fp##entrypoint = (PFN_vk##entrypoint)g_gdpa(dev, "vk" #entrypoint); \ + ERR_FAIL_COND_V_MSG(fp##entrypoint == nullptr, ERR_CANT_CREATE, \ + "vkGetDeviceProcAddr failed to find vk" #entrypoint); \ } GET_DEVICE_PROC_ADDR(device, CreateSwapchainKHR); @@ -1566,6 +1567,15 @@ void VulkanContext::local_device_push_command_buffers(RID p_local_device, const submit_info.pSignalSemaphores = nullptr; VkResult err = vkQueueSubmit(ld->queue, 1, &submit_info, VK_NULL_HANDLE); + if (err == VK_ERROR_OUT_OF_HOST_MEMORY) { + print_line("out of host memory"); + } + if (err == VK_ERROR_OUT_OF_DEVICE_MEMORY) { + print_line("out of device memory"); + } + if (err == VK_ERROR_DEVICE_LOST) { + print_line("device lost"); + } ERR_FAIL_COND(err); ld->waiting = true; diff --git a/drivers/wasapi/audio_driver_wasapi.cpp b/drivers/wasapi/audio_driver_wasapi.cpp index ab2976f02c..1fc01ce76e 100644 --- a/drivers/wasapi/audio_driver_wasapi.cpp +++ b/drivers/wasapi/audio_driver_wasapi.cpp @@ -69,13 +69,11 @@ static bool default_render_device_changed = false; static bool default_capture_device_changed = false; class CMMNotificationClient : public IMMNotificationClient { - LONG _cRef; - IMMDeviceEnumerator *_pEnumerator; + LONG _cRef = 1; + IMMDeviceEnumerator *_pEnumerator = nullptr; public: - CMMNotificationClient() : - _cRef(1), - _pEnumerator(nullptr) {} + CMMNotificationClient() {} virtual ~CMMNotificationClient() { if ((_pEnumerator) != nullptr) { (_pEnumerator)->Release(); @@ -854,17 +852,7 @@ String AudioDriverWASAPI::capture_get_device() { } AudioDriverWASAPI::AudioDriverWASAPI() { - - thread = nullptr; - samples_in.clear(); - - channels = 0; - mix_rate = 0; - buffer_frames = 0; - - thread_exited = false; - exit_thread = false; } -#endif +#endif // WASAPI_ENABLED diff --git a/drivers/wasapi/audio_driver_wasapi.h b/drivers/wasapi/audio_driver_wasapi.h index 01a4666812..2fcf8936fa 100644 --- a/drivers/wasapi/audio_driver_wasapi.h +++ b/drivers/wasapi/audio_driver_wasapi.h @@ -45,47 +45,36 @@ class AudioDriverWASAPI : public AudioDriver { class AudioDeviceWASAPI { public: - IAudioClient *audio_client; - IAudioRenderClient *render_client; - IAudioCaptureClient *capture_client; - bool active; - - WORD format_tag; - WORD bits_per_sample; - unsigned int channels; - unsigned int frame_size; - - String device_name; - String new_device; - - AudioDeviceWASAPI() : - audio_client(nullptr), - render_client(nullptr), - capture_client(nullptr), - active(false), - format_tag(0), - bits_per_sample(0), - channels(0), - frame_size(0), - device_name("Default"), - new_device("Default") { - } + IAudioClient *audio_client = nullptr; + IAudioRenderClient *render_client = nullptr; + IAudioCaptureClient *capture_client = nullptr; + bool active = false; + + WORD format_tag = 0; + WORD bits_per_sample = 0; + unsigned int channels = 0; + unsigned int frame_size = 0; + + String device_name = "Default"; + String new_device = "Default"; + + AudioDeviceWASAPI() {} }; AudioDeviceWASAPI audio_input; AudioDeviceWASAPI audio_output; Mutex mutex; - Thread *thread; + Thread *thread = nullptr; Vector<int32_t> samples_in; - unsigned int channels; - int mix_rate; - int buffer_frames; + unsigned int channels = 0; + int mix_rate = 0; + int buffer_frames = 0; - bool thread_exited; - mutable bool exit_thread; + bool thread_exited = false; + mutable bool exit_thread = false; static _FORCE_INLINE_ void write_sample(WORD format_tag, int bits_per_sample, BYTE *buffer, int i, int32_t sample); static _FORCE_INLINE_ int32_t read_sample(WORD format_tag, int bits_per_sample, BYTE *buffer, int i); diff --git a/drivers/windows/file_access_windows.cpp b/drivers/windows/file_access_windows.cpp index 69078b3326..f1326abb7b 100644 --- a/drivers/windows/file_access_windows.cpp +++ b/drivers/windows/file_access_windows.cpp @@ -353,15 +353,8 @@ Error FileAccessWindows::_set_unix_permissions(const String &p_file, uint32_t p_ return ERR_UNAVAILABLE; } -FileAccessWindows::FileAccessWindows() : - f(nullptr), - flags(0), - prev_op(0), - last_error(OK) { -} FileAccessWindows::~FileAccessWindows() { - close(); } -#endif +#endif // WINDOWS_ENABLED diff --git a/drivers/windows/file_access_windows.h b/drivers/windows/file_access_windows.h index 28d4375878..34a7e400a0 100644 --- a/drivers/windows/file_access_windows.h +++ b/drivers/windows/file_access_windows.h @@ -40,11 +40,11 @@ class FileAccessWindows : public FileAccess { - FILE *f; - int flags; + FILE *f = nullptr; + int flags = 0; void check_errors() const; - mutable int prev_op; - mutable Error last_error; + mutable int prev_op = 0; + mutable Error last_error = OK; String path; String path_src; String save_path; @@ -79,9 +79,10 @@ public: virtual uint32_t _get_unix_permissions(const String &p_file); virtual Error _set_unix_permissions(const String &p_file, uint32_t p_permissions); - FileAccessWindows(); + FileAccessWindows() {} virtual ~FileAccessWindows(); }; -#endif -#endif +#endif // WINDOWS_ENABLED + +#endif // FILE_ACCESS_WINDOWS_H diff --git a/drivers/windows/thread_windows.cpp b/drivers/windows/thread_windows.cpp index aea2db2603..c36437d891 100644 --- a/drivers/windows/thread_windows.cpp +++ b/drivers/windows/thread_windows.cpp @@ -90,11 +90,4 @@ void ThreadWindows::make_default() { wait_to_finish_func = wait_to_finish_func_windows; } -ThreadWindows::ThreadWindows() : - handle(nullptr) { -} - -ThreadWindows::~ThreadWindows() { -} - #endif diff --git a/drivers/windows/thread_windows.h b/drivers/windows/thread_windows.h index 669956cd32..93de4c6e8c 100644 --- a/drivers/windows/thread_windows.h +++ b/drivers/windows/thread_windows.h @@ -43,7 +43,7 @@ class ThreadWindows : public Thread { ThreadCreateCallback callback; void *user; ID id; - HANDLE handle; + HANDLE handle = nullptr; static Thread *create_thread_windows(); @@ -53,14 +53,14 @@ class ThreadWindows : public Thread { static ID get_thread_id_func_windows(); static void wait_to_finish_func_windows(Thread *p_thread); - ThreadWindows(); + ThreadWindows() {} public: virtual ID get_id() const; static void make_default(); - ~ThreadWindows(); + ~ThreadWindows() {} }; #endif diff --git a/drivers/xaudio2/audio_driver_xaudio2.cpp b/drivers/xaudio2/audio_driver_xaudio2.cpp index 120bfa2b36..d12ebf8d94 100644 --- a/drivers/xaudio2/audio_driver_xaudio2.cpp +++ b/drivers/xaudio2/audio_driver_xaudio2.cpp @@ -196,15 +196,9 @@ void AudioDriverXAudio2::finish() { thread = nullptr; } -AudioDriverXAudio2::AudioDriverXAudio2() : - thread(nullptr), - current_buffer(0) { - wave_format = { 0 }; +AudioDriverXAudio2::AudioDriverXAudio2() { for (int i = 0; i < AUDIO_BUFFERS; i++) { xaudio_buffer[i] = { 0 }; samples_out[i] = 0; } } - -AudioDriverXAudio2::~AudioDriverXAudio2() { -} diff --git a/drivers/xaudio2/audio_driver_xaudio2.h b/drivers/xaudio2/audio_driver_xaudio2.h index eb4a6d6e95..9182fa4172 100644 --- a/drivers/xaudio2/audio_driver_xaudio2.h +++ b/drivers/xaudio2/audio_driver_xaudio2.h @@ -64,7 +64,7 @@ class AudioDriverXAudio2 : public AudioDriver { void STDMETHODCALLTYPE OnVoiceError(void *pBufferContext, HRESULT Error) {} }; - Thread *thread; + Thread *thread = nullptr; Mutex mutex; int32_t *samples_in; @@ -83,9 +83,9 @@ class AudioDriverXAudio2 : public AudioDriver { mutable bool exit_thread; bool pcm_open; - WAVEFORMATEX wave_format; + WAVEFORMATEX wave_format = { 0 }; Microsoft::WRL::ComPtr<IXAudio2> xaudio; - int current_buffer; + int current_buffer = 0; IXAudio2MasteringVoice *mastering_voice; XAUDIO2_BUFFER xaudio_buffer[AUDIO_BUFFERS]; IXAudio2SourceVoice *source_voice; @@ -104,7 +104,7 @@ public: virtual void finish(); AudioDriverXAudio2(); - ~AudioDriverXAudio2(); + ~AudioDriverXAudio2() {} }; #endif diff --git a/editor/animation_track_editor.cpp b/editor/animation_track_editor.cpp index 09f55bea0c..4a43cb0c18 100644 --- a/editor/animation_track_editor.cpp +++ b/editor/animation_track_editor.cpp @@ -5297,10 +5297,18 @@ void AnimationTrackEditor::_edit_menu_pressed(int p_option) { } switch (animation->track_get_type(i)) { - case Animation::TYPE_TRANSFORM: text += " (Transform)"; break; - case Animation::TYPE_METHOD: text += " (Methods)"; break; - case Animation::TYPE_BEZIER: text += " (Bezier)"; break; - case Animation::TYPE_AUDIO: text += " (Audio)"; break; + case Animation::TYPE_TRANSFORM: + text += " (Transform)"; + break; + case Animation::TYPE_METHOD: + text += " (Methods)"; + break; + case Animation::TYPE_BEZIER: + text += " (Bezier)"; + break; + case Animation::TYPE_AUDIO: + text += " (Audio)"; + break; default: { }; } @@ -5947,7 +5955,7 @@ AnimationTrackEditor::AnimationTrackEditor() { insert_confirm_bezier->set_text(TTR("Use Bezier Curves")); icvb->add_child(insert_confirm_bezier); keying = false; - moving_selection = 0; + moving_selection = false; key_edit = nullptr; multi_key_edit = nullptr; diff --git a/editor/code_editor.cpp b/editor/code_editor.cpp index 987d5649b1..76716f01b7 100644 --- a/editor/code_editor.cpp +++ b/editor/code_editor.cpp @@ -309,7 +309,8 @@ void FindReplaceBar::_update_results_count() { results_count = 0; String searched = get_search_text(); - if (searched.empty()) return; + if (searched.empty()) + return; String full_text = text_edit->get_text(); @@ -317,7 +318,8 @@ void FindReplaceBar::_update_results_count() { while (true) { int pos = is_case_sensitive() ? full_text.find(searched, from_pos) : full_text.findn(searched, from_pos); - if (pos == -1) break; + if (pos == -1) + break; if (is_whole_words()) { from_pos++; // Making sure we won't hit the same match next time, if we get out via a continue. diff --git a/editor/connections_dialog.cpp b/editor/connections_dialog.cpp index bef5c3c2b0..4556a6e827 100644 --- a/editor/connections_dialog.cpp +++ b/editor/connections_dialog.cpp @@ -171,20 +171,48 @@ void ConnectDialog::_add_bind() { Variant value; switch (vt) { - case Variant::BOOL: value = false; break; - case Variant::INT: value = 0; break; - case Variant::FLOAT: value = 0.0; break; - case Variant::STRING: value = ""; break; - case Variant::STRING_NAME: value = ""; break; - case Variant::VECTOR2: value = Vector2(); break; - case Variant::RECT2: value = Rect2(); break; - case Variant::VECTOR3: value = Vector3(); break; - case Variant::PLANE: value = Plane(); break; - case Variant::QUAT: value = Quat(); break; - case Variant::AABB: value = AABB(); break; - case Variant::BASIS: value = Basis(); break; - case Variant::TRANSFORM: value = Transform(); break; - case Variant::COLOR: value = Color(); break; + case Variant::BOOL: + value = false; + break; + case Variant::INT: + value = 0; + break; + case Variant::FLOAT: + value = 0.0; + break; + case Variant::STRING: + value = ""; + break; + case Variant::STRING_NAME: + value = ""; + break; + case Variant::VECTOR2: + value = Vector2(); + break; + case Variant::RECT2: + value = Rect2(); + break; + case Variant::VECTOR3: + value = Vector3(); + break; + case Variant::PLANE: + value = Plane(); + break; + case Variant::QUAT: + value = Quat(); + break; + case Variant::AABB: + value = AABB(); + break; + case Variant::BASIS: + value = Basis(); + break; + case Variant::TRANSFORM: + value = Transform(); + break; + case Variant::COLOR: + value = Color(); + break; default: { ERR_FAIL(); } break; diff --git a/editor/debugger/editor_debugger_inspector.h b/editor/debugger/editor_debugger_inspector.h index e1dfbefcf3..f0ee7a2535 100644 --- a/editor/debugger/editor_debugger_inspector.h +++ b/editor/debugger/editor_debugger_inspector.h @@ -61,7 +61,7 @@ public: void update() { _change_notify(); } - EditorDebuggerRemoteObject(){}; + EditorDebuggerRemoteObject() {} }; class EditorDebuggerInspector : public EditorInspector { diff --git a/editor/debugger/editor_debugger_node.cpp b/editor/debugger/editor_debugger_node.cpp index 3302b50103..5bfb79806e 100644 --- a/editor/debugger/editor_debugger_node.cpp +++ b/editor/debugger/editor_debugger_node.cpp @@ -173,7 +173,7 @@ ScriptEditorDebugger *EditorDebuggerNode::get_default_debugger() const { return Object::cast_to<ScriptEditorDebugger>(tabs->get_tab_control(0)); } -Error EditorDebuggerNode::start() { +Error EditorDebuggerNode::start(const String &p_protocol) { stop(); if (EDITOR_GET("run/output/always_open_output_on_play")) { EditorNode::get_singleton()->make_bottom_panel_item_visible(EditorNode::get_log()); @@ -181,7 +181,7 @@ Error EditorDebuggerNode::start() { EditorNode::get_singleton()->make_bottom_panel_item_visible(this); } - server = Ref<EditorDebuggerServer>(EditorDebuggerServer::create_default()); + server = Ref<EditorDebuggerServer>(EditorDebuggerServer::create(p_protocol)); const Error err = server->start(); if (err != OK) { return err; @@ -213,14 +213,6 @@ void EditorDebuggerNode::stop() { void EditorDebuggerNode::_notification(int p_what) { switch (p_what) { - case NOTIFICATION_ENTER_TREE: { - EditorNode::get_singleton()->connect("play_pressed", callable_mp(this, &EditorDebuggerNode::start)); - EditorNode::get_singleton()->connect("stop_pressed", callable_mp(this, &EditorDebuggerNode::stop)); - } break; - case NOTIFICATION_EXIT_TREE: { - EditorNode::get_singleton()->disconnect("play_pressed", callable_mp(this, &EditorDebuggerNode::start)); - EditorNode::get_singleton()->disconnect("stop_pressed", callable_mp(this, &EditorDebuggerNode::stop)); - } break; case EditorSettings::NOTIFICATION_EDITOR_SETTINGS_CHANGED: { if (tabs->get_tab_count() > 1) { add_theme_constant_override("margin_left", -EditorNode::get_singleton()->get_gui_base()->get_theme_stylebox("BottomPanelDebuggerOverride", "EditorStyles")->get_margin(MARGIN_LEFT)); @@ -261,10 +253,12 @@ void EditorDebuggerNode::_notification(int p_what) { debugger_button->set_icon(Ref<Texture2D>()); } else { debugger_button->set_text(TTR("Debugger") + " (" + itos(error_count + warning_count) + ")"); - if (error_count == 0) { - debugger_button->set_icon(get_theme_icon("Warning", "EditorIcons")); - } else { + if (error_count >= 1 && warning_count >= 1) { + debugger_button->set_icon(get_theme_icon("ErrorWarning", "EditorIcons")); + } else if (error_count >= 1) { debugger_button->set_icon(get_theme_icon("Error", "EditorIcons")); + } else { + debugger_button->set_icon(get_theme_icon("Warning", "EditorIcons")); } } last_error_count = error_count; diff --git a/editor/debugger/editor_debugger_node.h b/editor/debugger/editor_debugger_node.h index 9467442c9a..88b82d37fd 100644 --- a/editor/debugger/editor_debugger_node.h +++ b/editor/debugger/editor_debugger_node.h @@ -76,7 +76,7 @@ private: return line < p_b.line; } - Breakpoint(){}; + Breakpoint() {} Breakpoint(const String &p_source, int p_line) { line = p_line; @@ -183,7 +183,7 @@ public: void set_camera_override(CameraOverride p_override) { camera_override = p_override; } CameraOverride get_camera_override() { return camera_override; } - Error start(); + Error start(const String &p_protocol = "tcp://"); void stop(); }; diff --git a/editor/debugger/editor_debugger_server.cpp b/editor/debugger/editor_debugger_server.cpp index c80988a662..33f20d9a12 100644 --- a/editor/debugger/editor_debugger_server.cpp +++ b/editor/debugger/editor_debugger_server.cpp @@ -44,6 +44,7 @@ private: Ref<TCP_Server> server; public: + static EditorDebuggerServer *create(const String &p_protocol); virtual void poll() {} virtual Error start(); virtual void stop(); @@ -54,6 +55,11 @@ public: EditorDebuggerServerTCP(); }; +EditorDebuggerServer *EditorDebuggerServerTCP::create(const String &p_protocol) { + ERR_FAIL_COND_V(p_protocol != "tcp://", nullptr); + return memnew(EditorDebuggerServerTCP); +} + EditorDebuggerServerTCP::EditorDebuggerServerTCP() { server.instance(); } @@ -85,6 +91,23 @@ Ref<RemoteDebuggerPeer> EditorDebuggerServerTCP::take_connection() { return memnew(RemoteDebuggerPeerTCP(server->take_connection())); } -EditorDebuggerServer *EditorDebuggerServer::create_default() { - return memnew(EditorDebuggerServerTCP); +/// EditorDebuggerServer +Map<StringName, EditorDebuggerServer::CreateServerFunc> EditorDebuggerServer::protocols; + +EditorDebuggerServer *EditorDebuggerServer::create(const String &p_protocol) { + ERR_FAIL_COND_V(!protocols.has(p_protocol), nullptr); + return protocols[p_protocol](p_protocol); +} + +void EditorDebuggerServer::register_protocol_handler(const String &p_protocol, CreateServerFunc p_func) { + ERR_FAIL_COND(protocols.has(p_protocol)); + protocols[p_protocol] = p_func; +} + +void EditorDebuggerServer::initialize() { + register_protocol_handler("tcp://", EditorDebuggerServerTCP::create); +} + +void EditorDebuggerServer::deinitialize() { + protocols.clear(); } diff --git a/editor/debugger/editor_debugger_server.h b/editor/debugger/editor_debugger_server.h index e9798c90b3..515ffce628 100644 --- a/editor/debugger/editor_debugger_server.h +++ b/editor/debugger/editor_debugger_server.h @@ -37,7 +37,17 @@ class EditorDebuggerServer : public Reference { public: - static EditorDebuggerServer *create_default(); + typedef EditorDebuggerServer *(*CreateServerFunc)(const String &p_uri); + +private: + static Map<StringName, CreateServerFunc> protocols; + +public: + static void initialize(); + static void deinitialize(); + + static void register_protocol_handler(const String &p_protocol, CreateServerFunc p_func); + static EditorDebuggerServer *create(const String &p_protocol); virtual void poll() = 0; virtual Error start() = 0; virtual void stop() = 0; diff --git a/editor/debugger/editor_profiler.cpp b/editor/debugger/editor_profiler.cpp index c7d4e9128a..347de2470b 100644 --- a/editor/debugger/editor_profiler.cpp +++ b/editor/debugger/editor_profiler.cpp @@ -344,7 +344,7 @@ void EditorProfiler::_update_plot() { Ref<Image> img; img.instance(); - img->create(w, h, 0, Image::FORMAT_RGBA8, graph_image); + img->create(w, h, false, Image::FORMAT_RGBA8, graph_image); if (reset_texture) { diff --git a/editor/debugger/editor_visual_profiler.cpp b/editor/debugger/editor_visual_profiler.cpp index 7d2822b1c9..589ef0d64e 100644 --- a/editor/debugger/editor_visual_profiler.cpp +++ b/editor/debugger/editor_visual_profiler.cpp @@ -307,7 +307,7 @@ void EditorVisualProfiler::_update_plot() { Ref<Image> img; img.instance(); - img->create(w, h, 0, Image::FORMAT_RGBA8, graph_image); + img->create(w, h, false, Image::FORMAT_RGBA8, graph_image); if (reset_texture) { diff --git a/editor/debugger/script_editor_debugger.cpp b/editor/debugger/script_editor_debugger.cpp index 152989f90b..3ed271c7c6 100644 --- a/editor/debugger/script_editor_debugger.cpp +++ b/editor/debugger/script_editor_debugger.cpp @@ -74,7 +74,8 @@ void ScriptEditorDebugger::_put_msg(String p_message, Array p_data) { void ScriptEditorDebugger::debug_copy() { String msg = reason->get_text(); - if (msg == "") return; + if (msg == "") + return; DisplayServer::get_singleton()->clipboard_set(msg); } @@ -130,10 +131,12 @@ void ScriptEditorDebugger::update_tabs() { tabs->set_tab_icon(errors_tab->get_index(), Ref<Texture2D>()); } else { errors_tab->set_name(TTR("Errors") + " (" + itos(error_count + warning_count) + ")"); - if (error_count == 0) { - tabs->set_tab_icon(errors_tab->get_index(), get_theme_icon("Warning", "EditorIcons")); - } else { + if (error_count >= 1 && warning_count >= 1) { + tabs->set_tab_icon(errors_tab->get_index(), get_theme_icon("ErrorWarning", "EditorIcons")); + } else if (error_count >= 1) { tabs->set_tab_icon(errors_tab->get_index(), get_theme_icon("Error", "EditorIcons")); + } else { + tabs->set_tab_icon(errors_tab->get_index(), get_theme_icon("Warning", "EditorIcons")); } } } @@ -843,6 +846,8 @@ void ScriptEditorDebugger::_notification(int p_what) { if (is_session_active()) { + peer->poll(); + if (camera_override == CameraOverride::OVERRIDE_2D) { CanvasItemEditor *editor = CanvasItemEditor::get_singleton(); diff --git a/editor/editor_audio_buses.cpp b/editor/editor_audio_buses.cpp index c80ae5f21b..39c94cc523 100644 --- a/editor/editor_audio_buses.cpp +++ b/editor/editor_audio_buses.cpp @@ -1480,11 +1480,7 @@ void EditorAudioMeterNotches::_draw_audio_notches() { } } -EditorAudioMeterNotches::EditorAudioMeterNotches() : - line_length(5.0f), - label_space(2.0f), - btm_padding(9.0f), - top_padding(5.0f) { +EditorAudioMeterNotches::EditorAudioMeterNotches() { notch_color = EditorSettings::get_singleton()->is_dark_theme() ? Color(1, 1, 1) : Color(0, 0, 0); } diff --git a/editor/editor_audio_buses.h b/editor/editor_audio_buses.h index be1551629d..3e911bcc65 100644 --- a/editor/editor_audio_buses.h +++ b/editor/editor_audio_buses.h @@ -246,10 +246,10 @@ private: List<AudioNotch> notches; public: - float line_length; - float label_space; - float btm_padding; - float top_padding; + float line_length = 5.0f; + float label_space = 2.0f; + float btm_padding = 9.0f; + float top_padding = 5.0f; Color notch_color; void add_notch(float p_normalized_offset, float p_db_value, bool p_render_value = false); diff --git a/editor/editor_export.cpp b/editor/editor_export.cpp index e8167070d4..abfd8e5484 100644 --- a/editor/editor_export.cpp +++ b/editor/editor_export.cpp @@ -252,13 +252,6 @@ String EditorExportPreset::get_script_encryption_key() const { return script_key; } -EditorExportPreset::EditorExportPreset() : - export_filter(EXPORT_ALL_RESOURCES), - export_path(""), - runnable(false), - script_mode(MODE_SCRIPT_COMPILED) { -} - /////////////////////////////////// void EditorExportPlatform::gen_debug_flags(Vector<String> &r_flags, int p_flags) { @@ -284,7 +277,7 @@ void EditorExportPlatform::gen_debug_flags(Vector<String> &r_flags, int p_flags) r_flags.push_back("--remote-debug"); - r_flags.push_back(host + ":" + String::num(remote_port)); + r_flags.push_back(get_debug_protocol() + host + ":" + String::num(remote_port)); List<String> breakpoints; ScriptEditor::get_singleton()->get_breakpoints(&breakpoints); @@ -871,7 +864,7 @@ Error EditorExportPlatform::export_project_files(const Ref<EditorExportPreset> & ProjectSettings::CustomMap custom_map; if (path_remaps.size()) { - if (1) { //new remap mode, use always as it's friendlier with multiple .pck exports + if (true) { //new remap mode, use always as it's friendlier with multiple .pck exports for (int i = 0; i < path_remaps.size(); i += 2) { String from = path_remaps[i]; String to = path_remaps[i + 1]; @@ -1127,7 +1120,7 @@ void EditorExportPlatform::gen_export_flags(Vector<String> &r_flags, int p_flags r_flags.push_back("--remote-debug"); - r_flags.push_back(host + ":" + String::num(remote_port)); + r_flags.push_back(get_debug_protocol() + host + ":" + String::num(remote_port)); List<String> breakpoints; ScriptEditor::get_singleton()->get_breakpoints(&breakpoints); diff --git a/editor/editor_export.h b/editor/editor_export.h index f47fe9c95e..1dedc21350 100644 --- a/editor/editor_export.h +++ b/editor/editor_export.h @@ -61,14 +61,14 @@ public: private: Ref<EditorExportPlatform> platform; - ExportFilter export_filter; + ExportFilter export_filter = EXPORT_ALL_RESOURCES; String include_filter; String exclude_filter; String export_path; String exporter; Set<String> selected_files; - bool runnable; + bool runnable = false; Vector<String> patches; @@ -82,7 +82,7 @@ private: String custom_features; - int script_mode; + int script_mode = MODE_SCRIPT_COMPILED; String script_key; protected: @@ -136,7 +136,7 @@ public: const List<PropertyInfo> &get_properties() const { return properties; } - EditorExportPreset(); + EditorExportPreset() {} }; struct SharedObject { @@ -270,6 +270,7 @@ public: virtual Error export_zip(const Ref<EditorExportPreset> &p_preset, bool p_debug, const String &p_path, int p_flags = 0); virtual void get_platform_features(List<String> *r_features) = 0; virtual void resolve_platform_feature_priorities(const Ref<EditorExportPreset> &p_preset, Set<String> &p_features) = 0; + virtual String get_debug_protocol() const { return "tcp://"; } EditorExportPlatform(); }; diff --git a/editor/editor_file_dialog.cpp b/editor/editor_file_dialog.cpp index 6a06c6657e..2411852541 100644 --- a/editor/editor_file_dialog.cpp +++ b/editor/editor_file_dialog.cpp @@ -691,7 +691,8 @@ bool EditorFileDialog::_is_open_should_be_disabled() { void EditorFileDialog::update_file_name() { int idx = filter->get_selected() - 1; if ((idx == -1 && filter->get_item_count() == 2) || (filter->get_item_count() > 2 && idx >= 0 && idx < filter->get_item_count() - 2)) { - if (idx == -1) idx += 1; + if (idx == -1) + idx += 1; String filter_str = filters[idx]; String file_str = file->get_text(); String base_name = file_str.get_basename(); diff --git a/editor/editor_help.cpp b/editor/editor_help.cpp index b566ad0fa4..6c5e8e17e4 100644 --- a/editor/editor_help.cpp +++ b/editor/editor_help.cpp @@ -526,7 +526,7 @@ void EditorHelp::_update_doc() { class_desc->push_font(doc_code_font); class_desc->push_indent(1); class_desc->push_table(2); - class_desc->set_table_column_expand(1, 1); + class_desc->set_table_column_expand(1, true); for (int i = 0; i < cd.properties.size(); i++) { property_line[cd.properties[i].name] = class_desc->get_line_count() - 2; //gets overridden if description @@ -629,7 +629,7 @@ void EditorHelp::_update_doc() { class_desc->push_font(doc_code_font); class_desc->push_indent(1); class_desc->push_table(2); - class_desc->set_table_column_expand(1, 1); + class_desc->set_table_column_expand(1, true); bool any_previous = false; for (int pass = 0; pass < 2; pass++) { @@ -698,7 +698,7 @@ void EditorHelp::_update_doc() { class_desc->push_indent(1); class_desc->push_table(2); - class_desc->set_table_column_expand(1, 1); + class_desc->set_table_column_expand(1, true); for (int i = 0; i < cd.theme_properties.size(); i++) { @@ -1005,7 +1005,7 @@ void EditorHelp::_update_doc() { property_line[cd.properties[i].name] = class_desc->get_line_count() - 2; class_desc->push_table(2); - class_desc->set_table_column_expand(1, 1); + class_desc->set_table_column_expand(1, true); class_desc->push_cell(); class_desc->push_font(doc_code_font); @@ -1497,7 +1497,8 @@ void EditorHelp::_notification(int p_what) { _class_desc_resized(); } } break; - default: break; + default: + break; } } @@ -1642,7 +1643,8 @@ void EditorHelpBit::_notification(int p_what) { rich_text->add_theme_color_override("selection_color", get_theme_color("accent_color", "Editor") * Color(1, 1, 1, 0.4)); } break; - default: break; + default: + break; } } @@ -1786,7 +1788,8 @@ void FindBar::_update_results_count() { results_count = 0; String searched = search_text->get_text(); - if (searched.empty()) return; + if (searched.empty()) + return; String full_text = rich_text_label->get_text(); diff --git a/editor/editor_help_search.cpp b/editor/editor_help_search.cpp index 01a50cad2c..1a865de23d 100644 --- a/editor/editor_help_search.cpp +++ b/editor/editor_help_search.cpp @@ -595,7 +595,6 @@ bool EditorHelpSearch::Runner::work(uint64_t slot) { } EditorHelpSearch::Runner::Runner(Control *p_icon_service, Tree *p_results_tree, const String &p_term, int p_search_flags) : - phase(0), ui_service(p_icon_service), results_tree(p_results_tree), term((p_search_flags & SEARCH_CASE_SENSITIVE) == 0 ? p_term.strip_edges().to_lower() : p_term.strip_edges()), diff --git a/editor/editor_help_search.h b/editor/editor_help_search.h index feff96d2e5..9f58c7244f 100644 --- a/editor/editor_help_search.h +++ b/editor/editor_help_search.h @@ -95,7 +95,7 @@ class EditorHelpSearch::Runner : public Reference { PHASE_SELECT_MATCH, PHASE_MAX }; - int phase; + int phase = 0; struct ClassMatch { DocData::ClassDoc *doc; diff --git a/editor/editor_log.cpp b/editor/editor_log.cpp index c89a7bcf23..78aed96363 100644 --- a/editor/editor_log.cpp +++ b/editor/editor_log.cpp @@ -63,11 +63,13 @@ void EditorLog::_notification(int p_what) { //button->set_icon(get_icon("Console","EditorIcons")); log->add_theme_font_override("normal_font", get_theme_font("output_source", "EditorFonts")); + log->add_theme_color_override("selection_color", get_theme_color("accent_color", "Editor") * Color(1, 1, 1, 0.4)); } else if (p_what == NOTIFICATION_THEME_CHANGED) { Ref<DynamicFont> df_output_code = get_theme_font("output_source", "EditorFonts"); if (df_output_code.is_valid()) { if (log != nullptr) { log->add_theme_font_override("normal_font", get_theme_font("output_source", "EditorFonts")); + log->add_theme_color_override("selection_color", get_theme_color("accent_color", "Editor") * Color(1, 1, 1, 0.4)); } } } diff --git a/editor/editor_node.cpp b/editor/editor_node.cpp index 90cea06439..c38f705c1b 100644 --- a/editor/editor_node.cpp +++ b/editor/editor_node.cpp @@ -42,7 +42,6 @@ #include "core/os/file_access.h" #include "core/os/keyboard.h" #include "core/os/os.h" -#include "core/path_remap.h" #include "core/print_string.h" #include "core/project_settings.h" #include "core/translation.h" @@ -158,6 +157,7 @@ #include "editor/plugins/style_box_editor_plugin.h" #include "editor/plugins/text_editor.h" #include "editor/plugins/texture_editor_plugin.h" +#include "editor/plugins/texture_layered_editor_plugin.h" #include "editor/plugins/texture_region_editor_plugin.h" #include "editor/plugins/theme_editor_plugin.h" #include "editor/plugins/tile_map_editor_plugin.h" @@ -381,6 +381,8 @@ void EditorNode::_notification(int p_what) { RS::get_singleton()->shadows_quality_set(shadows_quality); RS::ShadowQuality directional_shadow_quality = RS::ShadowQuality(int(GLOBAL_GET("rendering/quality/directional_shadow/soft_shadow_quality"))); RS::get_singleton()->directional_shadow_quality_set(directional_shadow_quality); + float probe_update_speed = GLOBAL_GET("rendering/lightmapper/probe_capture_update_speed"); + RS::get_singleton()->lightmap_set_probe_capture_update_speed(probe_update_speed); } ResourceImporterTexture::get_singleton()->update_imports(); @@ -713,7 +715,6 @@ void EditorNode::_sources_changed(bool p_exist) { // Reload the global shader variables, but this time // loading texures, as they are now properly imported. - print_line("done scanning, reload textures"); RenderingServer::get_singleton()->global_variables_load_settings(true); // Start preview thread now that it's safe. @@ -2070,9 +2071,11 @@ void EditorNode::_run(bool p_current, const String &p_custom) { args = ProjectSettings::get_singleton()->get("editor/main_run_args"); skip_breakpoints = EditorDebuggerNode::get_singleton()->is_skip_breakpoints(); + EditorDebuggerNode::get_singleton()->start(); Error error = editor_run.run(run_filename, args, breakpoints, skip_breakpoints); if (error != OK) { + EditorDebuggerNode::get_singleton()->stop(); show_accept(TTR("Could not start subprocess!"), TTR("OK")); return; } @@ -2094,6 +2097,24 @@ void EditorNode::_run(bool p_current, const String &p_custom) { _playing_edited = p_current; } +void EditorNode::_run_native(const Ref<EditorExportPreset> &p_preset) { + + bool autosave = EDITOR_GET("run/auto_save/save_before_running"); + if (autosave) { + _menu_option_confirm(FILE_SAVE_ALL_SCENES, false); + } + if (run_native->is_deploy_debug_remote_enabled()) { + _menu_option_confirm(RUN_STOP, true); + + if (!call_build()) + return; // build failed + + EditorDebuggerNode::get_singleton()->start(p_preset->get_platform()->get_debug_protocol()); + emit_signal("play_pressed"); + editor_run.run_native_notify(); + } +} + void EditorNode::_menu_option_confirm(int p_option, bool p_confirmed) { if (!p_confirmed) //this may be a hack.. @@ -2390,7 +2411,7 @@ void EditorNode::_menu_option_confirm(int p_option, bool p_confirmed) { } } break; - case EDIT_REVERT: { + case EDIT_RELOAD_SAVED_SCENE: { Node *scene = get_edited_scene(); @@ -2405,8 +2426,9 @@ void EditorNode::_menu_option_confirm(int p_option, bool p_confirmed) { } if (unsaved_cache && !p_confirmed) { - confirmation->get_ok()->set_text(TTR("Revert")); - confirmation->set_text(TTR("This action cannot be undone. Revert anyway?")); + confirmation->get_ok()->set_text(TTR("Reload Saved Scene")); + confirmation->set_text( + TTR("The current scene has unsaved changes.\nReload the saved scene anyway? This action cannot be undone.")); confirmation->popup_centered(); break; } @@ -2462,6 +2484,7 @@ void EditorNode::_menu_option_confirm(int p_option, bool p_confirmed) { } } } + EditorDebuggerNode::get_singleton()->stop(); emit_signal("stop_pressed"); } break; @@ -2480,22 +2503,6 @@ void EditorNode::_menu_option_confirm(int p_option, bool p_confirmed) { _run(true); } break; - case RUN_PLAY_NATIVE: { - - bool autosave = EDITOR_GET("run/auto_save/save_before_running"); - if (autosave) { - _menu_option_confirm(FILE_SAVE_ALL_SCENES, false); - } - if (run_native->is_deploy_debug_remote_enabled()) { - _menu_option_confirm(RUN_STOP, true); - - if (!call_build()) - break; // build failed - - emit_signal("play_pressed"); - editor_run.run_native_notify(); - } - } break; case RUN_SCENE_SETTINGS: { run_settings_dialog->popup_run_settings(); @@ -3082,7 +3089,8 @@ void EditorNode::_remove_edited_scene(bool p_change_tab) { ScriptEditor::get_singleton()->close_builtin_scripts_from_scene(editor_data.get_scene_path(old_index)); } - if (p_change_tab) _scene_tab_changed(new_index); + if (p_change_tab) + _scene_tab_changed(new_index); editor_data.remove_scene(old_index); editor_data.get_undo_redo().clear_history(false); _update_title(); @@ -5677,7 +5685,7 @@ EditorNode::EditorNode() { import_texture.instance(); ResourceFormatImporter::get_singleton()->add_importer(import_texture); - /* Ref<ResourceImporterLayeredTexture> import_cubemap; + Ref<ResourceImporterLayeredTexture> import_cubemap; import_cubemap.instance(); import_cubemap->set_mode(ResourceImporterLayeredTexture::MODE_CUBEMAP); ResourceFormatImporter::get_singleton()->add_importer(import_cubemap); @@ -5691,7 +5699,12 @@ EditorNode::EditorNode() { import_cubemap_array.instance(); import_cubemap_array->set_mode(ResourceImporterLayeredTexture::MODE_CUBEMAP_ARRAY); ResourceFormatImporter::get_singleton()->add_importer(import_cubemap_array); -*/ + + /*Ref<ResourceImporterLayeredTexture> import_3d; + import_3d.instance(); + import_3d->set_mode(ResourceImporterLayeredTexture::MODE_3D); + ResourceFormatImporter::get_singleton()->add_importer(import_3d);*/ + Ref<ResourceImporterImage> import_image; import_image.instance(); ResourceFormatImporter::get_singleton()->add_importer(import_image); @@ -6156,7 +6169,7 @@ EditorNode::EditorNode() { p->add_shortcut(ED_SHORTCUT("editor/redo", TTR("Redo"), KEY_MASK_CMD + KEY_MASK_SHIFT + KEY_Z), EDIT_REDO, true); p->add_separator(); - p->add_shortcut(ED_SHORTCUT("editor/revert_scene", TTR("Revert Scene")), EDIT_REVERT); + p->add_shortcut(ED_SHORTCUT("editor/reload_saved_scene", TTR("Reload Saved Scene")), EDIT_RELOAD_SAVED_SCENE); p->add_shortcut(ED_SHORTCUT("editor/close_scene", TTR("Close Scene"), KEY_MASK_CMD + KEY_MASK_SHIFT + KEY_W), FILE_CLOSE); recent_scenes = memnew(PopupMenu); @@ -6338,7 +6351,7 @@ EditorNode::EditorNode() { run_native = memnew(EditorRunNative); play_hb->add_child(run_native); - run_native->connect("native_run", callable_mp(this, &EditorNode::_menu_option), varray(RUN_PLAY_NATIVE)); + run_native->connect("native_run", callable_mp(this, &EditorNode::_run_native)); play_scene_button = memnew(ToolButton); play_hb->add_child(play_scene_button); @@ -6662,7 +6675,7 @@ EditorNode::EditorNode() { add_editor_plugin(memnew(SpriteFramesEditorPlugin(this))); add_editor_plugin(memnew(TextureRegionEditorPlugin(this))); add_editor_plugin(memnew(GIProbeEditorPlugin(this))); - //add_editor_plugin(memnew(BakedLightmapEditorPlugin(this))); + add_editor_plugin(memnew(BakedLightmapEditorPlugin(this))); add_editor_plugin(memnew(Path2DEditorPlugin(this))); add_editor_plugin(memnew(Path3DEditorPlugin(this))); add_editor_plugin(memnew(Line2DEditorPlugin(this))); @@ -6673,6 +6686,7 @@ EditorNode::EditorNode() { add_editor_plugin(memnew(CollisionShape2DEditorPlugin(this))); add_editor_plugin(memnew(CurveEditorPlugin(this))); add_editor_plugin(memnew(TextureEditorPlugin(this))); + add_editor_plugin(memnew(TextureLayeredEditorPlugin(this))); add_editor_plugin(memnew(AudioStreamEditorPlugin(this))); add_editor_plugin(memnew(AudioBusesEditorPlugin(audio_bus_editor))); add_editor_plugin(memnew(Skeleton3DEditorPlugin(this))); diff --git a/editor/editor_node.h b/editor/editor_node.h index c6f04b0749..b39a3bbfd0 100644 --- a/editor/editor_node.h +++ b/editor/editor_node.h @@ -32,6 +32,7 @@ #define EDITOR_NODE_H #include "editor/editor_data.h" +#include "editor/editor_export.h" #include "editor/editor_folding.h" #include "editor/editor_run.h" #include "editor/inspector_dock.h" @@ -152,7 +153,7 @@ private: FILE_EXTERNAL_OPEN_SCENE, EDIT_UNDO, EDIT_REDO, - EDIT_REVERT, + EDIT_RELOAD_SAVED_SCENE, TOOLS_ORPHAN_RESOURCES, TOOLS_CUSTOM, RESOURCE_SAVE, @@ -161,7 +162,6 @@ private: RUN_STOP, RUN_PLAY_SCENE, - RUN_PLAY_NATIVE, RUN_PLAY_CUSTOM_SCENE, RUN_SCENE_SETTINGS, RUN_SETTINGS, @@ -492,6 +492,7 @@ private: void _quick_run(); void _run(bool p_current = false, const String &p_custom = ""); + void _run_native(const Ref<EditorExportPreset> &p_preset); void _save_optimized(); void _import_action(const String &p_action); diff --git a/editor/editor_plugin.cpp b/editor/editor_plugin.cpp index 746ebc8292..8689cad45b 100644 --- a/editor/editor_plugin.cpp +++ b/editor/editor_plugin.cpp @@ -913,16 +913,6 @@ void EditorPlugin::_bind_methods() { BIND_ENUM_CONSTANT(DOCK_SLOT_MAX); } -EditorPlugin::EditorPlugin() : - undo_redo(nullptr), - input_event_forwarding_always_enabled(false), - force_draw_over_forwarding_enabled(false), - last_main_screen_name("") { -} - -EditorPlugin::~EditorPlugin() { -} - EditorPluginCreateFunc EditorPlugins::creation_funcs[MAX_CREATE_FUNCS]; int EditorPlugins::creation_func_count = 0; diff --git a/editor/editor_plugin.h b/editor/editor_plugin.h index 2ca96ceed2..a939d58752 100644 --- a/editor/editor_plugin.h +++ b/editor/editor_plugin.h @@ -113,12 +113,12 @@ class EditorPlugin : public Node { GDCLASS(EditorPlugin, Node); friend class EditorData; - UndoRedo *undo_redo; + UndoRedo *undo_redo = nullptr; UndoRedo *_get_undo_redo() { return undo_redo; } - bool input_event_forwarding_always_enabled; - bool force_draw_over_forwarding_enabled; + bool input_event_forwarding_always_enabled = false; + bool force_draw_over_forwarding_enabled = false; String last_main_screen_name; @@ -242,8 +242,8 @@ public: void enable_plugin(); void disable_plugin(); - EditorPlugin(); - virtual ~EditorPlugin(); + EditorPlugin() {} + virtual ~EditorPlugin() {} }; VARIANT_ENUM_CAST(EditorPlugin::CustomControlContainer); diff --git a/editor/editor_properties.cpp b/editor/editor_properties.cpp index 5213d7ec15..c5772e0ea7 100644 --- a/editor/editor_properties.cpp +++ b/editor/editor_properties.cpp @@ -3318,13 +3318,27 @@ bool EditorInspectorDefaultPlugin::parse_property(Object *p_object, Variant::Typ EditorPropertyMember::Type type = EditorPropertyMember::MEMBER_METHOD_OF_BASE_TYPE; switch (p_hint) { - case PROPERTY_HINT_METHOD_OF_BASE_TYPE: type = EditorPropertyMember::MEMBER_METHOD_OF_BASE_TYPE; break; - case PROPERTY_HINT_METHOD_OF_INSTANCE: type = EditorPropertyMember::MEMBER_METHOD_OF_INSTANCE; break; - case PROPERTY_HINT_METHOD_OF_SCRIPT: type = EditorPropertyMember::MEMBER_METHOD_OF_SCRIPT; break; - case PROPERTY_HINT_PROPERTY_OF_VARIANT_TYPE: type = EditorPropertyMember::MEMBER_PROPERTY_OF_VARIANT_TYPE; break; - case PROPERTY_HINT_PROPERTY_OF_BASE_TYPE: type = EditorPropertyMember::MEMBER_PROPERTY_OF_BASE_TYPE; break; - case PROPERTY_HINT_PROPERTY_OF_INSTANCE: type = EditorPropertyMember::MEMBER_PROPERTY_OF_INSTANCE; break; - case PROPERTY_HINT_PROPERTY_OF_SCRIPT: type = EditorPropertyMember::MEMBER_PROPERTY_OF_SCRIPT; break; + case PROPERTY_HINT_METHOD_OF_BASE_TYPE: + type = EditorPropertyMember::MEMBER_METHOD_OF_BASE_TYPE; + break; + case PROPERTY_HINT_METHOD_OF_INSTANCE: + type = EditorPropertyMember::MEMBER_METHOD_OF_INSTANCE; + break; + case PROPERTY_HINT_METHOD_OF_SCRIPT: + type = EditorPropertyMember::MEMBER_METHOD_OF_SCRIPT; + break; + case PROPERTY_HINT_PROPERTY_OF_VARIANT_TYPE: + type = EditorPropertyMember::MEMBER_PROPERTY_OF_VARIANT_TYPE; + break; + case PROPERTY_HINT_PROPERTY_OF_BASE_TYPE: + type = EditorPropertyMember::MEMBER_PROPERTY_OF_BASE_TYPE; + break; + case PROPERTY_HINT_PROPERTY_OF_INSTANCE: + type = EditorPropertyMember::MEMBER_PROPERTY_OF_INSTANCE; + break; + case PROPERTY_HINT_PROPERTY_OF_SCRIPT: + type = EditorPropertyMember::MEMBER_PROPERTY_OF_SCRIPT; + break; default: { } } diff --git a/editor/editor_run.cpp b/editor/editor_run.cpp index b4ddb7ebfa..01156d9809 100644 --- a/editor/editor_run.cpp +++ b/editor/editor_run.cpp @@ -53,7 +53,7 @@ Error EditorRun::run(const String &p_scene, const String &p_custom_args, const L } args.push_back("--remote-debug"); - args.push_back(remote_host + ":" + String::num(remote_port)); + args.push_back("tcp://" + remote_host + ":" + String::num(remote_port)); args.push_back("--allow_focus_steal_pid"); args.push_back(itos(OS::get_singleton()->get_process_id())); diff --git a/editor/editor_run_native.cpp b/editor/editor_run_native.cpp index 464666ac6a..d4450acea2 100644 --- a/editor/editor_run_native.cpp +++ b/editor/editor_run_native.cpp @@ -133,7 +133,7 @@ void EditorRunNative::_run_native(int p_idx, int p_platform) { return; } - emit_signal("native_run"); + emit_signal("native_run", preset); int flags = 0; @@ -160,7 +160,7 @@ void EditorRunNative::resume_run_native() { void EditorRunNative::_bind_methods() { - ADD_SIGNAL(MethodInfo("native_run")); + ADD_SIGNAL(MethodInfo("native_run", PropertyInfo(Variant::OBJECT, "preset", PROPERTY_HINT_RESOURCE_TYPE, "EditorExportPreset"))); } bool EditorRunNative::is_deploy_debug_remote_enabled() const { diff --git a/editor/editor_sectioned_inspector.cpp b/editor/editor_sectioned_inspector.cpp index bccc38ca1b..3bb7072eac 100644 --- a/editor/editor_sectioned_inspector.cpp +++ b/editor/editor_sectioned_inspector.cpp @@ -29,7 +29,9 @@ /*************************************************************************/ #include "editor_sectioned_inspector.h" + #include "editor_scale.h" + class SectionedInspectorFilter : public Object { GDCLASS(SectionedInspectorFilter, Object); @@ -307,8 +309,7 @@ EditorInspector *SectionedInspector::get_inspector() { SectionedInspector::SectionedInspector() : sections(memnew(Tree)), filter(memnew(SectionedInspectorFilter)), - inspector(memnew(EditorInspector)), - search_box(nullptr) { + inspector(memnew(EditorInspector)) { add_theme_constant_override("autohide", 1); // Fixes the dragger always showing up VBoxContainer *left_vb = memnew(VBoxContainer); diff --git a/editor/editor_sectioned_inspector.h b/editor/editor_sectioned_inspector.h index 7073b73751..f01e6e250e 100644 --- a/editor/editor_sectioned_inspector.h +++ b/editor/editor_sectioned_inspector.h @@ -48,7 +48,7 @@ class SectionedInspector : public HSplitContainer { Map<String, TreeItem *> section_map; EditorInspector *inspector; - LineEdit *search_box; + LineEdit *search_box = nullptr; String selected_category; diff --git a/editor/editor_settings.h b/editor/editor_settings.h index 29b89ef1a8..8f1a3c8333 100644 --- a/editor/editor_settings.h +++ b/editor/editor_settings.h @@ -64,27 +64,19 @@ public: private: struct VariantContainer { - int order; + int order = 0; Variant variant; Variant initial; - bool has_default_value; - bool hide_from_editor; - bool save; - bool restart_if_changed; - VariantContainer() : - order(0), - has_default_value(false), - hide_from_editor(false), - save(false), - restart_if_changed(false) { - } + bool has_default_value = false; + bool hide_from_editor = false; + bool save = false; + bool restart_if_changed = false; + + VariantContainer() {} + VariantContainer(const Variant &p_variant, int p_order) : order(p_order), - variant(p_variant), - has_default_value(false), - hide_from_editor(false), - save(false), - restart_if_changed(false) { + variant(p_variant) { } }; diff --git a/editor/export_template_manager.cpp b/editor/export_template_manager.cpp index f0ee5d451f..f52e340a26 100644 --- a/editor/export_template_manager.cpp +++ b/editor/export_template_manager.cpp @@ -504,18 +504,26 @@ void ExportTemplateManager::_notification(int p_what) { status = TTR("Disconnected"); errored = true; break; - case HTTPClient::STATUS_RESOLVING: status = TTR("Resolving"); break; + case HTTPClient::STATUS_RESOLVING: + status = TTR("Resolving"); + break; case HTTPClient::STATUS_CANT_RESOLVE: status = TTR("Can't Resolve"); errored = true; break; - case HTTPClient::STATUS_CONNECTING: status = TTR("Connecting..."); break; + case HTTPClient::STATUS_CONNECTING: + status = TTR("Connecting..."); + break; case HTTPClient::STATUS_CANT_CONNECT: status = TTR("Can't Connect"); errored = true; break; - case HTTPClient::STATUS_CONNECTED: status = TTR("Connected"); break; - case HTTPClient::STATUS_REQUESTING: status = TTR("Requesting..."); break; + case HTTPClient::STATUS_CONNECTED: + status = TTR("Connected"); + break; + case HTTPClient::STATUS_REQUESTING: + status = TTR("Requesting..."); + break; case HTTPClient::STATUS_BODY: status = TTR("Downloading"); if (download_templates->get_body_size() > 0) { diff --git a/editor/filesystem_dock.h b/editor/filesystem_dock.h index 6d2d8510d1..08c2124ae8 100644 --- a/editor/filesystem_dock.h +++ b/editor/filesystem_dock.h @@ -148,11 +148,9 @@ private: class FileOrFolder { public: String path; - bool is_file; + bool is_file = false; - FileOrFolder() : - path(""), - is_file(false) {} + FileOrFolder() {} FileOrFolder(const String &p_path, bool p_is_file) : path(p_path), is_file(p_is_file) {} diff --git a/editor/icons/ErrorWarning.svg b/editor/icons/ErrorWarning.svg new file mode 100644 index 0000000000..72b5037e50 --- /dev/null +++ b/editor/icons/ErrorWarning.svg @@ -0,0 +1 @@ +<svg height="8" viewBox="0 0 8 8" width="8" xmlns="http://www.w3.org/2000/svg"><path d="m4 0c-2.216 0-4 1.784-4 4s1.784 4 4 4z" fill="#ff5d5d"/><path d="m4 .00000003c2.216 0 4 1.78399997 4 3.99999997s-1.784 4-4 4z" fill="#ffdd65"/></svg>
\ No newline at end of file diff --git a/editor/icons/README.md b/editor/icons/README.md index 5f652e47ab..3159565180 100644 --- a/editor/icons/README.md +++ b/editor/icons/README.md @@ -1,12 +1,7 @@ -The icons here are optimized SVGs, because the editor renders the svgs at runtime, they need -to be small in size, so they can be efficiently parsed. +# Editor icons -The original icons can be found at: -https://github.com/godotengine/godot-design/tree/master/engine/icons +This folder contains all the icons used by Godot editor (except for platform +icons which are located in their respective platform folder). -There you can find the optimizer script. - -If you add a new icon, please make a pull request to this repo: -https://github.com/godotengine/godot-design/ - -and store the optimized SVG version here. +See [Editor icons](https://docs.godotengine.org/en/latest/development/editor/creating_icons.html) +in the documentation for details on creating icons for the Godot editor. diff --git a/editor/import/collada.h b/editor/import/collada.h index b74332fb22..187a8092da 100644 --- a/editor/import/collada.h +++ b/editor/import/collada.h @@ -61,26 +61,22 @@ public: struct Channel { - int uv_idx; + int uv_idx = 0; String texture; Color color; - Channel() { uv_idx = 0; } + Channel() {} }; Channel diffuse, specular, emission, bump; - float shininess; - bool found_double_sided; - bool double_sided; - bool unshaded; + float shininess = 40; + bool found_double_sided = false; + bool double_sided = true; + bool unshaded = false; String get_texture_path(const String &p_source, Collada &state) const; Effect() { diffuse.color = Color(1, 1, 1, 1); - double_sided = true; - found_double_sided = false; - shininess = 40; - unshaded = false; } }; @@ -91,31 +87,24 @@ public: MODE_ORTHOGONAL }; - Mode mode; + Mode mode = MODE_PERSPECTIVE; union { struct { - float x_fov; - float y_fov; + float x_fov = 0; + float y_fov = 0; } perspective; struct { - float x_mag; - float y_mag; + float x_mag = 0; + float y_mag = 0; } orthogonal; }; - float aspect; - float z_near; - float z_far; - - CameraData() : - mode(MODE_PERSPECTIVE), - aspect(1), - z_near(0.1), - z_far(100) { - perspective.x_fov = 0; - perspective.y_fov = 0; - } + float aspect = 1; + float z_near = 0.1; + float z_far = 100; + + CameraData() {} }; struct LightData { @@ -127,26 +116,18 @@ public: MODE_SPOT }; - Mode mode; + Mode mode = MODE_AMBIENT; - Color color; + Color color = Color(1, 1, 1, 1); - float constant_att; - float linear_att; - float quad_att; - - float spot_angle; - float spot_exp; - - LightData() : - mode(MODE_AMBIENT), - color(Color(1, 1, 1, 1)), - constant_att(0), - linear_att(0), - quad_att(0), - spot_angle(45), - spot_exp(1) { - } + float constant_att = 0; + float linear_att = 0; + float quad_att = 0; + + float spot_angle = 45; + float spot_exp = 1; + + LightData() {} }; struct MeshData { @@ -185,19 +166,16 @@ public: Vector<Primitives> primitives; - bool found_double_sided; - bool double_sided; + bool found_double_sided = false; + bool double_sided = true; - MeshData() { - found_double_sided = false; - double_sided = true; - } + MeshData() {} }; struct CurveData { String name; - bool closed; + bool closed = false; struct Source { @@ -210,15 +188,13 @@ public: Map<String, String> control_vertices; - CurveData() { - - closed = false; - } + CurveData() {} }; + struct SkinControllerData { String base; - bool use_idrefs; + bool use_idrefs = false; Transform bind_shape; @@ -226,10 +202,8 @@ public: Vector<String> sarray; //maybe for names Vector<float> array; - int stride; - Source() { - stride = 1; - } + int stride = 1; + Source() {} }; Map<String, Source> sources; @@ -256,7 +230,7 @@ public: Map<String, Transform> bone_rest_map; - SkinControllerData() { use_idrefs = false; } + SkinControllerData() {} }; struct MorphControllerData { @@ -266,10 +240,10 @@ public: struct Source { - int stride; + int stride = 1; Vector<String> sarray; //maybe for names Vector<float> array; - Source() { stride = 1; } + Source() {} }; Map<String, Source> sources; @@ -280,14 +254,14 @@ public: struct Vertex { - int idx; + int idx = 0; Vector3 vertex; Vector3 normal; Vector3 uv; Vector3 uv2; Plane tangent; Color color; - int uid; + int uid = 0; struct Weight { int bone_idx; float weight; @@ -350,11 +324,9 @@ public: return uid < p_vert.uid; } - Vertex() { - uid = 0; - idx = 0; - } + Vertex() {} }; + struct Node { enum Type { @@ -382,31 +354,26 @@ public: Vector<float> data; }; - Type type; + Type type = TYPE_NODE; String name; String id; String empty_draw_type; - bool noname; + bool noname = false; Vector<XForm> xform_list; Transform default_transform; Transform post_transform; Vector<Node *> children; - Node *parent; + Node *parent = nullptr; Transform compute_transform(Collada &state) const; Transform get_global_transform() const; Transform get_transform() const; - bool ignore_anim; + bool ignore_anim = false; - Node() { - noname = false; - type = TYPE_NODE; - parent = nullptr; - ignore_anim = false; - } + Node() {} virtual ~Node() { for (int i = 0; i < children.size(); i++) memdelete(children[i]); @@ -420,11 +387,10 @@ public: struct NodeJoint : public Node { - NodeSkeleton *owner; + NodeSkeleton *owner = nullptr; String sid; NodeJoint() { type = TYPE_JOINT; - owner = nullptr; } }; @@ -471,14 +437,11 @@ public: struct AnimationClip { String name; - float begin; - float end; + float begin = 0; + float end = 1; Vector<String> tracks; - AnimationClip() { - begin = 0; - end = 1; - } + AnimationClip() {} }; struct AnimationTrack { @@ -487,7 +450,7 @@ public: String target; String param; String component; - bool property; + bool property = false; enum InterpolationType { INTERP_LINEAR, @@ -505,16 +468,16 @@ public: Vector<float> data; Point2 in_tangent; Point2 out_tangent; - InterpolationType interp_type; + InterpolationType interp_type = INTERP_LINEAR; - Key() { interp_type = INTERP_LINEAR; } + Key() {} }; Vector<float> get_value_at_time(float p_time) const; Vector<Key> keys; - AnimationTrack() { property = false; } + AnimationTrack() {} }; /****************/ @@ -523,10 +486,10 @@ public: struct State { - int import_flags; + int import_flags = 0; - float unit_scale; - Vector3::Axis up_axis; + float unit_scale = 1.0; + Vector3::Axis up_axis = Vector3::AXIS_Y; bool z_up; struct Version { @@ -573,14 +536,9 @@ public: Map<String, Vector<int>> referenced_tracks; Map<String, Vector<int>> by_id_tracks; - float animation_length; + float animation_length = 0; - State() : - import_flags(0), - unit_scale(1.0), - up_axis(Vector3::AXIS_Y), - animation_length(0) { - } + State() {} } state; Error load(const String &p_path, int p_flags = 0); diff --git a/editor/import/editor_scene_importer_gltf.cpp b/editor/import/editor_scene_importer_gltf.cpp index 45e376a2aa..1a1e7171b9 100644 --- a/editor/import/editor_scene_importer_gltf.cpp +++ b/editor/import/editor_scene_importer_gltf.cpp @@ -538,12 +538,18 @@ Error EditorSceneImporterGLTF::_parse_accessors(GLTFState &state) { String EditorSceneImporterGLTF::_get_component_type_name(const uint32_t p_component) { switch (p_component) { - case COMPONENT_TYPE_BYTE: return "Byte"; - case COMPONENT_TYPE_UNSIGNED_BYTE: return "UByte"; - case COMPONENT_TYPE_SHORT: return "Short"; - case COMPONENT_TYPE_UNSIGNED_SHORT: return "UShort"; - case COMPONENT_TYPE_INT: return "Int"; - case COMPONENT_TYPE_FLOAT: return "Float"; + case COMPONENT_TYPE_BYTE: + return "Byte"; + case COMPONENT_TYPE_UNSIGNED_BYTE: + return "UByte"; + case COMPONENT_TYPE_SHORT: + return "Short"; + case COMPONENT_TYPE_UNSIGNED_SHORT: + return "UShort"; + case COMPONENT_TYPE_INT: + return "Int"; + case COMPONENT_TYPE_FLOAT: + return "Float"; } return "<Error>"; @@ -655,12 +661,24 @@ Error EditorSceneImporterGLTF::_decode_buffer_view(GLTFState &state, double *dst int EditorSceneImporterGLTF::_get_component_type_size(const int component_type) { switch (component_type) { - case COMPONENT_TYPE_BYTE: return 1; break; - case COMPONENT_TYPE_UNSIGNED_BYTE: return 1; break; - case COMPONENT_TYPE_SHORT: return 2; break; - case COMPONENT_TYPE_UNSIGNED_SHORT: return 2; break; - case COMPONENT_TYPE_INT: return 4; break; - case COMPONENT_TYPE_FLOAT: return 4; break; + case COMPONENT_TYPE_BYTE: + return 1; + break; + case COMPONENT_TYPE_UNSIGNED_BYTE: + return 1; + break; + case COMPONENT_TYPE_SHORT: + return 2; + break; + case COMPONENT_TYPE_UNSIGNED_SHORT: + return 2; + break; + case COMPONENT_TYPE_INT: + return 4; + break; + case COMPONENT_TYPE_FLOAT: + return 4; + break; default: { ERR_FAIL_V(0); } diff --git a/editor/import/editor_scene_importer_gltf.h b/editor/import/editor_scene_importer_gltf.h index d127a87782..86d7e627a3 100644 --- a/editor/import/editor_scene_importer_gltf.h +++ b/editor/import/editor_scene_importer_gltf.h @@ -94,90 +94,60 @@ class EditorSceneImporterGLTF : public EditorSceneImporter { struct GLTFNode { //matrices need to be transformed to this - GLTFNodeIndex parent; - int height; + GLTFNodeIndex parent = -1; + int height = -1; Transform xform; String name; - GLTFMeshIndex mesh; - GLTFCameraIndex camera; - GLTFSkinIndex skin; + GLTFMeshIndex mesh = -1; + GLTFCameraIndex camera = -1; + GLTFSkinIndex skin = -1; - GLTFSkeletonIndex skeleton; - bool joint; + GLTFSkeletonIndex skeleton = -1; + bool joint = false; Vector3 translation; Quat rotation; - Vector3 scale; + Vector3 scale = Vector3(1, 1, 1); Vector<int> children; - GLTFNodeIndex fake_joint_parent; - - GLTFNode() : - parent(-1), - height(-1), - mesh(-1), - camera(-1), - skin(-1), - skeleton(-1), - joint(false), - translation(0, 0, 0), - scale(Vector3(1, 1, 1)), - fake_joint_parent(-1) {} + GLTFNodeIndex fake_joint_parent = -1; + + GLTFNode() {} }; struct GLTFBufferView { - GLTFBufferIndex buffer; - int byte_offset; - int byte_length; - int byte_stride; - bool indices; + GLTFBufferIndex buffer = -1; + int byte_offset = 0; + int byte_length = 0; + int byte_stride = 0; + bool indices = false; //matrices need to be transformed to this - GLTFBufferView() : - buffer(-1), - byte_offset(0), - byte_length(0), - byte_stride(0), - indices(false) { - } + GLTFBufferView() {} }; struct GLTFAccessor { - GLTFBufferViewIndex buffer_view; - int byte_offset; - int component_type; - bool normalized; - int count; + GLTFBufferViewIndex buffer_view = 0; + int byte_offset = 0; + int component_type = 0; + bool normalized = false; + int count = 0; GLTFType type; - float min; - float max; - int sparse_count; - int sparse_indices_buffer_view; - int sparse_indices_byte_offset; - int sparse_indices_component_type; - int sparse_values_buffer_view; - int sparse_values_byte_offset; - - GLTFAccessor() { - buffer_view = 0; - byte_offset = 0; - component_type = 0; - normalized = false; - count = 0; - min = 0; - max = 0; - sparse_count = 0; - sparse_indices_buffer_view = 0; - sparse_indices_byte_offset = 0; - sparse_indices_component_type = 0; - sparse_values_buffer_view = 0; - sparse_values_byte_offset = 0; - } + float min = 0; + float max = 0; + int sparse_count = 0; + int sparse_indices_buffer_view = 0; + int sparse_indices_byte_offset = 0; + int sparse_indices_component_type = 0; + int sparse_values_buffer_view = 0; + int sparse_values_byte_offset = 0; + + GLTFAccessor() {} }; struct GLTFTexture { GLTFImageIndex src_image; @@ -192,21 +162,19 @@ class EditorSceneImporterGLTF : public EditorSceneImporter { Vector<GLTFNodeIndex> roots; // The created Skeleton for the scene - Skeleton3D *godot_skeleton; + Skeleton3D *godot_skeleton = nullptr; // Set of unique bone names for the skeleton Set<String> unique_names; - GLTFSkeleton() : - godot_skeleton(nullptr) { - } + GLTFSkeleton() {} }; struct GLTFSkin { String name; // The "skeleton" property defined in the gltf spec. -1 = Scene Root - GLTFNodeIndex skin_root; + GLTFNodeIndex skin_root = -1; Vector<GLTFNodeIndex> joints_original; Vector<Transform> inverse_binds; @@ -226,7 +194,7 @@ class EditorSceneImporterGLTF : public EditorSceneImporter { Vector<GLTFNodeIndex> roots; // The GLTF Skeleton this Skin points to (after we determine skeletons) - GLTFSkeletonIndex skeleton; + GLTFSkeletonIndex skeleton = -1; // A mapping from the joint indices (in the order of joints_original) to the // Godot Skeleton's bone_indices @@ -237,9 +205,7 @@ class EditorSceneImporterGLTF : public EditorSceneImporter { // to the generated skeleton for the mesh instances. Ref<Skin> godot_skin; - GLTFSkin() : - skin_root(-1), - skeleton(-1) {} + GLTFSkin() {} }; struct GLTFMesh { @@ -249,17 +215,12 @@ class EditorSceneImporterGLTF : public EditorSceneImporter { struct GLTFCamera { - bool perspective; - float fov_size; - float zfar; - float znear; + bool perspective = true; + float fov_size = 64; + float zfar = 500; + float znear = 0.1; - GLTFCamera() { - perspective = true; - fov_size = 65; - zfar = 500; - znear = 0.1; - } + GLTFCamera() {} }; struct GLTFAnimation { diff --git a/editor/import/resource_importer_csv_translation.cpp b/editor/import/resource_importer_csv_translation.cpp index 1d7bed3975..9d72396449 100644 --- a/editor/import/resource_importer_csv_translation.cpp +++ b/editor/import/resource_importer_csv_translation.cpp @@ -83,9 +83,15 @@ Error ResourceImporterCSVTranslation::import(const String &p_source_file, const String delimiter; switch ((int)p_options["delimiter"]) { - case 0: delimiter = ","; break; - case 1: delimiter = ";"; break; - case 2: delimiter = "\t"; break; + case 0: + delimiter = ","; + break; + case 1: + delimiter = ";"; + break; + case 2: + delimiter = "\t"; + break; } FileAccessRef f = FileAccess::open(p_source_file, FileAccess::READ); diff --git a/editor/import/resource_importer_layered_texture.cpp b/editor/import/resource_importer_layered_texture.cpp index a4cbc81b26..c46cf4c1a8 100644 --- a/editor/import/resource_importer_layered_texture.cpp +++ b/editor/import/resource_importer_layered_texture.cpp @@ -36,9 +36,9 @@ #include "core/io/image_loader.h" #include "editor/editor_file_system.h" #include "editor/editor_node.h" +#include "resource_importer_texture.h" #include "scene/resources/texture.h" -#if 0 String ResourceImporterLayeredTexture::get_importer_name() const { switch (mode) { @@ -51,6 +51,9 @@ String ResourceImporterLayeredTexture::get_importer_name() const { case MODE_CUBEMAP_ARRAY: { return "cubemap_array_texture"; } break; + case MODE_3D: { + return "cubemap_3d_texture"; + } break; } ERR_FAIL_V(""); @@ -68,6 +71,9 @@ String ResourceImporterLayeredTexture::get_visible_name() const { case MODE_CUBEMAP_ARRAY: { return "CubemapArray"; } break; + case MODE_3D: { + return "3D"; + } break; } ERR_FAIL_V(""); @@ -79,13 +85,16 @@ void ResourceImporterLayeredTexture::get_recognized_extensions(List<String> *p_e String ResourceImporterLayeredTexture::get_save_extension() const { switch (mode) { case MODE_CUBEMAP: { - return "cube"; + return "scube"; } break; case MODE_2D_ARRAY: { - return "tex2darr"; + return "stexarray"; } break; case MODE_CUBEMAP_ARRAY: { - return "cubearr"; + return "scubearray"; + } break; + case MODE_3D: { + return "stex3d"; } break; } @@ -96,13 +105,16 @@ String ResourceImporterLayeredTexture::get_resource_type() const { switch (mode) { case MODE_CUBEMAP: { - return "Cubemap"; + return "StreamCubemap"; } break; case MODE_2D_ARRAY: { - return "Texture2DArray"; + return "StreamTexture2DArray"; } break; case MODE_CUBEMAP_ARRAY: { - return "CubemapArray"; + return "StreamCubemapArray"; + } break; + case MODE_3D: { + return "StreamTexture3D"; } break; } ERR_FAIL_V(String()); @@ -110,6 +122,9 @@ String ResourceImporterLayeredTexture::get_resource_type() const { bool ResourceImporterLayeredTexture::get_option_visibility(const String &p_option, const Map<StringName, Variant> &p_options) const { + if (p_option == "compress/lossy_quality" && p_options.has("compress/mode")) { + return int(p_options["compress/mode"]) == COMPRESS_LOSSY; + } return true; } @@ -123,138 +138,109 @@ String ResourceImporterLayeredTexture::get_preset_name(int p_idx) const { void ResourceImporterLayeredTexture::get_import_options(List<ImportOption> *r_options, int p_preset) const { - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Video RAM,Uncompressed", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), 1)); - r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "compress/no_bptc_if_rgb"), false)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Lossy,Video RAM,Uncompressed,Basis Universal", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "compress/lossy_quality", PROPERTY_HINT_RANGE, "0,1,0.01"), 0.7)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_compression", PROPERTY_HINT_ENUM, "Disabled,Opaque Only,Always"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Disabled,Enabled,RGBA Only"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/channel_pack", PROPERTY_HINT_ENUM, "sRGB Friendly,Optimized"), 0)); - r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "flags/mipmaps"), true)); - if (mode == MODE_2D_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::BOOL, "mipmaps/generate"), true)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "mipmaps/limit", PROPERTY_HINT_RANGE, "-1,256"), -1)); + + if (mode == MODE_2D_ARRAY || mode == MODE_3D) { r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/horizontal", PROPERTY_HINT_RANGE, "1,256,1"), 8)); - } - if (mode == MODE_2D_ARRAY || mode == MODE_CUBEMAP_ARRAY) { r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/vertical", PROPERTY_HINT_RANGE, "1,256,1"), 8)); } -} - -void ResourceImporterLayeredTexture::_save_tex(const Vector<Ref<Image> > &p_images, const String &p_to_path, int p_compress_mode, Image::CompressMode p_vram_compression, bool p_mipmaps) { - - FileAccess *f = FileAccess::open(p_to_path, FileAccess::WRITE); - f->store_8('G'); - f->store_8('D'); - switch (mode) { - case MODE_2D_ARRAY: f->store_8('A'); break; - case MODE_CUBEMAP: f->store_8('C'); break; - case MODE_CUBEMAP_ARRAY: f->store_8('X'); break; - } - - f->store_8('T'); //godot streamable texture - - f->store_32(p_images[0]->get_width()); - f->store_32(p_images[0]->get_height()); - f->store_32(p_images.size()); //depth - uint32_t flags = 0; - if (p_mipmaps) { - flags |= TEXTURE_FLAGS_MIPMAPS; - } - f->store_32(flags); - if (p_compress_mode != COMPRESS_VIDEO_RAM) { - //vram needs to do a first compression to tell what the format is, for the rest its ok - f->store_32(p_images[0]->get_format()); - f->store_32(p_compress_mode); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed + if (mode == MODE_CUBEMAP || mode == MODE_CUBEMAP_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/arrangement", PROPERTY_HINT_ENUM, "1x6,2x3,3x2,6x1"), 1)); + if (mode == MODE_CUBEMAP_ARRAY) { + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/layout", PROPERTY_HINT_ENUM, "Horizontal,Vertical"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "slices/amount", PROPERTY_HINT_RANGE, "1,1024,1,or_greater"), 1)); + } } +} - if ((p_compress_mode == COMPRESS_LOSSLESS) && p_images[0]->get_format() > Image::FORMAT_RGBA8) { - p_compress_mode = COMPRESS_UNCOMPRESSED; //these can't go as lossy - } +void ResourceImporterLayeredTexture::_save_tex(Vector<Ref<Image>> p_images, const String &p_to_path, int p_compress_mode, float p_lossy, Image::CompressMode p_vram_compression, Image::CompressSource p_csource, Image::UsedChannels used_channels, bool p_mipmaps, bool p_force_po2) { for (int i = 0; i < p_images.size(); i++) { - switch (p_compress_mode) { - case COMPRESS_LOSSLESS: { - - Ref<Image> image = p_images[i]->duplicate(); - if (p_mipmaps) { - image->generate_mipmaps(); - } else { - image->clear_mipmaps(); - } - - int mmc = image->get_mipmap_count() + 1; - f->store_32(mmc); - - for (int j = 0; j < mmc; j++) { - - if (j > 0) { - image->shrink_x2(); - } - - Vector<uint8_t> data = Image::lossless_packer(image); - int data_len = data.size(); - f->store_32(data_len); - - const uint8_t* r = data.ptr(); - f->store_buffer(r.ptr(), data_len); - } - - } break; - case COMPRESS_VIDEO_RAM: { - - Ref<Image> image = p_images[i]->duplicate(); - image->generate_mipmaps(false); - - Image::CompressSource csource = Image::COMPRESS_SOURCE_LAYERED; - image->compress(p_vram_compression, csource, 0.7); - - if (i == 0) { - //hack so we can properly tell the format - f->store_32(image->get_format()); - f->store_32(p_compress_mode); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed - } - - Vector<uint8_t> data = image->get_data(); - int dl = data.size(); - - const uint8_t* r = data.ptr(); - f->store_buffer(r.ptr(), dl); - } break; - case COMPRESS_UNCOMPRESSED: { - - Ref<Image> image = p_images[i]->duplicate(); + if (p_force_po2) { + p_images.write[i]->resize_to_po2(); + } - if (p_mipmaps) { - image->generate_mipmaps(); - } else { - image->clear_mipmaps(); - } + if (p_mipmaps) { + p_images.write[i]->generate_mipmaps(); + } else { + p_images.write[i]->clear_mipmaps(); + } + } - Vector<uint8_t> data = image->get_data(); - int dl = data.size(); + FileAccessRef f = FileAccess::open(p_to_path, FileAccess::WRITE); + f->store_8('G'); + f->store_8('S'); + f->store_8('T'); + f->store_8('L'); - const uint8_t* r = data.ptr(); + f->store_32(StreamTextureLayered::FORMAT_VERSION); + f->store_32(p_images.size()); + f->store_32(mode); + f->store_32(0); //dataformat + f->store_32(0); //mipmap limit - f->store_buffer(r.ptr(), dl); + //reserverd + f->store_32(0); + f->store_32(0); + f->store_32(0); - } break; - } + for (int i = 0; i < p_images.size(); i++) { + ResourceImporterTexture::save_to_stex_format(f, p_images[i], ResourceImporterTexture::CompressMode(p_compress_mode), used_channels, p_vram_compression, p_lossy); } - memdelete(f); + f->close(); } Error ResourceImporterLayeredTexture::import(const String &p_source_file, const String &p_save_path, const Map<StringName, Variant> &p_options, List<String> *r_platform_variants, List<String> *r_gen_files, Variant *r_metadata) { int compress_mode = p_options["compress/mode"]; - int no_bptc_if_rgb = p_options["compress/no_bptc_if_rgb"]; - bool mipmaps = p_options["flags/mipmaps"]; + float lossy = p_options["compress/lossy_quality"]; + int hdr_compression = p_options["compress/hdr_compression"]; + int bptc_ldr = p_options["compress/bptc_ldr"]; + bool mipmaps = p_options["mipmaps/generate"]; + //bool mipmap_limit = p_options["mipmaps/limit"]; + int channel_pack = p_options["compress/channel_pack"]; int hslices = (p_options.has("slices/horizontal")) ? int(p_options["slices/horizontal"]) : 0; int vslices = (p_options.has("slices/vertical")) ? int(p_options["slices/vertical"]) : 0; + int arrangement = (p_options.has("slices/arrangement")) ? int(p_options["slices/arrangement"]) : 0; + int layout = (p_options.has("slices/layout")) ? int(p_options["slices/layout"]) : 0; + int amount = (p_options.has("slices/amount")) ? int(p_options["slices/amount"]) : 0; + + if (mode == MODE_CUBEMAP || mode == MODE_CUBEMAP_ARRAY) { + switch (arrangement) { + case CUBEMAP_FORMAT_1X6: { + hslices = 1; + vslices = 6; + } break; + case CUBEMAP_FORMAT_2X3: { + hslices = 2; + vslices = 3; + } break; + case CUBEMAP_FORMAT_3X2: { + hslices = 3; + vslices = 2; + } break; + case CUBEMAP_FORMAT_6X1: { + hslices = 6; + vslices = 1; + } break; + } - if (mode == MODE_CUBEMAP) { - hslices = 3; - vslices = 2; - } else if (mode == MODE_CUBEMAP_ARRAY) { - hslices = 3; - vslices *= 2; //put cubemaps vertically + if (mode == MODE_CUBEMAP_ARRAY) { + if (layout == 0) { + hslices *= amount; + } else { + vslices *= amount; + } + } } Ref<Image> image; @@ -263,28 +249,40 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const if (err != OK) return err; - if (compress_mode == COMPRESS_VIDEO_RAM) { - mipmaps = true; + if (compress_mode == COMPRESS_BASIS_UNIVERSAL && image->get_format() >= Image::FORMAT_RF) { + //basis universal does not support float formats, fall back + compress_mode = COMPRESS_VRAM_COMPRESSED; } - Vector<Ref<Image> > slices; - - int slice_w = image->get_width() / hslices; - int slice_h = image->get_height() / vslices; + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { + mipmaps = true; + } //optimize - if (compress_mode == COMPRESS_VIDEO_RAM) { + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { //if using video ram, optimize if (channel_pack == 0) { //remove alpha if not needed, so compression is more efficient if (image->get_format() == Image::FORMAT_RGBA8 && !image->detect_alpha()) { image->convert(Image::FORMAT_RGB8); } - } else { + } else if (image->get_format() < Image::FORMAT_RGBA8) { image->optimize_channels(); } } + Image::CompressSource csource = Image::COMPRESS_SOURCE_GENERIC; + if (channel_pack == 0) { + csource = Image::COMPRESS_SOURCE_SRGB; + } + + Image::UsedChannels used_channels = image->detect_used_channels(csource); + + Vector<Ref<Image>> slices; + + int slice_w = image->get_width() / hslices; + int slice_h = image->get_height() / vslices; + for (int i = 0; i < vslices; i++) { for (int j = 0; j < hslices; j++) { int x = slice_w * j; @@ -301,58 +299,82 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const String extension = get_save_extension(); Array formats_imported; - if (compress_mode == COMPRESS_VIDEO_RAM) { + if (compress_mode == COMPRESS_VRAM_COMPRESSED) { //must import in all formats, in order of priority (so platform choses the best supported one. IE, etc2 over etc). //Android, GLES 2.x bool ok_on_pc = false; - bool encode_bptc = false; + bool is_hdr = (image->get_format() >= Image::FORMAT_RF && image->get_format() <= Image::FORMAT_RGBE9995); + bool is_ldr = (image->get_format() >= Image::FORMAT_L8 && image->get_format() <= Image::FORMAT_RGB565); + bool can_bptc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_bptc"); + bool can_s3tc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc"); - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_bptc")) { + if (can_bptc) { + formats_imported.push_back("bptc"); //needs to be aded anyway + } + bool can_compress_hdr = hdr_compression > 0; + + if (is_hdr && can_compress_hdr) { + + if (used_channels == Image::USED_CHANNELS_LA || used_channels == Image::USED_CHANNELS_RGBA) { + //can compress hdr, but hdr with alpha is not compressible - encode_bptc = true; + if (hdr_compression == 2) { + //but user selected to compress hdr anyway, so force an alpha-less format. + if (image->get_format() == Image::FORMAT_RGBAF) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBF); + } - if (no_bptc_if_rgb) { - Image::UsedChannels channels = image->detect_used_channels(); - if (channels != Image::USED_CHANNELS_LA && channels != Image::USED_CHANNELS_RGBA) { - encode_bptc = false; + } else if (image->get_format() == Image::FORMAT_RGBAH) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBH); + } + } + } else { + can_compress_hdr = false; } } - formats_imported.push_back("bptc"); - } + if (can_compress_hdr) { - if (encode_bptc) { + if (!can_bptc) { - _save_tex(slices, p_save_path + ".bptc." + extension, compress_mode, Image::COMPRESS_BPTC, mipmaps); - r_platform_variants->push_back("bptc"); - ok_on_pc = true; + //default to rgbe + if (image->get_format() != Image::FORMAT_RGBE9995) { + for (int i = 0; i < slices.size(); i++) { + slices.write[i]->convert(Image::FORMAT_RGBE9995); + } + } + } + } else { + can_bptc = false; + } } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc")) { + if (is_ldr && can_bptc) { + if (bptc_ldr == 0 || (bptc_ldr == 1 && !(used_channels == Image::USED_CHANNELS_LA || used_channels == Image::USED_CHANNELS_RGBA))) { + can_bptc = false; + } + } - _save_tex(slices, p_save_path + ".s3tc." + extension, compress_mode, Image::COMPRESS_S3TC, mipmaps); + if (can_bptc || can_s3tc) { + _save_tex(slices, p_save_path + ".s3tc." + extension, compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, csource, used_channels, mipmaps, false); r_platform_variants->push_back("s3tc"); - ok_on_pc = true; formats_imported.push_back("s3tc"); + ok_on_pc = true; } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc2")) { - _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, Image::COMPRESS_ETC2, mipmaps); + _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, lossy, Image::COMPRESS_ETC2, csource, used_channels, mipmaps, true); r_platform_variants->push_back("etc2"); formats_imported.push_back("etc2"); } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc")) { - _save_tex(slices, p_save_path + ".etc." + extension, compress_mode, Image::COMPRESS_ETC, mipmaps); - r_platform_variants->push_back("etc"); - formats_imported.push_back("etc"); - } - if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_pvrtc")) { - _save_tex(slices, p_save_path + ".pvrtc." + extension, compress_mode, Image::COMPRESS_PVRTC4, mipmaps); + _save_tex(slices, p_save_path + ".etc2." + extension, compress_mode, lossy, Image::COMPRESS_ETC2, csource, used_channels, mipmaps, true); r_platform_variants->push_back("pvrtc"); formats_imported.push_back("pvrtc"); } @@ -362,12 +384,12 @@ Error ResourceImporterLayeredTexture::import(const String &p_source_file, const } } else { //import normally - _save_tex(slices, p_save_path + "." + extension, compress_mode, Image::COMPRESS_S3TC /*this is ignored */, mipmaps); + _save_tex(slices, p_save_path + "." + extension, compress_mode, lossy, Image::COMPRESS_S3TC /* IGNORED */, csource, used_channels, mipmaps, false); } if (r_metadata) { Dictionary metadata; - metadata["vram_texture"] = compress_mode == COMPRESS_VIDEO_RAM; + metadata["vram_texture"] = compress_mode == COMPRESS_VRAM_COMPRESSED; if (formats_imported.size()) { metadata["imported_formats"] = formats_imported; } @@ -448,4 +470,3 @@ ResourceImporterLayeredTexture::ResourceImporterLayeredTexture() { ResourceImporterLayeredTexture::~ResourceImporterLayeredTexture() { } -#endif diff --git a/editor/import/resource_importer_layered_texture.h b/editor/import/resource_importer_layered_texture.h index 40e5c9023e..18eaf31f6b 100644 --- a/editor/import/resource_importer_layered_texture.h +++ b/editor/import/resource_importer_layered_texture.h @@ -28,7 +28,6 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 /*************************************************************************/ /* resource_importer_layered_texture.h */ /*************************************************************************/ @@ -65,16 +64,24 @@ #include "core/image.h" #include "core/io/resource_importer.h" - -class StreamTexture; +class StreamTexture2D; class ResourceImporterLayeredTexture : public ResourceImporter { GDCLASS(ResourceImporterLayeredTexture, ResourceImporter); + public: enum Mode { - MODE_CUBEMAP, MODE_2D_ARRAY, - MODE_CUBEMAP_ARRAY + MODE_CUBEMAP, + MODE_CUBEMAP_ARRAY, + MODE_3D, + }; + + enum CubemapFormat { + CUBEMAP_FORMAT_1X6, + CUBEMAP_FORMAT_2X3, + CUBEMAP_FORMAT_3X2, + CUBEMAP_FORMAT_6X1, }; enum TextureFlags { @@ -86,9 +93,9 @@ private: static const char *compression_formats[]; protected: - static void _texture_reimport_srgb(const Ref<StreamTexture> &p_tex); - static void _texture_reimport_3d(const Ref<StreamTexture> &p_tex); - static void _texture_reimport_normal(const Ref<StreamTexture> &p_tex); + static void _texture_reimport_srgb(const Ref<StreamTexture2D> &p_tex); + static void _texture_reimport_3d(const Ref<StreamTexture2D> &p_tex); + static void _texture_reimport_normal(const Ref<StreamTexture2D> &p_tex); static ResourceImporterLayeredTexture *singleton; @@ -102,8 +109,10 @@ public: enum CompressMode { COMPRESS_LOSSLESS, - COMPRESS_VIDEO_RAM, - COMPRESS_UNCOMPRESSED + COMPRESS_LOSSY, + COMPRESS_VRAM_COMPRESSED, + COMPRESS_VRAM_UNCOMPRESSED, + COMPRESS_BASIS_UNIVERSAL }; virtual int get_preset_count() const; @@ -112,7 +121,7 @@ public: virtual void get_import_options(List<ImportOption> *r_options, int p_preset = 0) const; virtual bool get_option_visibility(const String &p_option, const Map<StringName, Variant> &p_options) const; - void _save_tex(const Vector<Ref<Image> > &p_images, const String &p_to_path, int p_compress_mode, Image::CompressMode p_vram_compression, bool p_mipmaps); + void _save_tex(Vector<Ref<Image>> p_images, const String &p_to_path, int p_compress_mode, float p_lossy, Image::CompressMode p_vram_compression, Image::CompressSource p_csource, Image::UsedChannels used_channels, bool p_mipmaps, bool p_force_po2); virtual Error import(const String &p_source_file, const String &p_save_path, const Map<StringName, Variant> &p_options, List<String> *r_platform_variants, List<String> *r_gen_files = nullptr, Variant *r_metadata = nullptr); @@ -126,6 +135,5 @@ public: ResourceImporterLayeredTexture(); ~ResourceImporterLayeredTexture(); }; -#endif // RESOURCE_IMPORTER_LAYERED_TEXTURE_H -#endif +#endif // RESOURCE_IMPORTER_LAYERED_TEXTURE_H diff --git a/editor/import/resource_importer_scene.cpp b/editor/import/resource_importer_scene.cpp index 239fae2268..37941109a1 100644 --- a/editor/import/resource_importer_scene.cpp +++ b/editor/import/resource_importer_scene.cpp @@ -221,16 +221,26 @@ int ResourceImporterScene::get_preset_count() const { String ResourceImporterScene::get_preset_name(int p_idx) const { switch (p_idx) { - case PRESET_SINGLE_SCENE: return TTR("Import as Single Scene"); - case PRESET_SEPARATE_ANIMATIONS: return TTR("Import with Separate Animations"); - case PRESET_SEPARATE_MATERIALS: return TTR("Import with Separate Materials"); - case PRESET_SEPARATE_MESHES: return TTR("Import with Separate Objects"); - case PRESET_SEPARATE_MESHES_AND_MATERIALS: return TTR("Import with Separate Objects+Materials"); - case PRESET_SEPARATE_MESHES_AND_ANIMATIONS: return TTR("Import with Separate Objects+Animations"); - case PRESET_SEPARATE_MATERIALS_AND_ANIMATIONS: return TTR("Import with Separate Materials+Animations"); - case PRESET_SEPARATE_MESHES_MATERIALS_AND_ANIMATIONS: return TTR("Import with Separate Objects+Materials+Animations"); - case PRESET_MULTIPLE_SCENES: return TTR("Import as Multiple Scenes"); - case PRESET_MULTIPLE_SCENES_AND_MATERIALS: return TTR("Import as Multiple Scenes+Materials"); + case PRESET_SINGLE_SCENE: + return TTR("Import as Single Scene"); + case PRESET_SEPARATE_ANIMATIONS: + return TTR("Import with Separate Animations"); + case PRESET_SEPARATE_MATERIALS: + return TTR("Import with Separate Materials"); + case PRESET_SEPARATE_MESHES: + return TTR("Import with Separate Objects"); + case PRESET_SEPARATE_MESHES_AND_MATERIALS: + return TTR("Import with Separate Objects+Materials"); + case PRESET_SEPARATE_MESHES_AND_ANIMATIONS: + return TTR("Import with Separate Objects+Animations"); + case PRESET_SEPARATE_MATERIALS_AND_ANIMATIONS: + return TTR("Import with Separate Materials+Animations"); + case PRESET_SEPARATE_MESHES_MATERIALS_AND_ANIMATIONS: + return TTR("Import with Separate Objects+Materials+Animations"); + case PRESET_MULTIPLE_SCENES: + return TTR("Import as Multiple Scenes"); + case PRESET_MULTIPLE_SCENES_AND_MATERIALS: + return TTR("Import as Multiple Scenes+Materials"); } return ""; @@ -344,7 +354,7 @@ Node *ResourceImporterScene::_fix_node(Node *p_node, Node *p_root, Map<Ref<Mesh> if (p_light_bake_mode != LIGHT_BAKE_DISABLED) { - mi->set_flag(GeometryInstance3D::FLAG_USE_BAKED_LIGHT, true); + mi->set_gi_mode(GeometryInstance3D::GI_MODE_BAKED); } } @@ -945,7 +955,7 @@ void ResourceImporterScene::_find_meshes(Node *p_node, Map<Ref<ArrayMesh>, Trans Transform transform; while (s) { transform = transform * s->get_transform(); - s = s->get_parent_spatial(); + s = Object::cast_to<Node3D>(s->get_parent()); } meshes[mesh] = transform; @@ -1348,8 +1358,9 @@ Error ResourceImporterScene::import(const String &p_source_file, const String &p scene->set_script(Variant(root_script)); } + float root_scale = 1.0; if (Object::cast_to<Node3D>(scene)) { - float root_scale = p_options["nodes/root_scale"]; + root_scale = p_options["nodes/root_scale"]; Object::cast_to<Node3D>(scene)->scale(Vector3(root_scale, root_scale, root_scale)); } @@ -1571,7 +1582,9 @@ Error ResourceImporterScene::import(const String &p_source_file, const String &p post_import_script->init(base_path, p_source_file); scene = post_import_script->post_import(scene); if (!scene) { - EditorNode::add_io_error(TTR("Error running post-import script:") + " " + post_import_script_path); + EditorNode::add_io_error( + TTR("Error running post-import script:") + " " + post_import_script_path + "\n" + + TTR("Did you return a Node-derived object in the `post_import()` method?")); return err; } } diff --git a/editor/import/resource_importer_shader_file.cpp b/editor/import/resource_importer_shader_file.cpp index a2f178de12..f085341969 100644 --- a/editor/import/resource_importer_shader_file.cpp +++ b/editor/import/resource_importer_shader_file.cpp @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* resource_importer_shader_file.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #include "resource_importer_shader_file.h" #include "core/io/marshalls.h" @@ -73,7 +103,7 @@ Error ResourceImporterShaderFile::import(const String &p_source_file, const Stri Ref<RDShaderFile> shader_file; shader_file.instance(); String base_path = p_source_file.get_base_dir(); - err = shader_file->parse_versions_from_text(file_txt, _include_function, &base_path); + err = shader_file->parse_versions_from_text(file_txt, "", _include_function, &base_path); if (err != OK) { if (!ShaderFileEditor::singleton->is_visible_in_tree()) { diff --git a/editor/import/resource_importer_shader_file.h b/editor/import/resource_importer_shader_file.h index f6b50bee9e..fa95ceecc1 100644 --- a/editor/import/resource_importer_shader_file.h +++ b/editor/import/resource_importer_shader_file.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* resource_importer_shader_file.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef RESOURCE_IMPORTER_SHADER_FILE_H #define RESOURCE_IMPORTER_SHADER_FILE_H diff --git a/editor/import/resource_importer_texture.cpp b/editor/import/resource_importer_texture.cpp index f8ed9304b6..111eab78b4 100644 --- a/editor/import/resource_importer_texture.cpp +++ b/editor/import/resource_importer_texture.cpp @@ -36,7 +36,7 @@ #include "editor/editor_file_system.h" #include "editor/editor_node.h" -void ResourceImporterTexture::_texture_reimport_roughness(const Ref<StreamTexture> &p_tex, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_channel) { +void ResourceImporterTexture::_texture_reimport_roughness(const Ref<StreamTexture2D> &p_tex, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_channel) { MutexLock lock(singleton->mutex); @@ -51,7 +51,7 @@ void ResourceImporterTexture::_texture_reimport_roughness(const Ref<StreamTextur singleton->make_flags[path].normal_path_for_roughness = p_normal_path; } -void ResourceImporterTexture::_texture_reimport_3d(const Ref<StreamTexture> &p_tex) { +void ResourceImporterTexture::_texture_reimport_3d(const Ref<StreamTexture2D> &p_tex) { MutexLock lock(singleton->mutex); @@ -64,7 +64,7 @@ void ResourceImporterTexture::_texture_reimport_3d(const Ref<StreamTexture> &p_t singleton->make_flags[path].flags |= MAKE_3D_FLAG; } -void ResourceImporterTexture::_texture_reimport_normal(const Ref<StreamTexture> &p_tex) { +void ResourceImporterTexture::_texture_reimport_normal(const Ref<StreamTexture2D> &p_tex) { MutexLock lock(singleton->mutex); @@ -157,7 +157,7 @@ String ResourceImporterTexture::get_save_extension() const { String ResourceImporterTexture::get_resource_type() const { - return "StreamTexture"; + return "StreamTexture2D"; } bool ResourceImporterTexture::get_option_visibility(const String &p_option, const Map<StringName, Variant> &p_options) const { @@ -207,8 +207,8 @@ void ResourceImporterTexture::get_import_options(List<ImportOption> *r_options, r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/mode", PROPERTY_HINT_ENUM, "Lossless,Lossy,VRAM Compressed,VRAM Uncompressed,Basis Universal", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_UPDATE_ALL_IF_MODIFIED), p_preset == PRESET_3D ? 2 : 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "compress/lossy_quality", PROPERTY_HINT_RANGE, "0,1,0.01"), 0.7)); - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_mode", PROPERTY_HINT_ENUM, "Enabled,Force RGBE"), 0)); - r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Enabled,RGBA Only"), 0)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/hdr_compression", PROPERTY_HINT_ENUM, "Disabled,Opaque Only,Always"), 1)); + r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/bptc_ldr", PROPERTY_HINT_ENUM, "Disabled,Enabled,RGBA Only"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/normal_map", PROPERTY_HINT_ENUM, "Detect,Enable,Disabled"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/channel_pack", PROPERTY_HINT_ENUM, "sRGB Friendly,Optimized"), 0)); r_options->push_back(ImportOption(PropertyInfo(Variant::INT, "compress/streamed"), false)); @@ -225,12 +225,12 @@ void ResourceImporterTexture::get_import_options(List<ImportOption> *r_options, r_options->push_back(ImportOption(PropertyInfo(Variant::FLOAT, "svg/scale", PROPERTY_HINT_RANGE, "0.001,100,0.001"), 1.0)); } -void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image> &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality, bool p_force_rgbe) { +void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image> &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality) { switch (p_compress_mode) { case COMPRESS_LOSSLESS: { - f->store_32(StreamTexture::DATA_FORMAT_LOSSLESS); + f->store_32(StreamTexture2D::DATA_FORMAT_LOSSLESS); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -249,7 +249,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image } break; case COMPRESS_LOSSY: { - f->store_32(StreamTexture::DATA_FORMAT_LOSSY); + f->store_32(StreamTexture2D::DATA_FORMAT_LOSSY); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -269,13 +269,9 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image Ref<Image> image = p_image->duplicate(); - if (p_force_rgbe && image->get_format() >= Image::FORMAT_RF && image->get_format() < Image::FORMAT_RGBE9995) { - image->convert(Image::FORMAT_RGBE9995); - } else { - image->compress_from_channels(p_compress_format, p_channels, p_lossy_quality); - } + image->compress_from_channels(p_compress_format, p_channels, p_lossy_quality); - f->store_32(StreamTexture::DATA_FORMAT_IMAGE); + f->store_32(StreamTexture2D::DATA_FORMAT_IMAGE); f->store_16(image->get_width()); f->store_16(image->get_height()); f->store_32(image->get_mipmap_count()); @@ -288,7 +284,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image } break; case COMPRESS_VRAM_UNCOMPRESSED: { - f->store_32(StreamTexture::DATA_FORMAT_IMAGE); + f->store_32(StreamTexture2D::DATA_FORMAT_IMAGE); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -303,7 +299,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image } break; case COMPRESS_BASIS_UNIVERSAL: { - f->store_32(StreamTexture::DATA_FORMAT_BASIS_UNIVERSAL); + f->store_32(StreamTexture2D::DATA_FORMAT_BASIS_UNIVERSAL); f->store_16(p_image->get_width()); f->store_16(p_image->get_height()); f->store_32(p_image->get_mipmap_count()); @@ -322,7 +318,7 @@ void ResourceImporterTexture::save_to_stex_format(FileAccess *f, const Ref<Image } } -void ResourceImporterTexture::_save_stex(const Ref<Image> &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_roughness, bool p_force_rgbe, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref<Image> &p_normal, Image::RoughnessChannel p_roughness_channel) { +void ResourceImporterTexture::_save_stex(const Ref<Image> &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_roughness, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref<Image> &p_normal, Image::RoughnessChannel p_roughness_channel) { FileAccess *f = FileAccess::open(p_to_path, FileAccess::WRITE); f->store_8('G'); @@ -331,22 +327,22 @@ void ResourceImporterTexture::_save_stex(const Ref<Image> &p_image, const String f->store_8('2'); //godot streamable texture 2D //format version - f->store_32(StreamTexture::FORMAT_VERSION); + f->store_32(StreamTexture2D::FORMAT_VERSION); //texture may be resized later, so original size must be saved first f->store_32(p_image->get_width()); f->store_32(p_image->get_height()); uint32_t flags = 0; if (p_streamable) - flags |= StreamTexture::FORMAT_BIT_STREAM; + flags |= StreamTexture2D::FORMAT_BIT_STREAM; if (p_mipmaps) - flags |= StreamTexture::FORMAT_BIT_HAS_MIPMAPS; //mipmaps bit + flags |= StreamTexture2D::FORMAT_BIT_HAS_MIPMAPS; //mipmaps bit if (p_detect_3d) - flags |= StreamTexture::FORMAT_BIT_DETECT_3D; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_3D; if (p_detect_roughness) - flags |= StreamTexture::FORMAT_BIT_DETECT_ROUGNESS; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_ROUGNESS; if (p_detect_normal) - flags |= StreamTexture::FORMAT_BIT_DETECT_NORMAL; + flags |= StreamTexture2D::FORMAT_BIT_DETECT_NORMAL; f->store_32(flags); f->store_32(p_limit_mipmap); @@ -385,10 +381,6 @@ void ResourceImporterTexture::_save_stex(const Ref<Image> &p_image, const String image->generate_mipmap_roughness(p_roughness_channel, p_normal); } - if (p_force_rgbe && image->get_format() >= Image::FORMAT_RF && image->get_format() < Image::FORMAT_RGBE9995) { - image->convert(Image::FORMAT_RGBE9995); - } - Image::CompressSource csource = Image::COMPRESS_SOURCE_GENERIC; if (p_force_normal) { csource = Image::COMPRESS_SOURCE_NORMAL; @@ -398,7 +390,7 @@ void ResourceImporterTexture::_save_stex(const Ref<Image> &p_image, const String Image::UsedChannels used_channels = image->detect_used_channels(csource); - save_to_stex_format(f, image, p_compress_mode, used_channels, p_vram_compression, p_lossy_quality, p_force_rgbe); + save_to_stex_format(f, image, p_compress_mode, used_channels, p_vram_compression, p_lossy_quality); memdelete(f); } @@ -418,7 +410,7 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String bool hdr_as_srgb = p_options["process/HDR_as_SRGB"]; int normal = p_options["compress/normal_map"]; float scale = p_options["svg/scale"]; - bool force_rgbe = int(p_options["compress/hdr_mode"]) == 1; + int hdr_compression = p_options["compress/hdr_compression"]; int bptc_ldr = p_options["compress/bptc_ldr"]; int roughness = p_options["roughness/mode"]; String normal_map = p_options["roughness/src_normal"]; @@ -501,30 +493,49 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String bool can_s3tc = ProjectSettings::get_singleton()->get("rendering/vram_compression/import_s3tc"); if (can_bptc) { - Image::UsedChannels channels = image->detect_used_channels(); - if (is_hdr) { + //add to the list anyway + formats_imported.push_back("bptc"); + } - if (channels == Image::USED_CHANNELS_LA || channels == Image::USED_CHANNELS_RGBA) { - can_bptc = false; + bool can_compress_hdr = hdr_compression > 0; + bool has_alpha = image->detect_alpha() != Image::ALPHA_NONE; + + if (is_hdr && can_compress_hdr) { + + if (has_alpha) { + //can compress hdr, but hdr with alpha is not compressible + if (hdr_compression == 2) { + //but user selected to compress hdr anyway, so force an alpha-less format. + if (image->get_format() == Image::FORMAT_RGBAF) { + image->convert(Image::FORMAT_RGBF); + } else if (image->get_format() == Image::FORMAT_RGBAH) { + image->convert(Image::FORMAT_RGBH); + } + } else { + can_compress_hdr = false; } - } else if (is_ldr) { + } - //handle "RGBA Only" setting - if (bptc_ldr == 1 && channels != Image::USED_CHANNELS_LA && channels != Image::USED_CHANNELS_RGBA) { - can_bptc = false; + if (can_compress_hdr) { + if (!can_bptc) { + //fallback to RGBE99995 + if (image->get_format() != Image::FORMAT_RGBE9995) { + image->convert(Image::FORMAT_RGBE9995); + } } + } else { + can_bptc = false; } - - formats_imported.push_back("bptc"); } - if (!can_bptc && is_hdr && !force_rgbe) { - //convert to ldr if this can't be stored hdr - image->convert(Image::FORMAT_RGBA8); + if (is_ldr && can_bptc) { + if (bptc_ldr == 0 || (bptc_ldr == 1 && !has_alpha)) { + can_bptc = false; + } } if (can_bptc || can_s3tc) { - _save_stex(image, p_save_path + ".s3tc.stex", compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".s3tc.stex", compress_mode, lossy, can_bptc ? Image::COMPRESS_BPTC : Image::COMPRESS_S3TC, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("s3tc"); formats_imported.push_back("s3tc"); ok_on_pc = true; @@ -532,20 +543,20 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc2")) { - _save_stex(image, p_save_path + ".etc2.stex", compress_mode, lossy, Image::COMPRESS_ETC2, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".etc2.stex", compress_mode, lossy, Image::COMPRESS_ETC2, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("etc2"); formats_imported.push_back("etc2"); } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_etc")) { - _save_stex(image, p_save_path + ".etc.stex", compress_mode, lossy, Image::COMPRESS_ETC, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".etc.stex", compress_mode, lossy, Image::COMPRESS_ETC, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("etc"); formats_imported.push_back("etc"); } if (ProjectSettings::get_singleton()->get("rendering/vram_compression/import_pvrtc")) { - _save_stex(image, p_save_path + ".pvrtc.stex", compress_mode, lossy, Image::COMPRESS_PVRTC4, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".pvrtc.stex", compress_mode, lossy, Image::COMPRESS_PVRTC4, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, true, mipmap_limit, normal_image, roughness_channel); r_platform_variants->push_back("pvrtc"); formats_imported.push_back("pvrtc"); } @@ -555,7 +566,7 @@ Error ResourceImporterTexture::import(const String &p_source_file, const String } } else { //import normally - _save_stex(image, p_save_path + ".stex", compress_mode, lossy, Image::COMPRESS_S3TC /*this is ignored */, mipmaps, stream, detect_3d, detect_roughness, force_rgbe, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); + _save_stex(image, p_save_path + ".stex", compress_mode, lossy, Image::COMPRESS_S3TC /*this is ignored */, mipmaps, stream, detect_3d, detect_roughness, detect_normal, force_normal, srgb_friendly_pack, false, mipmap_limit, normal_image, roughness_channel); } if (r_metadata) { @@ -635,9 +646,9 @@ ResourceImporterTexture *ResourceImporterTexture::singleton = nullptr; ResourceImporterTexture::ResourceImporterTexture() { singleton = this; - StreamTexture::request_3d_callback = _texture_reimport_3d; - StreamTexture::request_roughness_callback = _texture_reimport_roughness; - StreamTexture::request_normal_callback = _texture_reimport_normal; + StreamTexture2D::request_3d_callback = _texture_reimport_3d; + StreamTexture2D::request_roughness_callback = _texture_reimport_roughness; + StreamTexture2D::request_normal_callback = _texture_reimport_normal; } ResourceImporterTexture::~ResourceImporterTexture() { diff --git a/editor/import/resource_importer_texture.h b/editor/import/resource_importer_texture.h index e1c71ff1b8..da8ce3c0a8 100644 --- a/editor/import/resource_importer_texture.h +++ b/editor/import/resource_importer_texture.h @@ -37,7 +37,7 @@ #include "scene/resources/texture.h" #include "servers/rendering_server.h" -class StreamTexture; +class StreamTexture2D; class ResourceImporterTexture : public ResourceImporter { GDCLASS(ResourceImporterTexture, ResourceImporter); @@ -72,17 +72,17 @@ protected: Map<StringName, MakeInfo> make_flags; - static void _texture_reimport_roughness(const Ref<StreamTexture> &p_tex, const String &p_normal_path, RenderingServer::TextureDetectRoughnessChannel p_channel); - static void _texture_reimport_3d(const Ref<StreamTexture> &p_tex); - static void _texture_reimport_normal(const Ref<StreamTexture> &p_tex); + static void _texture_reimport_roughness(const Ref<StreamTexture2D> &p_tex, const String &p_normal_path, RenderingServer::TextureDetectRoughnessChannel p_channel); + static void _texture_reimport_3d(const Ref<StreamTexture2D> &p_tex); + static void _texture_reimport_normal(const Ref<StreamTexture2D> &p_tex); static ResourceImporterTexture *singleton; static const char *compression_formats[]; - void _save_stex(const Ref<Image> &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_srgb, bool p_force_rgbe, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref<Image> &p_normal, Image::RoughnessChannel p_roughness_channel); + void _save_stex(const Ref<Image> &p_image, const String &p_to_path, CompressMode p_compress_mode, float p_lossy_quality, Image::CompressMode p_vram_compression, bool p_mipmaps, bool p_streamable, bool p_detect_3d, bool p_detect_srgb, bool p_detect_normal, bool p_force_normal, bool p_srgb_friendly, bool p_force_po2_for_compressed, uint32_t p_limit_mipmap, const Ref<Image> &p_normal, Image::RoughnessChannel p_roughness_channel); public: - void save_to_stex_format(FileAccess *f, const Ref<Image> &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality, bool p_force_rgbe); + static void save_to_stex_format(FileAccess *f, const Ref<Image> &p_image, CompressMode p_compress_mode, Image::UsedChannels p_channels, Image::CompressMode p_compress_format, float p_lossy_quality); static ResourceImporterTexture *get_singleton() { return singleton; } virtual String get_importer_name() const; diff --git a/editor/node_3d_editor_gizmos.cpp b/editor/node_3d_editor_gizmos.cpp index c321e85546..cb0d9fa02b 100644 --- a/editor/node_3d_editor_gizmos.cpp +++ b/editor/node_3d_editor_gizmos.cpp @@ -41,6 +41,7 @@ #include "scene/3d/gi_probe.h" #include "scene/3d/gpu_particles_3d.h" #include "scene/3d/light_3d.h" +#include "scene/3d/lightmap_probe.h" #include "scene/3d/listener_3d.h" #include "scene/3d/mesh_instance_3d.h" #include "scene/3d/navigation_region_3d.h" @@ -431,7 +432,8 @@ bool EditorNode3DGizmo::intersect_frustum(const Camera3D *p_camera, const Vector ERR_FAIL_COND_V(!spatial_node, false); ERR_FAIL_COND_V(!valid, false); - if (hidden && !gizmo_plugin->is_selectable_when_hidden()) return false; + if (hidden && !gizmo_plugin->is_selectable_when_hidden()) + return false; if (selectable_icon_size > 0.0f) { Vector3 origin = spatial_node->get_global_transform().get_origin(); @@ -470,10 +472,12 @@ bool EditorNode3DGizmo::intersect_frustum(const Camera3D *p_camera, const Vector break; } } - if (any_out) break; + if (any_out) + break; } - if (!any_out) return true; + if (!any_out) + return true; } if (collision_mesh.is_valid()) { @@ -504,7 +508,8 @@ bool EditorNode3DGizmo::intersect_ray(Camera3D *p_camera, const Point2 &p_point, ERR_FAIL_COND_V(!spatial_node, false); ERR_FAIL_COND_V(!valid, false); - if (hidden && !gizmo_plugin->is_selectable_when_hidden()) return false; + if (hidden && !gizmo_plugin->is_selectable_when_hidden()) + return false; if (r_gizmo_handle && !hidden) { @@ -785,7 +790,8 @@ EditorNode3DGizmo::EditorNode3DGizmo() { EditorNode3DGizmo::~EditorNode3DGizmo() { - if (gizmo_plugin != nullptr) gizmo_plugin->unregister_gizmo(this); + if (gizmo_plugin != nullptr) + gizmo_plugin->unregister_gizmo(this); clear(); } @@ -2207,12 +2213,18 @@ int VisibilityNotifier3DGizmoPlugin::get_priority() const { String VisibilityNotifier3DGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { switch (p_idx) { - case 0: return "Size X"; - case 1: return "Size Y"; - case 2: return "Size Z"; - case 3: return "Pos X"; - case 4: return "Pos Y"; - case 5: return "Pos Z"; + case 0: + return "Size X"; + case 1: + return "Size Y"; + case 2: + return "Size Z"; + case 3: + return "Pos X"; + case 4: + return "Pos Y"; + case 5: + return "Pos Z"; } return ""; @@ -2399,12 +2411,18 @@ bool GPUParticles3DGizmoPlugin::is_selectable_when_hidden() const { String GPUParticles3DGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { switch (p_idx) { - case 0: return "Size X"; - case 1: return "Size Y"; - case 2: return "Size Z"; - case 3: return "Pos X"; - case 4: return "Pos Y"; - case 5: return "Pos Z"; + case 0: + return "Size X"; + case 1: + return "Size Y"; + case 2: + return "Size Z"; + case 3: + return "Pos X"; + case 4: + return "Pos Y"; + case 5: + return "Pos Z"; } return ""; @@ -2564,12 +2582,18 @@ int ReflectionProbeGizmoPlugin::get_priority() const { String ReflectionProbeGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { switch (p_idx) { - case 0: return "Extents X"; - case 1: return "Extents Y"; - case 2: return "Extents Z"; - case 3: return "Origin X"; - case 4: return "Origin Y"; - case 5: return "Origin Z"; + case 0: + return "Extents X"; + case 1: + return "Extents Y"; + case 2: + return "Extents Z"; + case 3: + return "Origin X"; + case 4: + return "Origin Y"; + case 5: + return "Origin Z"; } return ""; @@ -2747,9 +2771,12 @@ int DecalGizmoPlugin::get_priority() const { String DecalGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { switch (p_idx) { - case 0: return "Extents X"; - case 1: return "Extents Y"; - case 2: return "Extents Z"; + case 0: + return "Extents X"; + case 1: + return "Extents Y"; + case 2: + return "Extents Z"; } return ""; @@ -2888,9 +2915,12 @@ int GIProbeGizmoPlugin::get_priority() const { String GIProbeGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { switch (p_idx) { - case 0: return "Extents X"; - case 1: return "Extents Y"; - case 2: return "Extents Z"; + case 0: + return "Extents X"; + case 1: + return "Extents Y"; + case 2: + return "Extents Z"; } return ""; @@ -3040,136 +3070,296 @@ void GIProbeGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { } //// -#if 0 -BakedIndirectLightGizmoPlugin::BakedIndirectLightGizmoPlugin() { - Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/baked_indirect_light", Color(0.5, 0.6, 1)); - create_material("baked_indirect_light_material", gizmo_color); +BakedLightmapGizmoPlugin::BakedLightmapGizmoPlugin() { + Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/lightmap_lines", Color(0.5, 0.6, 1)); gizmo_color.a = 0.1; - create_material("baked_indirect_light_internal_material", gizmo_color); + create_material("lightmap_lines", gizmo_color); - create_icon_material("baked_indirect_light_icon", Node3DEditor::get_singleton()->get_icon("GizmoBakedLightmap", "EditorIcons")); - create_handle_material("handles"); -} + Ref<StandardMaterial3D> mat = memnew(StandardMaterial3D); + mat->set_shading_mode(StandardMaterial3D::SHADING_MODE_UNSHADED); + mat->set_cull_mode(StandardMaterial3D::CULL_DISABLED); + mat->set_flag(StandardMaterial3D::FLAG_ALBEDO_FROM_VERTEX_COLOR, true); + mat->set_flag(StandardMaterial3D::FLAG_SRGB_VERTEX_COLOR, false); -String BakedIndirectLightGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { + add_material("lightmap_probe_material", mat); - switch (p_idx) { - case 0: return "Extents X"; - case 1: return "Extents Y"; - case 2: return "Extents Z"; - } + create_icon_material("baked_indirect_light_icon", Node3DEditor::get_singleton()->get_theme_icon("GizmoBakedLightmap", "EditorIcons")); +} + +String BakedLightmapGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { return ""; } -Variant BakedIndirectLightGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { +Variant BakedLightmapGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { - BakedLightmap *baker = Object::cast_to<BakedLightmap>(p_gizmo->get_spatial_node()); - return baker->get_extents(); + return Variant(); +} +void BakedLightmapGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point) { } -void BakedIndirectLightGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera *p_camera, const Point2 &p_point) { +void BakedLightmapGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { +} + +bool BakedLightmapGizmoPlugin::has_gizmo(Node3D *p_spatial) { + return Object::cast_to<BakedLightmap>(p_spatial) != nullptr; +} + +String BakedLightmapGizmoPlugin::get_name() const { + return "BakedLightmap"; +} + +int BakedLightmapGizmoPlugin::get_priority() const { + return -1; +} + +void BakedLightmapGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { + + Ref<Material> icon = get_material("baked_indirect_light_icon", p_gizmo); BakedLightmap *baker = Object::cast_to<BakedLightmap>(p_gizmo->get_spatial_node()); + Ref<BakedLightmapData> data = baker->get_light_data(); - Transform gt = baker->get_global_transform(); - Transform gi = gt.affine_inverse(); + p_gizmo->add_unscaled_billboard(icon, 0.05); - Vector3 extents = baker->get_extents(); + if (data.is_null()) { + return; + } - Vector3 ray_from = p_camera->project_ray_origin(p_point); - Vector3 ray_dir = p_camera->project_ray_normal(p_point); + Ref<Material> material_lines = get_material("lightmap_lines", p_gizmo); + Ref<Material> material_probes = get_material("lightmap_probe_material", p_gizmo); - Vector3 sg[2] = { gi.xform(ray_from), gi.xform(ray_from + ray_dir * 16384) }; + p_gizmo->clear(); - Vector3 axis; - axis[p_idx] = 1.0; + Vector<Vector3> lines; + Set<Vector2i> lines_found; - Vector3 ra, rb; - Geometry::get_closest_points_between_segments(Vector3(), axis * 16384, sg[0], sg[1], ra, rb); - float d = ra[p_idx]; - if (Node3DEditor::get_singleton()->is_snap_enabled()) { - d = Math::stepify(d, Node3DEditor::get_singleton()->get_translate_snap()); + Vector<Vector3> points = data->get_capture_points(); + if (points.size() == 0) { + return; + } + Vector<Color> sh = data->get_capture_sh(); + if (sh.size() != points.size() * 9) { + return; } - if (d < 0.001) - d = 0.001; + Vector<int> tetrahedrons = data->get_capture_tetrahedra(); - extents[p_idx] = d; - baker->set_extents(extents); -} + for (int i = 0; i < tetrahedrons.size(); i += 4) { -void BakedIndirectLightGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { + for (int j = 0; j < 4; j++) { + for (int k = j + 1; k < 4; k++) { - BakedLightmap *baker = Object::cast_to<BakedLightmap>(p_gizmo->get_spatial_node()); + Vector2i pair; + pair.x = tetrahedrons[i + j]; + pair.y = tetrahedrons[i + k]; - Vector3 restore = p_restore; + if (pair.y < pair.x) { + SWAP(pair.x, pair.y); + } + if (lines_found.has(pair)) { + continue; + } + lines_found.insert(pair); + lines.push_back(points[pair.x]); + lines.push_back(points[pair.y]); + } + } + } - if (p_cancel) { - baker->set_extents(restore); - return; + p_gizmo->add_lines(lines, material_lines); + + int stack_count = 8; + int sector_count = 16; + + float sector_step = 2 * Math_PI / sector_count; + float stack_step = Math_PI / stack_count; + + Vector<Vector3> vertices; + Vector<Color> colors; + Vector<int> indices; + float radius = 0.3; + + for (int p = 0; p < points.size(); p++) { + + int vertex_base = vertices.size(); + Vector3 sh_col[9]; + for (int i = 0; i < 9; i++) { + sh_col[i].x = sh[p * 9 + i].r; + sh_col[i].y = sh[p * 9 + i].g; + sh_col[i].z = sh[p * 9 + i].b; + } + + for (int i = 0; i <= stack_count; ++i) { + float stack_angle = Math_PI / 2 - i * stack_step; // starting from pi/2 to -pi/2 + float xy = radius * Math::cos(stack_angle); // r * cos(u) + float z = radius * Math::sin(stack_angle); // r * sin(u) + + // add (sector_count+1) vertices per stack + // the first and last vertices have same position and normal, but different tex coords + for (int j = 0; j <= sector_count; ++j) { + float sector_angle = j * sector_step; // starting from 0 to 2pi + + // vertex position (x, y, z) + float x = xy * Math::cos(sector_angle); // r * cos(u) * cos(v) + float y = xy * Math::sin(sector_angle); // r * cos(u) * sin(v) + + Vector3 n = Vector3(x, z, y); + vertices.push_back(points[p] + n); + n.normalize(); + + const float c1 = 0.429043; + const float c2 = 0.511664; + const float c3 = 0.743125; + const float c4 = 0.886227; + const float c5 = 0.247708; + Vector3 light = (c1 * sh_col[8] * (n.x * n.x - n.y * n.y) + + c3 * sh_col[6] * n.z * n.z + + c4 * sh_col[0] - + c5 * sh_col[6] + + 2.0 * c1 * sh_col[4] * n.x * n.y + + 2.0 * c1 * sh_col[7] * n.x * n.z + + 2.0 * c1 * sh_col[5] * n.y * n.z + + 2.0 * c2 * sh_col[3] * n.x + + 2.0 * c2 * sh_col[1] * n.y + + 2.0 * c2 * sh_col[2] * n.z); + + colors.push_back(Color(light.x, light.y, light.z, 1)); + } + } + + for (int i = 0; i < stack_count; ++i) { + int k1 = i * (sector_count + 1); // beginning of current stack + int k2 = k1 + sector_count + 1; // beginning of next stack + + for (int j = 0; j < sector_count; ++j, ++k1, ++k2) { + // 2 triangles per sector excluding first and last stacks + // k1 => k2 => k1+1 + if (i != 0) { + indices.push_back(vertex_base + k1); + indices.push_back(vertex_base + k2); + indices.push_back(vertex_base + k1 + 1); + } + + // k1+1 => k2 => k2+1 + if (i != (stack_count - 1)) { + indices.push_back(vertex_base + k1 + 1); + indices.push_back(vertex_base + k2); + indices.push_back(vertex_base + k2 + 1); + } + } + } } - UndoRedo *ur = Node3DEditor::get_singleton()->get_undo_redo(); - ur->create_action(TTR("Change Probe Extents")); - ur->add_do_method(baker, "set_extents", baker->get_extents()); - ur->add_undo_method(baker, "set_extents", restore); - ur->commit_action(); + Array array; + array.resize(RS::ARRAY_MAX); + array[RS::ARRAY_VERTEX] = vertices; + array[RS::ARRAY_INDEX] = indices; + array[RS::ARRAY_COLOR] = colors; + + Ref<ArrayMesh> mesh; + mesh.instance(); + mesh->add_surface_from_arrays(Mesh::PRIMITIVE_TRIANGLES, array, Array(), Dictionary(), 0); //no compression + mesh->surface_set_material(0, material_probes); + + p_gizmo->add_mesh(mesh); } +///////// -bool BakedIndirectLightGizmoPlugin::has_gizmo(Spatial *p_spatial) { - return Object::cast_to<BakedLightmap>(p_spatial) != nullptr; +LightmapProbeGizmoPlugin::LightmapProbeGizmoPlugin() { + Color gizmo_color = EDITOR_DEF("editors/3d_gizmos/gizmo_colors/lightprobe_lines", Color(0.5, 0.6, 1)); + + gizmo_color.a = 0.3; + create_material("lightprobe_lines", gizmo_color); } -String BakedIndirectLightGizmoPlugin::get_name() const { - return "BakedLightmap"; +String LightmapProbeGizmoPlugin::get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const { + + return ""; } +Variant LightmapProbeGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const { -int BakedIndirectLightGizmoPlugin::get_priority() const { - return -1; + return Variant(); +} +void LightmapProbeGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point) { } -void BakedIndirectLightGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { +void LightmapProbeGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel) { +} - BakedLightmap *baker = Object::cast_to<BakedLightmap>(p_gizmo->get_spatial_node()); +bool LightmapProbeGizmoPlugin::has_gizmo(Node3D *p_spatial) { + return Object::cast_to<LightmapProbe>(p_spatial) != nullptr; +} - Ref<Material> material = get_material("baked_indirect_light_material", p_gizmo); - Ref<Material> icon = get_material("baked_indirect_light_icon", p_gizmo); - Ref<Material> material_internal = get_material("baked_indirect_light_internal_material", p_gizmo); +String LightmapProbeGizmoPlugin::get_name() const { + return "LightmapProbe"; +} + +int LightmapProbeGizmoPlugin::get_priority() const { + return -1; +} + +void LightmapProbeGizmoPlugin::redraw(EditorNode3DGizmo *p_gizmo) { + + Ref<Material> material_lines = get_material("lightprobe_lines", p_gizmo); p_gizmo->clear(); Vector<Vector3> lines; - Vector3 extents = baker->get_extents(); - AABB aabb = AABB(-extents, extents * 2); + int stack_count = 8; + int sector_count = 16; - for (int i = 0; i < 12; i++) { - Vector3 a, b; - aabb.get_edge(i, a, b); - lines.push_back(a); - lines.push_back(b); - } + float sector_step = 2 * Math_PI / sector_count; + float stack_step = Math_PI / stack_count; - p_gizmo->add_lines(lines, material); + Vector<Vector3> vertices; + float radius = 0.2; - Vector<Vector3> handles; + for (int i = 0; i <= stack_count; ++i) { + float stack_angle = Math_PI / 2 - i * stack_step; // starting from pi/2 to -pi/2 + float xy = radius * Math::cos(stack_angle); // r * cos(u) + float z = radius * Math::sin(stack_angle); // r * sin(u) - for (int i = 0; i < 3; i++) { + // add (sector_count+1) vertices per stack + // the first and last vertices have same position and normal, but different tex coords + for (int j = 0; j <= sector_count; ++j) { + float sector_angle = j * sector_step; // starting from 0 to 2pi - Vector3 ax; - ax[i] = aabb.position[i] + aabb.size[i]; - handles.push_back(ax); + // vertex position (x, y, z) + float x = xy * Math::cos(sector_angle); // r * cos(u) * cos(v) + float y = xy * Math::sin(sector_angle); // r * cos(u) * sin(v) + + Vector3 n = Vector3(x, z, y); + vertices.push_back(n); + } } - if (p_gizmo->is_selected()) { - p_gizmo->add_solid_box(material_internal, aabb.get_size()); + for (int i = 0; i < stack_count; ++i) { + int k1 = i * (sector_count + 1); // beginning of current stack + int k2 = k1 + sector_count + 1; // beginning of next stack + + for (int j = 0; j < sector_count; ++j, ++k1, ++k2) { + // 2 triangles per sector excluding first and last stacks + // k1 => k2 => k1+1 + if (i != 0) { + lines.push_back(vertices[k1]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k1]); + lines.push_back(vertices[k1 + 1]); + } + + if (i != (stack_count - 1)) { + lines.push_back(vertices[k1 + 1]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k2]); + lines.push_back(vertices[k2 + 1]); + } + } } - p_gizmo->add_unscaled_billboard(icon, 0.05); - p_gizmo->add_handles(handles, get_material("handles")); + p_gizmo->add_lines(lines, material_lines); } -#endif //// CollisionShape3DGizmoPlugin::CollisionShape3DGizmoPlugin() { diff --git a/editor/node_3d_editor_gizmos.h b/editor/node_3d_editor_gizmos.h index 6432feeecb..c25fff528c 100644 --- a/editor/node_3d_editor_gizmos.h +++ b/editor/node_3d_editor_gizmos.h @@ -321,25 +321,42 @@ public: GIProbeGizmoPlugin(); }; -#if 0 -class BakedIndirectLightGizmoPlugin : public EditorNode3DGizmoPlugin { +class BakedLightmapGizmoPlugin : public EditorNode3DGizmoPlugin { - GDCLASS(BakedIndirectLightGizmoPlugin, EditorNode3DGizmoPlugin); + GDCLASS(BakedLightmapGizmoPlugin, EditorNode3DGizmoPlugin); public: - bool has_gizmo(Spatial *p_spatial); + bool has_gizmo(Node3D *p_spatial); + String get_name() const; + int get_priority() const; + void redraw(EditorNode3DGizmo *p_gizmo); + + String get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const; + Variant get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const; + void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point); + void commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel = false); + + BakedLightmapGizmoPlugin(); +}; + +class LightmapProbeGizmoPlugin : public EditorNode3DGizmoPlugin { + + GDCLASS(LightmapProbeGizmoPlugin, EditorNode3DGizmoPlugin); + +public: + bool has_gizmo(Node3D *p_spatial); String get_name() const; int get_priority() const; void redraw(EditorNode3DGizmo *p_gizmo); String get_handle_name(const EditorNode3DGizmo *p_gizmo, int p_idx) const; Variant get_handle_value(EditorNode3DGizmo *p_gizmo, int p_idx) const; - void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera *p_camera, const Point2 &p_point); + void set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Camera3D *p_camera, const Point2 &p_point); void commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, const Variant &p_restore, bool p_cancel = false); - BakedIndirectLightGizmoPlugin(); + LightmapProbeGizmoPlugin(); }; -#endif + class CollisionShape3DGizmoPlugin : public EditorNode3DGizmoPlugin { GDCLASS(CollisionShape3DGizmoPlugin, EditorNode3DGizmoPlugin); diff --git a/editor/plugins/abstract_polygon_2d_editor.cpp b/editor/plugins/abstract_polygon_2d_editor.cpp index c26daa3857..beb3d760c0 100644 --- a/editor/plugins/abstract_polygon_2d_editor.cpp +++ b/editor/plugins/abstract_polygon_2d_editor.cpp @@ -34,24 +34,6 @@ #include "core/os/keyboard.h" #include "editor/editor_scale.h" -AbstractPolygon2DEditor::Vertex::Vertex() : - polygon(-1), - vertex(-1) { - // invalid vertex -} - -AbstractPolygon2DEditor::Vertex::Vertex(int p_vertex) : - polygon(-1), - vertex(p_vertex) { - // vertex p_vertex of current wip polygon -} - -AbstractPolygon2DEditor::Vertex::Vertex(int p_polygon, int p_vertex) : - polygon(p_polygon), - vertex(p_vertex) { - // vertex p_vertex of polygon p_polygon -} - bool AbstractPolygon2DEditor::Vertex::operator==(const AbstractPolygon2DEditor::Vertex &p_vertex) const { return polygon == p_vertex.polygon && vertex == p_vertex.vertex; @@ -67,20 +49,6 @@ bool AbstractPolygon2DEditor::Vertex::valid() const { return vertex >= 0; } -AbstractPolygon2DEditor::PosVertex::PosVertex() { - // invalid vertex -} - -AbstractPolygon2DEditor::PosVertex::PosVertex(const Vertex &p_vertex, const Vector2 &p_pos) : - Vertex(p_vertex.polygon, p_vertex.vertex), - pos(p_pos) { -} - -AbstractPolygon2DEditor::PosVertex::PosVertex(int p_polygon, int p_vertex, const Vector2 &p_pos) : - Vertex(p_polygon, p_vertex), - pos(p_pos) { -} - bool AbstractPolygon2DEditor::_is_empty() const { if (!_get_node()) diff --git a/editor/plugins/abstract_polygon_2d_editor.h b/editor/plugins/abstract_polygon_2d_editor.h index 6ed6d0a257..7d4a3a0f98 100644 --- a/editor/plugins/abstract_polygon_2d_editor.h +++ b/editor/plugins/abstract_polygon_2d_editor.h @@ -47,23 +47,30 @@ class AbstractPolygon2DEditor : public HBoxContainer { ToolButton *button_delete; struct Vertex { - Vertex(); - Vertex(int p_vertex); - Vertex(int p_polygon, int p_vertex); + Vertex() {} + Vertex(int p_vertex) : + vertex(p_vertex) {} + Vertex(int p_polygon, int p_vertex) : + polygon(p_polygon), + vertex(p_vertex) {} bool operator==(const Vertex &p_vertex) const; bool operator!=(const Vertex &p_vertex) const; bool valid() const; - int polygon; - int vertex; + int polygon = -1; + int vertex = -1; }; struct PosVertex : public Vertex { - PosVertex(); - PosVertex(const Vertex &p_vertex, const Vector2 &p_pos); - PosVertex(int p_polygon, int p_vertex, const Vector2 &p_pos); + PosVertex() {} + PosVertex(const Vertex &p_vertex, const Vector2 &p_pos) : + Vertex(p_vertex.polygon, p_vertex.vertex), + pos(p_pos) {} + PosVertex(int p_polygon, int p_vertex, const Vector2 &p_pos) : + Vertex(p_polygon, p_vertex), + pos(p_pos) {} Vector2 pos; }; diff --git a/editor/plugins/animation_blend_space_2d_editor.cpp b/editor/plugins/animation_blend_space_2d_editor.cpp index 17cb68df3a..4ea84e716b 100644 --- a/editor/plugins/animation_blend_space_2d_editor.cpp +++ b/editor/plugins/animation_blend_space_2d_editor.cpp @@ -32,7 +32,7 @@ #include "core/input/input.h" #include "core/io/resource_loader.h" -#include "core/math/delaunay.h" +#include "core/math/delaunay_2d.h" #include "core/os/keyboard.h" #include "core/project_settings.h" #include "editor/editor_scale.h" diff --git a/editor/plugins/animation_state_machine_editor.cpp b/editor/plugins/animation_state_machine_editor.cpp index ed51a2d2cf..509bf59716 100644 --- a/editor/plugins/animation_state_machine_editor.cpp +++ b/editor/plugins/animation_state_machine_editor.cpp @@ -32,7 +32,7 @@ #include "core/input/input.h" #include "core/io/resource_loader.h" -#include "core/math/delaunay.h" +#include "core/math/delaunay_2d.h" #include "core/os/keyboard.h" #include "core/project_settings.h" #include "editor/editor_scale.h" diff --git a/editor/plugins/animation_tree_editor_plugin.cpp b/editor/plugins/animation_tree_editor_plugin.cpp index 9452c0f11b..9c48444e3e 100644 --- a/editor/plugins/animation_tree_editor_plugin.cpp +++ b/editor/plugins/animation_tree_editor_plugin.cpp @@ -36,7 +36,7 @@ #include "animation_state_machine_editor.h" #include "core/input/input.h" #include "core/io/resource_loader.h" -#include "core/math/delaunay.h" +#include "core/math/delaunay_2d.h" #include "core/os/keyboard.h" #include "core/project_settings.h" #include "editor/editor_scale.h" diff --git a/editor/plugins/asset_library_editor_plugin.cpp b/editor/plugins/asset_library_editor_plugin.cpp index bb5147972c..1928d49556 100644 --- a/editor/plugins/asset_library_editor_plugin.cpp +++ b/editor/plugins/asset_library_editor_plugin.cpp @@ -1297,7 +1297,8 @@ void EditorAssetLibrary::_http_request_completed(int p_status, int p_code, const } } } break; - default: break; + default: + break; } } diff --git a/editor/plugins/baked_lightmap_editor_plugin.cpp b/editor/plugins/baked_lightmap_editor_plugin.cpp index ba161244d6..f754dd4725 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.cpp +++ b/editor/plugins/baked_lightmap_editor_plugin.cpp @@ -28,23 +28,36 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #include "baked_lightmap_editor_plugin.h" -void BakedLightmapEditorPlugin::_bake() { +void BakedLightmapEditorPlugin::_bake_select_file(const String &p_file) { if (lightmap) { BakedLightmap::BakeError err; if (get_tree()->get_edited_scene_root() && get_tree()->get_edited_scene_root() == lightmap) { - err = lightmap->bake(lightmap); + err = lightmap->bake(lightmap, p_file, bake_func_step); } else { - err = lightmap->bake(lightmap->get_parent()); + err = lightmap->bake(lightmap->get_parent(), p_file, bake_func_step); } + bake_func_end(); + switch (err) { - case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: - EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene (for images to be saved in the same dir), or pick a save path from the BakedLightmap properties.")); - break; + case BakedLightmap::BAKE_ERROR_NO_SAVE_PATH: { + String scene_path = lightmap->get_filename(); + if (scene_path == String()) { + scene_path = lightmap->get_owner()->get_filename(); + } + if (scene_path == String()) { + EditorNode::get_singleton()->show_warning(TTR("Can't determine a save path for lightmap images.\nSave your scene and try again.")); + break; + } + scene_path = scene_path.get_basename() + ".lmbake"; + + file_dialog->set_current_path(scene_path); + file_dialog->popup_centered_ratio(); + + } break; case BakedLightmap::BAKE_ERROR_NO_MESHES: EditorNode::get_singleton()->show_warning(TTR("No meshes to bake. Make sure they contain an UV2 channel and that the 'Bake Light' flag is on.")); break; @@ -57,6 +70,11 @@ void BakedLightmapEditorPlugin::_bake() { } } +void BakedLightmapEditorPlugin::_bake() { + + _bake_select_file(""); +} + void BakedLightmapEditorPlugin::edit(Object *p_object) { BakedLightmap *s = Object::cast_to<BakedLightmap>(p_object); @@ -83,23 +101,20 @@ void BakedLightmapEditorPlugin::make_visible(bool p_visible) { EditorProgress *BakedLightmapEditorPlugin::tmp_progress = nullptr; -void BakedLightmapEditorPlugin::bake_func_begin(int p_steps) { +bool BakedLightmapEditorPlugin::bake_func_step(float p_progress, const String &p_description, void *, bool p_refresh) { - ERR_FAIL_COND(tmp_progress != nullptr); - - tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), p_steps, true)); -} - -bool BakedLightmapEditorPlugin::bake_func_step(int p_step, const String &p_description) { - - ERR_FAIL_COND_V(tmp_progress == nullptr, false); - return tmp_progress->step(p_description, p_step, false); + if (!tmp_progress) { + tmp_progress = memnew(EditorProgress("bake_lightmaps", TTR("Bake Lightmaps"), 1000, false)); + ERR_FAIL_COND_V(tmp_progress == nullptr, false); + } + return tmp_progress->step(p_description, p_progress * 1000, p_refresh); } void BakedLightmapEditorPlugin::bake_func_end() { - ERR_FAIL_COND(tmp_progress == nullptr); - memdelete(tmp_progress); - tmp_progress = nullptr; + if (tmp_progress != nullptr) { + memdelete(tmp_progress); + tmp_progress = nullptr; + } } void BakedLightmapEditorPlugin::_bind_methods() { @@ -111,18 +126,20 @@ BakedLightmapEditorPlugin::BakedLightmapEditorPlugin(EditorNode *p_node) { editor = p_node; bake = memnew(ToolButton); - bake->set_icon(editor->get_gui_base()->get_icon("Bake", "EditorIcons")); + bake->set_icon(editor->get_gui_base()->get_theme_icon("Bake", "EditorIcons")); bake->set_text(TTR("Bake Lightmaps")); bake->hide(); - bake->connect("pressed", this, "_bake"); + bake->connect("pressed", Callable(this, "_bake")); add_control_to_container(CONTAINER_SPATIAL_EDITOR_MENU, bake); lightmap = nullptr; - BakedLightmap::bake_begin_function = bake_func_begin; - BakedLightmap::bake_step_function = bake_func_step; - BakedLightmap::bake_end_function = bake_func_end; + file_dialog = memnew(EditorFileDialog); + file_dialog->set_file_mode(EditorFileDialog::FILE_MODE_SAVE_FILE); + file_dialog->add_filter("*.lmbake ; LightMap Bake"); + file_dialog->set_title(TTR("Select lightmap bake file:")); + file_dialog->connect("file_selected", callable_mp(this, &BakedLightmapEditorPlugin::_bake_select_file)); + bake->add_child(file_dialog); } BakedLightmapEditorPlugin::~BakedLightmapEditorPlugin() { } -#endif diff --git a/editor/plugins/baked_lightmap_editor_plugin.h b/editor/plugins/baked_lightmap_editor_plugin.h index 818cdfe8fa..2dbc09fc1d 100644 --- a/editor/plugins/baked_lightmap_editor_plugin.h +++ b/editor/plugins/baked_lightmap_editor_plugin.h @@ -28,7 +28,6 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #ifndef BAKED_LIGHTMAP_EDITOR_PLUGIN_H #define BAKED_LIGHTMAP_EDITOR_PLUGIN_H @@ -46,11 +45,12 @@ class BakedLightmapEditorPlugin : public EditorPlugin { ToolButton *bake; EditorNode *editor; + EditorFileDialog *file_dialog; static EditorProgress *tmp_progress; - static void bake_func_begin(int p_steps); - static bool bake_func_step(int p_step, const String &p_description); + static bool bake_func_step(float p_progress, const String &p_description, void *, bool p_refresh); static void bake_func_end(); + void _bake_select_file(const String &p_file); void _bake(); protected: @@ -67,5 +67,4 @@ public: ~BakedLightmapEditorPlugin(); }; -#endif // BAKED_LIGHTMAP_EDITOR_PLUGIN_H #endif diff --git a/editor/plugins/canvas_item_editor_plugin.cpp b/editor/plugins/canvas_item_editor_plugin.cpp index e882b3a8d7..b5fcf82d76 100644 --- a/editor/plugins/canvas_item_editor_plugin.cpp +++ b/editor/plugins/canvas_item_editor_plugin.cpp @@ -1624,20 +1624,28 @@ bool CanvasItemEditor::_gui_input_anchors(const Ref<InputEvent> &p_event) { switch (drag_type) { case DRAG_ANCHOR_TOP_LEFT: - if (!use_single_axis || !use_y) control->set_anchor(MARGIN_LEFT, new_anchor.x, false, false); - if (!use_single_axis || use_y) control->set_anchor(MARGIN_TOP, new_anchor.y, false, false); + if (!use_single_axis || !use_y) + control->set_anchor(MARGIN_LEFT, new_anchor.x, false, false); + if (!use_single_axis || use_y) + control->set_anchor(MARGIN_TOP, new_anchor.y, false, false); break; case DRAG_ANCHOR_TOP_RIGHT: - if (!use_single_axis || !use_y) control->set_anchor(MARGIN_RIGHT, new_anchor.x, false, false); - if (!use_single_axis || use_y) control->set_anchor(MARGIN_TOP, new_anchor.y, false, false); + if (!use_single_axis || !use_y) + control->set_anchor(MARGIN_RIGHT, new_anchor.x, false, false); + if (!use_single_axis || use_y) + control->set_anchor(MARGIN_TOP, new_anchor.y, false, false); break; case DRAG_ANCHOR_BOTTOM_RIGHT: - if (!use_single_axis || !use_y) control->set_anchor(MARGIN_RIGHT, new_anchor.x, false, false); - if (!use_single_axis || use_y) control->set_anchor(MARGIN_BOTTOM, new_anchor.y, false, false); + if (!use_single_axis || !use_y) + control->set_anchor(MARGIN_RIGHT, new_anchor.x, false, false); + if (!use_single_axis || use_y) + control->set_anchor(MARGIN_BOTTOM, new_anchor.y, false, false); break; case DRAG_ANCHOR_BOTTOM_LEFT: - if (!use_single_axis || !use_y) control->set_anchor(MARGIN_LEFT, new_anchor.x, false, false); - if (!use_single_axis || use_y) control->set_anchor(MARGIN_BOTTOM, new_anchor.y, false, false); + if (!use_single_axis || !use_y) + control->set_anchor(MARGIN_LEFT, new_anchor.x, false, false); + if (!use_single_axis || use_y) + control->set_anchor(MARGIN_BOTTOM, new_anchor.y, false, false); break; case DRAG_ANCHOR_ALL: if (!use_single_axis || !use_y) { @@ -5037,7 +5045,8 @@ void CanvasItemEditor::_focus_selection(int p_op) { Map<Node *, Object *> &selection = editor_selection->get_selection(); for (Map<Node *, Object *>::Element *E = selection.front(); E; E = E->next()) { CanvasItem *canvas_item = Object::cast_to<CanvasItem>(E->key()); - if (!canvas_item) continue; + if (!canvas_item) + continue; if (canvas_item->get_viewport() != EditorNode::get_singleton()->get_scene_root()) continue; @@ -5065,7 +5074,8 @@ void CanvasItemEditor::_focus_selection(int p_op) { rect = rect.merge(canvas_item_rect); } }; - if (count == 0) return; + if (count == 0) + return; if (p_op == VIEW_CENTER_TO_SELECTION) { @@ -6151,7 +6161,7 @@ bool CanvasItemEditorViewport::can_drop_data(const Point2 &p_point, const Varian type == "ViewportTexture" || type == "CurveTexture" || type == "GradientTexture" || - type == "StreamTexture" || + type == "StreamTexture2D" || type == "AtlasTexture" || type == "LargeTexture") { Ref<Texture2D> texture = Ref<Texture2D>(Object::cast_to<Texture2D>(*res)); @@ -6252,7 +6262,8 @@ void CanvasItemEditorViewport::_notification(int p_what) { disconnect("mouse_exited", callable_mp(this, &CanvasItemEditorViewport::_on_mouse_exit)); } break; - default: break; + default: + break; } } diff --git a/editor/plugins/canvas_item_editor_plugin.h b/editor/plugins/canvas_item_editor_plugin.h index 9f1a92f563..77f23dfd6d 100644 --- a/editor/plugins/canvas_item_editor_plugin.h +++ b/editor/plugins/canvas_item_editor_plugin.h @@ -48,10 +48,10 @@ class CanvasItemEditorSelectedItem : public Object { public: Transform2D prev_xform; - float prev_rot; + float prev_rot = 0; Rect2 prev_rect; Vector2 prev_pivot; - float prev_anchors[4]; + float prev_anchors[4] = { 0.0f }; Transform2D pre_drag_xform; Rect2 pre_drag_rect; @@ -61,10 +61,7 @@ public: Dictionary undo_state; - CanvasItemEditorSelectedItem() : - prev_anchors() { - prev_rot = 0; - } + CanvasItemEditorSelectedItem() {} }; class CanvasItemEditor : public VBoxContainer { @@ -314,12 +311,10 @@ private: struct BoneList { Transform2D xform; - float length; - uint64_t last_pass; + float length = 0.f; + uint64_t last_pass = 0; - BoneList() : - length(0.f), - last_pass(0) {} + BoneList() {} }; uint64_t bone_last_frame; diff --git a/editor/plugins/curve_editor_plugin.cpp b/editor/plugins/curve_editor_plugin.cpp index 9b5c6bae3b..fc8eef52c0 100644 --- a/editor/plugins/curve_editor_plugin.cpp +++ b/editor/plugins/curve_editor_plugin.cpp @@ -797,7 +797,7 @@ Ref<Texture2D> CurvePreviewGenerator::generate(const Ref<Resource> &p_from, cons img_ref.instance(); Image &im = **img_ref; - im.create(thumbnail_size, thumbnail_size / 2, 0, Image::FORMAT_RGBA8); + im.create(thumbnail_size, thumbnail_size / 2, false, Image::FORMAT_RGBA8); Color bg_color(0.1, 0.1, 0.1, 1.0); for (int i = 0; i < thumbnail_size; i++) { diff --git a/editor/plugins/debugger_editor_plugin.cpp b/editor/plugins/debugger_editor_plugin.cpp index 566ff378c3..e0d345663c 100644 --- a/editor/plugins/debugger_editor_plugin.cpp +++ b/editor/plugins/debugger_editor_plugin.cpp @@ -32,11 +32,14 @@ #include "core/os/keyboard.h" #include "editor/debugger/editor_debugger_node.h" +#include "editor/debugger/editor_debugger_server.h" #include "editor/editor_node.h" #include "editor/fileserver/editor_file_server.h" #include "scene/gui/menu_button.h" DebuggerEditorPlugin::DebuggerEditorPlugin(EditorNode *p_editor, MenuButton *p_debug_menu) { + EditorDebuggerServer::initialize(); + ED_SHORTCUT("debugger/step_into", TTR("Step Into"), KEY_F11); ED_SHORTCUT("debugger/step_over", TTR("Step Over"), KEY_F10); ED_SHORTCUT("debugger/break", TTR("Break")); @@ -96,6 +99,7 @@ DebuggerEditorPlugin::DebuggerEditorPlugin(EditorNode *p_editor, MenuButton *p_d } DebuggerEditorPlugin::~DebuggerEditorPlugin() { + EditorDebuggerServer::deinitialize(); memdelete(file_server); } @@ -179,12 +183,18 @@ void DebuggerEditorPlugin::_update_debug_options() { bool check_reload_scripts = EditorSettings::get_singleton()->get_project_metadata("debug_options", "run_reload_scripts", false); int instances = EditorSettings::get_singleton()->get_project_metadata("debug_options", "run_debug_instances", 1); - if (check_deploy_remote) _menu_option(RUN_DEPLOY_REMOTE_DEBUG); - if (check_file_server) _menu_option(RUN_FILE_SERVER); - if (check_debug_collisions) _menu_option(RUN_DEBUG_COLLISONS); - if (check_debug_navigation) _menu_option(RUN_DEBUG_NAVIGATION); - if (check_live_debug) _menu_option(RUN_LIVE_DEBUG); - if (check_reload_scripts) _menu_option(RUN_RELOAD_SCRIPTS); + if (check_deploy_remote) + _menu_option(RUN_DEPLOY_REMOTE_DEBUG); + if (check_file_server) + _menu_option(RUN_FILE_SERVER); + if (check_debug_collisions) + _menu_option(RUN_DEBUG_COLLISONS); + if (check_debug_navigation) + _menu_option(RUN_DEBUG_NAVIGATION); + if (check_live_debug) + _menu_option(RUN_LIVE_DEBUG); + if (check_reload_scripts) + _menu_option(RUN_RELOAD_SCRIPTS); int len = instances_menu->get_item_count(); for (int idx = 0; idx < len; idx++) { diff --git a/editor/plugins/editor_preview_plugins.cpp b/editor/plugins/editor_preview_plugins.cpp index a8c4bddccf..2db0903f04 100644 --- a/editor/plugins/editor_preview_plugins.cpp +++ b/editor/plugins/editor_preview_plugins.cpp @@ -223,7 +223,7 @@ Ref<Texture2D> EditorBitmapPreviewPlugin::generate(const RES &p_from, const Size Ref<Image> img; img.instance(); - img->create(bm->get_size().width, bm->get_size().height, 0, Image::FORMAT_L8, data); + img->create(bm->get_size().width, bm->get_size().height, false, Image::FORMAT_L8, data); if (img->is_compressed()) { if (img->decompress() != OK) @@ -511,7 +511,7 @@ Ref<Texture2D> EditorScriptPreviewPlugin::generate(const RES &p_from, const Size Ref<Image> img; img.instance(); int thumbnail_size = MAX(p_size.x, p_size.y); - img->create(thumbnail_size, thumbnail_size, 0, Image::FORMAT_RGBA8); + img->create(thumbnail_size, thumbnail_size, false, Image::FORMAT_RGBA8); Color bg_color = EditorSettings::get_singleton()->get("text_editor/highlighting/background_color"); Color keyword_color = EditorSettings::get_singleton()->get("text_editor/highlighting/keyword_color"); diff --git a/editor/plugins/mesh_library_editor_plugin.cpp b/editor/plugins/mesh_library_editor_plugin.cpp index a3e3d88ae2..1ca8be528f 100644 --- a/editor/plugins/mesh_library_editor_plugin.cpp +++ b/editor/plugins/mesh_library_editor_plugin.cpp @@ -168,7 +168,7 @@ void MeshLibraryEditor::_import_scene(Node *p_scene, Ref<MeshLibrary> p_library, //generate previews! - if (1) { + if (true) { Vector<Ref<Mesh>> meshes; Vector<Transform> transforms; diff --git a/editor/plugins/node_3d_editor_plugin.cpp b/editor/plugins/node_3d_editor_plugin.cpp index 457a0aaa36..1614f3f073 100644 --- a/editor/plugins/node_3d_editor_plugin.cpp +++ b/editor/plugins/node_3d_editor_plugin.cpp @@ -718,9 +718,11 @@ void Node3DEditorViewport::_select_region() { item = sel; } - if (selected.find(item) != -1) continue; + if (selected.find(item) != -1) + continue; - if (_is_node_locked(item)) continue; + if (_is_node_locked(item)) + continue; Ref<EditorNode3DGizmo> seg = sp->get_gizmo(); @@ -783,11 +785,16 @@ static int _get_key_modifier_setting(const String &p_property) { switch (EditorSettings::get_singleton()->get(p_property).operator int()) { - case 0: return 0; - case 1: return KEY_SHIFT; - case 2: return KEY_ALT; - case 3: return KEY_META; - case 4: return KEY_CONTROL; + case 0: + return 0; + case 1: + return KEY_SHIFT; + case 2: + return KEY_ALT; + case 3: + return KEY_META; + case 4: + return KEY_CONTROL; } return 0; } @@ -1381,7 +1388,8 @@ void Node3DEditorViewport::_sinput(const Ref<InputEvent> &p_event) { if (cursor.region_select) { - if (!clicked_wants_append) _clear_selected(); + if (!clicked_wants_append) + _clear_selected(); _select_region(); cursor.region_select = false; @@ -2076,7 +2084,8 @@ void Node3DEditorViewport::_sinput(const Ref<InputEvent> &p_event) { } if (k->get_keycode() == KEY_SPACE) { - if (!k->is_pressed()) emit_signal("toggle_maximize_view", this); + if (!k->is_pressed()) + emit_signal("toggle_maximize_view", this); } } @@ -3876,7 +3885,7 @@ Node3DEditorViewport::Node3DEditorViewport(Node3DEditor *p_spatial_editor, Edito _edit.mode = TRANSFORM_NONE; _edit.plane = TRANSFORM_VIEW; _edit.edited_gizmo = 0; - _edit.snap = 1; + _edit.snap = true; _edit.gizmo_handle = 0; index = p_index; @@ -4633,7 +4642,8 @@ Dictionary Node3DEditor::get_state() const { Dictionary gizmos_status; for (int i = 0; i < gizmo_plugins_by_name.size(); i++) { - if (!gizmo_plugins_by_name[i]->can_be_hidden()) continue; + if (!gizmo_plugins_by_name[i]->can_be_hidden()) + continue; int state = gizmos_menu->get_item_state(gizmos_menu->get_item_index(i)); String name = gizmo_plugins_by_name[i]->get_name(); gizmos_status[name] = state; @@ -4727,7 +4737,8 @@ void Node3DEditor::set_state(const Dictionary &p_state) { gizmos_status.get_key_list(&keys); for (int j = 0; j < gizmo_plugins_by_name.size(); ++j) { - if (!gizmo_plugins_by_name[j]->can_be_hidden()) continue; + if (!gizmo_plugins_by_name[j]->can_be_hidden()) + continue; int state = EditorNode3DGizmoPlugin::VISIBLE; for (int i = 0; i < keys.size(); i++) { if (gizmo_plugins_by_name.write[j]->get_name() == keys[i]) { @@ -5492,7 +5503,8 @@ void Node3DEditor::_update_gizmos_menu() { gizmos_menu->clear(); for (int i = 0; i < gizmo_plugins_by_name.size(); ++i) { - if (!gizmo_plugins_by_name[i]->can_be_hidden()) continue; + if (!gizmo_plugins_by_name[i]->can_be_hidden()) + continue; String plugin_name = gizmo_plugins_by_name[i]->get_name(); const int plugin_state = gizmo_plugins_by_name[i]->get_state(); gizmos_menu->add_multistate_item(TTR(plugin_name), 3, plugin_state, i); @@ -5513,7 +5525,8 @@ void Node3DEditor::_update_gizmos_menu() { void Node3DEditor::_update_gizmos_menu_theme() { for (int i = 0; i < gizmo_plugins_by_name.size(); ++i) { - if (!gizmo_plugins_by_name[i]->can_be_hidden()) continue; + if (!gizmo_plugins_by_name[i]->can_be_hidden()) + continue; const int plugin_state = gizmo_plugins_by_name[i]->get_state(); const int idx = gizmos_menu->get_item_index(i); switch (plugin_state) { @@ -5924,9 +5937,11 @@ void Node3DEditor::_request_gizmo(Object *p_obj) { } void Node3DEditor::_toggle_maximize_view(Object *p_viewport) { - if (!p_viewport) return; + if (!p_viewport) + return; Node3DEditorViewport *current_viewport = Object::cast_to<Node3DEditorViewport>(p_viewport); - if (!current_viewport) return; + if (!current_viewport) + return; int index = -1; bool maximized = false; @@ -5938,7 +5953,8 @@ void Node3DEditor::_toggle_maximize_view(Object *p_viewport) { break; } } - if (index == -1) return; + if (index == -1) + return; if (!maximized) { @@ -5992,7 +6008,8 @@ void Node3DEditor::_register_all_gizmos() { add_gizmo_plugin(Ref<ReflectionProbeGizmoPlugin>(memnew(ReflectionProbeGizmoPlugin))); add_gizmo_plugin(Ref<DecalGizmoPlugin>(memnew(DecalGizmoPlugin))); add_gizmo_plugin(Ref<GIProbeGizmoPlugin>(memnew(GIProbeGizmoPlugin))); - // add_gizmo_plugin(Ref<BakedIndirectLightGizmoPlugin>(memnew(BakedIndirectLightGizmoPlugin))); + add_gizmo_plugin(Ref<BakedLightmapGizmoPlugin>(memnew(BakedLightmapGizmoPlugin))); + add_gizmo_plugin(Ref<LightmapProbeGizmoPlugin>(memnew(LightmapProbeGizmoPlugin))); add_gizmo_plugin(Ref<CollisionShape3DGizmoPlugin>(memnew(CollisionShape3DGizmoPlugin))); add_gizmo_plugin(Ref<CollisionPolygon3DGizmoPlugin>(memnew(CollisionPolygon3DGizmoPlugin))); add_gizmo_plugin(Ref<NavigationRegion3DGizmoPlugin>(memnew(NavigationRegion3DGizmoPlugin))); @@ -6655,7 +6672,8 @@ Ref<StandardMaterial3D> EditorNode3DGizmoPlugin::get_material(const String &p_na ERR_FAIL_COND_V(!materials.has(p_name), Ref<StandardMaterial3D>()); ERR_FAIL_COND_V(materials[p_name].size() == 0, Ref<StandardMaterial3D>()); - if (p_gizmo.is_null() || materials[p_name].size() == 1) return materials[p_name][0]; + if (p_gizmo.is_null() || materials[p_name].size() == 1) + return materials[p_name][0]; int index = (p_gizmo->is_selected() ? 1 : 0) + (p_gizmo->is_editable() ? 2 : 0); @@ -6692,7 +6710,8 @@ Ref<EditorNode3DGizmo> EditorNode3DGizmoPlugin::get_gizmo(Node3D *p_spatial) { Ref<EditorNode3DGizmo> ref = create_gizmo(p_spatial); - if (ref.is_null()) return ref; + if (ref.is_null()) + return ref; ref->set_plugin(this); ref->set_spatial_node(p_spatial); @@ -6751,7 +6770,8 @@ Ref<EditorNode3DGizmo> EditorNode3DGizmoPlugin::create_gizmo(Node3D *p_spatial) } Ref<EditorNode3DGizmo> ref; - if (has_gizmo(p_spatial)) ref.instance(); + if (has_gizmo(p_spatial)) + ref.instance(); return ref; } diff --git a/editor/plugins/path_2d_editor_plugin.cpp b/editor/plugins/path_2d_editor_plugin.cpp index 4516b7035b..2edb337b1c 100644 --- a/editor/plugins/path_2d_editor_plugin.cpp +++ b/editor/plugins/path_2d_editor_plugin.cpp @@ -288,8 +288,10 @@ bool Path2DEditor::forward_gui_input(const Ref<InputEvent> &p_event) { Vector2 gpoint = mm->get_position(); Ref<Curve2D> curve = node->get_curve(); - if (curve == nullptr) return true; - if (curve->get_point_count() < 2) return true; + if (curve == nullptr) + return true; + if (curve->get_point_count() < 2) + return true; // Find edge edge_point = xform.xform(curve->get_closest_point(xform.affine_inverse().xform(mm->get_position()))); diff --git a/editor/plugins/path_3d_editor_plugin.cpp b/editor/plugins/path_3d_editor_plugin.cpp index d3ece9556d..6c475d829f 100644 --- a/editor/plugins/path_3d_editor_plugin.cpp +++ b/editor/plugins/path_3d_editor_plugin.cpp @@ -631,7 +631,8 @@ Ref<EditorNode3DGizmo> Path3DGizmoPlugin::create_gizmo(Node3D *p_spatial) { Ref<Path3DGizmo> ref; Path3D *path = Object::cast_to<Path3D>(p_spatial); - if (path) ref = Ref<Path3DGizmo>(memnew(Path3DGizmo(path))); + if (path) + ref = Ref<Path3DGizmo>(memnew(Path3DGizmo(path))); return ref; } diff --git a/editor/plugins/physical_bone_3d_editor_plugin.cpp b/editor/plugins/physical_bone_3d_editor_plugin.cpp index 6d38f7f318..567d58922f 100644 --- a/editor/plugins/physical_bone_3d_editor_plugin.cpp +++ b/editor/plugins/physical_bone_3d_editor_plugin.cpp @@ -48,8 +48,7 @@ void PhysicalBone3DEditor::_set_move_joint() { } PhysicalBone3DEditor::PhysicalBone3DEditor(EditorNode *p_editor) : - editor(p_editor), - selected(nullptr) { + editor(p_editor) { spatial_editor_hb = memnew(HBoxContainer); spatial_editor_hb->set_h_size_flags(Control::SIZE_EXPAND_FILL); @@ -69,8 +68,6 @@ PhysicalBone3DEditor::PhysicalBone3DEditor(EditorNode *p_editor) : hide(); } -PhysicalBone3DEditor::~PhysicalBone3DEditor() {} - void PhysicalBone3DEditor::set_selected(PhysicalBone3D *p_pb) { button_transform_joint->set_pressed(false); @@ -90,7 +87,6 @@ void PhysicalBone3DEditor::show() { PhysicalBone3DEditorPlugin::PhysicalBone3DEditorPlugin(EditorNode *p_editor) : editor(p_editor), - selected(nullptr), physical_bone_editor(editor) {} void PhysicalBone3DEditorPlugin::make_visible(bool p_visible) { diff --git a/editor/plugins/physical_bone_3d_editor_plugin.h b/editor/plugins/physical_bone_3d_editor_plugin.h index 74932710d6..79c7cc4bb1 100644 --- a/editor/plugins/physical_bone_3d_editor_plugin.h +++ b/editor/plugins/physical_bone_3d_editor_plugin.h @@ -40,7 +40,7 @@ class PhysicalBone3DEditor : public Object { HBoxContainer *spatial_editor_hb; ToolButton *button_transform_joint; - PhysicalBone3D *selected; + PhysicalBone3D *selected = nullptr; protected: static void _bind_methods(); @@ -51,7 +51,7 @@ private: public: PhysicalBone3DEditor(EditorNode *p_editor); - ~PhysicalBone3DEditor(); + ~PhysicalBone3DEditor() {} void set_selected(PhysicalBone3D *p_pb); @@ -63,7 +63,7 @@ class PhysicalBone3DEditorPlugin : public EditorPlugin { GDCLASS(PhysicalBone3DEditorPlugin, EditorPlugin); EditorNode *editor; - PhysicalBone3D *selected; + PhysicalBone3D *selected = nullptr; PhysicalBone3DEditor physical_bone_editor; public: diff --git a/editor/plugins/script_text_editor.cpp b/editor/plugins/script_text_editor.cpp index c79c97737a..109d83d838 100644 --- a/editor/plugins/script_text_editor.cpp +++ b/editor/plugins/script_text_editor.cpp @@ -822,7 +822,8 @@ void ScriptTextEditor::_code_complete_scripts(void *p_ud, const String &p_code, void ScriptTextEditor::_code_complete_script(const String &p_code, List<ScriptCodeCompletionOption> *r_options, bool &r_force) { - if (color_panel->is_visible()) return; + if (color_panel->is_visible()) + return; Node *base = get_tree()->get_edited_scene_root(); if (base) { base = _find_node_for_script(base, base, script); diff --git a/editor/plugins/shader_file_editor_plugin.cpp b/editor/plugins/shader_file_editor_plugin.cpp index 296c7a01b6..9d5ffd6516 100644 --- a/editor/plugins/shader_file_editor_plugin.cpp +++ b/editor/plugins/shader_file_editor_plugin.cpp @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_file_editor_plugin.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #include "shader_file_editor_plugin.h" #include "core/io/resource_loader.h" diff --git a/editor/plugins/shader_file_editor_plugin.h b/editor/plugins/shader_file_editor_plugin.h index 44d32de2f1..7df177a0d5 100644 --- a/editor/plugins/shader_file_editor_plugin.h +++ b/editor/plugins/shader_file_editor_plugin.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_file_editor_plugin.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef SHADER_FILE_EDITOR_PLUGIN_H #define SHADER_FILE_EDITOR_PLUGIN_H diff --git a/editor/plugins/skeleton_3d_editor_plugin.h b/editor/plugins/skeleton_3d_editor_plugin.h index 2ba5a817bc..1bcf27e2f2 100644 --- a/editor/plugins/skeleton_3d_editor_plugin.h +++ b/editor/plugins/skeleton_3d_editor_plugin.h @@ -46,10 +46,9 @@ class Skeleton3DEditor : public Node { }; struct BoneInfo { - PhysicalBone3D *physical_bone; + PhysicalBone3D *physical_bone = nullptr; Transform relative_rest; // Relative to skeleton node - BoneInfo() : - physical_bone(nullptr) {} + BoneInfo() {} }; Skeleton3D *skeleton; diff --git a/editor/plugins/texture_editor_plugin.cpp b/editor/plugins/texture_editor_plugin.cpp index c1184c1c89..7a3e571f16 100644 --- a/editor/plugins/texture_editor_plugin.cpp +++ b/editor/plugins/texture_editor_plugin.cpp @@ -84,8 +84,8 @@ void TextureEditor::_notification(int p_what) { String format; if (Object::cast_to<ImageTexture>(*texture)) { format = Image::get_format_name(Object::cast_to<ImageTexture>(*texture)->get_format()); - } else if (Object::cast_to<StreamTexture>(*texture)) { - format = Image::get_format_name(Object::cast_to<StreamTexture>(*texture)->get_format()); + } else if (Object::cast_to<StreamTexture2D>(*texture)) { + format = Image::get_format_name(Object::cast_to<StreamTexture2D>(*texture)->get_format()); } else { format = texture->get_class(); } @@ -144,7 +144,7 @@ TextureEditor::~TextureEditor() { // bool EditorInspectorPluginTexture::can_handle(Object *p_object) { - return Object::cast_to<ImageTexture>(p_object) != nullptr || Object::cast_to<AtlasTexture>(p_object) != nullptr || Object::cast_to<StreamTexture>(p_object) != nullptr || Object::cast_to<LargeTexture>(p_object) != nullptr || Object::cast_to<AnimatedTexture>(p_object) != nullptr; + return Object::cast_to<ImageTexture>(p_object) != nullptr || Object::cast_to<AtlasTexture>(p_object) != nullptr || Object::cast_to<StreamTexture2D>(p_object) != nullptr || Object::cast_to<LargeTexture>(p_object) != nullptr || Object::cast_to<AnimatedTexture>(p_object) != nullptr; } void EditorInspectorPluginTexture::parse_begin(Object *p_object) { diff --git a/editor/plugins/texture_layered_editor_plugin.cpp b/editor/plugins/texture_layered_editor_plugin.cpp new file mode 100644 index 0000000000..6d716951b3 --- /dev/null +++ b/editor/plugins/texture_layered_editor_plugin.cpp @@ -0,0 +1,286 @@ +/*************************************************************************/ +/* texture_editor_plugin.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "texture_layered_editor_plugin.h" + +#include "core/io/resource_loader.h" +#include "core/project_settings.h" +#include "editor/editor_settings.h" + +void TextureLayeredEditor::_gui_input(Ref<InputEvent> p_event) { + Ref<InputEventMouseMotion> mm = p_event; + if (mm.is_valid() && mm->get_button_mask() & BUTTON_MASK_LEFT) { + y_rot += -mm->get_relative().x * 0.01; + x_rot += mm->get_relative().y * 0.01; + _update_material(); + } +} + +void TextureLayeredEditor::_texture_rect_draw() { + texture_rect->draw_rect(Rect2(Point2(), texture_rect->get_size()), Color(1, 1, 1, 1)); +} + +void TextureLayeredEditor::_notification(int p_what) { + + if (p_what == NOTIFICATION_READY) { + + //get_scene()->connect("node_removed",this,"_node_removed"); + } + if (p_what == NOTIFICATION_RESIZED) { + _texture_rect_update_area(); + } + + if (p_what == NOTIFICATION_DRAW) { + + Ref<Texture2D> checkerboard = get_theme_icon("Checkerboard", "EditorIcons"); + Size2 size = get_size(); + + draw_texture_rect(checkerboard, Rect2(Point2(), size), true); + } +} + +void TextureLayeredEditor::_changed_callback(Object *p_changed, const char *p_prop) { + + if (!is_visible()) + return; + update(); +} + +void TextureLayeredEditor::_update_material() { + + materials[0]->set_shader_param("layer", layer->get_value()); + materials[2]->set_shader_param("layer", layer->get_value()); + materials[texture->get_layered_type()]->set_shader_param("tex", texture->get_rid()); + + Vector3 v(1, 1, 1); + v.normalize(); + + Basis b; + b.rotate(Vector3(1, 0, 0), x_rot); + b.rotate(Vector3(0, 1, 0), y_rot); + + materials[1]->set_shader_param("normal", v); + materials[1]->set_shader_param("rot", b); + materials[2]->set_shader_param("normal", v); + materials[2]->set_shader_param("rot", b); + + String format = Image::get_format_name(texture->get_format()); + + String text; + if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_2D_ARRAY) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " (x " + itos(texture->get_layers()) + ")" + format; + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " " + format; + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP_ARRAY) { + text = itos(texture->get_width()) + "x" + itos(texture->get_height()) + " (x " + itos(texture->get_layers() / 6) + ")" + format; + } + + info->set_text(text); +} + +void TextureLayeredEditor::_make_shaders() { + String shader_2d_array = "" + "shader_type canvas_item;\n" + "uniform sampler2DArray tex;\n" + "uniform float layer;\n" + "void fragment() {\n" + " COLOR = textureLod(tex,vec3(UV,layer),0.0);\n" + "}"; + + shaders[0].instance(); + shaders[0]->set_code(shader_2d_array); + + String shader_cube = "" + "shader_type canvas_item;\n" + "uniform samplerCube tex;\n" + "uniform vec3 normal;\n" + "uniform mat3 rot;\n" + "void fragment() {\n" + " vec3 n = rot * normalize(vec3(normal.xy*(UV * 2.0 - 1.0),normal.z));\n" + " COLOR = textureLod(tex,n,0.0);\n" + "}"; + + shaders[1].instance(); + shaders[1]->set_code(shader_cube); + + String shader_cube_array = "" + "shader_type canvas_item;\n" + "uniform samplerCubeArray tex;\n" + "uniform vec3 normal;\n" + "uniform mat3 rot;\n" + "uniform float layer;\n" + "void fragment() {\n" + " vec3 n = rot * normalize(vec3(normal.xy*(UV * 2.0 - 1.0),normal.z));\n" + " COLOR = textureLod(tex,vec4(n,layer),0.0);\n" + "}"; + + shaders[2].instance(); + shaders[2]->set_code(shader_cube_array); + + for (int i = 0; i < 3; i++) { + materials[i].instance(); + materials[i]->set_shader(shaders[i]); + } +} + +void TextureLayeredEditor::_texture_rect_update_area() { + + Size2 size = get_size(); + int tex_width = texture->get_width() * size.height / texture->get_height(); + int tex_height = size.height; + + if (tex_width > size.width) { + tex_width = size.width; + tex_height = texture->get_height() * tex_width / texture->get_width(); + } + + // Prevent the texture from being unpreviewable after the rescale, so that we can still see something + if (tex_height <= 0) + tex_height = 1; + if (tex_width <= 0) + tex_width = 1; + + int ofs_x = (size.width - tex_width) / 2; + int ofs_y = (size.height - tex_height) / 2; + + texture_rect->set_position(Vector2(ofs_x, ofs_y)); + texture_rect->set_size(Vector2(tex_width, tex_height)); +} + +void TextureLayeredEditor::edit(Ref<TextureLayered> p_texture) { + + if (!texture.is_null()) + texture->remove_change_receptor(this); + + texture = p_texture; + + if (!texture.is_null()) { + + if (shaders[0].is_null()) { + _make_shaders(); + } + + texture->add_change_receptor(this); + update(); + texture_rect->set_material(materials[texture->get_layered_type()]); + setting = true; + if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_2D_ARRAY) { + layer->set_max(texture->get_layers() - 1); + layer->set_value(0); + layer->show(); + } else if (texture->get_layered_type() == TextureLayered::LAYERED_TYPE_CUBEMAP_ARRAY) { + layer->set_max(texture->get_layers() / 6 - 1); + layer->set_value(0); + layer->show(); + } else { + layer->hide(); + } + x_rot = 0; + y_rot = 0; + _update_material(); + setting = false; + _texture_rect_update_area(); + } else { + hide(); + } +} + +void TextureLayeredEditor::_bind_methods() { + + ClassDB::bind_method(D_METHOD("_gui_input"), &TextureLayeredEditor::_gui_input); + ClassDB::bind_method(D_METHOD("_layer_changed"), &TextureLayeredEditor::_layer_changed); +} + +TextureLayeredEditor::TextureLayeredEditor() { + + set_texture_repeat(TextureRepeat::TEXTURE_REPEAT_ENABLED); + set_custom_minimum_size(Size2(1, 150)); + texture_rect = memnew(Control); + texture_rect->connect("draw", callable_mp(this, &TextureLayeredEditor::_texture_rect_draw)); + texture_rect->set_mouse_filter(MOUSE_FILTER_IGNORE); + add_child(texture_rect); + + layer = memnew(SpinBox); + layer->set_step(1); + layer->set_max(100); + add_child(layer); + layer->set_anchor(MARGIN_RIGHT, 1); + layer->set_anchor(MARGIN_LEFT, 1); + layer->set_h_grow_direction(GROW_DIRECTION_BEGIN); + layer->set_modulate(Color(1, 1, 1, 0.8)); + info = memnew(Label); + add_child(info); + info->set_anchor(MARGIN_RIGHT, 1); + info->set_anchor(MARGIN_LEFT, 1); + info->set_anchor(MARGIN_BOTTOM, 1); + info->set_anchor(MARGIN_TOP, 1); + info->set_h_grow_direction(GROW_DIRECTION_BEGIN); + info->set_v_grow_direction(GROW_DIRECTION_BEGIN); + info->add_theme_color_override("font_color", Color(1, 1, 1, 1)); + info->add_theme_color_override("font_color_shadow", Color(0, 0, 0, 0.5)); + info->add_theme_color_override("font_color_shadow", Color(0, 0, 0, 0.5)); + info->add_theme_constant_override("shadow_as_outline", 1); + info->add_theme_constant_override("shadow_offset_x", 2); + info->add_theme_constant_override("shadow_offset_y", 2); + + setting = false; + layer->connect("value_changed", Callable(this, "_layer_changed")); +} + +TextureLayeredEditor::~TextureLayeredEditor() { + if (!texture.is_null()) { + texture->remove_change_receptor(this); + } +} +// +bool EditorInspectorPluginLayeredTexture::can_handle(Object *p_object) { + + return Object::cast_to<TextureLayered>(p_object) != nullptr; +} + +void EditorInspectorPluginLayeredTexture::parse_begin(Object *p_object) { + + TextureLayered *texture = Object::cast_to<TextureLayered>(p_object); + if (!texture) { + return; + } + Ref<TextureLayered> m(texture); + + TextureLayeredEditor *editor = memnew(TextureLayeredEditor); + editor->edit(m); + add_custom_control(editor); +} + +TextureLayeredEditorPlugin::TextureLayeredEditorPlugin(EditorNode *p_node) { + + Ref<EditorInspectorPluginLayeredTexture> plugin; + plugin.instance(); + add_inspector_plugin(plugin); +} diff --git a/editor/plugins/texture_layered_editor_plugin.h b/editor/plugins/texture_layered_editor_plugin.h new file mode 100644 index 0000000000..e8503e845e --- /dev/null +++ b/editor/plugins/texture_layered_editor_plugin.h @@ -0,0 +1,95 @@ +/*************************************************************************/ +/* texture_editor_plugin.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef TEXTURE_LAYERED_EDITOR_PLUGIN_H +#define TEXTURE_LAYERED_EDITOR_PLUGIN_H + +#include "editor/editor_node.h" +#include "editor/editor_plugin.h" +#include "scene/resources/shader.h" +#include "scene/resources/texture.h" +class TextureLayeredEditor : public Control { + + GDCLASS(TextureLayeredEditor, Control); + + SpinBox *layer; + Label *info; + Ref<TextureLayered> texture; + + Ref<Shader> shaders[3]; + Ref<ShaderMaterial> materials[3]; + + float x_rot = 0; + float y_rot = 0; + Control *texture_rect; + + void _make_shaders(); + + void _update_material(); + bool setting; + void _layer_changed(double) { + if (!setting) + _update_material(); + } + + void _texture_rect_update_area(); + void _texture_rect_draw(); + +protected: + void _notification(int p_what); + void _gui_input(Ref<InputEvent> p_event); + void _changed_callback(Object *p_changed, const char *p_prop); + static void _bind_methods(); + +public: + void edit(Ref<TextureLayered> p_texture); + TextureLayeredEditor(); + ~TextureLayeredEditor(); +}; + +class EditorInspectorPluginLayeredTexture : public EditorInspectorPlugin { + GDCLASS(EditorInspectorPluginLayeredTexture, EditorInspectorPlugin); + +public: + virtual bool can_handle(Object *p_object); + virtual void parse_begin(Object *p_object); +}; + +class TextureLayeredEditorPlugin : public EditorPlugin { + + GDCLASS(TextureLayeredEditorPlugin, EditorPlugin); + +public: + virtual String get_name() const { return "TextureLayered"; } + + TextureLayeredEditorPlugin(EditorNode *p_node); +}; + +#endif // TEXTURE_EDITOR_PLUGIN_H diff --git a/editor/plugins/theme_editor_plugin.cpp b/editor/plugins/theme_editor_plugin.cpp index b246b611fd..d9be2e32cb 100644 --- a/editor/plugins/theme_editor_plugin.cpp +++ b/editor/plugins/theme_editor_plugin.cpp @@ -74,11 +74,21 @@ void ThemeEditor::_name_menu_about_to_show() { switch (type_select->get_selected()) { - case 0: Theme::get_default()->get_icon_list(fromtype, &names); break; - case 1: Theme::get_default()->get_stylebox_list(fromtype, &names); break; - case 2: Theme::get_default()->get_font_list(fromtype, &names); break; - case 3: Theme::get_default()->get_color_list(fromtype, &names); break; - case 4: Theme::get_default()->get_constant_list(fromtype, &names); break; + case 0: + Theme::get_default()->get_icon_list(fromtype, &names); + break; + case 1: + Theme::get_default()->get_stylebox_list(fromtype, &names); + break; + case 2: + Theme::get_default()->get_font_list(fromtype, &names); + break; + case 3: + Theme::get_default()->get_color_list(fromtype, &names); + break; + case 4: + Theme::get_default()->get_constant_list(fromtype, &names); + break; } } else if (popup_mode == POPUP_REMOVE) { @@ -324,11 +334,21 @@ void ThemeEditor::_dialog_cbk() { switch (type_select->get_selected()) { - case 0: theme->set_icon(name_edit->get_text(), type_edit->get_text(), Ref<Texture2D>()); break; - case 1: theme->set_stylebox(name_edit->get_text(), type_edit->get_text(), Ref<StyleBox>()); break; - case 2: theme->set_font(name_edit->get_text(), type_edit->get_text(), Ref<Font>()); break; - case 3: theme->set_color(name_edit->get_text(), type_edit->get_text(), Color()); break; - case 4: theme->set_constant(name_edit->get_text(), type_edit->get_text(), 0); break; + case 0: + theme->set_icon(name_edit->get_text(), type_edit->get_text(), Ref<Texture2D>()); + break; + case 1: + theme->set_stylebox(name_edit->get_text(), type_edit->get_text(), Ref<StyleBox>()); + break; + case 2: + theme->set_font(name_edit->get_text(), type_edit->get_text(), Ref<Font>()); + break; + case 3: + theme->set_color(name_edit->get_text(), type_edit->get_text(), Color()); + break; + case 4: + theme->set_constant(name_edit->get_text(), type_edit->get_text(), 0); + break; } } break; @@ -376,11 +396,21 @@ void ThemeEditor::_dialog_cbk() { case POPUP_REMOVE: { switch (type_select->get_selected()) { - case 0: theme->clear_icon(name_edit->get_text(), type_edit->get_text()); break; - case 1: theme->clear_stylebox(name_edit->get_text(), type_edit->get_text()); break; - case 2: theme->clear_font(name_edit->get_text(), type_edit->get_text()); break; - case 3: theme->clear_color(name_edit->get_text(), type_edit->get_text()); break; - case 4: theme->clear_constant(name_edit->get_text(), type_edit->get_text()); break; + case 0: + theme->clear_icon(name_edit->get_text(), type_edit->get_text()); + break; + case 1: + theme->clear_stylebox(name_edit->get_text(), type_edit->get_text()); + break; + case 2: + theme->clear_font(name_edit->get_text(), type_edit->get_text()); + break; + case 3: + theme->clear_color(name_edit->get_text(), type_edit->get_text()); + break; + case 4: + theme->clear_constant(name_edit->get_text(), type_edit->get_text()); + break; } } break; diff --git a/editor/plugins/tile_map_editor_plugin.h b/editor/plugins/tile_map_editor_plugin.h index f43e5bb5cb..28b0e9b6db 100644 --- a/editor/plugins/tile_map_editor_plugin.h +++ b/editor/plugins/tile_map_editor_plugin.h @@ -33,7 +33,6 @@ #include "editor/editor_node.h" #include "editor/editor_plugin.h" - #include "scene/2d/tile_map.h" #include "scene/gui/check_box.h" #include "scene/gui/label.h" @@ -127,34 +126,26 @@ class TileMapEditor : public VBoxContainer { List<Point2i> bucket_queue; struct CellOp { - int idx; - bool xf; - bool yf; - bool tr; + int idx = TileMap::INVALID_CELL; + bool xf = false; + bool yf = false; + bool tr = false; Vector2 ac; - CellOp() : - idx(TileMap::INVALID_CELL), - xf(false), - yf(false), - tr(false) {} + CellOp() {} }; Map<Point2i, CellOp> paint_undo; struct TileData { Point2i pos; - int cell; - bool flip_h; - bool flip_v; - bool transpose; + int cell = TileMap::INVALID_CELL; + bool flip_h = false; + bool flip_v = false; + bool transpose = false; Point2i autotile_coord; - TileData() : - cell(TileMap::INVALID_CELL), - flip_h(false), - flip_v(false), - transpose(false) {} + TileData() {} }; List<TileData> copydata; diff --git a/editor/plugins/tile_set_editor_plugin.cpp b/editor/plugins/tile_set_editor_plugin.cpp index c393b15a97..b0d325efc1 100644 --- a/editor/plugins/tile_set_editor_plugin.cpp +++ b/editor/plugins/tile_set_editor_plugin.cpp @@ -126,7 +126,8 @@ void TileSetEditor::_import_node(Node *p_node, Ref<TileSet> p_library) { sb->get_shape_owners(&shapes); for (List<uint32_t>::Element *E = shapes.front(); E; E = E->next()) { - if (sb->is_shape_owner_disabled(E->get())) continue; + if (sb->is_shape_owner_disabled(E->get())) + continue; Transform2D shape_transform = sb->get_transform() * sb->shape_owner_get_transform(E->get()); bool one_way = sb->is_shape_owner_one_way_collision_enabled(E->get()); diff --git a/editor/plugins/visual_shader_editor_plugin.cpp b/editor/plugins/visual_shader_editor_plugin.cpp index 35ed29f562..07251ad7ad 100644 --- a/editor/plugins/visual_shader_editor_plugin.cpp +++ b/editor/plugins/visual_shader_editor_plugin.cpp @@ -1239,11 +1239,13 @@ void VisualShaderEditor::_port_name_focus_out(Object *line_edit, int p_node_id, List<String> output_names; for (int i = 0; i < node->get_input_port_count(); i++) { - if (!p_output && i == p_port_id) continue; + if (!p_output && i == p_port_id) + continue; input_names.push_back(node->get_input_port_name(i)); } for (int i = 0; i < node->get_output_port_count(); i++) { - if (p_output && i == p_port_id) continue; + if (p_output && i == p_port_id) + continue; output_names.push_back(node->get_output_port_name(i)); } diff --git a/editor/project_manager.cpp b/editor/project_manager.cpp index c918a799e1..2d1eb54e9a 100644 --- a/editor/project_manager.cpp +++ b/editor/project_manager.cpp @@ -2429,12 +2429,24 @@ ProjectManager::ProjectManager() { #endif } break; - case 1: editor_set_scale(0.75); break; - case 2: editor_set_scale(1.0); break; - case 3: editor_set_scale(1.25); break; - case 4: editor_set_scale(1.5); break; - case 5: editor_set_scale(1.75); break; - case 6: editor_set_scale(2.0); break; + case 1: + editor_set_scale(0.75); + break; + case 2: + editor_set_scale(1.0); + break; + case 3: + editor_set_scale(1.25); + break; + case 4: + editor_set_scale(1.5); + break; + case 5: + editor_set_scale(1.75); + break; + case 6: + editor_set_scale(2.0); + break; default: { editor_set_scale(custom_display_scale); diff --git a/editor/project_settings_editor.cpp b/editor/project_settings_editor.cpp index 49c02dc895..e5dd0c148e 100644 --- a/editor/project_settings_editor.cpp +++ b/editor/project_settings_editor.cpp @@ -43,37 +43,45 @@ ProjectSettingsEditor *ProjectSettingsEditor::singleton = nullptr; -static const char *_button_names[JOY_BUTTON_MAX] = { - "DualShock Cross, Xbox A, Nintendo B", - "DualShock Circle, Xbox B, Nintendo A", - "DualShock Square, Xbox X, Nintendo Y", - "DualShock Triangle, Xbox Y, Nintendo X", - "L, L1", - "R, R1", - "L2", - "R2", - "L3", - "R3", - "Select, DualShock Share, Nintendo -", - "Start, DualShock Options, Nintendo +", +static const char *_button_descriptions[JOY_SDL_BUTTONS] = { + "Face Bottom, DualShock Cross, Xbox A, Nintendo B", + "Face Right, DualShock Circle, Xbox B, Nintendo A", + "Face Left, DualShock Square, Xbox X, Nintendo Y", + "Face Top, DualShock Triangle, Xbox Y, Nintendo X", + "DualShock Select, Xbox Back, Nintendo -", + "Home, DualShock PS, Guide", + "Start, Nintendo +", + "Left Stick, DualShock L3, Xbox L/LS", + "Right Stick, DualShock R3, Xbox R/RS", + "Left Shoulder, DualShock L1, Xbox LB", + "Right Shoulder, DualShock R1, Xbox RB", "D-Pad Up", "D-Pad Down", "D-Pad Left", "D-Pad Right" }; -static const char *_axis_names[JOY_AXIS_MAX * 2] = { - " (Left Stick Left)", - " (Left Stick Right)", - " (Left Stick Up)", - " (Left Stick Down)", - " (Right Stick Left)", - " (Right Stick Right)", - " (Right Stick Up)", - " (Right Stick Down)", - "", "", "", "", - "", " (L2)", - "", " (R2)" +static const char *_axis_descriptions[JOY_AXIS_MAX * 2] = { + "Left Stick Left", + "Left Stick Right", + "Left Stick Up", + "Left Stick Down", + "Right Stick Left", + "Right Stick Right", + "Right Stick Up", + "Right Stick Down", + "Joystick 2 Left", + "Joystick 2 Right, Left Trigger, L2, LT", + "Joystick 2 Up", + "Joystick 2 Down, Right Trigger, R2, RT", + "Joystick 3 Left", + "Joystick 3 Right", + "Joystick 3 Up", + "Joystick 3 Down", + "Joystick 4 Left", + "Joystick 4 Right", + "Joystick 4 Up", + "Joystick 4 Down", }; void ProjectSettingsEditor::_unhandled_input(const Ref<InputEvent> &p_event) { @@ -446,11 +454,13 @@ void ProjectSettingsEditor::_show_last_added(const Ref<InputEvent> &p_event, con } child = child->get_next(); } - if (found) break; + if (found) + break; r = r->get_next(); } - if (found) input_editor->ensure_cursor_is_visible(); + if (found) + input_editor->ensure_cursor_is_visible(); } void ProjectSettingsEditor::_wait_for_key(const Ref<InputEvent> &p_event) { @@ -525,9 +535,9 @@ void ProjectSettingsEditor::_add_item(int p_item, Ref<InputEvent> p_exiting_even device_index_label->set_text(TTR("Joypad Axis Index:")); device_index->clear(); for (int i = 0; i < JOY_AXIS_MAX * 2; i++) { - - String desc = _axis_names[i]; - device_index->add_item(TTR("Axis") + " " + itos(i / 2) + " " + ((i & 1) ? "+" : "-") + desc); + String desc = TTR("Axis") + " " + itos(i / 2) + " " + ((i & 1) ? "+" : "-") + + " (" + TTR(_axis_descriptions[i]) + ")"; + device_index->add_item(desc); } device_input->popup_centered(Size2(350, 95) * EDSCALE); @@ -546,10 +556,11 @@ void ProjectSettingsEditor::_add_item(int p_item, Ref<InputEvent> p_exiting_even device_index_label->set_text(TTR("Joypad Button Index:")); device_index->clear(); - for (int i = 0; i < JOY_BUTTON_MAX; i++) { - - device_index->add_item(itos(i) + ": " + String(_button_names[i])); + String desc = TTR("Button") + " " + itos(i); + if (i < JOY_SDL_BUTTONS) + desc += " (" + TTR(_button_descriptions[i]) + ")"; + device_index->add_item(desc); } device_input->popup_centered(Size2(350, 95) * EDSCALE); @@ -790,9 +801,10 @@ void ProjectSettingsEditor::_update_actions() { if (jb.is_valid()) { - String str = _get_device_string(jb->get_device()) + ", " + TTR("Button") + " " + itos(jb->get_button_index()); - if (jb->get_button_index() >= 0 && jb->get_button_index() < JOY_BUTTON_MAX) { - str += String() + " (" + _button_names[jb->get_button_index()] + ")"; + String str = _get_device_string(jb->get_device()) + ", " + + TTR("Button") + " " + itos(jb->get_button_index()); + if (jb->get_button_index() >= 0 && jb->get_button_index() < JOY_SDL_BUTTONS) { + str += String() + " (" + TTR(_button_descriptions[jb->get_button_index()]) + ")"; } action2->set_text(0, str); @@ -804,12 +816,23 @@ void ProjectSettingsEditor::_update_actions() { if (mb.is_valid()) { String str = _get_device_string(mb->get_device()) + ", "; switch (mb->get_button_index()) { - case BUTTON_LEFT: str += TTR("Left Button"); break; - case BUTTON_RIGHT: str += TTR("Right Button"); break; - case BUTTON_MIDDLE: str += TTR("Middle Button"); break; - case BUTTON_WHEEL_UP: str += TTR("Wheel Up"); break; - case BUTTON_WHEEL_DOWN: str += TTR("Wheel Down"); break; - default: str += vformat(TTR("%d Button"), mb->get_button_index()); + case BUTTON_LEFT: + str += TTR("Left Button"); + break; + case BUTTON_RIGHT: + str += TTR("Right Button"); + break; + case BUTTON_MIDDLE: + str += TTR("Middle Button"); + break; + case BUTTON_WHEEL_UP: + str += TTR("Wheel Up"); + break; + case BUTTON_WHEEL_DOWN: + str += TTR("Wheel Down"); + break; + default: + str += vformat(TTR("%d Button"), mb->get_button_index()); } action2->set_text(0, str); @@ -822,8 +845,9 @@ void ProjectSettingsEditor::_update_actions() { int ax = jm->get_axis(); int n = 2 * ax + (jm->get_axis_value() < 0 ? 0 : 1); - String desc = _axis_names[n]; - String str = _get_device_string(jm->get_device()) + ", " + TTR("Axis") + " " + itos(ax) + " " + (jm->get_axis_value() < 0 ? "-" : "+") + desc; + String str = _get_device_string(jm->get_device()) + ", " + + TTR("Axis") + " " + itos(ax) + " " + (jm->get_axis_value() < 0 ? "-" : "+") + + " (" + _axis_descriptions[n] + ")"; action2->set_text(0, str); action2->set_icon(0, input_editor->get_theme_icon("JoyAxis", "EditorIcons")); } @@ -1596,7 +1620,8 @@ void ProjectSettingsEditor::_update_translations() { String n = names[i]; String l = langs[i]; bool is_checked = l_filter.has(l); - if (filter_mode == SHOW_ONLY_SELECTED_LOCALES && !is_checked) continue; + if (filter_mode == SHOW_ONLY_SELECTED_LOCALES && !is_checked) + continue; TreeItem *t = translation_filter->create_item(root); t->set_cell_mode(0, TreeItem::CELL_MODE_CHECK); diff --git a/editor/rename_dialog.h b/editor/rename_dialog.h index 194dd57648..280bd00d95 100644 --- a/editor/rename_dialog.h +++ b/editor/rename_dialog.h @@ -48,7 +48,7 @@ class RenameDialog : public ConfirmationDialog { GDCLASS(RenameDialog, ConfirmationDialog); virtual void ok_pressed() { rename(); }; - void _cancel_pressed(){}; + void _cancel_pressed() {} void _features_toggled(bool pressed); void _insert_text(String text); void _update_substitute(); @@ -111,7 +111,7 @@ public: void rename(); RenameDialog(SceneTreeEditor *p_scene_tree_editor, UndoRedo *p_undo_redo = nullptr); - ~RenameDialog(){}; + ~RenameDialog() {} }; #endif diff --git a/editor/scene_tree_dock.cpp b/editor/scene_tree_dock.cpp index a8aeb05150..6e00e3c291 100644 --- a/editor/scene_tree_dock.cpp +++ b/editor/scene_tree_dock.cpp @@ -97,8 +97,8 @@ void SceneTreeDock::_unhandled_key_input(Ref<InputEvent> p_event) { _tool_selected(TOOL_DUPLICATE); } else if (ED_IS_SHORTCUT("scene_tree/attach_script", p_event)) { _tool_selected(TOOL_ATTACH_SCRIPT); - } else if (ED_IS_SHORTCUT("scene_tree/clear_script", p_event)) { - _tool_selected(TOOL_CLEAR_SCRIPT); + } else if (ED_IS_SHORTCUT("scene_tree/detach_script", p_event)) { + _tool_selected(TOOL_DETACH_SCRIPT); } else if (ED_IS_SHORTCUT("scene_tree/move_up", p_event)) { _tool_selected(TOOL_MOVE_UP); } else if (ED_IS_SHORTCUT("scene_tree/move_down", p_event)) { @@ -426,7 +426,7 @@ void SceneTreeDock::_tool_selected(int p_tool, bool p_confirm_override) { case TOOL_ATTACH_SCRIPT: { attach_script_to_selected(false); } break; - case TOOL_CLEAR_SCRIPT: { + case TOOL_DETACH_SCRIPT: { if (!profile_allow_script_editing) { break; @@ -437,7 +437,7 @@ void SceneTreeDock::_tool_selected(int p_tool, bool p_confirm_override) { if (selection.empty()) return; - editor_data->get_undo_redo().create_action(TTR("Clear Script")); + editor_data->get_undo_redo().create_action(TTR("Detach Script")); editor_data->get_undo_redo().add_do_method(editor, "push_item", (Script *)nullptr); for (int i = 0; i < selection.size(); i++) { @@ -491,8 +491,10 @@ void SceneTreeDock::_tool_selected(int p_tool, bool p_confirm_override) { for (List<Node *>::Element *E = selection.front(); E; E = E->next()) { int index = E->get()->get_index(); - if (index > highest_id) highest_id = index; - if (index < lowest_id) lowest_id = index; + if (index > highest_id) + highest_id = index; + if (index < lowest_id) + lowest_id = index; if (E->get()->get_parent() != common_parent) common_parent = nullptr; @@ -501,8 +503,10 @@ void SceneTreeDock::_tool_selected(int p_tool, bool p_confirm_override) { if (!common_parent || (MOVING_DOWN && highest_id >= common_parent->get_child_count() - MOVING_DOWN) || (MOVING_UP && lowest_id == 0)) break; // one or more nodes can not be moved - if (selection.size() == 1) editor_data->get_undo_redo().create_action(TTR("Move Node In Parent")); - if (selection.size() > 1) editor_data->get_undo_redo().create_action(TTR("Move Nodes In Parent")); + if (selection.size() == 1) + editor_data->get_undo_redo().create_action(TTR("Move Node In Parent")); + if (selection.size() > 1) + editor_data->get_undo_redo().create_action(TTR("Move Nodes In Parent")); for (int i = 0; i < selection.size(); i++) { Node *top_node = selection[i]; @@ -986,8 +990,12 @@ void SceneTreeDock::_tool_selected(int p_tool, bool p_confirm_override) { } } else { switch (p_tool) { - case TOOL_CREATE_2D_SCENE: new_node = memnew(Node2D); break; - case TOOL_CREATE_3D_SCENE: new_node = memnew(Node3D); break; + case TOOL_CREATE_2D_SCENE: + new_node = memnew(Node2D); + break; + case TOOL_CREATE_3D_SCENE: + new_node = memnew(Node3D); + break; case TOOL_CREATE_USER_INTERFACE: { Control *node = memnew(Control); node->set_anchors_and_margins_preset(PRESET_WIDE); //more useful for resizable UIs. @@ -1063,7 +1071,7 @@ void SceneTreeDock::_notification(int p_what) { button_add->set_icon(get_theme_icon("Add", "EditorIcons")); button_instance->set_icon(get_theme_icon("Instance", "EditorIcons")); button_create_script->set_icon(get_theme_icon("ScriptCreate", "EditorIcons")); - button_clear_script->set_icon(get_theme_icon("ScriptRemove", "EditorIcons")); + button_detach_script->set_icon(get_theme_icon("ScriptRemove", "EditorIcons")); filter->set_right_icon(get_theme_icon("Search", "EditorIcons")); filter->set_clear_button_enabled(true); @@ -1140,7 +1148,7 @@ void SceneTreeDock::_notification(int p_what) { button_add->set_icon(get_theme_icon("Add", "EditorIcons")); button_instance->set_icon(get_theme_icon("Instance", "EditorIcons")); button_create_script->set_icon(get_theme_icon("ScriptCreate", "EditorIcons")); - button_clear_script->set_icon(get_theme_icon("ScriptRemove", "EditorIcons")); + button_detach_script->set_icon(get_theme_icon("ScriptRemove", "EditorIcons")); filter->set_right_icon(get_theme_icon("Search", "EditorIcons")); filter->set_clear_button_enabled(true); @@ -1875,18 +1883,18 @@ void SceneTreeDock::_update_script_button() { if (!profile_allow_script_editing) { button_create_script->hide(); - button_clear_script->hide(); + button_detach_script->hide(); } else if (EditorNode::get_singleton()->get_editor_selection()->get_selection().size() == 0) { button_create_script->hide(); - button_clear_script->hide(); + button_detach_script->hide(); } else if (EditorNode::get_singleton()->get_editor_selection()->get_selection().size() == 1) { Node *n = EditorNode::get_singleton()->get_editor_selection()->get_selected_node_list()[0]; if (n->get_script().is_null()) { button_create_script->show(); - button_clear_script->hide(); + button_detach_script->hide(); } else { button_create_script->hide(); - button_clear_script->show(); + button_detach_script->show(); } } else { button_create_script->hide(); @@ -1894,11 +1902,11 @@ void SceneTreeDock::_update_script_button() { for (int i = 0; i < selection.size(); i++) { Node *n = Object::cast_to<Node>(selection[i]); if (!n->get_script().is_null()) { - button_clear_script->show(); + button_detach_script->show(); return; } } - button_clear_script->hide(); + button_detach_script->hide(); } } @@ -2450,7 +2458,7 @@ void SceneTreeDock::_tree_rmb(const Vector2 &p_menu_pos) { } if (existing_script.is_valid() && existing_script_removable) { add_separator = true; - menu->add_icon_shortcut(get_theme_icon("ScriptRemove", "EditorIcons"), ED_GET_SHORTCUT("scene_tree/clear_script"), TOOL_CLEAR_SCRIPT); + menu->add_icon_shortcut(get_theme_icon("ScriptRemove", "EditorIcons"), ED_GET_SHORTCUT("scene_tree/detach_script"), TOOL_DETACH_SCRIPT); } else if (full_selection.size() > 1) { bool script_exists = false; for (List<Node *>::Element *E = full_selection.front(); E; E = E->next()) { @@ -2462,7 +2470,7 @@ void SceneTreeDock::_tree_rmb(const Vector2 &p_menu_pos) { if (script_exists) { add_separator = true; - menu->add_icon_shortcut(get_theme_icon("ScriptRemove", "EditorIcons"), ED_GET_SHORTCUT("scene_tree/clear_script"), TOOL_CLEAR_SCRIPT); + menu->add_icon_shortcut(get_theme_icon("ScriptRemove", "EditorIcons"), ED_GET_SHORTCUT("scene_tree/detach_script"), TOOL_DETACH_SCRIPT); } } @@ -2808,7 +2816,7 @@ SceneTreeDock::SceneTreeDock(EditorNode *p_editor, Node *p_scene_root, EditorSel ED_SHORTCUT("scene_tree/change_node_type", TTR("Change Type")); ED_SHORTCUT("scene_tree/attach_script", TTR("Attach Script")); ED_SHORTCUT("scene_tree/extend_script", TTR("Extend Script")); - ED_SHORTCUT("scene_tree/clear_script", TTR("Clear Script")); + ED_SHORTCUT("scene_tree/detach_script", TTR("Detach Script")); ED_SHORTCUT("scene_tree/move_up", TTR("Move Up"), KEY_MASK_CMD | KEY_UP); ED_SHORTCUT("scene_tree/move_down", TTR("Move Down"), KEY_MASK_CMD | KEY_DOWN); ED_SHORTCUT("scene_tree/duplicate", TTR("Duplicate"), KEY_MASK_CMD | KEY_D); @@ -2843,17 +2851,17 @@ SceneTreeDock::SceneTreeDock(EditorNode *p_editor, Node *p_scene_root, EditorSel button_create_script = memnew(ToolButton); button_create_script->connect("pressed", callable_mp(this, &SceneTreeDock::_tool_selected), make_binds(TOOL_ATTACH_SCRIPT, false)); - button_create_script->set_tooltip(TTR("Attach a new or existing script for the selected node.")); + button_create_script->set_tooltip(TTR("Attach a new or existing script to the selected node.")); button_create_script->set_shortcut(ED_GET_SHORTCUT("scene_tree/attach_script")); filter_hbc->add_child(button_create_script); button_create_script->hide(); - button_clear_script = memnew(ToolButton); - button_clear_script->connect("pressed", callable_mp(this, &SceneTreeDock::_tool_selected), make_binds(TOOL_CLEAR_SCRIPT, false)); - button_clear_script->set_tooltip(TTR("Clear a script for the selected node.")); - button_clear_script->set_shortcut(ED_GET_SHORTCUT("scene_tree/clear_script")); - filter_hbc->add_child(button_clear_script); - button_clear_script->hide(); + button_detach_script = memnew(ToolButton); + button_detach_script->connect("pressed", callable_mp(this, &SceneTreeDock::_tool_selected), make_binds(TOOL_DETACH_SCRIPT, false)); + button_detach_script->set_tooltip(TTR("Detach the script from the selected node.")); + button_detach_script->set_shortcut(ED_GET_SHORTCUT("scene_tree/detach_script")); + filter_hbc->add_child(button_detach_script); + button_detach_script->hide(); button_hb = memnew(HBoxContainer); vbc->add_child(button_hb); diff --git a/editor/scene_tree_dock.h b/editor/scene_tree_dock.h index 31ef1ce7d0..00b95c9853 100644 --- a/editor/scene_tree_dock.h +++ b/editor/scene_tree_dock.h @@ -66,7 +66,7 @@ class SceneTreeDock : public VBoxContainer { TOOL_REPLACE, TOOL_EXTEND_SCRIPT, TOOL_ATTACH_SCRIPT, - TOOL_CLEAR_SCRIPT, + TOOL_DETACH_SCRIPT, TOOL_MOVE_UP, TOOL_MOVE_DOWN, TOOL_DUPLICATE, @@ -110,7 +110,7 @@ class SceneTreeDock : public VBoxContainer { ToolButton *button_add; ToolButton *button_instance; ToolButton *button_create_script; - ToolButton *button_clear_script; + ToolButton *button_detach_script; Button *button_3d; diff --git a/editor/script_create_dialog.cpp b/editor/script_create_dialog.cpp index 12b21d871b..f84b7e73ed 100644 --- a/editor/script_create_dialog.cpp +++ b/editor/script_create_dialog.cpp @@ -167,11 +167,14 @@ String ScriptCreateDialog::_validate_path(const String &p_path, bool p_file_must String p = p_path.strip_edges(); - if (p == "") return TTR("Path is empty."); - if (p.get_file().get_basename() == "") return TTR("Filename is empty."); + if (p == "") + return TTR("Path is empty."); + if (p.get_file().get_basename() == "") + return TTR("Filename is empty."); p = ProjectSettings::get_singleton()->localize_path(p); - if (!p.begins_with("res://")) return TTR("Path is not local."); + if (!p.begins_with("res://")) + return TTR("Path is not local."); DirAccess *d = DirAccess::create(DirAccess::ACCESS_RESOURCES); if (d->change_dir(p.get_base_dir()) != OK) { @@ -216,12 +219,15 @@ String ScriptCreateDialog::_validate_path(const String &p_path, bool p_file_must index++; } - if (!found) return TTR("Invalid extension."); - if (!match) return TTR("Wrong extension chosen."); + if (!found) + return TTR("Invalid extension."); + if (!match) + return TTR("Wrong extension chosen."); /* Let ScriptLanguage do custom validation */ String path_error = ScriptServer::get_language(language_menu->get_selected())->validate_path(p); - if (path_error != "") return path_error; + if (path_error != "") + return path_error; /* All checks passed */ return ""; diff --git a/editor/shader_globals_editor.cpp b/editor/shader_globals_editor.cpp index 566ac54612..a013ba2ad3 100644 --- a/editor/shader_globals_editor.cpp +++ b/editor/shader_globals_editor.cpp @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_globals_editor.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #include "shader_globals_editor.h" #include "editor_node.h" @@ -447,6 +477,6 @@ ShaderGlobalsEditor::ShaderGlobalsEditor() { interface->connect("var_changed", Callable(this, "_changed")); } ShaderGlobalsEditor::~ShaderGlobalsEditor() { - inspector->edit(NULL); + inspector->edit(nullptr); memdelete(interface); } diff --git a/editor/shader_globals_editor.h b/editor/shader_globals_editor.h index 59cdeddd8d..b3dbddc87e 100644 --- a/editor/shader_globals_editor.h +++ b/editor/shader_globals_editor.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_globals_editor.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef SHADER_GLOBALS_EDITOR_H #define SHADER_GLOBALS_EDITOR_H diff --git a/gles_builders.py b/gles_builders.py index 6ff2f4248b..85a8d7aa15 100644 --- a/gles_builders.py +++ b/gles_builders.py @@ -740,5 +740,69 @@ def build_rd_headers(target, source, env): build_rd_header(str(x)) +class RAWHeaderStruct: + def __init__(self): + self.code = "" + + +def include_file_in_raw_header(filename, header_data, depth): + fs = open(filename, "r") + line = fs.readline() + text = "" + + while line: + + while line.find("#include ") != -1: + includeline = line.replace("#include ", "").strip()[1:-1] + + import os.path + + included_file = os.path.relpath(os.path.dirname(filename) + "/" + includeline) + include_file_in_raw_header(included_file, header_data, depth + 1) + + line = fs.readline() + + header_data.code += line + line = fs.readline() + + fs.close() + + +def build_raw_header(filename): + header_data = RAWHeaderStruct() + include_file_in_raw_header(filename, header_data, 0) + + out_file = filename + ".gen.h" + fd = open(out_file, "w") + + enum_constants = [] + + fd.write("/* WARNING, THIS FILE WAS GENERATED, DO NOT EDIT */\n") + + out_file_base = out_file.replace(".glsl.gen.h", "_shader_glsl") + out_file_base = out_file_base[out_file_base.rfind("/") + 1 :] + out_file_base = out_file_base[out_file_base.rfind("\\") + 1 :] + out_file_ifdef = out_file_base.replace(".", "_").upper() + fd.write("#ifndef " + out_file_ifdef + "_RAW_H\n") + fd.write("#define " + out_file_ifdef + "_RAW_H\n") + fd.write("\n") + fd.write("static const char " + out_file_base + "[] = {\n") + for c in header_data.code: + fd.write(str(ord(c)) + ",") + fd.write("\t\t0};\n\n") + fd.write("#endif\n") + fd.close() + + +def build_rd_headers(target, source, env): + for x in source: + build_rd_header(str(x)) + + +def build_raw_headers(target, source, env): + for x in source: + build_raw_header(str(x)) + + if __name__ == "__main__": subprocess_main(globals()) diff --git a/main/main.cpp b/main/main.cpp index 95449dd5cc..958d964b35 100644 --- a/main/main.cpp +++ b/main/main.cpp @@ -324,7 +324,7 @@ void Main::print_help(const char *p_binary) { OS::get_singleton()->print(" -b, --breakpoints Breakpoint list as source::line comma-separated pairs, no spaces (use %%20 instead).\n"); OS::get_singleton()->print(" --profiling Enable profiling in the script debugger.\n"); OS::get_singleton()->print(" --gpu-abort Abort on GPU errors (usually validation layer errors), may help see the problem if your system freezes.\n"); - OS::get_singleton()->print(" --remote-debug <address> Remote debug (<host/IP>:<port> address).\n"); + OS::get_singleton()->print(" --remote-debug <uri> Remote debug (<protocol>://<host/IP>[:<port>], e.g. tcp://127.0.0.1:6007).\n"); #if defined(DEBUG_ENABLED) && !defined(SERVER_ENABLED) OS::get_singleton()->print(" --debug-collisions Show collision shapes when running the scene.\n"); OS::get_singleton()->print(" --debug-navigation Show navigation polygons when running the scene.\n"); @@ -844,11 +844,10 @@ Error Main::setup(const char *execpath, int argc, char *argv[], bool p_second_ph if (I->next()) { debug_uri = I->next()->get(); - if (debug_uri.find(":") == -1) { // wrong address - OS::get_singleton()->print("Invalid debug host address, it should be of the form <host/IP>:<port>.\n"); + if (debug_uri.find("://") == -1) { // wrong address + OS::get_singleton()->print("Invalid debug host address, it should be of the form <protocol>://<host/IP>:<port>.\n"); goto error; } - debug_uri = "tcp://" + debug_uri; // will support multiple protocols eventually. N = I->next()->next(); } else { OS::get_singleton()->print("Missing remote debug host address, aborting.\n"); @@ -2011,7 +2010,7 @@ bool Main::start() { if (!project_manager && !editor) { // game // Load SSL Certificates from Project Settings (or builtin). - Crypto::load_default_certificates(GLOBAL_DEF("network/ssl/certificates", "")); + Crypto::load_default_certificates(GLOBAL_DEF("network/ssl/certificate_bundle_override", "")); if (game_path != "") { Node *scene = nullptr; @@ -2063,9 +2062,11 @@ bool Main::start() { } if (project_manager || editor) { - // Hide console window if requested (Windows-only). - bool hide_console = EditorSettings::get_singleton()->get_setting("interface/editor/hide_console_window"); - DisplayServer::get_singleton()->console_set_visible(!hide_console); + if (DisplayServer::get_singleton()->has_feature(DisplayServer::FEATURE_CONSOLE_WINDOW)) { + // Hide console window if requested (Windows-only). + bool hide_console = EditorSettings::get_singleton()->get_setting("interface/editor/hide_console_window"); + DisplayServer::get_singleton()->console_set_visible(!hide_console); + } // Load SSL Certificates from Editor Settings (or builtin) Crypto::load_default_certificates(EditorSettings::get_singleton()->get_setting("network/ssl/editor_ssl_certificates").operator String()); @@ -2260,7 +2261,8 @@ bool Main::iteration() { uint64_t time_step = 1000000L / target_fps; target_ticks += time_step; uint64_t current_ticks = OS::get_singleton()->get_ticks_usec(); - if (current_ticks < target_ticks) OS::get_singleton()->delay_usec(target_ticks - current_ticks); + if (current_ticks < target_ticks) + OS::get_singleton()->delay_usec(target_ticks - current_ticks); current_ticks = OS::get_singleton()->get_ticks_usec(); target_ticks = MIN(MAX(target_ticks, current_ticks - time_step), current_ticks + time_step); } diff --git a/main/main_timer_sync.cpp b/main/main_timer_sync.cpp index 9e23a1f5cb..469ef6f20a 100644 --- a/main/main_timer_sync.cpp +++ b/main/main_timer_sync.cpp @@ -193,12 +193,7 @@ float MainTimerSync::get_cpu_idle_step() { return cpu_ticks_elapsed / 1000000.0; } -MainTimerSync::MainTimerSync() : - last_cpu_ticks_usec(0), - current_cpu_ticks_usec(0), - time_accum(0), - time_deficit(0), - fixed_fps(0) { +MainTimerSync::MainTimerSync() { for (int i = CONTROL_STEPS - 1; i >= 0; --i) { typical_physics_steps[i] = i; accumulated_physics_steps[i] = i; diff --git a/main/main_timer_sync.h b/main/main_timer_sync.h index 620d1c747d..2126381c7c 100644 --- a/main/main_timer_sync.h +++ b/main/main_timer_sync.h @@ -43,14 +43,14 @@ struct MainFrameTime { class MainTimerSync { // wall clock time measured on the main thread - uint64_t last_cpu_ticks_usec; - uint64_t current_cpu_ticks_usec; + uint64_t last_cpu_ticks_usec = 0; + uint64_t current_cpu_ticks_usec = 0; // logical game time since last physics timestep - float time_accum; + float time_accum = 0; // current difference between wall clock time and reported sum of idle_steps - float time_deficit; + float time_deficit = 0; // number of frames back for keeping accumulated physics steps roughly constant. // value of 12 chosen because that is what is required to make 144 Hz monitors @@ -64,7 +64,7 @@ class MainTimerSync { // typical value for accumulated_physics_steps[i] is either this or this plus one int typical_physics_steps[CONTROL_STEPS]; - int fixed_fps; + int fixed_fps = 0; protected: // returns the fraction of p_frame_slice required for the timer to overshoot diff --git a/main/performance.cpp b/main/performance.cpp index 3ca7d7bed8..3de2cba125 100644 --- a/main/performance.cpp +++ b/main/performance.cpp @@ -125,33 +125,60 @@ String Performance::get_monitor_name(Monitor p_monitor) const { float Performance::get_monitor(Monitor p_monitor) const { switch (p_monitor) { - case TIME_FPS: return Engine::get_singleton()->get_frames_per_second(); - case TIME_PROCESS: return _process_time; - case TIME_PHYSICS_PROCESS: return _physics_process_time; - case MEMORY_STATIC: return Memory::get_mem_usage(); - case MEMORY_STATIC_MAX: return Memory::get_mem_max_usage(); - case MEMORY_MESSAGE_BUFFER_MAX: return MessageQueue::get_singleton()->get_max_buffer_usage(); - case OBJECT_COUNT: return ObjectDB::get_object_count(); - case OBJECT_RESOURCE_COUNT: return ResourceCache::get_cached_resource_count(); - case OBJECT_NODE_COUNT: return _get_node_count(); - case OBJECT_ORPHAN_NODE_COUNT: return Node::orphan_node_count; - case RENDER_OBJECTS_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_OBJECTS_IN_FRAME); - case RENDER_VERTICES_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_VERTICES_IN_FRAME); - case RENDER_MATERIAL_CHANGES_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_MATERIAL_CHANGES_IN_FRAME); - case RENDER_SHADER_CHANGES_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_SHADER_CHANGES_IN_FRAME); - case RENDER_SURFACE_CHANGES_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_SURFACE_CHANGES_IN_FRAME); - case RENDER_DRAW_CALLS_IN_FRAME: return RS::get_singleton()->get_render_info(RS::INFO_DRAW_CALLS_IN_FRAME); - case RENDER_VIDEO_MEM_USED: return RS::get_singleton()->get_render_info(RS::INFO_VIDEO_MEM_USED); - case RENDER_TEXTURE_MEM_USED: return RS::get_singleton()->get_render_info(RS::INFO_TEXTURE_MEM_USED); - case RENDER_VERTEX_MEM_USED: return RS::get_singleton()->get_render_info(RS::INFO_VERTEX_MEM_USED); - case RENDER_USAGE_VIDEO_MEM_TOTAL: return RS::get_singleton()->get_render_info(RS::INFO_USAGE_VIDEO_MEM_TOTAL); - case PHYSICS_2D_ACTIVE_OBJECTS: return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_ACTIVE_OBJECTS); - case PHYSICS_2D_COLLISION_PAIRS: return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_COLLISION_PAIRS); - case PHYSICS_2D_ISLAND_COUNT: return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_ISLAND_COUNT); - case PHYSICS_3D_ACTIVE_OBJECTS: return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_ACTIVE_OBJECTS); - case PHYSICS_3D_COLLISION_PAIRS: return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_COLLISION_PAIRS); - case PHYSICS_3D_ISLAND_COUNT: return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_ISLAND_COUNT); - case AUDIO_OUTPUT_LATENCY: return AudioServer::get_singleton()->get_output_latency(); + case TIME_FPS: + return Engine::get_singleton()->get_frames_per_second(); + case TIME_PROCESS: + return _process_time; + case TIME_PHYSICS_PROCESS: + return _physics_process_time; + case MEMORY_STATIC: + return Memory::get_mem_usage(); + case MEMORY_STATIC_MAX: + return Memory::get_mem_max_usage(); + case MEMORY_MESSAGE_BUFFER_MAX: + return MessageQueue::get_singleton()->get_max_buffer_usage(); + case OBJECT_COUNT: + return ObjectDB::get_object_count(); + case OBJECT_RESOURCE_COUNT: + return ResourceCache::get_cached_resource_count(); + case OBJECT_NODE_COUNT: + return _get_node_count(); + case OBJECT_ORPHAN_NODE_COUNT: + return Node::orphan_node_count; + case RENDER_OBJECTS_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_OBJECTS_IN_FRAME); + case RENDER_VERTICES_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_VERTICES_IN_FRAME); + case RENDER_MATERIAL_CHANGES_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_MATERIAL_CHANGES_IN_FRAME); + case RENDER_SHADER_CHANGES_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_SHADER_CHANGES_IN_FRAME); + case RENDER_SURFACE_CHANGES_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_SURFACE_CHANGES_IN_FRAME); + case RENDER_DRAW_CALLS_IN_FRAME: + return RS::get_singleton()->get_render_info(RS::INFO_DRAW_CALLS_IN_FRAME); + case RENDER_VIDEO_MEM_USED: + return RS::get_singleton()->get_render_info(RS::INFO_VIDEO_MEM_USED); + case RENDER_TEXTURE_MEM_USED: + return RS::get_singleton()->get_render_info(RS::INFO_TEXTURE_MEM_USED); + case RENDER_VERTEX_MEM_USED: + return RS::get_singleton()->get_render_info(RS::INFO_VERTEX_MEM_USED); + case RENDER_USAGE_VIDEO_MEM_TOTAL: + return RS::get_singleton()->get_render_info(RS::INFO_USAGE_VIDEO_MEM_TOTAL); + case PHYSICS_2D_ACTIVE_OBJECTS: + return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_ACTIVE_OBJECTS); + case PHYSICS_2D_COLLISION_PAIRS: + return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_COLLISION_PAIRS); + case PHYSICS_2D_ISLAND_COUNT: + return PhysicsServer2D::get_singleton()->get_process_info(PhysicsServer2D::INFO_ISLAND_COUNT); + case PHYSICS_3D_ACTIVE_OBJECTS: + return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_ACTIVE_OBJECTS); + case PHYSICS_3D_COLLISION_PAIRS: + return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_COLLISION_PAIRS); + case PHYSICS_3D_ISLAND_COUNT: + return PhysicsServer3D::get_singleton()->get_process_info(PhysicsServer3D::INFO_ISLAND_COUNT); + case AUDIO_OUTPUT_LATENCY: + return AudioServer::get_singleton()->get_output_latency(); default: { } diff --git a/main/tests/test_astar.cpp b/main/tests/test_astar.cpp index e0b4a7f2c8..3cc754e230 100644 --- a/main/tests/test_astar.cpp +++ b/main/tests/test_astar.cpp @@ -173,7 +173,8 @@ bool test_add_remove() { for (int i = 0; i < 20000; i++) { int u = Math::rand() % 5; int v = Math::rand() % 4; - if (u == v) v = 4; + if (u == v) + v = 4; if (Math::rand() % 2 == 1) { // Add a (possibly existing) directed edge and confirm connectivity a.connect_points(u, v, false); @@ -195,7 +196,8 @@ bool test_add_remove() { for (int j = 0; j < 10; j++) { int u = Math::rand() % 5; int v = Math::rand() % 4; - if (u == v) v = 4; + if (u == v) + v = 4; if (Math::rand() % 2 == 1) a.connect_points(u, v, false); else @@ -239,7 +241,8 @@ bool test_solutions() { int u, v; u = Math::rand() % N; v = Math::rand() % (N - 1); - if (u == v) v = N - 1; + if (u == v) + v = N - 1; // Pick a random operation int op = Math::rand(); @@ -253,14 +256,16 @@ bool test_solutions() { // Add edge (u, v); possibly bidirectional a.connect_points(u, v, op % 2); adj[u][v] = true; - if (op % 2) adj[v][u] = true; + if (op % 2) + adj[v][u] = true; break; case 6: case 7: // Remove edge (u, v); possibly bidirectional a.disconnect_points(u, v, op % 2); adj[u][v] = false; - if (op % 2) adj[v][u] = false; + if (op % 2) + adj[v][u] = false; break; case 8: // Remove point u and add it back; clears adjacent edges and changes coordinates @@ -291,12 +296,14 @@ bool test_solutions() { int count = 0; for (int u = 0; u < N; u++) for (int v = 0; v < N; v++) - if (adj[u][v]) count++; + if (adj[u][v]) + count++; printf("Test #%4d: %3d edges, ", test + 1, count); count = 0; for (int u = 0; u < N; u++) for (int v = 0; v < N; v++) - if (!Math::is_inf(d[u][v])) count++; + if (!Math::is_inf(d[u][v])) + count++; printf("%3d/%d pairs of reachable points\n", count - N, N * (N - 1)); // Check A*'s output @@ -339,12 +346,13 @@ bool test_solutions() { } exit: - if (!match) return false; + if (!match) + return false; } return true; } -typedef bool (*TestFunc)(void); +typedef bool (*TestFunc)(); TestFunc test_funcs[] = { test_abc, diff --git a/main/tests/test_math.cpp b/main/tests/test_math.cpp index 38f7371802..fbd1aa275a 100644 --- a/main/tests/test_math.cpp +++ b/main/tests/test_math.cpp @@ -32,8 +32,10 @@ #include "core/math/basis.h" #include "core/math/camera_matrix.h" +#include "core/math/delaunay_3d.h" #include "core/math/math_funcs.h" #include "core/math/transform.h" +#include "core/method_ptrcall.h" #include "core/os/file_access.h" #include "core/os/keyboard.h" #include "core/os/os.h" @@ -45,8 +47,6 @@ #include "scene/resources/texture.h" #include "servers/rendering/shader_language.h" -#include "core/method_ptrcall.h" - namespace TestMath { class GetClassAndNamespace { @@ -199,14 +199,24 @@ class GetClassAndNamespace { switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; case 'r': res = 13; break; - case '\"': res = '\"'; break; + case '\"': + res = '\"'; + break; case '\\': res = '\\'; break; @@ -405,6 +415,55 @@ uint32_t ihash3(uint32_t a) { MainLoop *test() { { + Vector<Vector3> points; + points.push_back(Vector3(0, 0, 0)); + points.push_back(Vector3(0, 0, 1)); + points.push_back(Vector3(0, 1, 0)); + points.push_back(Vector3(0, 1, 1)); + points.push_back(Vector3(1, 1, 0)); + points.push_back(Vector3(1, 0, 0)); + points.push_back(Vector3(1, 0, 1)); + points.push_back(Vector3(1, 1, 1)); + + for (int i = 0; i < 800; i++) { + points.push_back(Vector3(Math::randf() * 2.0 - 1.0, Math::randf() * 2.0 - 1.0, Math::randf() * 2.0 - 1.0) * Vector3(25, 30, 33)); + } + + Vector<Delaunay3D::OutputSimplex> os = Delaunay3D::tetrahedralize(points); + print_line("simplices in the end: " + itos(os.size())); + for (int i = 0; i < os.size(); i++) { + print_line("Simplex " + itos(i) + ": "); + print_line(points[os[i].points[0]]); + print_line(points[os[i].points[1]]); + print_line(points[os[i].points[2]]); + print_line(points[os[i].points[3]]); + } + + { + FileAccessRef f = FileAccess::open("res://bsp.obj", FileAccess::WRITE); + for (int i = 0; i < os.size(); i++) { + f->store_line("o Simplex" + itos(i)); + for (int j = 0; j < 4; j++) { + f->store_line(vformat("v %f %f %f", points[os[i].points[j]].x, points[os[i].points[j]].y, points[os[i].points[j]].z)); + } + static const int face_order[4][3] = { + { 1, 2, 3 }, + { 1, 3, 4 }, + { 1, 2, 4 }, + { 2, 3, 4 } + }; + + for (int j = 0; j < 4; j++) { + f->store_line(vformat("f %d %d %d", 4 * i + face_order[j][0], 4 * i + face_order[j][1], 4 * i + face_order[j][2])); + } + } + f->close(); + } + + return nullptr; + } + + { float r = 1; float g = 0.5; float b = 0.1; diff --git a/main/tests/test_ordered_hash_map.cpp b/main/tests/test_ordered_hash_map.cpp index 0720eebaf9..e909626243 100644 --- a/main/tests/test_ordered_hash_map.cpp +++ b/main/tests/test_ordered_hash_map.cpp @@ -130,7 +130,7 @@ bool test_const_iteration() { return test_const_iteration(map); } -typedef bool (*TestFunc)(void); +typedef bool (*TestFunc)(); TestFunc test_funcs[] = { diff --git a/main/tests/test_shader_lang.cpp b/main/tests/test_shader_lang.cpp index dbe2da86cf..abcf30c97f 100644 --- a/main/tests/test_shader_lang.cpp +++ b/main/tests/test_shader_lang.cpp @@ -61,10 +61,14 @@ static String _typestr(SL::DataType p_type) { static String _prestr(SL::DataPrecision p_pres) { switch (p_pres) { - case SL::PRECISION_LOWP: return "lowp "; - case SL::PRECISION_MEDIUMP: return "mediump "; - case SL::PRECISION_HIGHP: return "highp "; - case SL::PRECISION_DEFAULT: return ""; + case SL::PRECISION_LOWP: + return "lowp "; + case SL::PRECISION_MEDIUMP: + return "mediump "; + case SL::PRECISION_HIGHP: + return "highp "; + case SL::PRECISION_DEFAULT: + return ""; } return ""; } @@ -77,23 +81,40 @@ static String _opstr(SL::Operator p_op) { static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNode::Value> &p_values) { switch (p_type) { - case SL::TYPE_BOOL: return p_values[0].boolean ? "true" : "false"; - case SL::TYPE_BVEC2: return String() + "bvec2(" + (p_values[0].boolean ? "true" : "false") + (p_values[1].boolean ? "true" : "false") + ")"; - case SL::TYPE_BVEC3: return String() + "bvec3(" + (p_values[0].boolean ? "true" : "false") + "," + (p_values[1].boolean ? "true" : "false") + "," + (p_values[2].boolean ? "true" : "false") + ")"; - case SL::TYPE_BVEC4: return String() + "bvec4(" + (p_values[0].boolean ? "true" : "false") + "," + (p_values[1].boolean ? "true" : "false") + "," + (p_values[2].boolean ? "true" : "false") + "," + (p_values[3].boolean ? "true" : "false") + ")"; - case SL::TYPE_INT: return rtos(p_values[0].sint); - case SL::TYPE_IVEC2: return String() + "ivec2(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + ")"; - case SL::TYPE_IVEC3: return String() + "ivec3(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + "," + rtos(p_values[2].sint) + ")"; - case SL::TYPE_IVEC4: return String() + "ivec4(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + "," + rtos(p_values[2].sint) + "," + rtos(p_values[3].sint) + ")"; - case SL::TYPE_UINT: return rtos(p_values[0].real); - case SL::TYPE_UVEC2: return String() + "uvec2(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + ")"; - case SL::TYPE_UVEC3: return String() + "uvec3(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + ")"; - case SL::TYPE_UVEC4: return String() + "uvec4(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + "," + rtos(p_values[3].real) + ")"; - case SL::TYPE_FLOAT: return rtos(p_values[0].real); - case SL::TYPE_VEC2: return String() + "vec2(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + ")"; - case SL::TYPE_VEC3: return String() + "vec3(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + ")"; - case SL::TYPE_VEC4: return String() + "vec4(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + "," + rtos(p_values[3].real) + ")"; - default: ERR_FAIL_V(String()); + case SL::TYPE_BOOL: + return p_values[0].boolean ? "true" : "false"; + case SL::TYPE_BVEC2: + return String() + "bvec2(" + (p_values[0].boolean ? "true" : "false") + (p_values[1].boolean ? "true" : "false") + ")"; + case SL::TYPE_BVEC3: + return String() + "bvec3(" + (p_values[0].boolean ? "true" : "false") + "," + (p_values[1].boolean ? "true" : "false") + "," + (p_values[2].boolean ? "true" : "false") + ")"; + case SL::TYPE_BVEC4: + return String() + "bvec4(" + (p_values[0].boolean ? "true" : "false") + "," + (p_values[1].boolean ? "true" : "false") + "," + (p_values[2].boolean ? "true" : "false") + "," + (p_values[3].boolean ? "true" : "false") + ")"; + case SL::TYPE_INT: + return rtos(p_values[0].sint); + case SL::TYPE_IVEC2: + return String() + "ivec2(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + ")"; + case SL::TYPE_IVEC3: + return String() + "ivec3(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + "," + rtos(p_values[2].sint) + ")"; + case SL::TYPE_IVEC4: + return String() + "ivec4(" + rtos(p_values[0].sint) + "," + rtos(p_values[1].sint) + "," + rtos(p_values[2].sint) + "," + rtos(p_values[3].sint) + ")"; + case SL::TYPE_UINT: + return rtos(p_values[0].real); + case SL::TYPE_UVEC2: + return String() + "uvec2(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + ")"; + case SL::TYPE_UVEC3: + return String() + "uvec3(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + ")"; + case SL::TYPE_UVEC4: + return String() + "uvec4(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + "," + rtos(p_values[3].real) + ")"; + case SL::TYPE_FLOAT: + return rtos(p_values[0].real); + case SL::TYPE_VEC2: + return String() + "vec2(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + ")"; + case SL::TYPE_VEC3: + return String() + "vec3(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + ")"; + case SL::TYPE_VEC4: + return String() + "vec4(" + rtos(p_values[0].real) + "," + rtos(p_values[1].real) + "," + rtos(p_values[2].real) + "," + rtos(p_values[3].real) + ")"; + default: + ERR_FAIL_V(String()); } } @@ -342,7 +363,7 @@ MainLoop *test() { Set<String> types; types.insert("spatial"); - Error err = sl.compile(code, dt, rm, types, NULL); + Error err = sl.compile(code, dt, rm, types, nullptr); if (err) { diff --git a/main/tests/test_string.cpp b/main/tests/test_string.cpp index 7438e2bae9..76631f9eae 100644 --- a/main/tests/test_string.cpp +++ b/main/tests/test_string.cpp @@ -972,22 +972,26 @@ bool test_31() { String a = ""; success = a[0] == 0; OS::get_singleton()->print("Is 0 String[0]:, %s\n", success ? "OK" : "FAIL"); - if (!success) state = false; + if (!success) + state = false; String b = "Godot"; success = b[b.size()] == 0; OS::get_singleton()->print("Is 0 String[size()]:, %s\n", success ? "OK" : "FAIL"); - if (!success) state = false; + if (!success) + state = false; const String c = ""; success = c[0] == 0; OS::get_singleton()->print("Is 0 const String[0]:, %s\n", success ? "OK" : "FAIL"); - if (!success) state = false; + if (!success) + state = false; const String d = "Godot"; success = d[d.size()] == 0; OS::get_singleton()->print("Is 0 const String[size()]:, %s\n", success ? "OK" : "FAIL"); - if (!success) state = false; + if (!success) + state = false; return state; }; @@ -1125,7 +1129,7 @@ bool test_35() { return state; } -typedef bool (*TestFunc)(void); +typedef bool (*TestFunc)(); TestFunc test_funcs[] = { diff --git a/misc/dist/html/fixed-size.html b/misc/dist/html/fixed-size.html index e7a23b3f29..a5633115d5 100644 --- a/misc/dist/html/fixed-size.html +++ b/misc/dist/html/fixed-size.html @@ -232,6 +232,7 @@ $GODOT_HEAD_INCLUDE const EXECUTABLE_NAME = '$GODOT_BASENAME'; const MAIN_PACK = '$GODOT_BASENAME.pck'; + const EXTRA_ARGS = JSON.parse('$GODOT_ARGS'); const DEBUG_ENABLED = $GODOT_DEBUG_ENABLED; const INDETERMINATE_STATUS_STEP_MS = 100; @@ -382,7 +383,7 @@ $GODOT_HEAD_INCLUDE } else { setStatusMode('indeterminate'); engine.setCanvas(canvas); - engine.startGame(EXECUTABLE_NAME, MAIN_PACK).then(() => { + engine.startGame(EXECUTABLE_NAME, MAIN_PACK, EXTRA_ARGS).then(() => { setStatusMode('hidden'); initializing = false; }, displayFailureNotice); diff --git a/misc/dist/html/full-size.html b/misc/dist/html/full-size.html index 193f2a6aad..435013cb5e 100644 --- a/misc/dist/html/full-size.html +++ b/misc/dist/html/full-size.html @@ -145,6 +145,7 @@ $GODOT_HEAD_INCLUDE const EXECUTABLE_NAME = '$GODOT_BASENAME'; const MAIN_PACK = '$GODOT_BASENAME.pck'; + const EXTRA_ARGS = JSON.parse('$GODOT_ARGS'); const INDETERMINATE_STATUS_STEP_MS = 100; var canvas = document.getElementById('canvas'); @@ -169,8 +170,6 @@ $GODOT_HEAD_INCLUDE var height = window.innerHeight; canvas.width = width * scale; canvas.height = height * scale; - canvas.style.width = width + "px"; - canvas.style.height = height + "px"; } animationCallbacks.push(adjustCanvasDimensions); adjustCanvasDimensions(); @@ -256,7 +255,7 @@ $GODOT_HEAD_INCLUDE } else { setStatusMode('indeterminate'); engine.setCanvas(canvas); - engine.startGame(EXECUTABLE_NAME, MAIN_PACK).then(() => { + engine.startGame(EXECUTABLE_NAME, MAIN_PACK, EXTRA_ARGS).then(() => { setStatusMode('hidden'); initializing = false; }, displayFailureNotice); diff --git a/misc/scons/site_tools/compilation_db.py b/misc/scons/site_tools/compilation_db.py new file mode 100644 index 0000000000..87db32adc9 --- /dev/null +++ b/misc/scons/site_tools/compilation_db.py @@ -0,0 +1,177 @@ +# Copyright 2015 MongoDB Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import SCons +import itertools + +# Implements the ability for SCons to emit a compilation database for the MongoDB project. See +# http://clang.llvm.org/docs/JSONCompilationDatabase.html for details on what a compilation +# database is, and why you might want one. The only user visible entry point here is +# 'env.CompilationDatabase'. This method takes an optional 'target' to name the file that +# should hold the compilation database, otherwise, the file defaults to compile_commands.json, +# which is the name that most clang tools search for by default. + +# TODO: Is there a better way to do this than this global? Right now this exists so that the +# emitter we add can record all of the things it emits, so that the scanner for the top level +# compilation database can access the complete list, and also so that the writer has easy +# access to write all of the files. But it seems clunky. How can the emitter and the scanner +# communicate more gracefully? +__COMPILATION_DB_ENTRIES = [] + +# We make no effort to avoid rebuilding the entries. Someday, perhaps we could and even +# integrate with the cache, but there doesn't seem to be much call for it. +class __CompilationDbNode(SCons.Node.Python.Value): + def __init__(self, value): + SCons.Node.Python.Value.__init__(self, value) + self.Decider(changed_since_last_build_node) + + +def changed_since_last_build_node(child, target, prev_ni, node): + """ Dummy decider to force always building""" + return True + + +def makeEmitCompilationDbEntry(comstr): + """ + Effectively this creates a lambda function to capture: + * command line + * source + * target + :param comstr: unevaluated command line + :return: an emitter which has captured the above + """ + user_action = SCons.Action.Action(comstr) + + def EmitCompilationDbEntry(target, source, env): + """ + This emitter will be added to each c/c++ object build to capture the info needed + for clang tools + :param target: target node(s) + :param source: source node(s) + :param env: Environment for use building this node + :return: target(s), source(s) + """ + + dbtarget = __CompilationDbNode(source) + + entry = env.__COMPILATIONDB_Entry( + target=dbtarget, + source=[], + __COMPILATIONDB_UTARGET=target, + __COMPILATIONDB_USOURCE=source, + __COMPILATIONDB_UACTION=user_action, + __COMPILATIONDB_ENV=env, + ) + + # TODO: Technically, these next two lines should not be required: it should be fine to + # cache the entries. However, they don't seem to update properly. Since they are quick + # to re-generate disable caching and sidestep this problem. + env.AlwaysBuild(entry) + env.NoCache(entry) + + __COMPILATION_DB_ENTRIES.append(dbtarget) + + return target, source + + return EmitCompilationDbEntry + + +def CompilationDbEntryAction(target, source, env, **kw): + """ + Create a dictionary with evaluated command line, target, source + and store that info as an attribute on the target + (Which has been stored in __COMPILATION_DB_ENTRIES array + :param target: target node(s) + :param source: source node(s) + :param env: Environment for use building this node + :param kw: + :return: None + """ + + command = env["__COMPILATIONDB_UACTION"].strfunction( + target=env["__COMPILATIONDB_UTARGET"], source=env["__COMPILATIONDB_USOURCE"], env=env["__COMPILATIONDB_ENV"], + ) + + entry = { + "directory": env.Dir("#").abspath, + "command": command, + "file": str(env["__COMPILATIONDB_USOURCE"][0]), + } + + target[0].write(entry) + + +def WriteCompilationDb(target, source, env): + entries = [] + + for s in __COMPILATION_DB_ENTRIES: + entries.append(s.read()) + + with open(str(target[0]), "w") as target_file: + json.dump(entries, target_file, sort_keys=True, indent=4, separators=(",", ": ")) + + +def ScanCompilationDb(node, env, path): + return __COMPILATION_DB_ENTRIES + + +def generate(env, **kwargs): + + static_obj, shared_obj = SCons.Tool.createObjBuilders(env) + + env["COMPILATIONDB_COMSTR"] = kwargs.get("COMPILATIONDB_COMSTR", "Building compilation database $TARGET") + + components_by_suffix = itertools.chain( + itertools.product( + env["CPPSUFFIXES"], + [ + (static_obj, SCons.Defaults.StaticObjectEmitter, "$CXXCOM"), + (shared_obj, SCons.Defaults.SharedObjectEmitter, "$SHCXXCOM"), + ], + ), + ) + + for entry in components_by_suffix: + suffix = entry[0] + builder, base_emitter, command = entry[1] + + # Ensure we have a valid entry + # used to auto ignore header files + if suffix in builder.emitter: + emitter = builder.emitter[suffix] + builder.emitter[suffix] = SCons.Builder.ListEmitter([emitter, makeEmitCompilationDbEntry(command),]) + + env["BUILDERS"]["__COMPILATIONDB_Entry"] = SCons.Builder.Builder( + action=SCons.Action.Action(CompilationDbEntryAction, None), + ) + + env["BUILDERS"]["__COMPILATIONDB_Database"] = SCons.Builder.Builder( + action=SCons.Action.Action(WriteCompilationDb, "$COMPILATIONDB_COMSTR"), + target_scanner=SCons.Scanner.Scanner(function=ScanCompilationDb, node_class=None), + ) + + def CompilationDatabase(env, target): + result = env.__COMPILATIONDB_Database(target=target, source=[]) + + env.AlwaysBuild(result) + env.NoCache(result) + + return result + + env.AddMethod(CompilationDatabase, "CompilationDatabase") + + +def exists(env): + return True diff --git a/modules/assimp/import_utils.h b/modules/assimp/import_utils.h index f78931add3..e3510c2cb3 100644 --- a/modules/assimp/import_utils.h +++ b/modules/assimp/import_utils.h @@ -202,20 +202,34 @@ public: */ static float get_fbx_fps(int32_t time_mode, const aiScene *p_scene) { switch (time_mode) { - case AssetImportFbx::TIME_MODE_DEFAULT: return 24; //hack - case AssetImportFbx::TIME_MODE_120: return 120; - case AssetImportFbx::TIME_MODE_100: return 100; - case AssetImportFbx::TIME_MODE_60: return 60; - case AssetImportFbx::TIME_MODE_50: return 50; - case AssetImportFbx::TIME_MODE_48: return 48; - case AssetImportFbx::TIME_MODE_30: return 30; - case AssetImportFbx::TIME_MODE_30_DROP: return 30; - case AssetImportFbx::TIME_MODE_NTSC_DROP_FRAME: return 29.9700262f; - case AssetImportFbx::TIME_MODE_NTSC_FULL_FRAME: return 29.9700262f; - case AssetImportFbx::TIME_MODE_PAL: return 25; - case AssetImportFbx::TIME_MODE_CINEMA: return 24; - case AssetImportFbx::TIME_MODE_1000: return 1000; - case AssetImportFbx::TIME_MODE_CINEMA_ND: return 23.976f; + case AssetImportFbx::TIME_MODE_DEFAULT: + return 24; //hack + case AssetImportFbx::TIME_MODE_120: + return 120; + case AssetImportFbx::TIME_MODE_100: + return 100; + case AssetImportFbx::TIME_MODE_60: + return 60; + case AssetImportFbx::TIME_MODE_50: + return 50; + case AssetImportFbx::TIME_MODE_48: + return 48; + case AssetImportFbx::TIME_MODE_30: + return 30; + case AssetImportFbx::TIME_MODE_30_DROP: + return 30; + case AssetImportFbx::TIME_MODE_NTSC_DROP_FRAME: + return 29.9700262f; + case AssetImportFbx::TIME_MODE_NTSC_FULL_FRAME: + return 29.9700262f; + case AssetImportFbx::TIME_MODE_PAL: + return 25; + case AssetImportFbx::TIME_MODE_CINEMA: + return 24; + case AssetImportFbx::TIME_MODE_1000: + return 1000; + case AssetImportFbx::TIME_MODE_CINEMA_ND: + return 23.976f; case AssetImportFbx::TIME_MODE_CUSTOM: int32_t frame_rate = -1; p_scene->mMetaData->Get("FrameRate", frame_rate); diff --git a/modules/bmp/image_loader_bmp.cpp b/modules/bmp/image_loader_bmp.cpp index ea2b2b548f..f69f3a43a4 100644 --- a/modules/bmp/image_loader_bmp.cpp +++ b/modules/bmp/image_loader_bmp.cpp @@ -153,7 +153,7 @@ Error ImageLoaderBMP::convert_to_image(Ref<Image> p_image, if (p_color_buffer == nullptr || color_table_size == 0) { // regular pixels - p_image->create(width, height, 0, Image::FORMAT_RGBA8, data); + p_image->create(width, height, false, Image::FORMAT_RGBA8, data); } else { // data is in indexed format, extend it @@ -193,7 +193,7 @@ Error ImageLoaderBMP::convert_to_image(Ref<Image> p_image, dest += 4; } - p_image->create(width, height, 0, Image::FORMAT_RGBA8, extended_data); + p_image->create(width, height, false, Image::FORMAT_RGBA8, extended_data); } } return err; diff --git a/modules/bullet/area_bullet.cpp b/modules/bullet/area_bullet.cpp index 4d727529ef..a4a86ab751 100644 --- a/modules/bullet/area_bullet.cpp +++ b/modules/bullet/area_bullet.cpp @@ -44,18 +44,7 @@ */ AreaBullet::AreaBullet() : - RigidCollisionObjectBullet(CollisionObjectBullet::TYPE_AREA), - monitorable(true), - spOv_mode(PhysicsServer3D::AREA_SPACE_OVERRIDE_DISABLED), - spOv_gravityPoint(false), - spOv_gravityPointDistanceScale(0), - spOv_gravityPointAttenuation(1), - spOv_gravityVec(0, -1, 0), - spOv_gravityMag(10), - spOv_linearDump(0.1), - spOv_angularDump(1), - spOv_priority(0), - isScratched(false) { + RigidCollisionObjectBullet(CollisionObjectBullet::TYPE_AREA) { btGhost = bulletnew(btGhostObject); reload_shapes(); diff --git a/modules/bullet/area_bullet.h b/modules/bullet/area_bullet.h index 0272350510..cde889c1ba 100644 --- a/modules/bullet/area_bullet.h +++ b/modules/bullet/area_bullet.h @@ -61,12 +61,10 @@ public: }; struct OverlappingObjectData { - CollisionObjectBullet *object; - OverlapState state; + CollisionObjectBullet *object = nullptr; + OverlapState state = OVERLAP_STATE_ENTER; - OverlappingObjectData() : - object(nullptr), - state(OVERLAP_STATE_ENTER) {} + OverlappingObjectData() {} OverlappingObjectData(CollisionObjectBullet *p_object, OverlapState p_state) : object(p_object), state(p_state) {} @@ -86,19 +84,19 @@ private: btGhostObject *btGhost; Vector<OverlappingObjectData> overlappingObjects; - bool monitorable; - - PhysicsServer3D::AreaSpaceOverrideMode spOv_mode; - bool spOv_gravityPoint; - real_t spOv_gravityPointDistanceScale; - real_t spOv_gravityPointAttenuation; - Vector3 spOv_gravityVec; - real_t spOv_gravityMag; - real_t spOv_linearDump; - real_t spOv_angularDump; - int spOv_priority; - - bool isScratched; + bool monitorable = true; + + PhysicsServer3D::AreaSpaceOverrideMode spOv_mode = PhysicsServer3D::AREA_SPACE_OVERRIDE_DISABLED; + bool spOv_gravityPoint = false; + real_t spOv_gravityPointDistanceScale = 0; + real_t spOv_gravityPointAttenuation = 1; + Vector3 spOv_gravityVec = Vector3(0, -1, 0); + real_t spOv_gravityMag = 10; + real_t spOv_linearDump = 0.1; + real_t spOv_angularDump = 1; + int spOv_priority = 0; + + bool isScratched = false; InOutEventCallback eventsCallbacks[2]; diff --git a/modules/bullet/bullet_physics_server.cpp b/modules/bullet/bullet_physics_server.cpp index 2705c749a2..09a5f6f983 100644 --- a/modules/bullet/bullet_physics_server.cpp +++ b/modules/bullet/bullet_physics_server.cpp @@ -79,9 +79,7 @@ void BulletPhysicsServer3D::_bind_methods() { } BulletPhysicsServer3D::BulletPhysicsServer3D() : - PhysicsServer3D(), - active(true), - active_spaces_count(0) {} + PhysicsServer3D() {} BulletPhysicsServer3D::~BulletPhysicsServer3D() {} diff --git a/modules/bullet/bullet_physics_server.h b/modules/bullet/bullet_physics_server.h index ea9c5e589e..558d1ce5f7 100644 --- a/modules/bullet/bullet_physics_server.h +++ b/modules/bullet/bullet_physics_server.h @@ -40,6 +40,7 @@ #include "shape_bullet.h" #include "soft_body_bullet.h" #include "space_bullet.h" + /** @author AndreaCatania */ @@ -49,8 +50,8 @@ class BulletPhysicsServer3D : public PhysicsServer3D { friend class BulletPhysicsDirectSpaceState; - bool active; - char active_spaces_count; + bool active = true; + char active_spaces_count = 0; Vector<SpaceBullet *> active_spaces; mutable RID_PtrOwner<SpaceBullet> space_owner; diff --git a/modules/bullet/collision_object_bullet.cpp b/modules/bullet/collision_object_bullet.cpp index 1b72c2f577..9ad74ad262 100644 --- a/modules/bullet/collision_object_bullet.cpp +++ b/modules/bullet/collision_object_bullet.cpp @@ -89,18 +89,7 @@ void CollisionObjectBullet::ShapeWrapper::claim_bt_shape(const btVector3 &body_s CollisionObjectBullet::CollisionObjectBullet(Type p_type) : RIDBullet(), - type(p_type), - instance_id(ObjectID()), - collisionLayer(0), - collisionMask(0), - collisionsEnabled(true), - m_isStatic(false), - ray_pickable(false), - bt_collision_object(nullptr), - body_scale(1., 1., 1.), - force_shape_reset(false), - space(nullptr), - isTransformChanged(false) {} + type(p_type) {} CollisionObjectBullet::~CollisionObjectBullet() { // Remove all overlapping, notify is not required since godot take care of it @@ -225,11 +214,6 @@ void CollisionObjectBullet::notify_transform_changed() { isTransformChanged = true; } -RigidCollisionObjectBullet::RigidCollisionObjectBullet(Type p_type) : - CollisionObjectBullet(p_type), - mainShape(nullptr) { -} - RigidCollisionObjectBullet::~RigidCollisionObjectBullet() { remove_all_shapes(true, true); if (mainShape && mainShape->isCompound()) { diff --git a/modules/bullet/collision_object_bullet.h b/modules/bullet/collision_object_bullet.h index 25176458a7..f1423a69e4 100644 --- a/modules/bullet/collision_object_bullet.h +++ b/modules/bullet/collision_object_bullet.h @@ -69,27 +69,22 @@ public: }; struct ShapeWrapper { - ShapeBullet *shape; - btCollisionShape *bt_shape; + ShapeBullet *shape = nullptr; + btCollisionShape *bt_shape = nullptr; btTransform transform; btVector3 scale; - bool active; + bool active = true; - ShapeWrapper() : - shape(nullptr), - bt_shape(nullptr), - active(true) {} + ShapeWrapper() {} ShapeWrapper(ShapeBullet *p_shape, const btTransform &p_transform, bool p_active) : shape(p_shape), - bt_shape(nullptr), active(p_active) { set_transform(p_transform); } ShapeWrapper(ShapeBullet *p_shape, const Transform &p_transform, bool p_active) : shape(p_shape), - bt_shape(nullptr), active(p_active) { set_transform(p_transform); } @@ -117,15 +112,15 @@ public: protected: Type type; ObjectID instance_id; - uint32_t collisionLayer; - uint32_t collisionMask; - bool collisionsEnabled; - bool m_isStatic; - bool ray_pickable; - btCollisionObject *bt_collision_object; - Vector3 body_scale; - bool force_shape_reset; - SpaceBullet *space; + uint32_t collisionLayer = 0; + uint32_t collisionMask = 0; + bool collisionsEnabled = true; + bool m_isStatic = false; + bool ray_pickable = false; + btCollisionObject *bt_collision_object = nullptr; + Vector3 body_scale = Vector3(1, 1, 1); + bool force_shape_reset = false; + SpaceBullet *space = nullptr; VSet<RID> exceptions; @@ -133,7 +128,7 @@ protected: /// New area is added when overlap with new area (AreaBullet::addOverlap), then is removed when it exit (CollisionObjectBullet::onExitArea) /// This array is used mainly to know which area hold the pointer of this object Vector<AreaBullet *> areasOverlapped; - bool isTransformChanged; + bool isTransformChanged = false; public: CollisionObjectBullet(Type p_type); @@ -218,11 +213,12 @@ public: class RigidCollisionObjectBullet : public CollisionObjectBullet, public ShapeOwnerBullet { protected: - btCollisionShape *mainShape; + btCollisionShape *mainShape = nullptr; Vector<ShapeWrapper> shapes; public: - RigidCollisionObjectBullet(Type p_type); + RigidCollisionObjectBullet(Type p_type) : + CollisionObjectBullet(p_type) {} ~RigidCollisionObjectBullet(); _FORCE_INLINE_ const Vector<ShapeWrapper> &get_shapes_wrappers() const { return shapes; } diff --git a/modules/bullet/constraint_bullet.cpp b/modules/bullet/constraint_bullet.cpp index 469b58521e..c47a23e75f 100644 --- a/modules/bullet/constraint_bullet.cpp +++ b/modules/bullet/constraint_bullet.cpp @@ -37,10 +37,7 @@ @author AndreaCatania */ -ConstraintBullet::ConstraintBullet() : - space(nullptr), - constraint(nullptr), - disabled_collisions_between_bodies(true) {} +ConstraintBullet::ConstraintBullet() {} void ConstraintBullet::setup(btTypedConstraint *p_constraint) { constraint = p_constraint; diff --git a/modules/bullet/constraint_bullet.h b/modules/bullet/constraint_bullet.h index 1946807bad..125940439f 100644 --- a/modules/bullet/constraint_bullet.h +++ b/modules/bullet/constraint_bullet.h @@ -47,9 +47,9 @@ class btTypedConstraint; class ConstraintBullet : public RIDBullet { protected: - SpaceBullet *space; - btTypedConstraint *constraint; - bool disabled_collisions_between_bodies; + SpaceBullet *space = nullptr; + btTypedConstraint *constraint = nullptr; + bool disabled_collisions_between_bodies = true; public: ConstraintBullet(); diff --git a/modules/bullet/godot_ray_world_algorithm.cpp b/modules/bullet/godot_ray_world_algorithm.cpp index 2ef277cf5b..2caa75c2a7 100644 --- a/modules/bullet/godot_ray_world_algorithm.cpp +++ b/modules/bullet/godot_ray_world_algorithm.cpp @@ -52,7 +52,6 @@ GodotRayWorldAlgorithm::GodotRayWorldAlgorithm(const btDiscreteDynamicsWorld *wo btActivatingCollisionAlgorithm(ci, body0Wrap, body1Wrap), m_world(world), m_manifoldPtr(mf), - m_ownManifold(false), m_isSwapped(isSwapped) {} GodotRayWorldAlgorithm::~GodotRayWorldAlgorithm() { diff --git a/modules/bullet/godot_ray_world_algorithm.h b/modules/bullet/godot_ray_world_algorithm.h index 2cdea6c133..ec7f68dc51 100644 --- a/modules/bullet/godot_ray_world_algorithm.h +++ b/modules/bullet/godot_ray_world_algorithm.h @@ -45,7 +45,7 @@ class GodotRayWorldAlgorithm : public btActivatingCollisionAlgorithm { const btDiscreteDynamicsWorld *m_world; btPersistentManifold *m_manifoldPtr; - bool m_ownManifold; + bool m_ownManifold = false; bool m_isSwapped; public: diff --git a/modules/bullet/godot_result_callbacks.h b/modules/bullet/godot_result_callbacks.h index 7e74a2b22e..8636ca8eb6 100644 --- a/modules/bullet/godot_result_callbacks.h +++ b/modules/bullet/godot_result_callbacks.h @@ -56,8 +56,8 @@ struct GodotFilterCallback : public btOverlapFilterCallback { /// It performs an additional check allow exclusions. struct GodotClosestRayResultCallback : public btCollisionWorld::ClosestRayResultCallback { const Set<RID> *m_exclude; - bool m_pickRay; - int m_shapeId; + bool m_pickRay = false; + int m_shapeId = 0; bool collide_with_bodies; bool collide_with_areas; @@ -66,8 +66,6 @@ public: GodotClosestRayResultCallback(const btVector3 &rayFromWorld, const btVector3 &rayToWorld, const Set<RID> *p_exclude, bool p_collide_with_bodies, bool p_collide_with_areas) : btCollisionWorld::ClosestRayResultCallback(rayFromWorld, rayToWorld), m_exclude(p_exclude), - m_pickRay(false), - m_shapeId(0), collide_with_bodies(p_collide_with_bodies), collide_with_areas(p_collide_with_areas) {} @@ -88,13 +86,12 @@ public: PhysicsDirectSpaceState3D::ShapeResult *m_results; int m_resultMax; const Set<RID> *m_exclude; - int count; + int count = 0; GodotAllConvexResultCallback(PhysicsDirectSpaceState3D::ShapeResult *p_results, int p_resultMax, const Set<RID> *p_exclude) : m_results(p_results), m_resultMax(p_resultMax), - m_exclude(p_exclude), - count(0) {} + m_exclude(p_exclude) {} virtual bool needsCollision(btBroadphaseProxy *proxy0) const; @@ -117,7 +114,7 @@ public: struct GodotClosestConvexResultCallback : public btCollisionWorld::ClosestConvexResultCallback { public: const Set<RID> *m_exclude; - int m_shapeId; + int m_shapeId = 0; bool collide_with_bodies; bool collide_with_areas; @@ -125,7 +122,6 @@ public: GodotClosestConvexResultCallback(const btVector3 &convexFromWorld, const btVector3 &convexToWorld, const Set<RID> *p_exclude, bool p_collide_with_bodies, bool p_collide_with_areas) : btCollisionWorld::ClosestConvexResultCallback(convexFromWorld, convexToWorld), m_exclude(p_exclude), - m_shapeId(0), collide_with_bodies(p_collide_with_bodies), collide_with_areas(p_collide_with_areas) {} @@ -140,7 +136,7 @@ public: PhysicsDirectSpaceState3D::ShapeResult *m_results; int m_resultMax; const Set<RID> *m_exclude; - int m_count; + int m_count = 0; bool collide_with_bodies; bool collide_with_areas; @@ -150,7 +146,6 @@ public: m_results(p_results), m_resultMax(p_resultMax), m_exclude(p_exclude), - m_count(0), collide_with_bodies(p_collide_with_bodies), collide_with_areas(p_collide_with_areas) {} @@ -166,7 +161,7 @@ public: Vector3 *m_results; int m_resultMax; const Set<RID> *m_exclude; - int m_count; + int m_count = 0; bool collide_with_bodies; bool collide_with_areas; @@ -176,7 +171,6 @@ public: m_results(p_results), m_resultMax(p_resultMax), m_exclude(p_exclude), - m_count(0), collide_with_bodies(p_collide_with_bodies), collide_with_areas(p_collide_with_areas) {} @@ -190,8 +184,8 @@ public: const btCollisionObject *m_self_object; PhysicsDirectSpaceState3D::ShapeRestInfo *m_result; const Set<RID> *m_exclude; - bool m_collided; - real_t m_min_distance; + bool m_collided = false; + real_t m_min_distance = 0; const btCollisionObject *m_rest_info_collision_object; btVector3 m_rest_info_bt_point; bool collide_with_bodies; @@ -201,8 +195,6 @@ public: m_self_object(p_self_object), m_result(p_result), m_exclude(p_exclude), - m_collided(false), - m_min_distance(0), collide_with_bodies(p_collide_with_bodies), collide_with_areas(p_collide_with_areas) {} @@ -214,13 +206,11 @@ public: struct GodotDeepPenetrationContactResultCallback : public btManifoldResult { btVector3 m_pointNormalWorld; btVector3 m_pointWorld; - btScalar m_penetration_distance; - int m_other_compound_shape_index; + btScalar m_penetration_distance = 0; + int m_other_compound_shape_index = 0; GodotDeepPenetrationContactResultCallback(const btCollisionObjectWrapper *body0Wrap, const btCollisionObjectWrapper *body1Wrap) : - btManifoldResult(body0Wrap, body1Wrap), - m_penetration_distance(0), - m_other_compound_shape_index(0) {} + btManifoldResult(body0Wrap, body1Wrap) {} void reset() { m_penetration_distance = 0; diff --git a/modules/bullet/hinge_joint_bullet.cpp b/modules/bullet/hinge_joint_bullet.cpp index 4bea9f87c0..e7f3d75c10 100644 --- a/modules/bullet/hinge_joint_bullet.cpp +++ b/modules/bullet/hinge_joint_bullet.cpp @@ -162,7 +162,8 @@ void HingeJointBullet::set_flag(PhysicsServer3D::HingeJointFlag p_flag, bool p_v case PhysicsServer3D::HINGE_JOINT_FLAG_ENABLE_MOTOR: hingeConstraint->enableMotor(p_value); break; - case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: + break; // Can't happen, but silences warning } } diff --git a/modules/bullet/rigid_body_bullet.cpp b/modules/bullet/rigid_body_bullet.cpp index e393396713..7a244b8c32 100644 --- a/modules/bullet/rigid_body_bullet.cpp +++ b/modules/bullet/rigid_body_bullet.cpp @@ -256,25 +256,7 @@ void RigidBodyBullet::KinematicUtilities::just_delete_shapes(int new_size) { } RigidBodyBullet::RigidBodyBullet() : - RigidCollisionObjectBullet(CollisionObjectBullet::TYPE_RIGID_BODY), - kinematic_utilities(nullptr), - locked_axis(0), - mass(1), - gravity_scale(1), - linearDamp(0), - angularDamp(0), - can_sleep(true), - omit_forces_integration(false), - can_integrate_forces(false), - maxCollisionsDetection(0), - collisionsCount(0), - prev_collision_count(0), - maxAreasWhereIam(10), - areaWhereIamCount(0), - countGravityPointSpaces(0), - isScratchedSpaceOverrideModificator(false), - previousActiveState(true), - force_integration_callback(nullptr) { + RigidCollisionObjectBullet(CollisionObjectBullet::TYPE_RIGID_BODY) { godotMotionState = bulletnew(GodotMotionState(this)); diff --git a/modules/bullet/rigid_body_bullet.h b/modules/bullet/rigid_body_bullet.h index 420b5cc443..f94dea8036 100644 --- a/modules/bullet/rigid_body_bullet.h +++ b/modules/bullet/rigid_body_bullet.h @@ -162,11 +162,10 @@ public: /// Used to hold shapes struct KinematicShape { - class btConvexShape *shape; + class btConvexShape *shape = nullptr; btTransform transform; - KinematicShape() : - shape(nullptr) {} + KinematicShape() {} bool is_active() const { return shape; } }; @@ -190,19 +189,19 @@ private: friend class BulletPhysicsDirectBodyState3D; // This is required only for Kinematic movement - KinematicUtilities *kinematic_utilities; + KinematicUtilities *kinematic_utilities = nullptr; PhysicsServer3D::BodyMode mode; GodotMotionState *godotMotionState; btRigidBody *btBody; - uint16_t locked_axis; - real_t mass; - real_t gravity_scale; - real_t linearDamp; - real_t angularDamp; - bool can_sleep; - bool omit_forces_integration; - bool can_integrate_forces; + uint16_t locked_axis = 0; + real_t mass = 1; + real_t gravity_scale = 1; + real_t linearDamp = 0; + real_t angularDamp = 0; + bool can_sleep = true; + bool omit_forces_integration = false; + bool can_integrate_forces = false; Vector<CollisionData> collisions; Vector<RigidBodyBullet *> collision_traces_1; @@ -211,21 +210,21 @@ private: Vector<RigidBodyBullet *> *curr_collision_traces; // these parameters are used to avoid vector resize - int maxCollisionsDetection; - int collisionsCount; - int prev_collision_count; + int maxCollisionsDetection = 0; + int collisionsCount = 0; + int prev_collision_count = 0; Vector<AreaBullet *> areasWhereIam; // these parameters are used to avoid vector resize - int maxAreasWhereIam; - int areaWhereIamCount; + int maxAreasWhereIam = 10; + int areaWhereIamCount = 0; // Used to know if the area is used as gravity point - int countGravityPointSpaces; - bool isScratchedSpaceOverrideModificator; + int countGravityPointSpaces = 0; + bool isScratchedSpaceOverrideModificator = false; - bool previousActiveState; // Last check state + bool previousActiveState = true; // Last check state - ForceIntegrationCallback *force_integration_callback; + ForceIntegrationCallback *force_integration_callback = nullptr; public: RigidBodyBullet(); diff --git a/modules/bullet/shape_bullet.cpp b/modules/bullet/shape_bullet.cpp index 8ac26a0fdb..e3b869ad8d 100644 --- a/modules/bullet/shape_bullet.cpp +++ b/modules/bullet/shape_bullet.cpp @@ -46,8 +46,7 @@ @author AndreaCatania */ -ShapeBullet::ShapeBullet() : - margin(0.04) {} +ShapeBullet::ShapeBullet() {} ShapeBullet::~ShapeBullet() {} @@ -81,7 +80,8 @@ void ShapeBullet::add_owner(ShapeOwnerBullet *p_owner) { void ShapeBullet::remove_owner(ShapeOwnerBullet *p_owner, bool p_permanentlyFromThisBody) { Map<ShapeOwnerBullet *, int>::Element *E = owners.find(p_owner); - if (!E) return; + if (!E) + return; E->get()--; if (p_permanentlyFromThisBody || 0 >= E->get()) { owners.erase(E); @@ -361,8 +361,7 @@ btCollisionShape *ConvexPolygonShapeBullet::create_bt_shape(const btVector3 &p_i /* Concave polygon */ ConcavePolygonShapeBullet::ConcavePolygonShapeBullet() : - ShapeBullet(), - meshShape(nullptr) {} + ShapeBullet() {} ConcavePolygonShapeBullet::~ConcavePolygonShapeBullet() { if (meshShape) { @@ -562,9 +561,7 @@ btCollisionShape *HeightMapShapeBullet::create_bt_shape(const btVector3 &p_impli /* Ray shape */ RayShapeBullet::RayShapeBullet() : - ShapeBullet(), - length(1), - slips_on_slope(false) {} + ShapeBullet() {} void RayShapeBullet::set_data(const Variant &p_data) { diff --git a/modules/bullet/shape_bullet.h b/modules/bullet/shape_bullet.h index 0dbc616fe5..88b62b6dc9 100644 --- a/modules/bullet/shape_bullet.h +++ b/modules/bullet/shape_bullet.h @@ -52,7 +52,7 @@ class btBvhTriangleMeshShape; class ShapeBullet : public RIDBullet { Map<ShapeOwnerBullet *, int> owners; - real_t margin; + real_t margin = 0.04; protected: /// return self @@ -200,7 +200,7 @@ private: }; class ConcavePolygonShapeBullet : public ShapeBullet { - class btBvhTriangleMeshShape *meshShape; + class btBvhTriangleMeshShape *meshShape = nullptr; public: Vector<Vector3> faces; @@ -240,8 +240,8 @@ private: class RayShapeBullet : public ShapeBullet { public: - real_t length; - bool slips_on_slope; + real_t length = 1; + bool slips_on_slope = false; RayShapeBullet(); diff --git a/modules/bullet/slider_joint_bullet.cpp b/modules/bullet/slider_joint_bullet.cpp index f193daef39..de47c91e5d 100644 --- a/modules/bullet/slider_joint_bullet.cpp +++ b/modules/bullet/slider_joint_bullet.cpp @@ -344,56 +344,123 @@ real_t SliderJointBullet::getLinearPos() { void SliderJointBullet::set_param(PhysicsServer3D::SliderJointParam p_param, real_t p_value) { switch (p_param) { - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: setUpperLinLimit(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: setLowerLinLimit(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: setSoftnessLimLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: setRestitutionLimLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: setDampingLimLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: setSoftnessDirLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: setRestitutionDirLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: setDampingDirLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: setSoftnessOrthoLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: setRestitutionOrthoLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: setDampingOrthoLin(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: setUpperAngLimit(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: setLowerAngLimit(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: setSoftnessLimAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: setRestitutionLimAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: setDampingLimAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: setSoftnessDirAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: setRestitutionDirAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: setDampingDirAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: setSoftnessOrthoAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: setRestitutionOrthoAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: setDampingOrthoAng(p_value); break; - case PhysicsServer3D::SLIDER_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: + setUpperLinLimit(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: + setLowerLinLimit(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: + setSoftnessLimLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: + setRestitutionLimLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: + setDampingLimLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: + setSoftnessDirLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: + setRestitutionDirLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: + setDampingDirLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: + setSoftnessOrthoLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: + setRestitutionOrthoLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: + setDampingOrthoLin(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: + setUpperAngLimit(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: + setLowerAngLimit(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: + setSoftnessLimAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: + setRestitutionLimAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: + setDampingLimAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: + setSoftnessDirAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: + setRestitutionDirAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: + setDampingDirAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: + setSoftnessOrthoAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: + setRestitutionOrthoAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: + setDampingOrthoAng(p_value); + break; + case PhysicsServer3D::SLIDER_JOINT_MAX: + break; // Can't happen, but silences warning } } real_t SliderJointBullet::get_param(PhysicsServer3D::SliderJointParam p_param) const { switch (p_param) { - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: return getUpperLinLimit(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: return getLowerLinLimit(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: return getSoftnessLimLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: return getRestitutionLimLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: return getDampingLimLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: return getSoftnessDirLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: return getRestitutionDirLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: return getDampingDirLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: return getSoftnessOrthoLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: return getRestitutionOrthoLin(); - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: return getDampingOrthoLin(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: return getUpperAngLimit(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: return getLowerAngLimit(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: return getSoftnessLimAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: return getRestitutionLimAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: return getDampingLimAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: return getSoftnessDirAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: return getRestitutionDirAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: return getDampingDirAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: return getSoftnessOrthoAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: return getRestitutionOrthoAng(); - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: return getDampingOrthoAng(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: + return getUpperLinLimit(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: + return getLowerLinLimit(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: + return getSoftnessLimLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: + return getRestitutionLimLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: + return getDampingLimLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: + return getSoftnessDirLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: + return getRestitutionDirLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: + return getDampingDirLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: + return getSoftnessOrthoLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: + return getRestitutionOrthoLin(); + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: + return getDampingOrthoLin(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: + return getUpperAngLimit(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: + return getLowerAngLimit(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: + return getSoftnessLimAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: + return getRestitutionLimAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: + return getDampingLimAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: + return getSoftnessDirAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: + return getRestitutionDirAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: + return getDampingDirAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: + return getSoftnessOrthoAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: + return getRestitutionOrthoAng(); + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: + return getDampingOrthoAng(); default: return 0; } diff --git a/modules/bullet/soft_body_bullet.cpp b/modules/bullet/soft_body_bullet.cpp index 236bdc7c8a..bbaa23e064 100644 --- a/modules/bullet/soft_body_bullet.cpp +++ b/modules/bullet/soft_body_bullet.cpp @@ -36,18 +36,7 @@ #include "space_bullet.h" SoftBodyBullet::SoftBodyBullet() : - CollisionObjectBullet(CollisionObjectBullet::TYPE_SOFT_BODY), - bt_soft_body(nullptr), - isScratched(false), - simulation_precision(5), - total_mass(1.), - linear_stiffness(0.5), - areaAngular_stiffness(0.5), - volume_stiffness(0.5), - pressure_coefficient(0.), - pose_matching_coefficient(0.), - damping_coefficient(0.01), - drag_coefficient(0.) {} + CollisionObjectBullet(CollisionObjectBullet::TYPE_SOFT_BODY) {} SoftBodyBullet::~SoftBodyBullet() { } diff --git a/modules/bullet/soft_body_bullet.h b/modules/bullet/soft_body_bullet.h index 3c6871e0d6..d28af7d61d 100644 --- a/modules/bullet/soft_body_bullet.h +++ b/modules/bullet/soft_body_bullet.h @@ -58,22 +58,22 @@ class SoftBodyBullet : public CollisionObjectBullet { private: - btSoftBody *bt_soft_body; + btSoftBody *bt_soft_body = nullptr; Vector<Vector<int>> indices_table; btSoftBody::Material *mat0; // This is just a copy of pointer managed by btSoftBody - bool isScratched; + bool isScratched = false; Ref<Mesh> soft_mesh; - int simulation_precision; - real_t total_mass; - real_t linear_stiffness; // [0,1] - real_t areaAngular_stiffness; // [0,1] - real_t volume_stiffness; // [0,1] - real_t pressure_coefficient; // [-inf,+inf] - real_t pose_matching_coefficient; // [0,1] - real_t damping_coefficient; // [0,1] - real_t drag_coefficient; // [0,1] + int simulation_precision = 5; + real_t total_mass = 1.; + real_t linear_stiffness = 0.5; // [0,1] + real_t areaAngular_stiffness = 0.5; // [0,1] + real_t volume_stiffness = 0.5; // [0,1] + real_t pressure_coefficient = 0.; // [-inf,+inf] + real_t pose_matching_coefficient = 0.; // [0,1] + real_t damping_coefficient = 0.01; // [0,1] + real_t drag_coefficient = 0.; // [0,1] Vector<int> pinned_nodes; // Other property to add diff --git a/modules/bullet/space_bullet.cpp b/modules/bullet/space_bullet.cpp index d49e635fd5..aff203d7b4 100644 --- a/modules/bullet/space_bullet.cpp +++ b/modules/bullet/space_bullet.cpp @@ -205,7 +205,7 @@ bool BulletPhysicsDirectSpaceState::cast_motion(const RID &p_shape, const Transf /// Returns the list of contacts pairs in this order: Local contact, other body contact bool BulletPhysicsDirectSpaceState::collide_shape(RID p_shape, const Transform &p_shape_xform, float p_margin, Vector3 *r_results, int p_result_max, int &r_result_count, const Set<RID> &p_exclude, uint32_t p_collision_mask, bool p_collide_with_bodies, bool p_collide_with_areas) { if (p_result_max <= 0) - return 0; + return false; ShapeBullet *shape = space->get_physics_server()->get_shape_owner()->getornull(p_shape); @@ -213,7 +213,7 @@ bool BulletPhysicsDirectSpaceState::collide_shape(RID p_shape, const Transform & if (!btShape->isConvex()) { bulletdelete(btShape); ERR_PRINT("The shape is not a convex shape, then is not supported: shape type: " + itos(shape->get_type())); - return 0; + return false; } btConvexShape *btConvex = static_cast<btConvexShape *>(btShape); @@ -245,7 +245,7 @@ bool BulletPhysicsDirectSpaceState::rest_info(RID p_shape, const Transform &p_sh if (!btShape->isConvex()) { bulletdelete(btShape); ERR_PRINT("The shape is not a convex shape, then is not supported: shape type: " + itos(shape->get_type())); - return 0; + return false; } btConvexShape *btConvex = static_cast<btConvexShape *>(btShape); @@ -331,22 +331,7 @@ Vector3 BulletPhysicsDirectSpaceState::get_closest_point_to_object_volume(RID p_ } } -SpaceBullet::SpaceBullet() : - broadphase(nullptr), - collisionConfiguration(nullptr), - dispatcher(nullptr), - solver(nullptr), - dynamicsWorld(nullptr), - soft_body_world_info(nullptr), - ghostPairCallback(nullptr), - godotFilterCallback(nullptr), - gravityDirection(0, -1, 0), - gravityMagnitude(10), - linear_damp(0.0), - angular_damp(0.0), - contactDebugCount(0), - delta_time(0.) { - +SpaceBullet::SpaceBullet() { create_empty_world(GLOBAL_DEF("physics/3d/active_soft_world", true)); direct_access = memnew(BulletPhysicsDirectSpaceState(this)); } diff --git a/modules/bullet/space_bullet.h b/modules/bullet/space_bullet.h index 3d4a2aeceb..6fe4571bba 100644 --- a/modules/bullet/space_bullet.h +++ b/modules/bullet/space_bullet.h @@ -92,30 +92,30 @@ class SpaceBullet : public RIDBullet { friend void onBulletTickCallback(btDynamicsWorld *world, btScalar timeStep); friend class BulletPhysicsDirectSpaceState; - btBroadphaseInterface *broadphase; - btDefaultCollisionConfiguration *collisionConfiguration; - btCollisionDispatcher *dispatcher; - btConstraintSolver *solver; - btDiscreteDynamicsWorld *dynamicsWorld; - btSoftBodyWorldInfo *soft_body_world_info; - btGhostPairCallback *ghostPairCallback; - GodotFilterCallback *godotFilterCallback; + btBroadphaseInterface *broadphase = nullptr; + btDefaultCollisionConfiguration *collisionConfiguration = nullptr; + btCollisionDispatcher *dispatcher = nullptr; + btConstraintSolver *solver = nullptr; + btDiscreteDynamicsWorld *dynamicsWorld = nullptr; + btSoftBodyWorldInfo *soft_body_world_info = nullptr; + btGhostPairCallback *ghostPairCallback = nullptr; + GodotFilterCallback *godotFilterCallback = nullptr; btGjkEpaPenetrationDepthSolver *gjk_epa_pen_solver; btVoronoiSimplexSolver *gjk_simplex_solver; BulletPhysicsDirectSpaceState *direct_access; - Vector3 gravityDirection; - real_t gravityMagnitude; + Vector3 gravityDirection = Vector3(0, -1, 0); + real_t gravityMagnitude = 10; - real_t linear_damp; - real_t angular_damp; + real_t linear_damp = 0.0; + real_t angular_damp = 0.0; Vector<AreaBullet *> areas; Vector<Vector3> contactDebug; - int contactDebugCount; - real_t delta_time; + int contactDebugCount = 0; + real_t delta_time = 0.; public: SpaceBullet(); @@ -170,7 +170,8 @@ public: contactDebugCount = 0; } _FORCE_INLINE_ void add_debug_contact(const Vector3 &p_contact) { - if (contactDebugCount < contactDebug.size()) contactDebug.write[contactDebugCount++] = p_contact; + if (contactDebugCount < contactDebug.size()) + contactDebug.write[contactDebugCount++] = p_contact; } _FORCE_INLINE_ Vector<Vector3> get_debug_contacts() { return contactDebug; } _FORCE_INLINE_ int get_debug_contact_count() { return contactDebugCount; } @@ -193,22 +194,15 @@ private: void check_body_collision(); struct RecoverResult { - bool hasPenetration; - btVector3 normal; - btVector3 pointWorld; - btScalar penetration_distance; // Negative mean penetration - int other_compound_shape_index; - const btCollisionObject *other_collision_object; - int local_shape_most_recovered; - - RecoverResult() : - hasPenetration(false), - normal(0, 0, 0), - pointWorld(0, 0, 0), - penetration_distance(1e20), - other_compound_shape_index(0), - other_collision_object(nullptr), - local_shape_most_recovered(0) {} + bool hasPenetration = false; + btVector3 normal = btVector3(0, 0, 0); + btVector3 pointWorld = btVector3(0, 0, 0); + btScalar penetration_distance = 1e20; // Negative mean penetration + int other_compound_shape_index = 0; + const btCollisionObject *other_collision_object = nullptr; + int local_shape_most_recovered = 0; + + RecoverResult() {} }; bool recover_from_penetration(RigidBodyBullet *p_body, const btTransform &p_body_position, btScalar p_recover_movement_scale, bool p_infinite_inertia, btVector3 &r_delta_recover_movement, RecoverResult *r_recover_result = nullptr); diff --git a/modules/csg/csg.cpp b/modules/csg/csg.cpp index 6714db76bb..a6951a9320 100644 --- a/modules/csg/csg.cpp +++ b/modules/csg/csg.cpp @@ -138,10 +138,12 @@ inline bool is_point_in_triangle(const Vector3 &p_point, const Vector3 p_vertice lambda[2] = p_vertices[0].cross(p_vertices[1]).dot(p_point) / det; // Point is in the plane if all lambdas sum to 1. - if (!Math::is_equal_approx(lambda[0] + lambda[1] + lambda[2], 1)) return false; + if (!Math::is_equal_approx(lambda[0] + lambda[1] + lambda[2], 1)) + return false; // Point is inside the triangle if all lambdas are positive. - if (lambda[0] < 0 || lambda[1] < 0 || lambda[2] < 0) return false; + if (lambda[0] < 0 || lambda[1] < 0 || lambda[2] < 0) + return false; return true; } @@ -524,7 +526,8 @@ void CSGBrushOperation::MeshMerge::_add_distance(List<real_t> &r_intersectionsA, // Check if distance exists. for (const List<real_t>::Element *E = intersections.front(); E; E = E->next()) - if (Math::abs(**E - p_distance) < vertex_snap) return; + if (Math::abs(**E - p_distance) < vertex_snap) + return; intersections.push_back(p_distance); } @@ -790,7 +793,8 @@ int CSGBrushOperation::Build2DFaces::_add_vertex(const Vertex2D &p_vertex) { // Check if vertex exists. int vertex_id = _get_point_idx(p_vertex.point); - if (vertex_id != -1) return vertex_id; + if (vertex_id != -1) + return vertex_id; vertices.push_back(p_vertex); return vertices.size() - 1; @@ -816,7 +820,8 @@ void CSGBrushOperation::Build2DFaces::_add_vertex_idx_sorted(Vector<int> &r_vert // Sort along the axis with the greatest difference. int axis = 0; - if (Math::abs(new_point.x - first_point.x) < Math::abs(new_point.y - first_point.y)) axis = 1; + if (Math::abs(new_point.x - first_point.x) < Math::abs(new_point.y - first_point.y)) + axis = 1; // Add it to the beginning or the end appropriately. if (new_point[axis] < first_point[axis]) @@ -834,7 +839,8 @@ void CSGBrushOperation::Build2DFaces::_add_vertex_idx_sorted(Vector<int> &r_vert // Determine axis being sorted against i.e. the axis with the greatest difference. int axis = 0; - if (Math::abs(last_point.x - first_point.x) < Math::abs(last_point.y - first_point.y)) axis = 1; + if (Math::abs(last_point.x - first_point.x) < Math::abs(last_point.y - first_point.y)) + axis = 1; // Insert the point at the appropriate index. for (int insert_idx = 0; insert_idx < r_vertex_indices.size(); ++insert_idx) { @@ -853,7 +859,8 @@ void CSGBrushOperation::Build2DFaces::_add_vertex_idx_sorted(Vector<int> &r_vert void CSGBrushOperation::Build2DFaces::_merge_faces(const Vector<int> &p_segment_indices) { int segments = p_segment_indices.size() - 1; - if (segments < 2) return; + if (segments < 2) + return; // Faces around an inner vertex are merged by moving the inner vertex to the first vertex. for (int sorted_idx = 1; sorted_idx < segments; ++sorted_idx) { @@ -893,7 +900,8 @@ void CSGBrushOperation::Build2DFaces::_merge_faces(const Vector<int> &p_segment_ // Skip flattened faces. if (outer_edge_idx[0] == p_segment_indices[closest_idx] || - outer_edge_idx[1] == p_segment_indices[closest_idx]) continue; + outer_edge_idx[1] == p_segment_indices[closest_idx]) + continue; //Don't create degenerate triangles. Vector2 edge1[2] = { @@ -924,7 +932,8 @@ void CSGBrushOperation::Build2DFaces::_merge_faces(const Vector<int> &p_segment_ for (int i = 0; i < merge_faces_idx.size(); ++i) faces.remove(merge_faces_idx[i]); - if (degenerate_points.size() == 0) continue; + if (degenerate_points.size() == 0) + continue; // Split faces using degenerate points. for (int face_idx = 0; face_idx < faces.size(); ++face_idx) { @@ -954,7 +963,8 @@ void CSGBrushOperation::Build2DFaces::_merge_faces(const Vector<int> &p_segment_ break; } } - if (existing) continue; + if (existing) + continue; // Check if point is on an each edge. for (int face_edge_idx = 0; face_edge_idx < 3; ++face_edge_idx) { @@ -1043,10 +1053,12 @@ void CSGBrushOperation::Build2DFaces::_find_edge_intersections(const Vector2 p_s // Check if intersection point is an edge point. if ((intersection_point - edge_points[0]).length_squared() < vertex_snap2 || - (intersection_point - edge_points[1]).length_squared() < vertex_snap2) continue; + (intersection_point - edge_points[1]).length_squared() < vertex_snap2) + continue; // Check if edge exists, by checking if the intersecting segment is parallel to the edge. - if (are_segements_parallel(p_segment_points, edge_points, vertex_snap2)) continue; + if (are_segements_parallel(p_segment_points, edge_points, vertex_snap2)) + continue; // Add the intersection point as a new vertex. Vertex2D new_vertex; @@ -1384,7 +1396,8 @@ void CSGBrushOperation::update_faces(const CSGBrush &p_brush_a, const int p_face p_collection.build2DFacesB[p_face_idx_b] = Build2DFaces(); has_degenerate = true; } - if (has_degenerate) return; + if (has_degenerate) + return; // Ensure B has points either side of or in the plane of A. int in_plane_count = 0, over_count = 0, under_count = 0; @@ -1400,7 +1413,8 @@ void CSGBrushOperation::update_faces(const CSGBrush &p_brush_a, const int p_face under_count++; } // If all points under or over the plane, there is no intesection. - if (over_count == 3 || under_count == 3) return; + if (over_count == 3 || under_count == 3) + return; // Ensure A has points either side of or in the plane of B. in_plane_count = 0; @@ -1418,7 +1432,8 @@ void CSGBrushOperation::update_faces(const CSGBrush &p_brush_a, const int p_face under_count++; } // If all points under or over the plane, there is no intesection. - if (over_count == 3 || under_count == 3) return; + if (over_count == 3 || under_count == 3) + return; // Check for intersection using the SAT theorem. { diff --git a/modules/csg/csg_gizmos.cpp b/modules/csg/csg_gizmos.cpp index fa176cb94e..a859ef80a6 100644 --- a/modules/csg/csg_gizmos.cpp +++ b/modules/csg/csg_gizmos.cpp @@ -90,9 +90,12 @@ Variant CSGShape3DGizmoPlugin::get_handle_value(EditorNode3DGizmo *p_gizmo, int CSGBox3D *s = Object::cast_to<CSGBox3D>(cs); switch (p_idx) { - case 0: return s->get_width(); - case 1: return s->get_height(); - case 2: return s->get_depth(); + case 0: + return s->get_width(); + case 1: + return s->get_height(); + case 2: + return s->get_depth(); } } @@ -157,9 +160,15 @@ void CSGShape3DGizmoPlugin::set_handle(EditorNode3DGizmo *p_gizmo, int p_idx, Ca d = 0.001; switch (p_idx) { - case 0: s->set_width(d * 2); break; - case 1: s->set_height(d * 2); break; - case 2: s->set_depth(d * 2); break; + case 0: + s->set_width(d * 2); + break; + case 1: + s->set_height(d * 2); + break; + case 2: + s->set_depth(d * 2); + break; } } @@ -229,9 +238,15 @@ void CSGShape3DGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, CSGBox3D *s = Object::cast_to<CSGBox3D>(cs); if (p_cancel) { switch (p_idx) { - case 0: s->set_width(p_restore); break; - case 1: s->set_height(p_restore); break; - case 2: s->set_depth(p_restore); break; + case 0: + s->set_width(p_restore); + break; + case 1: + s->set_height(p_restore); + break; + case 2: + s->set_depth(p_restore); + break; } return; } @@ -241,9 +256,15 @@ void CSGShape3DGizmoPlugin::commit_handle(EditorNode3DGizmo *p_gizmo, int p_idx, static const char *method[3] = { "set_width", "set_height", "set_depth" }; float current = 0; switch (p_idx) { - case 0: current = s->get_width(); break; - case 1: current = s->get_height(); break; - case 2: current = s->get_depth(); break; + case 0: + current = s->get_width(); + break; + case 1: + current = s->get_height(); + break; + case 2: + current = s->get_depth(); + break; } ur->add_do_method(s, method[p_idx], current); diff --git a/modules/csg/csg_shape.cpp b/modules/csg/csg_shape.cpp index 550a919d0d..a5b664eeab 100644 --- a/modules/csg/csg_shape.cpp +++ b/modules/csg/csg_shape.cpp @@ -185,9 +185,15 @@ CSGBrush *CSGShape3D::_get_brush() { CSGBrushOperation bop; switch (child->get_operation()) { - case CSGShape3D::OPERATION_UNION: bop.merge_brushes(CSGBrushOperation::OPERATION_UNION, *n, *nn2, *nn, snap); break; - case CSGShape3D::OPERATION_INTERSECTION: bop.merge_brushes(CSGBrushOperation::OPERATION_INTERSECTION, *n, *nn2, *nn, snap); break; - case CSGShape3D::OPERATION_SUBTRACTION: bop.merge_brushes(CSGBrushOperation::OPERATION_SUBSTRACTION, *n, *nn2, *nn, snap); break; + case CSGShape3D::OPERATION_UNION: + bop.merge_brushes(CSGBrushOperation::OPERATION_UNION, *n, *nn2, *nn, snap); + break; + case CSGShape3D::OPERATION_INTERSECTION: + bop.merge_brushes(CSGBrushOperation::OPERATION_INTERSECTION, *n, *nn2, *nn, snap); + break; + case CSGShape3D::OPERATION_SUBTRACTION: + bop.merge_brushes(CSGBrushOperation::OPERATION_SUBSTRACTION, *n, *nn2, *nn, snap); + break; } memdelete(n); memdelete(nn2); @@ -340,21 +346,31 @@ void CSGShape3D::_update_shape() { } } - //fill arrays - Vector<Vector3> physics_faces; - bool fill_physics_faces = false; + // Update collision faces. if (root_collision_shape.is_valid()) { + + Vector<Vector3> physics_faces; physics_faces.resize(n->faces.size() * 3); - fill_physics_faces = true; - } + Vector3 *physicsw = physics_faces.ptrw(); - { - Vector3 *physicsw; + for (int i = 0; i < n->faces.size(); i++) { + + int order[3] = { 0, 1, 2 }; + + if (n->faces[i].invert) { + SWAP(order[1], order[2]); + } - if (fill_physics_faces) { - physicsw = physics_faces.ptrw(); + physicsw[i * 3 + 0] = n->faces[i].vertices[order[0]]; + physicsw[i * 3 + 1] = n->faces[i].vertices[order[1]]; + physicsw[i * 3 + 2] = n->faces[i].vertices[order[2]]; } + root_collision_shape->set_faces(physics_faces); + } + + //fill arrays + { for (int i = 0; i < n->faces.size(); i++) { int order[3] = { 0, 1, 2 }; @@ -363,12 +379,6 @@ void CSGShape3D::_update_shape() { SWAP(order[1], order[2]); } - if (fill_physics_faces) { - physicsw[i * 3 + 0] = n->faces[i].vertices[order[0]]; - physicsw[i * 3 + 1] = n->faces[i].vertices[order[1]]; - physicsw[i * 3 + 2] = n->faces[i].vertices[order[2]]; - } - int mat = n->faces[i].material; ERR_CONTINUE(mat < -1 || mat >= face_count.size()); int idx = mat == -1 ? face_count.size() - 1 : mat; @@ -452,10 +462,6 @@ void CSGShape3D::_update_shape() { root_mesh->surface_set_material(idx, surfaces[i].material); } - if (root_collision_shape.is_valid()) { - root_collision_shape->set_faces(physics_faces); - } - set_base(root_mesh->get_rid()); } AABB CSGShape3D::get_aabb() const { @@ -1783,11 +1789,15 @@ CSGBrush *CSGPolygon3D::_build_brush() { final_polygon_min = p; final_polygon_max = final_polygon_min; } else { - if (p.x < final_polygon_min.x) final_polygon_min.x = p.x; - if (p.y < final_polygon_min.y) final_polygon_min.y = p.y; - - if (p.x > final_polygon_max.x) final_polygon_max.x = p.x; - if (p.y > final_polygon_max.y) final_polygon_max.y = p.y; + if (p.x < final_polygon_min.x) + final_polygon_min.x = p.x; + if (p.y < final_polygon_min.y) + final_polygon_min.y = p.y; + + if (p.x > final_polygon_max.x) + final_polygon_max.x = p.x; + if (p.y > final_polygon_max.y) + final_polygon_max.y = p.y; } } Vector2 final_polygon_size = final_polygon_max - final_polygon_min; @@ -1826,8 +1836,12 @@ CSGBrush *CSGPolygon3D::_build_brush() { int face_count = 0; switch (mode) { - case MODE_DEPTH: face_count = triangles.size() * 2 / 3 + (final_polygon.size()) * 2; break; - case MODE_SPIN: face_count = (spin_degrees < 360 ? triangles.size() * 2 / 3 : 0) + (final_polygon.size()) * 2 * spin_sides; break; + case MODE_DEPTH: + face_count = triangles.size() * 2 / 3 + (final_polygon.size()) * 2; + break; + case MODE_SPIN: + face_count = (spin_degrees < 360 ? triangles.size() * 2 / 3 : 0) + (final_polygon.size()) * 2 * spin_sides; + break; case MODE_PATH: { float bl = curve->get_baked_length(); int splits = MAX(2, Math::ceil(bl / path_interval)); diff --git a/modules/denoise/SCsub b/modules/denoise/SCsub new file mode 100644 index 0000000000..0fa65c6296 --- /dev/null +++ b/modules/denoise/SCsub @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import resource_to_cpp +from platform_methods import run_in_subprocess + +Import("env") +Import("env_modules") + +env_oidn = env_modules.Clone() + +# Thirdparty source files +thirdparty_dir = "#thirdparty/oidn/" +thirdparty_sources = [ + "core/api.cpp", + "core/device.cpp", + "core/filter.cpp", + "core/network.cpp", + "core/autoencoder.cpp", + "core/transfer_function.cpp", + "weights/rtlightmap_hdr.gen.cpp", + "mkl-dnn/src/common/batch_normalization.cpp", + "mkl-dnn/src/common/concat.cpp", + "mkl-dnn/src/common/convolution.cpp", + "mkl-dnn/src/common/convolution_pd.cpp", + "mkl-dnn/src/common/deconvolution.cpp", + "mkl-dnn/src/common/eltwise.cpp", + "mkl-dnn/src/common/engine.cpp", + "mkl-dnn/src/common/inner_product.cpp", + "mkl-dnn/src/common/inner_product_pd.cpp", + "mkl-dnn/src/common/lrn.cpp", + "mkl-dnn/src/common/memory.cpp", + "mkl-dnn/src/common/memory_desc_wrapper.cpp", + "mkl-dnn/src/common/mkldnn_debug.cpp", + "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp", + "mkl-dnn/src/common/pooling.cpp", + "mkl-dnn/src/common/primitive.cpp", + "mkl-dnn/src/common/primitive_attr.cpp", + "mkl-dnn/src/common/primitive_desc.cpp", + "mkl-dnn/src/common/primitive_exec_types.cpp", + "mkl-dnn/src/common/primitive_iterator.cpp", + "mkl-dnn/src/common/query.cpp", + "mkl-dnn/src/common/reorder.cpp", + "mkl-dnn/src/common/rnn.cpp", + "mkl-dnn/src/common/scratchpad.cpp", + "mkl-dnn/src/common/shuffle.cpp", + "mkl-dnn/src/common/softmax.cpp", + "mkl-dnn/src/common/stream.cpp", + "mkl-dnn/src/common/sum.cpp", + "mkl-dnn/src/common/utils.cpp", + "mkl-dnn/src/common/verbose.cpp", + "mkl-dnn/src/cpu/cpu_barrier.cpp", + "mkl-dnn/src/cpu/cpu_concat.cpp", + "mkl-dnn/src/cpu/cpu_engine.cpp", + "mkl-dnn/src/cpu/cpu_memory.cpp", + "mkl-dnn/src/cpu/cpu_reducer.cpp", + "mkl-dnn/src/cpu/cpu_reorder.cpp", + "mkl-dnn/src/cpu/cpu_sum.cpp", + "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx2_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp", + "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp", + "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp", + "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_sse42_convolution.cpp", + "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp", + "mkl-dnn/src/cpu/jit_uni_eltwise.cpp", + "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp", + "mkl-dnn/src/cpu/jit_uni_pooling.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder.cpp", + "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp", + "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c", + "common/platform.cpp", + "common/thread.cpp", + "common/tensor.cpp", +] +thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources] + +thirdparty_include_dirs = [ + "", + "include", + "mkl-dnn/include", + "mkl-dnn/src", + "mkl-dnn/src/common", + "mkl-dnn/src/cpu/xbyak", + "mkl-dnn/src/cpu", +] +thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs] + + +env_oidn.Prepend(CPPPATH=thirdparty_include_dirs) +env_oidn.Append( + CPPDEFINES=[ + "MKLDNN_THR=MKLDNN_THR_SEQ", + "OIDN_STATIC_LIB", + "__STDC_CONSTANT_MACROS", + "__STDC_LIMIT_MACROS", + "DISABLE_VERBOSE", + "MKLDNN_ENABLE_CONCURRENT_EXEC", + "NDEBUG", + ] +) + +env_thirdparty = env_oidn.Clone() +env_thirdparty.disable_warnings() +env_thirdparty.add_source_files(env.modules_sources, thirdparty_sources) + +weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza" +weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp" + +env_thirdparty.Depends(weights_out_path, weights_in_path) +env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp) + +env_oidn.add_source_files(env.modules_sources, "denoise_wrapper.cpp") +env_modules.add_source_files(env.modules_sources, ["register_types.cpp", "lightmap_denoiser.cpp"]) diff --git a/modules/denoise/config.py b/modules/denoise/config.py new file mode 100644 index 0000000000..53b8f2f2e3 --- /dev/null +++ b/modules/denoise/config.py @@ -0,0 +1,6 @@ +def can_build(env, platform): + return env["tools"] + + +def configure(env): + pass diff --git a/modules/denoise/denoise_wrapper.cpp b/modules/denoise/denoise_wrapper.cpp new file mode 100644 index 0000000000..feeeaef507 --- /dev/null +++ b/modules/denoise/denoise_wrapper.cpp @@ -0,0 +1,34 @@ +#include "denoise_wrapper.h" +#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h" +#include <stdio.h> + +void *oidn_denoiser_init() { + OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU); + oidnCommitDevice(device); + return device; +} + +bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) { + OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr; + OIDNFilter filter = oidnNewFilter(device, "RTLightmap"); + oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0); + oidnSetFilter1b(filter, "hdr", true); + //oidnSetFilter1f(filter, "hdrScale", 1.0f); + oidnCommitFilter(filter); + oidnExecuteFilter(filter); + + const char *msg; + bool success = true; + if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) { + printf("LightmapDenoiser: %s\n", msg); + success = false; + } + + oidnReleaseFilter(filter); + return success; +} + +void oidn_denoiser_finish(void *device) { + oidnReleaseDevice((OIDNDeviceImpl *)device); +} diff --git a/modules/denoise/denoise_wrapper.h b/modules/denoise/denoise_wrapper.h new file mode 100644 index 0000000000..3aef326e22 --- /dev/null +++ b/modules/denoise/denoise_wrapper.h @@ -0,0 +1,8 @@ +#ifndef DENOISE_WRAPPER_H +#define DENOISE_WRAPPER_H + +void *oidn_denoiser_init(); +bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height); +void oidn_denoiser_finish(void *device); + +#endif // DENOISE_WRAPPER_H diff --git a/modules/denoise/lightmap_denoiser.cpp b/modules/denoise/lightmap_denoiser.cpp new file mode 100644 index 0000000000..c821b22d85 --- /dev/null +++ b/modules/denoise/lightmap_denoiser.cpp @@ -0,0 +1,63 @@ +/*************************************************************************/ +/* lightmap_denoiser.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "lightmap_denoiser.h" +#include "denoise_wrapper.h" + +LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() { + return memnew(LightmapDenoiserOIDN); +} + +void LightmapDenoiserOIDN::make_default_denoiser() { + create_function = create_oidn_denoiser; +} + +Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) { + + Ref<Image> img = p_image->duplicate(); + + img->convert(Image::FORMAT_RGBF); + + Vector<uint8_t> data = img->get_data(); + if (!oidn_denoise(device, (float *)data.ptrw(), img->get_width(), img->get_height())) { + return p_image; + } + + img->create(img->get_width(), img->get_height(), false, img->get_format(), data); + return img; +} + +LightmapDenoiserOIDN::LightmapDenoiserOIDN() { + device = oidn_denoiser_init(); +} + +LightmapDenoiserOIDN::~LightmapDenoiserOIDN() { + oidn_denoiser_finish(device); +} diff --git a/core/os/semaphore.cpp b/modules/denoise/lightmap_denoiser.h index 93f1e2dff4..ac0cc8b9db 100644 --- a/core/os/semaphore.cpp +++ b/modules/denoise/lightmap_denoiser.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* semaphore.cpp */ +/* lightmap_denoiser.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,4 +28,30 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#include "semaphore.h" +#ifndef LIGHTMAP_DENOISER_H +#define LIGHTMAP_DENOISER_H + +#include "core/object.h" +#include "scene/3d/lightmapper.h" + +struct OIDNDeviceImpl; + +class LightmapDenoiserOIDN : public LightmapDenoiser { + + GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser); + +protected: + void *device = nullptr; + +public: + static LightmapDenoiser *create_oidn_denoiser(); + + Ref<Image> denoise_image(const Ref<Image> &p_image); + + static void make_default_denoiser(); + + LightmapDenoiserOIDN(); + ~LightmapDenoiserOIDN(); +}; + +#endif // LIGHTMAP_DENOISER_H diff --git a/core/path_remap.cpp b/modules/denoise/register_types.cpp index e1708e0350..b6b92701c8 100644 --- a/core/path_remap.cpp +++ b/modules/denoise/register_types.cpp @@ -1,5 +1,5 @@ /*************************************************************************/ -/* path_remap.cpp */ +/* register_types.cpp */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,4 +28,14 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#include "path_remap.h" +#include "register_types.h" +#include "core/engine.h" +#include "lightmap_denoiser.h" + +void register_denoise_types() { + + LightmapDenoiserOIDN::make_default_denoiser(); +} + +void unregister_denoise_types() { +} diff --git a/core/math/audio_frame.cpp b/modules/denoise/register_types.h index c565ea9b13..2ffc36ee2c 100644 --- a/core/math/audio_frame.cpp +++ b/modules/denoise/register_types.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* audio_frame.cpp */ +/* register_types.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,4 +28,5 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#include "audio_frame.h" +void register_denoise_types(); +void unregister_denoise_types(); diff --git a/modules/denoise/resource_to_cpp.py b/modules/denoise/resource_to_cpp.py new file mode 100644 index 0000000000..4c0b67f701 --- /dev/null +++ b/modules/denoise/resource_to_cpp.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +## ======================================================================== ## +## Copyright 2009-2019 Intel Corporation ## +## ## +## Licensed under the Apache License, Version 2.0 (the "License"); ## +## you may not use this file except in compliance with the License. ## +## You may obtain a copy of the License at ## +## ## +## http://www.apache.org/licenses/LICENSE-2.0 ## +## ## +## Unless required by applicable law or agreed to in writing, software ## +## distributed under the License is distributed on an "AS IS" BASIS, ## +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ## +## See the License for the specific language governing permissions and ## +## limitations under the License. ## +## ======================================================================== ## + +import os +import sys +import argparse +from array import array + +# Generates a C++ file from the specified binary resource file +def generate(in_path, out_path): + + namespace = "oidn::weights" + scopes = namespace.split("::") + + file_name = os.path.basename(in_path) + var_name = os.path.splitext(file_name)[0] + + with open(in_path, "rb") as in_file, open(out_path, "w") as out_file: + # Header + out_file.write("// Generated from: %s\n" % file_name) + out_file.write("#include <cstddef>\n\n") + + # Open the namespaces + for s in scopes: + out_file.write("namespace %s {\n" % s) + if scopes: + out_file.write("\n") + + # Read the file + in_data = array("B", in_file.read()) + + # Write the size + out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data))) + + # Write the data + out_file.write("unsigned char %s[] = {" % var_name) + for i in range(len(in_data)): + c = in_data[i] + if i > 0: + out_file.write(",") + if (i + 1) % 20 == 1: + out_file.write("\n") + out_file.write("%d" % c) + out_file.write("\n};\n") + + # Close the namespaces + if scopes: + out_file.write("\n") + for scope in reversed(scopes): + out_file.write("} // namespace %s\n" % scope) + + +def tza_to_cpp(target, source, env): + for x in zip(source, target): + generate(str(x[0]), str(x[1])) diff --git a/modules/gdnative/gdnative/gdnative.cpp b/modules/gdnative/gdnative/gdnative.cpp index 3175340448..1216d1d9d3 100644 --- a/modules/gdnative/gdnative/gdnative.cpp +++ b/modules/gdnative/gdnative/gdnative.cpp @@ -177,7 +177,8 @@ void *godot_get_class_tag(const godot_string_name *p_class) { } godot_object *godot_object_cast_to(const godot_object *p_object, void *p_class_tag) { - if (!p_object) return nullptr; + if (!p_object) + return nullptr; Object *o = (Object *)p_object; return o->is_class_ptr(p_class_tag) ? (godot_object *)o : nullptr; diff --git a/modules/gdnative/include/videodecoder/godot_videodecoder.h b/modules/gdnative/include/videodecoder/godot_videodecoder.h index 3e91a2e9ac..16c92abd22 100644 --- a/modules/gdnative/include/videodecoder/godot_videodecoder.h +++ b/modules/gdnative/include/videodecoder/godot_videodecoder.h @@ -46,7 +46,7 @@ typedef struct void *next; void *(*constructor)(godot_object *); void (*destructor)(void *); - const char *(*get_plugin_name)(void); + const char *(*get_plugin_name)(); const char **(*get_supported_extensions)(int *count); godot_bool (*open_file)(void *, void *); // data struct, and a FileAccess pointer godot_real (*get_length)(const void *); diff --git a/modules/gdnative/nativescript/nativescript.h b/modules/gdnative/nativescript/nativescript.h index 7e7598e06c..d24b247e59 100644 --- a/modules/gdnative/nativescript/nativescript.h +++ b/modules/gdnative/nativescript/nativescript.h @@ -43,6 +43,7 @@ #include "scene/main/node.h" #include "modules/gdnative/gdnative.h" + #include <nativescript/godot_nativescript.h> struct NativeScriptDesc { @@ -54,6 +55,7 @@ struct NativeScriptDesc { uint16_t rpc_method_id; String documentation; }; + struct Property { godot_property_set_func setter; godot_property_get_func getter; @@ -69,9 +71,9 @@ struct NativeScriptDesc { String documentation; }; - uint16_t rpc_count; + uint16_t rpc_count = 0; Map<StringName, Method> methods; - uint16_t rset_count; + uint16_t rset_count = 0; OrderedHashMap<StringName, Property> properties; Map<StringName, Signal> signals_; // QtCreator doesn't like the name signals StringName base; @@ -82,20 +84,11 @@ struct NativeScriptDesc { String documentation; - const void *type_tag; + const void *type_tag = nullptr; bool is_tool; - inline NativeScriptDesc() : - rpc_count(0), - methods(), - rset_count(0), - properties(), - signals_(), - base(), - base_native_type(), - documentation(), - type_tag(nullptr) { + inline NativeScriptDesc() { zeromem(&create_func, sizeof(godot_instance_create_func)); zeromem(&destroy_func, sizeof(godot_instance_destroy_func)); } @@ -396,14 +389,13 @@ inline NativeScriptDesc *NativeScript::get_script_desc() const { class NativeReloadNode : public Node { GDCLASS(NativeReloadNode, Node); - bool unloaded; + bool unloaded = false; public: static void _bind_methods(); void _notification(int p_what); - NativeReloadNode() : - unloaded(false) {} + NativeReloadNode() {} }; class ResourceFormatLoaderNativeScript : public ResourceFormatLoader { diff --git a/modules/gdnative/pluginscript/pluginscript_script.cpp b/modules/gdnative/pluginscript/pluginscript_script.cpp index 6b303c8716..9b00d654d1 100644 --- a/modules/gdnative/pluginscript/pluginscript_script.cpp +++ b/modules/gdnative/pluginscript/pluginscript_script.cpp @@ -550,11 +550,6 @@ MultiplayerAPI::RPCMode PluginScript::get_rset_mode(const StringName &p_variable } PluginScript::PluginScript() : - _data(nullptr), - _desc(nullptr), - _language(nullptr), - _tool(false), - _valid(false), _script_list(this) { } diff --git a/modules/gdnative/pluginscript/pluginscript_script.h b/modules/gdnative/pluginscript/pluginscript_script.h index 70b9ca980b..287f42bf7b 100644 --- a/modules/gdnative/pluginscript/pluginscript_script.h +++ b/modules/gdnative/pluginscript/pluginscript_script.h @@ -45,11 +45,11 @@ class PluginScript : public Script { friend class PluginScriptLanguage; private: - godot_pluginscript_script_data *_data; - const godot_pluginscript_script_desc *_desc; - PluginScriptLanguage *_language; - bool _tool; - bool _valid; + godot_pluginscript_script_data *_data = nullptr; + const godot_pluginscript_script_desc *_desc = nullptr; + PluginScriptLanguage *_language = nullptr; + bool _tool = false; + bool _valid = false; Ref<Script> _ref_base_parent; StringName _native_parent; diff --git a/modules/gdnative/videodecoder/video_stream_gdnative.cpp b/modules/gdnative/videodecoder/video_stream_gdnative.cpp index f7d87595af..a2ff376e31 100644 --- a/modules/gdnative/videodecoder/video_stream_gdnative.cpp +++ b/modules/gdnative/videodecoder/video_stream_gdnative.cpp @@ -202,22 +202,7 @@ void VideoStreamPlaybackGDNative::update_texture() { // ctor and dtor VideoStreamPlaybackGDNative::VideoStreamPlaybackGDNative() : - texture(Ref<ImageTexture>(memnew(ImageTexture))), - playing(false), - paused(false), - mix_udata(nullptr), - mix_callback(nullptr), - num_channels(-1), - time(0), - seek_backward(false), - mix_rate(0), - delay_compensation(0), - pcm(nullptr), - pcm_write_idx(0), - samples_decoded(0), - file(nullptr), - interface(nullptr), - data_struct(nullptr) {} + texture(Ref<ImageTexture>(memnew(ImageTexture))) {} VideoStreamPlaybackGDNative::~VideoStreamPlaybackGDNative() { cleanup(); diff --git a/modules/gdnative/videodecoder/video_stream_gdnative.h b/modules/gdnative/videodecoder/video_stream_gdnative.h index 092e10a0f5..f1bae22801 100644 --- a/modules/gdnative/videodecoder/video_stream_gdnative.h +++ b/modules/gdnative/videodecoder/video_stream_gdnative.h @@ -37,13 +37,11 @@ #include "scene/resources/video_stream.h" struct VideoDecoderGDNative { - const godot_videodecoder_interface_gdnative *interface; - String plugin_name; + const godot_videodecoder_interface_gdnative *interface = nullptr; + String plugin_name = "none"; Vector<String> supported_extensions; - VideoDecoderGDNative() : - interface(nullptr), - plugin_name("none") {} + VideoDecoderGDNative() {} VideoDecoderGDNative(const godot_videodecoder_interface_gdnative *p_interface) : interface(p_interface), @@ -111,23 +109,23 @@ class VideoStreamPlaybackGDNative : public VideoStreamPlayback { GDCLASS(VideoStreamPlaybackGDNative, VideoStreamPlayback); Ref<ImageTexture> texture; - bool playing; - bool paused; + bool playing = false; + bool paused = false; Vector2 texture_size; - void *mix_udata; - AudioMixCallback mix_callback; + void *mix_udata = nullptr; + AudioMixCallback mix_callback = nullptr; - int num_channels; - float time; - bool seek_backward; - int mix_rate; - double delay_compensation; + int num_channels = -1; + float time = 0; + bool seek_backward = false; + int mix_rate = 0; + double delay_compensation = 0; - float *pcm; - int pcm_write_idx; - int samples_decoded; + float *pcm = nullptr; + int pcm_write_idx = 0; + int samples_decoded = 0; void cleanup(); void update_texture(); @@ -135,10 +133,10 @@ class VideoStreamPlaybackGDNative : public VideoStreamPlayback { protected: String file_name; - FileAccess *file; + FileAccess *file = nullptr; - const godot_videodecoder_interface_gdnative *interface; - void *data_struct; + const godot_videodecoder_interface_gdnative *interface = nullptr; + void *data_struct = nullptr; public: VideoStreamPlaybackGDNative(); @@ -181,7 +179,7 @@ class VideoStreamGDNative : public VideoStream { GDCLASS(VideoStreamGDNative, VideoStream); String file; - int audio_track; + int audio_track = 0; protected: static void @@ -194,7 +192,7 @@ public: virtual void set_audio_track(int p_track); virtual Ref<VideoStreamPlayback> instance_playback(); - VideoStreamGDNative() { audio_track = 0; } + VideoStreamGDNative() {} }; class ResourceFormatLoaderVideoStreamGDNative : public ResourceFormatLoader { diff --git a/modules/gdnavigation/gd_navigation_server.cpp b/modules/gdnavigation/gd_navigation_server.cpp index 278c27ae22..3792098af4 100644 --- a/modules/gdnavigation/gd_navigation_server.cpp +++ b/modules/gdnavigation/gd_navigation_server.cpp @@ -114,8 +114,7 @@ void GdNavigationServer::MERGE(_cmd_, F_NAME)(T_0 D_0, T_1 D_1, T_2 D_2, T_3 D_3) GdNavigationServer::GdNavigationServer() : - NavigationServer3D(), - active(true) { + NavigationServer3D() { } GdNavigationServer::~GdNavigationServer() { diff --git a/modules/gdnavigation/gd_navigation_server.h b/modules/gdnavigation/gd_navigation_server.h index 01d1a4fba9..e3e02f3d7c 100644 --- a/modules/gdnavigation/gd_navigation_server.h +++ b/modules/gdnavigation/gd_navigation_server.h @@ -78,7 +78,7 @@ class GdNavigationServer : public NavigationServer3D { mutable RID_PtrOwner<NavRegion> region_owner; mutable RID_PtrOwner<RvoAgent> agent_owner; - bool active; + bool active = true; Vector<NavMap *> active_maps; public: diff --git a/modules/gdnavigation/nav_map.cpp b/modules/gdnavigation/nav_map.cpp index 7e6a3f7a26..d6dd95d6e7 100644 --- a/modules/gdnavigation/nav_map.cpp +++ b/modules/gdnavigation/nav_map.cpp @@ -33,6 +33,7 @@ #include "core/os/threaded_array_processor.h" #include "nav_region.h" #include "rvo_agent.h" + #include <algorithm> /** @@ -41,16 +42,6 @@ #define USE_ENTRY_POINT -NavMap::NavMap() : - up(0, 1, 0), - cell_size(0.3), - edge_connection_margin(5.0), - regenerate_polygons(true), - regenerate_links(true), - agents_dirty(false), - deltatime(0.0), - map_update_id(0) {} - void NavMap::set_up(Vector3 p_up) { up = p_up; regenerate_polygons = true; diff --git a/modules/gdnavigation/nav_map.h b/modules/gdnavigation/nav_map.h index 4543f00926..d39e301511 100644 --- a/modules/gdnavigation/nav_map.h +++ b/modules/gdnavigation/nav_map.h @@ -48,17 +48,17 @@ class NavRegion; class NavMap : public NavRid { /// Map Up - Vector3 up; + Vector3 up = Vector3(0, 1, 0); /// To find the polygons edges the vertices are displaced in a grid where /// each cell has the following cell_size. - real_t cell_size; + real_t cell_size = 0.3; /// This value is used to detect the near edges to connect. - real_t edge_connection_margin; + real_t edge_connection_margin = 5.0; - bool regenerate_polygons; - bool regenerate_links; + bool regenerate_polygons = true; + bool regenerate_links = true; std::vector<NavRegion *> regions; @@ -69,7 +69,7 @@ class NavMap : public NavRid { RVO::KdTree rvo; /// Is agent array modified? - bool agents_dirty; + bool agents_dirty = false; /// All the Agents (even the controlled one) std::vector<RvoAgent *> agents; @@ -78,13 +78,13 @@ class NavMap : public NavRid { std::vector<RvoAgent *> controlled_agents; /// Physics delta time - real_t deltatime; + real_t deltatime = 0.0; /// Change the id each time the map is updated. - uint32_t map_update_id; + uint32_t map_update_id = 0; public: - NavMap(); + NavMap() {} void set_up(Vector3 p_up); Vector3 get_up() const { diff --git a/modules/gdnavigation/nav_region.cpp b/modules/gdnavigation/nav_region.cpp index b91376f761..2bd42ba980 100644 --- a/modules/gdnavigation/nav_region.cpp +++ b/modules/gdnavigation/nav_region.cpp @@ -36,11 +36,6 @@ @author AndreaCatania */ -NavRegion::NavRegion() : - map(nullptr), - polygons_dirty(true) { -} - void NavRegion::set_map(NavMap *p_map) { map = p_map; polygons_dirty = true; diff --git a/modules/gdnavigation/nav_region.h b/modules/gdnavigation/nav_region.h index f35ee4bea0..731855bfb5 100644 --- a/modules/gdnavigation/nav_region.h +++ b/modules/gdnavigation/nav_region.h @@ -45,17 +45,17 @@ class NavMap; class NavRegion; class NavRegion : public NavRid { - NavMap *map; + NavMap *map = nullptr; Transform transform; Ref<NavigationMesh> mesh; - bool polygons_dirty; + bool polygons_dirty = true; /// Cache std::vector<gd::Polygon> polygons; public: - NavRegion(); + NavRegion() {} void scratch_polygons() { polygons_dirty = true; diff --git a/modules/gdnavigation/nav_utils.h b/modules/gdnavigation/nav_utils.h index 3401284c31..388e53a66a 100644 --- a/modules/gdnavigation/nav_utils.h +++ b/modules/gdnavigation/nav_utils.h @@ -32,6 +32,7 @@ #define NAV_UTILS_H #include "core/math/vector3.h" + #include <vector> /** @@ -80,19 +81,15 @@ struct Point { struct Edge { /// This edge ID - int this_edge; + int this_edge = -1; /// Other Polygon - Polygon *other_polygon; + Polygon *other_polygon = nullptr; /// The other `Polygon` at this edge id has this `Polygon`. - int other_edge; + int other_edge = -1; - Edge() { - this_edge = -1; - other_polygon = nullptr; - other_edge = -1; - } + Edge() {} }; struct Polygon { @@ -113,39 +110,29 @@ struct Polygon { struct Connection { - Polygon *A; - int A_edge; - Polygon *B; - int B_edge; + Polygon *A = nullptr; + int A_edge = -1; + Polygon *B = nullptr; + int B_edge = -1; - Connection() { - A = nullptr; - B = nullptr; - A_edge = -1; - B_edge = -1; - } + Connection() {} }; struct NavigationPoly { - uint32_t self_id; + uint32_t self_id = 0; /// This poly. const Polygon *poly; /// The previous navigation poly (id in the `navigation_poly` array). - int prev_navigation_poly_id; + int prev_navigation_poly_id = -1; /// The edge id in this `Poly` to reach the `prev_navigation_poly_id`. - uint32_t back_navigation_edge; + uint32_t back_navigation_edge = 0; /// The entry location of this poly. Vector3 entry; /// The distance to the destination. - float traveled_distance; + float traveled_distance = 0.0; NavigationPoly(const Polygon *p_poly) : - self_id(0), - poly(p_poly), - prev_navigation_poly_id(-1), - back_navigation_edge(0), - traveled_distance(0.0) { - } + poly(p_poly) {} bool operator==(const NavigationPoly &other) const { return this->poly == other.poly; diff --git a/modules/gdnavigation/rvo_agent.cpp b/modules/gdnavigation/rvo_agent.cpp index 3c39f02c26..1e1bdbd07d 100644 --- a/modules/gdnavigation/rvo_agent.cpp +++ b/modules/gdnavigation/rvo_agent.cpp @@ -36,8 +36,7 @@ @author AndreaCatania */ -RvoAgent::RvoAgent() : - map(nullptr) { +RvoAgent::RvoAgent() { callback.id = ObjectID(); } diff --git a/modules/gdnavigation/rvo_agent.h b/modules/gdnavigation/rvo_agent.h index 914cbaa7d9..f5c579ba84 100644 --- a/modules/gdnavigation/rvo_agent.h +++ b/modules/gdnavigation/rvo_agent.h @@ -49,7 +49,7 @@ class RvoAgent : public NavRid { Variant new_velocity; }; - NavMap *map; + NavMap *map = nullptr; RVO::Agent agent; AvoidanceComputedCallback callback; uint32_t map_update_id; diff --git a/modules/gdscript/gdscript.cpp b/modules/gdscript/gdscript.cpp index cdd5deb7ee..8559fac57c 100644 --- a/modules/gdscript/gdscript.cpp +++ b/modules/gdscript/gdscript.cpp @@ -376,10 +376,15 @@ void GDScript::_update_exports_values(Map<StringName, Variant> &values, List<Pro } #endif -bool GDScript::_update_exports() { +bool GDScript::_update_exports(bool *r_err, bool p_recursive_call) { #ifdef TOOLS_ENABLED + static Vector<GDScript *> base_caches; + if (!p_recursive_call) + base_caches.clear(); + base_caches.append(this); + bool changed = false; if (source_changed_cache) { @@ -473,7 +478,22 @@ bool GDScript::_update_exports() { placeholder_fallback_enabled = false; if (base_cache.is_valid() && base_cache->is_valid()) { - if (base_cache->_update_exports()) { + for (int i = 0; i < base_caches.size(); i++) { + if (base_caches[i] == base_cache.ptr()) { + if (r_err) + *r_err = true; + valid = false; // to show error in the editor + base_cache->valid = false; + base_cache->inheriters_cache.clear(); // to prevent future stackoverflows + base_cache.unref(); + base.unref(); + _base = nullptr; + ERR_FAIL_V_MSG(false, "Cyclic inheritance in script class."); + } + } + if (base_cache->_update_exports(r_err, true)) { + if (r_err && *r_err) + return false; changed = true; } } @@ -501,7 +521,10 @@ void GDScript::update_exports() { #ifdef TOOLS_ENABLED - _update_exports(); + bool cyclic_error = false; + _update_exports(&cyclic_error); + if (cyclic_error) + return; Set<ObjectID> copy = inheriters_cache; //might get modified @@ -635,12 +658,14 @@ uint16_t GDScript::get_rpc_method_id(const StringName &p_method) const { } StringName GDScript::get_rpc_method(const uint16_t p_rpc_method_id) const { - if (p_rpc_method_id >= rpc_functions.size()) return StringName(); + if (p_rpc_method_id >= rpc_functions.size()) + return StringName(); return rpc_functions[p_rpc_method_id].name; } MultiplayerAPI::RPCMode GDScript::get_rpc_mode_by_id(const uint16_t p_rpc_method_id) const { - if (p_rpc_method_id >= rpc_functions.size()) return MultiplayerAPI::RPC_MODE_DISABLED; + if (p_rpc_method_id >= rpc_functions.size()) + return MultiplayerAPI::RPC_MODE_DISABLED; return rpc_functions[p_rpc_method_id].mode; } @@ -662,12 +687,14 @@ uint16_t GDScript::get_rset_property_id(const StringName &p_variable) const { } StringName GDScript::get_rset_property(const uint16_t p_rset_member_id) const { - if (p_rset_member_id >= rpc_variables.size()) return StringName(); + if (p_rset_member_id >= rpc_variables.size()) + return StringName(); return rpc_variables[p_rset_member_id].name; } MultiplayerAPI::RPCMode GDScript::get_rset_mode_by_id(const uint16_t p_rset_member_id) const { - if (p_rset_member_id >= rpc_variables.size()) return MultiplayerAPI::RPC_MODE_DISABLED; + if (p_rset_member_id >= rpc_variables.size()) + return MultiplayerAPI::RPC_MODE_DISABLED; return rpc_variables[p_rset_member_id].mode; } @@ -2157,7 +2184,8 @@ String GDScriptWarning::get_message() const { case STANDALONE_TERNARY: { return "Standalone ternary conditional operator: the return value is being discarded."; } - case WARNING_MAX: break; // Can't happen, but silences warning + case WARNING_MAX: + break; // Can't happen, but silences warning } ERR_FAIL_V_MSG(String(), "Invalid GDScript warning code: " + get_name_from_code(code) + "."); diff --git a/modules/gdscript/gdscript.h b/modules/gdscript/gdscript.h index 2dbc2252fa..3cba621578 100644 --- a/modules/gdscript/gdscript.h +++ b/modules/gdscript/gdscript.h @@ -135,7 +135,7 @@ class GDScript : public Script { #endif - bool _update_exports(); + bool _update_exports(bool *r_err = nullptr, bool p_recursive_call = false); void _save_orphaned_subclasses(); void _init_rpc_methods_properties(); @@ -334,18 +334,18 @@ struct GDScriptWarning { DEPRECATED_KEYWORD, // The keyword is deprecated and should be replaced STANDALONE_TERNARY, // Return value of ternary expression is discarded WARNING_MAX, - } code; + }; + + Code code = WARNING_MAX; Vector<String> symbols; - int line; + int line = -1; String get_name() const; String get_message() const; static String get_name_from_code(Code p_code); static Code get_code_from_name(const String &p_name); - GDScriptWarning() : - code(WARNING_MAX), - line(-1) {} + GDScriptWarning() {} }; #endif // DEBUG_ENABLED diff --git a/modules/gdscript/gdscript_compiler.cpp b/modules/gdscript/gdscript_compiler.cpp index 2bbec29043..473b6fab05 100644 --- a/modules/gdscript/gdscript_compiler.cpp +++ b/modules/gdscript/gdscript_compiler.cpp @@ -177,16 +177,36 @@ int GDScriptCompiler::_parse_assign_right_expression(CodeGen &codegen, const GDS switch (p_expression->op) { - case GDScriptParser::OperatorNode::OP_ASSIGN_ADD: var_op = Variant::OP_ADD; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_SUB: var_op = Variant::OP_SUBTRACT; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_MUL: var_op = Variant::OP_MULTIPLY; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_DIV: var_op = Variant::OP_DIVIDE; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_MOD: var_op = Variant::OP_MODULE; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_SHIFT_LEFT: var_op = Variant::OP_SHIFT_LEFT; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_SHIFT_RIGHT: var_op = Variant::OP_SHIFT_RIGHT; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_AND: var_op = Variant::OP_BIT_AND; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_OR: var_op = Variant::OP_BIT_OR; break; - case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_XOR: var_op = Variant::OP_BIT_XOR; break; + case GDScriptParser::OperatorNode::OP_ASSIGN_ADD: + var_op = Variant::OP_ADD; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_SUB: + var_op = Variant::OP_SUBTRACT; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_MUL: + var_op = Variant::OP_MULTIPLY; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_DIV: + var_op = Variant::OP_DIVIDE; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_MOD: + var_op = Variant::OP_MODULE; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_SHIFT_LEFT: + var_op = Variant::OP_SHIFT_LEFT; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_SHIFT_RIGHT: + var_op = Variant::OP_SHIFT_RIGHT; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_AND: + var_op = Variant::OP_BIT_AND; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_OR: + var_op = Variant::OP_BIT_OR; + break; + case GDScriptParser::OperatorNode::OP_ASSIGN_BIT_XOR: + var_op = Variant::OP_BIT_XOR; + break; case GDScriptParser::OperatorNode::OP_INIT_ASSIGN: case GDScriptParser::OperatorNode::OP_ASSIGN: { @@ -861,71 +881,92 @@ int GDScriptCompiler::_parse_expression(CodeGen &codegen, const GDScriptParser:: } break; //unary operators case GDScriptParser::OperatorNode::OP_NEG: { - if (!_create_unary_operator(codegen, on, Variant::OP_NEGATE, p_stack_level)) return -1; + if (!_create_unary_operator(codegen, on, Variant::OP_NEGATE, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_POS: { - if (!_create_unary_operator(codegen, on, Variant::OP_POSITIVE, p_stack_level)) return -1; + if (!_create_unary_operator(codegen, on, Variant::OP_POSITIVE, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_NOT: { - if (!_create_unary_operator(codegen, on, Variant::OP_NOT, p_stack_level)) return -1; + if (!_create_unary_operator(codegen, on, Variant::OP_NOT, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_BIT_INVERT: { - if (!_create_unary_operator(codegen, on, Variant::OP_BIT_NEGATE, p_stack_level)) return -1; + if (!_create_unary_operator(codegen, on, Variant::OP_BIT_NEGATE, p_stack_level)) + return -1; } break; //binary operators (in precedence order) case GDScriptParser::OperatorNode::OP_IN: { - if (!_create_binary_operator(codegen, on, Variant::OP_IN, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_IN, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_EQUAL: { - if (!_create_binary_operator(codegen, on, Variant::OP_EQUAL, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_EQUAL, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_NOT_EQUAL: { - if (!_create_binary_operator(codegen, on, Variant::OP_NOT_EQUAL, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_NOT_EQUAL, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_LESS: { - if (!_create_binary_operator(codegen, on, Variant::OP_LESS, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_LESS, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_LESS_EQUAL: { - if (!_create_binary_operator(codegen, on, Variant::OP_LESS_EQUAL, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_LESS_EQUAL, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_GREATER: { - if (!_create_binary_operator(codegen, on, Variant::OP_GREATER, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_GREATER, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_GREATER_EQUAL: { - if (!_create_binary_operator(codegen, on, Variant::OP_GREATER_EQUAL, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_GREATER_EQUAL, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_ADD: { - if (!_create_binary_operator(codegen, on, Variant::OP_ADD, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_ADD, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_SUB: { - if (!_create_binary_operator(codegen, on, Variant::OP_SUBTRACT, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_SUBTRACT, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_MUL: { - if (!_create_binary_operator(codegen, on, Variant::OP_MULTIPLY, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_MULTIPLY, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_DIV: { - if (!_create_binary_operator(codegen, on, Variant::OP_DIVIDE, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_DIVIDE, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_MOD: { - if (!_create_binary_operator(codegen, on, Variant::OP_MODULE, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_MODULE, p_stack_level)) + return -1; } break; //case GDScriptParser::OperatorNode::OP_SHIFT_LEFT: { if (!_create_binary_operator(codegen,on,Variant::OP_SHIFT_LEFT,p_stack_level)) return -1;} break; //case GDScriptParser::OperatorNode::OP_SHIFT_RIGHT: { if (!_create_binary_operator(codegen,on,Variant::OP_SHIFT_RIGHT,p_stack_level)) return -1;} break; case GDScriptParser::OperatorNode::OP_BIT_AND: { - if (!_create_binary_operator(codegen, on, Variant::OP_BIT_AND, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_BIT_AND, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_BIT_OR: { - if (!_create_binary_operator(codegen, on, Variant::OP_BIT_OR, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_BIT_OR, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_BIT_XOR: { - if (!_create_binary_operator(codegen, on, Variant::OP_BIT_XOR, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_BIT_XOR, p_stack_level)) + return -1; } break; //shift case GDScriptParser::OperatorNode::OP_SHIFT_LEFT: { - if (!_create_binary_operator(codegen, on, Variant::OP_SHIFT_LEFT, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_SHIFT_LEFT, p_stack_level)) + return -1; } break; case GDScriptParser::OperatorNode::OP_SHIFT_RIGHT: { - if (!_create_binary_operator(codegen, on, Variant::OP_SHIFT_RIGHT, p_stack_level)) return -1; + if (!_create_binary_operator(codegen, on, Variant::OP_SHIFT_RIGHT, p_stack_level)) + return -1; } break; //assignment operators case GDScriptParser::OperatorNode::OP_ASSIGN_ADD: diff --git a/modules/gdscript/gdscript_compiler.h b/modules/gdscript/gdscript_compiler.h index 34b066b5c7..08e1ec74af 100644 --- a/modules/gdscript/gdscript_compiler.h +++ b/modules/gdscript/gdscript_compiler.h @@ -123,10 +123,12 @@ class GDScriptCompiler { Vector<int> opcodes; void alloc_stack(int p_level) { - if (p_level >= stack_max) stack_max = p_level + 1; + if (p_level >= stack_max) + stack_max = p_level + 1; } void alloc_call(int p_params) { - if (p_params >= call_max) call_max = p_params; + if (p_params >= call_max) + call_max = p_params; } int current_line; diff --git a/modules/gdscript/gdscript_editor.cpp b/modules/gdscript/gdscript_editor.cpp index ab3228d076..56381e8af7 100644 --- a/modules/gdscript/gdscript_editor.cpp +++ b/modules/gdscript/gdscript_editor.cpp @@ -492,31 +492,24 @@ String GDScriptLanguage::make_function(const String &p_class, const String &p_na struct GDScriptCompletionContext { - const GDScriptParser::ClassNode *_class; - const GDScriptParser::FunctionNode *function; - const GDScriptParser::BlockNode *block; - Object *base; + const GDScriptParser::ClassNode *_class = nullptr; + const GDScriptParser::FunctionNode *function = nullptr; + const GDScriptParser::BlockNode *block = nullptr; + Object *base = nullptr; String base_path; - int line; - uint32_t depth; - - GDScriptCompletionContext() : - _class(nullptr), - function(nullptr), - block(nullptr), - base(nullptr), - line(0), - depth(0) {} + int line = 0; + uint32_t depth = 0; + + GDScriptCompletionContext() {} }; struct GDScriptCompletionIdentifier { GDScriptParser::DataType type; String enumeration; Variant value; - const GDScriptParser::Node *assigned_expression; + const GDScriptParser::Node *assigned_expression = nullptr; - GDScriptCompletionIdentifier() : - assigned_expression(nullptr) {} + GDScriptCompletionIdentifier() {} }; static void _get_directory_contents(EditorFileSystemDirectory *p_dir, Map<String, ScriptCodeCompletionOption> &r_list) { @@ -1082,16 +1075,36 @@ static bool _guess_expression_type(GDScriptCompletionContext &p_context, const G Variant::Operator vop = Variant::OP_MAX; switch (op->op) { - case GDScriptParser::OperatorNode::OP_ADD: vop = Variant::OP_ADD; break; - case GDScriptParser::OperatorNode::OP_SUB: vop = Variant::OP_SUBTRACT; break; - case GDScriptParser::OperatorNode::OP_MUL: vop = Variant::OP_MULTIPLY; break; - case GDScriptParser::OperatorNode::OP_DIV: vop = Variant::OP_DIVIDE; break; - case GDScriptParser::OperatorNode::OP_MOD: vop = Variant::OP_MODULE; break; - case GDScriptParser::OperatorNode::OP_SHIFT_LEFT: vop = Variant::OP_SHIFT_LEFT; break; - case GDScriptParser::OperatorNode::OP_SHIFT_RIGHT: vop = Variant::OP_SHIFT_RIGHT; break; - case GDScriptParser::OperatorNode::OP_BIT_AND: vop = Variant::OP_BIT_AND; break; - case GDScriptParser::OperatorNode::OP_BIT_OR: vop = Variant::OP_BIT_OR; break; - case GDScriptParser::OperatorNode::OP_BIT_XOR: vop = Variant::OP_BIT_XOR; break; + case GDScriptParser::OperatorNode::OP_ADD: + vop = Variant::OP_ADD; + break; + case GDScriptParser::OperatorNode::OP_SUB: + vop = Variant::OP_SUBTRACT; + break; + case GDScriptParser::OperatorNode::OP_MUL: + vop = Variant::OP_MULTIPLY; + break; + case GDScriptParser::OperatorNode::OP_DIV: + vop = Variant::OP_DIVIDE; + break; + case GDScriptParser::OperatorNode::OP_MOD: + vop = Variant::OP_MODULE; + break; + case GDScriptParser::OperatorNode::OP_SHIFT_LEFT: + vop = Variant::OP_SHIFT_LEFT; + break; + case GDScriptParser::OperatorNode::OP_SHIFT_RIGHT: + vop = Variant::OP_SHIFT_RIGHT; + break; + case GDScriptParser::OperatorNode::OP_BIT_AND: + vop = Variant::OP_BIT_AND; + break; + case GDScriptParser::OperatorNode::OP_BIT_OR: + vop = Variant::OP_BIT_OR; + break; + case GDScriptParser::OperatorNode::OP_BIT_XOR: + vop = Variant::OP_BIT_XOR; + break; default: { } } diff --git a/modules/gdscript/gdscript_function.cpp b/modules/gdscript/gdscript_function.cpp index df0fac956c..4e31ffe2a4 100644 --- a/modules/gdscript/gdscript_function.cpp +++ b/modules/gdscript/gdscript_function.cpp @@ -1289,7 +1289,7 @@ Variant GDScriptFunction::call(GDScriptInstance *p_instance, const Variant **p_a gdfs->state.instance = p_instance; p_instance->pending_func_states.add(&gdfs->instances_list); } else { - gdfs->state.instance = NULL; + gdfs->state.instance = nullptr; } } #ifdef DEBUG_ENABLED diff --git a/modules/gdscript/gdscript_function.h b/modules/gdscript/gdscript_function.h index d38b6d0739..7043c9b69b 100644 --- a/modules/gdscript/gdscript_function.h +++ b/modules/gdscript/gdscript_function.h @@ -43,20 +43,24 @@ class GDScriptInstance; class GDScript; struct GDScriptDataType { - bool has_type; - enum { + enum Kind { UNINITIALIZED, BUILTIN, NATIVE, SCRIPT, GDSCRIPT, - } kind; - Variant::Type builtin_type; + }; + + Kind kind = UNINITIALIZED; + + bool has_type = false; + Variant::Type builtin_type = Variant::NIL; StringName native_type; Ref<Script> script_type; bool is_type(const Variant &p_variant, bool p_allow_implicit_conversion = false) const { - if (!has_type) return true; // Can't type check + if (!has_type) + return true; // Can't type check switch (kind) { case UNINITIALIZED: @@ -146,10 +150,7 @@ struct GDScriptDataType { return info; } - GDScriptDataType() : - has_type(false), - kind(UNINITIALIZED), - builtin_type(Variant::NIL) {} + GDScriptDataType() {} }; class GDScriptFunction { diff --git a/modules/gdscript/gdscript_parser.cpp b/modules/gdscript/gdscript_parser.cpp index b3ec42a7fb..8cc6986f34 100644 --- a/modules/gdscript/gdscript_parser.cpp +++ b/modules/gdscript/gdscript_parser.cpp @@ -72,6 +72,16 @@ bool GDScriptParser::_end_statement() { return false; } +void GDScriptParser::_set_end_statement_error(String p_name) { + String error_msg; + if (tokenizer->get_token() == GDScriptTokenizer::TK_IDENTIFIER) { + error_msg = vformat("Expected end of statement (\"%s\"), got %s (\"%s\") instead.", p_name, tokenizer->get_token_name(tokenizer->get_token()), tokenizer->get_token_identifier()); + } else { + error_msg = vformat("Expected end of statement (\"%s\"), got %s instead.", p_name, tokenizer->get_token_name(tokenizer->get_token())); + } + _set_error(error_msg); +} + bool GDScriptParser::_enter_indent_block(BlockNode *p_block) { if (tokenizer->get_token() != GDScriptTokenizer::TK_COLON) { @@ -910,10 +920,18 @@ GDScriptParser::Node *GDScriptParser::_parse_expression(Node *p_parent, bool p_s e.is_op = true; switch (tokenizer->get_token()) { - case GDScriptTokenizer::TK_OP_ADD: e.op = OperatorNode::OP_POS; break; - case GDScriptTokenizer::TK_OP_SUB: e.op = OperatorNode::OP_NEG; break; - case GDScriptTokenizer::TK_OP_NOT: e.op = OperatorNode::OP_NOT; break; - case GDScriptTokenizer::TK_OP_BIT_INVERT: e.op = OperatorNode::OP_BIT_INVERT; break; + case GDScriptTokenizer::TK_OP_ADD: + e.op = OperatorNode::OP_POS; + break; + case GDScriptTokenizer::TK_OP_SUB: + e.op = OperatorNode::OP_NEG; + break; + case GDScriptTokenizer::TK_OP_NOT: + e.op = OperatorNode::OP_NOT; + break; + case GDScriptTokenizer::TK_OP_BIT_INVERT: + e.op = OperatorNode::OP_BIT_INVERT; + break; default: { } } @@ -1329,25 +1347,55 @@ GDScriptParser::Node *GDScriptParser::_parse_expression(Node *p_parent, bool p_s switch (tokenizer->get_token()) { //see operator - case GDScriptTokenizer::TK_OP_IN: op = OperatorNode::OP_IN; break; - case GDScriptTokenizer::TK_OP_EQUAL: op = OperatorNode::OP_EQUAL; break; - case GDScriptTokenizer::TK_OP_NOT_EQUAL: op = OperatorNode::OP_NOT_EQUAL; break; - case GDScriptTokenizer::TK_OP_LESS: op = OperatorNode::OP_LESS; break; - case GDScriptTokenizer::TK_OP_LESS_EQUAL: op = OperatorNode::OP_LESS_EQUAL; break; - case GDScriptTokenizer::TK_OP_GREATER: op = OperatorNode::OP_GREATER; break; - case GDScriptTokenizer::TK_OP_GREATER_EQUAL: op = OperatorNode::OP_GREATER_EQUAL; break; - case GDScriptTokenizer::TK_OP_AND: op = OperatorNode::OP_AND; break; - case GDScriptTokenizer::TK_OP_OR: op = OperatorNode::OP_OR; break; - case GDScriptTokenizer::TK_OP_ADD: op = OperatorNode::OP_ADD; break; - case GDScriptTokenizer::TK_OP_SUB: op = OperatorNode::OP_SUB; break; - case GDScriptTokenizer::TK_OP_MUL: op = OperatorNode::OP_MUL; break; - case GDScriptTokenizer::TK_OP_DIV: op = OperatorNode::OP_DIV; break; + case GDScriptTokenizer::TK_OP_IN: + op = OperatorNode::OP_IN; + break; + case GDScriptTokenizer::TK_OP_EQUAL: + op = OperatorNode::OP_EQUAL; + break; + case GDScriptTokenizer::TK_OP_NOT_EQUAL: + op = OperatorNode::OP_NOT_EQUAL; + break; + case GDScriptTokenizer::TK_OP_LESS: + op = OperatorNode::OP_LESS; + break; + case GDScriptTokenizer::TK_OP_LESS_EQUAL: + op = OperatorNode::OP_LESS_EQUAL; + break; + case GDScriptTokenizer::TK_OP_GREATER: + op = OperatorNode::OP_GREATER; + break; + case GDScriptTokenizer::TK_OP_GREATER_EQUAL: + op = OperatorNode::OP_GREATER_EQUAL; + break; + case GDScriptTokenizer::TK_OP_AND: + op = OperatorNode::OP_AND; + break; + case GDScriptTokenizer::TK_OP_OR: + op = OperatorNode::OP_OR; + break; + case GDScriptTokenizer::TK_OP_ADD: + op = OperatorNode::OP_ADD; + break; + case GDScriptTokenizer::TK_OP_SUB: + op = OperatorNode::OP_SUB; + break; + case GDScriptTokenizer::TK_OP_MUL: + op = OperatorNode::OP_MUL; + break; + case GDScriptTokenizer::TK_OP_DIV: + op = OperatorNode::OP_DIV; + break; case GDScriptTokenizer::TK_OP_MOD: op = OperatorNode::OP_MOD; break; //case GDScriptTokenizer::TK_OP_NEG: op=OperatorNode::OP_NEG ; break; - case GDScriptTokenizer::TK_OP_SHIFT_LEFT: op = OperatorNode::OP_SHIFT_LEFT; break; - case GDScriptTokenizer::TK_OP_SHIFT_RIGHT: op = OperatorNode::OP_SHIFT_RIGHT; break; + case GDScriptTokenizer::TK_OP_SHIFT_LEFT: + op = OperatorNode::OP_SHIFT_LEFT; + break; + case GDScriptTokenizer::TK_OP_SHIFT_RIGHT: + op = OperatorNode::OP_SHIFT_RIGHT; + break; case GDScriptTokenizer::TK_OP_ASSIGN: { _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN; @@ -1364,23 +1412,57 @@ GDScriptParser::Node *GDScriptParser::_parse_expression(Node *p_parent, bool p_s } } break; - case GDScriptTokenizer::TK_OP_ASSIGN_ADD: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_ADD; break; - case GDScriptTokenizer::TK_OP_ASSIGN_SUB: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SUB; break; - case GDScriptTokenizer::TK_OP_ASSIGN_MUL: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_MUL; break; - case GDScriptTokenizer::TK_OP_ASSIGN_DIV: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_DIV; break; - case GDScriptTokenizer::TK_OP_ASSIGN_MOD: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_MOD; break; - case GDScriptTokenizer::TK_OP_ASSIGN_SHIFT_LEFT: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SHIFT_LEFT; break; - case GDScriptTokenizer::TK_OP_ASSIGN_SHIFT_RIGHT: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SHIFT_RIGHT; break; - case GDScriptTokenizer::TK_OP_ASSIGN_BIT_AND: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_AND; break; - case GDScriptTokenizer::TK_OP_ASSIGN_BIT_OR: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_OR; break; - case GDScriptTokenizer::TK_OP_ASSIGN_BIT_XOR: _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_XOR; break; - case GDScriptTokenizer::TK_OP_BIT_AND: op = OperatorNode::OP_BIT_AND; break; - case GDScriptTokenizer::TK_OP_BIT_OR: op = OperatorNode::OP_BIT_OR; break; - case GDScriptTokenizer::TK_OP_BIT_XOR: op = OperatorNode::OP_BIT_XOR; break; - case GDScriptTokenizer::TK_PR_IS: op = OperatorNode::OP_IS; break; - case GDScriptTokenizer::TK_CF_IF: op = OperatorNode::OP_TERNARY_IF; break; - case GDScriptTokenizer::TK_CF_ELSE: op = OperatorNode::OP_TERNARY_ELSE; break; - default: valid = false; break; + case GDScriptTokenizer::TK_OP_ASSIGN_ADD: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_ADD; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_SUB: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SUB; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_MUL: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_MUL; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_DIV: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_DIV; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_MOD: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_MOD; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_SHIFT_LEFT: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SHIFT_LEFT; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_SHIFT_RIGHT: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_SHIFT_RIGHT; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_BIT_AND: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_AND; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_BIT_OR: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_OR; + break; + case GDScriptTokenizer::TK_OP_ASSIGN_BIT_XOR: + _VALIDATE_ASSIGN op = OperatorNode::OP_ASSIGN_BIT_XOR; + break; + case GDScriptTokenizer::TK_OP_BIT_AND: + op = OperatorNode::OP_BIT_AND; + break; + case GDScriptTokenizer::TK_OP_BIT_OR: + op = OperatorNode::OP_BIT_OR; + break; + case GDScriptTokenizer::TK_OP_BIT_XOR: + op = OperatorNode::OP_BIT_XOR; + break; + case GDScriptTokenizer::TK_PR_IS: + op = OperatorNode::OP_IS; + break; + case GDScriptTokenizer::TK_CF_IF: + op = OperatorNode::OP_TERNARY_IF; + break; + case GDScriptTokenizer::TK_CF_ELSE: + op = OperatorNode::OP_TERNARY_ELSE; + break; + default: + valid = false; + break; } if (valid) { @@ -1433,36 +1515,74 @@ GDScriptParser::Node *GDScriptParser::_parse_expression(Node *p_parent, bool p_s unary = true; break; - case OperatorNode::OP_MUL: priority = 2; break; - case OperatorNode::OP_DIV: priority = 2; break; - case OperatorNode::OP_MOD: priority = 2; break; + case OperatorNode::OP_MUL: + priority = 2; + break; + case OperatorNode::OP_DIV: + priority = 2; + break; + case OperatorNode::OP_MOD: + priority = 2; + break; - case OperatorNode::OP_ADD: priority = 3; break; - case OperatorNode::OP_SUB: priority = 3; break; + case OperatorNode::OP_ADD: + priority = 3; + break; + case OperatorNode::OP_SUB: + priority = 3; + break; - case OperatorNode::OP_SHIFT_LEFT: priority = 4; break; - case OperatorNode::OP_SHIFT_RIGHT: priority = 4; break; + case OperatorNode::OP_SHIFT_LEFT: + priority = 4; + break; + case OperatorNode::OP_SHIFT_RIGHT: + priority = 4; + break; - case OperatorNode::OP_BIT_AND: priority = 5; break; - case OperatorNode::OP_BIT_XOR: priority = 6; break; - case OperatorNode::OP_BIT_OR: priority = 7; break; + case OperatorNode::OP_BIT_AND: + priority = 5; + break; + case OperatorNode::OP_BIT_XOR: + priority = 6; + break; + case OperatorNode::OP_BIT_OR: + priority = 7; + break; - case OperatorNode::OP_LESS: priority = 8; break; - case OperatorNode::OP_LESS_EQUAL: priority = 8; break; - case OperatorNode::OP_GREATER: priority = 8; break; - case OperatorNode::OP_GREATER_EQUAL: priority = 8; break; + case OperatorNode::OP_LESS: + priority = 8; + break; + case OperatorNode::OP_LESS_EQUAL: + priority = 8; + break; + case OperatorNode::OP_GREATER: + priority = 8; + break; + case OperatorNode::OP_GREATER_EQUAL: + priority = 8; + break; - case OperatorNode::OP_EQUAL: priority = 8; break; - case OperatorNode::OP_NOT_EQUAL: priority = 8; break; + case OperatorNode::OP_EQUAL: + priority = 8; + break; + case OperatorNode::OP_NOT_EQUAL: + priority = 8; + break; - case OperatorNode::OP_IN: priority = 10; break; + case OperatorNode::OP_IN: + priority = 10; + break; case OperatorNode::OP_NOT: priority = 11; unary = true; break; - case OperatorNode::OP_AND: priority = 12; break; - case OperatorNode::OP_OR: priority = 13; break; + case OperatorNode::OP_AND: + priority = 12; + break; + case OperatorNode::OP_OR: + priority = 13; + break; case OperatorNode::OP_TERNARY_IF: priority = 14; @@ -1475,17 +1595,39 @@ GDScriptParser::Node *GDScriptParser::_parse_expression(Node *p_parent, bool p_s // Rigth-to-left should be false in this case, otherwise it would always error. break; - case OperatorNode::OP_ASSIGN: priority = 15; break; - case OperatorNode::OP_ASSIGN_ADD: priority = 15; break; - case OperatorNode::OP_ASSIGN_SUB: priority = 15; break; - case OperatorNode::OP_ASSIGN_MUL: priority = 15; break; - case OperatorNode::OP_ASSIGN_DIV: priority = 15; break; - case OperatorNode::OP_ASSIGN_MOD: priority = 15; break; - case OperatorNode::OP_ASSIGN_SHIFT_LEFT: priority = 15; break; - case OperatorNode::OP_ASSIGN_SHIFT_RIGHT: priority = 15; break; - case OperatorNode::OP_ASSIGN_BIT_AND: priority = 15; break; - case OperatorNode::OP_ASSIGN_BIT_OR: priority = 15; break; - case OperatorNode::OP_ASSIGN_BIT_XOR: priority = 15; break; + case OperatorNode::OP_ASSIGN: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_ADD: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_SUB: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_MUL: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_DIV: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_MOD: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_SHIFT_LEFT: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_SHIFT_RIGHT: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_BIT_AND: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_BIT_OR: + priority = 15; + break; + case OperatorNode::OP_ASSIGN_BIT_XOR: + priority = 15; + break; default: { _set_error("GDScriptParser bug, invalid operator in expression: " + itos(expression[i].op)); @@ -2037,7 +2179,8 @@ bool GDScriptParser::_reduce_export_var_type(Variant &p_value, int p_line) { if (p_value.get_type() == Variant::ARRAY) { Array arr = p_value; for (int i = 0; i < arr.size(); i++) { - if (!_reduce_export_var_type(arr[i], p_line)) return false; + if (!_reduce_export_var_type(arr[i], p_line)) + return false; } return true; } @@ -2046,7 +2189,8 @@ bool GDScriptParser::_reduce_export_var_type(Variant &p_value, int p_line) { Dictionary dict = p_value; for (int i = 0; i < dict.size(); i++) { Variant value = dict.get_value_at_index(i); - if (!_reduce_export_var_type(value, p_line)) return false; + if (!_reduce_export_var_type(value, p_line)) + return false; } return true; } @@ -2938,7 +3082,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { lv->assign = assigned; if (!_end_statement()) { - _set_error("Expected end of statement (\"var\")."); + _set_end_statement_error("var"); return; } @@ -3160,6 +3304,8 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { ConstantNode *c = static_cast<ConstantNode *>(op->arguments[i]); if (c->value.get_type() == Variant::FLOAT || c->value.get_type() == Variant::INT) { constants.push_back(c->value); + } else { + constant = false; } } else { constant = false; @@ -3172,9 +3318,15 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { ConstantNode *cn = alloc_node<ConstantNode>(); switch (args.size()) { - case 1: cn->value = (int)constants[0]; break; - case 2: cn->value = Vector2(constants[0], constants[1]); break; - case 3: cn->value = Vector3(constants[0], constants[1], constants[2]); break; + case 1: + cn->value = (int64_t)constants[0]; + break; + case 2: + cn->value = Vector2i(constants[0], constants[1]); + break; + case 3: + cn->value = Vector3i(constants[0], constants[1], constants[2]); + break; } cn->datatype = _type_from_variant(cn->value); container = cn; @@ -3186,9 +3338,15 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { on->arguments.push_back(tn); switch (args.size()) { - case 1: tn->vtype = Variant::INT; break; - case 2: tn->vtype = Variant::VECTOR2; break; - case 3: tn->vtype = Variant::VECTOR3; break; + case 1: + tn->vtype = Variant::INT; + break; + case 2: + tn->vtype = Variant::VECTOR2I; + break; + case 3: + tn->vtype = Variant::VECTOR3I; + break; } for (int i = 0; i < args.size(); i++) { @@ -3249,7 +3407,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { cf_continue->cf_type = ControlFlowNode::CF_CONTINUE; p_block->statements.push_back(cf_continue); if (!_end_statement()) { - _set_error("Expected end of statement (\"continue\")."); + _set_end_statement_error("continue"); return; } } break; @@ -3261,7 +3419,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { cf_break->cf_type = ControlFlowNode::CF_BREAK; p_block->statements.push_back(cf_break); if (!_end_statement()) { - _set_error("Expected end of statement (\"break\")."); + _set_end_statement_error("break"); return; } } break; @@ -3290,7 +3448,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { cf_return->arguments.push_back(retexpr); p_block->statements.push_back(cf_return); if (!_end_statement()) { - _set_error("Expected end of statement after return expression."); + _set_end_statement_error("return"); return; } } @@ -3327,7 +3485,8 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { _parse_pattern_block(compiled_branches, match_node->branches, p_static); - if (error_set) return; + if (error_set) + return; ControlFlowNode *match_cf_node = alloc_node<ControlFlowNode>(); match_cf_node->cf_type = ControlFlowNode::CF_MATCH; @@ -3379,7 +3538,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { p_block->statements.push_back(an); if (!_end_statement()) { - _set_error("Expected end of statement after \"assert\".", assert_line); + _set_end_statement_error("assert"); return; } } break; @@ -3390,7 +3549,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { p_block->statements.push_back(bn); if (!_end_statement()) { - _set_error("Expected end of statement after \"breakpoint\"."); + _set_end_statement_error("breakpoint"); return; } } break; @@ -3409,7 +3568,7 @@ void GDScriptParser::_parse_block(BlockNode *p_block, bool p_static) { if (tokenizer->get_token() == GDScriptTokenizer::TK_COLON && tokenizer->get_token(1) == GDScriptTokenizer::TK_OP_ASSIGN) { _set_error("Unexpected ':=', use '=' instead. Expected end of statement after expression."); } else { - _set_error(String() + "Expected end of statement after expression, got " + tokenizer->get_token_name(tokenizer->get_token()) + " instead"); + _set_error(vformat("Expected end of statement after expression, got %s instead.", tokenizer->get_token_name(tokenizer->get_token()))); } return; } @@ -3599,7 +3758,7 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { if (error_set) return; if (!_end_statement()) { - _set_error("Expected end of statement after \"extends\"."); + _set_end_statement_error("extends"); return; } @@ -4104,7 +4263,7 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { p_class->_signals.push_back(sig); if (!_end_statement()) { - _set_error("Expected end of statement (\"signal\")."); + _set_end_statement_error("signal"); return; } } break; @@ -4924,7 +5083,8 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { return; } - if (!_reduce_export_var_type(cn->value, member.line)) return; + if (!_reduce_export_var_type(cn->value, member.line)) + return; member._export.type = cn->value.get_type(); member._export.usage |= PROPERTY_USAGE_SCRIPT_VARIABLE; @@ -5047,7 +5207,7 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { p_class->variables.push_back(member); if (!_end_statement()) { - _set_error("Expected end of statement (\"continue\")."); + _set_end_statement_error("var"); return; } } break; @@ -5127,7 +5287,7 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { p_class->constant_expressions.insert(const_id, constant); if (!_end_statement()) { - _set_error("Expected end of statement (constant).", line); + _set_end_statement_error("const"); return; } @@ -5281,7 +5441,7 @@ void GDScriptParser::_parse_class(ClassNode *p_class) { } if (!_end_statement()) { - _set_error("Expected end of statement (\"enum\")."); + _set_end_statement_error("enum"); return; } @@ -5441,8 +5601,10 @@ void GDScriptParser::_determine_inheritance(ClassNode *p_class, bool p_recursive } } - if (base_class) break; - if (found) continue; + if (base_class) + break; + if (found) + continue; if (p->constant_expressions.has(base)) { if (p->constant_expressions[base].expression->type != Node::TYPE_CONSTANT) { @@ -5544,10 +5706,12 @@ void GDScriptParser::_determine_inheritance(ClassNode *p_class, bool p_recursive } String GDScriptParser::DataType::to_string() const { - if (!has_type) return "var"; + if (!has_type) + return "var"; switch (kind) { case BUILTIN: { - if (builtin_type == Variant::NIL) return "null"; + if (builtin_type == Variant::NIL) + return "null"; return Variant::get_type_name(builtin_type); } break; case NATIVE: { @@ -5711,8 +5875,10 @@ bool GDScriptParser::_parse_type(DataType &r_type, bool p_can_be_void) { } GDScriptParser::DataType GDScriptParser::_resolve_type(const DataType &p_source, int p_line) { - if (!p_source.has_type) return p_source; - if (p_source.kind != DataType::UNRESOLVED) return p_source; + if (!p_source.has_type) + return p_source; + if (p_source.kind != DataType::UNRESOLVED) + return p_source; Vector<String> full_name = p_source.native_type.operator String().split(".", false); int name_part = 0; @@ -6952,7 +7118,8 @@ bool GDScriptParser::_get_function_signature(DataType &p_base_type, const String native = "_" + native.operator String(); } if (!ClassDB::class_exists(native)) { - if (!check_types) return false; + if (!check_types) + return false; ERR_FAIL_V_MSG(false, "Parser bug: Class '" + String(native) + "' not found."); } @@ -7043,7 +7210,8 @@ GDScriptParser::DataType GDScriptParser::_reduce_function_call_type(const Operat par_types.write[i - 1] = _reduce_node_type(p_call->arguments[i]); } - if (error_set) return DataType(); + if (error_set) + return DataType(); // Special case: check copy constructor. Those are defined implicitly in Variant. if (par_types.size() == 1) { @@ -7111,7 +7279,8 @@ GDScriptParser::DataType GDScriptParser::_reduce_function_call_type(const Operat err += "' matches the signature '"; err += Variant::get_type_name(tn->vtype) + "("; for (int i = 0; i < par_types.size(); i++) { - if (i > 0) err += ", "; + if (i > 0) + err += ", "; err += par_types[i].to_string(); } err += ")'."; @@ -7340,7 +7509,7 @@ GDScriptParser::DataType GDScriptParser::_reduce_function_call_type(const Operat return return_type; } -bool GDScriptParser::_get_member_type(const DataType &p_base_type, const StringName &p_member, DataType &r_member_type) const { +bool GDScriptParser::_get_member_type(const DataType &p_base_type, const StringName &p_member, DataType &r_member_type, bool *r_is_const) const { DataType base_type = p_base_type; // Check classes in current file @@ -7351,6 +7520,8 @@ bool GDScriptParser::_get_member_type(const DataType &p_base_type, const StringN while (base) { if (base->constant_expressions.has(p_member)) { + if (r_is_const) + *r_is_const = true; r_member_type = base->constant_expressions[p_member].expression->get_datatype(); return true; } @@ -7469,7 +7640,8 @@ bool GDScriptParser::_get_member_type(const DataType &p_base_type, const StringN native = "_" + native.operator String(); } if (!ClassDB::class_exists(native)) { - if (!check_types) return false; + if (!check_types) + return false; ERR_FAIL_V_MSG(false, "Parser bug: Class \"" + String(native) + "\" not found."); } @@ -7574,8 +7746,9 @@ GDScriptParser::DataType GDScriptParser::_reduce_identifier_type(const DataType base_type = DataType(*p_base_type); } - if (_get_member_type(base_type, p_identifier, member_type)) { - if (!p_base_type && current_function && current_function->_static) { + bool is_const = false; + if (_get_member_type(base_type, p_identifier, member_type, &is_const)) { + if (!p_base_type && current_function && current_function->_static && !is_const) { _set_error("Can't access member variable (\"" + p_identifier.operator String() + "\") from a static function.", p_line); return DataType(); } @@ -7766,12 +7939,14 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { // Function declarations for (int i = 0; i < p_class->static_functions.size(); i++) { _check_function_types(p_class->static_functions[i]); - if (error_set) return; + if (error_set) + return; } for (int i = 0; i < p_class->functions.size(); i++) { _check_function_types(p_class->functions[i]); - if (error_set) return; + if (error_set) + return; } // Class variables @@ -7786,6 +7961,7 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { _mark_line_as_safe(v.line); v.data_type = _resolve_type(v.data_type, v.line); + v.initial_assignment->arguments[0]->set_datatype(v.data_type); if (v.expression) { DataType expr_type = _reduce_node_type(v.expression); @@ -7810,7 +7986,7 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { ConstantNode *tgt_type = alloc_node<ConstantNode>(); tgt_type->line = v.line; - tgt_type->value = (int)v.data_type.builtin_type; + tgt_type->value = (int64_t)v.data_type.builtin_type; OperatorNode *convert_call = alloc_node<OperatorNode>(); convert_call->line = v.line; @@ -7846,7 +8022,8 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { } // Setter and getter - if (v.setter == StringName() && v.getter == StringName()) continue; + if (v.setter == StringName() && v.getter == StringName()) + continue; bool found_getter = false; bool found_setter = false; @@ -7889,10 +8066,12 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { return; } } - if (found_getter && found_setter) break; + if (found_getter && found_setter) + break; } - if ((found_getter || v.getter == StringName()) && (found_setter || v.setter == StringName())) continue; + if ((found_getter || v.getter == StringName()) && (found_setter || v.setter == StringName())) + continue; // Check for static functions for (int j = 0; j < p_class->static_functions.size(); j++) { @@ -7923,7 +8102,8 @@ void GDScriptParser::_check_class_level_types(ClassNode *p_class) { for (int i = 0; i < p_class->subclasses.size(); i++) { current_class = p_class->subclasses[i]; _check_class_level_types(current_class); - if (error_set) return; + if (error_set) + return; current_class = p_class; } } @@ -8060,7 +8240,8 @@ void GDScriptParser::_check_class_blocks_types(ClassNode *p_class) { _check_block_types(current_block); current_block = nullptr; current_function = nullptr; - if (error_set) return; + if (error_set) + return; } for (int i = 0; i < p_class->functions.size(); i++) { @@ -8070,7 +8251,8 @@ void GDScriptParser::_check_class_blocks_types(ClassNode *p_class) { _check_block_types(current_block); current_block = nullptr; current_function = nullptr; - if (error_set) return; + if (error_set) + return; } #ifdef DEBUG_ENABLED @@ -8091,7 +8273,8 @@ void GDScriptParser::_check_class_blocks_types(ClassNode *p_class) { for (int i = 0; i < p_class->subclasses.size(); i++) { current_class = p_class->subclasses[i]; _check_class_blocks_types(current_class); - if (error_set) return; + if (error_set) + return; current_class = p_class; } } @@ -8181,7 +8364,7 @@ void GDScriptParser::_check_block_types(BlockNode *p_block) { ConstantNode *tgt_type = alloc_node<ConstantNode>(); tgt_type->line = lv->line; - tgt_type->value = (int)lv->datatype.builtin_type; + tgt_type->value = (int64_t)lv->datatype.builtin_type; tgt_type->datatype = _type_from_variant(tgt_type->value); OperatorNode *convert_call = alloc_node<OperatorNode>(); @@ -8359,7 +8542,8 @@ void GDScriptParser::_check_block_types(BlockNode *p_block) { _add_warning(GDScriptWarning::RETURN_VALUE_DISCARDED, op->line, func_name); } #endif // DEBUG_ENABLED - if (error_set) return; + if (error_set) + return; } break; case OperatorNode::OP_YIELD: { _mark_line_as_safe(op->line); @@ -8394,7 +8578,8 @@ void GDScriptParser::_check_block_types(BlockNode *p_block) { } } - if (!function_type.has_type) break; + if (!function_type.has_type) + break; if (function_type.kind == DataType::BUILTIN && function_type.builtin_type == Variant::NIL) { // Return void, should not have arguments @@ -8454,7 +8639,8 @@ void GDScriptParser::_check_block_types(BlockNode *p_block) { current_block = p_block->sub_blocks[i]; _check_block_types(current_block); current_block = p_block; - if (error_set) return; + if (error_set) + return; } #ifdef DEBUG_ENABLED @@ -8597,7 +8783,8 @@ Error GDScriptParser::_parse(const String &p_base_path) { current_function = nullptr; current_block = nullptr; - if (for_completion) check_types = false; + if (for_completion) + check_types = false; // Resolve all class-level stuff before getting into function blocks _check_class_level_types(main_class); @@ -8786,7 +8973,7 @@ int GDScriptParser::get_completion_argument_index() { return completion_argument; } -int GDScriptParser::get_completion_identifier_is_function() { +bool GDScriptParser::get_completion_identifier_is_function() { return completion_ident_is_call; } diff --git a/modules/gdscript/gdscript_parser.h b/modules/gdscript/gdscript_parser.h index f254352423..035af30b6a 100644 --- a/modules/gdscript/gdscript_parser.h +++ b/modules/gdscript/gdscript_parser.h @@ -45,25 +45,27 @@ public: struct ClassNode; struct DataType { - enum { + enum Kind { BUILTIN, NATIVE, SCRIPT, GDSCRIPT, CLASS, UNRESOLVED - } kind; + }; + + Kind kind = UNRESOLVED; - bool has_type; - bool is_constant; - bool is_meta_type; // Whether the value can be used as a type - bool infer_type; - bool may_yield; // For function calls + bool has_type = false; + bool is_constant = false; + bool is_meta_type = false; // Whether the value can be used as a type + bool infer_type = false; + bool may_yield = false; // For function calls - Variant::Type builtin_type; + Variant::Type builtin_type = Variant::NIL; StringName native_type; Ref<Script> script_type; - ClassNode *class_type; + ClassNode *class_type = nullptr; String to_string() const; @@ -94,15 +96,7 @@ public: return false; } - DataType() : - kind(UNRESOLVED), - has_type(false), - is_constant(false), - is_meta_type(false), - infer_type(false), - may_yield(false), - builtin_type(Variant::NIL), - class_type(nullptr) {} + DataType() {} }; struct Node { @@ -236,66 +230,63 @@ public: struct BlockNode : public Node { - ClassNode *parent_class; - BlockNode *parent_block; + ClassNode *parent_class = nullptr; + BlockNode *parent_block = nullptr; List<Node *> statements; Map<StringName, LocalVarNode *> variables; - bool has_return; + bool has_return = false; - Node *if_condition; //tiny hack to improve code completion on if () blocks + Node *if_condition = nullptr; //tiny hack to improve code completion on if () blocks //the following is useful for code completion List<BlockNode *> sub_blocks; - int end_line; + int end_line = -1; + BlockNode() { - if_condition = nullptr; type = TYPE_BLOCK; - end_line = -1; - parent_block = nullptr; - parent_class = nullptr; - has_return = false; } }; struct TypeNode : public Node { - Variant::Type vtype; - TypeNode() { type = TYPE_TYPE; } + + TypeNode() { + type = TYPE_TYPE; + } }; + struct BuiltInFunctionNode : public Node { GDScriptFunctions::Function function; - BuiltInFunctionNode() { type = TYPE_BUILT_IN_FUNCTION; } + + BuiltInFunctionNode() { + type = TYPE_BUILT_IN_FUNCTION; + } }; struct IdentifierNode : public Node { - StringName name; - BlockNode *declared_block; // Simplify lookup by checking if it is declared locally + BlockNode *declared_block = nullptr; // Simplify lookup by checking if it is declared locally DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } + IdentifierNode() { type = TYPE_IDENTIFIER; - declared_block = nullptr; } }; struct LocalVarNode : public Node { - StringName name; - Node *assign; - OperatorNode *assign_op; - int assignments; - int usages; + Node *assign = nullptr; + OperatorNode *assign_op = nullptr; + int assignments = 0; + int usages = 0; DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } + LocalVarNode() { type = TYPE_LOCAL_VAR; - assign = nullptr; - assign_op = nullptr; - assignments = 0; - usages = 0; } }; @@ -304,15 +295,18 @@ public: DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } - ConstantNode() { type = TYPE_CONSTANT; } + + ConstantNode() { + type = TYPE_CONSTANT; + } }; struct ArrayNode : public Node { - Vector<Node *> elements; DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } + ArrayNode() { type = TYPE_ARRAY; datatype.has_type = true; @@ -324,7 +318,6 @@ public: struct DictionaryNode : public Node { struct Pair { - Node *key; Node *value; }; @@ -333,6 +326,7 @@ public: DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } + DictionaryNode() { type = TYPE_DICTIONARY; datatype.has_type = true; @@ -342,7 +336,9 @@ public: }; struct SelfNode : public Node { - SelfNode() { type = TYPE_SELF; } + SelfNode() { + type = TYPE_SELF; + } }; struct OperatorNode : public Node { @@ -404,7 +400,9 @@ public: DataType datatype; virtual DataType get_datatype() const { return datatype; } virtual void set_datatype(const DataType &p_datatype) { datatype = p_datatype; } - OperatorNode() { type = TYPE_OPERATOR; } + OperatorNode() { + type = TYPE_OPERATOR; + } }; struct PatternNode : public Node { @@ -454,19 +452,17 @@ public: CF_MATCH }; - CFType cf_type; + CFType cf_type = CF_IF; Vector<Node *> arguments; - BlockNode *body; - BlockNode *body_else; + BlockNode *body = nullptr; + BlockNode *body_else = nullptr; MatchNode *match; ControlFlowNode *_else; //used for if + ControlFlowNode() { type = TYPE_CONTROL_FLOW; - cf_type = CF_IF; - body = nullptr; - body_else = nullptr; } }; @@ -476,29 +472,34 @@ public: DataType return_type; virtual DataType get_datatype() const { return return_type; } virtual void set_datatype(const DataType &p_datatype) { return_type = p_datatype; } - CastNode() { type = TYPE_CAST; } + + CastNode() { + type = TYPE_CAST; + } }; struct AssertNode : public Node { - Node *condition; - Node *message; - AssertNode() : - condition(0), - message(0) { + Node *condition = nullptr; + Node *message = nullptr; + + AssertNode() { type = TYPE_ASSERT; } }; struct BreakpointNode : public Node { - BreakpointNode() { type = TYPE_BREAKPOINT; } + BreakpointNode() { + type = TYPE_BREAKPOINT; + } }; struct NewLineNode : public Node { - NewLineNode() { type = TYPE_NEWLINE; } + NewLineNode() { + type = TYPE_NEWLINE; + } }; struct Expression { - bool is_op; union { OperatorNode::Operator op; @@ -553,8 +554,8 @@ private: int pending_newline; struct IndentLevel { - int indent; - int tabs; + int indent = 0; + int tabs = 0; bool is_mixed(IndentLevel other) { return ( @@ -563,9 +564,7 @@ private: (indent < other.indent && tabs > other.tabs)); } - IndentLevel() : - indent(0), - tabs(0) {} + IndentLevel() {} IndentLevel(int p_indent, int p_tabs) : indent(p_indent), @@ -624,6 +623,7 @@ private: void _parse_extends(ClassNode *p_class); void _parse_class(ClassNode *p_class); bool _end_statement(); + void _set_end_statement_error(String p_name); void _determine_inheritance(ClassNode *p_class, bool p_recursive = true); bool _parse_type(DataType &r_type, bool p_can_be_void = false); @@ -634,7 +634,7 @@ private: DataType _get_operation_type(const Variant::Operator p_op, const DataType &p_a, const DataType &p_b, bool &r_valid) const; Variant::Operator _get_variant_operation(const OperatorNode::Operator &p_op) const; bool _get_function_signature(DataType &p_base_type, const StringName &p_function, DataType &r_return_type, List<DataType> &r_arg_types, int &r_default_arg_count, bool &r_static, bool &r_vararg) const; - bool _get_member_type(const DataType &p_base_type, const StringName &p_member, DataType &r_member_type) const; + bool _get_member_type(const DataType &p_base_type, const StringName &p_member, DataType &r_member_type, bool *r_is_const = nullptr) const; bool _is_type_compatible(const DataType &p_container, const DataType &p_expression, bool p_allow_implicit_conversion = false) const; Node *_get_default_value_for_type(const DataType &p_type, int p_line = -1); @@ -647,12 +647,14 @@ private: void _check_block_types(BlockNode *p_block); _FORCE_INLINE_ void _mark_line_as_safe(int p_line) const { #ifdef DEBUG_ENABLED - if (safe_lines) safe_lines->insert(p_line); + if (safe_lines) + safe_lines->insert(p_line); #endif // DEBUG_ENABLED } _FORCE_INLINE_ void _mark_line_as_unsafe(int p_line) const { #ifdef DEBUG_ENABLED - if (safe_lines) safe_lines->erase(p_line); + if (safe_lines) + safe_lines->erase(p_line); #endif // DEBUG_ENABLED } @@ -683,7 +685,7 @@ public: BlockNode *get_completion_block(); FunctionNode *get_completion_function(); int get_completion_argument_index(); - int get_completion_identifier_is_function(); + bool get_completion_identifier_is_function(); const List<String> &get_dependencies() const { return dependencies; } diff --git a/modules/gdscript/gdscript_tokenizer.cpp b/modules/gdscript/gdscript_tokenizer.cpp index 76f42ead5f..1c8282e13e 100644 --- a/modules/gdscript/gdscript_tokenizer.cpp +++ b/modules/gdscript/gdscript_tokenizer.cpp @@ -803,16 +803,36 @@ void GDScriptTokenizerText::_advance() { switch (next) { - case 'a': res = '\a'; break; - case 'b': res = '\b'; break; - case 't': res = '\t'; break; - case 'n': res = '\n'; break; - case 'v': res = '\v'; break; - case 'f': res = '\f'; break; - case 'r': res = '\r'; break; - case '\'': res = '\''; break; - case '\"': res = '\"'; break; - case '\\': res = '\\'; break; + case 'a': + res = '\a'; + break; + case 'b': + res = '\b'; + break; + case 't': + res = '\t'; + break; + case 'n': + res = '\n'; + break; + case 'v': + res = '\v'; + break; + case 'f': + res = '\f'; + break; + case 'r': + res = '\r'; + break; + case '\'': + res = '\''; + break; + case '\"': + res = '\"'; + break; + case '\\': + res = '\\'; + break; case 'u': { // hex number diff --git a/modules/gdscript/gdscript_tokenizer.h b/modules/gdscript/gdscript_tokenizer.h index 180ec3c77e..76410433de 100644 --- a/modules/gdscript/gdscript_tokenizer.h +++ b/modules/gdscript/gdscript_tokenizer.h @@ -176,7 +176,7 @@ public: virtual bool is_ignoring_warnings() const = 0; #endif // DEBUG_ENABLED - virtual ~GDScriptTokenizer(){}; + virtual ~GDScriptTokenizer() {} }; class GDScriptTokenizerText : public GDScriptTokenizer { diff --git a/modules/gdscript/language_server/gdscript_extend_parser.cpp b/modules/gdscript/language_server/gdscript_extend_parser.cpp index b2c6b0e1ab..a6b749059a 100644 --- a/modules/gdscript/language_server/gdscript_extend_parser.cpp +++ b/modules/gdscript/language_server/gdscript_extend_parser.cpp @@ -385,7 +385,8 @@ String ExtendGDScriptParser::parse_documentation(int p_line, bool p_docs_down) { int start_line = p_docs_down ? p_line : p_line - 1; for (int i = start_line; true; i += step) { - if (i < 0 || i >= lines.size()) break; + if (i < 0 || i >= lines.size()) + break; String line_comment = lines[i].strip_edges(true, false); if (line_comment.begins_with("#")) { diff --git a/modules/gdscript/language_server/gdscript_workspace.cpp b/modules/gdscript/language_server/gdscript_workspace.cpp index 32fc8f36f0..be036b44c4 100644 --- a/modules/gdscript/language_server/gdscript_workspace.cpp +++ b/modules/gdscript/language_server/gdscript_workspace.cpp @@ -185,7 +185,8 @@ Array GDScriptWorkspace::symbol(const Dictionary &p_params) { } Error GDScriptWorkspace::initialize() { - if (initialized) return OK; + if (initialized) + return OK; DocData *doc = EditorHelp::get_doc_data(); for (Map<String, DocData::ClassDoc>::Element *E = doc->class_list.front(); E; E = E->next()) { diff --git a/modules/gdscript/language_server/lsp.hpp b/modules/gdscript/language_server/lsp.hpp index 124fcbfed8..e469a26df8 100644 --- a/modules/gdscript/language_server/lsp.hpp +++ b/modules/gdscript/language_server/lsp.hpp @@ -282,7 +282,8 @@ struct Command { Dictionary dict; dict["title"] = title; dict["command"] = command; - if (arguments.size()) dict["arguments"] = arguments; + if (arguments.size()) + dict["arguments"] = arguments; return dict; } }; @@ -946,16 +947,20 @@ struct CompletionItem { dict["preselect"] = preselect; dict["sortText"] = sortText; dict["filterText"] = filterText; - if (commitCharacters.size()) dict["commitCharacters"] = commitCharacters; + if (commitCharacters.size()) + dict["commitCharacters"] = commitCharacters; dict["command"] = command.to_json(); } return dict; } void load(const Dictionary &p_dict) { - if (p_dict.has("label")) label = p_dict["label"]; - if (p_dict.has("kind")) kind = p_dict["kind"]; - if (p_dict.has("detail")) detail = p_dict["detail"]; + if (p_dict.has("label")) + label = p_dict["label"]; + if (p_dict.has("kind")) + kind = p_dict["kind"]; + if (p_dict.has("detail")) + detail = p_dict["detail"]; if (p_dict.has("documentation")) { Variant doc = p_dict["documentation"]; if (doc.get_type() == Variant::STRING) { @@ -965,12 +970,18 @@ struct CompletionItem { documentation.value = v["value"]; } } - if (p_dict.has("deprecated")) deprecated = p_dict["deprecated"]; - if (p_dict.has("preselect")) preselect = p_dict["preselect"]; - if (p_dict.has("sortText")) sortText = p_dict["sortText"]; - if (p_dict.has("filterText")) filterText = p_dict["filterText"]; - if (p_dict.has("insertText")) insertText = p_dict["insertText"]; - if (p_dict.has("data")) data = p_dict["data"]; + if (p_dict.has("deprecated")) + deprecated = p_dict["deprecated"]; + if (p_dict.has("preselect")) + preselect = p_dict["preselect"]; + if (p_dict.has("sortText")) + sortText = p_dict["sortText"]; + if (p_dict.has("filterText")) + filterText = p_dict["filterText"]; + if (p_dict.has("insertText")) + insertText = p_dict["insertText"]; + if (p_dict.has("data")) + data = p_dict["data"]; } }; diff --git a/modules/glslang/register_types.cpp b/modules/glslang/register_types.cpp index 2540ba476c..5203460830 100644 --- a/modules/glslang/register_types.cpp +++ b/modules/glslang/register_types.cpp @@ -130,15 +130,15 @@ static const TBuiltInResource default_builtin_resource = { /*maxTaskWorkGroupSizeZ_NV*/ 0, /*maxMeshViewCountNV*/ 0, /*limits*/ { - /*nonInductiveForLoops*/ 1, - /*whileLoops*/ 1, - /*doWhileLoops*/ 1, - /*generalUniformIndexing*/ 1, - /*generalAttributeMatrixVectorIndexing*/ 1, - /*generalVaryingIndexing*/ 1, - /*generalSamplerIndexing*/ 1, - /*generalVariableIndexing*/ 1, - /*generalConstantMatrixVectorIndexing*/ 1, + /*nonInductiveForLoops*/ true, + /*whileLoops*/ true, + /*doWhileLoops*/ true, + /*generalUniformIndexing*/ true, + /*generalAttributeMatrixVectorIndexing*/ true, + /*generalVaryingIndexing*/ true, + /*generalSamplerIndexing*/ true, + /*generalVariableIndexing*/ true, + /*generalConstantMatrixVectorIndexing*/ true, } }; diff --git a/modules/gridmap/grid_map_editor_plugin.cpp b/modules/gridmap/grid_map_editor_plugin.cpp index 9abbac6a0b..fc545430f5 100644 --- a/modules/gridmap/grid_map_editor_plugin.cpp +++ b/modules/gridmap/grid_map_editor_plugin.cpp @@ -1405,8 +1405,8 @@ GridMapEditor::GridMapEditor(EditorNode *p_editor) { Vector3 points[4]; for (int j = 0; j < 4; j++) { - static const bool orderx[4] = { 0, 1, 1, 0 }; - static const bool ordery[4] = { 0, 0, 1, 1 }; + static const bool orderx[4] = { false, true, true, false }; + static const bool ordery[4] = { false, false, true, true }; Vector3 sp; if (orderx[j]) { diff --git a/modules/jpg/image_loader_jpegd.cpp b/modules/jpg/image_loader_jpegd.cpp index 9e87d11ac1..1a2afeb5b9 100644 --- a/modules/jpg/image_loader_jpegd.cpp +++ b/modules/jpg/image_loader_jpegd.cpp @@ -96,7 +96,7 @@ Error jpeg_load_image_from_buffer(Image *p_image, const uint8_t *p_buffer, int p else fmt = Image::FORMAT_RGB8; - p_image->create(image_width, image_height, 0, fmt, data); + p_image->create(image_width, image_height, false, fmt, data); return OK; } diff --git a/modules/jsonrpc/jsonrpc.cpp b/modules/jsonrpc/jsonrpc.cpp index 393269d422..208ce24f3d 100644 --- a/modules/jsonrpc/jsonrpc.cpp +++ b/modules/jsonrpc/jsonrpc.cpp @@ -148,7 +148,8 @@ Variant JSONRPC::process_action(const Variant &p_action, bool p_process_arr_elem String JSONRPC::process_string(const String &p_input) { - if (p_input.empty()) return String(); + if (p_input.empty()) + return String(); Variant ret; Variant input; diff --git a/modules/lightmapper_rd/SCsub b/modules/lightmapper_rd/SCsub new file mode 100644 index 0000000000..2f04f1833e --- /dev/null +++ b/modules/lightmapper_rd/SCsub @@ -0,0 +1,12 @@ +#!/usr/bin/env python + +Import("env") +Import("env_modules") + +env_lightmapper_rd = env_modules.Clone() +env_lightmapper_rd.GLSL_HEADER("lm_raster.glsl") +env_lightmapper_rd.GLSL_HEADER("lm_compute.glsl") +env_lightmapper_rd.GLSL_HEADER("lm_blendseams.glsl") + +# Godot source files +env_lightmapper_rd.add_source_files(env.modules_sources, "*.cpp") diff --git a/modules/lightmapper_rd/config.py b/modules/lightmapper_rd/config.py new file mode 100644 index 0000000000..d22f9454ed --- /dev/null +++ b/modules/lightmapper_rd/config.py @@ -0,0 +1,6 @@ +def can_build(env, platform): + return True + + +def configure(env): + pass diff --git a/modules/lightmapper_rd/lightmapper_rd.cpp b/modules/lightmapper_rd/lightmapper_rd.cpp new file mode 100644 index 0000000000..6983c222c0 --- /dev/null +++ b/modules/lightmapper_rd/lightmapper_rd.cpp @@ -0,0 +1,1754 @@ +#include "lightmapper_rd.h" +#include "core/math/geometry.h" +#include "core/project_settings.h" +#include "lm_blendseams.glsl.gen.h" +#include "lm_compute.glsl.gen.h" +#include "lm_raster.glsl.gen.h" +#include "servers/rendering/rendering_device_binds.h" + +//uncomment this if you want to see textures from all the process saved +//#define DEBUG_TEXTURES + +void LightmapperRD::add_mesh(const MeshData &p_mesh) { + ERR_FAIL_COND(p_mesh.albedo_on_uv2.is_null() || p_mesh.albedo_on_uv2->empty()); + ERR_FAIL_COND(p_mesh.emission_on_uv2.is_null() || p_mesh.emission_on_uv2->empty()); + ERR_FAIL_COND(p_mesh.albedo_on_uv2->get_width() != p_mesh.emission_on_uv2->get_width()); + ERR_FAIL_COND(p_mesh.albedo_on_uv2->get_height() != p_mesh.emission_on_uv2->get_height()); + ERR_FAIL_COND(p_mesh.points.size() == 0); + MeshInstance mi; + mi.data = p_mesh; + mesh_instances.push_back(mi); +} + +void LightmapperRD::add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance) { + Light l; + l.type = LIGHT_TYPE_DIRECTIONAL; + l.direction[0] = p_direction.x; + l.direction[1] = p_direction.y; + l.direction[2] = p_direction.z; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_angular_distance; + lights.push_back(l); +} +void LightmapperRD::add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size) { + Light l; + l.type = LIGHT_TYPE_OMNI; + l.position[0] = p_position.x; + l.position[1] = p_position.y; + l.position[2] = p_position.z; + l.range = p_range; + l.attenuation = p_attenuation; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_size; + lights.push_back(l); +} +void LightmapperRD::add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size) { + + Light l; + l.type = LIGHT_TYPE_SPOT; + l.position[0] = p_position.x; + l.position[1] = p_position.y; + l.position[2] = p_position.z; + l.direction[0] = p_direction.x; + l.direction[1] = p_direction.y; + l.direction[2] = p_direction.z; + l.range = p_range; + l.attenuation = p_attenuation; + l.spot_angle = Math::deg2rad(p_spot_angle); + l.spot_attenuation = p_spot_attenuation; + l.color[0] = p_color.r; + l.color[1] = p_color.g; + l.color[2] = p_color.b; + l.energy = p_energy; + l.static_bake = p_static; + l.size = p_size; + lights.push_back(l); +} + +void LightmapperRD::add_probe(const Vector3 &p_position) { + Probe probe; + probe.position[0] = p_position.x; + probe.position[1] = p_position.y; + probe.position[2] = p_position.z; + probe.position[3] = 0; + probe_positions.push_back(probe); +} + +void LightmapperRD::_plot_triangle_into_triangle_index_list(int p_size, const Vector3i &p_ofs, const AABB &p_bounds, const Vector3 p_points[3], uint32_t p_triangle_index, LocalVector<TriangleSort> &triangles, uint32_t p_grid_size) { + + int half_size = p_size / 2; + + for (int i = 0; i < 8; i++) { + + AABB aabb = p_bounds; + aabb.size *= 0.5; + Vector3i n = p_ofs; + + if (i & 1) { + aabb.position.x += aabb.size.x; + n.x += half_size; + } + if (i & 2) { + aabb.position.y += aabb.size.y; + n.y += half_size; + } + if (i & 4) { + aabb.position.z += aabb.size.z; + n.z += half_size; + } + + { + Vector3 qsize = aabb.size * 0.5; //quarter size, for fast aabb test + + if (!Geometry::triangle_box_overlap(aabb.position + qsize, qsize, p_points)) { + //does not fit in child, go on + continue; + } + } + + if (half_size == 1) { + //got to the end + TriangleSort ts; + ts.cell_index = n.x + (n.y * p_grid_size) + (n.z * p_grid_size * p_grid_size); + ts.triangle_index = p_triangle_index; + triangles.push_back(ts); + } else { + _plot_triangle_into_triangle_index_list(half_size, n, aabb, p_points, p_triangle_index, triangles, p_grid_size); + } + } +} + +Lightmapper::BakeError LightmapperRD::_blit_meshes_into_atlas(int p_max_texture_size, Vector<Ref<Image>> &albedo_images, Vector<Ref<Image>> &emission_images, AABB &bounds, Size2i &atlas_size, int &atlas_slices, BakeStepFunc p_step_function, void *p_bake_userdata) { + + Vector<Size2i> sizes; + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + MeshInstance &mi = mesh_instances.write[m_i]; + Size2i s = Size2i(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height()); + sizes.push_back(s); + atlas_size.width = MAX(atlas_size.width, s.width); + atlas_size.height = MAX(atlas_size.height, s.height); + } + + int max = nearest_power_of_2_templated(atlas_size.width); + max = MAX(max, nearest_power_of_2_templated(atlas_size.height)); + + if (max > p_max_texture_size) { + return BAKE_ERROR_LIGHTMAP_TOO_SMALL; + } + + if (p_step_function) { + p_step_function(0.1, TTR("Determining optimal atlas size"), p_bake_userdata, true); + } + + atlas_size = Size2i(max, max); + + Size2i best_atlas_size; + int best_atlas_slices = 0; + int best_atlas_memory = 0x7FFFFFFF; + Vector<Vector3i> best_atlas_offsets; + + //determine best texture array atlas size by bruteforce fitting + while (atlas_size.x <= p_max_texture_size && atlas_size.y <= p_max_texture_size) { + + Vector<Vector2i> source_sizes = sizes; + Vector<int> source_indices; + source_indices.resize(source_sizes.size()); + for (int i = 0; i < source_indices.size(); i++) { + source_indices.write[i] = i; + } + Vector<Vector3i> atlas_offsets; + atlas_offsets.resize(source_sizes.size()); + + int slices = 0; + + while (source_sizes.size() > 0) { + + Vector<Vector3i> offsets = Geometry::partial_pack_rects(source_sizes, atlas_size); + Vector<int> new_indices; + Vector<Vector2i> new_sources; + for (int i = 0; i < offsets.size(); i++) { + Vector3i ofs = offsets[i]; + int sidx = source_indices[i]; + if (ofs.z > 0) { + //valid + ofs.z = slices; + atlas_offsets.write[sidx] = ofs; + } else { + new_indices.push_back(sidx); + new_sources.push_back(source_sizes[i]); + } + } + + source_sizes = new_sources; + source_indices = new_indices; + slices++; + } + + int mem_used = atlas_size.x * atlas_size.y * slices; + if (mem_used < best_atlas_memory) { + best_atlas_size = atlas_size; + best_atlas_offsets = atlas_offsets; + best_atlas_slices = slices; + best_atlas_memory = mem_used; + } + + if (atlas_size.width == atlas_size.height) { + atlas_size.width *= 2; + } else { + atlas_size.height *= 2; + } + } + atlas_size = best_atlas_size; + atlas_slices = best_atlas_slices; + + // apply the offsets and slice to all images, and also blit albedo and emission + albedo_images.resize(atlas_slices); + emission_images.resize(atlas_slices); + + if (p_step_function) { + p_step_function(0.2, TTR("Blitting albedo and emission"), p_bake_userdata, true); + } + + for (int i = 0; i < atlas_slices; i++) { + Ref<Image> albedo; + albedo.instance(); + albedo->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBA8); + albedo->set_as_black(); + albedo_images.write[i] = albedo; + + Ref<Image> emission; + emission.instance(); + emission->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH); + emission->set_as_black(); + emission_images.write[i] = emission; + } + + //assign uv positions + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + MeshInstance &mi = mesh_instances.write[m_i]; + mi.offset.x = best_atlas_offsets[m_i].x; + mi.offset.y = best_atlas_offsets[m_i].y; + mi.slice = best_atlas_offsets[m_i].z; + albedo_images.write[mi.slice]->blit_rect(mi.data.albedo_on_uv2, Rect2(Vector2(), Size2i(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height())), mi.offset); + emission_images.write[mi.slice]->blit_rect(mi.data.emission_on_uv2, Rect2(Vector2(), Size2i(mi.data.emission_on_uv2->get_width(), mi.data.emission_on_uv2->get_height())), mi.offset); + } + + return BAKE_OK; +} + +void LightmapperRD::_create_acceleration_structures(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, AABB &bounds, int grid_size, Vector<Probe> &probe_positions, GenerateProbes p_generate_probes, Vector<int> &slice_triangle_count, Vector<int> &slice_seam_count, RID &vertex_buffer, RID &triangle_buffer, RID &box_buffer, RID &lights_buffer, RID &triangle_cell_indices_buffer, RID &probe_positions_buffer, RID &grid_texture, RID &grid_texture_sdf, RID &seams_buffer, BakeStepFunc p_step_function, void *p_bake_userdata) { + + HashMap<Vertex, uint32_t, VertexHash> vertex_map; + + //fill triangles array and vertex array + LocalVector<Triangle> triangles; + LocalVector<Vertex> vertex_array; + LocalVector<Box> box_array; + LocalVector<Seam> seams; + + slice_triangle_count.resize(atlas_slices); + slice_seam_count.resize(atlas_slices); + + for (int i = 0; i < atlas_slices; i++) { + slice_triangle_count.write[i] = 0; + slice_seam_count.write[i] = 0; + } + + bounds = AABB(); + + for (int m_i = 0; m_i < mesh_instances.size(); m_i++) { + + if (p_step_function) { + float p = float(m_i + 1) / mesh_instances.size() * 0.1; + p_step_function(0.3 + p, vformat(TTR("Plotting mesh into acceleration structure %d/%d"), m_i + 1, mesh_instances.size()), p_bake_userdata, false); + } + + HashMap<Edge, EdgeUV2, EdgeHash> edges; + + MeshInstance &mi = mesh_instances.write[m_i]; + + Vector2 uv_scale = Vector2(mi.data.albedo_on_uv2->get_width(), mi.data.albedo_on_uv2->get_height()) / Vector2(atlas_size); + Vector2 uv_offset = Vector2(mi.offset) / Vector2(atlas_size); + if (m_i == 0) { + bounds.position = mi.data.points[0]; + } + + for (int i = 0; i < mi.data.points.size(); i += 3) { + + Vector3 vtxs[3] = { mi.data.points[i + 0], mi.data.points[i + 1], mi.data.points[i + 2] }; + Vector2 uvs[3] = { mi.data.uv2[i + 0] * uv_scale + uv_offset, mi.data.uv2[i + 1] * uv_scale + uv_offset, mi.data.uv2[i + 2] * uv_scale + uv_offset }; + Vector3 normal[3] = { mi.data.normal[i + 0], mi.data.normal[i + 1], mi.data.normal[i + 2] }; + + AABB taabb; + Triangle t; + t.slice = mi.slice; + for (int k = 0; k < 3; k++) { + + bounds.expand_to(vtxs[k]); + + Vertex v; + v.position[0] = vtxs[k].x; + v.position[1] = vtxs[k].y; + v.position[2] = vtxs[k].z; + v.uv[0] = uvs[k].x; + v.uv[1] = uvs[k].y; + v.normal_xy[0] = normal[k].x; + v.normal_xy[1] = normal[k].y; + v.normal_z = normal[k].z; + + uint32_t *indexptr = vertex_map.getptr(v); + + if (indexptr) { + t.indices[k] = *indexptr; + } else { + uint32_t new_index = vertex_map.size(); + t.indices[k] = new_index; + vertex_map[v] = new_index; + vertex_array.push_back(v); + } + + if (k == 0) { + taabb.position = vtxs[k]; + } else { + taabb.expand_to(vtxs[k]); + } + } + + //compute seams that will need to be blended later + for (int k = 0; k < 3; k++) { + int n = (k + 1) % 3; + + Edge edge(vtxs[k], vtxs[n], normal[k], normal[n]); + Vector2i edge_indices(t.indices[k], t.indices[n]); + EdgeUV2 uv2(uvs[k], uvs[n], edge_indices); + + if (edge.b == edge.a) { + continue; //degenerate, somehow + } + if (edge.b < edge.a) { + SWAP(edge.a, edge.b); + SWAP(edge.na, edge.nb); + SWAP(uv2.a, uv2.b); + SWAP(edge_indices.x, edge_indices.y); + } + + EdgeUV2 *euv2 = edges.getptr(edge); + if (!euv2) { + edges[edge] = uv2; + } else { + if (*euv2 == uv2) { + continue; // seam shared UV space, no need to blend + } + if (euv2->seam_found) { + continue; //bad geometry + } + + Seam seam; + seam.a = edge_indices; + seam.b = euv2->indices; + seam.slice = mi.slice; + seams.push_back(seam); + slice_seam_count.write[mi.slice]++; + euv2->seam_found = true; + } + } + + Box box; + box.min_bounds[0] = taabb.position.x; + box.min_bounds[1] = taabb.position.y; + box.min_bounds[2] = taabb.position.z; + box.max_bounds[0] = taabb.position.x + MAX(taabb.size.x, 0.0001); + box.max_bounds[1] = taabb.position.y + MAX(taabb.size.y, 0.0001); + box.max_bounds[2] = taabb.position.z + MAX(taabb.size.z, 0.0001); + box.pad0 = box.pad1 = 0; //make valgrind not complain + box_array.push_back(box); + + triangles.push_back(t); + slice_triangle_count.write[t.slice]++; + } + } + + //also consider probe positions for bounds + for (int i = 0; i < probe_positions.size(); i++) { + Vector3 pp(probe_positions[i].position[0], probe_positions[i].position[1], probe_positions[i].position[2]); + bounds.expand_to(pp); + } + bounds.grow_by(0.1); //grow a bit to avoid numerical error + + triangles.sort(); //sort by slice + seams.sort(); + + if (p_step_function) { + p_step_function(0.4, TTR("Optimizing acceleration structure"), p_bake_userdata, true); + } + + //fill list of triangles in grid + LocalVector<TriangleSort> triangle_sort; + for (uint32_t i = 0; i < triangles.size(); i++) { + + const Triangle &t = triangles[i]; + Vector3 face[3] = { + Vector3(vertex_array[t.indices[0]].position[0], vertex_array[t.indices[0]].position[1], vertex_array[t.indices[0]].position[2]), + Vector3(vertex_array[t.indices[1]].position[0], vertex_array[t.indices[1]].position[1], vertex_array[t.indices[1]].position[2]), + Vector3(vertex_array[t.indices[2]].position[0], vertex_array[t.indices[2]].position[1], vertex_array[t.indices[2]].position[2]) + }; + _plot_triangle_into_triangle_index_list(grid_size, Vector3i(), bounds, face, i, triangle_sort, grid_size); + } + //sort it + triangle_sort.sort(); + + Vector<uint32_t> triangle_indices; + triangle_indices.resize(triangle_sort.size()); + Vector<uint32_t> grid_indices; + grid_indices.resize(grid_size * grid_size * grid_size * 2); + zeromem(grid_indices.ptrw(), grid_indices.size() * sizeof(uint32_t)); + Vector<bool> solid; + solid.resize(grid_size * grid_size * grid_size); + zeromem(solid.ptrw(), solid.size() * sizeof(bool)); + + { + uint32_t *tiw = triangle_indices.ptrw(); + uint32_t last_cell = 0xFFFFFFFF; + uint32_t *giw = grid_indices.ptrw(); + bool *solidw = solid.ptrw(); + for (uint32_t i = 0; i < triangle_sort.size(); i++) { + uint32_t cell = triangle_sort[i].cell_index; + if (cell != last_cell) { + //cell changed, update pointer to indices + giw[cell * 2 + 1] = i; + last_cell = cell; + solidw[cell] = true; + } + tiw[i] = triangle_sort[i].triangle_index; + giw[cell * 2]++; //update counter + last_cell = cell; + } + } +#if 0 + for (int i = 0; i < grid_size; i++) { + for (int j = 0; j < grid_size; j++) { + for (int k = 0; k < grid_size; k++) { + uint32_t index = i * (grid_size * grid_size) + j * grid_size + k; + grid_indices.write[index * 2] = float(i) / grid_size * 255; + grid_indices.write[index * 2 + 1] = float(j) / grid_size * 255; + } + } + } +#endif + +#if 0 + for (int i = 0; i < grid_size; i++) { + Vector<uint8_t> grid_usage; + grid_usage.resize(grid_size * grid_size); + for (int j = 0; j < grid_usage.size(); j++) { + uint32_t ofs = i * grid_size * grid_size + j; + uint32_t count = grid_indices[ofs * 2]; + grid_usage.write[j] = count > 0 ? 255 : 0; + } + + Ref<Image> img; + img.instance(); + img->create(grid_size, grid_size, false, Image::FORMAT_L8, grid_usage); + img->save_png("res://grid_layer_" + itos(1000 + i).substr(1, 3) + ".png"); + } +#endif + if (p_step_function) { + p_step_function(0.45, TTR("Generating Signed Distance Field"), p_bake_userdata, true); + } + + //generate SDF for raytracing + Vector<uint32_t> euclidean_pos = Geometry::generate_edf(solid, Vector3i(grid_size, grid_size, grid_size), false); + Vector<uint32_t> euclidean_neg = Geometry::generate_edf(solid, Vector3i(grid_size, grid_size, grid_size), true); + Vector<int8_t> sdf8 = Geometry::generate_sdf8(euclidean_pos, euclidean_neg); + + /*****************************/ + /*** CREATE GPU STRUCTURES ***/ + /*****************************/ + + lights.sort(); + + Vector<Vector2i> seam_buffer_vec; + seam_buffer_vec.resize(seams.size() * 2); + for (uint32_t i = 0; i < seams.size(); i++) { + seam_buffer_vec.write[i * 2 + 0] = seams[i].a; + seam_buffer_vec.write[i * 2 + 1] = seams[i].b; + } + + { //buffers + Vector<uint8_t> vb = vertex_array.to_byte_array(); + vertex_buffer = rd->storage_buffer_create(vb.size(), vb); + + Vector<uint8_t> tb = triangles.to_byte_array(); + triangle_buffer = rd->storage_buffer_create(tb.size(), tb); + + Vector<uint8_t> bb = box_array.to_byte_array(); + box_buffer = rd->storage_buffer_create(bb.size(), bb); + + Vector<uint8_t> tib = triangle_indices.to_byte_array(); + triangle_cell_indices_buffer = rd->storage_buffer_create(tib.size(), tib); + + Vector<uint8_t> lb = lights.to_byte_array(); + if (lb.size() == 0) { + lb.resize(sizeof(Light)); //even if no lights, the buffer must exist + } + lights_buffer = rd->storage_buffer_create(lb.size(), lb); + + Vector<uint8_t> sb = seam_buffer_vec.to_byte_array(); + if (sb.size() == 0) { + sb.resize(sizeof(Vector2i) * 2); //even if no seams, the buffer must exist + } + seams_buffer = rd->storage_buffer_create(sb.size(), sb); + + Vector<uint8_t> pb = probe_positions.to_byte_array(); + if (pb.size() == 0) { + pb.resize(sizeof(Probe)); + } + probe_positions_buffer = rd->storage_buffer_create(pb.size(), pb); + } + + { //grid + + RD::TextureFormat tf; + tf.width = grid_size; + tf.height = grid_size; + tf.depth = grid_size; + tf.type = RD::TEXTURE_TYPE_3D; + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + + Vector<Vector<uint8_t>> texdata; + texdata.resize(1); + //grid and indices + tf.format = RD::DATA_FORMAT_R32G32_UINT; + texdata.write[0] = grid_indices.to_byte_array(); + grid_texture = rd->texture_create(tf, RD::TextureView(), texdata); + //sdf + tf.format = RD::DATA_FORMAT_R8_SNORM; + texdata.write[0] = sdf8.to_byte_array(); + grid_texture_sdf = rd->texture_create(tf, RD::TextureView(), texdata); + } +} + +void LightmapperRD::_raster_geometry(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, int grid_size, AABB bounds, float p_bias, Vector<int> slice_triangle_count, RID position_tex, RID unocclude_tex, RID normal_tex, RID raster_depth_buffer, RID rasterize_shader, RID raster_base_uniform) { + + Vector<RID> framebuffers; + + for (int i = 0; i < atlas_slices; i++) { + RID slice_pos_tex = rd->texture_create_shared_from_slice(RD::TextureView(), position_tex, i, 0); + RID slice_unoc_tex = rd->texture_create_shared_from_slice(RD::TextureView(), unocclude_tex, i, 0); + RID slice_norm_tex = rd->texture_create_shared_from_slice(RD::TextureView(), normal_tex, i, 0); + Vector<RID> fb; + fb.push_back(slice_pos_tex); + fb.push_back(slice_norm_tex); + fb.push_back(slice_unoc_tex); + fb.push_back(raster_depth_buffer); + framebuffers.push_back(rd->framebuffer_create(fb)); + } + + RD::PipelineDepthStencilState ds; + ds.enable_depth_test = true; + ds.enable_depth_write = true; + ds.depth_compare_operator = RD::COMPARE_OP_LESS; //so it does render same pixel twice + + RID raster_pipeline = rd->render_pipeline_create(rasterize_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, RD::PipelineColorBlendState::create_disabled(3), 0); + RID raster_pipeline_wire; + { + + RD::PipelineRasterizationState rw; + rw.wireframe = true; + raster_pipeline_wire = rd->render_pipeline_create(rasterize_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, rw, RD::PipelineMultisampleState(), ds, RD::PipelineColorBlendState::create_disabled(3), 0); + } + + uint32_t triangle_offset = 0; + Vector<Color> clear_colors; + clear_colors.push_back(Color(0, 0, 0, 0)); + clear_colors.push_back(Color(0, 0, 0, 0)); + clear_colors.push_back(Color(0, 0, 0, 0)); + + for (int i = 0; i < atlas_slices; i++) { + + RasterPushConstant raster_push_constant; + raster_push_constant.atlas_size[0] = atlas_size.x; + raster_push_constant.atlas_size[1] = atlas_size.y; + raster_push_constant.base_triangle = triangle_offset; + raster_push_constant.to_cell_offset[0] = bounds.position.x; + raster_push_constant.to_cell_offset[1] = bounds.position.y; + raster_push_constant.to_cell_offset[2] = bounds.position.z; + raster_push_constant.bias = p_bias; + raster_push_constant.to_cell_size[0] = (1.0 / bounds.size.x) * float(grid_size); + raster_push_constant.to_cell_size[1] = (1.0 / bounds.size.y) * float(grid_size); + raster_push_constant.to_cell_size[2] = (1.0 / bounds.size.z) * float(grid_size); + raster_push_constant.grid_size[0] = grid_size; + raster_push_constant.grid_size[1] = grid_size; + raster_push_constant.grid_size[2] = grid_size; + raster_push_constant.uv_offset[0] = 0; + raster_push_constant.uv_offset[1] = 0; + + RD::DrawListID draw_list = rd->draw_list_begin(framebuffers[i], RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_DISCARD, clear_colors); + //draw opaque + rd->draw_list_bind_render_pipeline(draw_list, raster_pipeline); + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_set_push_constant(draw_list, &raster_push_constant, sizeof(RasterPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + //draw wire + rd->draw_list_bind_render_pipeline(draw_list, raster_pipeline_wire); + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_set_push_constant(draw_list, &raster_push_constant, sizeof(RasterPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + + rd->draw_list_end(); + + triangle_offset += slice_triangle_count[i]; + } +} + +LightmapperRD::BakeError LightmapperRD::bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function, void *p_bake_userdata) { + + if (p_step_function) { + p_step_function(0.0, TTR("Begin Bake"), p_bake_userdata, true); + } + bake_textures.clear(); + int grid_size = 128; + + /* STEP 1: Fetch material textures and compute the bounds */ + + AABB bounds; + Size2i atlas_size; + int atlas_slices; + Vector<Ref<Image>> albedo_images; + Vector<Ref<Image>> emission_images; + + BakeError bake_error = _blit_meshes_into_atlas(p_max_texture_size, albedo_images, emission_images, bounds, atlas_size, atlas_slices, p_step_function, p_bake_userdata); + if (bake_error != BAKE_OK) { + return bake_error; + } + +#ifdef DEBUG_TEXTURES + for (int i = 0; i < atlas_slices; i++) { + albedo_images[i]->save_png("res://0_albedo_" + itos(i) + ".png"); + emission_images[i]->save_png("res://0_emission_" + itos(i) + ".png"); + } +#endif + + RenderingDevice *rd = RenderingDevice::get_singleton()->create_local_device(); + + RID albedo_array_tex; + RID emission_array_tex; + RID normal_tex; + RID position_tex; + RID unocclude_tex; + RID light_source_tex; + RID light_dest_tex; + RID light_accum_tex; + RID light_accum_tex2; + RID light_primary_dynamic_tex; + RID light_environment_tex; + +#define FREE_TEXTURES \ + rd->free(albedo_array_tex); \ + rd->free(emission_array_tex); \ + rd->free(normal_tex); \ + rd->free(position_tex); \ + rd->free(unocclude_tex); \ + rd->free(light_source_tex); \ + rd->free(light_accum_tex2); \ + rd->free(light_accum_tex); \ + rd->free(light_primary_dynamic_tex); \ + rd->free(light_environment_tex); + + { // create all textures + + Vector<Vector<uint8_t>> albedo_data; + Vector<Vector<uint8_t>> emission_data; + for (int i = 0; i < atlas_slices; i++) { + albedo_data.push_back(albedo_images[i]->get_data()); + emission_data.push_back(emission_images[i]->get_data()); + } + + RD::TextureFormat tf; + tf.width = atlas_size.width; + tf.height = atlas_size.height; + tf.array_layers = atlas_slices; + tf.type = RD::TEXTURE_TYPE_2D_ARRAY; + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tf.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + + albedo_array_tex = rd->texture_create(tf, RD::TextureView(), albedo_data); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + + emission_array_tex = rd->texture_create(tf, RD::TextureView(), emission_data); + + //this will be rastered to + tf.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT | RD::TEXTURE_USAGE_STORAGE_BIT; + normal_tex = rd->texture_create(tf, RD::TextureView()); + tf.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + position_tex = rd->texture_create(tf, RD::TextureView()); + unocclude_tex = rd->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + tf.usage_bits = RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_STORAGE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT | RD::TEXTURE_USAGE_CAN_COPY_TO_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + + light_source_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_source_tex, Color(0, 0, 0, 0), 0, 1, 0, atlas_slices); + light_primary_dynamic_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_primary_dynamic_tex, Color(0, 0, 0, 0), 0, 1, 0, atlas_slices); + + if (p_bake_sh) { + tf.array_layers *= 4; + } + light_accum_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_accum_tex, Color(0, 0, 0, 0), 0, 1, 0, tf.array_layers); + light_dest_tex = rd->texture_create(tf, RD::TextureView()); + rd->texture_clear(light_dest_tex, Color(0, 0, 0, 0), 0, 1, 0, tf.array_layers); + light_accum_tex2 = light_dest_tex; + + //env + { + Ref<Image> panorama_tex; + if (p_environment_panorama.is_valid()) { + panorama_tex = p_environment_panorama; + panorama_tex->convert(Image::FORMAT_RGBAF); + } else { + panorama_tex.instance(); + panorama_tex->create(8, 8, false, Image::FORMAT_RGBAF); + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j++) { + panorama_tex->set_pixel(i, j, Color(0, 0, 0, 1)); + } + } + } + + RD::TextureFormat tfp; + tfp.width = panorama_tex->get_width(); + tfp.height = panorama_tex->get_height(); + tfp.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tfp.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + + Vector<Vector<uint8_t>> tdata; + tdata.push_back(panorama_tex->get_data()); + light_environment_tex = rd->texture_create(tfp, RD::TextureView(), tdata); + +#ifdef DEBUG_TEXTURES + panorama_tex->convert(Image::FORMAT_RGB8); + panorama_tex->save_png("res://0_panorama.png"); +#endif + } + } + + /* STEP 2: create the acceleration structure for the GPU*/ + + Vector<int> slice_triangle_count; + RID vertex_buffer; + RID triangle_buffer; + RID box_buffer; + RID lights_buffer; + RID triangle_cell_indices_buffer; + RID grid_texture; + RID grid_texture_sdf; + RID seams_buffer; + RID probe_positions_buffer; + + Vector<int> slice_seam_count; + +#define FREE_BUFFERS \ + rd->free(vertex_buffer); \ + rd->free(triangle_buffer); \ + rd->free(box_buffer); \ + rd->free(lights_buffer); \ + rd->free(triangle_cell_indices_buffer); \ + rd->free(grid_texture); \ + rd->free(grid_texture_sdf); \ + rd->free(seams_buffer); \ + rd->free(probe_positions_buffer); + + _create_acceleration_structures(rd, atlas_size, atlas_slices, bounds, grid_size, probe_positions, p_generate_probes, slice_triangle_count, slice_seam_count, vertex_buffer, triangle_buffer, box_buffer, lights_buffer, triangle_cell_indices_buffer, probe_positions_buffer, grid_texture, grid_texture_sdf, seams_buffer, p_step_function, p_bake_userdata); + + if (p_step_function) { + p_step_function(0.47, TTR("Preparing shaders"), p_bake_userdata, true); + } + + //shaders + Ref<RDShaderFile> raster_shader; + raster_shader.instance(); + Error err = raster_shader->parse_versions_from_text(lm_raster_shader_glsl); + if (err != OK) { + raster_shader->print_errors("raster_shader"); + + FREE_TEXTURES + FREE_BUFFERS + + memdelete(rd); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID rasterize_shader = rd->shader_create_from_bytecode(raster_shader->get_bytecode()); + + ERR_FAIL_COND_V(rasterize_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //this is a bug check, though, should not happen + + RID sampler; + { + RD::SamplerState s; + s.mag_filter = RD::SAMPLER_FILTER_LINEAR; + s.min_filter = RD::SAMPLER_FILTER_LINEAR; + s.max_lod = 0; + + sampler = rd->sampler_create(s); + } + + Vector<RD::Uniform> base_uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 1; + u.ids.push_back(vertex_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 2; + u.ids.push_back(triangle_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 3; + u.ids.push_back(box_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 4; + u.ids.push_back(triangle_cell_indices_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 5; + u.ids.push_back(lights_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 6; + u.ids.push_back(seams_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 7; + u.ids.push_back(probe_positions_buffer); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 8; + u.ids.push_back(grid_texture); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 9; + u.ids.push_back(grid_texture_sdf); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 10; + u.ids.push_back(albedo_array_tex); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 11; + u.ids.push_back(emission_array_tex); + base_uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_SAMPLER; + u.binding = 12; + u.ids.push_back(sampler); + base_uniforms.push_back(u); + } + } + + RID raster_base_uniform = rd->uniform_set_create(base_uniforms, rasterize_shader, 0); + RID raster_depth_buffer; + { + RD::TextureFormat tf; + tf.width = atlas_size.width; + tf.height = atlas_size.height; + tf.depth = 1; + tf.type = RD::TEXTURE_TYPE_2D; + tf.usage_bits = RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT; + tf.format = RD::DATA_FORMAT_D32_SFLOAT; + + raster_depth_buffer = rd->texture_create(tf, RD::TextureView()); + } + + rd->submit(); + rd->sync(); + + /* STEP 3: Raster the geometry to UV2 coords in the atlas textures GPU*/ + + _raster_geometry(rd, atlas_size, atlas_slices, grid_size, bounds, p_bias, slice_triangle_count, position_tex, unocclude_tex, normal_tex, raster_depth_buffer, rasterize_shader, raster_base_uniform); + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices; i++) { + Vector<uint8_t> s = rd->texture_get_data(position_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAF, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://1_position_" + itos(i) + ".png"); + + s = rd->texture_get_data(normal_tex, i); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://1_normal_" + itos(i) + ".png"); + } +#endif + +#define FREE_RASTER_RESOURCES \ + rd->free(rasterize_shader); \ + rd->free(sampler); \ + rd->free(raster_depth_buffer); + + /* Plot direct light */ + + Ref<RDShaderFile> compute_shader; + compute_shader.instance(); + err = compute_shader->parse_versions_from_text(lm_compute_shader_glsl, p_bake_sh ? "\n#define USE_SH_LIGHTMAPS\n" : ""); + if (err != OK) { + + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + memdelete(rd); + compute_shader->print_errors("compute_shader"); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + //unoccluder + RID compute_shader_unocclude = rd->shader_create_from_bytecode(compute_shader->get_bytecode("unocclude")); + ERR_FAIL_COND_V(compute_shader_unocclude.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); // internal check, should not happen + RID compute_shader_unocclude_pipeline = rd->compute_pipeline_create(compute_shader_unocclude); + + //direct light + RID compute_shader_primary = rd->shader_create_from_bytecode(compute_shader->get_bytecode("primary")); + ERR_FAIL_COND_V(compute_shader_primary.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); // internal check, should not happen + RID compute_shader_primary_pipeline = rd->compute_pipeline_create(compute_shader_primary); + + //indirect light + RID compute_shader_secondary = rd->shader_create_from_bytecode(compute_shader->get_bytecode("secondary")); + ERR_FAIL_COND_V(compute_shader_secondary.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_secondary_pipeline = rd->compute_pipeline_create(compute_shader_secondary); + + //dilate + RID compute_shader_dilate = rd->shader_create_from_bytecode(compute_shader->get_bytecode("dilate")); + ERR_FAIL_COND_V(compute_shader_dilate.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_dilate_pipeline = rd->compute_pipeline_create(compute_shader_dilate); + + //dilate + RID compute_shader_light_probes = rd->shader_create_from_bytecode(compute_shader->get_bytecode("light_probes")); + ERR_FAIL_COND_V(compute_shader_light_probes.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); //internal check, should not happen + RID compute_shader_light_probes_pipeline = rd->compute_pipeline_create(compute_shader_light_probes); + + RID compute_base_uniform_set = rd->uniform_set_create(base_uniforms, compute_shader_primary, 0); + +#define FREE_COMPUTE_RESOURCES \ + rd->free(compute_shader_unocclude); \ + rd->free(compute_shader_primary); \ + rd->free(compute_shader_secondary); \ + rd->free(compute_shader_dilate); \ + rd->free(compute_shader_light_probes); + + PushConstant push_constant; + { + //set defaults + push_constant.atlas_size[0] = atlas_size.width; + push_constant.atlas_size[1] = atlas_size.height; + push_constant.world_size[0] = bounds.size.x; + push_constant.world_size[1] = bounds.size.y; + push_constant.world_size[2] = bounds.size.z; + push_constant.to_cell_offset[0] = bounds.position.x; + push_constant.to_cell_offset[1] = bounds.position.y; + push_constant.to_cell_offset[2] = bounds.position.z; + push_constant.bias = p_bias; + push_constant.to_cell_size[0] = (1.0 / bounds.size.x) * float(grid_size); + push_constant.to_cell_size[1] = (1.0 / bounds.size.y) * float(grid_size); + push_constant.to_cell_size[2] = (1.0 / bounds.size.z) * float(grid_size); + push_constant.light_count = lights.size(); + push_constant.grid_size = grid_size; + push_constant.atlas_slice = 0; + push_constant.region_ofs[0] = 0; + push_constant.region_ofs[1] = 0; + push_constant.environment_xform[0] = p_environment_transform.elements[0][0]; + push_constant.environment_xform[1] = p_environment_transform.elements[1][0]; + push_constant.environment_xform[2] = p_environment_transform.elements[2][0]; + push_constant.environment_xform[3] = 0; + push_constant.environment_xform[4] = p_environment_transform.elements[0][1]; + push_constant.environment_xform[5] = p_environment_transform.elements[1][1]; + push_constant.environment_xform[6] = p_environment_transform.elements[2][1]; + push_constant.environment_xform[7] = 0; + push_constant.environment_xform[8] = p_environment_transform.elements[0][2]; + push_constant.environment_xform[9] = p_environment_transform.elements[1][2]; + push_constant.environment_xform[10] = p_environment_transform.elements[2][2]; + push_constant.environment_xform[11] = 0; + } + + Vector3i group_size((atlas_size.x - 1) / 8 + 1, (atlas_size.y - 1) / 8 + 1, 1); + rd->submit(); + rd->sync(); + + if (p_step_function) { + p_step_function(0.49, TTR("Un-occluding geometry"), p_bake_userdata, true); + } + + /* UNOCCLUDE */ + { + + Vector<RD::Uniform> uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 1; + u.ids.push_back(unocclude_tex); //will be unused + uniforms.push_back(u); + } + } + + RID unocclude_uniform_set = rd->uniform_set_create(uniforms, compute_shader_unocclude, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_unocclude_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, unocclude_uniform_set, 1); + + for (int i = 0; i < atlas_slices; i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); //done + } + + if (p_step_function) { + p_step_function(0.5, TTR("Plot direct lighting"), p_bake_userdata, true); + } + + /* PRIMARY (direct) LIGHT PASS */ + { + + Vector<RD::Uniform> uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_source_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_dest_tex); //will be unused + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(normal_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 4; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 5; + u.ids.push_back(light_primary_dynamic_tex); + uniforms.push_back(u); + } + } + + RID light_uniform_set = rd->uniform_set_create(uniforms, compute_shader_primary, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_primary_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, light_uniform_set, 1); + + for (int i = 0; i < atlas_slices; i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); //done + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices; i++) { + Vector<uint8_t> s = rd->texture_get_data(light_source_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://2_light_primary_" + itos(i) + ".png"); + } +#endif + + /* SECONDARY (indirect) LIGHT PASS(ES) */ + if (p_step_function) { + p_step_function(0.6, TTR("Integrate indirect lighting"), p_bake_userdata, true); + } + + if (p_bounces > 0) { + + Vector<RD::Uniform> uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_dest_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_source_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(position_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(normal_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 4; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 5; + u.ids.push_back(unocclude_tex); //reuse unocclude tex + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 6; + u.ids.push_back(light_environment_tex); //reuse unocclude tex + uniforms.push_back(u); + } + } + + RID secondary_uniform_set[2]; + secondary_uniform_set[0] = rd->uniform_set_create(uniforms, compute_shader_secondary, 1); + uniforms.write[0].ids.write[0] = light_source_tex; + uniforms.write[1].ids.write[0] = light_dest_tex; + secondary_uniform_set[1] = rd->uniform_set_create(uniforms, compute_shader_secondary, 1); + + switch (p_quality) { + case BAKE_QUALITY_LOW: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/low_quality_ray_count"); + } break; + case BAKE_QUALITY_MEDIUM: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/medium_quality_ray_count"); + } break; + case BAKE_QUALITY_HIGH: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/high_quality_ray_count"); + } break; + case BAKE_QUALITY_ULTRA: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/ultra_quality_ray_count"); + } break; + } + + push_constant.ray_count = CLAMP(push_constant.ray_count, 16, 8192); + + int max_region_size = nearest_power_of_2_templated(int(GLOBAL_GET("rendering/gpu_lightmapper/performance/region_size"))); + int max_rays = GLOBAL_GET("rendering/gpu_lightmapper/performance/max_rays_per_pass"); + + int x_regions = (atlas_size.width - 1) / max_region_size + 1; + int y_regions = (atlas_size.height - 1) / max_region_size + 1; + int ray_iterations = (push_constant.ray_count - 1) / max_rays + 1; + + rd->submit(); + rd->sync(); + + for (int b = 0; b < p_bounces; b++) { + int count = 0; + if (b > 0) { + SWAP(light_source_tex, light_dest_tex); + SWAP(secondary_uniform_set[0], secondary_uniform_set[1]); + } + + for (int s = 0; s < atlas_slices; s++) { + push_constant.atlas_slice = s; + + for (int i = 0; i < x_regions; i++) { + for (int j = 0; j < y_regions; j++) { + + int x = i * max_region_size; + int y = j * max_region_size; + int w = MIN((i + 1) * max_region_size, atlas_size.width) - x; + int h = MIN((j + 1) * max_region_size, atlas_size.height) - y; + + push_constant.region_ofs[0] = x; + push_constant.region_ofs[1] = y; + + group_size = Vector3i((w - 1) / 8 + 1, (h - 1) / 8 + 1, 1); + + for (int k = 0; k < ray_iterations; k++) { + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_secondary_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, secondary_uniform_set[0], 1); + + push_constant.ray_from = k * max_rays; + push_constant.ray_to = MIN((k + 1) * max_rays, int32_t(push_constant.ray_count)); + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + + rd->compute_list_end(); //done + rd->submit(); + rd->sync(); + + count++; + if (p_step_function) { + int total = (atlas_slices * x_regions * y_regions * ray_iterations); + int percent = count * 100 / total; + float p = float(count) / total * 0.1; + p_step_function(0.6 + p, vformat(TTR("Bounce %d/%d: Integrate indirect lighting %d%%"), b + 1, p_bounces, percent), p_bake_userdata, false); + } + } + } + } + } + } + } + + /* LIGHPROBES */ + + RID light_probe_buffer; + + if (probe_positions.size()) { + + light_probe_buffer = rd->storage_buffer_create(sizeof(float) * 4 * 9 * probe_positions.size()); + + if (p_step_function) { + p_step_function(0.7, TTR("Baking lightprobes"), p_bake_userdata, true); + } + + Vector<RD::Uniform> uniforms; + { + + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.binding = 0; + u.ids.push_back(light_probe_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_dest_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 2; + u.ids.push_back(light_primary_dynamic_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 3; + u.ids.push_back(light_environment_tex); + uniforms.push_back(u); + } + } + RID light_probe_uniform_set = rd->uniform_set_create(uniforms, compute_shader_light_probes, 1); + + switch (p_quality) { + case BAKE_QUALITY_LOW: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/low_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_MEDIUM: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/medium_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_HIGH: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/high_quality_probe_ray_count"); + } break; + case BAKE_QUALITY_ULTRA: { + push_constant.ray_count = GLOBAL_GET("rendering/gpu_lightmapper/quality/ultra_quality_probe_ray_count"); + } break; + } + + push_constant.atlas_size[0] = probe_positions.size(); + push_constant.ray_count = CLAMP(push_constant.ray_count, 16, 8192); + + int max_rays = GLOBAL_GET("rendering/gpu_lightmapper/performance/max_rays_per_probe_pass"); + int ray_iterations = (push_constant.ray_count - 1) / max_rays + 1; + + for (int i = 0; i < ray_iterations; i++) { + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_light_probes_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, light_probe_uniform_set, 1); + + push_constant.ray_from = i * max_rays; + push_constant.ray_to = MIN((i + 1) * max_rays, int32_t(push_constant.ray_count)); + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, (probe_positions.size() - 1) / 64 + 1, 1, 1); + + rd->compute_list_end(); //done + rd->submit(); + rd->sync(); + + if (p_step_function) { + int percent = i * 100 / ray_iterations; + float p = float(i) / ray_iterations * 0.1; + p_step_function(0.7 + p, vformat(TTR("Integrating light probes %d%%"), percent), p_bake_userdata, false); + } + } + + push_constant.atlas_size[0] = atlas_size.x; //restore + } + +#if 0 + for (int i = 0; i < probe_positions.size(); i++) { + Ref<Image> img; + img.instance(); + img->create(6, 4, false, Image::FORMAT_RGB8); + for (int j = 0; j < 6; j++) { + Vector<uint8_t> s = rd->texture_get_data(lightprobe_tex, i * 6 + j); + Ref<Image> img2; + img2.instance(); + img2->create(2, 2, false, Image::FORMAT_RGBAF, s); + img2->convert(Image::FORMAT_RGB8); + img->blit_rect(img2, Rect2(0, 0, 2, 2), Point2((j % 3) * 2, (j / 3) * 2)); + } + img->save_png("res://3_light_probe_" + itos(i) + ".png"); + } +#endif + + /* DENOISE */ + + if (p_use_denoiser) { + if (p_step_function) { + p_step_function(0.8, TTR("Denoising"), p_bake_userdata, true); + } + + Ref<LightmapDenoiser> denoiser = LightmapDenoiser::create(); + if (denoiser.is_valid()) { + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + + Ref<Image> denoised = denoiser->denoise_image(img); + if (denoised != img) { + denoised->convert(Image::FORMAT_RGBAH); + Vector<uint8_t> ds = denoised->get_data(); + denoised.unref(); //avoid copy on write + { //restore alpha + uint32_t count = s.size() / 2; //uint16s + const uint16_t *src = (const uint16_t *)s.ptr(); + uint16_t *dst = (uint16_t *)ds.ptrw(); + for (uint32_t j = 0; j < count; j += 4) { + dst[j + 3] = src[j + 3]; + } + } + rd->texture_update(light_accum_tex, i, ds, true); + } + } + } + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://4_light_secondary_" + itos(i) + ".png"); + } +#endif + + /* DILATE LIGHTMAP */ + { + + SWAP(light_accum_tex, light_accum_tex2); + + Vector<RD::Uniform> uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_IMAGE; + u.binding = 0; + u.ids.push_back(light_accum_tex); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 1; + u.ids.push_back(light_accum_tex2); + uniforms.push_back(u); + } + } + + RID dilate_uniform_set = rd->uniform_set_create(uniforms, compute_shader_dilate, 1); + + RD::ComputeListID compute_list = rd->compute_list_begin(); + rd->compute_list_bind_compute_pipeline(compute_list, compute_shader_dilate_pipeline); + rd->compute_list_bind_uniform_set(compute_list, compute_base_uniform_set, 0); + rd->compute_list_bind_uniform_set(compute_list, dilate_uniform_set, 1); + push_constant.region_ofs[0] = 0; + push_constant.region_ofs[1] = 0; + group_size = Vector3i((atlas_size.x - 1) / 8 + 1, (atlas_size.y - 1) / 8 + 1, 1); //restore group size + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + push_constant.atlas_slice = i; + rd->compute_list_set_push_constant(compute_list, &push_constant, sizeof(PushConstant)); + rd->compute_list_dispatch(compute_list, group_size.x, group_size.y, group_size.z); + //no barrier, let them run all together + } + rd->compute_list_end(); + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://5_dilated_" + itos(i) + ".png"); + } +#endif + + /* BLEND SEAMS */ + //shaders + Ref<RDShaderFile> blendseams_shader; + blendseams_shader.instance(); + err = blendseams_shader->parse_versions_from_text(lm_blendseams_shader_glsl); + if (err != OK) { + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + FREE_COMPUTE_RESOURCES + memdelete(rd); + blendseams_shader->print_errors("blendseams_shader"); + } + ERR_FAIL_COND_V(err != OK, BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID blendseams_line_raster_shader = rd->shader_create_from_bytecode(blendseams_shader->get_bytecode("lines")); + + ERR_FAIL_COND_V(blendseams_line_raster_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + + RID blendseams_triangle_raster_shader = rd->shader_create_from_bytecode(blendseams_shader->get_bytecode("triangles")); + + ERR_FAIL_COND_V(blendseams_triangle_raster_shader.is_null(), BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES); + +#define FREE_BLENDSEAMS_RESOURCES \ + rd->free(blendseams_line_raster_shader); \ + rd->free(blendseams_triangle_raster_shader); + + { + + //pre copy + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + rd->texture_copy(light_accum_tex, light_accum_tex2, Vector3(), Vector3(), Vector3(atlas_size.width, atlas_size.height, 1), 0, 0, i, i, true); + } + + Vector<RID> framebuffers; + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + RID slice_tex = rd->texture_create_shared_from_slice(RD::TextureView(), light_accum_tex, i, 0); + Vector<RID> fb; + fb.push_back(slice_tex); + fb.push_back(raster_depth_buffer); + framebuffers.push_back(rd->framebuffer_create(fb)); + } + + Vector<RD::Uniform> uniforms; + { + { + + RD::Uniform u; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.binding = 0; + u.ids.push_back(light_accum_tex2); + uniforms.push_back(u); + } + } + + RID blendseams_raster_uniform = rd->uniform_set_create(uniforms, blendseams_line_raster_shader, 1); + + bool debug = false; + RD::PipelineColorBlendState bs = RD::PipelineColorBlendState::create_blend(1); + bs.attachments.write[0].src_alpha_blend_factor = RD::BLEND_FACTOR_ZERO; + bs.attachments.write[0].dst_alpha_blend_factor = RD::BLEND_FACTOR_ONE; + + RD::PipelineDepthStencilState ds; + ds.enable_depth_test = true; + ds.enable_depth_write = true; + ds.depth_compare_operator = RD::COMPARE_OP_LESS; //so it does not render same pixel twice, this avoids wrong blending + + RID blendseams_line_raster_pipeline = rd->render_pipeline_create(blendseams_line_raster_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_LINES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, bs, 0); + RID blendseams_triangle_raster_pipeline = rd->render_pipeline_create(blendseams_triangle_raster_shader, rd->framebuffer_get_format(framebuffers[0]), RD::INVALID_FORMAT_ID, RD::RENDER_PRIMITIVE_TRIANGLES, RD::PipelineRasterizationState(), RD::PipelineMultisampleState(), ds, bs, 0); + + uint32_t seam_offset = 0; + uint32_t triangle_offset = 0; + + Vector<Color> clear_colors; + clear_colors.push_back(Color(0, 0, 0, 1)); + for (int i = 0; i < atlas_slices; i++) { + + int subslices = (p_bake_sh ? 4 : 1); + for (int k = 0; k < subslices; k++) { + + RasterSeamsPushConstant seams_push_constant; + seams_push_constant.slice = uint32_t(i * subslices + k); + seams_push_constant.debug = debug; + + RD::DrawListID draw_list = rd->draw_list_begin(framebuffers[i], RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_DISCARD, clear_colors); + + rd->draw_list_bind_uniform_set(draw_list, raster_base_uniform, 0); + rd->draw_list_bind_uniform_set(draw_list, blendseams_raster_uniform, 1); + + const int uv_offset_count = 9; + static const Vector3 uv_offsets[uv_offset_count] = { + Vector3(0, 0, 0.5), //using zbuffer, so go inwards-outwards + Vector3(0, 1, 0.2), + Vector3(0, -1, 0.2), + Vector3(1, 0, 0.2), + Vector3(-1, 0, 0.2), + Vector3(-1, -1, 0.1), + Vector3(1, -1, 0.1), + Vector3(1, 1, 0.1), + Vector3(-1, 1, 0.1), + }; + + /* step 1 use lines to blend the edges */ + { + seams_push_constant.base_index = seam_offset; + rd->draw_list_bind_render_pipeline(draw_list, blendseams_line_raster_pipeline); + seams_push_constant.uv_offset[0] = uv_offsets[0].x / float(atlas_size.width); + seams_push_constant.uv_offset[1] = uv_offsets[0].y / float(atlas_size.height); + seams_push_constant.blend = uv_offsets[0].z; + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_seam_count[i] * 4); + } + + /* step 2 use triangles to mask the interior */ + + { + seams_push_constant.base_index = triangle_offset; + rd->draw_list_bind_render_pipeline(draw_list, blendseams_triangle_raster_pipeline); + seams_push_constant.blend = 0; //do not draw them, just fill the z-buffer so its used as a mask + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_triangle_count[i] * 3); + } + /* step 3 blend around the triangle */ + + rd->draw_list_bind_render_pipeline(draw_list, blendseams_line_raster_pipeline); + + for (int j = 1; j < uv_offset_count; j++) { + + seams_push_constant.base_index = seam_offset; + seams_push_constant.uv_offset[0] = uv_offsets[j].x / float(atlas_size.width); + seams_push_constant.uv_offset[1] = uv_offsets[j].y / float(atlas_size.height); + seams_push_constant.blend = uv_offsets[0].z; + + rd->draw_list_set_push_constant(draw_list, &seams_push_constant, sizeof(RasterSeamsPushConstant)); + rd->draw_list_draw(draw_list, false, 1, slice_seam_count[i] * 4); + } + rd->draw_list_end(); + } + seam_offset += slice_seam_count[i]; + triangle_offset += slice_triangle_count[i]; + } + } + +#ifdef DEBUG_TEXTURES + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBA8); + img->save_png("res://5_blendseams" + itos(i) + ".png"); + } +#endif + if (p_step_function) { + p_step_function(0.9, TTR("Retrieving textures"), p_bake_userdata, true); + } + + for (int i = 0; i < atlas_slices * (p_bake_sh ? 4 : 1); i++) { + Vector<uint8_t> s = rd->texture_get_data(light_accum_tex, i); + Ref<Image> img; + img.instance(); + img->create(atlas_size.width, atlas_size.height, false, Image::FORMAT_RGBAH, s); + img->convert(Image::FORMAT_RGBH); //remove alpha + bake_textures.push_back(img); + } + + if (probe_positions.size() > 0) { + probe_values.resize(probe_positions.size() * 9); + Vector<uint8_t> probe_data = rd->buffer_get_data(light_probe_buffer); + copymem(probe_values.ptrw(), probe_data.ptr(), probe_data.size()); + rd->free(light_probe_buffer); + +#ifdef DEBUG_TEXTURES + { + Ref<Image> img2; + img2.instance(); + img2->create(probe_values.size(), 1, false, Image::FORMAT_RGBAF, probe_data); + img2->save_png("res://6_lightprobes.png"); + } +#endif + } + + FREE_TEXTURES + FREE_BUFFERS + FREE_RASTER_RESOURCES + FREE_COMPUTE_RESOURCES + FREE_BLENDSEAMS_RESOURCES + + memdelete(rd); + + return BAKE_OK; +} + +int LightmapperRD::get_bake_texture_count() const { + return bake_textures.size(); +} +Ref<Image> LightmapperRD::get_bake_texture(int p_index) const { + ERR_FAIL_INDEX_V(p_index, bake_textures.size(), Ref<Image>()); + return bake_textures[p_index]; +} +int LightmapperRD::get_bake_mesh_count() const { + return mesh_instances.size(); +} +Variant LightmapperRD::get_bake_mesh_userdata(int p_index) const { + ERR_FAIL_INDEX_V(p_index, mesh_instances.size(), Variant()); + return mesh_instances[p_index].data.userdata; +} +Rect2 LightmapperRD::get_bake_mesh_uv_scale(int p_index) const { + + ERR_FAIL_COND_V(bake_textures.size() == 0, Rect2()); + Rect2 uv_ofs; + Vector2 atlas_size = Vector2(bake_textures[0]->get_width(), bake_textures[0]->get_height()); + uv_ofs.position = Vector2(mesh_instances[p_index].offset) / atlas_size; + uv_ofs.size = Vector2(mesh_instances[p_index].data.albedo_on_uv2->get_width(), mesh_instances[p_index].data.albedo_on_uv2->get_height()) / atlas_size; + return uv_ofs; +} +int LightmapperRD::get_bake_mesh_texture_slice(int p_index) const { + ERR_FAIL_INDEX_V(p_index, mesh_instances.size(), Variant()); + return mesh_instances[p_index].slice; +} + +int LightmapperRD::get_bake_probe_count() const { + return probe_positions.size(); +} + +Vector3 LightmapperRD::get_bake_probe_point(int p_probe) const { + ERR_FAIL_INDEX_V(p_probe, probe_positions.size(), Variant()); + return Vector3(probe_positions[p_probe].position[0], probe_positions[p_probe].position[1], probe_positions[p_probe].position[2]); +} + +Vector<Color> LightmapperRD::get_bake_probe_sh(int p_probe) const { + ERR_FAIL_INDEX_V(p_probe, probe_positions.size(), Vector<Color>()); + Vector<Color> ret; + ret.resize(9); + copymem(ret.ptrw(), &probe_values[p_probe * 9], sizeof(Color) * 9); + return ret; +} + +LightmapperRD::LightmapperRD() { +} diff --git a/modules/lightmapper_rd/lightmapper_rd.h b/modules/lightmapper_rd/lightmapper_rd.h new file mode 100644 index 0000000000..cb98efbeaa --- /dev/null +++ b/modules/lightmapper_rd/lightmapper_rd.h @@ -0,0 +1,229 @@ +#ifndef LIGHTMAPPER_RD_H +#define LIGHTMAPPER_RD_H + +#include "core/local_vector.h" +#include "scene/3d/lightmapper.h" +#include "scene/resources/mesh.h" +#include "servers/rendering/rendering_device.h" + +class LightmapperRD : public Lightmapper { + GDCLASS(LightmapperRD, Lightmapper) + + struct MeshInstance { + MeshData data; + int slice = 0; + Vector2i offset; + }; + + struct Light { + float position[3]; + uint32_t type = LIGHT_TYPE_DIRECTIONAL; + float direction[3]; + float energy; + float color[3]; + float size; + float range; + float attenuation; + float spot_angle; + float spot_attenuation; + uint32_t static_bake; + uint32_t pad[3]; + + bool operator<(const Light &p_light) const { + return type < p_light.type; + } + }; + + struct Vertex { + float position[3]; + float normal_z; + float uv[2]; + float normal_xy[2]; + + bool operator==(const Vertex &p_vtx) const { + return (position[0] == p_vtx.position[0]) && + (position[1] == p_vtx.position[1]) && + (position[2] == p_vtx.position[2]) && + (uv[0] == p_vtx.uv[0]) && + (uv[1] == p_vtx.uv[1]) && + (normal_xy[0] == p_vtx.normal_xy[0]) && + (normal_xy[1] == p_vtx.normal_xy[1]) && + (normal_z == p_vtx.normal_z); + } + }; + + struct Edge { + Vector3 a; + Vector3 b; + Vector3 na; + Vector3 nb; + bool operator==(const Edge &p_seam) const { + return a == p_seam.a && b == p_seam.b && na == p_seam.na && nb == p_seam.nb; + } + Edge() { + } + + Edge(const Vector3 &p_a, const Vector3 &p_b, const Vector3 &p_na, const Vector3 &p_nb) { + a = p_a; + b = p_b; + na = p_na; + nb = p_nb; + } + }; + + struct Probe { + float position[4]; + }; + + Vector<Probe> probe_positions; + + struct EdgeHash { + _FORCE_INLINE_ static uint32_t hash(const Edge &p_edge) { + uint32_t h = hash_djb2_one_float(p_edge.a.x); + h = hash_djb2_one_float(p_edge.a.y, h); + h = hash_djb2_one_float(p_edge.a.z, h); + h = hash_djb2_one_float(p_edge.b.x, h); + h = hash_djb2_one_float(p_edge.b.y, h); + h = hash_djb2_one_float(p_edge.b.z, h); + return h; + } + }; + struct EdgeUV2 { + Vector2 a; + Vector2 b; + Vector2i indices; + bool operator==(const EdgeUV2 &p_uv2) const { + return a == p_uv2.a && b == p_uv2.b; + } + bool seam_found = false; + EdgeUV2(Vector2 p_a, Vector2 p_b, Vector2i p_indices) { + a = p_a; + b = p_b; + indices = p_indices; + } + EdgeUV2() {} + }; + + struct Seam { + Vector2i a; + Vector2i b; + uint32_t slice; + bool operator<(const Seam &p_seam) const { + return slice < p_seam.slice; + } + }; + + struct VertexHash { + _FORCE_INLINE_ static uint32_t hash(const Vertex &p_vtx) { + uint32_t h = hash_djb2_one_float(p_vtx.position[0]); + h = hash_djb2_one_float(p_vtx.position[1], h); + h = hash_djb2_one_float(p_vtx.position[2], h); + h = hash_djb2_one_float(p_vtx.uv[0], h); + h = hash_djb2_one_float(p_vtx.uv[1], h); + h = hash_djb2_one_float(p_vtx.normal_xy[0], h); + h = hash_djb2_one_float(p_vtx.normal_xy[1], h); + h = hash_djb2_one_float(p_vtx.normal_z, h); + return h; + } + }; + + struct Box { + float min_bounds[3]; + float pad0; + float max_bounds[3]; + float pad1; + }; + + struct Triangle { + uint32_t indices[3]; + uint32_t slice; + bool operator<(const Triangle &p_triangle) const { + return slice < p_triangle.slice; + } + }; + + Vector<MeshInstance> mesh_instances; + + Vector<Light> lights; + + struct TriangleSort { + uint32_t cell_index; + uint32_t triangle_index; + bool operator<(const TriangleSort &p_triangle_sort) const { + return cell_index < p_triangle_sort.cell_index; //sorting by triangle index in this case makes no sense + } + }; + + void _plot_triangle_into_triangle_index_list(int p_size, const Vector3i &p_ofs, const AABB &p_bounds, const Vector3 p_points[], uint32_t p_triangle_index, LocalVector<TriangleSort> &triangles, uint32_t p_grid_size); + + struct RasterPushConstant { + float atlas_size[2]; + float uv_offset[2]; + float to_cell_size[3]; + uint32_t base_triangle; + float to_cell_offset[3]; + float bias; + int32_t grid_size[3]; + uint32_t pad2; + }; + + struct RasterSeamsPushConstant { + + uint32_t base_index; + uint32_t slice; + float uv_offset[2]; + uint32_t debug; + float blend; + uint32_t pad[2]; + }; + + struct PushConstant { + int32_t atlas_size[2]; + uint32_t ray_count; + uint32_t ray_to; + + float world_size[3]; + float bias; + + float to_cell_offset[3]; + uint32_t ray_from; + + float to_cell_size[3]; + uint32_t light_count; + + int32_t grid_size; + int32_t atlas_slice; + int32_t region_ofs[2]; + + float environment_xform[12]; + }; + + Vector<Ref<Image>> bake_textures; + Vector<Color> probe_values; + + BakeError _blit_meshes_into_atlas(int p_max_texture_size, Vector<Ref<Image>> &albedo_images, Vector<Ref<Image>> &emission_images, AABB &bounds, Size2i &atlas_size, int &atlas_slices, BakeStepFunc p_step_function, void *p_bake_userdata); + void _create_acceleration_structures(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, AABB &bounds, int grid_size, Vector<Probe> &probe_positions, GenerateProbes p_generate_probes, Vector<int> &slice_triangle_count, Vector<int> &slice_seam_count, RID &vertex_buffer, RID &triangle_buffer, RID &box_buffer, RID &lights_buffer, RID &triangle_cell_indices_buffer, RID &probe_positions_buffer, RID &grid_texture, RID &grid_texture_sdf, RID &seams_buffer, BakeStepFunc p_step_function, void *p_bake_userdata); + void _raster_geometry(RenderingDevice *rd, Size2i atlas_size, int atlas_slices, int grid_size, AABB bounds, float p_bias, Vector<int> slice_triangle_count, RID position_tex, RID unocclude_tex, RID normal_tex, RID raster_depth_buffer, RID rasterize_shader, RID raster_base_uniform); + +public: + virtual void add_mesh(const MeshData &p_mesh); + virtual void add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance); + virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size); + virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size); + virtual void add_probe(const Vector3 &p_position); + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_bake_userdata = nullptr); + + int get_bake_texture_count() const; + Ref<Image> get_bake_texture(int p_index) const; + int get_bake_mesh_count() const; + Variant get_bake_mesh_userdata(int p_index) const; + Rect2 get_bake_mesh_uv_scale(int p_index) const; + int get_bake_mesh_texture_slice(int p_index) const; + int get_bake_probe_count() const; + Vector3 get_bake_probe_point(int p_probe) const; + Vector<Color> get_bake_probe_sh(int p_probe) const; + + LightmapperRD(); +}; + +#endif // LIGHTMAPPER_H diff --git a/modules/lightmapper_rd/lm_blendseams.glsl b/modules/lightmapper_rd/lm_blendseams.glsl new file mode 100644 index 0000000000..ef1ece8ea1 --- /dev/null +++ b/modules/lightmapper_rd/lm_blendseams.glsl @@ -0,0 +1,117 @@ +/* clang-format off */ +[versions] + +lines = "#define MODE_LINES" +triangles = "#define MODE_TRIANGLES" + +[vertex] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(push_constant, binding = 0, std430) uniform Params { + uint base_index; + uint slice; + vec2 uv_offset; + bool debug; + float blend; + uint pad[2]; + } params; + +layout(location = 0) out vec3 uv_interp; + +void main() { + +#ifdef MODE_TRIANGLES + + uint triangle_idx = params.base_index + gl_VertexIndex / 3; + uint triangle_subidx = gl_VertexIndex % 3; + + vec2 uv; + if (triangle_subidx == 0) { + uv = vertices.data[triangles.data[triangle_idx].indices.x].uv; + } else if (triangle_subidx == 1) { + uv = vertices.data[triangles.data[triangle_idx].indices.y].uv; + } else { + uv = vertices.data[triangles.data[triangle_idx].indices.z].uv; + } + + uv_interp = vec3(uv, float(params.slice)); + gl_Position = vec4((uv + params.uv_offset) * 2.0 - 1.0, 0.0001, 1.0); + +#endif + +#ifdef MODE_LINES + uint seam_idx = params.base_index + gl_VertexIndex / 4; + uint seam_subidx = gl_VertexIndex % 4; + + uint src_idx; + uint dst_idx; + + if (seam_subidx == 0) { + src_idx = seams.data[seam_idx].b.x; + dst_idx = seams.data[seam_idx].a.x; + } else if (seam_subidx == 1) { + src_idx = seams.data[seam_idx].b.y; + dst_idx = seams.data[seam_idx].a.y; + } else if (seam_subidx == 2) { + src_idx = seams.data[seam_idx].a.x; + dst_idx = seams.data[seam_idx].b.x; + } else if (seam_subidx == 3) { + src_idx = seams.data[seam_idx].a.y; + dst_idx = seams.data[seam_idx].b.y; + } + + vec2 src_uv = vertices.data[src_idx].uv; + vec2 dst_uv = vertices.data[dst_idx].uv + params.uv_offset; + + uv_interp = vec3(src_uv, float(params.slice)); + gl_Position = vec4(dst_uv * 2.0 - 1.0, 0.0001, 1.0); + ; +#endif +} + +/* clang-format off */ +[fragment] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(push_constant, binding = 0, std430) uniform Params { + uint base_index; + uint slice; + vec2 uv_offset; + bool debug; + float blend; + uint pad[2]; + } params; + +layout(location = 0) in vec3 uv_interp; + +layout(location = 0) out vec4 dst_color; + +layout(set = 1, binding = 0) uniform texture2DArray src_color_tex; + +void main() { + + if (params.debug) { +#ifdef MODE_TRIANGLES + dst_color = vec4(1, 0, 1, 1); +#else + dst_color = vec4(1, 1, 0, 1); +#endif + } else { + vec4 src_color = textureLod(sampler2DArray(src_color_tex, linear_sampler), uv_interp, 0.0); + dst_color = vec4(src_color.rgb, params.blend); //mix + } +} diff --git a/modules/lightmapper_rd/lm_common_inc.glsl b/modules/lightmapper_rd/lm_common_inc.glsl new file mode 100644 index 0000000000..0ff455936e --- /dev/null +++ b/modules/lightmapper_rd/lm_common_inc.glsl @@ -0,0 +1,92 @@ + +/* SET 0, static data that does not change between any call */ + +struct Vertex { + vec3 position; + float normal_z; + vec2 uv; + vec2 normal_xy; +}; + +layout(set = 0, binding = 1, std430) restrict readonly buffer Vertices { + Vertex data[]; +} +vertices; + +struct Triangle { + uvec3 indices; + uint slice; +}; + +layout(set = 0, binding = 2, std430) restrict readonly buffer Triangles { + Triangle data[]; +} +triangles; + +struct Box { + vec3 min_bounds; + uint pad0; + vec3 max_bounds; + uint pad1; +}; + +layout(set = 0, binding = 3, std430) restrict readonly buffer Boxes { + Box data[]; +} +boxes; + +layout(set = 0, binding = 4, std430) restrict readonly buffer GridIndices { + uint data[]; +} +grid_indices; + +#define LIGHT_TYPE_DIRECTIONAL 0 +#define LIGHT_TYPE_OMNI 1 +#define LIGHT_TYPE_SPOT 2 + +struct Light { + vec3 position; + uint type; + + vec3 direction; + float energy; + + vec3 color; + float size; + + float range; + float attenuation; + float spot_angle; + float spot_attenuation; + + bool static_bake; + uint pad[3]; +}; + +layout(set = 0, binding = 5, std430) restrict readonly buffer Lights { + Light data[]; +} +lights; + +struct Seam { + uvec2 a; + uvec2 b; +}; + +layout(set = 0, binding = 6, std430) restrict readonly buffer Seams { + Seam data[]; +} +seams; + +layout(set = 0, binding = 7, std430) restrict readonly buffer Probes { + vec4 data[]; +} +probe_positions; + +layout(set = 0, binding = 8) uniform utexture3D grid; +layout(set = 0, binding = 9) uniform texture3D grid_sdf; + +layout(set = 0, binding = 10) uniform texture2DArray albedo_tex; +layout(set = 0, binding = 11) uniform texture2DArray emission_tex; + +layout(set = 0, binding = 12) uniform sampler linear_sampler; diff --git a/modules/lightmapper_rd/lm_compute.glsl b/modules/lightmapper_rd/lm_compute.glsl new file mode 100644 index 0000000000..a178bd9b2e --- /dev/null +++ b/modules/lightmapper_rd/lm_compute.glsl @@ -0,0 +1,657 @@ +/* clang-format off */ +[versions] + +primary = "#define MODE_DIRECT_LIGHT" +secondary = "#define MODE_BOUNCE_LIGHT" +dilate = "#define MODE_DILATE" +unocclude = "#define MODE_UNOCCLUDE" +light_probes = "#define MODE_LIGHT_PROBES" + +[compute] + +#version 450 + +VERSION_DEFINES + +// One 2D local group focusing in one layer at a time, though all +// in parallel (no barriers) makes more sense than a 3D local group +// as this can take more advantage of the cache for each group. + +#ifdef MODE_LIGHT_PROBES + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +#else + +layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in; + +#endif + +#include "lm_common_inc.glsl" + +/* clang-format on */ + +#ifdef MODE_LIGHT_PROBES + +layout(set = 1, binding = 0, std430) restrict buffer LightProbeData { + vec4 data[]; +} +light_probes; + +layout(set = 1, binding = 1) uniform texture2DArray source_light; +layout(set = 1, binding = 2) uniform texture2DArray source_direct_light; //also need the direct light, which was omitted +layout(set = 1, binding = 3) uniform texture2D environment; +#endif + +#ifdef MODE_UNOCCLUDE + +layout(rgba32f, set = 1, binding = 0) uniform restrict image2DArray position; +layout(rgba32f, set = 1, binding = 1) uniform restrict readonly image2DArray unocclude; + +#endif + +#if defined(MODE_DIRECT_LIGHT) || defined(MODE_BOUNCE_LIGHT) + +layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2DArray dest_light; +layout(set = 1, binding = 1) uniform texture2DArray source_light; +layout(set = 1, binding = 2) uniform texture2DArray source_position; +layout(set = 1, binding = 3) uniform texture2DArray source_normal; +layout(rgba16f, set = 1, binding = 4) uniform restrict image2DArray accum_light; + +#endif + +#ifdef MODE_BOUNCE_LIGHT +layout(rgba32f, set = 1, binding = 5) uniform restrict image2DArray bounce_accum; +layout(set = 1, binding = 6) uniform texture2D environment; +#endif +#ifdef MODE_DIRECT_LIGHT +layout(rgba32f, set = 1, binding = 5) uniform restrict writeonly image2DArray primary_dynamic; +#endif + +#ifdef MODE_DILATE +layout(rgba16f, set = 1, binding = 0) uniform restrict writeonly image2DArray dest_light; +layout(set = 1, binding = 1) uniform texture2DArray source_light; +#endif + +layout(push_constant, binding = 0, std430) uniform Params { + ivec2 atlas_size; // x used for light probe mode total probes + uint ray_count; + uint ray_to; + + vec3 world_size; + float bias; + + vec3 to_cell_offset; + uint ray_from; + + vec3 to_cell_size; + uint light_count; + + int grid_size; + int atlas_slice; + ivec2 region_ofs; + + mat3x4 env_transform; +} +params; + +//check it, but also return distance and barycentric coords (for uv lookup) +bool ray_hits_triangle(vec3 from, vec3 dir, float max_dist, vec3 p0, vec3 p1, vec3 p2, out float r_distance, out vec3 r_barycentric) { + + const vec3 e0 = p1 - p0; + const vec3 e1 = p0 - p2; + vec3 triangleNormal = cross(e1, e0); + + const vec3 e2 = (1.0 / dot(triangleNormal, dir)) * (p0 - from); + const vec3 i = cross(dir, e2); + + r_barycentric.y = dot(i, e1); + r_barycentric.z = dot(i, e0); + r_barycentric.x = 1.0 - (r_barycentric.z + r_barycentric.y); + r_distance = dot(triangleNormal, e2); + return (r_distance > params.bias) && (r_distance < max_dist) && all(greaterThanEqual(r_barycentric, vec3(0.0))); +} + +bool trace_ray(vec3 p_from, vec3 p_to +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + , + out uint r_triangle, out vec3 r_barycentric +#endif +#if defined(MODE_UNOCCLUDE) + , + out float r_distance, out vec3 r_normal +#endif +) { + + /* world coords */ + + vec3 rel = p_to - p_from; + float rel_len = length(rel); + vec3 dir = normalize(rel); + vec3 inv_dir = 1.0 / dir; + + /* cell coords */ + + vec3 from_cell = (p_from - params.to_cell_offset) * params.to_cell_size; + vec3 to_cell = (p_to - params.to_cell_offset) * params.to_cell_size; + + //prepare DDA + vec3 rel_cell = to_cell - from_cell; + ivec3 icell = ivec3(from_cell); + ivec3 iendcell = ivec3(to_cell); + vec3 dir_cell = normalize(rel_cell); + vec3 delta = abs(1.0 / dir_cell); //vec3(length(rel_cell)) / rel_cell); + ivec3 step = ivec3(sign(rel_cell)); + vec3 side = (sign(rel_cell) * (vec3(icell) - from_cell) + (sign(rel_cell) * 0.5) + 0.5) * delta; + + uint iters = 0; + while (all(greaterThanEqual(icell, ivec3(0))) && all(lessThan(icell, ivec3(params.grid_size))) && iters < 1000) { + + uvec2 cell_data = texelFetch(usampler3D(grid, linear_sampler), icell, 0).xy; + if (cell_data.x > 0) { //triangles here + + bool hit = false; +#if defined(MODE_UNOCCLUDE) + bool hit_backface = false; +#endif + float best_distance = 1e20; + + for (uint i = 0; i < cell_data.x; i++) { + uint tidx = grid_indices.data[cell_data.y + i]; + + //Ray-Box test + vec3 t0 = (boxes.data[tidx].min_bounds - p_from) * inv_dir; + vec3 t1 = (boxes.data[tidx].max_bounds - p_from) * inv_dir; + vec3 tmin = min(t0, t1), tmax = max(t0, t1); + + if (max(tmin.x, max(tmin.y, tmin.z)) <= min(tmax.x, min(tmax.y, tmax.z))) { + continue; //ray box failed + } + + //prepare triangle vertices + vec3 vtx0 = vertices.data[triangles.data[tidx].indices.x].position; + vec3 vtx1 = vertices.data[triangles.data[tidx].indices.y].position; + vec3 vtx2 = vertices.data[triangles.data[tidx].indices.z].position; +#if defined(MODE_UNOCCLUDE) + vec3 normal = -normalize(cross((vtx0 - vtx1), (vtx0 - vtx2))); + + bool backface = dot(normal, dir) >= 0.0; +#endif + float distance; + vec3 barycentric; + + if (ray_hits_triangle(p_from, dir, rel_len, vtx0, vtx1, vtx2, distance, barycentric)) { +#ifdef MODE_DIRECT_LIGHT + return true; //any hit good +#endif + +#if defined(MODE_UNOCCLUDE) + if (!backface) { + // the case of meshes having both a front and back face in the same plane is more common than + // expected, so if this is a front-face, bias it closer to the ray origin, so it always wins over the back-face + distance = max(params.bias, distance - params.bias); + } + + hit = true; + + if (distance < best_distance) { + hit_backface = backface; + best_distance = distance; + r_distance = distance; + r_normal = normal; + } + +#endif + +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + + hit = true; + if (distance < best_distance) { + best_distance = distance; + r_triangle = tidx; + r_barycentric = barycentric; + } + +#endif + } + } +#if defined(MODE_UNOCCLUDE) + + if (hit) { + return hit_backface; + } +#endif +#if defined(MODE_BOUNCE_LIGHT) || defined(MODE_LIGHT_PROBES) + if (hit) { + return true; + } +#endif + } + + if (icell == iendcell) { + break; + } + + bvec3 mask = lessThanEqual(side.xyz, min(side.yzx, side.zxy)); + side += vec3(mask) * delta; + icell += ivec3(vec3(mask)) * step; + + iters++; + } + + return false; +} + +const float PI = 3.14159265f; +const float GOLDEN_ANGLE = PI * (3.0 - sqrt(5.0)); + +vec3 vogel_hemisphere(uint p_index, uint p_count, float p_offset) { + float r = sqrt(float(p_index) + 0.5f) / sqrt(float(p_count)); + float theta = float(p_index) * GOLDEN_ANGLE + p_offset; + float y = cos(r * PI * 0.5); + float l = sin(r * PI * 0.5); + return vec3(l * cos(theta), l * sin(theta), y); +} + +float quick_hash(vec2 pos) { + return fract(sin(dot(pos * 19.19, vec2(49.5791, 97.413))) * 49831.189237); +} + +void main() { + +#ifdef MODE_LIGHT_PROBES + int probe_index = int(gl_GlobalInvocationID.x); + if (probe_index >= params.atlas_size.x) { //too large, do nothing + return; + } + +#else + ivec2 atlas_pos = ivec2(gl_GlobalInvocationID.xy) + params.region_ofs; + if (any(greaterThanEqual(atlas_pos, params.atlas_size))) { //too large, do nothing + return; + } +#endif + +#ifdef MODE_DIRECT_LIGHT + + vec3 normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + if (length(normal) < 0.5) { + return; //empty texel, no process + } + vec3 position = texelFetch(sampler2DArray(source_position, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + + //go through all lights + //start by own light (emissive) + vec3 static_light = vec3(0.0); + vec3 dynamic_light = vec3(0.0); + +#ifdef USE_SH_LIGHTMAPS + vec4 sh_accum[4] = vec4[]( + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0)); +#endif + + for (uint i = 0; i < params.light_count; i++) { + + vec3 light_pos; + float attenuation; + if (lights.data[i].type == LIGHT_TYPE_DIRECTIONAL) { + vec3 light_vec = lights.data[i].direction; + light_pos = position - light_vec * length(params.world_size); + attenuation = 1.0; + } else { + light_pos = lights.data[i].position; + float d = distance(position, light_pos); + if (d > lights.data[i].range) { + continue; + } + + d /= lights.data[i].range; + + attenuation = pow(max(1.0 - d, 0.0), lights.data[i].attenuation); + + if (lights.data[i].type == LIGHT_TYPE_SPOT) { + + vec3 rel = normalize(position - light_pos); + float angle = acos(dot(rel, lights.data[i].direction)); + if (angle > lights.data[i].spot_angle) { + continue; //invisible, dont try + } + + float d = clamp(angle / lights.data[i].spot_angle, 0, 1); + attenuation *= pow(1.0 - d, lights.data[i].spot_attenuation); + } + } + + vec3 light_dir = normalize(light_pos - position); + attenuation *= max(0.0, dot(normal, light_dir)); + + if (attenuation <= 0.0001) { + continue; //no need to do anything + } + + if (!trace_ray(position + light_dir * params.bias, light_pos)) { + vec3 light = lights.data[i].color * lights.data[i].energy * attenuation; + if (lights.data[i].static_bake) { + static_light += light; +#ifdef USE_SH_LIGHTMAPS + + float c[4] = float[]( + 0.282095, //l0 + 0.488603 * light_dir.y, //l1n1 + 0.488603 * light_dir.z, //l1n0 + 0.488603 * light_dir.x //l1p1 + ); + + for (uint j = 0; j < 4; j++) { + sh_accum[j].rgb += light * c[j] * (1.0 / 3.0); + } +#endif + + } else { + dynamic_light += light; + } + } + } + + vec3 albedo = texelFetch(sampler2DArray(albedo_tex, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).rgb; + vec3 emissive = texelFetch(sampler2DArray(emission_tex, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).rgb; + + dynamic_light *= albedo; //if it will bounce, must multiply by albedo + dynamic_light += emissive; + + //keep for lightprobes + imageStore(primary_dynamic, ivec3(atlas_pos, params.atlas_slice), vec4(dynamic_light, 1.0)); + + dynamic_light += static_light * albedo; //send for bounces + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), vec4(dynamic_light, 1.0)); + +#ifdef USE_SH_LIGHTMAPS + //keep for adding at the end + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 0), sh_accum[0]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 1), sh_accum[1]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 2), sh_accum[2]); + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + 3), sh_accum[3]); + +#else + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice), vec4(static_light, 1.0)); +#endif + +#endif + +#ifdef MODE_BOUNCE_LIGHT + + vec3 normal = texelFetch(sampler2DArray(source_normal, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + if (length(normal) < 0.5) { + return; //empty texel, no process + } + + vec3 position = texelFetch(sampler2DArray(source_position, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0).xyz; + + vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, normal)); + vec3 bitangent = normalize(cross(tangent, normal)); + mat3 normal_mat = mat3(tangent, bitangent, normal); + +#ifdef USE_SH_LIGHTMAPS + vec4 sh_accum[4] = vec4[]( + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0), + vec4(0.0, 0.0, 0.0, 1.0)); +#endif + vec3 light_average = vec3(0.0); + for (uint i = params.ray_from; i < params.ray_to; i++) { + vec3 ray_dir = normal_mat * vogel_hemisphere(i, params.ray_count, quick_hash(vec2(atlas_pos))); + + uint tidx; + vec3 barycentric; + + vec3 light; + if (trace_ray(position + ray_dir * params.bias, position + ray_dir * length(params.world_size), tidx, barycentric)) { + //hit a triangle + vec2 uv0 = vertices.data[triangles.data[tidx].indices.x].uv; + vec2 uv1 = vertices.data[triangles.data[tidx].indices.y].uv; + vec2 uv2 = vertices.data[triangles.data[tidx].indices.z].uv; + vec3 uvw = vec3(barycentric.x * uv0 + barycentric.y * uv1 + barycentric.z * uv2, float(triangles.data[tidx].slice)); + + light = textureLod(sampler2DArray(source_light, linear_sampler), uvw, 0.0).rgb; + } else { + //did not hit a triangle, reach out for the sky + vec3 sky_dir = normalize(mat3(params.env_transform) * ray_dir); + + vec2 st = vec2( + atan(sky_dir.x, sky_dir.z), + acos(sky_dir.y)); + + if (st.x < 0.0) + st.x += PI * 2.0; + + st /= vec2(PI * 2.0, PI); + + light = textureLod(sampler2D(environment, linear_sampler), st, 0.0).rgb; + } + + light_average += light; + +#ifdef USE_SH_LIGHTMAPS + + float c[4] = float[]( + 0.282095, //l0 + 0.488603 * ray_dir.y, //l1n1 + 0.488603 * ray_dir.z, //l1n0 + 0.488603 * ray_dir.x //l1p1 + ); + + for (uint j = 0; j < 4; j++) { + sh_accum[j].rgb += light * c[j] * (8.0 / float(params.ray_count)); + } +#endif + } + + vec3 light_total; + if (params.ray_from == 0) { + light_total = vec3(0.0); + } else { + light_total = imageLoad(bounce_accum, ivec3(atlas_pos, params.atlas_slice)).rgb; + } + + light_total += light_average; + +#ifdef USE_SH_LIGHTMAPS + + for (int i = 0; i < 4; i++) { + vec4 accum = imageLoad(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + i)); + accum.rgb += sh_accum[i].rgb; + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice * 4 + i), accum); + } + +#endif + if (params.ray_to == params.ray_count) { + light_total /= float(params.ray_count); + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), vec4(light_total, 1.0)); +#ifndef USE_SH_LIGHTMAPS + vec4 accum = imageLoad(accum_light, ivec3(atlas_pos, params.atlas_slice)); + accum.rgb += light_total; + imageStore(accum_light, ivec3(atlas_pos, params.atlas_slice), accum); +#endif + } else { + imageStore(bounce_accum, ivec3(atlas_pos, params.atlas_slice), vec4(light_total, 1.0)); + } + +#endif + +#ifdef MODE_UNOCCLUDE + + //texel_size = 0.5; + //compute tangents + + vec4 position_alpha = imageLoad(position, ivec3(atlas_pos, params.atlas_slice)); + if (position_alpha.a < 0.5) { + return; + } + + vec3 vertex_pos = position_alpha.xyz; + vec4 normal_tsize = imageLoad(unocclude, ivec3(atlas_pos, params.atlas_slice)); + + vec3 face_normal = normal_tsize.xyz; + float texel_size = normal_tsize.w; + + vec3 v0 = abs(face_normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, face_normal)); + vec3 bitangent = normalize(cross(tangent, face_normal)); + vec3 base_pos = vertex_pos + face_normal * params.bias; //raise a bit + + vec3 rays[4] = vec3[](tangent, bitangent, -tangent, -bitangent); + float min_d = 1e20; + for (int i = 0; i < 4; i++) { + vec3 ray_to = base_pos + rays[i] * texel_size; + float d; + vec3 norm; + + if (trace_ray(base_pos, ray_to, d, norm)) { + + if (d < min_d) { + vertex_pos = base_pos + rays[i] * d + norm * params.bias * 10.0; //this bias needs to be greater than the regular bias, because otherwise later, rays will go the other side when pointing back. + min_d = d; + } + } + } + + position_alpha.xyz = vertex_pos; + + imageStore(position, ivec3(atlas_pos, params.atlas_slice), position_alpha); + +#endif + +#ifdef MODE_LIGHT_PROBES + + vec3 position = probe_positions.data[probe_index].xyz; + + vec4 probe_sh_accum[9] = vec4[]( + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0), + vec4(0.0)); + + for (uint i = params.ray_from; i < params.ray_to; i++) { + vec3 ray_dir = vogel_hemisphere(i, params.ray_count, quick_hash(vec2(float(probe_index), 0.0))); + if (bool(i & 1)) { + //throw to both sides, so alternate them + ray_dir.z *= -1.0; + } + + uint tidx; + vec3 barycentric; + vec3 light; + + if (trace_ray(position + ray_dir * params.bias, position + ray_dir * length(params.world_size), tidx, barycentric)) { + vec2 uv0 = vertices.data[triangles.data[tidx].indices.x].uv; + vec2 uv1 = vertices.data[triangles.data[tidx].indices.y].uv; + vec2 uv2 = vertices.data[triangles.data[tidx].indices.z].uv; + vec3 uvw = vec3(barycentric.x * uv0 + barycentric.y * uv1 + barycentric.z * uv2, float(triangles.data[tidx].slice)); + + light = textureLod(sampler2DArray(source_light, linear_sampler), uvw, 0.0).rgb; + light += textureLod(sampler2DArray(source_direct_light, linear_sampler), uvw, 0.0).rgb; + } else { + + //did not hit a triangle, reach out for the sky + vec3 sky_dir = normalize(mat3(params.env_transform) * ray_dir); + + vec2 st = vec2( + atan(sky_dir.x, sky_dir.z), + acos(sky_dir.y)); + + if (st.x < 0.0) + st.x += PI * 2.0; + + st /= vec2(PI * 2.0, PI); + + light = textureLod(sampler2D(environment, linear_sampler), st, 0.0).rgb; + } + + { + float c[9] = float[]( + 0.282095, //l0 + 0.488603 * ray_dir.y, //l1n1 + 0.488603 * ray_dir.z, //l1n0 + 0.488603 * ray_dir.x, //l1p1 + 1.092548 * ray_dir.x * ray_dir.y, //l2n2 + 1.092548 * ray_dir.y * ray_dir.z, //l2n1 + //0.315392 * (ray_dir.x * ray_dir.x + ray_dir.y * ray_dir.y + 2.0 * ray_dir.z * ray_dir.z), //l20 + 0.315392 * (3.0 * ray_dir.z * ray_dir.z - 1.0), //l20 + 1.092548 * ray_dir.x * ray_dir.z, //l2p1 + 0.546274 * (ray_dir.x * ray_dir.x - ray_dir.y * ray_dir.y) //l2p2 + ); + + for (uint j = 0; j < 9; j++) { + probe_sh_accum[j].rgb += light * c[j]; + } + } + } + + if (params.ray_from > 0) { + for (uint j = 0; j < 9; j++) { //accum from existing + probe_sh_accum[j] += light_probes.data[probe_index * 9 + j]; + } + } + + if (params.ray_to == params.ray_count) { + for (uint j = 0; j < 9; j++) { //accum from existing + probe_sh_accum[j] *= 4.0 / float(params.ray_count); + } + } + + for (uint j = 0; j < 9; j++) { //accum from existing + light_probes.data[probe_index * 9 + j] = probe_sh_accum[j]; + } + +#endif + +#ifdef MODE_DILATE + + vec4 c = texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos, params.atlas_slice), 0); + //sides first, as they are closer + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, -1), params.atlas_slice), 0); + //endpoints second + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 1), params.atlas_slice), 0); + + //far sides third + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 0), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(0, -2), params.atlas_slice), 0); + + //far-mid endpoints + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, -1), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 1), params.atlas_slice), 0); + + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-1, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(1, 2), params.atlas_slice), 0); + //far endpoints + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(-2, 2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, -2), params.atlas_slice), 0); + c = c.a > 0.5 ? c : texelFetch(sampler2DArray(source_light, linear_sampler), ivec3(atlas_pos + ivec2(2, 2), params.atlas_slice), 0); + + imageStore(dest_light, ivec3(atlas_pos, params.atlas_slice), c); + +#endif +} diff --git a/modules/lightmapper_rd/lm_raster.glsl b/modules/lightmapper_rd/lm_raster.glsl new file mode 100644 index 0000000000..ae3038aead --- /dev/null +++ b/modules/lightmapper_rd/lm_raster.glsl @@ -0,0 +1,170 @@ +/* clang-format off */ +[vertex] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + /* clang-format on */ + + layout(location = 0) out vec3 vertex_interp; +layout(location = 1) out vec3 normal_interp; +layout(location = 2) out vec2 uv_interp; +layout(location = 3) out vec3 barycentric; +layout(location = 4) flat out uvec3 vertex_indices; +layout(location = 5) flat out vec3 face_normal; + +layout(push_constant, binding = 0, std430) uniform Params { + vec2 atlas_size; + vec2 uv_offset; + vec3 to_cell_size; + uint base_triangle; + vec3 to_cell_offset; + float bias; + ivec3 grid_size; + uint pad2; +} +params; + +/* clang-format on */ + +void main() { + + uint triangle_idx = params.base_triangle + gl_VertexIndex / 3; + uint triangle_subidx = gl_VertexIndex % 3; + + vertex_indices = triangles.data[triangle_idx].indices; + + uint vertex_idx; + if (triangle_subidx == 0) { + vertex_idx = vertex_indices.x; + barycentric = vec3(1, 0, 0); + } else if (triangle_subidx == 1) { + vertex_idx = vertex_indices.y; + barycentric = vec3(0, 1, 0); + } else { + vertex_idx = vertex_indices.z; + barycentric = vec3(0, 0, 1); + } + + vertex_interp = vertices.data[vertex_idx].position; + uv_interp = vertices.data[vertex_idx].uv; + normal_interp = vec3(vertices.data[vertex_idx].normal_xy, vertices.data[vertex_idx].normal_z); + + face_normal = -normalize(cross((vertices.data[vertex_indices.x].position - vertices.data[vertex_indices.y].position), (vertices.data[vertex_indices.x].position - vertices.data[vertex_indices.z].position))); + + gl_Position = vec4((uv_interp + params.uv_offset) * 2.0 - 1.0, 0.0001, 1.0); + ; +} + +/* clang-format off */ + +[fragment] + +#version 450 + +VERSION_DEFINES + +#include "lm_common_inc.glsl" + + +layout(push_constant, binding = 0, std430) uniform Params { + vec2 atlas_size; + vec2 uv_offset; + vec3 to_cell_size; + uint base_triangle; + vec3 to_cell_offset; + float bias; + ivec3 grid_size; + uint pad2; +} params; + +/* clang-format on */ + +layout(location = 0) in vec3 vertex_interp; +layout(location = 1) in vec3 normal_interp; +layout(location = 2) in vec2 uv_interp; +layout(location = 3) in vec3 barycentric; +layout(location = 4) in flat uvec3 vertex_indices; +layout(location = 5) in flat vec3 face_normal; + +layout(location = 0) out vec4 position; +layout(location = 1) out vec4 normal; +layout(location = 2) out vec4 unocclude; + +void main() { + + vec3 vertex_pos = vertex_interp; + + { + // smooth out vertex position by interpolating its projection in the 3 normal planes (normal plane is created by vertex pos and normal) + // because we don't want to interpolate inwards, normals found pointing inwards are pushed out. + + vec3 pos_a = vertices.data[vertex_indices.x].position; + vec3 pos_b = vertices.data[vertex_indices.y].position; + vec3 pos_c = vertices.data[vertex_indices.z].position; + vec3 center = (pos_a + pos_b + pos_c) * 0.3333333; + vec3 norm_a = vec3(vertices.data[vertex_indices.x].normal_xy, vertices.data[vertex_indices.x].normal_z); + vec3 norm_b = vec3(vertices.data[vertex_indices.y].normal_xy, vertices.data[vertex_indices.y].normal_z); + vec3 norm_c = vec3(vertices.data[vertex_indices.z].normal_xy, vertices.data[vertex_indices.z].normal_z); + + { + vec3 dir_a = normalize(pos_a - center); + float d_a = dot(dir_a, norm_a); + if (d_a < 0) { + //pointing inwards + norm_a = normalize(norm_a - dir_a * d_a); + } + } + { + vec3 dir_b = normalize(pos_b - center); + float d_b = dot(dir_b, norm_b); + if (d_b < 0) { + //pointing inwards + norm_b = normalize(norm_b - dir_b * d_b); + } + } + { + vec3 dir_c = normalize(pos_c - center); + float d_c = dot(dir_c, norm_c); + if (d_c < 0) { + //pointing inwards + norm_c = normalize(norm_c - dir_c * d_c); + } + } + + float d_a = dot(norm_a, pos_a); + float d_b = dot(norm_b, pos_b); + float d_c = dot(norm_c, pos_c); + + vec3 proj_a = vertex_pos - norm_a * (dot(norm_a, vertex_pos) - d_a); + vec3 proj_b = vertex_pos - norm_b * (dot(norm_b, vertex_pos) - d_b); + vec3 proj_c = vertex_pos - norm_c * (dot(norm_c, vertex_pos) - d_c); + + vec3 smooth_position = proj_a * barycentric.x + proj_b * barycentric.y + proj_c * barycentric.z; + + if (dot(face_normal, smooth_position) > dot(face_normal, vertex_pos)) { //only project outwards + vertex_pos = smooth_position; + } + } + + { + // unocclusion technique based on: + // https://ndotl.wordpress.com/2018/08/29/baking-artifact-free-lightmaps/ + + /* compute texel size */ + vec3 delta_uv = max(abs(dFdx(vertex_interp)), abs(dFdy(vertex_interp))); + float texel_size = max(delta_uv.x, max(delta_uv.y, delta_uv.z)); + texel_size *= sqrt(2.0); //expand to unit box edge length (again, worst case) + + unocclude.xyz = face_normal; + unocclude.w = texel_size; + + //continued on lm_compute.glsl + } + + position = vec4(vertex_pos, 1.0); + normal = vec4(normalize(normal_interp), 1.0); +} diff --git a/modules/lightmapper_rd/register_types.cpp b/modules/lightmapper_rd/register_types.cpp new file mode 100644 index 0000000000..f3938f3190 --- /dev/null +++ b/modules/lightmapper_rd/register_types.cpp @@ -0,0 +1,64 @@ +/*************************************************************************/ +/* register_types.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "register_types.h" + +#include "core/project_settings.h" +#include "lightmapper_rd.h" +#include "scene/3d/lightmapper.h" + +#ifndef _3D_DISABLED +static Lightmapper *create_lightmapper_rd() { + return memnew(LightmapperRD); +} +#endif + +void register_lightmapper_rd_types() { + + GLOBAL_DEF("rendering/gpu_lightmapper/quality/low_quality_ray_count", 16); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/medium_quality_ray_count", 64); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/high_quality_ray_count", 256); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/ultra_quality_ray_count", 1024); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/max_rays_per_pass", 32); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/region_size", 512); + + GLOBAL_DEF("rendering/gpu_lightmapper/quality/low_quality_probe_ray_count", 64); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/medium_quality_probe_ray_count", 256); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/high_quality_probe_ray_count", 512); + GLOBAL_DEF("rendering/gpu_lightmapper/quality/ultra_quality_probe_ray_count", 2048); + GLOBAL_DEF("rendering/gpu_lightmapper/performance/max_rays_per_probe_pass", 64); +#ifndef _3D_DISABLED + ClassDB::register_class<LightmapperRD>(); + Lightmapper::create_gpu = create_lightmapper_rd; +#endif +} + +void unregister_lightmapper_rd_types() { +} diff --git a/core/path_remap.h b/modules/lightmapper_rd/register_types.h index 1580e88625..b0e15a927f 100644 --- a/core/path_remap.h +++ b/modules/lightmapper_rd/register_types.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* path_remap.h */ +/* register_types.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,7 +28,10 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#ifndef PATH_REMAP_H -#define PATH_REMAP_H +#ifndef LIGHTMAPPER_RD_REGISTER_TYPES_H +#define LIGHTMAPPER_RD_REGISTER_TYPES_H -#endif // PATH_REMAP_H +void register_lightmapper_rd_types(); +void unregister_lightmapper_rd_types(); + +#endif // XATLAS_UNWRAP_REGISTER_TYPES_H diff --git a/modules/mbedtls/packet_peer_mbed_dtls.cpp b/modules/mbedtls/packet_peer_mbed_dtls.cpp index b2aa5f5827..37477e1246 100755..100644 --- a/modules/mbedtls/packet_peer_mbed_dtls.cpp +++ b/modules/mbedtls/packet_peer_mbed_dtls.cpp @@ -36,7 +36,8 @@ int PacketPeerMbedDTLS::bio_send(void *ctx, const unsigned char *buf, size_t len) { - if (buf == nullptr || len <= 0) return 0; + if (buf == nullptr || len <= 0) + return 0; PacketPeerMbedDTLS *sp = (PacketPeerMbedDTLS *)ctx; @@ -53,7 +54,8 @@ int PacketPeerMbedDTLS::bio_send(void *ctx, const unsigned char *buf, size_t len int PacketPeerMbedDTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) { - if (buf == nullptr || len <= 0) return 0; + if (buf == nullptr || len <= 0) + return 0; PacketPeerMbedDTLS *sp = (PacketPeerMbedDTLS *)ctx; diff --git a/modules/mbedtls/stream_peer_mbedtls.cpp b/modules/mbedtls/stream_peer_mbedtls.cpp index 983095c536..af36b29dac 100755..100644 --- a/modules/mbedtls/stream_peer_mbedtls.cpp +++ b/modules/mbedtls/stream_peer_mbedtls.cpp @@ -35,7 +35,8 @@ int StreamPeerMbedTLS::bio_send(void *ctx, const unsigned char *buf, size_t len) { - if (buf == nullptr || len <= 0) return 0; + if (buf == nullptr || len <= 0) + return 0; StreamPeerMbedTLS *sp = (StreamPeerMbedTLS *)ctx; @@ -54,7 +55,8 @@ int StreamPeerMbedTLS::bio_send(void *ctx, const unsigned char *buf, size_t len) int StreamPeerMbedTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) { - if (buf == nullptr || len <= 0) return 0; + if (buf == nullptr || len <= 0) + return 0; StreamPeerMbedTLS *sp = (StreamPeerMbedTLS *)ctx; diff --git a/modules/mobile_vr/mobile_vr_interface.cpp b/modules/mobile_vr/mobile_vr_interface.cpp index 48276c0d3d..6d10de0096 100644 --- a/modules/mobile_vr/mobile_vr_interface.cpp +++ b/modules/mobile_vr/mobile_vr_interface.cpp @@ -61,13 +61,19 @@ Vector3 MobileVRInterface::scale_magneto(const Vector3 &p_magnetometer) { }; // adjust our min and max - if (mag_raw.x > mag_next_max.x) mag_next_max.x = mag_raw.x; - if (mag_raw.y > mag_next_max.y) mag_next_max.y = mag_raw.y; - if (mag_raw.z > mag_next_max.z) mag_next_max.z = mag_raw.z; - - if (mag_raw.x < mag_next_min.x) mag_next_min.x = mag_raw.x; - if (mag_raw.y < mag_next_min.y) mag_next_min.y = mag_raw.y; - if (mag_raw.z < mag_next_min.z) mag_next_min.z = mag_raw.z; + if (mag_raw.x > mag_next_max.x) + mag_next_max.x = mag_raw.x; + if (mag_raw.y > mag_next_max.y) + mag_next_max.y = mag_raw.y; + if (mag_raw.z > mag_next_max.z) + mag_next_max.z = mag_raw.z; + + if (mag_raw.x < mag_next_min.x) + mag_next_min.x = mag_raw.x; + if (mag_raw.y < mag_next_min.y) + mag_next_min.y = mag_raw.y; + if (mag_raw.z < mag_next_min.z) + mag_next_min.z = mag_raw.z; // scale our x, y and z if (!(mag_current_max.x - mag_current_min.x)) { diff --git a/modules/mono/build_scripts/godot_tools_build.py b/modules/mono/build_scripts/godot_tools_build.py index 8e77f3bb44..7391e8790d 100644 --- a/modules/mono/build_scripts/godot_tools_build.py +++ b/modules/mono/build_scripts/godot_tools_build.py @@ -15,7 +15,7 @@ def build_godot_tools(source, target, env): from .solution_builder import build_solution - build_solution(env, solution_path, build_config, restore=True) + build_solution(env, solution_path, build_config) # No need to copy targets. The GodotTools csproj takes care of copying them. diff --git a/modules/mono/build_scripts/solution_builder.py b/modules/mono/build_scripts/solution_builder.py index e8ddb7114e..371819fd72 100644 --- a/modules/mono/build_scripts/solution_builder.py +++ b/modules/mono/build_scripts/solution_builder.py @@ -4,7 +4,29 @@ import os verbose = False -def find_msbuild_unix(filename): +def find_dotnet_cli(): + import os.path + + if os.name == "nt": + windows_exts = os.environ["PATHEXT"] + windows_exts = windows_exts.split(os.pathsep) if windows_exts else [] + + for hint_dir in os.environ["PATH"].split(os.pathsep): + hint_dir = hint_dir.strip('"') + hint_path = os.path.join(hint_dir, "dotnet") + if os.path.isfile(hint_path) and os.access(hint_path, os.X_OK): + return hint_path + if os.path.isfile(hint_path + ".exe") and os.access(hint_path + ".exe", os.X_OK): + return hint_path + ".exe" + else: + for hint_dir in os.environ["PATH"].split(os.pathsep): + hint_dir = hint_dir.strip('"') + hint_path = os.path.join(hint_dir, "dotnet") + if os.path.isfile(hint_path) and os.access(hint_path, os.X_OK): + return hint_path + + +def find_msbuild_unix(): import os.path import sys @@ -86,15 +108,7 @@ def run_command(command, args, env_override=None, name=None): raise RuntimeError("'%s' exited with error code: %s" % (name, e.returncode)) -def nuget_restore(env, *args): - global verbose - verbose = env["verbose"] - - # Do NuGet restore - run_command(nuget_path, ["restore"] + list(args), name="nuget restore") - - -def build_solution(env, solution_path, build_config, extra_msbuild_args=[], restore=False): +def build_solution(env, solution_path, build_config, extra_msbuild_args=[]): global verbose verbose = env["verbose"] @@ -104,27 +118,33 @@ def build_solution(env, solution_path, build_config, extra_msbuild_args=[], rest if "PLATFORM" in msbuild_env: del msbuild_env["PLATFORM"] - # Find MSBuild - if os.name == "nt": - msbuild_info = find_msbuild_windows(env) - if msbuild_info is None: - raise RuntimeError("Cannot find MSBuild executable") - msbuild_path = msbuild_info[0] - msbuild_env.update(msbuild_info[1]) + msbuild_args = [] + + dotnet_cli = find_dotnet_cli() + + if dotnet_cli: + msbuild_path = dotnet_cli + msbuild_args += ["msbuild"] # `dotnet msbuild` command else: - msbuild_path = find_msbuild_unix("msbuild") - if msbuild_path is None: - raise RuntimeError("Cannot find MSBuild executable") + # Find MSBuild + if os.name == "nt": + msbuild_info = find_msbuild_windows(env) + if msbuild_info is None: + raise RuntimeError("Cannot find MSBuild executable") + msbuild_path = msbuild_info[0] + msbuild_env.update(msbuild_info[1]) + else: + msbuild_path = find_msbuild_unix() + if msbuild_path is None: + raise RuntimeError("Cannot find MSBuild executable") print("MSBuild path: " + msbuild_path) # Build solution - targets = ["Build"] - if restore: - targets.insert(0, "Restore") + targets = ["Restore", "Build"] - msbuild_args = [solution_path, "/t:%s" % ",".join(targets), "/p:Configuration=" + build_config] + msbuild_args += [solution_path, "/t:%s" % ",".join(targets), "/p:Configuration=" + build_config] msbuild_args += extra_msbuild_args run_command(msbuild_path, msbuild_args, env_override=msbuild_env, name="msbuild") diff --git a/modules/mono/csharp_script.cpp b/modules/mono/csharp_script.cpp index 6c1c8b87ef..43c86d3e28 100644 --- a/modules/mono/csharp_script.cpp +++ b/modules/mono/csharp_script.cpp @@ -295,7 +295,7 @@ void CSharpLanguage::get_reserved_words(List<String> *p_words) const { "when", "where", "yield", - 0 + nullptr }; const char **w = _reserved_words; @@ -2421,58 +2421,68 @@ void CSharpScript::_update_member_info_no_exports() { bool CSharpScript::_update_exports() { #ifdef TOOLS_ENABLED - if (!Engine::get_singleton()->is_editor_hint()) - return false; - - placeholder_fallback_enabled = true; // until proven otherwise - + bool is_editor = Engine::get_singleton()->is_editor_hint(); + if (is_editor) + placeholder_fallback_enabled = true; // until proven otherwise +#endif if (!valid) return false; bool changed = false; - if (exports_invalidated) { +#ifdef TOOLS_ENABLED + if (exports_invalidated) +#endif + { GD_MONO_SCOPE_THREAD_ATTACH; - exports_invalidated = false; - changed = true; member_info.clear(); - exported_members_cache.clear(); - exported_members_defval_cache.clear(); - // Here we create a temporary managed instance of the class to get the initial values +#ifdef TOOLS_ENABLED + MonoObject *tmp_object = nullptr; + Object *tmp_native = nullptr; + uint32_t tmp_pinned_gchandle = 0; - MonoObject *tmp_object = mono_object_new(mono_domain_get(), script_class->get_mono_ptr()); + if (is_editor) { + exports_invalidated = false; - if (!tmp_object) { - ERR_PRINT("Failed to allocate temporary MonoObject."); - return false; - } + exported_members_cache.clear(); + exported_members_defval_cache.clear(); - uint32_t tmp_pinned_gchandle = GDMonoUtils::new_strong_gchandle_pinned(tmp_object); // pin it (not sure if needed) + // Here we create a temporary managed instance of the class to get the initial values + tmp_object = mono_object_new(mono_domain_get(), script_class->get_mono_ptr()); - GDMonoMethod *ctor = script_class->get_method(CACHED_STRING_NAME(dotctor), 0); + if (!tmp_object) { + ERR_PRINT("Failed to allocate temporary MonoObject."); + return false; + } - ERR_FAIL_NULL_V_MSG(ctor, false, - "Cannot construct temporary MonoObject because the class does not define a parameterless constructor: '" + get_path() + "'."); + tmp_pinned_gchandle = GDMonoUtils::new_strong_gchandle_pinned(tmp_object); // pin it (not sure if needed) - MonoException *ctor_exc = nullptr; - ctor->invoke(tmp_object, nullptr, &ctor_exc); + GDMonoMethod *ctor = script_class->get_method(CACHED_STRING_NAME(dotctor), 0); - Object *tmp_native = GDMonoMarshal::unbox<Object *>(CACHED_FIELD(GodotObject, ptr)->get_value(tmp_object)); + ERR_FAIL_NULL_V_MSG(ctor, false, + "Cannot construct temporary MonoObject because the class does not define a parameterless constructor: '" + get_path() + "'."); - if (ctor_exc) { - // TODO: Should we free 'tmp_native' if the exception was thrown after its creation? + MonoException *ctor_exc = nullptr; + ctor->invoke(tmp_object, nullptr, &ctor_exc); - GDMonoUtils::free_gchandle(tmp_pinned_gchandle); - tmp_object = nullptr; + tmp_native = GDMonoMarshal::unbox<Object *>(CACHED_FIELD(GodotObject, ptr)->get_value(tmp_object)); - ERR_PRINT("Exception thrown from constructor of temporary MonoObject:"); - GDMonoUtils::debug_print_unhandled_exception(ctor_exc); - return false; + if (ctor_exc) { + // TODO: Should we free 'tmp_native' if the exception was thrown after its creation? + + GDMonoUtils::free_gchandle(tmp_pinned_gchandle); + tmp_object = nullptr; + + ERR_PRINT("Exception thrown from constructor of temporary MonoObject:"); + GDMonoUtils::debug_print_unhandled_exception(ctor_exc); + return false; + } } +#endif GDMonoClass *top = script_class; @@ -2488,16 +2498,16 @@ bool CSharpScript::_update_exports() { if (_get_member_export(field, /* inspect export: */ true, prop_info, exported)) { StringName member_name = field->get_name(); - if (exported) { - member_info[member_name] = prop_info; + member_info[member_name] = prop_info; +#ifdef TOOLS_ENABLED + if (is_editor && exported) { exported_members_cache.push_front(prop_info); if (tmp_object) { exported_members_defval_cache[member_name] = GDMonoMarshal::mono_object_to_variant(field->get_value(tmp_object)); } - } else { - member_info[member_name] = prop_info; } +#endif } } @@ -2509,10 +2519,10 @@ bool CSharpScript::_update_exports() { if (_get_member_export(property, /* inspect export: */ true, prop_info, exported)) { StringName member_name = property->get_name(); - if (exported) { - member_info[member_name] = prop_info; + member_info[member_name] = prop_info; +#ifdef TOOLS_ENABLED + if (is_editor && exported) { exported_members_cache.push_front(prop_info); - if (tmp_object) { MonoException *exc = nullptr; MonoObject *ret = property->get_value(tmp_object, &exc); @@ -2523,57 +2533,62 @@ bool CSharpScript::_update_exports() { exported_members_defval_cache[member_name] = GDMonoMarshal::mono_object_to_variant(ret); } } - } else { - member_info[member_name] = prop_info; } +#endif } } top = top->get_parent_class(); } - // Need to check this here, before disposal - bool base_ref = Object::cast_to<Reference>(tmp_native) != nullptr; +#ifdef TOOLS_ENABLED + if (is_editor) { + // Need to check this here, before disposal + bool base_ref = Object::cast_to<Reference>(tmp_native) != nullptr; - // Dispose the temporary managed instance + // Dispose the temporary managed instance - MonoException *exc = nullptr; - GDMonoUtils::dispose(tmp_object, &exc); + MonoException *exc = nullptr; + GDMonoUtils::dispose(tmp_object, &exc); - if (exc) { - ERR_PRINT("Exception thrown from method Dispose() of temporary MonoObject:"); - GDMonoUtils::debug_print_unhandled_exception(exc); - } + if (exc) { + ERR_PRINT("Exception thrown from method Dispose() of temporary MonoObject:"); + GDMonoUtils::debug_print_unhandled_exception(exc); + } - GDMonoUtils::free_gchandle(tmp_pinned_gchandle); - tmp_object = nullptr; + GDMonoUtils::free_gchandle(tmp_pinned_gchandle); + tmp_object = nullptr; - if (tmp_native && !base_ref) { - Node *node = Object::cast_to<Node>(tmp_native); - if (node && node->is_inside_tree()) { - ERR_PRINT("Temporary instance was added to the scene tree."); - } else { - memdelete(tmp_native); + if (tmp_native && !base_ref) { + Node *node = Object::cast_to<Node>(tmp_native); + if (node && node->is_inside_tree()) { + ERR_PRINT("Temporary instance was added to the scene tree."); + } else { + memdelete(tmp_native); + } } } +#endif } - placeholder_fallback_enabled = false; +#ifdef TOOLS_ENABLED + if (is_editor) { + placeholder_fallback_enabled = false; - if (placeholders.size()) { - // Update placeholders if any - Map<StringName, Variant> values; - List<PropertyInfo> propnames; - _update_exports_values(values, propnames); + if (placeholders.size()) { + // Update placeholders if any + Map<StringName, Variant> values; + List<PropertyInfo> propnames; + _update_exports_values(values, propnames); - for (Set<PlaceHolderScriptInstance *>::Element *E = placeholders.front(); E; E = E->next()) { - E->get()->update(propnames, values); + for (Set<PlaceHolderScriptInstance *>::Element *E = placeholders.front(); E; E = E->next()) { + E->get()->update(propnames, values); + } } } +#endif return changed; -#endif - return false; } void CSharpScript::load_script_signals(GDMonoClass *p_class, GDMonoClass *p_native_class) { @@ -2679,7 +2694,6 @@ bool CSharpScript::_get_signal(GDMonoClass *p_class, GDMonoMethod *p_delegate_in return true; } -#ifdef TOOLS_ENABLED /** * Returns false if there was an error, otherwise true. * If there was an error, r_prop_info and r_exported are not assigned any value. @@ -2693,8 +2707,10 @@ bool CSharpScript::_get_member_export(IMonoClassMember *p_member, bool p_inspect (m_member->get_enclosing_class()->get_full_name() + "." + (String)m_member->get_name()) if (p_member->is_static()) { +#ifdef TOOLS_ENABLED if (p_member->has_attribute(CACHED_CLASS(ExportAttribute))) ERR_PRINT("Cannot export member because it is static: '" + MEMBER_FULL_QUALIFIED_NAME(p_member) + "'."); +#endif return false; } @@ -2716,13 +2732,17 @@ bool CSharpScript::_get_member_export(IMonoClassMember *p_member, bool p_inspect if (p_member->get_member_type() == IMonoClassMember::MEMBER_TYPE_PROPERTY) { GDMonoProperty *property = static_cast<GDMonoProperty *>(p_member); if (!property->has_getter()) { +#ifdef TOOLS_ENABLED if (exported) ERR_PRINT("Read-only property cannot be exported: '" + MEMBER_FULL_QUALIFIED_NAME(p_member) + "'."); +#endif return false; } if (!property->has_setter()) { +#ifdef TOOLS_ENABLED if (exported) ERR_PRINT("Write-only property (without getter) cannot be exported: '" + MEMBER_FULL_QUALIFIED_NAME(p_member) + "'."); +#endif return false; } } @@ -2742,10 +2762,13 @@ bool CSharpScript::_get_member_export(IMonoClassMember *p_member, bool p_inspect String hint_string; if (variant_type == Variant::NIL && !nil_is_variant) { +#ifdef TOOLS_ENABLED ERR_PRINT("Unknown exported member type: '" + MEMBER_FULL_QUALIFIED_NAME(p_member) + "'."); +#endif return false; } +#ifdef TOOLS_ENABLED int hint_res = _try_get_member_export_hint(p_member, type, variant_type, /* allow_generics: */ true, hint, hint_string); ERR_FAIL_COND_V_MSG(hint_res == -1, false, @@ -2756,6 +2779,7 @@ bool CSharpScript::_get_member_export(IMonoClassMember *p_member, bool p_inspect hint = PropertyHint(CACHED_FIELD(ExportAttribute, hint)->get_int_value(attr)); hint_string = CACHED_FIELD(ExportAttribute, hintString)->get_string_value(attr); } +#endif uint32_t prop_usage = PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_SCRIPT_VARIABLE; @@ -2772,6 +2796,7 @@ bool CSharpScript::_get_member_export(IMonoClassMember *p_member, bool p_inspect #undef MEMBER_FULL_QUALIFIED_NAME } +#ifdef TOOLS_ENABLED int CSharpScript::_try_get_member_export_hint(IMonoClassMember *p_member, ManagedType p_type, Variant::Type p_variant_type, bool p_allow_generics, PropertyHint &r_hint, String &r_hint_string) { if (p_variant_type == Variant::NIL) { @@ -3542,10 +3567,15 @@ bool CSharpScript::inherits_script(const Ref<Script> &p_script) const { return false; } -#ifndef _MSC_VER -#warning TODO: Implement CSharpScript::inherits_script and other relevant changes after GH-38063. -#endif - return false; + if (script_class == nullptr || cs->script_class == nullptr) { + return false; + } + + if (script_class == cs->script_class) { + return true; + } + + return cs->script_class->is_assignable_from(script_class); } Ref<Script> CSharpScript::get_base_script() const { diff --git a/modules/mono/csharp_script.h b/modules/mono/csharp_script.h index 05e2857538..280c981a47 100644 --- a/modules/mono/csharp_script.h +++ b/modules/mono/csharp_script.h @@ -147,8 +147,9 @@ private: bool _get_signal(GDMonoClass *p_class, GDMonoMethod *p_delegate_invoke, Vector<SignalParameter> ¶ms); bool _update_exports(); -#ifdef TOOLS_ENABLED + bool _get_member_export(IMonoClassMember *p_member, bool p_inspect_export, PropertyInfo &r_prop_info, bool &r_exported); +#ifdef TOOLS_ENABLED static int _try_get_member_export_hint(IMonoClassMember *p_member, ManagedType p_type, Variant::Type p_variant_type, bool p_allow_generics, PropertyHint &r_hint, String &r_hint_string); #endif @@ -325,17 +326,13 @@ public: }; struct CSharpScriptBinding { - bool inited; + bool inited = false; StringName type_name; - GDMonoClass *wrapper_class; + GDMonoClass *wrapper_class = nullptr; MonoGCHandleData gchandle; - Object *owner; + Object *owner = nullptr; - CSharpScriptBinding() : - inited(false), - wrapper_class(nullptr), - owner(nullptr) { - } + CSharpScriptBinding() {} }; class ManagedCallableMiddleman : public Object { diff --git a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotBuildLogger.cs b/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotBuildLogger.cs index 6015cb22b6..c2549b4ad5 100644 --- a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotBuildLogger.cs +++ b/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotBuildLogger.cs @@ -2,7 +2,6 @@ using System; using System.IO; using System.Security; using Microsoft.Build.Framework; -using GodotTools.Core; namespace GodotTools.BuildLogger { @@ -183,4 +182,17 @@ namespace GodotTools.BuildLogger private StreamWriter issuesStreamWriter; private int indent; } + + internal static class StringExtensions + { + public static string CsvEscape(this string value, char delimiter = ',') + { + bool hasSpecialChar = value.IndexOfAny(new[] { '\"', '\n', '\r', delimiter }) != -1; + + if (hasSpecialChar) + return "\"" + value.Replace("\"", "\"\"") + "\""; + + return value; + } + } } diff --git a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotTools.BuildLogger.csproj b/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotTools.BuildLogger.csproj index 8fdd485209..0afec970c6 100644 --- a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotTools.BuildLogger.csproj +++ b/modules/mono/editor/GodotTools/GodotTools.BuildLogger/GodotTools.BuildLogger.csproj @@ -1,60 +1,10 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project ToolsVersion="4.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> - <Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" /> +<Project Sdk="Microsoft.NET.Sdk"> <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> <ProjectGuid>{6CE9A984-37B1-4F8A-8FE9-609F05F071B3}</ProjectGuid> - <OutputType>Library</OutputType> - <AppDesignerFolder>Properties</AppDesignerFolder> - <RootNamespace>GodotTools.BuildLogger</RootNamespace> - <AssemblyName>GodotTools.BuildLogger</AssemblyName> - <TargetFrameworkVersion>v4.7</TargetFrameworkVersion> - <FileAlignment>512</FileAlignment> - <LangVersion>7</LangVersion> + <TargetFramework>netstandard2.0</TargetFramework> + <LangVersion>7.2</LangVersion> </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <PlatformTarget>AnyCPU</PlatformTarget> - <DebugSymbols>true</DebugSymbols> - <DebugType>portable</DebugType> - <Optimize>false</Optimize> - <OutputPath>bin\Debug\</OutputPath> - <DefineConstants>DEBUG;TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <PlatformTarget>AnyCPU</PlatformTarget> - <DebugType>portable</DebugType> - <Optimize>true</Optimize> - <OutputPath>bin\Release\</OutputPath> - <DefineConstants>TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - </PropertyGroup> - <ItemGroup> - <Reference Include="Microsoft.Build.Framework" /> - <Reference Include="System" /> - <Reference Include="System.Core" /> - <Reference Include="System.Data" /> - <Reference Include="System.Xml" /> - </ItemGroup> - <ItemGroup> - <Compile Include="GodotBuildLogger.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - </ItemGroup> <ItemGroup> - <ProjectReference Include="..\GodotTools.Core\GodotTools.Core.csproj"> - <Project>{639e48bd-44e5-4091-8edd-22d36dc0768d}</Project> - <Name>GodotTools.Core</Name> - </ProjectReference> + <PackageReference Include="Microsoft.Build.Framework" Version="16.5.0" /> </ItemGroup> - <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> - <!-- To modify your build process, add your task inside one of the targets below and uncomment it. - Other similar extension points exist, see Microsoft.Common.targets. - <Target Name="BeforeBuild"> - </Target> - <Target Name="AfterBuild"> - </Target> - --> </Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/Properties/AssemblyInfo.cs b/modules/mono/editor/GodotTools/GodotTools.BuildLogger/Properties/AssemblyInfo.cs deleted file mode 100644 index 4374f21cfa..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.BuildLogger/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System.Reflection; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("GodotTools.BuildLogger")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("")] -[assembly: AssemblyCopyright("Godot Engine contributors")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("6CE9A984-37B1-4F8A-8FE9-609F05F071B3")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/modules/mono/editor/GodotTools/GodotTools.Core/GodotTools.Core.csproj b/modules/mono/editor/GodotTools/GodotTools.Core/GodotTools.Core.csproj index c9ea7d3a2c..d6d8962f90 100644 --- a/modules/mono/editor/GodotTools/GodotTools.Core/GodotTools.Core.csproj +++ b/modules/mono/editor/GodotTools/GodotTools.Core/GodotTools.Core.csproj @@ -1,40 +1,7 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> +<Project Sdk="Microsoft.NET.Sdk"> <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> <ProjectGuid>{639E48BD-44E5-4091-8EDD-22D36DC0768D}</ProjectGuid> - <OutputType>Library</OutputType> - <RootNamespace>GodotTools.Core</RootNamespace> - <AssemblyName>GodotTools.Core</AssemblyName> - <TargetFrameworkVersion>v4.7</TargetFrameworkVersion> - <LangVersion>7</LangVersion> + <TargetFramework>netstandard2.0</TargetFramework> + <LangVersion>7.2</LangVersion> </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <DebugSymbols>true</DebugSymbols> - <DebugType>full</DebugType> - <Optimize>false</Optimize> - <OutputPath>bin\Debug</OutputPath> - <DefineConstants>DEBUG;</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <Optimize>true</Optimize> - <OutputPath>bin\Release</OutputPath> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> - </PropertyGroup> - <ItemGroup> - <Reference Include="System" /> - </ItemGroup> - <ItemGroup> - <Compile Include="FileUtils.cs" /> - <Compile Include="ProcessExtensions.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - <Compile Include="StringExtensions.cs" /> - </ItemGroup> - <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /> </Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.Core/Properties/AssemblyInfo.cs b/modules/mono/editor/GodotTools/GodotTools.Core/Properties/AssemblyInfo.cs deleted file mode 100644 index 699ae6e741..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.Core/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; - -// Information about this assembly is defined by the following attributes. -// Change them to the values specific to your project. - -[assembly: AssemblyTitle("GodotTools.Core")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("")] -[assembly: AssemblyCopyright("Godot Engine contributors")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". -// The form "{Major}.{Minor}.*" will automatically update the build and revision, -// and "{Major}.{Minor}.{Build}.*" will update just the revision. - -[assembly: AssemblyVersion("1.0.*")] - -// The following attributes are used to specify the signing key for the assembly, -// if desired. See the Mono documentation for more information about signing. - -//[assembly: AssemblyDelaySign(false)] -//[assembly: AssemblyKeyFile("")] diff --git a/modules/mono/editor/GodotTools/GodotTools.Core/StringExtensions.cs b/modules/mono/editor/GodotTools/GodotTools.Core/StringExtensions.cs index 326c49f096..7ab5c5fc59 100644 --- a/modules/mono/editor/GodotTools/GodotTools.Core/StringExtensions.cs +++ b/modules/mono/editor/GodotTools/GodotTools.Core/StringExtensions.cs @@ -33,23 +33,13 @@ namespace GodotTools.Core return rooted ? Path.DirectorySeparatorChar + path : path; } - private static readonly string driveRoot = Path.GetPathRoot(Environment.CurrentDirectory); + private static readonly string DriveRoot = Path.GetPathRoot(Environment.CurrentDirectory); public static bool IsAbsolutePath(this string path) { return path.StartsWith("/", StringComparison.Ordinal) || path.StartsWith("\\", StringComparison.Ordinal) || - path.StartsWith(driveRoot, StringComparison.Ordinal); - } - - public static string CsvEscape(this string value, char delimiter = ',') - { - bool hasSpecialChar = value.IndexOfAny(new char[] { '\"', '\n', '\r', delimiter }) != -1; - - if (hasSpecialChar) - return "\"" + value.Replace("\"", "\"\"") + "\""; - - return value; + path.StartsWith(DriveRoot, StringComparison.Ordinal); } public static string ToSafeDirName(this string dirName, bool allowDirSeparator) diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/ConsoleLogger.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/ConsoleLogger.cs deleted file mode 100644 index 7a2ff2ca56..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/ConsoleLogger.cs +++ /dev/null @@ -1,33 +0,0 @@ -using System; - -namespace GodotTools.IdeConnection -{ - public class ConsoleLogger : ILogger - { - public void LogDebug(string message) - { - Console.WriteLine("DEBUG: " + message); - } - - public void LogInfo(string message) - { - Console.WriteLine("INFO: " + message); - } - - public void LogWarning(string message) - { - Console.WriteLine("WARN: " + message); - } - - public void LogError(string message) - { - Console.WriteLine("ERROR: " + message); - } - - public void LogError(string message, Exception e) - { - Console.WriteLine("EXCEPTION: " + message); - Console.WriteLine(e); - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeBase.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeBase.cs deleted file mode 100644 index be89638241..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeBase.cs +++ /dev/null @@ -1,94 +0,0 @@ -using System; -using Path = System.IO.Path; - -namespace GodotTools.IdeConnection -{ - public class GodotIdeBase : IDisposable - { - private ILogger logger; - - public ILogger Logger - { - get => logger ?? (logger = new ConsoleLogger()); - set => logger = value; - } - - private readonly string projectMetadataDir; - - protected const string MetaFileName = "ide_server_meta.txt"; - protected string MetaFilePath => Path.Combine(projectMetadataDir, MetaFileName); - - private GodotIdeConnection connection; - protected readonly object ConnectionLock = new object(); - - public bool IsDisposed { get; private set; } = false; - - public bool IsConnected => connection != null && !connection.IsDisposed && connection.IsConnected; - - public event Action Connected - { - add - { - if (connection != null && !connection.IsDisposed) - connection.Connected += value; - } - remove - { - if (connection != null && !connection.IsDisposed) - connection.Connected -= value; - } - } - - protected GodotIdeConnection Connection - { - get => connection; - set - { - connection?.Dispose(); - connection = value; - } - } - - protected GodotIdeBase(string projectMetadataDir) - { - this.projectMetadataDir = projectMetadataDir; - } - - protected void DisposeConnection() - { - lock (ConnectionLock) - { - connection?.Dispose(); - } - } - - ~GodotIdeBase() - { - Dispose(disposing: false); - } - - public void Dispose() - { - if (IsDisposed) - return; - - lock (ConnectionLock) - { - if (IsDisposed) // lock may not be fair - return; - IsDisposed = true; - } - - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - if (disposing) - { - connection?.Dispose(); - } - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeClient.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeClient.cs deleted file mode 100644 index 2bf3b83c75..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeClient.cs +++ /dev/null @@ -1,219 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Net; -using System.Net.Sockets; -using System.Threading; - -namespace GodotTools.IdeConnection -{ - public abstract class GodotIdeClient : GodotIdeBase - { - protected GodotIdeMetadata GodotIdeMetadata; - - private readonly FileSystemWatcher fsWatcher; - - protected GodotIdeClient(string projectMetadataDir) : base(projectMetadataDir) - { - messageHandlers = InitializeMessageHandlers(); - - // FileSystemWatcher requires an existing directory - if (!File.Exists(projectMetadataDir)) - Directory.CreateDirectory(projectMetadataDir); - - fsWatcher = new FileSystemWatcher(projectMetadataDir, MetaFileName); - } - - private void OnMetaFileChanged(object sender, FileSystemEventArgs e) - { - if (IsDisposed) - return; - - lock (ConnectionLock) - { - if (IsDisposed) - return; - - if (!File.Exists(MetaFilePath)) - return; - - var metadata = ReadMetadataFile(); - - if (metadata != null && metadata != GodotIdeMetadata) - { - GodotIdeMetadata = metadata.Value; - ConnectToServer(); - } - } - } - - private void OnMetaFileDeleted(object sender, FileSystemEventArgs e) - { - if (IsDisposed) - return; - - if (IsConnected) - DisposeConnection(); - - // The file may have been re-created - - lock (ConnectionLock) - { - if (IsDisposed) - return; - - if (IsConnected || !File.Exists(MetaFilePath)) - return; - - var metadata = ReadMetadataFile(); - - if (metadata != null) - { - GodotIdeMetadata = metadata.Value; - ConnectToServer(); - } - } - } - - private GodotIdeMetadata? ReadMetadataFile() - { - using (var reader = File.OpenText(MetaFilePath)) - { - string portStr = reader.ReadLine(); - - if (portStr == null) - return null; - - string editorExecutablePath = reader.ReadLine(); - - if (editorExecutablePath == null) - return null; - - if (!int.TryParse(portStr, out int port)) - return null; - - return new GodotIdeMetadata(port, editorExecutablePath); - } - } - - private void ConnectToServer() - { - var tcpClient = new TcpClient(); - - Connection = new GodotIdeConnectionClient(tcpClient, HandleMessage); - Connection.Logger = Logger; - - try - { - Logger.LogInfo("Connecting to Godot Ide Server"); - - tcpClient.Connect(IPAddress.Loopback, GodotIdeMetadata.Port); - - Logger.LogInfo("Connection open with Godot Ide Server"); - - var clientThread = new Thread(Connection.Start) - { - IsBackground = true, - Name = "Godot Ide Connection Client" - }; - clientThread.Start(); - } - catch (SocketException e) - { - if (e.SocketErrorCode == SocketError.ConnectionRefused) - Logger.LogError("The connection to the Godot Ide Server was refused"); - else - throw; - } - } - - public void Start() - { - Logger.LogInfo("Starting Godot Ide Client"); - - fsWatcher.Changed += OnMetaFileChanged; - fsWatcher.Deleted += OnMetaFileDeleted; - fsWatcher.EnableRaisingEvents = true; - - lock (ConnectionLock) - { - if (IsDisposed) - return; - - if (!File.Exists(MetaFilePath)) - { - Logger.LogInfo("There is no Godot Ide Server running"); - return; - } - - var metadata = ReadMetadataFile(); - - if (metadata != null) - { - GodotIdeMetadata = metadata.Value; - ConnectToServer(); - } - else - { - Logger.LogError("Failed to read Godot Ide metadata file"); - } - } - } - - public bool WriteMessage(Message message) - { - return Connection.WriteMessage(message); - } - - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - - if (disposing) - { - fsWatcher?.Dispose(); - } - } - - protected virtual bool HandleMessage(Message message) - { - if (messageHandlers.TryGetValue(message.Id, out var action)) - { - action(message.Arguments); - return true; - } - - return false; - } - - private readonly Dictionary<string, Action<string[]>> messageHandlers; - - private Dictionary<string, Action<string[]>> InitializeMessageHandlers() - { - return new Dictionary<string, Action<string[]>> - { - ["OpenFile"] = args => - { - switch (args.Length) - { - case 1: - OpenFile(file: args[0]); - return; - case 2: - OpenFile(file: args[0], line: int.Parse(args[1])); - return; - case 3: - OpenFile(file: args[0], line: int.Parse(args[1]), column: int.Parse(args[2])); - return; - default: - throw new ArgumentException(); - } - } - }; - } - - protected abstract void OpenFile(string file); - protected abstract void OpenFile(string file, int line); - protected abstract void OpenFile(string file, int line, int column); - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnection.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnection.cs deleted file mode 100644 index 6441be8d6e..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnection.cs +++ /dev/null @@ -1,207 +0,0 @@ -using System; -using System.Diagnostics; -using System.IO; -using System.Net.Sockets; -using System.Text; - -namespace GodotTools.IdeConnection -{ - public abstract class GodotIdeConnection : IDisposable - { - protected const string Version = "1.0"; - - protected static readonly string ClientHandshake = $"Godot Ide Client Version {Version}"; - protected static readonly string ServerHandshake = $"Godot Ide Server Version {Version}"; - - private const int ClientWriteTimeout = 8000; - private readonly TcpClient tcpClient; - - private TextReader clientReader; - private TextWriter clientWriter; - - private readonly object writeLock = new object(); - - private readonly Func<Message, bool> messageHandler; - - public event Action Connected; - - private ILogger logger; - - public ILogger Logger - { - get => logger ?? (logger = new ConsoleLogger()); - set => logger = value; - } - - public bool IsDisposed { get; private set; } = false; - - public bool IsConnected => tcpClient.Client != null && tcpClient.Client.Connected; - - protected GodotIdeConnection(TcpClient tcpClient, Func<Message, bool> messageHandler) - { - this.tcpClient = tcpClient; - this.messageHandler = messageHandler; - } - - public void Start() - { - try - { - if (!StartConnection()) - return; - - string messageLine; - while ((messageLine = ReadLine()) != null) - { - if (!MessageParser.TryParse(messageLine, out Message msg)) - { - Logger.LogError($"Received message with invalid format: {messageLine}"); - continue; - } - - Logger.LogDebug($"Received message: {msg}"); - - if (msg.Id == "close") - { - Logger.LogInfo("Closing connection"); - return; - } - - try - { - try - { - Debug.Assert(messageHandler != null); - - if (!messageHandler(msg)) - Logger.LogError($"Received unknown message: {msg}"); - } - catch (Exception e) - { - Logger.LogError($"Message handler for '{msg}' failed with exception", e); - } - } - catch (Exception e) - { - Logger.LogError($"Exception thrown from message handler. Message: {msg}", e); - } - } - } - catch (Exception e) - { - Logger.LogError($"Unhandled exception in the Godot Ide Connection thread", e); - } - finally - { - Dispose(); - } - } - - private bool StartConnection() - { - NetworkStream clientStream = tcpClient.GetStream(); - - clientReader = new StreamReader(clientStream, Encoding.UTF8); - - lock (writeLock) - clientWriter = new StreamWriter(clientStream, Encoding.UTF8); - - clientStream.WriteTimeout = ClientWriteTimeout; - - if (!WriteHandshake()) - { - Logger.LogError("Could not write handshake"); - return false; - } - - if (!IsValidResponseHandshake(ReadLine())) - { - Logger.LogError("Received invalid handshake"); - return false; - } - - Connected?.Invoke(); - - Logger.LogInfo("Godot Ide connection started"); - - return true; - } - - private string ReadLine() - { - try - { - return clientReader?.ReadLine(); - } - catch (Exception e) - { - if (IsDisposed) - { - var se = e as SocketException ?? e.InnerException as SocketException; - if (se != null && se.SocketErrorCode == SocketError.Interrupted) - return null; - } - - throw; - } - } - - public bool WriteMessage(Message message) - { - Logger.LogDebug($"Sending message {message}"); - - var messageComposer = new MessageComposer(); - - messageComposer.AddArgument(message.Id); - foreach (string argument in message.Arguments) - messageComposer.AddArgument(argument); - - return WriteLine(messageComposer.ToString()); - } - - protected bool WriteLine(string text) - { - if (clientWriter == null || IsDisposed || !IsConnected) - return false; - - lock (writeLock) - { - try - { - clientWriter.WriteLine(text); - clientWriter.Flush(); - } - catch (Exception e) - { - if (!IsDisposed) - { - var se = e as SocketException ?? e.InnerException as SocketException; - if (se != null && se.SocketErrorCode == SocketError.Shutdown) - Logger.LogInfo("Client disconnected ungracefully"); - else - Logger.LogError("Exception thrown when trying to write to client", e); - - Dispose(); - } - } - } - - return true; - } - - protected abstract bool WriteHandshake(); - protected abstract bool IsValidResponseHandshake(string handshakeLine); - - public void Dispose() - { - if (IsDisposed) - return; - - IsDisposed = true; - - clientReader?.Dispose(); - clientWriter?.Dispose(); - ((IDisposable)tcpClient)?.Dispose(); - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionClient.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionClient.cs deleted file mode 100644 index 1b11a14358..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionClient.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Net.Sockets; -using System.Threading.Tasks; - -namespace GodotTools.IdeConnection -{ - public class GodotIdeConnectionClient : GodotIdeConnection - { - public GodotIdeConnectionClient(TcpClient tcpClient, Func<Message, bool> messageHandler) - : base(tcpClient, messageHandler) - { - } - - protected override bool WriteHandshake() - { - return WriteLine(ClientHandshake); - } - - protected override bool IsValidResponseHandshake(string handshakeLine) - { - return handshakeLine == ServerHandshake; - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionServer.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionServer.cs deleted file mode 100644 index aa98dc7ca3..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeConnectionServer.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Net.Sockets; -using System.Threading.Tasks; - -namespace GodotTools.IdeConnection -{ - public class GodotIdeConnectionServer : GodotIdeConnection - { - public GodotIdeConnectionServer(TcpClient tcpClient, Func<Message, bool> messageHandler) - : base(tcpClient, messageHandler) - { - } - - protected override bool WriteHandshake() - { - return WriteLine(ServerHandshake); - } - - protected override bool IsValidResponseHandshake(string handshakeLine) - { - return handshakeLine == ClientHandshake; - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotTools.IdeConnection.csproj b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotTools.IdeConnection.csproj deleted file mode 100644 index 8454535fba..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotTools.IdeConnection.csproj +++ /dev/null @@ -1,53 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project ToolsVersion="4.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> - <Import Project="$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props" Condition="Exists('$(MSBuildExtensionsPath)\$(MSBuildToolsVersion)\Microsoft.Common.props')" /> - <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> - <ProjectGuid>{92600954-25F0-4291-8E11-1FEE9FC4BE20}</ProjectGuid> - <OutputType>Library</OutputType> - <AppDesignerFolder>Properties</AppDesignerFolder> - <RootNamespace>GodotTools.IdeConnection</RootNamespace> - <AssemblyName>GodotTools.IdeConnection</AssemblyName> - <TargetFrameworkVersion>v4.7</TargetFrameworkVersion> - <FileAlignment>512</FileAlignment> - <LangVersion>7</LangVersion> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <PlatformTarget>AnyCPU</PlatformTarget> - <DebugSymbols>true</DebugSymbols> - <DebugType>portable</DebugType> - <Optimize>false</Optimize> - <OutputPath>bin\Debug\</OutputPath> - <DefineConstants>DEBUG;TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <PlatformTarget>AnyCPU</PlatformTarget> - <DebugType>portable</DebugType> - <Optimize>true</Optimize> - <OutputPath>bin\Release\</OutputPath> - <DefineConstants>TRACE</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - </PropertyGroup> - <ItemGroup> - <Reference Include="System" /> - </ItemGroup> - <ItemGroup> - <Compile Include="ConsoleLogger.cs" /> - <Compile Include="GodotIdeMetadata.cs" /> - <Compile Include="GodotIdeBase.cs" /> - <Compile Include="GodotIdeClient.cs" /> - <Compile Include="GodotIdeConnection.cs" /> - <Compile Include="GodotIdeConnectionClient.cs" /> - <Compile Include="GodotIdeConnectionServer.cs" /> - <Compile Include="ILogger.cs" /> - <Compile Include="Message.cs" /> - <Compile Include="MessageComposer.cs" /> - <Compile Include="MessageParser.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - </ItemGroup> - <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> -</Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Message.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Message.cs deleted file mode 100644 index f24d324ae3..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Message.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System.Linq; - -namespace GodotTools.IdeConnection -{ - public struct Message - { - public string Id { get; set; } - public string[] Arguments { get; set; } - - public Message(string id, params string[] arguments) - { - Id = id; - Arguments = arguments; - } - - public override string ToString() - { - return $"(Id: '{Id}', Arguments: '{string.Join(",", Arguments)}')"; - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageComposer.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageComposer.cs deleted file mode 100644 index 30ffe7a06e..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageComposer.cs +++ /dev/null @@ -1,46 +0,0 @@ -using System.Linq; -using System.Text; - -namespace GodotTools.IdeConnection -{ - public class MessageComposer - { - private readonly StringBuilder stringBuilder = new StringBuilder(); - - private static readonly char[] CharsToEscape = { '\\', '"' }; - - public void AddArgument(string argument) - { - AddArgument(argument, quoted: argument.Contains(",")); - } - - public void AddArgument(string argument, bool quoted) - { - if (stringBuilder.Length > 0) - stringBuilder.Append(','); - - if (quoted) - { - stringBuilder.Append('"'); - - foreach (char @char in argument) - { - if (CharsToEscape.Contains(@char)) - stringBuilder.Append('\\'); - stringBuilder.Append(@char); - } - - stringBuilder.Append('"'); - } - else - { - stringBuilder.Append(argument); - } - } - - public override string ToString() - { - return stringBuilder.ToString(); - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageParser.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageParser.cs deleted file mode 100644 index 4365d69989..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/MessageParser.cs +++ /dev/null @@ -1,88 +0,0 @@ -using System.Collections.Generic; -using System.Linq; -using System.Text; - -namespace GodotTools.IdeConnection -{ - public static class MessageParser - { - public static bool TryParse(string messageLine, out Message message) - { - var arguments = new List<string>(); - var stringBuilder = new StringBuilder(); - - bool expectingArgument = true; - - for (int i = 0; i < messageLine.Length; i++) - { - char @char = messageLine[i]; - - if (@char == ',') - { - if (expectingArgument) - arguments.Add(string.Empty); - - expectingArgument = true; - continue; - } - - bool quoted = false; - - if (messageLine[i] == '"') - { - quoted = true; - i++; - } - - while (i < messageLine.Length) - { - @char = messageLine[i]; - - if (quoted && @char == '"') - { - i++; - break; - } - - if (@char == '\\') - { - i++; - if (i < messageLine.Length) - break; - - stringBuilder.Append(messageLine[i]); - } - else if (!quoted && @char == ',') - { - break; // We don't increment the counter to allow the colon to be parsed after this - } - else - { - stringBuilder.Append(@char); - } - - i++; - } - - arguments.Add(stringBuilder.ToString()); - stringBuilder.Clear(); - - expectingArgument = false; - } - - if (arguments.Count == 0) - { - message = new Message(); - return false; - } - - message = new Message - { - Id = arguments[0], - Arguments = arguments.Skip(1).ToArray() - }; - - return true; - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Properties/AssemblyInfo.cs b/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Properties/AssemblyInfo.cs deleted file mode 100644 index 0806d02ca0..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,35 +0,0 @@ -using System.Reflection; -using System.Runtime.InteropServices; - -// General Information about an assembly is controlled through the following -// set of attributes. Change these attribute values to modify the information -// associated with an assembly. -[assembly: AssemblyTitle("GodotTools.IdeConnection")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("")] -[assembly: AssemblyCopyright("Godot Engine contributors")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// Setting ComVisible to false makes the types in this assembly not visible -// to COM components. If you need to access a type in this assembly from -// COM, set the ComVisible attribute to true on that type. -[assembly: ComVisible(false)] - -// The following GUID is for the ID of the typelib if this project is exposed to COM -[assembly: Guid("92600954-25F0-4291-8E11-1FEE9FC4BE20")] - -// Version information for an assembly consists of the following four values: -// -// Major Version -// Minor Version -// Build Number -// Revision -// -// You can specify all the values or you can default the Build and Revision Numbers -// by using the '*' as shown below: -// [assembly: AssemblyVersion("1.0.*")] -[assembly: AssemblyVersion("1.0.0.0")] -[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/ForwarderMessageHandler.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/ForwarderMessageHandler.cs new file mode 100644 index 0000000000..3cb6a6687e --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/ForwarderMessageHandler.cs @@ -0,0 +1,57 @@ +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using GodotTools.IdeMessaging.Utils; + +namespace GodotTools.IdeMessaging.CLI +{ + public class ForwarderMessageHandler : IMessageHandler + { + private readonly StreamWriter outputWriter; + private readonly SemaphoreSlim outputWriteSem = new SemaphoreSlim(1); + + public ForwarderMessageHandler(StreamWriter outputWriter) + { + this.outputWriter = outputWriter; + } + + public async Task<MessageContent> HandleRequest(Peer peer, string id, MessageContent content, ILogger logger) + { + await WriteRequestToOutput(id, content); + return new MessageContent(MessageStatus.RequestNotSupported, "null"); + } + + private async Task WriteRequestToOutput(string id, MessageContent content) + { + using (await outputWriteSem.UseAsync()) + { + await outputWriter.WriteLineAsync("======= Request ======="); + await outputWriter.WriteLineAsync(id); + await outputWriter.WriteLineAsync(content.Body.Count(c => c == '\n').ToString()); + await outputWriter.WriteLineAsync(content.Body); + await outputWriter.WriteLineAsync("======================="); + await outputWriter.FlushAsync(); + } + } + + public async Task WriteResponseToOutput(string id, MessageContent content) + { + using (await outputWriteSem.UseAsync()) + { + await outputWriter.WriteLineAsync("======= Response ======="); + await outputWriter.WriteLineAsync(id); + await outputWriter.WriteLineAsync(content.Body.Count(c => c == '\n').ToString()); + await outputWriter.WriteLineAsync(content.Body); + await outputWriter.WriteLineAsync("========================"); + await outputWriter.FlushAsync(); + } + } + + public async Task WriteLineToOutput(string eventName) + { + using (await outputWriteSem.UseAsync()) + await outputWriter.WriteLineAsync($"======= {eventName} ======="); + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/GodotTools.IdeMessaging.CLI.csproj b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/GodotTools.IdeMessaging.CLI.csproj new file mode 100644 index 0000000000..ae78da27bc --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/GodotTools.IdeMessaging.CLI.csproj @@ -0,0 +1,17 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <ProjectGuid>{B06C2951-C8E3-4F28-80B2-717CF327EB19}</ProjectGuid> + <OutputType>Exe</OutputType> + <TargetFramework>net472</TargetFramework> + <LangVersion>7.2</LangVersion> + </PropertyGroup> + <ItemGroup> + <Reference Include="System" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\GodotTools.IdeMessaging\GodotTools.IdeMessaging.csproj" /> + </ItemGroup> + <ItemGroup> + <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> + </ItemGroup> +</Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/Program.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/Program.cs new file mode 100644 index 0000000000..99a55c471b --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging.CLI/Program.cs @@ -0,0 +1,218 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; +using GodotTools.IdeMessaging.Requests; +using Newtonsoft.Json; + +namespace GodotTools.IdeMessaging.CLI +{ + internal static class Program + { + private static readonly ILogger Logger = new CustomLogger(); + + public static int Main(string[] args) + { + try + { + var mainTask = StartAsync(args, Console.OpenStandardInput(), Console.OpenStandardOutput()); + mainTask.Wait(); + return mainTask.Result; + } + catch (Exception ex) + { + Logger.LogError("Unhandled exception: ", ex); + return 1; + } + } + + private static async Task<int> StartAsync(string[] args, Stream inputStream, Stream outputStream) + { + var inputReader = new StreamReader(inputStream, Encoding.UTF8); + var outputWriter = new StreamWriter(outputStream, Encoding.UTF8); + + try + { + if (args.Length == 0) + { + Logger.LogError("Expected at least 1 argument"); + return 1; + } + + string godotProjectDir = args[0]; + + if (!Directory.Exists(godotProjectDir)) + { + Logger.LogError($"The specified Godot project directory does not exist: {godotProjectDir}"); + return 1; + } + + var forwarder = new ForwarderMessageHandler(outputWriter); + + using (var fwdClient = new Client("VisualStudioCode", godotProjectDir, forwarder, Logger)) + { + fwdClient.Start(); + + // ReSharper disable AccessToDisposedClosure + fwdClient.Connected += async () => await forwarder.WriteLineToOutput("Event=Connected"); + fwdClient.Disconnected += async () => await forwarder.WriteLineToOutput("Event=Disconnected"); + // ReSharper restore AccessToDisposedClosure + + // TODO: Await connected with timeout + + while (!fwdClient.IsDisposed) + { + string firstLine = await inputReader.ReadLineAsync(); + + if (firstLine == null || firstLine == "QUIT") + goto ExitMainLoop; + + string messageId = firstLine; + + string messageArgcLine = await inputReader.ReadLineAsync(); + + if (messageArgcLine == null) + { + Logger.LogInfo("EOF when expecting argument count"); + goto ExitMainLoop; + } + + if (!int.TryParse(messageArgcLine, out int messageArgc)) + { + Logger.LogError("Received invalid line for argument count: " + firstLine); + continue; + } + + var body = new StringBuilder(); + + for (int i = 0; i < messageArgc; i++) + { + string bodyLine = await inputReader.ReadLineAsync(); + + if (bodyLine == null) + { + Logger.LogInfo($"EOF when expecting body line #{i + 1}"); + goto ExitMainLoop; + } + + body.AppendLine(bodyLine); + } + + var response = await SendRequest(fwdClient, messageId, new MessageContent(MessageStatus.Ok, body.ToString())); + + if (response == null) + { + Logger.LogError($"Failed to write message to the server: {messageId}"); + } + else + { + var content = new MessageContent(response.Status, JsonConvert.SerializeObject(response)); + await forwarder.WriteResponseToOutput(messageId, content); + } + } + + ExitMainLoop: + + await forwarder.WriteLineToOutput("Event=Quit"); + } + + return 0; + } + catch (Exception e) + { + Logger.LogError("Unhandled exception", e); + return 1; + } + } + + private static async Task<Response> SendRequest(Client client, string id, MessageContent content) + { + var handlers = new Dictionary<string, Func<Task<Response>>> + { + [PlayRequest.Id] = async () => + { + var request = JsonConvert.DeserializeObject<PlayRequest>(content.Body); + return await client.SendRequest<PlayResponse>(request); + }, + [DebugPlayRequest.Id] = async () => + { + var request = JsonConvert.DeserializeObject<DebugPlayRequest>(content.Body); + return await client.SendRequest<DebugPlayResponse>(request); + }, + [ReloadScriptsRequest.Id] = async () => + { + var request = JsonConvert.DeserializeObject<ReloadScriptsRequest>(content.Body); + return await client.SendRequest<ReloadScriptsResponse>(request); + }, + [CodeCompletionRequest.Id] = async () => + { + var request = JsonConvert.DeserializeObject<CodeCompletionRequest>(content.Body); + return await client.SendRequest<CodeCompletionResponse>(request); + } + }; + + if (handlers.TryGetValue(id, out var handler)) + return await handler(); + + Console.WriteLine("INVALID REQUEST"); + return null; + } + + private class CustomLogger : ILogger + { + private static string ThisAppPath => Assembly.GetExecutingAssembly().Location; + private static string ThisAppPathWithoutExtension => Path.ChangeExtension(ThisAppPath, null); + + private static readonly string LogPath = $"{ThisAppPathWithoutExtension}.log"; + + private static StreamWriter NewWriter() => new StreamWriter(LogPath, append: true, encoding: Encoding.UTF8); + + private static void Log(StreamWriter writer, string message) + { + writer.WriteLine($"{DateTime.Now:HH:mm:ss.ffffff}: {message}"); + } + + public void LogDebug(string message) + { + using (var writer = NewWriter()) + { + Log(writer, "DEBUG: " + message); + } + } + + public void LogInfo(string message) + { + using (var writer = NewWriter()) + { + Log(writer, "INFO: " + message); + } + } + + public void LogWarning(string message) + { + using (var writer = NewWriter()) + { + Log(writer, "WARN: " + message); + } + } + + public void LogError(string message) + { + using (var writer = NewWriter()) + { + Log(writer, "ERROR: " + message); + } + } + + public void LogError(string message, Exception e) + { + using (var writer = NewWriter()) + { + Log(writer, "EXCEPTION: " + message + '\n' + e); + } + } + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Client.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Client.cs new file mode 100644 index 0000000000..d069651dd3 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Client.cs @@ -0,0 +1,332 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Sockets; +using Newtonsoft.Json; +using System.Threading; +using System.Threading.Tasks; +using GodotTools.IdeMessaging.Requests; +using GodotTools.IdeMessaging.Utils; + +namespace GodotTools.IdeMessaging +{ + // ReSharper disable once UnusedType.Global + public sealed class Client : IDisposable + { + private readonly ILogger logger; + + private readonly string identity; + + private string MetaFilePath { get; } + private GodotIdeMetadata godotIdeMetadata; + private readonly FileSystemWatcher fsWatcher; + + private readonly IMessageHandler messageHandler; + + private Peer peer; + private readonly SemaphoreSlim connectionSem = new SemaphoreSlim(1); + + private readonly Queue<NotifyAwaiter<bool>> clientConnectedAwaiters = new Queue<NotifyAwaiter<bool>>(); + private readonly Queue<NotifyAwaiter<bool>> clientDisconnectedAwaiters = new Queue<NotifyAwaiter<bool>>(); + + // ReSharper disable once UnusedMember.Global + public async Task<bool> AwaitConnected() + { + var awaiter = new NotifyAwaiter<bool>(); + clientConnectedAwaiters.Enqueue(awaiter); + return await awaiter; + } + + // ReSharper disable once UnusedMember.Global + public async Task<bool> AwaitDisconnected() + { + var awaiter = new NotifyAwaiter<bool>(); + clientDisconnectedAwaiters.Enqueue(awaiter); + return await awaiter; + } + + // ReSharper disable once MemberCanBePrivate.Global + public bool IsDisposed { get; private set; } + + // ReSharper disable once MemberCanBePrivate.Global + public bool IsConnected => peer != null && !peer.IsDisposed && peer.IsTcpClientConnected; + + // ReSharper disable once EventNeverSubscribedTo.Global + public event Action Connected + { + add + { + if (peer != null && !peer.IsDisposed) + peer.Connected += value; + } + remove + { + if (peer != null && !peer.IsDisposed) + peer.Connected -= value; + } + } + + // ReSharper disable once EventNeverSubscribedTo.Global + public event Action Disconnected + { + add + { + if (peer != null && !peer.IsDisposed) + peer.Disconnected += value; + } + remove + { + if (peer != null && !peer.IsDisposed) + peer.Disconnected -= value; + } + } + + ~Client() + { + Dispose(disposing: false); + } + + public async void Dispose() + { + if (IsDisposed) + return; + + using (await connectionSem.UseAsync()) + { + if (IsDisposed) // lock may not be fair + return; + IsDisposed = true; + } + + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + peer?.Dispose(); + fsWatcher?.Dispose(); + } + } + + public Client(string identity, string godotProjectDir, IMessageHandler messageHandler, ILogger logger) + { + this.identity = identity; + this.messageHandler = messageHandler; + this.logger = logger; + + string projectMetadataDir = Path.Combine(godotProjectDir, ".mono", "metadata"); + + MetaFilePath = Path.Combine(projectMetadataDir, GodotIdeMetadata.DefaultFileName); + + // FileSystemWatcher requires an existing directory + if (!File.Exists(projectMetadataDir)) + Directory.CreateDirectory(projectMetadataDir); + + fsWatcher = new FileSystemWatcher(projectMetadataDir, GodotIdeMetadata.DefaultFileName); + } + + private async void OnMetaFileChanged(object sender, FileSystemEventArgs e) + { + if (IsDisposed) + return; + + using (await connectionSem.UseAsync()) + { + if (IsDisposed) + return; + + if (!File.Exists(MetaFilePath)) + return; + + var metadata = ReadMetadataFile(); + + if (metadata != null && metadata != godotIdeMetadata) + { + godotIdeMetadata = metadata.Value; + _ = Task.Run(ConnectToServer); + } + } + } + + private async void OnMetaFileDeleted(object sender, FileSystemEventArgs e) + { + if (IsDisposed) + return; + + if (IsConnected) + { + using (await connectionSem.UseAsync()) + peer?.Dispose(); + } + + // The file may have been re-created + + using (await connectionSem.UseAsync()) + { + if (IsDisposed) + return; + + if (IsConnected || !File.Exists(MetaFilePath)) + return; + + var metadata = ReadMetadataFile(); + + if (metadata != null) + { + godotIdeMetadata = metadata.Value; + _ = Task.Run(ConnectToServer); + } + } + } + + private GodotIdeMetadata? ReadMetadataFile() + { + using (var reader = File.OpenText(MetaFilePath)) + { + string portStr = reader.ReadLine(); + + if (portStr == null) + return null; + + string editorExecutablePath = reader.ReadLine(); + + if (editorExecutablePath == null) + return null; + + if (!int.TryParse(portStr, out int port)) + return null; + + return new GodotIdeMetadata(port, editorExecutablePath); + } + } + + private async Task AcceptClient(TcpClient tcpClient) + { + logger.LogDebug("Accept client..."); + + using (peer = new Peer(tcpClient, new ClientHandshake(), messageHandler, logger)) + { + // ReSharper disable AccessToDisposedClosure + peer.Connected += () => + { + logger.LogInfo("Connection open with Ide Client"); + + while (clientConnectedAwaiters.Count > 0) + clientConnectedAwaiters.Dequeue().SetResult(true); + }; + + peer.Disconnected += () => + { + while (clientDisconnectedAwaiters.Count > 0) + clientDisconnectedAwaiters.Dequeue().SetResult(true); + }; + // ReSharper restore AccessToDisposedClosure + + try + { + if (!await peer.DoHandshake(identity)) + { + logger.LogError("Handshake failed"); + return; + } + } + catch (Exception e) + { + logger.LogError("Handshake failed with unhandled exception: ", e); + return; + } + + await peer.Process(); + + logger.LogInfo("Connection closed with Ide Client"); + } + } + + private async Task ConnectToServer() + { + var tcpClient = new TcpClient(); + + try + { + logger.LogInfo("Connecting to Godot Ide Server"); + + await tcpClient.ConnectAsync(IPAddress.Loopback, godotIdeMetadata.Port); + + logger.LogInfo("Connection open with Godot Ide Server"); + + await AcceptClient(tcpClient); + } + catch (SocketException e) + { + if (e.SocketErrorCode == SocketError.ConnectionRefused) + logger.LogError("The connection to the Godot Ide Server was refused"); + else + throw; + } + } + + // ReSharper disable once UnusedMember.Global + public async void Start() + { + fsWatcher.Changed += OnMetaFileChanged; + fsWatcher.Deleted += OnMetaFileDeleted; + fsWatcher.EnableRaisingEvents = true; + + using (await connectionSem.UseAsync()) + { + if (IsDisposed) + return; + + if (IsConnected) + return; + + if (!File.Exists(MetaFilePath)) + { + logger.LogInfo("There is no Godot Ide Server running"); + return; + } + + var metadata = ReadMetadataFile(); + + if (metadata != null) + { + godotIdeMetadata = metadata.Value; + _ = Task.Run(ConnectToServer); + } + else + { + logger.LogError("Failed to read Godot Ide metadata file"); + } + } + } + + public async Task<TResponse> SendRequest<TResponse>(Request request) + where TResponse : Response, new() + { + if (!IsConnected) + { + logger.LogError("Cannot write request. Not connected to the Godot Ide Server."); + return null; + } + + string body = JsonConvert.SerializeObject(request); + return await peer.SendRequest<TResponse>(request.Id, body); + } + + public async Task<TResponse> SendRequest<TResponse>(string id, string body) + where TResponse : Response, new() + { + if (!IsConnected) + { + logger.LogError("Cannot write request. Not connected to the Godot Ide Server."); + return null; + } + + return await peer.SendRequest<TResponse>(id, body); + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientHandshake.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientHandshake.cs new file mode 100644 index 0000000000..43041be7be --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientHandshake.cs @@ -0,0 +1,44 @@ +using System.Text.RegularExpressions; + +namespace GodotTools.IdeMessaging +{ + public class ClientHandshake : IHandshake + { + private static readonly string ClientHandshakeBase = $"{Peer.ClientHandshakeName},Version={Peer.ProtocolVersionMajor}.{Peer.ProtocolVersionMinor}.{Peer.ProtocolVersionRevision}"; + private static readonly string ServerHandshakePattern = $@"{Regex.Escape(Peer.ServerHandshakeName)},Version=([0-9]+)\.([0-9]+)\.([0-9]+),([_a-zA-Z][_a-zA-Z0-9]{{0,63}})"; + + public string GetHandshakeLine(string identity) => $"{ClientHandshakeBase},{identity}"; + + public bool IsValidPeerHandshake(string handshake, out string identity, ILogger logger) + { + identity = null; + + var match = Regex.Match(handshake, ServerHandshakePattern); + + if (!match.Success) + return false; + + if (!uint.TryParse(match.Groups[1].Value, out uint serverMajor) || Peer.ProtocolVersionMajor != serverMajor) + { + logger.LogDebug("Incompatible major version: " + match.Groups[1].Value); + return false; + } + + if (!uint.TryParse(match.Groups[2].Value, out uint serverMinor) || Peer.ProtocolVersionMinor < serverMinor) + { + logger.LogDebug("Incompatible minor version: " + match.Groups[2].Value); + return false; + } + + if (!uint.TryParse(match.Groups[3].Value, out uint _)) // Revision + { + logger.LogDebug("Incompatible revision build: " + match.Groups[3].Value); + return false; + } + + identity = match.Groups[4].Value; + + return true; + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientMessageHandler.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientMessageHandler.cs new file mode 100644 index 0000000000..64bcfd824c --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ClientMessageHandler.cs @@ -0,0 +1,52 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using GodotTools.IdeMessaging.Requests; +using Newtonsoft.Json; + +namespace GodotTools.IdeMessaging +{ + // ReSharper disable once UnusedType.Global + public abstract class ClientMessageHandler : IMessageHandler + { + private readonly Dictionary<string, Peer.RequestHandler> requestHandlers; + + protected ClientMessageHandler() + { + requestHandlers = InitializeRequestHandlers(); + } + + public async Task<MessageContent> HandleRequest(Peer peer, string id, MessageContent content, ILogger logger) + { + if (!requestHandlers.TryGetValue(id, out var handler)) + { + logger.LogError($"Received unknown request: {id}"); + return new MessageContent(MessageStatus.RequestNotSupported, "null"); + } + + try + { + var response = await handler(peer, content); + return new MessageContent(response.Status, JsonConvert.SerializeObject(response)); + } + catch (JsonException) + { + logger.LogError($"Received request with invalid body: {id}"); + return new MessageContent(MessageStatus.InvalidRequestBody, "null"); + } + } + + private Dictionary<string, Peer.RequestHandler> InitializeRequestHandlers() + { + return new Dictionary<string, Peer.RequestHandler> + { + [OpenFileRequest.Id] = async (peer, content) => + { + var request = JsonConvert.DeserializeObject<OpenFileRequest>(content.Body); + return await HandleOpenFile(request); + } + }; + } + + protected abstract Task<Response> HandleOpenFile(OpenFileRequest request); + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeMetadata.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/GodotIdeMetadata.cs index d16daba0e2..686202e81e 100644 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/GodotIdeMetadata.cs +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/GodotIdeMetadata.cs @@ -1,10 +1,12 @@ -namespace GodotTools.IdeConnection +namespace GodotTools.IdeMessaging { - public struct GodotIdeMetadata + public readonly struct GodotIdeMetadata { public int Port { get; } public string EditorExecutablePath { get; } + public const string DefaultFileName = "ide_messaging_meta.txt"; + public GodotIdeMetadata(int port, string editorExecutablePath) { Port = port; diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/GodotTools.IdeMessaging.csproj b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/GodotTools.IdeMessaging.csproj new file mode 100644 index 0000000000..67815959a6 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/GodotTools.IdeMessaging.csproj @@ -0,0 +1,24 @@ +<Project Sdk="Microsoft.NET.Sdk"> + <PropertyGroup> + <ProjectGuid>{92600954-25F0-4291-8E11-1FEE9FC4BE20}</ProjectGuid> + <TargetFramework>netstandard2.0</TargetFramework> + <LangVersion>7.2</LangVersion> + <PackageId>GodotTools.IdeMessaging</PackageId> + <Version>1.1.0</Version> + <AssemblyVersion>$(Version)</AssemblyVersion> + <Authors>Godot Engine contributors</Authors> + <Company /> + <PackageTags>godot</PackageTags> + <RepositoryUrl>https://github.com/godotengine/godot/tree/master/modules/mono/editor/GodotTools/GodotTools.IdeMessaging</RepositoryUrl> + <PackageLicenseExpression>MIT</PackageLicenseExpression> + <Description> +This library enables communication with the Godot Engine editor (the version with .NET support). +It's intended for use in IDEs/editors plugins for a better experience working with Godot C# projects. + +A client using this library is only compatible with servers of the same major version and of a lower or equal minor version. + </Description> + </PropertyGroup> + <ItemGroup> + <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> + </ItemGroup> +</Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IHandshake.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IHandshake.cs new file mode 100644 index 0000000000..6387145a28 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IHandshake.cs @@ -0,0 +1,8 @@ +namespace GodotTools.IdeMessaging +{ + public interface IHandshake + { + string GetHandshakeLine(string identity); + bool IsValidPeerHandshake(string handshake, out string identity, ILogger logger); + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/ILogger.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ILogger.cs index 614bb30271..d2855f93a1 100644 --- a/modules/mono/editor/GodotTools/GodotTools.IdeConnection/ILogger.cs +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ILogger.cs @@ -1,6 +1,6 @@ using System; -namespace GodotTools.IdeConnection +namespace GodotTools.IdeMessaging { public interface ILogger { diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IMessageHandler.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IMessageHandler.cs new file mode 100644 index 0000000000..9622fcc96d --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/IMessageHandler.cs @@ -0,0 +1,9 @@ +using System.Threading.Tasks; + +namespace GodotTools.IdeMessaging +{ + public interface IMessageHandler + { + Task<MessageContent> HandleRequest(Peer peer, string id, MessageContent content, ILogger logger); + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Message.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Message.cs new file mode 100644 index 0000000000..6903ec197b --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Message.cs @@ -0,0 +1,52 @@ +namespace GodotTools.IdeMessaging +{ + public class Message + { + public MessageKind Kind { get; } + public string Id { get; } + public MessageContent Content { get; } + + public Message(MessageKind kind, string id, MessageContent content) + { + Kind = kind; + Id = id; + Content = content; + } + + public override string ToString() + { + return $"{Kind} | {Id}"; + } + } + + public enum MessageKind + { + Request, + Response + } + + public enum MessageStatus + { + Ok, + RequestNotSupported, + InvalidRequestBody + } + + public readonly struct MessageContent + { + public MessageStatus Status { get; } + public string Body { get; } + + public MessageContent(string body) + { + Status = MessageStatus.Ok; + Body = body; + } + + public MessageContent(MessageStatus status, string body) + { + Status = status; + Body = body; + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/MessageDecoder.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/MessageDecoder.cs new file mode 100644 index 0000000000..a00575a2a1 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/MessageDecoder.cs @@ -0,0 +1,100 @@ +using System; +using System.Text; + +namespace GodotTools.IdeMessaging +{ + public class MessageDecoder + { + private class DecodedMessage + { + public MessageKind? Kind; + public string Id; + public MessageStatus? Status; + public readonly StringBuilder Body = new StringBuilder(); + public uint? PendingBodyLines; + + public void Clear() + { + Kind = null; + Id = null; + Status = null; + Body.Clear(); + PendingBodyLines = null; + } + + public Message ToMessage() + { + if (!Kind.HasValue || Id == null || !Status.HasValue || + !PendingBodyLines.HasValue || PendingBodyLines.Value > 0) + throw new InvalidOperationException(); + + return new Message(Kind.Value, Id, new MessageContent(Status.Value, Body.ToString())); + } + } + + public enum State + { + Decoding, + Decoded, + Errored + } + + private readonly DecodedMessage decodingMessage = new DecodedMessage(); + + public State Decode(string messageLine, out Message decodedMessage) + { + decodedMessage = null; + + if (!decodingMessage.Kind.HasValue) + { + if (!Enum.TryParse(messageLine, ignoreCase: true, out MessageKind kind)) + { + decodingMessage.Clear(); + return State.Errored; + } + + decodingMessage.Kind = kind; + } + else if (decodingMessage.Id == null) + { + decodingMessage.Id = messageLine; + } + else if (decodingMessage.Status == null) + { + if (!Enum.TryParse(messageLine, ignoreCase: true, out MessageStatus status)) + { + decodingMessage.Clear(); + return State.Errored; + } + + decodingMessage.Status = status; + } + else if (decodingMessage.PendingBodyLines == null) + { + if (!uint.TryParse(messageLine, out uint pendingBodyLines)) + { + decodingMessage.Clear(); + return State.Errored; + } + + decodingMessage.PendingBodyLines = pendingBodyLines; + } + else + { + if (decodingMessage.PendingBodyLines > 0) + { + decodingMessage.Body.AppendLine(messageLine); + decodingMessage.PendingBodyLines -= 1; + } + else + { + decodedMessage = decodingMessage.ToMessage(); + decodingMessage.Clear(); + return State.Decoded; + } + } + + return State.Decoding; + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Peer.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Peer.cs new file mode 100644 index 0000000000..a4e86d6177 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Peer.cs @@ -0,0 +1,302 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Sockets; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using GodotTools.IdeMessaging.Requests; +using GodotTools.IdeMessaging.Utils; + +namespace GodotTools.IdeMessaging +{ + public sealed class Peer : IDisposable + { + /// <summary> + /// Major version. + /// There is no forward nor backward compatibility between different major versions. + /// Connection is refused if client and server have different major versions. + /// </summary> + public static readonly int ProtocolVersionMajor = Assembly.GetAssembly(typeof(Peer)).GetName().Version.Major; + + /// <summary> + /// Minor version, which clients must be backward compatible with. + /// Connection is refused if the client's minor version is lower than the server's. + /// </summary> + public static readonly int ProtocolVersionMinor = Assembly.GetAssembly(typeof(Peer)).GetName().Version.Minor; + + /// <summary> + /// Revision, which doesn't affect compatibility. + /// </summary> + public static readonly int ProtocolVersionRevision = Assembly.GetAssembly(typeof(Peer)).GetName().Version.Revision; + + public const string ClientHandshakeName = "GodotIdeClient"; + public const string ServerHandshakeName = "GodotIdeServer"; + + private const int ClientWriteTimeout = 8000; + + public delegate Task<Response> RequestHandler(Peer peer, MessageContent content); + + private readonly TcpClient tcpClient; + + private readonly TextReader clientReader; + private readonly TextWriter clientWriter; + + private readonly SemaphoreSlim writeSem = new SemaphoreSlim(1); + + private string remoteIdentity = string.Empty; + public string RemoteIdentity => remoteIdentity; + + public event Action Connected; + public event Action Disconnected; + + private ILogger Logger { get; } + + public bool IsDisposed { get; private set; } + + public bool IsTcpClientConnected => tcpClient.Client != null && tcpClient.Client.Connected; + + private bool IsConnected { get; set; } + + private readonly IHandshake handshake; + private readonly IMessageHandler messageHandler; + + private readonly Dictionary<string, Queue<ResponseAwaiter>> requestAwaiterQueues = new Dictionary<string, Queue<ResponseAwaiter>>(); + private readonly SemaphoreSlim requestsSem = new SemaphoreSlim(1); + + public Peer(TcpClient tcpClient, IHandshake handshake, IMessageHandler messageHandler, ILogger logger) + { + this.tcpClient = tcpClient; + this.handshake = handshake; + this.messageHandler = messageHandler; + + Logger = logger; + + NetworkStream clientStream = tcpClient.GetStream(); + clientStream.WriteTimeout = ClientWriteTimeout; + + clientReader = new StreamReader(clientStream, Encoding.UTF8); + clientWriter = new StreamWriter(clientStream, Encoding.UTF8) {NewLine = "\n"}; + } + + public async Task Process() + { + try + { + var decoder = new MessageDecoder(); + + string messageLine; + while ((messageLine = await ReadLine()) != null) + { + var state = decoder.Decode(messageLine, out var msg); + + if (state == MessageDecoder.State.Decoding) + continue; // Not finished decoding yet + + if (state == MessageDecoder.State.Errored) + { + Logger.LogError($"Received message line with invalid format: {messageLine}"); + continue; + } + + Logger.LogDebug($"Received message: {msg}"); + + try + { + try + { + if (msg.Kind == MessageKind.Request) + { + var responseContent = await messageHandler.HandleRequest(this, msg.Id, msg.Content, Logger); + await WriteMessage(new Message(MessageKind.Response, msg.Id, responseContent)); + } + else if (msg.Kind == MessageKind.Response) + { + ResponseAwaiter responseAwaiter; + + using (await requestsSem.UseAsync()) + { + if (!requestAwaiterQueues.TryGetValue(msg.Id, out var queue) || queue.Count <= 0) + { + Logger.LogError($"Received unexpected response: {msg.Id}"); + return; + } + + responseAwaiter = queue.Dequeue(); + } + + responseAwaiter.SetResult(msg.Content); + } + else + { + throw new IndexOutOfRangeException($"Invalid message kind {msg.Kind}"); + } + } + catch (Exception e) + { + Logger.LogError($"Message handler for '{msg}' failed with exception", e); + } + } + catch (Exception e) + { + Logger.LogError($"Exception thrown from message handler. Message: {msg}", e); + } + } + } + catch (Exception e) + { + Logger.LogError("Unhandled exception in the peer loop", e); + } + } + + public async Task<bool> DoHandshake(string identity) + { + if (!await WriteLine(handshake.GetHandshakeLine(identity))) + { + Logger.LogError("Could not write handshake"); + return false; + } + + var readHandshakeTask = ReadLine(); + + if (await Task.WhenAny(readHandshakeTask, Task.Delay(8000)) != readHandshakeTask) + { + Logger.LogError("Timeout waiting for the client handshake"); + return false; + } + + string peerHandshake = await readHandshakeTask; + + if (handshake == null || !handshake.IsValidPeerHandshake(peerHandshake, out remoteIdentity, Logger)) + { + Logger.LogError("Received invalid handshake: " + peerHandshake); + return false; + } + + IsConnected = true; + Connected?.Invoke(); + + Logger.LogInfo("Peer connection started"); + + return true; + } + + private async Task<string> ReadLine() + { + try + { + return await clientReader.ReadLineAsync(); + } + catch (Exception e) + { + if (IsDisposed) + { + var se = e as SocketException ?? e.InnerException as SocketException; + if (se != null && se.SocketErrorCode == SocketError.Interrupted) + return null; + } + + throw; + } + } + + private Task<bool> WriteMessage(Message message) + { + Logger.LogDebug($"Sending message: {message}"); + int bodyLineCount = message.Content.Body.Count(c => c == '\n'); + + bodyLineCount += 1; // Extra line break at the end + + var builder = new StringBuilder(); + + builder.AppendLine(message.Kind.ToString()); + builder.AppendLine(message.Id); + builder.AppendLine(message.Content.Status.ToString()); + builder.AppendLine(bodyLineCount.ToString()); + builder.AppendLine(message.Content.Body); + + return WriteLine(builder.ToString()); + } + + public async Task<TResponse> SendRequest<TResponse>(string id, string body) + where TResponse : Response, new() + { + ResponseAwaiter responseAwaiter; + + using (await requestsSem.UseAsync()) + { + bool written = await WriteMessage(new Message(MessageKind.Request, id, new MessageContent(body))); + + if (!written) + return null; + + if (!requestAwaiterQueues.TryGetValue(id, out var queue)) + { + queue = new Queue<ResponseAwaiter>(); + requestAwaiterQueues.Add(id, queue); + } + + responseAwaiter = new ResponseAwaiter<TResponse>(); + queue.Enqueue(responseAwaiter); + } + + return (TResponse)await responseAwaiter; + } + + private async Task<bool> WriteLine(string text) + { + if (clientWriter == null || IsDisposed || !IsTcpClientConnected) + return false; + + using (await writeSem.UseAsync()) + { + try + { + await clientWriter.WriteLineAsync(text); + await clientWriter.FlushAsync(); + } + catch (Exception e) + { + if (!IsDisposed) + { + var se = e as SocketException ?? e.InnerException as SocketException; + if (se != null && se.SocketErrorCode == SocketError.Shutdown) + Logger.LogInfo("Client disconnected ungracefully"); + else + Logger.LogError("Exception thrown when trying to write to client", e); + + Dispose(); + } + } + } + + return true; + } + + // ReSharper disable once UnusedMember.Global + public void ShutdownSocketSend() + { + tcpClient.Client.Shutdown(SocketShutdown.Send); + } + + public void Dispose() + { + if (IsDisposed) + return; + + IsDisposed = true; + + if (IsTcpClientConnected) + { + if (IsConnected) + Disconnected?.Invoke(); + } + + clientReader?.Dispose(); + clientWriter?.Dispose(); + ((IDisposable)tcpClient)?.Dispose(); + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Requests/Requests.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Requests/Requests.cs new file mode 100644 index 0000000000..1dd4f852e5 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Requests/Requests.cs @@ -0,0 +1,116 @@ +// ReSharper disable ClassNeverInstantiated.Global +// ReSharper disable UnusedMember.Global +// ReSharper disable UnusedAutoPropertyAccessor.Global + +using Newtonsoft.Json; + +namespace GodotTools.IdeMessaging.Requests +{ + public abstract class Request + { + [JsonIgnore] public string Id { get; } + + protected Request(string id) + { + Id = id; + } + } + + public abstract class Response + { + [JsonIgnore] public MessageStatus Status { get; set; } = MessageStatus.Ok; + } + + public sealed class CodeCompletionRequest : Request + { + public enum CompletionKind + { + InputActions = 0, + NodePaths, + ResourcePaths, + ScenePaths, + ShaderParams, + Signals, + ThemeColors, + ThemeConstants, + ThemeFonts, + ThemeStyles + } + + public CompletionKind Kind { get; set; } + public string ScriptFile { get; set; } + + public new const string Id = "CodeCompletion"; + + public CodeCompletionRequest() : base(Id) + { + } + } + + public sealed class CodeCompletionResponse : Response + { + public CodeCompletionRequest.CompletionKind Kind; + public string ScriptFile { get; set; } + public string[] Suggestions { get; set; } + } + + public sealed class PlayRequest : Request + { + public new const string Id = "Play"; + + public PlayRequest() : base(Id) + { + } + } + + public sealed class PlayResponse : Response + { + } + + public sealed class DebugPlayRequest : Request + { + public string DebuggerHost { get; set; } + public int DebuggerPort { get; set; } + public bool? BuildBeforePlaying { get; set; } + + public new const string Id = "DebugPlay"; + + public DebugPlayRequest() : base(Id) + { + } + } + + public sealed class DebugPlayResponse : Response + { + } + + public sealed class OpenFileRequest : Request + { + public string File { get; set; } + public int? Line { get; set; } + public int? Column { get; set; } + + public new const string Id = "OpenFile"; + + public OpenFileRequest() : base(Id) + { + } + } + + public sealed class OpenFileResponse : Response + { + } + + public sealed class ReloadScriptsRequest : Request + { + public new const string Id = "ReloadScripts"; + + public ReloadScriptsRequest() : base(Id) + { + } + } + + public sealed class ReloadScriptsResponse : Response + { + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ResponseAwaiter.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ResponseAwaiter.cs new file mode 100644 index 0000000000..548e7f06ee --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/ResponseAwaiter.cs @@ -0,0 +1,23 @@ +using GodotTools.IdeMessaging.Requests; +using GodotTools.IdeMessaging.Utils; +using Newtonsoft.Json; + +namespace GodotTools.IdeMessaging +{ + public abstract class ResponseAwaiter : NotifyAwaiter<Response> + { + public abstract void SetResult(MessageContent content); + } + + public class ResponseAwaiter<T> : ResponseAwaiter + where T : Response, new() + { + public override void SetResult(MessageContent content) + { + if (content.Status == MessageStatus.Ok) + SetResult(JsonConvert.DeserializeObject<T>(content.Body)); + else + SetResult(new T {Status = content.Status}); + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools/Utils/NotifyAwaiter.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Utils/NotifyAwaiter.cs index 700b786752..d84a63c83c 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Utils/NotifyAwaiter.cs +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Utils/NotifyAwaiter.cs @@ -1,9 +1,9 @@ using System; using System.Runtime.CompilerServices; -namespace GodotTools.Utils +namespace GodotTools.IdeMessaging.Utils { - public sealed class NotifyAwaiter<T> : INotifyCompletion + public class NotifyAwaiter<T> : INotifyCompletion { private Action continuation; private Exception exception; diff --git a/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Utils/SemaphoreExtensions.cs b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Utils/SemaphoreExtensions.cs new file mode 100644 index 0000000000..9d593fbf8a --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.IdeMessaging/Utils/SemaphoreExtensions.cs @@ -0,0 +1,32 @@ +using System; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace GodotTools.IdeMessaging.Utils +{ + public static class SemaphoreExtensions + { + public static ConfiguredTaskAwaitable<IDisposable> UseAsync(this SemaphoreSlim semaphoreSlim, CancellationToken cancellationToken = default(CancellationToken)) + { + var wrapper = new SemaphoreSlimWaitReleaseWrapper(semaphoreSlim, out Task waitAsyncTask, cancellationToken); + return waitAsyncTask.ContinueWith<IDisposable>(t => wrapper, cancellationToken).ConfigureAwait(false); + } + + private struct SemaphoreSlimWaitReleaseWrapper : IDisposable + { + private readonly SemaphoreSlim semaphoreSlim; + + public SemaphoreSlimWaitReleaseWrapper(SemaphoreSlim semaphoreSlim, out Task waitAsyncTask, CancellationToken cancellationToken = default(CancellationToken)) + { + this.semaphoreSlim = semaphoreSlim; + waitAsyncTask = this.semaphoreSlim.WaitAsync(cancellationToken); + } + + public void Dispose() + { + semaphoreSlim.Release(); + } + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/GodotTools.ProjectEditor.csproj b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/GodotTools.ProjectEditor.csproj index b60e501beb..9cb50014b0 100644 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/GodotTools.ProjectEditor.csproj +++ b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/GodotTools.ProjectEditor.csproj @@ -1,57 +1,23 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> +<Project Sdk="Microsoft.NET.Sdk"> <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> <ProjectGuid>{A8CDAD94-C6D4-4B19-A7E7-76C53CC92984}</ProjectGuid> - <OutputType>Library</OutputType> - <RootNamespace>GodotTools.ProjectEditor</RootNamespace> - <AssemblyName>GodotTools.ProjectEditor</AssemblyName> - <TargetFrameworkVersion>v4.7</TargetFrameworkVersion> - <BaseIntermediateOutputPath>obj</BaseIntermediateOutputPath> - <LangVersion>7</LangVersion> + <TargetFramework>net472</TargetFramework> + <LangVersion>7.2</LangVersion> </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <DebugSymbols>true</DebugSymbols> - <DebugType>portable</DebugType> - <Optimize>false</Optimize> - <OutputPath>bin\Debug</OutputPath> - <DefineConstants>DEBUG;</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <Optimize>true</Optimize> - <OutputPath>bin\Release</OutputPath> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> - </PropertyGroup> - <ItemGroup> - <Reference Include="System" /> - <Reference Include="Microsoft.Build" /> - <Reference Include="DotNet.Glob, Version=2.1.1.0, Culture=neutral, PublicKeyToken=b68cc888b4f632d1, processorArchitecture=MSIL"> - <HintPath>$(SolutionDir)\packages\DotNet.Glob.2.1.1\lib\net45\DotNet.Glob.dll</HintPath> - </Reference> - </ItemGroup> <ItemGroup> - <Compile Include="ApiAssembliesInfo.cs" /> - <Compile Include="DotNetSolution.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - <Compile Include="IdentifierUtils.cs" /> - <Compile Include="ProjectExtensions.cs" /> - <Compile Include="ProjectGenerator.cs" /> - <Compile Include="ProjectUtils.cs" /> + <PackageReference Include="Microsoft.Build" Version="16.5.0" /> + <PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="1.0.0" PrivateAssets="All" /> </ItemGroup> <ItemGroup> - <None Include="packages.config" /> + <ProjectReference Include="..\GodotTools.Core\GodotTools.Core.csproj" /> </ItemGroup> <ItemGroup> - <ProjectReference Include="..\GodotTools.Core\GodotTools.Core.csproj"> - <Project>{639E48BD-44E5-4091-8EDD-22D36DC0768D}</Project> - <Name>GodotTools.Core</Name> - </ProjectReference> + <!-- + The Microsoft.Build.Runtime package is too problematic so we create a MSBuild.exe stub. The workaround described + here doesn't work with Microsoft.NETFramework.ReferenceAssemblies: https://github.com/microsoft/msbuild/issues/3486 + We need a MSBuild.exe file as there's an issue in Microsoft.Build where it executes platform dependent code when + searching for MSBuild.exe before the fallback to not using it. A stub is fine as it should never be executed. + --> + <None Include="MSBuild.exe" CopyToOutputDirectory="Always" /> </ItemGroup> - <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /> </Project> diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/MSBuild.exe b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/MSBuild.exe new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/MSBuild.exe diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectExtensions.cs b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectExtensions.cs index f0e0d1b33d..704f2ec194 100644 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectExtensions.cs +++ b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectExtensions.cs @@ -2,8 +2,8 @@ using GodotTools.Core; using System; using System.Collections.Generic; using System.IO; -using DotNet.Globbing; using Microsoft.Build.Construction; +using Microsoft.Build.Globbing; namespace GodotTools.ProjectEditor { @@ -11,8 +11,6 @@ namespace GodotTools.ProjectEditor { public static ProjectItemElement FindItemOrNull(this ProjectRootElement root, string itemType, string include, bool noCondition = false) { - GlobOptions globOptions = new GlobOptions {Evaluation = {CaseInsensitive = false}}; - string normalizedInclude = include.NormalizePath(); foreach (var itemGroup in root.ItemGroups) @@ -25,7 +23,8 @@ namespace GodotTools.ProjectEditor if (item.ItemType != itemType) continue; - var glob = Glob.Parse(item.Include.NormalizePath(), globOptions); + //var glob = Glob.Parse(item.Include.NormalizePath(), globOptions); + var glob = MSBuildGlob.Parse(item.Include.NormalizePath()); if (glob.IsMatch(normalizedInclude)) return item; @@ -36,8 +35,6 @@ namespace GodotTools.ProjectEditor } public static ProjectItemElement FindItemOrNullAbs(this ProjectRootElement root, string itemType, string include, bool noCondition = false) { - GlobOptions globOptions = new GlobOptions {Evaluation = {CaseInsensitive = false}}; - string normalizedInclude = Path.GetFullPath(include).NormalizePath(); foreach (var itemGroup in root.ItemGroups) @@ -50,7 +47,7 @@ namespace GodotTools.ProjectEditor if (item.ItemType != itemType) continue; - var glob = Glob.Parse(Path.GetFullPath(item.Include).NormalizePath(), globOptions); + var glob = MSBuildGlob.Parse(Path.GetFullPath(item.Include).NormalizePath()); if (glob.IsMatch(normalizedInclude)) return item; diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectGenerator.cs b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectGenerator.cs index cbe3afaedd..fb2beb6995 100644 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectGenerator.cs +++ b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectGenerator.cs @@ -125,6 +125,12 @@ namespace GodotTools.ProjectEditor // References var referenceGroup = root.AddItemGroup(); referenceGroup.AddItem("Reference", "System"); + var frameworkRefAssembliesItem = referenceGroup.AddItem("PackageReference", "Microsoft.NETFramework.ReferenceAssemblies"); + + // Use metadata (child nodes) instead of attributes for the PackageReference. + // This is for compatibility with 3.2, where GodotTools uses an old Microsoft.Build. + frameworkRefAssembliesItem.AddMetadata("Version", "1.0.0"); + frameworkRefAssembliesItem.AddMetadata("PrivateAssets", "All"); root.AddImport(Path.Combine("$(MSBuildBinPath)", "Microsoft.CSharp.targets").Replace("/", "\\")); diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectUtils.cs b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectUtils.cs index f2ebef1a7d..069a1edaa3 100644 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectUtils.cs +++ b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/ProjectUtils.cs @@ -4,8 +4,8 @@ using System.Diagnostics; using System.IO; using System.Linq; using System.Reflection; -using DotNet.Globbing; using Microsoft.Build.Construction; +using Microsoft.Build.Globbing; namespace GodotTools.ProjectEditor { @@ -133,9 +133,6 @@ namespace GodotTools.ProjectEditor var result = new List<string>(); var existingFiles = GetAllFilesRecursive(Path.GetDirectoryName(projectPath), "*.cs"); - var globOptions = new GlobOptions(); - globOptions.Evaluation.CaseInsensitive = false; - var root = ProjectRootElement.Open(projectPath); Debug.Assert(root != null); @@ -151,7 +148,7 @@ namespace GodotTools.ProjectEditor string normalizedInclude = item.Include.NormalizePath(); - var glob = Glob.Parse(normalizedInclude, globOptions); + var glob = MSBuildGlob.Parse(normalizedInclude); // TODO Check somehow if path has no blob to avoid the following loop... @@ -176,7 +173,7 @@ namespace GodotTools.ProjectEditor void AddPropertyIfNotPresent(string name, string condition, string value) { if (root.PropertyGroups - .Any(g => (g.Condition == string.Empty || g.Condition.Trim() == condition) && + .Any(g => (string.IsNullOrEmpty(g.Condition) || g.Condition.Trim() == condition) && g.Properties .Any(p => p.Name == name && p.Value == value && @@ -267,7 +264,7 @@ namespace GodotTools.ProjectEditor bool hasGodotProjectGeneratorVersion = false; bool foundOldConfiguration = false; - foreach (var propertyGroup in root.PropertyGroups.Where(g => g.Condition == string.Empty)) + foreach (var propertyGroup in root.PropertyGroups.Where(g => string.IsNullOrEmpty(g.Condition))) { if (!hasGodotProjectGeneratorVersion && propertyGroup.Properties.Any(p => p.Name == "GodotProjectGeneratorVersion")) hasGodotProjectGeneratorVersion = true; @@ -283,7 +280,7 @@ namespace GodotTools.ProjectEditor if (!hasGodotProjectGeneratorVersion) { - root.PropertyGroups.First(g => g.Condition == string.Empty)? + root.PropertyGroups.First(g => string.IsNullOrEmpty(g.Condition))? .AddProperty("GodotProjectGeneratorVersion", Assembly.GetExecutingAssembly().GetName().Version.ToString()); project.HasUnsavedChanges = true; } @@ -351,5 +348,25 @@ namespace GodotTools.ProjectEditor MigrateConfigurationConditions("Tools", "Debug"); // Must be last } } + + public static void EnsureHasNugetNetFrameworkRefAssemblies(MSBuildProject project) + { + var root = project.Root; + + bool found = root.ItemGroups.Any(g => string.IsNullOrEmpty(g.Condition) && g.Items.Any( + item => item.ItemType == "PackageReference" && item.Include == "Microsoft.NETFramework.ReferenceAssemblies")); + + if (found) + return; + + var frameworkRefAssembliesItem = root.AddItem("PackageReference", "Microsoft.NETFramework.ReferenceAssemblies"); + + // Use metadata (child nodes) instead of attributes for the PackageReference. + // This is for compatibility with 3.2, where GodotTools uses an old Microsoft.Build. + frameworkRefAssembliesItem.AddMetadata("Version", "1.0.0"); + frameworkRefAssembliesItem.AddMetadata("PrivateAssets", "All"); + + project.HasUnsavedChanges = true; + } } } diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/Properties/AssemblyInfo.cs b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/Properties/AssemblyInfo.cs deleted file mode 100644 index 3a0464c9bc..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; - -// Information about this assembly is defined by the following attributes. -// Change them to the values specific to your project. - -[assembly: AssemblyTitle("GodotTools.ProjectEditor")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("")] -[assembly: AssemblyCopyright("Godot Engine contributors")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". -// The form "{Major}.{Minor}.*" will automatically update the build and revision, -// and "{Major}.{Minor}.{Build}.*" will update just the revision. - -[assembly: AssemblyVersion("1.0.*")] - -// The following attributes are used to specify the signing key for the assembly, -// if desired. See the Mono documentation for more information about signing. - -//[assembly: AssemblyDelaySign(false)] -//[assembly: AssemblyKeyFile("")] - diff --git a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/packages.config b/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/packages.config deleted file mode 100644 index 2db030f9d8..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools.ProjectEditor/packages.config +++ /dev/null @@ -1,4 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<packages> - <package id="DotNet.Glob" version="2.1.1" targetFramework="net45" /> -</packages> diff --git a/modules/mono/editor/GodotTools/GodotTools.sln b/modules/mono/editor/GodotTools/GodotTools.sln index a3438ea5f3..f6147eb5bb 100644 --- a/modules/mono/editor/GodotTools/GodotTools.sln +++ b/modules/mono/editor/GodotTools/GodotTools.sln @@ -9,7 +9,7 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GodotTools.Core", "GodotToo EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GodotTools.BuildLogger", "GodotTools.BuildLogger\GodotTools.BuildLogger.csproj", "{6CE9A984-37B1-4F8A-8FE9-609F05F071B3}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GodotTools.IdeConnection", "GodotTools.IdeConnection\GodotTools.IdeConnection.csproj", "{92600954-25F0-4291-8E11-1FEE9FC4BE20}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GodotTools.IdeMessaging", "GodotTools.IdeMessaging\GodotTools.IdeMessaging.csproj", "{92600954-25F0-4291-8E11-1FEE9FC4BE20}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/modules/mono/editor/GodotTools/GodotTools/Build/BuildSystem.cs b/modules/mono/editor/GodotTools/GodotTools/Build/BuildSystem.cs index 43c96d2e30..e55558c100 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Build/BuildSystem.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Build/BuildSystem.cs @@ -14,16 +14,6 @@ namespace GodotTools.Build { public static class BuildSystem { - private static string GetMsBuildPath() - { - string msbuildPath = MsBuildFinder.FindMsBuild(); - - if (msbuildPath == null) - throw new FileNotFoundException("Cannot find the MSBuild executable."); - - return msbuildPath; - } - private static string MonoWindowsBinDir { get @@ -46,8 +36,8 @@ namespace GodotTools.Build { if (OS.IsWindows) { - return (BuildManager.BuildTool)EditorSettings.GetSetting("mono/builds/build_tool") - == BuildManager.BuildTool.MsBuildMono; + return (BuildTool)EditorSettings.GetSetting("mono/builds/build_tool") + == BuildTool.MsBuildMono; } return false; @@ -57,16 +47,21 @@ namespace GodotTools.Build private static bool PrintBuildOutput => (bool)EditorSettings.GetSetting("mono/builds/print_build_output"); - private static Process LaunchBuild(string solution, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) + private static Process LaunchBuild(string solution, IEnumerable<string> targets, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) { + (string msbuildPath, BuildTool buildTool) = MsBuildFinder.FindMsBuild(); + + if (msbuildPath == null) + throw new FileNotFoundException("Cannot find the MSBuild executable."); + var customPropertiesList = new List<string>(); if (customProperties != null) customPropertiesList.AddRange(customProperties); - string compilerArgs = BuildArguments(solution, config, loggerOutputDir, customPropertiesList); + string compilerArgs = BuildArguments(buildTool, solution, targets, config, loggerOutputDir, customPropertiesList); - var startInfo = new ProcessStartInfo(GetMsBuildPath(), compilerArgs); + var startInfo = new ProcessStartInfo(msbuildPath, compilerArgs); bool redirectOutput = !IsDebugMsBuildRequested() && !PrintBuildOutput; @@ -90,7 +85,7 @@ namespace GodotTools.Build // Needed when running from Developer Command Prompt for VS RemovePlatformVariable(startInfo.EnvironmentVariables); - var process = new Process { StartInfo = startInfo }; + var process = new Process {StartInfo = startInfo}; process.Start(); @@ -105,19 +100,19 @@ namespace GodotTools.Build public static int Build(BuildInfo buildInfo) { - return Build(buildInfo.Solution, buildInfo.Configuration, + return Build(buildInfo.Solution, buildInfo.Targets, buildInfo.Configuration, buildInfo.LogsDirPath, buildInfo.CustomProperties); } - public static async Task<int> BuildAsync(BuildInfo buildInfo) + public static Task<int> BuildAsync(BuildInfo buildInfo) { - return await BuildAsync(buildInfo.Solution, buildInfo.Configuration, + return BuildAsync(buildInfo.Solution, buildInfo.Targets, buildInfo.Configuration, buildInfo.LogsDirPath, buildInfo.CustomProperties); } - public static int Build(string solution, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) + public static int Build(string solution, string[] targets, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) { - using (var process = LaunchBuild(solution, config, loggerOutputDir, customProperties)) + using (var process = LaunchBuild(solution, targets, config, loggerOutputDir, customProperties)) { process.WaitForExit(); @@ -125,9 +120,9 @@ namespace GodotTools.Build } } - public static async Task<int> BuildAsync(string solution, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) + public static async Task<int> BuildAsync(string solution, IEnumerable<string> targets, string config, string loggerOutputDir, IEnumerable<string> customProperties = null) { - using (var process = LaunchBuild(solution, config, loggerOutputDir, customProperties)) + using (var process = LaunchBuild(solution, targets, config, loggerOutputDir, customProperties)) { await process.WaitForExitAsync(); @@ -135,10 +130,15 @@ namespace GodotTools.Build } } - private static string BuildArguments(string solution, string config, string loggerOutputDir, List<string> customProperties) + private static string BuildArguments(BuildTool buildTool, string solution, IEnumerable<string> targets, string config, string loggerOutputDir, IEnumerable<string> customProperties) { - string arguments = $@"""{solution}"" /v:normal /t:Build ""/p:{"Configuration=" + config}"" " + - $@"""/l:{typeof(GodotBuildLogger).FullName},{GodotBuildLogger.AssemblyPath};{loggerOutputDir}"""; + string arguments = string.Empty; + + if (buildTool == BuildTool.DotnetCli) + arguments += "msbuild "; // `dotnet msbuild` command + + arguments += $@"""{solution}"" /v:normal /t:{string.Join(",", targets)} ""/p:{"Configuration=" + config}"" " + + $@"""/l:{typeof(GodotBuildLogger).FullName},{GodotBuildLogger.AssemblyPath};{loggerOutputDir}"""; foreach (string customProperty in customProperties) { diff --git a/modules/mono/editor/GodotTools/GodotTools/Build/BuildTool.cs b/modules/mono/editor/GodotTools/GodotTools/Build/BuildTool.cs new file mode 100644 index 0000000000..a1a69334e3 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools/Build/BuildTool.cs @@ -0,0 +1,10 @@ +namespace GodotTools.Build +{ + public enum BuildTool + { + MsBuildMono, + MsBuildVs, + JetBrainsMsBuild, + DotnetCli + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools/Build/MsBuildFinder.cs b/modules/mono/editor/GodotTools/GodotTools/Build/MsBuildFinder.cs index af8d070cbd..f36e581a5f 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Build/MsBuildFinder.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Build/MsBuildFinder.cs @@ -17,75 +17,96 @@ namespace GodotTools.Build private static string _msbuildToolsPath = string.Empty; private static string _msbuildUnixPath = string.Empty; - public static string FindMsBuild() + public static (string, BuildTool) FindMsBuild() { var editorSettings = GodotSharpEditor.Instance.GetEditorInterface().GetEditorSettings(); - var buildTool = (BuildManager.BuildTool)editorSettings.GetSetting("mono/builds/build_tool"); + var buildTool = (BuildTool)editorSettings.GetSetting("mono/builds/build_tool"); if (OS.IsWindows) { switch (buildTool) { - case BuildManager.BuildTool.MsBuildVs: + case BuildTool.DotnetCli: { - if (_msbuildToolsPath.Empty() || !File.Exists(_msbuildToolsPath)) + string dotnetCliPath = OS.PathWhich("dotnet"); + if (!string.IsNullOrEmpty(dotnetCliPath)) + return (dotnetCliPath, BuildTool.DotnetCli); + GD.PushError("Cannot find dotnet CLI executable. Fallback to MSBuild from Visual Studio."); + goto case BuildTool.MsBuildVs; + } + case BuildTool.MsBuildVs: + { + if (string.IsNullOrEmpty(_msbuildToolsPath) || !File.Exists(_msbuildToolsPath)) { // Try to search it again if it wasn't found last time or if it was removed from its location _msbuildToolsPath = FindMsBuildToolsPathOnWindows(); - if (_msbuildToolsPath.Empty()) - { - throw new FileNotFoundException($"Cannot find executable for '{BuildManager.PropNameMsbuildVs}'."); - } + if (string.IsNullOrEmpty(_msbuildToolsPath)) + throw new FileNotFoundException($"Cannot find executable for '{BuildManager.PropNameMSBuildVs}'."); } if (!_msbuildToolsPath.EndsWith("\\")) _msbuildToolsPath += "\\"; - return Path.Combine(_msbuildToolsPath, "MSBuild.exe"); + return (Path.Combine(_msbuildToolsPath, "MSBuild.exe"), BuildTool.MsBuildVs); } - case BuildManager.BuildTool.MsBuildMono: + case BuildTool.MsBuildMono: { string msbuildPath = Path.Combine(Internal.MonoWindowsInstallRoot, "bin", "msbuild.bat"); if (!File.Exists(msbuildPath)) - { - throw new FileNotFoundException($"Cannot find executable for '{BuildManager.PropNameMsbuildMono}'. Tried with path: {msbuildPath}"); - } + throw new FileNotFoundException($"Cannot find executable for '{BuildManager.PropNameMSBuildMono}'. Tried with path: {msbuildPath}"); - return msbuildPath; + return (msbuildPath, BuildTool.MsBuildMono); } - case BuildManager.BuildTool.JetBrainsMsBuild: + case BuildTool.JetBrainsMsBuild: + { var editorPath = (string)editorSettings.GetSetting(RiderPathManager.EditorPathSettingName); + if (!File.Exists(editorPath)) throw new FileNotFoundException($"Cannot find Rider executable. Tried with path: {editorPath}"); - var riderDir = new FileInfo(editorPath).Directory.Parent; - return Path.Combine(riderDir.FullName, @"tools\MSBuild\Current\Bin\MSBuild.exe"); + + var riderDir = new FileInfo(editorPath).Directory?.Parent; + + string msbuildPath = Path.Combine(riderDir.FullName, @"tools\MSBuild\Current\Bin\MSBuild.exe"); + + if (!File.Exists(msbuildPath)) + throw new FileNotFoundException($"Cannot find executable for '{BuildManager.PropNameMSBuildJetBrains}'. Tried with path: {msbuildPath}"); + + return (msbuildPath, BuildTool.JetBrainsMsBuild); + } default: throw new IndexOutOfRangeException("Invalid build tool in editor settings"); } } - if (OS.IsUnixLike()) + if (OS.IsUnixLike) { - if (buildTool == BuildManager.BuildTool.MsBuildMono) + switch (buildTool) { - if (_msbuildUnixPath.Empty() || !File.Exists(_msbuildUnixPath)) + case BuildTool.DotnetCli: { - // Try to search it again if it wasn't found last time or if it was removed from its location - _msbuildUnixPath = FindBuildEngineOnUnix("msbuild"); + string dotnetCliPath = OS.PathWhich("dotnet"); + if (!string.IsNullOrEmpty(dotnetCliPath)) + return (dotnetCliPath, BuildTool.DotnetCli); + GD.PushError("Cannot find dotnet CLI executable. Fallback to MSBuild from Mono."); + goto case BuildTool.MsBuildMono; } - - if (_msbuildUnixPath.Empty()) + case BuildTool.MsBuildMono: { - throw new FileNotFoundException($"Cannot find binary for '{BuildManager.PropNameMsbuildMono}'"); - } + if (string.IsNullOrEmpty(_msbuildUnixPath) || !File.Exists(_msbuildUnixPath)) + { + // Try to search it again if it wasn't found last time or if it was removed from its location + _msbuildUnixPath = FindBuildEngineOnUnix("msbuild"); + } - return _msbuildUnixPath; - } - else - { - throw new IndexOutOfRangeException("Invalid build tool in editor settings"); + if (string.IsNullOrEmpty(_msbuildUnixPath)) + throw new FileNotFoundException($"Cannot find binary for '{BuildManager.PropNameMSBuildMono}'"); + + return (_msbuildUnixPath, BuildTool.MsBuildMono); + } + default: + throw new IndexOutOfRangeException("Invalid build tool in editor settings"); } } @@ -114,12 +135,12 @@ namespace GodotTools.Build { string ret = OS.PathWhich(name); - if (!ret.Empty()) + if (!string.IsNullOrEmpty(ret)) return ret; string retFallback = OS.PathWhich($"{name}.exe"); - if (!retFallback.Empty()) + if (!string.IsNullOrEmpty(retFallback)) return retFallback; foreach (string hintDir in MsBuildHintDirs) @@ -143,7 +164,7 @@ namespace GodotTools.Build string vsWherePath = Environment.GetEnvironmentVariable(Internal.GodotIs32Bits() ? "ProgramFiles" : "ProgramFiles(x86)"); vsWherePath += "\\Microsoft Visual Studio\\Installer\\vswhere.exe"; - var vsWhereArgs = new[] { "-latest", "-products", "*", "-requires", "Microsoft.Component.MSBuild" }; + var vsWhereArgs = new[] {"-latest", "-products", "*", "-requires", "Microsoft.Component.MSBuild"}; var outputArray = new Godot.Collections.Array<string>(); int exitCode = Godot.OS.Execute(vsWherePath, vsWhereArgs, @@ -171,7 +192,7 @@ namespace GodotTools.Build string value = line.Substring(sepIdx + 1).StripEdges(); - if (value.Empty()) + if (string.IsNullOrEmpty(value)) throw new FormatException("installationPath value is empty"); if (!value.EndsWith("\\")) diff --git a/modules/mono/editor/GodotTools/GodotTools/BuildInfo.cs b/modules/mono/editor/GodotTools/GodotTools/BuildInfo.cs index 70bd552f2f..cca0983c01 100644 --- a/modules/mono/editor/GodotTools/GodotTools/BuildInfo.cs +++ b/modules/mono/editor/GodotTools/GodotTools/BuildInfo.cs @@ -10,6 +10,7 @@ namespace GodotTools public sealed class BuildInfo : Reference // TODO Remove Reference once we have proper serialization { public string Solution { get; } + public string[] Targets { get; } public string Configuration { get; } public Array<string> CustomProperties { get; } = new Array<string>(); // TODO Use List once we have proper serialization @@ -38,9 +39,10 @@ namespace GodotTools { } - public BuildInfo(string solution, string configuration) + public BuildInfo(string solution, string[] targets, string configuration) { Solution = solution; + Targets = targets; Configuration = configuration; } } diff --git a/modules/mono/editor/GodotTools/GodotTools/BuildManager.cs b/modules/mono/editor/GodotTools/GodotTools/BuildManager.cs index 520e665595..598787ba03 100644 --- a/modules/mono/editor/GodotTools/GodotTools/BuildManager.cs +++ b/modules/mono/editor/GodotTools/GodotTools/BuildManager.cs @@ -15,20 +15,14 @@ namespace GodotTools { private static readonly List<BuildInfo> BuildsInProgress = new List<BuildInfo>(); - public const string PropNameMsbuildMono = "MSBuild (Mono)"; - public const string PropNameMsbuildVs = "MSBuild (VS Build Tools)"; - public const string PropNameMsbuildJetBrains = "MSBuild (JetBrains Rider)"; + public const string PropNameMSBuildMono = "MSBuild (Mono)"; + public const string PropNameMSBuildVs = "MSBuild (VS Build Tools)"; + public const string PropNameMSBuildJetBrains = "MSBuild (JetBrains Rider)"; + public const string PropNameDotnetCli = "dotnet CLI"; public const string MsBuildIssuesFileName = "msbuild_issues.csv"; public const string MsBuildLogFileName = "msbuild_log.txt"; - public enum BuildTool - { - MsBuildMono, - MsBuildVs, - JetBrainsMsBuild - } - private static void RemoveOldIssuesFile(BuildInfo buildInfo) { var issuesFile = GetIssuesFilePath(buildInfo); @@ -181,10 +175,12 @@ namespace GodotTools { pr.Step("Building project solution", 0); - var buildInfo = new BuildInfo(GodotSharpDirs.ProjectSlnPath, config); + var buildInfo = new BuildInfo(GodotSharpDirs.ProjectSlnPath, targets: new[] {"Restore", "Build"}, config); + + bool escapeNeedsDoubleBackslash = buildTool == BuildTool.MsBuildMono || buildTool == BuildTool.DotnetCli; // Add Godot defines - string constants = buildTool != BuildTool.MsBuildMono ? "GodotDefineConstants=\"" : "GodotDefineConstants=\\\""; + string constants = !escapeNeedsDoubleBackslash ? "GodotDefineConstants=\"" : "GodotDefineConstants=\\\""; foreach (var godotDefine in godotDefines) constants += $"GODOT_{godotDefine.ToUpper().Replace("-", "_").Replace(" ", "_").Replace(";", "_")};"; @@ -192,7 +188,7 @@ namespace GodotTools if (Internal.GodotIsRealTDouble()) constants += "GODOT_REAL_T_IS_DOUBLE;"; - constants += buildTool != BuildTool.MsBuildMono ? "\"" : "\\\""; + constants += !escapeNeedsDoubleBackslash ? "\"" : "\\\""; buildInfo.CustomProperties.Add(constants); @@ -219,7 +215,7 @@ namespace GodotTools if (File.Exists(editorScriptsMetadataPath)) File.Copy(editorScriptsMetadataPath, playerScriptsMetadataPath); - var currentPlayRequest = GodotSharpEditor.Instance.GodotIdeManager.GodotIdeServer.CurrentPlayRequest; + var currentPlayRequest = GodotSharpEditor.Instance.CurrentPlaySettings; if (currentPlayRequest != null) { @@ -233,7 +229,8 @@ namespace GodotTools ",server=n"); } - return true; // Requested play from an external editor/IDE which already built the project + if (!currentPlayRequest.Value.BuildBeforePlaying) + return true; // Requested play from an external editor/IDE which already built the project } var godotDefines = new[] @@ -249,22 +246,44 @@ namespace GodotTools { // Build tool settings var editorSettings = GodotSharpEditor.Instance.GetEditorInterface().GetEditorSettings(); - var msbuild = BuildTool.MsBuildMono; + + BuildTool msbuildDefault; + if (OS.IsWindows) - msbuild = RiderPathManager.IsExternalEditorSetToRider(editorSettings) - ? BuildTool.JetBrainsMsBuild - : BuildTool.MsBuildVs; + { + if (RiderPathManager.IsExternalEditorSetToRider(editorSettings)) + msbuildDefault = BuildTool.JetBrainsMsBuild; + else + msbuildDefault = !string.IsNullOrEmpty(OS.PathWhich("dotnet")) ? BuildTool.DotnetCli : BuildTool.MsBuildVs; + } + else + { + msbuildDefault = !string.IsNullOrEmpty(OS.PathWhich("dotnet")) ? BuildTool.DotnetCli : BuildTool.MsBuildMono; + } + + EditorDef("mono/builds/build_tool", msbuildDefault); - EditorDef("mono/builds/build_tool", msbuild); + string hintString; + + if (OS.IsWindows) + { + hintString = $"{PropNameMSBuildMono}:{(int)BuildTool.MsBuildMono}," + + $"{PropNameMSBuildVs}:{(int)BuildTool.MsBuildVs}," + + $"{PropNameMSBuildJetBrains}:{(int)BuildTool.JetBrainsMsBuild}," + + $"{PropNameDotnetCli}:{(int)BuildTool.DotnetCli}"; + } + else + { + hintString = $"{PropNameMSBuildMono}:{(int)BuildTool.MsBuildMono}," + + $"{PropNameDotnetCli}:{(int)BuildTool.DotnetCli}"; + } editorSettings.AddPropertyInfo(new Godot.Collections.Dictionary { ["type"] = Godot.Variant.Type.Int, ["name"] = "mono/builds/build_tool", ["hint"] = Godot.PropertyHint.Enum, - ["hint_string"] = OS.IsWindows ? - $"{PropNameMsbuildMono},{PropNameMsbuildVs},{PropNameMsbuildJetBrains}" : - $"{PropNameMsbuildMono}" + ["hint_string"] = hintString }); EditorDef("mono/builds/print_build_output", false); diff --git a/modules/mono/editor/GodotTools/GodotTools/BuildTab.cs b/modules/mono/editor/GodotTools/GodotTools/BuildTab.cs index 938c3d8be1..0106e1f1ac 100644 --- a/modules/mono/editor/GodotTools/GodotTools/BuildTab.cs +++ b/modules/mono/editor/GodotTools/GodotTools/BuildTab.cs @@ -72,7 +72,7 @@ namespace GodotTools { string[] csvColumns = file.GetCsvLine(); - if (csvColumns.Length == 1 && csvColumns[0].Empty()) + if (csvColumns.Length == 1 && string.IsNullOrEmpty(csvColumns[0])) return; if (csvColumns.Length != 7) @@ -115,12 +115,12 @@ namespace GodotTools // Get correct issue idx from issue list int issueIndex = (int)issuesList.GetItemMetadata(idx); - if (idx < 0 || idx >= issues.Count) + if (issueIndex < 0 || issueIndex >= issues.Count) throw new IndexOutOfRangeException("Issue index out of range"); BuildIssue issue = issues[issueIndex]; - if (issue.ProjectFile.Empty() && issue.File.Empty()) + if (string.IsNullOrEmpty(issue.ProjectFile) && string.IsNullOrEmpty(issue.File)) return; string projectDir = issue.ProjectFile.Length > 0 ? issue.ProjectFile.GetBaseDir() : BuildInfo.Solution.GetBaseDir(); @@ -158,14 +158,14 @@ namespace GodotTools string tooltip = string.Empty; tooltip += $"Message: {issue.Message}"; - if (!issue.Code.Empty()) + if (!string.IsNullOrEmpty(issue.Code)) tooltip += $"\nCode: {issue.Code}"; tooltip += $"\nType: {(issue.Warning ? "warning" : "error")}"; string text = string.Empty; - if (!issue.File.Empty()) + if (!string.IsNullOrEmpty(issue.File)) { text += $"{issue.File}({issue.Line},{issue.Column}): "; @@ -174,7 +174,7 @@ namespace GodotTools tooltip += $"\nColumn: {issue.Column}"; } - if (!issue.ProjectFile.Empty()) + if (!string.IsNullOrEmpty(issue.ProjectFile)) tooltip += $"\nProject: {issue.ProjectFile}"; text += issue.Message; diff --git a/modules/mono/editor/GodotTools/GodotTools/Export/AotBuilder.cs b/modules/mono/editor/GodotTools/GodotTools/Export/AotBuilder.cs index f1765f7e19..f60e469503 100755 --- a/modules/mono/editor/GodotTools/GodotTools/Export/AotBuilder.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Export/AotBuilder.cs @@ -587,7 +587,7 @@ MONO_AOT_MODE_LAST = 1000, string arch = "x86_64"; return $"{platform}-{arch}"; } - case OS.Platforms.X11: + case OS.Platforms.LinuxBSD: case OS.Platforms.Server: { string arch = bits == "64" ? "x86_64" : "i686"; diff --git a/modules/mono/editor/GodotTools/GodotTools/Export/ExportPlugin.cs b/modules/mono/editor/GodotTools/GodotTools/Export/ExportPlugin.cs index 2ceb4888a2..6bfbc62f3b 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Export/ExportPlugin.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Export/ExportPlugin.cs @@ -414,7 +414,7 @@ namespace GodotTools.Export case OS.Platforms.UWP: return "net_4_x_win"; case OS.Platforms.OSX: - case OS.Platforms.X11: + case OS.Platforms.LinuxBSD: case OS.Platforms.Server: case OS.Platforms.Haiku: return "net_4_x"; diff --git a/modules/mono/editor/GodotTools/GodotTools/GodotSharpEditor.cs b/modules/mono/editor/GodotTools/GodotTools/GodotSharpEditor.cs index c070cb16d9..eb7696685f 100644 --- a/modules/mono/editor/GodotTools/GodotTools/GodotSharpEditor.cs +++ b/modules/mono/editor/GodotTools/GodotTools/GodotSharpEditor.cs @@ -37,6 +37,8 @@ namespace GodotTools public BottomPanel BottomPanel { get; private set; } + public PlaySettings? CurrentPlaySettings { get; set; } + public static string ProjectAssemblyName { get @@ -228,12 +230,12 @@ namespace GodotTools [UsedImplicitly] public Error OpenInExternalEditor(Script script, int line, int col) { - var editor = (ExternalEditorId)editorSettings.GetSetting("mono/editor/external_editor"); + var editorId = (ExternalEditorId)editorSettings.GetSetting("mono/editor/external_editor"); - switch (editor) + switch (editorId) { case ExternalEditorId.None: - // Tells the caller to fallback to the global external editor settings or the built-in editor + // Not an error. Tells the caller to fallback to the global external editor settings or the built-in editor. return Error.Unavailable; case ExternalEditorId.VisualStudio: throw new NotSupportedException(); @@ -249,17 +251,20 @@ namespace GodotTools { string scriptPath = ProjectSettings.GlobalizePath(script.ResourcePath); - if (line >= 0) - GodotIdeManager.SendOpenFile(scriptPath, line + 1, col); - else - GodotIdeManager.SendOpenFile(scriptPath); + GodotIdeManager.LaunchIdeAsync().ContinueWith(launchTask => + { + var editorPick = launchTask.Result; + if (line >= 0) + editorPick?.SendOpenFile(scriptPath, line + 1, col); + else + editorPick?.SendOpenFile(scriptPath); + }); break; } - case ExternalEditorId.VsCode: { - if (_vsCodePath.Empty() || !File.Exists(_vsCodePath)) + if (string.IsNullOrEmpty(_vsCodePath) || !File.Exists(_vsCodePath)) { // Try to search it again if it wasn't found last time or if it was removed from its location _vsCodePath = VsCodeNames.SelectFirstNotNull(OS.PathWhich, orElse: string.Empty); @@ -300,7 +305,7 @@ namespace GodotTools if (line >= 0) { args.Add("-g"); - args.Add($"{scriptPath}:{line + 1}:{col}"); + args.Add($"{scriptPath}:{line}:{col}"); } else { @@ -311,7 +316,7 @@ namespace GodotTools if (OS.IsOSX) { - if (!osxAppBundleInstalled && _vsCodePath.Empty()) + if (!osxAppBundleInstalled && string.IsNullOrEmpty(_vsCodePath)) { GD.PushError("Cannot find code editor: VSCode"); return Error.FileNotFound; @@ -321,7 +326,7 @@ namespace GodotTools } else { - if (_vsCodePath.Empty()) + if (string.IsNullOrEmpty(_vsCodePath)) { GD.PushError("Cannot find code editor: VSCode"); return Error.FileNotFound; @@ -341,7 +346,6 @@ namespace GodotTools break; } - default: throw new ArgumentOutOfRangeException(); } @@ -457,6 +461,9 @@ namespace GodotTools // Make sure the existing project has Api assembly references configured correctly ProjectUtils.FixApiHintPath(msbuildProject); + // Make sure the existing project references the Microsoft.NETFramework.ReferenceAssemblies nuget package + ProjectUtils.EnsureHasNugetNetFrameworkRefAssemblies(msbuildProject); + if (msbuildProject.HasUnsavedChanges) { // Save a copy of the project before replacing it @@ -505,7 +512,7 @@ namespace GodotTools $",Visual Studio Code:{(int)ExternalEditorId.VsCode}" + $",JetBrains Rider:{(int)ExternalEditorId.Rider}"; } - else if (OS.IsUnixLike()) + else if (OS.IsUnixLike) { settingsHintStr += $",MonoDevelop:{(int)ExternalEditorId.MonoDevelop}" + $",Visual Studio Code:{(int)ExternalEditorId.VsCode}" + diff --git a/modules/mono/editor/GodotTools/GodotTools/GodotTools.csproj b/modules/mono/editor/GodotTools/GodotTools/GodotTools.csproj index e1570d6465..ba527ca3b5 100644 --- a/modules/mono/editor/GodotTools/GodotTools/GodotTools.csproj +++ b/modules/mono/editor/GodotTools/GodotTools/GodotTools.csproj @@ -1,14 +1,8 @@ -<?xml version="1.0" encoding="utf-8"?> -<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> +<Project Sdk="Microsoft.NET.Sdk"> <PropertyGroup> - <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> - <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> <ProjectGuid>{27B00618-A6F2-4828-B922-05CAEB08C286}</ProjectGuid> - <OutputType>Library</OutputType> - <RootNamespace>GodotTools</RootNamespace> - <AssemblyName>GodotTools</AssemblyName> - <TargetFrameworkVersion>v4.7</TargetFrameworkVersion> - <LangVersion>7</LangVersion> + <TargetFramework>net472</TargetFramework> + <LangVersion>7.2</LangVersion> <GodotApiConfiguration>Debug</GodotApiConfiguration> <!-- The Godot editor uses the Debug Godot API assemblies --> <GodotSourceRootPath>$(SolutionDir)/../../../../</GodotSourceRootPath> <GodotOutputDataDir>$(GodotSourceRootPath)/bin/GodotSharp</GodotOutputDataDir> @@ -18,32 +12,13 @@ <!-- The project is part of the Godot source tree --> <!-- Use the Godot source tree output folder instead of '$(ProjectDir)/bin' --> <OutputPath>$(GodotOutputDataDir)/Tools</OutputPath> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> - <DebugSymbols>true</DebugSymbols> - <DebugType>portable</DebugType> - <Optimize>false</Optimize> - <DefineConstants>DEBUG;</DefineConstants> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> - </PropertyGroup> - <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> - <Optimize>true</Optimize> - <ErrorReport>prompt</ErrorReport> - <WarningLevel>4</WarningLevel> - <ConsolePause>false</ConsolePause> + <!-- Must not append '$(TargetFramework)' to the output path in this case --> + <AppendTargetFrameworkToOutputPath>False</AppendTargetFrameworkToOutputPath> </PropertyGroup> <ItemGroup> - <Reference Include="JetBrains.Annotations, Version=2019.1.3.0, Culture=neutral, PublicKeyToken=1010a0d8d6380325"> - <HintPath>..\packages\JetBrains.Annotations.2019.1.3\lib\net20\JetBrains.Annotations.dll</HintPath> - <Private>True</Private> - </Reference> - <Reference Include="Newtonsoft.Json, Version=12.0.0.0, Culture=neutral, PublicKeyToken=30ad4fe6b2a6aeed"> - <HintPath>..\packages\Newtonsoft.Json.12.0.3\lib\net45\Newtonsoft.Json.dll</HintPath> - <Private>True</Private> - </Reference> - <Reference Include="System" /> + <PackageReference Include="JetBrains.Annotations" Version="2019.1.3.0" ExcludeAssets="runtime" PrivateAssets="all" /> + <PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="1.0.0" PrivateAssets="All" /> + <PackageReference Include="Newtonsoft.Json" Version="12.0.3" /> <Reference Include="GodotSharp"> <HintPath>$(GodotApiAssembliesDir)/GodotSharp.dll</HintPath> <Private>False</Private> @@ -54,58 +29,9 @@ </Reference> </ItemGroup> <ItemGroup> - <Compile Include="Build\MsBuildFinder.cs" /> - <Compile Include="Export\AotBuilder.cs" /> - <Compile Include="Export\ExportPlugin.cs" /> - <Compile Include="Export\XcodeHelper.cs" /> - <Compile Include="ExternalEditorId.cs" /> - <Compile Include="Ides\GodotIdeManager.cs" /> - <Compile Include="Ides\GodotIdeServer.cs" /> - <Compile Include="Ides\MonoDevelop\EditorId.cs" /> - <Compile Include="Ides\MonoDevelop\Instance.cs" /> - <Compile Include="Ides\Rider\RiderPathLocator.cs" /> - <Compile Include="Ides\Rider\RiderPathManager.cs" /> - <Compile Include="Internals\EditorProgress.cs" /> - <Compile Include="Internals\GodotSharpDirs.cs" /> - <Compile Include="Internals\Internal.cs" /> - <Compile Include="Internals\ScriptClassParser.cs" /> - <Compile Include="Internals\Globals.cs" /> - <Compile Include="Properties\AssemblyInfo.cs" /> - <Compile Include="Build\BuildSystem.cs" /> - <Compile Include="Utils\Directory.cs" /> - <Compile Include="Utils\File.cs" /> - <Compile Include="Utils\NotifyAwaiter.cs" /> - <Compile Include="Utils\OS.cs" /> - <Compile Include="GodotSharpEditor.cs" /> - <Compile Include="BuildManager.cs" /> - <Compile Include="HotReloadAssemblyWatcher.cs" /> - <Compile Include="BuildInfo.cs" /> - <Compile Include="BuildTab.cs" /> - <Compile Include="BottomPanel.cs" /> - <Compile Include="CsProjOperations.cs" /> - <Compile Include="Utils\CollectionExtensions.cs" /> - <Compile Include="Utils\User32Dll.cs" /> - </ItemGroup> - <ItemGroup> - <ProjectReference Include="..\GodotTools.BuildLogger\GodotTools.BuildLogger.csproj"> - <Project>{6ce9a984-37b1-4f8a-8fe9-609f05f071b3}</Project> - <Name>GodotTools.BuildLogger</Name> - </ProjectReference> - <ProjectReference Include="..\GodotTools.IdeConnection\GodotTools.IdeConnection.csproj"> - <Project>{92600954-25f0-4291-8e11-1fee9fc4be20}</Project> - <Name>GodotTools.IdeConnection</Name> - </ProjectReference> - <ProjectReference Include="..\GodotTools.ProjectEditor\GodotTools.ProjectEditor.csproj"> - <Project>{A8CDAD94-C6D4-4B19-A7E7-76C53CC92984}</Project> - <Name>GodotTools.ProjectEditor</Name> - </ProjectReference> - <ProjectReference Include="..\GodotTools.Core\GodotTools.Core.csproj"> - <Project>{639E48BD-44E5-4091-8EDD-22D36DC0768D}</Project> - <Name>GodotTools.Core</Name> - </ProjectReference> - </ItemGroup> - <ItemGroup> - <None Include="packages.config" /> + <ProjectReference Include="..\GodotTools.BuildLogger\GodotTools.BuildLogger.csproj" /> + <ProjectReference Include="..\GodotTools.IdeMessaging\GodotTools.IdeMessaging.csproj" /> + <ProjectReference Include="..\GodotTools.ProjectEditor\GodotTools.ProjectEditor.csproj" /> + <ProjectReference Include="..\GodotTools.Core\GodotTools.Core.csproj" /> </ItemGroup> - <Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" /> </Project> diff --git a/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeManager.cs b/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeManager.cs index 54f0ffab96..e4932ca217 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeManager.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeManager.cs @@ -1,73 +1,104 @@ using System; using System.IO; +using System.Threading.Tasks; using Godot; -using GodotTools.IdeConnection; +using GodotTools.IdeMessaging; +using GodotTools.IdeMessaging.Requests; using GodotTools.Internals; namespace GodotTools.Ides { - public class GodotIdeManager : Node, ISerializationListener + public sealed class GodotIdeManager : Node, ISerializationListener { - public GodotIdeServer GodotIdeServer { get; private set; } + private MessagingServer MessagingServer { get; set; } private MonoDevelop.Instance monoDevelInstance; private MonoDevelop.Instance vsForMacInstance; - private GodotIdeServer GetRunningServer() + private MessagingServer GetRunningOrNewServer() { - if (GodotIdeServer != null && !GodotIdeServer.IsDisposed) - return GodotIdeServer; - StartServer(); - return GodotIdeServer; + if (MessagingServer != null && !MessagingServer.IsDisposed) + return MessagingServer; + + MessagingServer?.Dispose(); + MessagingServer = new MessagingServer(OS.GetExecutablePath(), ProjectSettings.GlobalizePath(GodotSharpDirs.ResMetadataDir), new GodotLogger()); + + _ = MessagingServer.Listen(); + + return MessagingServer; } public override void _Ready() { - StartServer(); + _ = GetRunningOrNewServer(); } public void OnBeforeSerialize() { - GodotIdeServer?.Dispose(); } public void OnAfterDeserialize() { - StartServer(); + _ = GetRunningOrNewServer(); } - private ILogger logger; + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); - protected ILogger Logger + if (disposing) + { + MessagingServer?.Dispose(); + } + } + + private string GetExternalEditorIdentity(ExternalEditorId editorId) { - get => logger ?? (logger = new GodotLogger()); + // Manually convert to string to avoid breaking compatibility in case we rename the enum fields. + switch (editorId) + { + case ExternalEditorId.None: + return null; + case ExternalEditorId.VisualStudio: + return "VisualStudio"; + case ExternalEditorId.VsCode: + return "VisualStudioCode"; + case ExternalEditorId.Rider: + return "Rider"; + case ExternalEditorId.VisualStudioForMac: + return "VisualStudioForMac"; + case ExternalEditorId.MonoDevelop: + return "MonoDevelop"; + default: + throw new NotImplementedException(); + } } - private void StartServer() + public async Task<EditorPick?> LaunchIdeAsync(int millisecondsTimeout = 10000) { - GodotIdeServer?.Dispose(); - GodotIdeServer = new GodotIdeServer(LaunchIde, - OS.GetExecutablePath(), - ProjectSettings.GlobalizePath(GodotSharpDirs.ResMetadataDir)); + var editorId = (ExternalEditorId)GodotSharpEditor.Instance.GetEditorInterface() + .GetEditorSettings().GetSetting("mono/editor/external_editor"); + string editorIdentity = GetExternalEditorIdentity(editorId); - GodotIdeServer.Logger = Logger; + var runningServer = GetRunningOrNewServer(); - GodotIdeServer.StartServer(); - } + if (runningServer.IsAnyConnected(editorIdentity)) + return new EditorPick(editorIdentity); - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); + LaunchIde(editorId, editorIdentity); + + var timeoutTask = Task.Delay(millisecondsTimeout); + var completedTask = await Task.WhenAny(timeoutTask, runningServer.AwaitClientConnected(editorIdentity)); + + if (completedTask != timeoutTask) + return new EditorPick(editorIdentity); - GodotIdeServer?.Dispose(); + return null; } - private void LaunchIde() + private void LaunchIde(ExternalEditorId editorId, string editorIdentity) { - var editor = (ExternalEditorId)GodotSharpEditor.Instance.GetEditorInterface() - .GetEditorSettings().GetSetting("mono/editor/external_editor"); - - switch (editor) + switch (editorId) { case ExternalEditorId.None: case ExternalEditorId.VisualStudio: @@ -80,14 +111,14 @@ namespace GodotTools.Ides { MonoDevelop.Instance GetMonoDevelopInstance(string solutionPath) { - if (Utils.OS.IsOSX && editor == ExternalEditorId.VisualStudioForMac) + if (Utils.OS.IsOSX && editorId == ExternalEditorId.VisualStudioForMac) { - vsForMacInstance = vsForMacInstance ?? + vsForMacInstance = (vsForMacInstance?.IsDisposed ?? true ? null : vsForMacInstance) ?? new MonoDevelop.Instance(solutionPath, MonoDevelop.EditorId.VisualStudioForMac); return vsForMacInstance; } - monoDevelInstance = monoDevelInstance ?? + monoDevelInstance = (monoDevelInstance?.IsDisposed ?? true ? null : monoDevelInstance) ?? new MonoDevelop.Instance(solutionPath, MonoDevelop.EditorId.MonoDevelop); return monoDevelInstance; } @@ -96,12 +127,25 @@ namespace GodotTools.Ides { var instance = GetMonoDevelopInstance(GodotSharpDirs.ProjectSlnPath); - if (!instance.IsRunning) + if (instance.IsRunning && !GetRunningOrNewServer().IsAnyConnected(editorIdentity)) + { + // After launch we wait up to 30 seconds for the IDE to connect to our messaging server. + var waitAfterLaunch = TimeSpan.FromSeconds(30); + var timeSinceLaunch = DateTime.Now - instance.LaunchTime; + if (timeSinceLaunch > waitAfterLaunch) + { + instance.Dispose(); + instance.Execute(); + } + } + else if (!instance.IsRunning) + { instance.Execute(); + } } catch (FileNotFoundException) { - string editorName = editor == ExternalEditorId.VisualStudioForMac ? "Visual Studio" : "MonoDevelop"; + string editorName = editorId == ExternalEditorId.VisualStudioForMac ? "Visual Studio" : "MonoDevelop"; GD.PushError($"Cannot find code editor: {editorName}"); } @@ -113,26 +157,45 @@ namespace GodotTools.Ides } } - private void WriteMessage(string id, params string[] arguments) + public readonly struct EditorPick { - GetRunningServer().WriteMessage(new Message(id, arguments)); - } + private readonly string identity; - public void SendOpenFile(string file) - { - WriteMessage("OpenFile", file); - } + public EditorPick(string identity) + { + this.identity = identity; + } - public void SendOpenFile(string file, int line) - { - WriteMessage("OpenFile", file, line.ToString()); - } + public bool IsAnyConnected() => + GodotSharpEditor.Instance.GodotIdeManager.GetRunningOrNewServer().IsAnyConnected(identity); - public void SendOpenFile(string file, int line, int column) - { - WriteMessage("OpenFile", file, line.ToString(), column.ToString()); + private void SendRequest<TResponse>(Request request) + where TResponse : Response, new() + { + // Logs an error if no client is connected with the specified identity + GodotSharpEditor.Instance.GodotIdeManager + .GetRunningOrNewServer() + .BroadcastRequest<TResponse>(identity, request); + } + + public void SendOpenFile(string file) + { + SendRequest<OpenFileResponse>(new OpenFileRequest {File = file}); + } + + public void SendOpenFile(string file, int line) + { + SendRequest<OpenFileResponse>(new OpenFileRequest {File = file, Line = line}); + } + + public void SendOpenFile(string file, int line, int column) + { + SendRequest<OpenFileResponse>(new OpenFileRequest {File = file, Line = line, Column = column}); + } } + public EditorPick PickEditor(ExternalEditorId editorId) => new EditorPick(GetExternalEditorIdentity(editorId)); + private class GodotLogger : ILogger { public void LogDebug(string message) diff --git a/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeServer.cs b/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeServer.cs deleted file mode 100644 index 72676a8b24..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools/Ides/GodotIdeServer.cs +++ /dev/null @@ -1,212 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Net; -using System.Net.Sockets; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using GodotTools.IdeConnection; -using GodotTools.Internals; -using GodotTools.Utils; -using Directory = System.IO.Directory; -using File = System.IO.File; -using Thread = System.Threading.Thread; - -namespace GodotTools.Ides -{ - public class GodotIdeServer : GodotIdeBase - { - private readonly TcpListener listener; - private readonly FileStream metaFile; - private readonly Action launchIdeAction; - private readonly NotifyAwaiter<bool> clientConnectedAwaiter = new NotifyAwaiter<bool>(); - - private async Task<bool> AwaitClientConnected() - { - return await clientConnectedAwaiter.Reset(); - } - - public GodotIdeServer(Action launchIdeAction, string editorExecutablePath, string projectMetadataDir) - : base(projectMetadataDir) - { - messageHandlers = InitializeMessageHandlers(); - - this.launchIdeAction = launchIdeAction; - - // Make sure the directory exists - Directory.CreateDirectory(projectMetadataDir); - - // The Godot editor's file system thread can keep the file open for writing, so we are forced to allow write sharing... - const FileShare metaFileShare = FileShare.ReadWrite; - - metaFile = File.Open(MetaFilePath, FileMode.Create, FileAccess.Write, metaFileShare); - - listener = new TcpListener(new IPEndPoint(IPAddress.Loopback, port: 0)); - listener.Start(); - - int port = ((IPEndPoint)listener.Server.LocalEndPoint).Port; - using (var metaFileWriter = new StreamWriter(metaFile, Encoding.UTF8)) - { - metaFileWriter.WriteLine(port); - metaFileWriter.WriteLine(editorExecutablePath); - } - - StartServer(); - } - - public void StartServer() - { - var serverThread = new Thread(RunServerThread) { Name = "Godot Ide Connection Server" }; - serverThread.Start(); - } - - private void RunServerThread() - { - SynchronizationContext.SetSynchronizationContext(Godot.Dispatcher.SynchronizationContext); - - try - { - while (!IsDisposed) - { - TcpClient tcpClient = listener.AcceptTcpClient(); - - Logger.LogInfo("Connection open with Ide Client"); - - lock (ConnectionLock) - { - Connection = new GodotIdeConnectionServer(tcpClient, HandleMessage); - Connection.Logger = Logger; - } - - Connected += () => clientConnectedAwaiter.SetResult(true); - - Connection.Start(); - } - } - catch (Exception e) - { - if (!IsDisposed && !(e is SocketException se && se.SocketErrorCode == SocketError.Interrupted)) - throw; - } - } - - public async void WriteMessage(Message message) - { - async Task LaunchIde() - { - if (IsConnected) - return; - - launchIdeAction(); - await Task.WhenAny(Task.Delay(10000), AwaitClientConnected()); - } - - await LaunchIde(); - - if (!IsConnected) - { - Logger.LogError("Cannot write message: Godot Ide Server not connected"); - return; - } - - Connection.WriteMessage(message); - } - - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - - if (disposing) - { - listener?.Stop(); - - metaFile?.Dispose(); - - File.Delete(MetaFilePath); - } - } - - protected virtual bool HandleMessage(Message message) - { - if (messageHandlers.TryGetValue(message.Id, out var action)) - { - action(message.Arguments); - return true; - } - - return false; - } - - private readonly Dictionary<string, Action<string[]>> messageHandlers; - - private Dictionary<string, Action<string[]>> InitializeMessageHandlers() - { - return new Dictionary<string, Action<string[]>> - { - ["Play"] = args => - { - switch (args.Length) - { - case 0: - Play(); - return; - case 2: - Play(debuggerHost: args[0], debuggerPort: int.Parse(args[1])); - return; - default: - throw new ArgumentException(); - } - }, - ["ReloadScripts"] = args => ReloadScripts() - }; - } - - private void DispatchToMainThread(Action action) - { - var d = new SendOrPostCallback(state => action()); - Godot.Dispatcher.SynchronizationContext.Post(d, null); - } - - private void Play() - { - DispatchToMainThread(() => - { - CurrentPlayRequest = new PlayRequest(); - Internal.EditorRunPlay(); - CurrentPlayRequest = null; - }); - } - - private void Play(string debuggerHost, int debuggerPort) - { - DispatchToMainThread(() => - { - CurrentPlayRequest = new PlayRequest(debuggerHost, debuggerPort); - Internal.EditorRunPlay(); - CurrentPlayRequest = null; - }); - } - - private void ReloadScripts() - { - DispatchToMainThread(Internal.ScriptEditorDebugger_ReloadScripts); - } - - public PlayRequest? CurrentPlayRequest { get; private set; } - - public struct PlayRequest - { - public bool HasDebugger { get; } - public string DebuggerHost { get; } - public int DebuggerPort { get; } - - public PlayRequest(string debuggerHost, int debuggerPort) - { - HasDebugger = true; - DebuggerHost = debuggerHost; - DebuggerPort = debuggerPort; - } - } - } -} diff --git a/modules/mono/editor/GodotTools/GodotTools/Ides/MessagingServer.cs b/modules/mono/editor/GodotTools/GodotTools/Ides/MessagingServer.cs new file mode 100644 index 0000000000..32f264d100 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools/Ides/MessagingServer.cs @@ -0,0 +1,360 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using GodotTools.IdeMessaging; +using GodotTools.IdeMessaging.Requests; +using GodotTools.IdeMessaging.Utils; +using GodotTools.Internals; +using Newtonsoft.Json; +using Directory = System.IO.Directory; +using File = System.IO.File; + +namespace GodotTools.Ides +{ + public sealed class MessagingServer : IDisposable + { + private readonly ILogger logger; + + private readonly FileStream metaFile; + private string MetaFilePath { get; } + + private readonly SemaphoreSlim peersSem = new SemaphoreSlim(1); + + private readonly TcpListener listener; + + private readonly Dictionary<string, Queue<NotifyAwaiter<bool>>> clientConnectedAwaiters = new Dictionary<string, Queue<NotifyAwaiter<bool>>>(); + private readonly Dictionary<string, Queue<NotifyAwaiter<bool>>> clientDisconnectedAwaiters = new Dictionary<string, Queue<NotifyAwaiter<bool>>>(); + + public async Task<bool> AwaitClientConnected(string identity) + { + if (!clientConnectedAwaiters.TryGetValue(identity, out var queue)) + { + queue = new Queue<NotifyAwaiter<bool>>(); + clientConnectedAwaiters.Add(identity, queue); + } + + var awaiter = new NotifyAwaiter<bool>(); + queue.Enqueue(awaiter); + return await awaiter; + } + + public async Task<bool> AwaitClientDisconnected(string identity) + { + if (!clientDisconnectedAwaiters.TryGetValue(identity, out var queue)) + { + queue = new Queue<NotifyAwaiter<bool>>(); + clientDisconnectedAwaiters.Add(identity, queue); + } + + var awaiter = new NotifyAwaiter<bool>(); + queue.Enqueue(awaiter); + return await awaiter; + } + + public bool IsDisposed { get; private set; } + + public bool IsAnyConnected(string identity) => string.IsNullOrEmpty(identity) ? + Peers.Count > 0 : + Peers.Any(c => c.RemoteIdentity == identity); + + private List<Peer> Peers { get; } = new List<Peer>(); + + ~MessagingServer() + { + Dispose(disposing: false); + } + + public async void Dispose() + { + if (IsDisposed) + return; + + using (await peersSem.UseAsync()) + { + if (IsDisposed) // lock may not be fair + return; + IsDisposed = true; + } + + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + foreach (var connection in Peers) + connection.Dispose(); + Peers.Clear(); + listener?.Stop(); + + metaFile?.Dispose(); + + File.Delete(MetaFilePath); + } + } + + public MessagingServer(string editorExecutablePath, string projectMetadataDir, ILogger logger) + { + this.logger = logger; + + MetaFilePath = Path.Combine(projectMetadataDir, GodotIdeMetadata.DefaultFileName); + + // Make sure the directory exists + Directory.CreateDirectory(projectMetadataDir); + + // The Godot editor's file system thread can keep the file open for writing, so we are forced to allow write sharing... + const FileShare metaFileShare = FileShare.ReadWrite; + + metaFile = File.Open(MetaFilePath, FileMode.Create, FileAccess.Write, metaFileShare); + + listener = new TcpListener(new IPEndPoint(IPAddress.Loopback, port: 0)); + listener.Start(); + + int port = ((IPEndPoint)listener.Server.LocalEndPoint).Port; + using (var metaFileWriter = new StreamWriter(metaFile, Encoding.UTF8)) + { + metaFileWriter.WriteLine(port); + metaFileWriter.WriteLine(editorExecutablePath); + } + } + + private async Task AcceptClient(TcpClient tcpClient) + { + logger.LogDebug("Accept client..."); + + using (var peer = new Peer(tcpClient, new ServerHandshake(), new ServerMessageHandler(), logger)) + { + // ReSharper disable AccessToDisposedClosure + peer.Connected += () => + { + logger.LogInfo("Connection open with Ide Client"); + + if (clientConnectedAwaiters.TryGetValue(peer.RemoteIdentity, out var queue)) + { + while (queue.Count > 0) + queue.Dequeue().SetResult(true); + clientConnectedAwaiters.Remove(peer.RemoteIdentity); + } + }; + + peer.Disconnected += () => + { + if (clientDisconnectedAwaiters.TryGetValue(peer.RemoteIdentity, out var queue)) + { + while (queue.Count > 0) + queue.Dequeue().SetResult(true); + clientDisconnectedAwaiters.Remove(peer.RemoteIdentity); + } + }; + // ReSharper restore AccessToDisposedClosure + + try + { + if (!await peer.DoHandshake("server")) + { + logger.LogError("Handshake failed"); + return; + } + } + catch (Exception e) + { + logger.LogError("Handshake failed with unhandled exception: ", e); + return; + } + + using (await peersSem.UseAsync()) + Peers.Add(peer); + + try + { + await peer.Process(); + } + finally + { + using (await peersSem.UseAsync()) + Peers.Remove(peer); + } + } + } + + public async Task Listen() + { + try + { + while (!IsDisposed) + _ = AcceptClient(await listener.AcceptTcpClientAsync()); + } + catch (Exception e) + { + if (!IsDisposed && !(e is SocketException se && se.SocketErrorCode == SocketError.Interrupted)) + throw; + } + } + + public async void BroadcastRequest<TResponse>(string identity, Request request) + where TResponse : Response, new() + { + using (await peersSem.UseAsync()) + { + if (!IsAnyConnected(identity)) + { + logger.LogError("Cannot write request. No client connected to the Godot Ide Server."); + return; + } + + var selectedConnections = string.IsNullOrEmpty(identity) ? + Peers : + Peers.Where(c => c.RemoteIdentity == identity); + + string body = JsonConvert.SerializeObject(request); + + foreach (var connection in selectedConnections) + _ = connection.SendRequest<TResponse>(request.Id, body); + } + } + + private class ServerHandshake : IHandshake + { + private static readonly string ServerHandshakeBase = $"{Peer.ServerHandshakeName},Version={Peer.ProtocolVersionMajor}.{Peer.ProtocolVersionMinor}.{Peer.ProtocolVersionRevision}"; + private static readonly string ClientHandshakePattern = $@"{Regex.Escape(Peer.ClientHandshakeName)},Version=([0-9]+)\.([0-9]+)\.([0-9]+),([_a-zA-Z][_a-zA-Z0-9]{{0,63}})"; + + public string GetHandshakeLine(string identity) => $"{ServerHandshakeBase},{identity}"; + + public bool IsValidPeerHandshake(string handshake, out string identity, ILogger logger) + { + identity = null; + + var match = Regex.Match(handshake, ClientHandshakePattern); + + if (!match.Success) + return false; + + if (!uint.TryParse(match.Groups[1].Value, out uint clientMajor) || Peer.ProtocolVersionMajor != clientMajor) + { + logger.LogDebug("Incompatible major version: " + match.Groups[1].Value); + return false; + } + + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + if (!uint.TryParse(match.Groups[2].Value, out uint clientMinor) || Peer.ProtocolVersionMinor > clientMinor) + { + logger.LogDebug("Incompatible minor version: " + match.Groups[2].Value); + return false; + } + + if (!uint.TryParse(match.Groups[3].Value, out uint _)) // Revision + { + logger.LogDebug("Incompatible revision build: " + match.Groups[3].Value); + return false; + } + + identity = match.Groups[4].Value; + + return true; + } + } + + private class ServerMessageHandler : IMessageHandler + { + private static void DispatchToMainThread(Action action) + { + var d = new SendOrPostCallback(state => action()); + Godot.Dispatcher.SynchronizationContext.Post(d, null); + } + + private readonly Dictionary<string, Peer.RequestHandler> requestHandlers = InitializeRequestHandlers(); + + public async Task<MessageContent> HandleRequest(Peer peer, string id, MessageContent content, ILogger logger) + { + if (!requestHandlers.TryGetValue(id, out var handler)) + { + logger.LogError($"Received unknown request: {id}"); + return new MessageContent(MessageStatus.RequestNotSupported, "null"); + } + + try + { + var response = await handler(peer, content); + return new MessageContent(response.Status, JsonConvert.SerializeObject(response)); + } + catch (JsonException) + { + logger.LogError($"Received request with invalid body: {id}"); + return new MessageContent(MessageStatus.InvalidRequestBody, "null"); + } + } + + private static Dictionary<string, Peer.RequestHandler> InitializeRequestHandlers() + { + return new Dictionary<string, Peer.RequestHandler> + { + [PlayRequest.Id] = async (peer, content) => + { + _ = JsonConvert.DeserializeObject<PlayRequest>(content.Body); + return await HandlePlay(); + }, + [DebugPlayRequest.Id] = async (peer, content) => + { + var request = JsonConvert.DeserializeObject<DebugPlayRequest>(content.Body); + return await HandleDebugPlay(request); + }, + [ReloadScriptsRequest.Id] = async (peer, content) => + { + _ = JsonConvert.DeserializeObject<ReloadScriptsRequest>(content.Body); + return await HandleReloadScripts(); + }, + [CodeCompletionRequest.Id] = async (peer, content) => + { + var request = JsonConvert.DeserializeObject<CodeCompletionRequest>(content.Body); + return await HandleCodeCompletionRequest(request); + } + }; + } + + private static Task<Response> HandlePlay() + { + DispatchToMainThread(() => + { + GodotSharpEditor.Instance.CurrentPlaySettings = new PlaySettings(); + Internal.EditorRunPlay(); + GodotSharpEditor.Instance.CurrentPlaySettings = null; + }); + return Task.FromResult<Response>(new PlayResponse()); + } + + private static Task<Response> HandleDebugPlay(DebugPlayRequest request) + { + DispatchToMainThread(() => + { + GodotSharpEditor.Instance.CurrentPlaySettings = + new PlaySettings(request.DebuggerHost, request.DebuggerPort, request.BuildBeforePlaying ?? true); + Internal.EditorRunPlay(); + GodotSharpEditor.Instance.CurrentPlaySettings = null; + }); + return Task.FromResult<Response>(new DebugPlayResponse()); + } + + private static Task<Response> HandleReloadScripts() + { + DispatchToMainThread(Internal.ScriptEditorDebugger_ReloadScripts); + return Task.FromResult<Response>(new ReloadScriptsResponse()); + } + + private static async Task<Response> HandleCodeCompletionRequest(CodeCompletionRequest request) + { + var response = new CodeCompletionResponse {Kind = request.Kind, ScriptFile = request.ScriptFile}; + response.Suggestions = await Task.Run(() => Internal.CodeCompletionRequest(response.Kind, response.ScriptFile)); + return response; + } + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools/Ides/MonoDevelop/Instance.cs b/modules/mono/editor/GodotTools/GodotTools/Ides/MonoDevelop/Instance.cs index 6026c109ad..d6fa2eeba7 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Ides/MonoDevelop/Instance.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Ides/MonoDevelop/Instance.cs @@ -7,14 +7,16 @@ using GodotTools.Utils; namespace GodotTools.Ides.MonoDevelop { - public class Instance + public class Instance : IDisposable { + public DateTime LaunchTime { get; private set; } private readonly string solutionFile; private readonly EditorId editorId; private Process process; public bool IsRunning => process != null && !process.HasExited; + public bool IsDisposed { get; private set; } public void Execute() { @@ -59,6 +61,8 @@ namespace GodotTools.Ides.MonoDevelop if (command == null) throw new FileNotFoundException(); + LaunchTime = DateTime.Now; + if (newWindow) { process = Process.Start(new ProcessStartInfo @@ -88,6 +92,12 @@ namespace GodotTools.Ides.MonoDevelop this.editorId = editorId; } + public void Dispose() + { + IsDisposed = true; + process?.Dispose(); + } + private static readonly IReadOnlyDictionary<EditorId, string> ExecutableNames; private static readonly IReadOnlyDictionary<EditorId, string> BundleIds; @@ -118,7 +128,7 @@ namespace GodotTools.Ides.MonoDevelop {EditorId.MonoDevelop, "MonoDevelop.exe"} }; } - else if (OS.IsUnixLike()) + else if (OS.IsUnixLike) { ExecutableNames = new Dictionary<EditorId, string> { diff --git a/modules/mono/editor/GodotTools/GodotTools/Ides/Rider/RiderPathLocator.cs b/modules/mono/editor/GodotTools/GodotTools/Ides/Rider/RiderPathLocator.cs index e3a4fa7b45..e22e9af919 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Ides/Rider/RiderPathLocator.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Ides/Rider/RiderPathLocator.cs @@ -36,7 +36,7 @@ namespace GodotTools.Ides.Rider { return CollectRiderInfosMac(); } - if (OS.IsUnixLike()) + if (OS.IsUnixLike) { return CollectAllRiderPathsLinux(); } @@ -141,16 +141,16 @@ namespace GodotTools.Ides.Rider if (OS.IsOSX) { var home = Environment.GetEnvironmentVariable("HOME"); - if (string.IsNullOrEmpty(home)) + if (string.IsNullOrEmpty(home)) return string.Empty; var localAppData = Path.Combine(home, @"Library/Application Support"); return GetToolboxRiderRootPath(localAppData); } - if (OS.IsUnixLike()) + if (OS.IsUnixLike) { var home = Environment.GetEnvironmentVariable("HOME"); - if (string.IsNullOrEmpty(home)) + if (string.IsNullOrEmpty(home)) return string.Empty; var localAppData = Path.Combine(home, @".local/share"); return GetToolboxRiderRootPath(localAppData); @@ -209,7 +209,7 @@ namespace GodotTools.Ides.Rider private static string GetRelativePathToBuildTxt() { - if (OS.IsWindows || OS.IsUnixLike()) + if (OS.IsWindows || OS.IsUnixLike) return "../../build.txt"; if (OS.IsOSX) return "Contents/Resources/build.txt"; @@ -322,7 +322,7 @@ namespace GodotTools.Ides.Rider class SettingsJson { public string install_location; - + [CanBeNull] public static string GetInstallLocationFromJson(string json) { diff --git a/modules/mono/editor/GodotTools/GodotTools/Internals/Internal.cs b/modules/mono/editor/GodotTools/GodotTools/Internals/Internal.cs index 026a7db89c..7e5049e4b7 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Internals/Internal.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Internals/Internal.cs @@ -2,6 +2,7 @@ using System; using System.Runtime.CompilerServices; using Godot; using Godot.Collections; +using GodotTools.IdeMessaging.Requests; namespace GodotTools.Internals { @@ -52,6 +53,9 @@ namespace GodotTools.Internals public static void ScriptEditorDebugger_ReloadScripts() => internal_ScriptEditorDebugger_ReloadScripts(); + public static string[] CodeCompletionRequest(CodeCompletionRequest.CompletionKind kind, string scriptFile) => + internal_CodeCompletionRequest((int)kind, scriptFile); + #region Internal [MethodImpl(MethodImplOptions.InternalCall)] @@ -111,6 +115,9 @@ namespace GodotTools.Internals [MethodImpl(MethodImplOptions.InternalCall)] private static extern void internal_ScriptEditorDebugger_ReloadScripts(); + [MethodImpl(MethodImplOptions.InternalCall)] + private static extern string[] internal_CodeCompletionRequest(int kind, string scriptFile); + #endregion } } diff --git a/modules/mono/editor/GodotTools/GodotTools/PlaySettings.cs b/modules/mono/editor/GodotTools/GodotTools/PlaySettings.cs new file mode 100644 index 0000000000..820d0c0b83 --- /dev/null +++ b/modules/mono/editor/GodotTools/GodotTools/PlaySettings.cs @@ -0,0 +1,19 @@ +namespace GodotTools +{ + public struct PlaySettings + { + public bool HasDebugger { get; } + public string DebuggerHost { get; } + public int DebuggerPort { get; } + + public bool BuildBeforePlaying { get; } + + public PlaySettings(string debuggerHost, int debuggerPort, bool buildBeforePlaying) + { + HasDebugger = true; + DebuggerHost = debuggerHost; + DebuggerPort = debuggerPort; + BuildBeforePlaying = buildBeforePlaying; + } + } +} diff --git a/modules/mono/editor/GodotTools/GodotTools/Properties/AssemblyInfo.cs b/modules/mono/editor/GodotTools/GodotTools/Properties/AssemblyInfo.cs deleted file mode 100644 index f5fe85c722..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,26 +0,0 @@ -using System.Reflection; -using System.Runtime.CompilerServices; - -// Information about this assembly is defined by the following attributes. -// Change them to the values specific to your project. - -[assembly: AssemblyTitle("GodotTools")] -[assembly: AssemblyDescription("")] -[assembly: AssemblyConfiguration("")] -[assembly: AssemblyCompany("")] -[assembly: AssemblyProduct("")] -[assembly: AssemblyCopyright("Godot Engine contributors")] -[assembly: AssemblyTrademark("")] -[assembly: AssemblyCulture("")] - -// The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". -// The form "{Major}.{Minor}.*" will automatically update the build and revision, -// and "{Major}.{Minor}.{Build}.*" will update just the revision. - -[assembly: AssemblyVersion("1.0.*")] - -// The following attributes are used to specify the signing key for the assembly, -// if desired. See the Mono documentation for more information about signing. - -//[assembly: AssemblyDelaySign(false)] -//[assembly: AssemblyKeyFile("")] diff --git a/modules/mono/editor/GodotTools/GodotTools/Utils/OS.cs b/modules/mono/editor/GodotTools/GodotTools/Utils/OS.cs index b057ac12c6..6c05891f2c 100644 --- a/modules/mono/editor/GodotTools/GodotTools/Utils/OS.cs +++ b/modules/mono/editor/GodotTools/GodotTools/Utils/OS.cs @@ -22,7 +22,10 @@ namespace GodotTools.Utils { public const string Windows = "Windows"; public const string OSX = "OSX"; - public const string X11 = "X11"; + public const string Linux = "Linux"; + public const string FreeBSD = "FreeBSD"; + public const string NetBSD = "NetBSD"; + public const string BSD = "BSD"; public const string Server = "Server"; public const string UWP = "UWP"; public const string Haiku = "Haiku"; @@ -35,7 +38,7 @@ namespace GodotTools.Utils { public const string Windows = "windows"; public const string OSX = "osx"; - public const string X11 = "linuxbsd"; + public const string LinuxBSD = "linuxbsd"; public const string Server = "server"; public const string UWP = "uwp"; public const string Haiku = "haiku"; @@ -48,7 +51,10 @@ namespace GodotTools.Utils { [Names.Windows] = Platforms.Windows, [Names.OSX] = Platforms.OSX, - [Names.X11] = Platforms.X11, + [Names.Linux] = Platforms.LinuxBSD, + [Names.FreeBSD] = Platforms.LinuxBSD, + [Names.NetBSD] = Platforms.LinuxBSD, + [Names.BSD] = Platforms.LinuxBSD, [Names.Server] = Platforms.Server, [Names.UWP] = Platforms.UWP, [Names.Haiku] = Platforms.Haiku, @@ -62,38 +68,39 @@ namespace GodotTools.Utils return name.Equals(GetPlatformName(), StringComparison.OrdinalIgnoreCase); } + private static bool IsAnyOS(IEnumerable<string> names) + { + return names.Any(p => p.Equals(GetPlatformName(), StringComparison.OrdinalIgnoreCase)); + } + + private static readonly IEnumerable<string> LinuxBSDPlatforms = + new[] {Names.Linux, Names.FreeBSD, Names.NetBSD, Names.BSD}; + + private static readonly IEnumerable<string> UnixLikePlatforms = + new[] {Names.OSX, Names.Server, Names.Haiku, Names.Android, Names.iOS} + .Concat(LinuxBSDPlatforms).ToArray(); + private static readonly Lazy<bool> _isWindows = new Lazy<bool>(() => IsOS(Names.Windows)); private static readonly Lazy<bool> _isOSX = new Lazy<bool>(() => IsOS(Names.OSX)); - private static readonly Lazy<bool> _isX11 = new Lazy<bool>(() => IsOS(Names.X11)); + private static readonly Lazy<bool> _isLinuxBSD = new Lazy<bool>(() => IsAnyOS(LinuxBSDPlatforms)); private static readonly Lazy<bool> _isServer = new Lazy<bool>(() => IsOS(Names.Server)); private static readonly Lazy<bool> _isUWP = new Lazy<bool>(() => IsOS(Names.UWP)); private static readonly Lazy<bool> _isHaiku = new Lazy<bool>(() => IsOS(Names.Haiku)); private static readonly Lazy<bool> _isAndroid = new Lazy<bool>(() => IsOS(Names.Android)); private static readonly Lazy<bool> _isiOS = new Lazy<bool>(() => IsOS(Names.iOS)); private static readonly Lazy<bool> _isHTML5 = new Lazy<bool>(() => IsOS(Names.HTML5)); + private static readonly Lazy<bool> _isUnixLike = new Lazy<bool>(() => IsAnyOS(UnixLikePlatforms)); public static bool IsWindows => _isWindows.Value || IsUWP; public static bool IsOSX => _isOSX.Value; - public static bool IsX11 => _isX11.Value; + public static bool IsLinuxBSD => _isLinuxBSD.Value; public static bool IsServer => _isServer.Value; public static bool IsUWP => _isUWP.Value; public static bool IsHaiku => _isHaiku.Value; public static bool IsAndroid => _isAndroid.Value; public static bool IsiOS => _isiOS.Value; public static bool IsHTML5 => _isHTML5.Value; - - private static bool? _isUnixCache; - private static readonly string[] UnixLikePlatforms = { Names.OSX, Names.X11, Names.Server, Names.Haiku, Names.Android, Names.iOS }; - - public static bool IsUnixLike() - { - if (_isUnixCache.HasValue) - return _isUnixCache.Value; - - string osName = GetPlatformName(); - _isUnixCache = UnixLikePlatforms.Any(p => p.Equals(osName, StringComparison.OrdinalIgnoreCase)); - return _isUnixCache.Value; - } + public static bool IsUnixLike => _isUnixLike.Value; public static char PathSep => IsWindows ? ';' : ':'; @@ -121,10 +128,10 @@ namespace GodotTools.Utils return searchDirs.Select(dir => Path.Combine(dir, name)).FirstOrDefault(File.Exists); return (from dir in searchDirs - select Path.Combine(dir, name) + select Path.Combine(dir, name) into path - from ext in windowsExts - select path + ext).FirstOrDefault(File.Exists); + from ext in windowsExts + select path + ext).FirstOrDefault(File.Exists); } private static string PathWhichUnix([NotNull] string name) @@ -189,7 +196,7 @@ namespace GodotTools.Utils startInfo.UseShellExecute = false; - using (var process = new Process { StartInfo = startInfo }) + using (var process = new Process {StartInfo = startInfo}) { process.Start(); process.WaitForExit(); diff --git a/modules/mono/editor/GodotTools/GodotTools/packages.config b/modules/mono/editor/GodotTools/GodotTools/packages.config deleted file mode 100644 index dd3de2865a..0000000000 --- a/modules/mono/editor/GodotTools/GodotTools/packages.config +++ /dev/null @@ -1,5 +0,0 @@ -<?xml version="1.0" encoding="utf-8"?> -<packages> - <package id="JetBrains.Annotations" version="2019.1.3" targetFramework="net45" /> - <package id="Newtonsoft.Json" version="12.0.3" targetFramework="net45" /> -</packages> diff --git a/modules/mono/editor/bindings_generator.cpp b/modules/mono/editor/bindings_generator.cpp index bdf9cf965f..258b8ed3ed 100644 --- a/modules/mono/editor/bindings_generator.cpp +++ b/modules/mono/editor/bindings_generator.cpp @@ -1664,6 +1664,10 @@ Error BindingsGenerator::_generate_cs_method(const BindingsGenerator::TypeInterf } if (!p_imethod.is_internal) { + // TODO: This alone adds ~0.2 MB of bloat to the core API assembly. It would be + // better to generate a table in the C++ glue instead. That way the strings wouldn't + // add that much extra bloat as they're already used in engine code. Also, it would + // probably be much faster than looking up the attributes when fetching methods. p_output.append(MEMBER_BEGIN "[GodotMethod(\""); p_output.append(p_imethod.name); p_output.append("\")]"); @@ -2139,7 +2143,7 @@ Error BindingsGenerator::_generate_glue_method(const BindingsGenerator::TypeInte if (return_type->ret_as_byref_arg) { p_output.append("\tif (" CS_PARAM_INSTANCE " == nullptr) { *arg_ret = "); p_output.append(fail_ret); - p_output.append("; ERR_FAIL_MSG(\"Parameter ' arg_ret ' is null.\"); }\n"); + p_output.append("; ERR_FAIL_MSG(\"Parameter ' " CS_PARAM_INSTANCE " ' is null.\"); }\n"); } else { p_output.append("\tERR_FAIL_NULL_V(" CS_PARAM_INSTANCE ", "); p_output.append(fail_ret); @@ -2390,6 +2394,11 @@ bool BindingsGenerator::_populate_object_type_interfaces() { if (property.usage & PROPERTY_USAGE_GROUP || property.usage & PROPERTY_USAGE_SUBGROUP || property.usage & PROPERTY_USAGE_CATEGORY) continue; + if (property.name.find("/") >= 0) { + // Ignore properties with '/' (slash) in the name. These are only meant for use in the inspector. + continue; + } + PropertyInterface iprop; iprop.cname = property.name; iprop.setter = ClassDB::get_property_setter(type_cname, iprop.cname); @@ -2402,7 +2411,7 @@ bool BindingsGenerator::_populate_object_type_interfaces() { bool valid = false; iprop.index = ClassDB::get_property_index(type_cname, iprop.cname, &valid); - ERR_FAIL_COND_V(!valid, false); + ERR_FAIL_COND_V_MSG(!valid, false, "Invalid property: '" + itype.name + "." + String(iprop.cname) + "'."); iprop.proxy_name = escape_csharp_keyword(snake_to_pascal_case(iprop.cname)); @@ -2414,8 +2423,6 @@ bool BindingsGenerator::_populate_object_type_interfaces() { iprop.proxy_name += "_"; } - iprop.proxy_name = iprop.proxy_name.replace("/", "__"); // Some members have a slash... - iprop.prop_doc = nullptr; for (int i = 0; i < itype.class_doc->properties.size(); i++) { diff --git a/modules/mono/editor/bindings_generator.h b/modules/mono/editor/bindings_generator.h index 7c87c688db..9aad9622d9 100644 --- a/modules/mono/editor/bindings_generator.h +++ b/modules/mono/editor/bindings_generator.h @@ -85,16 +85,12 @@ class BindingsGenerator { struct TypeReference { StringName cname; - bool is_enum; + bool is_enum = false; - TypeReference() : - is_enum(false) { - } + TypeReference() {} TypeReference(const StringName &p_cname) : - cname(p_cname), - is_enum(false) { - } + cname(p_cname) {} }; struct ArgumentInterface { @@ -107,7 +103,7 @@ class BindingsGenerator { TypeReference type; String name; - DefaultParamMode def_param_mode; + DefaultParamMode def_param_mode = CONSTANT; /** * Determines the expression for the parameter default value. @@ -116,9 +112,7 @@ class BindingsGenerator { */ String default_argument; - ArgumentInterface() { - def_param_mode = CONSTANT; - } + ArgumentInterface() {} }; struct MethodInterface { @@ -138,19 +132,19 @@ class BindingsGenerator { /** * Determines if the method has a variable number of arguments (VarArg) */ - bool is_vararg; + bool is_vararg = false; /** * Virtual methods ("virtual" as defined by the Godot API) are methods that by default do nothing, * but can be overridden by the user to add custom functionality. * e.g.: _ready, _process, etc. */ - bool is_virtual; + bool is_virtual = false; /** * Determines if the call should fallback to Godot's object.Call(string, params) in C#. */ - bool requires_object_call; + bool requires_object_call = false; /** * Determines if the method visibility is 'internal' (visible only to files in the same assembly). @@ -158,27 +152,20 @@ class BindingsGenerator { * but are required by properties as getters or setters. * Methods that are not meant to be exposed are those that begin with underscore and are not virtual. */ - bool is_internal; + bool is_internal = false; List<ArgumentInterface> arguments; - const DocData::MethodDoc *method_doc; + const DocData::MethodDoc *method_doc = nullptr; - bool is_deprecated; + bool is_deprecated = false; String deprecation_message; void add_argument(const ArgumentInterface &argument) { arguments.push_back(argument); } - MethodInterface() { - is_vararg = false; - is_virtual = false; - requires_object_call = false; - is_internal = false; - method_doc = nullptr; - is_deprecated = false; - } + MethodInterface() {} }; struct SignalInterface { @@ -192,19 +179,16 @@ class BindingsGenerator { List<ArgumentInterface> arguments; - const DocData::MethodDoc *method_doc; + const DocData::MethodDoc *method_doc = nullptr; - bool is_deprecated; + bool is_deprecated = false; String deprecation_message; void add_argument(const ArgumentInterface &argument) { arguments.push_back(argument); } - SignalInterface() { - method_doc = nullptr; - is_deprecated = false; - } + SignalInterface() {} }; struct TypeInterface { @@ -225,26 +209,26 @@ class BindingsGenerator { */ String proxy_name; - ClassDB::APIType api_type; + ClassDB::APIType api_type = ClassDB::API_NONE; - bool is_enum; - bool is_object_type; - bool is_singleton; - bool is_reference; + bool is_enum = false; + bool is_object_type = false; + bool is_singleton = false; + bool is_reference = false; /** * Used only by Object-derived types. * Determines if this type is not abstract (incomplete). * e.g.: CanvasItem cannot be instantiated. */ - bool is_instantiable; + bool is_instantiable = false; /** * Used only by Object-derived types. * Determines if the C# class owns the native handle and must free it somehow when disposed. * e.g.: Reference types must notify when the C# instance is disposed, for proper refcounting. */ - bool memory_own; + bool memory_own = false; /** * This must be set to true for any struct bigger than 32-bits. Those cannot be passed/returned by value @@ -252,7 +236,7 @@ class BindingsGenerator { * In this case, [c_out] and [cs_out] must have a different format, explained below. * The Mono IL interpreter icall trampolines don't support passing structs bigger than 32-bits by value (at least not on WASM). */ - bool ret_as_byref_arg; + bool ret_as_byref_arg = false; // !! The comments of the following fields make reference to other fields via square brackets, e.g.: [field_name] // !! When renaming those fields, make sure to rename their references in the comments @@ -279,7 +263,7 @@ class BindingsGenerator { * Formatting elements: * %0 or %s: name of the parameter */ - String c_arg_in; + String c_arg_in = "%s"; /** * One or more statements that determine how a variable of this type is returned from a function. @@ -362,7 +346,7 @@ class BindingsGenerator { */ String im_type_out; - const DocData::ClassDoc *class_doc; + const DocData::ClassDoc *class_doc = nullptr; List<ConstantInterface> constants; List<EnumInterface> enums; @@ -482,24 +466,7 @@ class BindingsGenerator { r_enum_itype.class_doc = &EditorHelp::get_doc_data()->class_list[r_enum_itype.proxy_name]; } - TypeInterface() { - - api_type = ClassDB::API_NONE; - - is_enum = false; - is_object_type = false; - is_singleton = false; - is_reference = false; - is_instantiable = false; - - memory_own = false; - - ret_as_byref_arg = false; - - c_arg_in = "%s"; - - class_doc = nullptr; - } + TypeInterface() {} }; struct InternalCall { @@ -532,8 +499,8 @@ class BindingsGenerator { } }; - bool log_print_enabled; - bool initialized; + bool log_print_enabled = true; + bool initialized = false; OrderedHashMap<StringName, TypeInterface> obj_types; @@ -619,7 +586,8 @@ class BindingsGenerator { const List<InternalCall>::Element *find_icall_by_name(const String &p_name, const List<InternalCall> &p_list) { const List<InternalCall>::Element *it = p_list.front(); while (it) { - if (it->get().name == p_name) return it; + if (it->get().name == p_name) + return it; it = it->next(); } return nullptr; @@ -696,9 +664,7 @@ public: static void handle_cmdline_args(const List<String> &p_cmdline_args); - BindingsGenerator() : - log_print_enabled(true), - initialized(false) { + BindingsGenerator() { _initialize(); } }; diff --git a/modules/mono/editor/code_completion.cpp b/modules/mono/editor/code_completion.cpp new file mode 100644 index 0000000000..7a5e465e7a --- /dev/null +++ b/modules/mono/editor/code_completion.cpp @@ -0,0 +1,249 @@ +/*************************************************************************/ +/* code_completion.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "code_completion.h" + +#include "core/project_settings.h" +#include "editor/editor_file_system.h" +#include "editor/editor_settings.h" +#include "scene/gui/control.h" +#include "scene/main/node.h" + +namespace gdmono { + +// Almost everything here is taken from functions used by GDScript for code completion, adapted for C#. + +_FORCE_INLINE_ String quoted(const String &p_str) { + return "\"" + p_str + "\""; +} + +void _add_nodes_suggestions(const Node *p_base, const Node *p_node, PackedStringArray &r_suggestions) { + if (p_node != p_base && !p_node->get_owner()) + return; + + String path_relative_to_orig = p_base->get_path_to(p_node); + + r_suggestions.push_back(quoted(path_relative_to_orig)); + + for (int i = 0; i < p_node->get_child_count(); i++) { + _add_nodes_suggestions(p_base, p_node->get_child(i), r_suggestions); + } +} + +Node *_find_node_for_script(Node *p_base, Node *p_current, const Ref<Script> &p_script) { + if (p_current->get_owner() != p_base && p_base != p_current) + return nullptr; + + Ref<Script> c = p_current->get_script(); + + if (c == p_script) + return p_current; + + for (int i = 0; i < p_current->get_child_count(); i++) { + Node *found = _find_node_for_script(p_base, p_current->get_child(i), p_script); + if (found) + return found; + } + + return nullptr; +} + +void _get_directory_contents(EditorFileSystemDirectory *p_dir, PackedStringArray &r_suggestions) { + for (int i = 0; i < p_dir->get_file_count(); i++) { + r_suggestions.push_back(quoted(p_dir->get_file_path(i))); + } + + for (int i = 0; i < p_dir->get_subdir_count(); i++) { + _get_directory_contents(p_dir->get_subdir(i), r_suggestions); + } +} + +Node *_try_find_owner_node_in_tree(const Ref<Script> p_script) { + SceneTree *tree = SceneTree::get_singleton(); + if (!tree) + return nullptr; + Node *base = tree->get_edited_scene_root(); + if (base) { + base = _find_node_for_script(base, base, p_script); + } + return base; +} + +PackedStringArray get_code_completion(CompletionKind p_kind, const String &p_script_file) { + PackedStringArray suggestions; + + switch (p_kind) { + case CompletionKind::INPUT_ACTIONS: { + List<PropertyInfo> project_props; + ProjectSettings::get_singleton()->get_property_list(&project_props); + + for (List<PropertyInfo>::Element *E = project_props.front(); E; E = E->next()) { + const PropertyInfo &prop = E->get(); + + if (!prop.name.begins_with("input/")) + continue; + + String name = prop.name.substr(prop.name.find("/") + 1, prop.name.length()); + suggestions.push_back(quoted(name)); + } + } break; + case CompletionKind::NODE_PATHS: { + { + // AutoLoads + List<PropertyInfo> props; + ProjectSettings::get_singleton()->get_property_list(&props); + + for (List<PropertyInfo>::Element *E = props.front(); E; E = E->next()) { + String s = E->get().name; + if (!s.begins_with("autoload/")) { + continue; + } + String name = s.get_slice("/", 1); + suggestions.push_back(quoted("/root/" + name)); + } + } + + { + // Current edited scene tree + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + Node *base = _try_find_owner_node_in_tree(script); + if (base) { + _add_nodes_suggestions(base, base, suggestions); + } + } + } break; + case CompletionKind::RESOURCE_PATHS: { + if (bool(EditorSettings::get_singleton()->get("text_editor/completion/complete_file_paths"))) { + _get_directory_contents(EditorFileSystem::get_singleton()->get_filesystem(), suggestions); + } + } break; + case CompletionKind::SCENE_PATHS: { + DirAccessRef dir_access = DirAccess::create(DirAccess::ACCESS_RESOURCES); + List<String> directories; + directories.push_back(dir_access->get_current_dir()); + + while (!directories.empty()) { + dir_access->change_dir(directories.back()->get()); + directories.pop_back(); + + dir_access->list_dir_begin(); + String filename = dir_access->get_next(); + + while (filename != "") { + if (filename == "." || filename == "..") { + filename = dir_access->get_next(); + continue; + } + + if (dir_access->dir_exists(filename)) { + directories.push_back(dir_access->get_current_dir().plus_file(filename)); + } else if (filename.ends_with(".tscn") || filename.ends_with(".scn")) { + suggestions.push_back(quoted(dir_access->get_current_dir().plus_file(filename))); + } + + filename = dir_access->get_next(); + } + } + } break; + case CompletionKind::SHADER_PARAMS: { + print_verbose("Shared params completion for C# not implemented."); + } break; + case CompletionKind::SIGNALS: { + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + + List<MethodInfo> signals; + script->get_script_signal_list(&signals); + + StringName native = script->get_instance_base_type(); + if (native != StringName()) { + ClassDB::get_signal_list(native, &signals, /* p_no_inheritance: */ false); + } + + for (List<MethodInfo>::Element *E = signals.front(); E; E = E->next()) { + const String &signal = E->get().name; + suggestions.push_back(quoted(signal)); + } + } break; + case CompletionKind::THEME_COLORS: { + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + Node *base = _try_find_owner_node_in_tree(script); + if (base && Object::cast_to<Control>(base)) { + List<StringName> sn; + Theme::get_default()->get_color_list(base->get_class(), &sn); + + for (List<StringName>::Element *E = sn.front(); E; E = E->next()) { + suggestions.push_back(quoted(E->get())); + } + } + } break; + case CompletionKind::THEME_CONSTANTS: { + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + Node *base = _try_find_owner_node_in_tree(script); + if (base && Object::cast_to<Control>(base)) { + List<StringName> sn; + Theme::get_default()->get_constant_list(base->get_class(), &sn); + + for (List<StringName>::Element *E = sn.front(); E; E = E->next()) { + suggestions.push_back(quoted(E->get())); + } + } + } break; + case CompletionKind::THEME_FONTS: { + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + Node *base = _try_find_owner_node_in_tree(script); + if (base && Object::cast_to<Control>(base)) { + List<StringName> sn; + Theme::get_default()->get_font_list(base->get_class(), &sn); + + for (List<StringName>::Element *E = sn.front(); E; E = E->next()) { + suggestions.push_back(quoted(E->get())); + } + } + } break; + case CompletionKind::THEME_STYLES: { + Ref<Script> script = ResourceLoader::load(p_script_file.simplify_path()); + Node *base = _try_find_owner_node_in_tree(script); + if (base && Object::cast_to<Control>(base)) { + List<StringName> sn; + Theme::get_default()->get_stylebox_list(base->get_class(), &sn); + + for (List<StringName>::Element *E = sn.front(); E; E = E->next()) { + suggestions.push_back(quoted(E->get())); + } + } + } break; + default: + ERR_FAIL_V_MSG(suggestions, "Invalid completion kind."); + } + + return suggestions; +} + +} // namespace gdmono diff --git a/core/math/disjoint_set.cpp b/modules/mono/editor/code_completion.h index a508151ad3..77673b766f 100644 --- a/core/math/disjoint_set.cpp +++ b/modules/mono/editor/code_completion.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* disjoint_set.cpp */ +/* code_completion.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ @@ -28,4 +28,29 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#include "disjoint_set.h" +#ifndef CODE_COMPLETION_H +#define CODE_COMPLETION_H + +#include "core/ustring.h" +#include "core/variant.h" + +namespace gdmono { + +enum class CompletionKind { + INPUT_ACTIONS = 0, + NODE_PATHS, + RESOURCE_PATHS, + SCENE_PATHS, + SHADER_PARAMS, + SIGNALS, + THEME_COLORS, + THEME_CONSTANTS, + THEME_FONTS, + THEME_STYLES +}; + +PackedStringArray get_code_completion(CompletionKind p_kind, const String &p_script_file); + +} // namespace gdmono + +#endif // CODE_COMPLETION_H diff --git a/modules/mono/editor/editor_internal_calls.cpp b/modules/mono/editor/editor_internal_calls.cpp index c3e7e67ae9..c9117f1312 100644 --- a/modules/mono/editor/editor_internal_calls.cpp +++ b/modules/mono/editor/editor_internal_calls.cpp @@ -48,6 +48,7 @@ #include "../mono_gd/gd_mono_marshal.h" #include "../utils/osx_utils.h" #include "bindings_generator.h" +#include "code_completion.h" #include "godotsharp_export.h" #include "script_class_parser.h" @@ -354,6 +355,12 @@ void godot_icall_Internal_ScriptEditorDebugger_ReloadScripts() { } } +MonoArray *godot_icall_Internal_CodeCompletionRequest(int32_t p_kind, MonoString *p_script_file) { + String script_file = GDMonoMarshal::mono_string_to_godot(p_script_file); + PackedStringArray suggestions = gdmono::get_code_completion((gdmono::CompletionKind)p_kind, script_file); + return GDMonoMarshal::PackedStringArray_to_mono_array(suggestions); +} + float godot_icall_Globals_EditorScale() { return EDSCALE; } @@ -454,6 +461,7 @@ void register_editor_internal_calls() { mono_add_internal_call("GodotTools.Internals.Internal::internal_EditorRunPlay", (void *)godot_icall_Internal_EditorRunPlay); mono_add_internal_call("GodotTools.Internals.Internal::internal_EditorRunStop", (void *)godot_icall_Internal_EditorRunStop); mono_add_internal_call("GodotTools.Internals.Internal::internal_ScriptEditorDebugger_ReloadScripts", (void *)godot_icall_Internal_ScriptEditorDebugger_ReloadScripts); + mono_add_internal_call("GodotTools.Internals.Internal::internal_CodeCompletionRequest", (void *)godot_icall_Internal_CodeCompletionRequest); // Globals mono_add_internal_call("GodotTools.Internals.Globals::internal_EditorScale", (void *)godot_icall_Globals_EditorScale); diff --git a/modules/mono/editor/godotsharp_export.cpp b/modules/mono/editor/godotsharp_export.cpp index 4126da16be..1cdb08d50e 100644 --- a/modules/mono/editor/godotsharp_export.cpp +++ b/modules/mono/editor/godotsharp_export.cpp @@ -32,68 +32,70 @@ #include <mono/metadata/image.h> +#include "core/io/file_access_pack.h" #include "core/os/os.h" +#include "core/project_settings.h" #include "../mono_gd/gd_mono.h" #include "../mono_gd/gd_mono_assembly.h" #include "../mono_gd/gd_mono_cache.h" +#include "../utils/macros.h" namespace GodotSharpExport { -String get_assemblyref_name(MonoImage *p_image, int index) { +struct AssemblyRefInfo { + String name; + uint16_t major; + uint16_t minor; + uint16_t build; + uint16_t revision; +}; + +AssemblyRefInfo get_assemblyref_name(MonoImage *p_image, int index) { const MonoTableInfo *table_info = mono_image_get_table_info(p_image, MONO_TABLE_ASSEMBLYREF); uint32_t cols[MONO_ASSEMBLYREF_SIZE]; mono_metadata_decode_row(table_info, index, cols, MONO_ASSEMBLYREF_SIZE); - return String::utf8(mono_metadata_string_heap(p_image, cols[MONO_ASSEMBLYREF_NAME])); + return { + String::utf8(mono_metadata_string_heap(p_image, cols[MONO_ASSEMBLYREF_NAME])), + (uint16_t)cols[MONO_ASSEMBLYREF_MAJOR_VERSION], + (uint16_t)cols[MONO_ASSEMBLYREF_MINOR_VERSION], + (uint16_t)cols[MONO_ASSEMBLYREF_BUILD_NUMBER], + (uint16_t)cols[MONO_ASSEMBLYREF_REV_NUMBER] + }; } Error get_assembly_dependencies(GDMonoAssembly *p_assembly, const Vector<String> &p_search_dirs, Dictionary &r_assembly_dependencies) { MonoImage *image = p_assembly->get_image(); for (int i = 0; i < mono_image_get_table_rows(image, MONO_TABLE_ASSEMBLYREF); i++) { - String ref_name = get_assemblyref_name(image, i); + AssemblyRefInfo ref_info = get_assemblyref_name(image, i); + + const String &ref_name = ref_info.name; if (r_assembly_dependencies.has(ref_name)) continue; GDMonoAssembly *ref_assembly = nullptr; - String path; - bool has_extension = ref_name.ends_with(".dll") || ref_name.ends_with(".exe"); - - for (int j = 0; j < p_search_dirs.size(); j++) { - const String &search_dir = p_search_dirs[j]; - - if (has_extension) { - path = search_dir.plus_file(ref_name); - if (FileAccess::exists(path)) { - GDMono::get_singleton()->load_assembly_from(ref_name.get_basename(), path, &ref_assembly, true); - if (ref_assembly != nullptr) - break; - } - } else { - path = search_dir.plus_file(ref_name + ".dll"); - if (FileAccess::exists(path)) { - GDMono::get_singleton()->load_assembly_from(ref_name, path, &ref_assembly, true); - if (ref_assembly != nullptr) - break; - } - - path = search_dir.plus_file(ref_name + ".exe"); - if (FileAccess::exists(path)) { - GDMono::get_singleton()->load_assembly_from(ref_name, path, &ref_assembly, true); - if (ref_assembly != nullptr) - break; - } - } - } - ERR_FAIL_COND_V_MSG(!ref_assembly, ERR_CANT_RESOLVE, "Cannot load assembly (refonly): '" + ref_name + "'."); + { + MonoAssemblyName *ref_aname = mono_assembly_name_new("A"); // We can't allocate an empty MonoAssemblyName, hence "A" + CRASH_COND(ref_aname == nullptr); + SCOPE_EXIT { + mono_assembly_name_free(ref_aname); + mono_free(ref_aname); + }; + + mono_assembly_get_assemblyref(image, i, ref_aname); - // Use the path we got from the search. Don't try to get the path from the loaded assembly as we can't trust it will be from the selected BCL dir. - r_assembly_dependencies[ref_name] = path; + if (!GDMono::get_singleton()->load_assembly(ref_name, ref_aname, &ref_assembly, /* refonly: */ true, p_search_dirs)) { + ERR_FAIL_V_MSG(ERR_CANT_RESOLVE, "Cannot load assembly (refonly): '" + ref_name + "'."); + } + + r_assembly_dependencies[ref_name] = ref_assembly->get_path(); + } Error err = get_assembly_dependencies(ref_assembly, p_search_dirs, r_assembly_dependencies); ERR_FAIL_COND_V_MSG(err != OK, err, "Cannot load one of the dependencies for the assembly: '" + ref_name + "'."); @@ -113,6 +115,11 @@ Error get_exported_assembly_dependencies(const Dictionary &p_initial_assemblies, Vector<String> search_dirs; GDMonoAssembly::fill_search_dirs(search_dirs, p_build_config, p_custom_bcl_dir); + if (p_custom_bcl_dir.length()) { + // Only one mscorlib can be loaded. We need this workaround to make sure we get it from the right BCL directory. + r_assembly_dependencies["mscorlib"] = p_custom_bcl_dir.plus_file("mscorlib.dll").simplify_path(); + } + for (const Variant *key = p_initial_assemblies.next(); key; key = p_initial_assemblies.next(key)) { String assembly_name = *key; String assembly_path = p_initial_assemblies[*key]; diff --git a/modules/mono/editor/script_class_parser.cpp b/modules/mono/editor/script_class_parser.cpp index bece23c9a6..3ffbf8ba14 100644 --- a/modules/mono/editor/script_class_parser.cpp +++ b/modules/mono/editor/script_class_parser.cpp @@ -181,14 +181,24 @@ ScriptClassParser::Token ScriptClassParser::get_token() { CharType res = 0; switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; case 'r': res = 13; break; - case '\"': res = '\"'; break; + case '\"': + res = '\"'; + break; case '\\': res = '\\'; break; diff --git a/modules/mono/glue/GodotSharp/GodotSharp/GodotSharp.csproj b/modules/mono/glue/GodotSharp/GodotSharp/GodotSharp.csproj index ba0bbd7630..b5ac124c9a 100644 --- a/modules/mono/glue/GodotSharp/GodotSharp/GodotSharp.csproj +++ b/modules/mono/glue/GodotSharp/GodotSharp/GodotSharp.csproj @@ -30,6 +30,7 @@ <ConsolePause>false</ConsolePause> </PropertyGroup> <ItemGroup> + <PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="1.0.0" PrivateAssets="All" /> <Reference Include="System" /> </ItemGroup> <ItemGroup> diff --git a/modules/mono/glue/GodotSharp/GodotSharpEditor/GodotSharpEditor.csproj b/modules/mono/glue/GodotSharp/GodotSharpEditor/GodotSharpEditor.csproj index 22853797c1..8785931312 100644 --- a/modules/mono/glue/GodotSharp/GodotSharpEditor/GodotSharpEditor.csproj +++ b/modules/mono/glue/GodotSharp/GodotSharpEditor/GodotSharpEditor.csproj @@ -30,6 +30,7 @@ <ConsolePause>false</ConsolePause> </PropertyGroup> <ItemGroup> + <PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="1.0.0" PrivateAssets="All" /> <Reference Include="System" /> </ItemGroup> <ItemGroup> diff --git a/modules/mono/managed_callable.cpp b/modules/mono/managed_callable.cpp index dfd78a8244..26347e9162 100644 --- a/modules/mono/managed_callable.cpp +++ b/modules/mono/managed_callable.cpp @@ -82,6 +82,7 @@ CallableCustom::CompareLessFunc ManagedCallable::get_compare_less_func() const { } ObjectID ManagedCallable::get_object() const { + // TODO: If the delegate target extends Godot.Object, use that instead! return CSharpLanguage::get_singleton()->get_managed_callable_middleman()->get_instance_id(); } diff --git a/modules/mono/mono_gc_handle.h b/modules/mono/mono_gc_handle.h index fbcb405b0d..005ee52b35 100644 --- a/modules/mono/mono_gc_handle.h +++ b/modules/mono/mono_gc_handle.h @@ -47,8 +47,8 @@ enum class GCHandleType : char { // Manual release of the GC handle must be done when using this struct struct MonoGCHandleData { - uint32_t handle; - gdmono::GCHandleType type; + uint32_t handle = 0; + gdmono::GCHandleType type = gdmono::GCHandleType::NIL; _FORCE_INLINE_ bool is_released() const { return !handle; } _FORCE_INLINE_ bool is_weak() const { return type == gdmono::GCHandleType::WEAK_HANDLE; } @@ -68,10 +68,7 @@ struct MonoGCHandleData { MonoGCHandleData(const MonoGCHandleData &) = default; - MonoGCHandleData() : - handle(0), - type(gdmono::GCHandleType::NIL) { - } + MonoGCHandleData() {} MonoGCHandleData(uint32_t p_handle, gdmono::GCHandleType p_type) : handle(p_handle), diff --git a/modules/mono/mono_gd/gd_mono.cpp b/modules/mono/mono_gd/gd_mono.cpp index 3298c5da4c..fbaa81e83f 100644 --- a/modules/mono/mono_gd/gd_mono.cpp +++ b/modules/mono/mono_gd/gd_mono.cpp @@ -133,6 +133,10 @@ void gd_mono_debug_init() { CharString da_args = OS::get_singleton()->get_environment("GODOT_MONO_DEBUGGER_AGENT").utf8(); + if (da_args.length()) { + OS::get_singleton()->set_environment("GODOT_MONO_DEBUGGER_AGENT", String()); + } + #ifdef TOOLS_ENABLED int da_port = GLOBAL_DEF("mono/debugger_agent/port", 23685); bool da_suspend = GLOBAL_DEF("mono/debugger_agent/wait_for_debugger", false); @@ -515,8 +519,8 @@ void GDMono::add_assembly(uint32_t p_domain_id, GDMonoAssembly *p_assembly) { GDMonoAssembly *GDMono::get_loaded_assembly(const String &p_name) { - if (p_name == "mscorlib") - return get_corlib_assembly(); + if (p_name == "mscorlib" && corlib_assembly) + return corlib_assembly; MonoDomain *domain = mono_domain_get(); uint32_t domain_id = domain ? mono_domain_get_id(domain) : 0; @@ -526,7 +530,9 @@ GDMonoAssembly *GDMono::get_loaded_assembly(const String &p_name) { bool GDMono::load_assembly(const String &p_name, GDMonoAssembly **r_assembly, bool p_refonly) { +#ifdef DEBUG_ENABLED CRASH_COND(!r_assembly); +#endif MonoAssemblyName *aname = mono_assembly_name_new(p_name.utf8()); bool result = load_assembly(p_name, aname, r_assembly, p_refonly); @@ -538,26 +544,27 @@ bool GDMono::load_assembly(const String &p_name, GDMonoAssembly **r_assembly, bo bool GDMono::load_assembly(const String &p_name, MonoAssemblyName *p_aname, GDMonoAssembly **r_assembly, bool p_refonly) { +#ifdef DEBUG_ENABLED CRASH_COND(!r_assembly); +#endif - print_verbose("Mono: Loading assembly " + p_name + (p_refonly ? " (refonly)" : "") + "..."); + return load_assembly(p_name, p_aname, r_assembly, p_refonly, GDMonoAssembly::get_default_search_dirs()); +} - MonoImageOpenStatus status = MONO_IMAGE_OK; - MonoAssembly *assembly = mono_assembly_load_full(p_aname, nullptr, &status, p_refonly); +bool GDMono::load_assembly(const String &p_name, MonoAssemblyName *p_aname, GDMonoAssembly **r_assembly, bool p_refonly, const Vector<String> &p_search_dirs) { - if (!assembly) - return false; - - ERR_FAIL_COND_V(status != MONO_IMAGE_OK, false); +#ifdef DEBUG_ENABLED + CRASH_COND(!r_assembly); +#endif - uint32_t domain_id = mono_domain_get_id(mono_domain_get()); + print_verbose("Mono: Loading assembly " + p_name + (p_refonly ? " (refonly)" : "") + "..."); - GDMonoAssembly **stored_assembly = assemblies[domain_id].getptr(p_name); + GDMonoAssembly *assembly = GDMonoAssembly::load(p_name, p_aname, p_refonly, p_search_dirs); - ERR_FAIL_COND_V(stored_assembly == nullptr, false); - ERR_FAIL_COND_V((*stored_assembly)->get_assembly() != assembly, false); + if (!assembly) + return false; - *r_assembly = *stored_assembly; + *r_assembly = assembly; print_verbose("Mono: Assembly " + p_name + (p_refonly ? " (refonly)" : "") + " loaded from path: " + (*r_assembly)->get_path()); diff --git a/modules/mono/mono_gd/gd_mono.h b/modules/mono/mono_gd/gd_mono.h index 4898833e8e..3b0be4c180 100644 --- a/modules/mono/mono_gd/gd_mono.h +++ b/modules/mono/mono_gd/gd_mono.h @@ -48,9 +48,9 @@ enum Type { }; struct Version { - uint64_t godot_api_hash; - uint32_t bindings_version; - uint32_t cs_glue_version; + uint64_t godot_api_hash = 0; + uint32_t bindings_version = 0; + uint32_t cs_glue_version = 0; bool operator==(const Version &p_other) const { return godot_api_hash == p_other.godot_api_hash && @@ -58,11 +58,7 @@ struct Version { cs_glue_version == p_other.cs_glue_version; } - Version() : - godot_api_hash(0), - bindings_version(0), - cs_glue_version(0) { - } + Version() {} Version(uint64_t p_godot_api_hash, uint32_t p_bindings_version, @@ -87,13 +83,10 @@ public: }; struct LoadedApiAssembly { - GDMonoAssembly *assembly; - bool out_of_sync; + GDMonoAssembly *assembly = nullptr; + bool out_of_sync = false; - LoadedApiAssembly() : - assembly(nullptr), - out_of_sync(false) { - } + LoadedApiAssembly() {} }; private: @@ -241,6 +234,7 @@ public: bool load_assembly(const String &p_name, GDMonoAssembly **r_assembly, bool p_refonly = false); bool load_assembly(const String &p_name, MonoAssemblyName *p_aname, GDMonoAssembly **r_assembly, bool p_refonly = false); + bool load_assembly(const String &p_name, MonoAssemblyName *p_aname, GDMonoAssembly **r_assembly, bool p_refonly, const Vector<String> &p_search_dirs); bool load_assembly_from(const String &p_name, const String &p_path, GDMonoAssembly **r_assembly, bool p_refonly = false); Error finalize_and_unload_domain(MonoDomain *p_domain); diff --git a/modules/mono/mono_gd/gd_mono_assembly.cpp b/modules/mono/mono_gd/gd_mono_assembly.cpp index 0f211eebc6..ca84338666 100644 --- a/modules/mono/mono_gd/gd_mono_assembly.cpp +++ b/modules/mono/mono_gd/gd_mono_assembly.cpp @@ -33,6 +33,7 @@ #include <mono/metadata/mono-debug.h> #include <mono/metadata/tokentype.h> +#include "core/io/file_access_pack.h" #include "core/list.h" #include "core/os/file_access.h" #include "core/os/os.h" @@ -99,7 +100,7 @@ void GDMonoAssembly::fill_search_dirs(Vector<String> &r_search_dirs, const Strin // - The 'load' hook is called after the assembly has been loaded. Its job is to add the // assembly to the list of loaded assemblies so that the 'search' hook can look it up. -void GDMonoAssembly::assembly_load_hook(MonoAssembly *assembly, void *user_data) { +void GDMonoAssembly::assembly_load_hook(MonoAssembly *assembly, [[maybe_unused]] void *user_data) { String name = String::utf8(mono_assembly_name_get_name(mono_assembly_get_name(assembly))); @@ -133,9 +134,7 @@ MonoAssembly *GDMonoAssembly::assembly_refonly_preload_hook(MonoAssemblyName *an return GDMonoAssembly::_preload_hook(aname, assemblies_path, user_data, true); } -MonoAssembly *GDMonoAssembly::_search_hook(MonoAssemblyName *aname, void *user_data, bool refonly) { - - (void)user_data; // UNUSED +MonoAssembly *GDMonoAssembly::_search_hook(MonoAssemblyName *aname, [[maybe_unused]] void *user_data, bool refonly) { String name = String::utf8(mono_assembly_name_get_name(aname)); bool has_extension = name.ends_with(".dll") || name.ends_with(".exe"); @@ -147,15 +146,13 @@ MonoAssembly *GDMonoAssembly::_search_hook(MonoAssemblyName *aname, void *user_d return nullptr; } -MonoAssembly *GDMonoAssembly::_preload_hook(MonoAssemblyName *aname, char **, void *user_data, bool refonly) { - - (void)user_data; // UNUSED +MonoAssembly *GDMonoAssembly::_preload_hook(MonoAssemblyName *aname, char **, [[maybe_unused]] void *user_data, bool refonly) { String name = String::utf8(mono_assembly_name_get_name(aname)); - return _load_assembly_search(name, search_dirs, refonly); + return _load_assembly_search(name, aname, refonly, search_dirs); } -MonoAssembly *GDMonoAssembly::_load_assembly_search(const String &p_name, const Vector<String> &p_search_dirs, bool p_refonly) { +MonoAssembly *GDMonoAssembly::_load_assembly_search(const String &p_name, MonoAssemblyName *p_aname, bool p_refonly, const Vector<String> &p_search_dirs) { MonoAssembly *res = nullptr; String path; @@ -168,21 +165,21 @@ MonoAssembly *GDMonoAssembly::_load_assembly_search(const String &p_name, const if (has_extension) { path = search_dir.plus_file(p_name); if (FileAccess::exists(path)) { - res = _real_load_assembly_from(path, p_refonly); + res = _real_load_assembly_from(path, p_refonly, p_aname); if (res != nullptr) return res; } } else { path = search_dir.plus_file(p_name + ".dll"); if (FileAccess::exists(path)) { - res = _real_load_assembly_from(path, p_refonly); + res = _real_load_assembly_from(path, p_refonly, p_aname); if (res != nullptr) return res; } path = search_dir.plus_file(p_name + ".exe"); if (FileAccess::exists(path)) { - res = _real_load_assembly_from(path, p_refonly); + res = _real_load_assembly_from(path, p_refonly, p_aname); if (res != nullptr) return res; } @@ -230,7 +227,7 @@ void GDMonoAssembly::initialize() { mono_install_assembly_load_hook(&assembly_load_hook, nullptr); } -MonoAssembly *GDMonoAssembly::_real_load_assembly_from(const String &p_path, bool p_refonly) { +MonoAssembly *GDMonoAssembly::_real_load_assembly_from(const String &p_path, bool p_refonly, MonoAssemblyName *p_aname) { Vector<uint8_t> data = FileAccess::get_file_as_array(p_path); ERR_FAIL_COND_V_MSG(data.empty(), nullptr, "Could read the assembly in the specified location"); @@ -255,7 +252,33 @@ MonoAssembly *GDMonoAssembly::_real_load_assembly_from(const String &p_path, boo true, &status, p_refonly, image_filename.utf8()); - ERR_FAIL_COND_V_MSG(status != MONO_IMAGE_OK || !image, nullptr, "Failed to open assembly image from the loaded data"); + ERR_FAIL_COND_V_MSG(status != MONO_IMAGE_OK || !image, nullptr, "Failed to open assembly image from memory: '" + p_path + "'."); + + if (p_aname != nullptr) { + // Check assembly version + const MonoTableInfo *table = mono_image_get_table_info(image, MONO_TABLE_ASSEMBLY); + + ERR_FAIL_NULL_V(table, nullptr); + + if (mono_table_info_get_rows(table)) { + uint32_t cols[MONO_ASSEMBLY_SIZE]; + mono_metadata_decode_row(table, 0, cols, MONO_ASSEMBLY_SIZE); + + // Not sure about .NET's policy. We will only ensure major and minor are equal, and ignore build and revision. + uint16_t major = cols[MONO_ASSEMBLY_MAJOR_VERSION]; + uint16_t minor = cols[MONO_ASSEMBLY_MINOR_VERSION]; + + uint16_t required_minor; + uint16_t required_major = mono_assembly_name_get_version(p_aname, &required_minor, nullptr, nullptr); + + if (required_major != 0) { + if (major != required_major && minor != required_minor) { + mono_image_close(image); + return nullptr; + } + } + } + } #ifdef DEBUG_ENABLED Vector<uint8_t> pdb_data; @@ -425,6 +448,26 @@ GDMonoClass *GDMonoAssembly::get_object_derived_class(const StringName &p_class) return match; } +GDMonoAssembly *GDMonoAssembly::load(const String &p_name, MonoAssemblyName *p_aname, bool p_refonly, const Vector<String> &p_search_dirs) { + + if (GDMono::get_singleton()->get_corlib_assembly() && (p_name == "mscorlib" || p_name == "mscorlib.dll")) + return GDMono::get_singleton()->get_corlib_assembly(); + + // We need to manually call the search hook in this case, as it won't be called in the next step + MonoAssembly *assembly = mono_assembly_invoke_search_hook(p_aname); + + if (!assembly) { + assembly = _load_assembly_search(p_name, p_aname, p_refonly, p_search_dirs); + ERR_FAIL_NULL_V(assembly, nullptr); + } + + GDMonoAssembly *loaded_asm = GDMono::get_singleton()->get_loaded_assembly(p_name); + ERR_FAIL_NULL_V_MSG(loaded_asm, nullptr, "Loaded assembly missing from table. Did we not receive the load hook?"); + ERR_FAIL_COND_V(loaded_asm->get_assembly() != assembly, nullptr); + + return loaded_asm; +} + GDMonoAssembly *GDMonoAssembly::load_from(const String &p_name, const String &p_path, bool p_refonly) { if (p_name == "mscorlib" || p_name == "mscorlib.dll") @@ -447,18 +490,7 @@ GDMonoAssembly *GDMonoAssembly::load_from(const String &p_name, const String &p_ return loaded_asm; } -GDMonoAssembly::GDMonoAssembly(const String &p_name, MonoImage *p_image, MonoAssembly *p_assembly) : - name(p_name), - image(p_image), - assembly(p_assembly), -#ifdef GD_MONO_HOT_RELOAD - modified_time(0), -#endif - gdobject_class_cache_updated(false) { -} - GDMonoAssembly::~GDMonoAssembly() { - if (image) unload(); } diff --git a/modules/mono/mono_gd/gd_mono_assembly.h b/modules/mono/mono_gd/gd_mono_assembly.h index 43c8225b74..cc8d699558 100644 --- a/modules/mono/mono_gd/gd_mono_assembly.h +++ b/modules/mono/mono_gd/gd_mono_assembly.h @@ -73,10 +73,10 @@ class GDMonoAssembly { MonoAssembly *assembly; #ifdef GD_MONO_HOT_RELOAD - uint64_t modified_time; + uint64_t modified_time = 0; #endif - bool gdobject_class_cache_updated; + bool gdobject_class_cache_updated = false; Map<StringName, GDMonoClass *> gdobject_class_cache; HashMap<ClassKey, GDMonoClass *, ClassKey::Hasher> cached_classes; @@ -93,8 +93,8 @@ class GDMonoAssembly { static MonoAssembly *_search_hook(MonoAssemblyName *aname, void *user_data, bool refonly); static MonoAssembly *_preload_hook(MonoAssemblyName *aname, char **assemblies_path, void *user_data, bool refonly); - static MonoAssembly *_real_load_assembly_from(const String &p_path, bool p_refonly); - static MonoAssembly *_load_assembly_search(const String &p_name, const Vector<String> &p_search_dirs, bool p_refonly); + static MonoAssembly *_real_load_assembly_from(const String &p_path, bool p_refonly, MonoAssemblyName *p_aname = nullptr); + static MonoAssembly *_load_assembly_search(const String &p_name, MonoAssemblyName *p_aname, bool p_refonly, const Vector<String> &p_search_dirs); friend class GDMono; static void initialize(); @@ -120,10 +120,16 @@ public: static String find_assembly(const String &p_name); static void fill_search_dirs(Vector<String> &r_search_dirs, const String &p_custom_config = String(), const String &p_custom_bcl_dir = String()); + static const Vector<String> &get_default_search_dirs() { return search_dirs; } + static GDMonoAssembly *load(const String &p_name, MonoAssemblyName *p_aname, bool p_refonly, const Vector<String> &p_search_dirs); static GDMonoAssembly *load_from(const String &p_name, const String &p_path, bool p_refonly); - GDMonoAssembly(const String &p_name, MonoImage *p_image, MonoAssembly *p_assembly); + GDMonoAssembly(const String &p_name, MonoImage *p_image, MonoAssembly *p_assembly) : + name(p_name), + image(p_image), + assembly(p_assembly) { + } ~GDMonoAssembly(); }; diff --git a/modules/mono/mono_gd/gd_mono_field.cpp b/modules/mono/mono_gd/gd_mono_field.cpp index e76cb84d43..948170f51c 100644 --- a/modules/mono/mono_gd/gd_mono_field.cpp +++ b/modules/mono/mono_gd/gd_mono_field.cpp @@ -501,7 +501,8 @@ void GDMonoField::set_value_from_variant(MonoObject *p_object, const Variant &p_ case Variant::PACKED_COLOR_ARRAY: { SET_FROM_ARRAY(PackedColorArray); } break; - default: break; + default: + break; } } break; diff --git a/modules/mono/mono_gd/gd_mono_log.cpp b/modules/mono/mono_gd/gd_mono_log.cpp index ca16c2b76a..b56350ae1b 100644 --- a/modules/mono/mono_gd/gd_mono_log.cpp +++ b/modules/mono/mono_gd/gd_mono_log.cpp @@ -175,7 +175,7 @@ void GDMonoLog::initialize() { log_level_id = get_log_level_id(log_level.get_data()); if (log_file) { - OS::get_singleton()->print("Mono: Logfile is: %s\n", log_file_path.utf8().get_data()); + OS::get_singleton()->print("Mono: Log file is: '%s'\n", log_file_path.utf8().get_data()); mono_trace_set_log_handler(mono_log_callback, this); } else { OS::get_singleton()->printerr("Mono: No log file, using default log handler\n"); diff --git a/modules/mono/mono_gd/gd_mono_method_thunk.h b/modules/mono/mono_gd/gd_mono_method_thunk.h index 0e05e974e9..82c6f32c81 100644 --- a/modules/mono/mono_gd/gd_mono_method_thunk.h +++ b/modules/mono/mono_gd/gd_mono_method_thunk.h @@ -50,7 +50,7 @@ struct GDMonoMethodThunk { typedef void(GD_MONO_STDCALL *M)(ParamTypes... p_args, MonoException **); - M mono_method_thunk; + M mono_method_thunk = nullptr; public: _FORCE_INLINE_ void invoke(ParamTypes... p_args, MonoException **r_exc) { @@ -81,9 +81,7 @@ public: mono_method_thunk = (M)mono_method_get_unmanaged_thunk(p_mono_method->get_mono_ptr()); } - GDMonoMethodThunk() : - mono_method_thunk(nullptr) { - } + GDMonoMethodThunk() {} explicit GDMonoMethodThunk(GDMonoMethod *p_mono_method) { set_from_method(p_mono_method); @@ -95,7 +93,7 @@ struct GDMonoMethodThunkR { typedef R(GD_MONO_STDCALL *M)(ParamTypes... p_args, MonoException **); - M mono_method_thunk; + M mono_method_thunk = nullptr; public: _FORCE_INLINE_ R invoke(ParamTypes... p_args, MonoException **r_exc) { @@ -127,9 +125,7 @@ public: mono_method_thunk = (M)mono_method_get_unmanaged_thunk(p_mono_method->get_mono_ptr()); } - GDMonoMethodThunkR() : - mono_method_thunk(nullptr) { - } + GDMonoMethodThunkR() {} explicit GDMonoMethodThunkR(GDMonoMethod *p_mono_method) { #ifdef DEBUG_ENABLED @@ -248,7 +244,7 @@ struct VariadicInvokeMonoMethodR<1, R, P1> { template <class... ParamTypes> struct GDMonoMethodThunk { - GDMonoMethod *mono_method; + GDMonoMethod *mono_method = nullptr; public: _FORCE_INLINE_ void invoke(ParamTypes... p_args, MonoException **r_exc) { @@ -277,9 +273,7 @@ public: mono_method = p_mono_method; } - GDMonoMethodThunk() : - mono_method(nullptr) { - } + GDMonoMethodThunk() {} explicit GDMonoMethodThunk(GDMonoMethod *p_mono_method) { set_from_method(p_mono_method); @@ -289,7 +283,7 @@ public: template <class R, class... ParamTypes> struct GDMonoMethodThunkR { - GDMonoMethod *mono_method; + GDMonoMethod *mono_method = nullptr; public: _FORCE_INLINE_ R invoke(ParamTypes... p_args, MonoException **r_exc) { @@ -318,9 +312,7 @@ public: mono_method = p_mono_method; } - GDMonoMethodThunkR() : - mono_method(nullptr) { - } + GDMonoMethodThunkR() {} explicit GDMonoMethodThunkR(GDMonoMethod *p_mono_method) { set_from_method(p_mono_method); diff --git a/modules/mono/mono_gd/gd_mono_utils.cpp b/modules/mono/mono_gd/gd_mono_utils.cpp index f9d492dabb..a2ae42ae9f 100644 --- a/modules/mono/mono_gd/gd_mono_utils.cpp +++ b/modules/mono/mono_gd/gd_mono_utils.cpp @@ -564,9 +564,10 @@ namespace Marshal { #ifdef MONO_GLUE_ENABLED #ifdef TOOLS_ENABLED -#define NO_GLUE_RET(m_ret) \ - { \ - if (!GDMonoCache::cached_data.godot_api_cache_updated) return m_ret; \ +#define NO_GLUE_RET(m_ret) \ + { \ + if (!GDMonoCache::cached_data.godot_api_cache_updated) \ + return m_ret; \ } #else #define NO_GLUE_RET(m_ret) \ @@ -663,8 +664,7 @@ GDMonoClass *make_generic_dictionary_type(MonoReflectionType *p_key_reftype, Mon } // namespace Marshal -ScopeThreadAttach::ScopeThreadAttach() : - mono_thread(nullptr) { +ScopeThreadAttach::ScopeThreadAttach() { if (likely(GDMono::get_singleton()->is_runtime_initialized()) && unlikely(!mono_domain_get())) { mono_thread = GDMonoUtils::attach_current_thread(); } diff --git a/modules/mono/mono_gd/gd_mono_utils.h b/modules/mono/mono_gd/gd_mono_utils.h index e3011ade5d..caf0c792b7 100644 --- a/modules/mono/mono_gd/gd_mono_utils.h +++ b/modules/mono/mono_gd/gd_mono_utils.h @@ -155,7 +155,7 @@ struct ScopeThreadAttach { ~ScopeThreadAttach(); private: - MonoThread *mono_thread; + MonoThread *mono_thread = nullptr; }; StringName get_native_godot_class_name(GDMonoClass *p_class); diff --git a/modules/mono/mono_gd/managed_type.h b/modules/mono/mono_gd/managed_type.h index 84d1837853..491a2f3d20 100644 --- a/modules/mono/mono_gd/managed_type.h +++ b/modules/mono/mono_gd/managed_type.h @@ -36,18 +36,15 @@ #include "gd_mono_header.h" struct ManagedType { - int type_encoding; - GDMonoClass *type_class; + int type_encoding = 0; + GDMonoClass *type_class = nullptr; static ManagedType from_class(GDMonoClass *p_class); static ManagedType from_class(MonoClass *p_mono_class); static ManagedType from_type(MonoType *p_mono_type); static ManagedType from_reftype(MonoReflectionType *p_mono_reftype); - ManagedType() : - type_encoding(0), - type_class(nullptr) { - } + ManagedType() {} ManagedType(int p_type_encoding, GDMonoClass *p_type_class) : type_encoding(p_type_encoding), diff --git a/modules/mono/utils/macros.h b/modules/mono/utils/macros.h index 8650d6cc09..dc542477f5 100644 --- a/modules/mono/utils/macros.h +++ b/modules/mono/utils/macros.h @@ -68,6 +68,6 @@ public: } // namespace gdmono #define SCOPE_EXIT \ - auto GD_UNIQUE_NAME(gd_scope_exit) = gdmono::ScopeExitAux() + [=]() + auto GD_UNIQUE_NAME(gd_scope_exit) = gdmono::ScopeExitAux() + [=]() -> void #endif // UTIL_MACROS_H diff --git a/modules/opensimplex/noise_texture.cpp b/modules/opensimplex/noise_texture.cpp index 2018f90e9f..16cd04b044 100644 --- a/modules/opensimplex/noise_texture.cpp +++ b/modules/opensimplex/noise_texture.cpp @@ -202,19 +202,22 @@ Ref<OpenSimplexNoise> NoiseTexture::get_noise() { } void NoiseTexture::set_width(int p_width) { - if (p_width == size.x) return; + if (p_width == size.x) + return; size.x = p_width; _queue_update(); } void NoiseTexture::set_height(int p_height) { - if (p_height == size.y) return; + if (p_height == size.y) + return; size.y = p_height; _queue_update(); } void NoiseTexture::set_seamless(bool p_seamless) { - if (p_seamless == seamless) return; + if (p_seamless == seamless) + return; seamless = p_seamless; _queue_update(); } @@ -224,7 +227,8 @@ bool NoiseTexture::get_seamless() { } void NoiseTexture::set_as_normalmap(bool p_as_normalmap) { - if (p_as_normalmap == as_normalmap) return; + if (p_as_normalmap == as_normalmap) + return; as_normalmap = p_as_normalmap; _queue_update(); _change_notify(); @@ -236,7 +240,8 @@ bool NoiseTexture::is_normalmap() { void NoiseTexture::set_bump_strength(float p_bump_strength) { - if (p_bump_strength == bump_strength) return; + if (p_bump_strength == bump_strength) + return; bump_strength = p_bump_strength; if (as_normalmap) _queue_update(); diff --git a/modules/opensimplex/open_simplex_noise.cpp b/modules/opensimplex/open_simplex_noise.cpp index 238faa4130..205c033614 100644 --- a/modules/opensimplex/open_simplex_noise.cpp +++ b/modules/opensimplex/open_simplex_noise.cpp @@ -70,7 +70,8 @@ int OpenSimplexNoise::get_seed() { } void OpenSimplexNoise::set_octaves(int p_octaves) { - if (p_octaves == octaves) return; + if (p_octaves == octaves) + return; ERR_FAIL_COND_MSG(p_octaves > MAX_OCTAVES, vformat("The number of OpenSimplexNoise octaves is limited to %d; ignoring the new value.", MAX_OCTAVES)); @@ -79,19 +80,22 @@ void OpenSimplexNoise::set_octaves(int p_octaves) { } void OpenSimplexNoise::set_period(float p_period) { - if (p_period == period) return; + if (p_period == period) + return; period = p_period; emit_changed(); } void OpenSimplexNoise::set_persistence(float p_persistence) { - if (p_persistence == persistence) return; + if (p_persistence == persistence) + return; persistence = p_persistence; emit_changed(); } void OpenSimplexNoise::set_lacunarity(float p_lacunarity) { - if (p_lacunarity == lacunarity) return; + if (p_lacunarity == lacunarity) + return; lacunarity = p_lacunarity; emit_changed(); } diff --git a/modules/pvr/texture_loader_pvr.cpp b/modules/pvr/texture_loader_pvr.cpp index a8e8a9a2f1..d498c6d858 100644 --- a/modules/pvr/texture_loader_pvr.cpp +++ b/modules/pvr/texture_loader_pvr.cpp @@ -111,9 +111,13 @@ RES ResourceFormatPVR::load(const String &p_path, const String &p_original_path, switch (flags & 0xFF) { case 0x18: - case 0xC: format = (flags & PVR_HAS_ALPHA) ? Image::FORMAT_PVRTC2A : Image::FORMAT_PVRTC2; break; + case 0xC: + format = (flags & PVR_HAS_ALPHA) ? Image::FORMAT_PVRTC2A : Image::FORMAT_PVRTC2; + break; case 0x19: - case 0xD: format = (flags & PVR_HAS_ALPHA) ? Image::FORMAT_PVRTC4A : Image::FORMAT_PVRTC4; break; + case 0xD: + format = (flags & PVR_HAS_ALPHA) ? Image::FORMAT_PVRTC4A : Image::FORMAT_PVRTC4; + break; case 0x16: format = Image::FORMAT_L8; break; @@ -257,9 +261,9 @@ struct PVRTCBlock { _FORCE_INLINE_ bool is_po2(uint32_t p_input) { if (p_input == 0) - return 0; + return false; uint32_t minus1 = p_input - 1; - return ((p_input | minus1) == (p_input ^ minus1)) ? 1 : 0; + return ((p_input | minus1) == (p_input ^ minus1)) ? true : false; } static void unpack_5554(const PVRTCBlock *p_block, int p_ab_colors[2][4]) { diff --git a/modules/tga/image_loader_tga.cpp b/modules/tga/image_loader_tga.cpp index fc9d727bb0..eb9f23d894 100644 --- a/modules/tga/image_loader_tga.cpp +++ b/modules/tga/image_loader_tga.cpp @@ -199,7 +199,7 @@ Error ImageLoaderTGA::convert_to_image(Ref<Image> p_image, const uint8_t *p_buff } } - p_image->create(width, height, 0, Image::FORMAT_RGBA8, image_data); + p_image->create(width, height, false, Image::FORMAT_RGBA8, image_data); return OK; } diff --git a/modules/theora/video_stream_theora.cpp b/modules/theora/video_stream_theora.cpp index b9f276fb12..3163b60bb8 100644 --- a/modules/theora/video_stream_theora.cpp +++ b/modules/theora/video_stream_theora.cpp @@ -80,7 +80,7 @@ int VideoStreamPlaybackTheora::queue_page(ogg_page *page) { return 0; } -void VideoStreamPlaybackTheora::video_write(void) { +void VideoStreamPlaybackTheora::video_write() { th_ycbcr_buffer yuv; th_decode_ycbcr_out(td, yuv); @@ -209,7 +209,8 @@ void VideoStreamPlaybackTheora::set_file(const String &p_file) { while (!stateflag) { int ret = buffer_data(); - if (ret == 0) break; + if (ret == 0) + break; while (ogg_sync_pageout(&oy, &og) > 0) { ogg_stream_state test; @@ -286,7 +287,8 @@ void VideoStreamPlaybackTheora::set_file(const String &p_file) { return; } vorbis_p++; - if (vorbis_p == 3) break; + if (vorbis_p == 3) + break; } /* The header pages/packets will arrive before anything else we diff --git a/modules/theora/video_stream_theora.h b/modules/theora/video_stream_theora.h index f98368ed8b..d5a2640794 100644 --- a/modules/theora/video_stream_theora.h +++ b/modules/theora/video_stream_theora.h @@ -63,7 +63,7 @@ class VideoStreamPlaybackTheora : public VideoStreamPlayback { int buffer_data(); int queue_page(ogg_page *page); - void video_write(void); + void video_write(); float get_time() const; bool theora_eos; diff --git a/modules/tinyexr/image_saver_tinyexr.cpp b/modules/tinyexr/image_saver_tinyexr.cpp index 05080289bd..bc30f4e4fd 100644 --- a/modules/tinyexr/image_saver_tinyexr.cpp +++ b/modules/tinyexr/image_saver_tinyexr.cpp @@ -267,13 +267,21 @@ Error save_exr(const String &p_path, const Ref<Image> &p_img, bool p_grayscale) header.channels = channel_infos; header.pixel_types = pixel_types; header.requested_pixel_types = requested_pixel_types; + header.compression_type = TINYEXR_COMPRESSIONTYPE_PIZ; - CharString utf8_filename = p_path.utf8(); - const char *err; - int ret = SaveEXRImageToFile(&image, &header, utf8_filename.ptr(), &err); - if (ret != TINYEXR_SUCCESS) { + unsigned char *mem = nullptr; + const char *err = nullptr; + + size_t bytes = SaveEXRImageToMemory(&image, &header, &mem, &err); + + if (bytes == 0) { print_error(String("Saving EXR failed. Error: {0}").format(varray(err))); return ERR_FILE_CANT_WRITE; + } else { + FileAccessRef ref = FileAccess::open(p_path, FileAccess::WRITE); + ERR_FAIL_COND_V(!ref, ERR_FILE_CANT_WRITE); + ref->store_buffer(mem, bytes); + free(mem); } return OK; diff --git a/modules/visual_script/visual_script_editor.cpp b/modules/visual_script/visual_script_editor.cpp index 9a4076bec4..3649486724 100644 --- a/modules/visual_script/visual_script_editor.cpp +++ b/modules/visual_script/visual_script_editor.cpp @@ -342,84 +342,212 @@ static Color _color_from_type(Variant::Type p_type, bool dark_theme = true) { Color color; if (dark_theme) switch (p_type) { - case Variant::NIL: color = Color(0.41, 0.93, 0.74); break; - - case Variant::BOOL: color = Color(0.55, 0.65, 0.94); break; - case Variant::INT: color = Color(0.49, 0.78, 0.94); break; - case Variant::FLOAT: color = Color(0.38, 0.85, 0.96); break; - case Variant::STRING: color = Color(0.42, 0.65, 0.93); break; - - case Variant::VECTOR2: color = Color(0.74, 0.57, 0.95); break; - case Variant::VECTOR2I: color = Color(0.74, 0.57, 0.95); break; - case Variant::RECT2: color = Color(0.95, 0.57, 0.65); break; - case Variant::RECT2I: color = Color(0.95, 0.57, 0.65); break; - case Variant::VECTOR3: color = Color(0.84, 0.49, 0.93); break; - case Variant::VECTOR3I: color = Color(0.84, 0.49, 0.93); break; - case Variant::TRANSFORM2D: color = Color(0.77, 0.93, 0.41); break; - case Variant::PLANE: color = Color(0.97, 0.44, 0.44); break; - case Variant::QUAT: color = Color(0.93, 0.41, 0.64); break; - case Variant::AABB: color = Color(0.93, 0.47, 0.57); break; - case Variant::BASIS: color = Color(0.89, 0.93, 0.41); break; - case Variant::TRANSFORM: color = Color(0.96, 0.66, 0.43); break; - - case Variant::COLOR: color = Color(0.62, 1.0, 0.44); break; - case Variant::NODE_PATH: color = Color(0.41, 0.58, 0.93); break; - case Variant::_RID: color = Color(0.41, 0.93, 0.6); break; - case Variant::OBJECT: color = Color(0.47, 0.95, 0.91); break; - case Variant::DICTIONARY: color = Color(0.47, 0.93, 0.69); break; - - case Variant::ARRAY: color = Color(0.88, 0.88, 0.88); break; - case Variant::PACKED_BYTE_ARRAY: color = Color(0.67, 0.96, 0.78); break; - case Variant::PACKED_INT32_ARRAY: color = Color(0.69, 0.86, 0.96); break; - case Variant::PACKED_FLOAT32_ARRAY: color = Color(0.59, 0.91, 0.97); break; - case Variant::PACKED_INT64_ARRAY: color = Color(0.69, 0.86, 0.96); break; - case Variant::PACKED_FLOAT64_ARRAY: color = Color(0.59, 0.91, 0.97); break; - case Variant::PACKED_STRING_ARRAY: color = Color(0.62, 0.77, 0.95); break; - case Variant::PACKED_VECTOR2_ARRAY: color = Color(0.82, 0.7, 0.96); break; - case Variant::PACKED_VECTOR3_ARRAY: color = Color(0.87, 0.61, 0.95); break; - case Variant::PACKED_COLOR_ARRAY: color = Color(0.91, 1.0, 0.59); break; + case Variant::NIL: + color = Color(0.41, 0.93, 0.74); + break; + + case Variant::BOOL: + color = Color(0.55, 0.65, 0.94); + break; + case Variant::INT: + color = Color(0.49, 0.78, 0.94); + break; + case Variant::FLOAT: + color = Color(0.38, 0.85, 0.96); + break; + case Variant::STRING: + color = Color(0.42, 0.65, 0.93); + break; + + case Variant::VECTOR2: + color = Color(0.74, 0.57, 0.95); + break; + case Variant::VECTOR2I: + color = Color(0.74, 0.57, 0.95); + break; + case Variant::RECT2: + color = Color(0.95, 0.57, 0.65); + break; + case Variant::RECT2I: + color = Color(0.95, 0.57, 0.65); + break; + case Variant::VECTOR3: + color = Color(0.84, 0.49, 0.93); + break; + case Variant::VECTOR3I: + color = Color(0.84, 0.49, 0.93); + break; + case Variant::TRANSFORM2D: + color = Color(0.77, 0.93, 0.41); + break; + case Variant::PLANE: + color = Color(0.97, 0.44, 0.44); + break; + case Variant::QUAT: + color = Color(0.93, 0.41, 0.64); + break; + case Variant::AABB: + color = Color(0.93, 0.47, 0.57); + break; + case Variant::BASIS: + color = Color(0.89, 0.93, 0.41); + break; + case Variant::TRANSFORM: + color = Color(0.96, 0.66, 0.43); + break; + + case Variant::COLOR: + color = Color(0.62, 1.0, 0.44); + break; + case Variant::NODE_PATH: + color = Color(0.41, 0.58, 0.93); + break; + case Variant::_RID: + color = Color(0.41, 0.93, 0.6); + break; + case Variant::OBJECT: + color = Color(0.47, 0.95, 0.91); + break; + case Variant::DICTIONARY: + color = Color(0.47, 0.93, 0.69); + break; + + case Variant::ARRAY: + color = Color(0.88, 0.88, 0.88); + break; + case Variant::PACKED_BYTE_ARRAY: + color = Color(0.67, 0.96, 0.78); + break; + case Variant::PACKED_INT32_ARRAY: + color = Color(0.69, 0.86, 0.96); + break; + case Variant::PACKED_FLOAT32_ARRAY: + color = Color(0.59, 0.91, 0.97); + break; + case Variant::PACKED_INT64_ARRAY: + color = Color(0.69, 0.86, 0.96); + break; + case Variant::PACKED_FLOAT64_ARRAY: + color = Color(0.59, 0.91, 0.97); + break; + case Variant::PACKED_STRING_ARRAY: + color = Color(0.62, 0.77, 0.95); + break; + case Variant::PACKED_VECTOR2_ARRAY: + color = Color(0.82, 0.7, 0.96); + break; + case Variant::PACKED_VECTOR3_ARRAY: + color = Color(0.87, 0.61, 0.95); + break; + case Variant::PACKED_COLOR_ARRAY: + color = Color(0.91, 1.0, 0.59); + break; default: color.set_hsv(p_type / float(Variant::VARIANT_MAX), 0.7, 0.7); } else switch (p_type) { - case Variant::NIL: color = Color(0.15, 0.89, 0.63); break; - - case Variant::BOOL: color = Color(0.43, 0.56, 0.92); break; - case Variant::INT: color = Color(0.31, 0.7, 0.91); break; - case Variant::FLOAT: color = Color(0.15, 0.8, 0.94); break; - case Variant::STRING: color = Color(0.27, 0.56, 0.91); break; - - case Variant::VECTOR2: color = Color(0.68, 0.46, 0.93); break; - case Variant::VECTOR2I: color = Color(0.68, 0.46, 0.93); break; - case Variant::RECT2: color = Color(0.93, 0.46, 0.56); break; - case Variant::RECT2I: color = Color(0.93, 0.46, 0.56); break; - case Variant::VECTOR3: color = Color(0.86, 0.42, 0.93); break; - case Variant::VECTOR3I: color = Color(0.86, 0.42, 0.93); break; - case Variant::TRANSFORM2D: color = Color(0.59, 0.81, 0.1); break; - case Variant::PLANE: color = Color(0.97, 0.44, 0.44); break; - case Variant::QUAT: color = Color(0.93, 0.41, 0.64); break; - case Variant::AABB: color = Color(0.93, 0.47, 0.57); break; - case Variant::BASIS: color = Color(0.7, 0.73, 0.1); break; - case Variant::TRANSFORM: color = Color(0.96, 0.56, 0.28); break; - - case Variant::COLOR: color = Color(0.24, 0.75, 0.0); break; - case Variant::NODE_PATH: color = Color(0.41, 0.58, 0.93); break; - case Variant::_RID: color = Color(0.17, 0.9, 0.45); break; - case Variant::OBJECT: color = Color(0.07, 0.84, 0.76); break; - case Variant::DICTIONARY: color = Color(0.34, 0.91, 0.62); break; - - case Variant::ARRAY: color = Color(0.45, 0.45, 0.45); break; - case Variant::PACKED_BYTE_ARRAY: color = Color(0.38, 0.92, 0.6); break; - case Variant::PACKED_INT32_ARRAY: color = Color(0.38, 0.73, 0.92); break; - case Variant::PACKED_FLOAT32_ARRAY: color = Color(0.25, 0.83, 0.95); break; - case Variant::PACKED_INT64_ARRAY: color = Color(0.38, 0.73, 0.92); break; - case Variant::PACKED_FLOAT64_ARRAY: color = Color(0.25, 0.83, 0.95); break; - case Variant::PACKED_STRING_ARRAY: color = Color(0.38, 0.62, 0.92); break; - case Variant::PACKED_VECTOR2_ARRAY: color = Color(0.62, 0.36, 0.92); break; - case Variant::PACKED_VECTOR3_ARRAY: color = Color(0.79, 0.35, 0.92); break; - case Variant::PACKED_COLOR_ARRAY: color = Color(0.57, 0.73, 0.0); break; + case Variant::NIL: + color = Color(0.15, 0.89, 0.63); + break; + + case Variant::BOOL: + color = Color(0.43, 0.56, 0.92); + break; + case Variant::INT: + color = Color(0.31, 0.7, 0.91); + break; + case Variant::FLOAT: + color = Color(0.15, 0.8, 0.94); + break; + case Variant::STRING: + color = Color(0.27, 0.56, 0.91); + break; + + case Variant::VECTOR2: + color = Color(0.68, 0.46, 0.93); + break; + case Variant::VECTOR2I: + color = Color(0.68, 0.46, 0.93); + break; + case Variant::RECT2: + color = Color(0.93, 0.46, 0.56); + break; + case Variant::RECT2I: + color = Color(0.93, 0.46, 0.56); + break; + case Variant::VECTOR3: + color = Color(0.86, 0.42, 0.93); + break; + case Variant::VECTOR3I: + color = Color(0.86, 0.42, 0.93); + break; + case Variant::TRANSFORM2D: + color = Color(0.59, 0.81, 0.1); + break; + case Variant::PLANE: + color = Color(0.97, 0.44, 0.44); + break; + case Variant::QUAT: + color = Color(0.93, 0.41, 0.64); + break; + case Variant::AABB: + color = Color(0.93, 0.47, 0.57); + break; + case Variant::BASIS: + color = Color(0.7, 0.73, 0.1); + break; + case Variant::TRANSFORM: + color = Color(0.96, 0.56, 0.28); + break; + + case Variant::COLOR: + color = Color(0.24, 0.75, 0.0); + break; + case Variant::NODE_PATH: + color = Color(0.41, 0.58, 0.93); + break; + case Variant::_RID: + color = Color(0.17, 0.9, 0.45); + break; + case Variant::OBJECT: + color = Color(0.07, 0.84, 0.76); + break; + case Variant::DICTIONARY: + color = Color(0.34, 0.91, 0.62); + break; + + case Variant::ARRAY: + color = Color(0.45, 0.45, 0.45); + break; + case Variant::PACKED_BYTE_ARRAY: + color = Color(0.38, 0.92, 0.6); + break; + case Variant::PACKED_INT32_ARRAY: + color = Color(0.38, 0.73, 0.92); + break; + case Variant::PACKED_FLOAT32_ARRAY: + color = Color(0.25, 0.83, 0.95); + break; + case Variant::PACKED_INT64_ARRAY: + color = Color(0.38, 0.73, 0.92); + break; + case Variant::PACKED_FLOAT64_ARRAY: + color = Color(0.25, 0.83, 0.95); + break; + case Variant::PACKED_STRING_ARRAY: + color = Color(0.38, 0.62, 0.92); + break; + case Variant::PACKED_VECTOR2_ARRAY: + color = Color(0.62, 0.36, 0.92); + break; + case Variant::PACKED_VECTOR3_ARRAY: + color = Color(0.79, 0.35, 0.92); + break; + case Variant::PACKED_COLOR_ARRAY: + color = Color(0.57, 0.73, 0.0); + break; default: color.set_hsv(p_type / float(Variant::VARIANT_MAX), 0.3, 0.3); diff --git a/modules/visual_script/visual_script_expression.cpp b/modules/visual_script/visual_script_expression.cpp index 71ed483d65..616a621845 100644 --- a/modules/visual_script/visual_script_expression.cpp +++ b/modules/visual_script/visual_script_expression.cpp @@ -386,11 +386,21 @@ Error VisualScriptExpression::_get_token(Token &r_token) { switch (next) { - case 'b': res = 8; break; - case 't': res = 9; break; - case 'n': res = 10; break; - case 'f': res = 12; break; - case 'r': res = 13; break; + case 'b': + res = 8; + break; + case 't': + res = 9; + break; + case 'n': + res = 10; + break; + case 'f': + res = 12; + break; + case 'r': + res = 13; + break; case 'u': { // hex number for (int j = 0; j < 4; j++) { @@ -1005,27 +1015,69 @@ VisualScriptExpression::ENode *VisualScriptExpression::_parse_expression() { Variant::Operator op = Variant::OP_MAX; switch (tk.type) { - case TK_OP_IN: op = Variant::OP_IN; break; - case TK_OP_EQUAL: op = Variant::OP_EQUAL; break; - case TK_OP_NOT_EQUAL: op = Variant::OP_NOT_EQUAL; break; - case TK_OP_LESS: op = Variant::OP_LESS; break; - case TK_OP_LESS_EQUAL: op = Variant::OP_LESS_EQUAL; break; - case TK_OP_GREATER: op = Variant::OP_GREATER; break; - case TK_OP_GREATER_EQUAL: op = Variant::OP_GREATER_EQUAL; break; - case TK_OP_AND: op = Variant::OP_AND; break; - case TK_OP_OR: op = Variant::OP_OR; break; - case TK_OP_NOT: op = Variant::OP_NOT; break; - case TK_OP_ADD: op = Variant::OP_ADD; break; - case TK_OP_SUB: op = Variant::OP_SUBTRACT; break; - case TK_OP_MUL: op = Variant::OP_MULTIPLY; break; - case TK_OP_DIV: op = Variant::OP_DIVIDE; break; - case TK_OP_MOD: op = Variant::OP_MODULE; break; - case TK_OP_SHIFT_LEFT: op = Variant::OP_SHIFT_LEFT; break; - case TK_OP_SHIFT_RIGHT: op = Variant::OP_SHIFT_RIGHT; break; - case TK_OP_BIT_AND: op = Variant::OP_BIT_AND; break; - case TK_OP_BIT_OR: op = Variant::OP_BIT_OR; break; - case TK_OP_BIT_XOR: op = Variant::OP_BIT_XOR; break; - case TK_OP_BIT_INVERT: op = Variant::OP_BIT_NEGATE; break; + case TK_OP_IN: + op = Variant::OP_IN; + break; + case TK_OP_EQUAL: + op = Variant::OP_EQUAL; + break; + case TK_OP_NOT_EQUAL: + op = Variant::OP_NOT_EQUAL; + break; + case TK_OP_LESS: + op = Variant::OP_LESS; + break; + case TK_OP_LESS_EQUAL: + op = Variant::OP_LESS_EQUAL; + break; + case TK_OP_GREATER: + op = Variant::OP_GREATER; + break; + case TK_OP_GREATER_EQUAL: + op = Variant::OP_GREATER_EQUAL; + break; + case TK_OP_AND: + op = Variant::OP_AND; + break; + case TK_OP_OR: + op = Variant::OP_OR; + break; + case TK_OP_NOT: + op = Variant::OP_NOT; + break; + case TK_OP_ADD: + op = Variant::OP_ADD; + break; + case TK_OP_SUB: + op = Variant::OP_SUBTRACT; + break; + case TK_OP_MUL: + op = Variant::OP_MULTIPLY; + break; + case TK_OP_DIV: + op = Variant::OP_DIVIDE; + break; + case TK_OP_MOD: + op = Variant::OP_MODULE; + break; + case TK_OP_SHIFT_LEFT: + op = Variant::OP_SHIFT_LEFT; + break; + case TK_OP_SHIFT_RIGHT: + op = Variant::OP_SHIFT_RIGHT; + break; + case TK_OP_BIT_AND: + op = Variant::OP_BIT_AND; + break; + case TK_OP_BIT_OR: + op = Variant::OP_BIT_OR; + break; + case TK_OP_BIT_XOR: + op = Variant::OP_BIT_XOR; + break; + case TK_OP_BIT_INVERT: + op = Variant::OP_BIT_NEGATE; + break; default: { }; } @@ -1074,36 +1126,74 @@ VisualScriptExpression::ENode *VisualScriptExpression::_parse_expression() { unary = true; break; - case Variant::OP_MULTIPLY: priority = 2; break; - case Variant::OP_DIVIDE: priority = 2; break; - case Variant::OP_MODULE: priority = 2; break; + case Variant::OP_MULTIPLY: + priority = 2; + break; + case Variant::OP_DIVIDE: + priority = 2; + break; + case Variant::OP_MODULE: + priority = 2; + break; - case Variant::OP_ADD: priority = 3; break; - case Variant::OP_SUBTRACT: priority = 3; break; + case Variant::OP_ADD: + priority = 3; + break; + case Variant::OP_SUBTRACT: + priority = 3; + break; - case Variant::OP_SHIFT_LEFT: priority = 4; break; - case Variant::OP_SHIFT_RIGHT: priority = 4; break; + case Variant::OP_SHIFT_LEFT: + priority = 4; + break; + case Variant::OP_SHIFT_RIGHT: + priority = 4; + break; - case Variant::OP_BIT_AND: priority = 5; break; - case Variant::OP_BIT_XOR: priority = 6; break; - case Variant::OP_BIT_OR: priority = 7; break; + case Variant::OP_BIT_AND: + priority = 5; + break; + case Variant::OP_BIT_XOR: + priority = 6; + break; + case Variant::OP_BIT_OR: + priority = 7; + break; - case Variant::OP_LESS: priority = 8; break; - case Variant::OP_LESS_EQUAL: priority = 8; break; - case Variant::OP_GREATER: priority = 8; break; - case Variant::OP_GREATER_EQUAL: priority = 8; break; + case Variant::OP_LESS: + priority = 8; + break; + case Variant::OP_LESS_EQUAL: + priority = 8; + break; + case Variant::OP_GREATER: + priority = 8; + break; + case Variant::OP_GREATER_EQUAL: + priority = 8; + break; - case Variant::OP_EQUAL: priority = 8; break; - case Variant::OP_NOT_EQUAL: priority = 8; break; + case Variant::OP_EQUAL: + priority = 8; + break; + case Variant::OP_NOT_EQUAL: + priority = 8; + break; - case Variant::OP_IN: priority = 10; break; + case Variant::OP_IN: + priority = 10; + break; case Variant::OP_NOT: priority = 11; unary = true; break; - case Variant::OP_AND: priority = 12; break; - case Variant::OP_OR: priority = 13; break; + case Variant::OP_AND: + priority = 12; + break; + case Variant::OP_OR: + priority = 13; + break; default: { _set_error("Parser bug, invalid operator in expression: " + itos(expression[i].op)); diff --git a/modules/visual_script/visual_script_property_selector.cpp b/modules/visual_script/visual_script_property_selector.cpp index f57853078d..b06cf513ba 100644 --- a/modules/visual_script/visual_script_property_selector.cpp +++ b/modules/visual_script/visual_script_property_selector.cpp @@ -172,7 +172,7 @@ void VisualScriptPropertySelector::_update_search() { item->set_metadata(0, F->get().name); item->set_icon(0, type_icons[F->get().type]); item->set_metadata(1, "get"); - item->set_collapsed(1); + item->set_collapsed(true); item->set_selectable(0, true); item->set_selectable(1, false); item->set_selectable(2, false); @@ -257,7 +257,7 @@ void VisualScriptPropertySelector::_update_search() { item->set_selectable(0, true); item->set_metadata(1, "method"); - item->set_collapsed(1); + item->set_collapsed(true); item->set_selectable(1, false); item->set_selectable(2, false); @@ -320,7 +320,7 @@ void VisualScriptPropertySelector::create_visualscript_item(const String &name, item->set_metadata(0, name); item->set_metadata(1, "action"); item->set_selectable(0, true); - item->set_collapsed(1); + item->set_collapsed(true); item->set_selectable(1, false); item->set_selectable(2, false); item->set_metadata(2, connecting); diff --git a/modules/visual_script/visual_script_yield_nodes.cpp b/modules/visual_script/visual_script_yield_nodes.cpp index b300aec385..2296745ad0 100644 --- a/modules/visual_script/visual_script_yield_nodes.cpp +++ b/modules/visual_script/visual_script_yield_nodes.cpp @@ -81,10 +81,18 @@ String VisualScriptYield::get_caption() const { String VisualScriptYield::get_text() const { switch (yield_mode) { - case YIELD_RETURN: return ""; break; - case YIELD_FRAME: return "Next Frame"; break; - case YIELD_PHYSICS_FRAME: return "Next Physics Frame"; break; - case YIELD_WAIT: return rtos(wait_time) + " sec(s)"; break; + case YIELD_RETURN: + return ""; + break; + case YIELD_FRAME: + return "Next Frame"; + break; + case YIELD_PHYSICS_FRAME: + return "Next Physics Frame"; + break; + case YIELD_WAIT: + return rtos(wait_time) + " sec(s)"; + break; } return String(); @@ -122,9 +130,15 @@ public: case VisualScriptYield::YIELD_RETURN: ret = STEP_EXIT_FUNCTION_BIT; break; //return the yield - case VisualScriptYield::YIELD_FRAME: state->connect_to_signal(tree, "idle_frame", Array()); break; - case VisualScriptYield::YIELD_PHYSICS_FRAME: state->connect_to_signal(tree, "physics_frame", Array()); break; - case VisualScriptYield::YIELD_WAIT: state->connect_to_signal(tree->create_timer(wait_time).ptr(), "timeout", Array()); break; + case VisualScriptYield::YIELD_FRAME: + state->connect_to_signal(tree, "idle_frame", Array()); + break; + case VisualScriptYield::YIELD_PHYSICS_FRAME: + state->connect_to_signal(tree, "physics_frame", Array()); + break; + case VisualScriptYield::YIELD_WAIT: + state->connect_to_signal(tree->create_timer(wait_time).ptr(), "timeout", Array()); + break; } *p_working_mem = state; diff --git a/modules/webm/video_stream_webm.cpp b/modules/webm/video_stream_webm.cpp index a2d0f78f5f..faf1f32124 100644 --- a/modules/webm/video_stream_webm.cpp +++ b/modules/webm/video_stream_webm.cpp @@ -95,26 +95,8 @@ private: /**/ VideoStreamPlaybackWebm::VideoStreamPlaybackWebm() : - audio_track(0), - webm(nullptr), - video(nullptr), - audio(nullptr), - video_frames(nullptr), - audio_frame(nullptr), - video_frames_pos(0), - video_frames_capacity(0), - num_decoded_samples(0), - samples_offset(-1), - mix_callback(nullptr), - mix_udata(nullptr), - playing(false), - paused(false), - delay_compensation(0.0), - time(0.0), - video_frame_delay(0.0), - video_pos(0.0), - texture(memnew(ImageTexture)), - pcm(nullptr) {} + + texture(memnew(ImageTexture)) {} VideoStreamPlaybackWebm::~VideoStreamPlaybackWebm() { delete_pointers(); @@ -438,8 +420,7 @@ void VideoStreamPlaybackWebm::delete_pointers() { /**/ -VideoStreamWebm::VideoStreamWebm() : - audio_track(0) {} +VideoStreamWebm::VideoStreamWebm() {} Ref<VideoStreamPlayback> VideoStreamWebm::instance_playback() { diff --git a/modules/webm/video_stream_webm.h b/modules/webm/video_stream_webm.h index 6677fb85aa..0a32dfc671 100644 --- a/modules/webm/video_stream_webm.h +++ b/modules/webm/video_stream_webm.h @@ -44,27 +44,27 @@ class VideoStreamPlaybackWebm : public VideoStreamPlayback { GDCLASS(VideoStreamPlaybackWebm, VideoStreamPlayback); String file_name; - int audio_track; + int audio_track = 0; - WebMDemuxer *webm; - VPXDecoder *video; - OpusVorbisDecoder *audio; + WebMDemuxer *webm = nullptr; + VPXDecoder *video = nullptr; + OpusVorbisDecoder *audio = nullptr; - WebMFrame **video_frames, *audio_frame; - int video_frames_pos, video_frames_capacity; + WebMFrame **video_frames = nullptr, *audio_frame = nullptr; + int video_frames_pos = 0, video_frames_capacity = 0; - int num_decoded_samples, samples_offset; - AudioMixCallback mix_callback; - void *mix_udata; + int num_decoded_samples = 0, samples_offset = -1; + AudioMixCallback mix_callback = nullptr; + void *mix_udata = nullptr; - bool playing, paused; - double delay_compensation; - double time, video_frame_delay, video_pos; + bool playing = false, paused = false; + double delay_compensation = 0.0; + double time = 0.0, video_frame_delay = 0.0, video_pos = 0.0; Vector<uint8_t> frame_data; Ref<ImageTexture> texture; - float *pcm; + float *pcm = nullptr; public: VideoStreamPlaybackWebm(); @@ -111,7 +111,7 @@ class VideoStreamWebm : public VideoStream { GDCLASS(VideoStreamWebm, VideoStream); String file; - int audio_track; + int audio_track = 0; protected: static void _bind_methods(); diff --git a/modules/webp/image_loader_webp.cpp b/modules/webp/image_loader_webp.cpp index 0998977bb4..6c778d2809 100644 --- a/modules/webp/image_loader_webp.cpp +++ b/modules/webp/image_loader_webp.cpp @@ -135,7 +135,7 @@ Error webp_load_image_from_buffer(Image *p_image, const uint8_t *p_buffer, int p ERR_FAIL_COND_V_MSG(errdec, ERR_FILE_CORRUPT, "Failed decoding WebP image."); - p_image->create(features.width, features.height, 0, features.has_alpha ? Image::FORMAT_RGBA8 : Image::FORMAT_RGB8, dst_image); + p_image->create(features.width, features.height, false, features.has_alpha ? Image::FORMAT_RGBA8 : Image::FORMAT_RGB8, dst_image); return OK; } diff --git a/modules/webrtc/webrtc_data_channel_js.cpp b/modules/webrtc/webrtc_data_channel_js.cpp index 1b360720a2..40c3f5801b 100644 --- a/modules/webrtc/webrtc_data_channel_js.cpp +++ b/modules/webrtc/webrtc_data_channel_js.cpp @@ -312,14 +312,14 @@ WebRTCDataChannelJS::WebRTCDataChannelJS(int js_id) { return; } var len = buffer.length*buffer.BYTES_PER_ELEMENT; - var out = Module._malloc(len); - Module.HEAPU8.set(buffer, out); + var out = _malloc(len); + HEAPU8.set(buffer, out); ccall("_emrtc_on_ch_message", "void", ["number", "number", "number", "number"], [c_ptr, out, len, is_string] ); - Module._free(out); + _free(out); } }, this, js_id); diff --git a/modules/webrtc/webrtc_multiplayer.cpp b/modules/webrtc/webrtc_multiplayer.cpp index 78a4d1e61a..f294733961 100644 --- a/modules/webrtc/webrtc_multiplayer.cpp +++ b/modules/webrtc/webrtc_multiplayer.cpp @@ -144,7 +144,8 @@ void WebRTCMultiplayer::poll() { void WebRTCMultiplayer::_find_next_peer() { Map<int, Ref<ConnectedPeer>>::Element *E = peer_map.find(next_packet_peer); - if (E) E = E->next(); + if (E) + E = E->next(); // After last. while (E) { for (List<Ref<WebRTCDataChannel>>::Element *F = E->get()->channels.front(); F; F = F->next()) { diff --git a/modules/websocket/editor_debugger_server_websocket.cpp b/modules/websocket/editor_debugger_server_websocket.cpp new file mode 100644 index 0000000000..0cf78eaa9e --- /dev/null +++ b/modules/websocket/editor_debugger_server_websocket.cpp @@ -0,0 +1,92 @@ +/*************************************************************************/ +/* script_editor_debugger_websocket.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2019 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2019 Godot Engine contributors (cf. AUTHORS.md) */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "editor_debugger_server_websocket.h" + +#include "core/project_settings.h" +#include "editor/editor_settings.h" +#include "modules/websocket/remote_debugger_peer_websocket.h" + +void EditorDebuggerServerWebSocket::_peer_connected(int p_id, String _protocol) { + pending_peers.push_back(p_id); +} + +void EditorDebuggerServerWebSocket::_peer_disconnected(int p_id, bool p_was_clean) { + if (pending_peers.find(p_id)) + pending_peers.erase(p_id); +} + +void EditorDebuggerServerWebSocket::poll() { + server->poll(); +} + +Error EditorDebuggerServerWebSocket::start() { + int remote_port = (int)EditorSettings::get_singleton()->get("network/debug/remote_port"); + Vector<String> protocols; + protocols.push_back("binary"); // compatibility with EMSCRIPTEN TCP-to-WebSocket layer. + return server->listen(remote_port, protocols); +} + +void EditorDebuggerServerWebSocket::stop() { + server->stop(); + pending_peers.clear(); +} + +bool EditorDebuggerServerWebSocket::is_active() const { + return server->is_listening(); +} + +bool EditorDebuggerServerWebSocket::is_connection_available() const { + return pending_peers.size() > 0; +} + +Ref<RemoteDebuggerPeer> EditorDebuggerServerWebSocket::take_connection() { + ERR_FAIL_COND_V(!is_connection_available(), Ref<RemoteDebuggerPeer>()); + RemoteDebuggerPeer *peer = memnew(RemoteDebuggerPeerWebSocket(server->get_peer(pending_peers[0]))); + pending_peers.pop_front(); + return peer; +} + +EditorDebuggerServerWebSocket::EditorDebuggerServerWebSocket() { + server = Ref<WebSocketServer>(WebSocketServer::create()); + int max_pkts = (int)GLOBAL_GET("network/limits/debugger/max_queued_messages"); + server->set_buffers(8192, max_pkts, 8192, max_pkts); + server->connect("client_connected", callable_mp(this, &EditorDebuggerServerWebSocket::_peer_connected)); + server->connect("client_disconnected", callable_mp(this, &EditorDebuggerServerWebSocket::_peer_disconnected)); +} + +EditorDebuggerServerWebSocket::~EditorDebuggerServerWebSocket() { + stop(); +} + +EditorDebuggerServer *EditorDebuggerServerWebSocket::create(const String &p_protocol) { + ERR_FAIL_COND_V(p_protocol != "ws://", nullptr); + return memnew(EditorDebuggerServerWebSocket); +} diff --git a/modules/websocket/editor_debugger_server_websocket.h b/modules/websocket/editor_debugger_server_websocket.h new file mode 100644 index 0000000000..81a31d8364 --- /dev/null +++ b/modules/websocket/editor_debugger_server_websocket.h @@ -0,0 +1,63 @@ +/*************************************************************************/ +/* script_editor_debugger_websocket.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2019 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2019 Godot Engine contributors (cf. AUTHORS.md) */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef SCRIPT_EDITOR_DEBUGGER_WEBSOCKET_H +#define SCRIPT_EDITOR_DEBUGGER_WEBSOCKET_H + +#include "modules/websocket/websocket_server.h" + +#include "editor/debugger/editor_debugger_server.h" + +class EditorDebuggerServerWebSocket : public EditorDebuggerServer { + + GDCLASS(EditorDebuggerServerWebSocket, EditorDebuggerServer); + +private: + Ref<WebSocketServer> server; + List<int> pending_peers; + +public: + static EditorDebuggerServer *create(const String &p_protocol); + + void _peer_connected(int p_peer, String p_protocol); + void _peer_disconnected(int p_peer, bool p_was_clean); + + void poll(); + Error start(); + void stop(); + bool is_active() const; + bool is_connection_available() const; + Ref<RemoteDebuggerPeer> take_connection(); + + EditorDebuggerServerWebSocket(); + ~EditorDebuggerServerWebSocket(); +}; + +#endif // SCRIPT_EDITOR_DEBUGGER_WEBSOCKET_H diff --git a/modules/websocket/emws_client.cpp b/modules/websocket/emws_client.cpp index bbe4d6dc5b..bc9d75d327 100644 --- a/modules/websocket/emws_client.cpp +++ b/modules/websocket/emws_client.cpp @@ -142,14 +142,14 @@ Error EMWSClient::connect_to_host(String p_host, String p_path, uint16_t p_port, } var len = buffer.length*buffer.BYTES_PER_ELEMENT; - var out = Module._malloc(len); - Module.HEAPU8.set(buffer, out); + var out = _malloc(len); + HEAPU8.set(buffer, out); ccall("_esws_on_message", "void", ["number", "number", "number", "number"], [c_ptr, out, len, is_string] ); - Module._free(out); + _free(out); }); socket.addEventListener("error", function (event) { @@ -231,8 +231,9 @@ Error EMWSClient::set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffe } EMWSClient::EMWSClient() { - _in_buf_size = nearest_shift((int)GLOBAL_GET(WSC_IN_BUF) - 1) + 10; - _in_pkt_size = nearest_shift((int)GLOBAL_GET(WSC_IN_PKT) - 1); + _in_buf_size = DEF_BUF_SHIFT; + _in_pkt_size = DEF_PKT_SHIFT; + _is_connecting = false; _peer = Ref<EMWSPeer>(memnew(EMWSPeer)); /* clang-format off */ diff --git a/modules/websocket/register_types.cpp b/modules/websocket/register_types.cpp index e389d75ffd..1ad249e1eb 100644 --- a/modules/websocket/register_types.cpp +++ b/modules/websocket/register_types.cpp @@ -40,24 +40,12 @@ #include "wsl_client.h" #include "wsl_server.h" #endif +#ifdef TOOLS_ENABLED +#include "editor/debugger/editor_debugger_server.h" +#include "editor_debugger_server_websocket.h" +#endif void register_websocket_types() { -#define _SET_HINT(NAME, _VAL_, _MAX_) \ - GLOBAL_DEF(NAME, _VAL_); \ - ProjectSettings::get_singleton()->set_custom_property_info(NAME, PropertyInfo(Variant::INT, NAME, PROPERTY_HINT_RANGE, "2," #_MAX_ ",1,or_greater")); - - // Client buffers project settings - _SET_HINT(WSC_IN_BUF, 64, 4096); - _SET_HINT(WSC_IN_PKT, 1024, 16384); - _SET_HINT(WSC_OUT_BUF, 64, 4096); - _SET_HINT(WSC_OUT_PKT, 1024, 16384); - - // Server buffers project settings - _SET_HINT(WSS_IN_BUF, 64, 4096); - _SET_HINT(WSS_IN_PKT, 1024, 16384); - _SET_HINT(WSS_OUT_BUF, 64, 4096); - _SET_HINT(WSS_OUT_PKT, 1024, 16384); - #ifdef JAVASCRIPT_ENABLED EMWSPeer::make_default(); EMWSClient::make_default(); @@ -72,6 +60,10 @@ void register_websocket_types() { ClassDB::register_custom_instance_class<WebSocketServer>(); ClassDB::register_custom_instance_class<WebSocketClient>(); ClassDB::register_custom_instance_class<WebSocketPeer>(); + +#ifdef TOOLS_ENABLED + EditorDebuggerServer::register_protocol_handler("ws://", EditorDebuggerServerWebSocket::create); +#endif } void unregister_websocket_types() {} diff --git a/modules/websocket/remote_debugger_peer_websocket.cpp b/modules/websocket/remote_debugger_peer_websocket.cpp new file mode 100644 index 0000000000..d156a39f53 --- /dev/null +++ b/modules/websocket/remote_debugger_peer_websocket.cpp @@ -0,0 +1,134 @@ +/*************************************************************************/ +/* script_debugger_websocket.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2019 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2019 Godot Engine contributors (cf. AUTHORS.md) */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#include "remote_debugger_peer_websocket.h" + +#include "core/project_settings.h" + +Error RemoteDebuggerPeerWebSocket::connect_to_host(const String &p_uri) { + + Vector<String> protocols; + protocols.push_back("binary"); // Compatibility for emscripten TCP-to-WebSocket. + + ws_client->connect_to_url(p_uri, protocols); + ws_client->poll(); + + if (ws_client->get_connection_status() == WebSocketClient::CONNECTION_DISCONNECTED) { + + ERR_PRINT("Remote Debugger: Unable to connect. Status: " + String::num(ws_client->get_connection_status()) + "."); + return FAILED; + } + + ws_peer = ws_client->get_peer(1); + + return OK; +} + +bool RemoteDebuggerPeerWebSocket::is_peer_connected() { + return ws_peer.is_valid() && ws_peer->is_connected_to_host(); +} + +void RemoteDebuggerPeerWebSocket::poll() { + ws_client->poll(); + + while (ws_peer->is_connected_to_host() && ws_peer->get_available_packet_count() > 0 && in_queue.size() < max_queued_messages) { + Variant var; + Error err = ws_peer->get_var(var); + ERR_CONTINUE(err != OK); + ERR_CONTINUE(var.get_type() != Variant::ARRAY); + in_queue.push_back(var); + } + + while (ws_peer->is_connected_to_host() && out_queue.size() > 0) { + Array var = out_queue[0]; + Error err = ws_peer->put_var(var); + ERR_BREAK(err != OK); // Peer buffer full? + out_queue.pop_front(); + } +} + +int RemoteDebuggerPeerWebSocket::get_max_message_size() const { + return 8 << 20; // 8 Mib +} + +bool RemoteDebuggerPeerWebSocket::has_message() { + return in_queue.size() > 0; +} + +Array RemoteDebuggerPeerWebSocket::get_message() { + ERR_FAIL_COND_V(in_queue.size() < 1, Array()); + Array msg = in_queue[0]; + in_queue.pop_front(); + return msg; +} + +Error RemoteDebuggerPeerWebSocket::put_message(const Array &p_arr) { + if (out_queue.size() >= max_queued_messages) + return ERR_OUT_OF_MEMORY; + out_queue.push_back(p_arr); + return OK; +} + +void RemoteDebuggerPeerWebSocket::close() { + if (ws_peer.is_valid()) { + ws_peer.unref(); + } + ws_client->disconnect_from_host(); +} + +bool RemoteDebuggerPeerWebSocket::can_block() const { +#ifdef JAVASCRIPT_ENABLED + return false; +#else + return true; +#endif +} + +RemoteDebuggerPeerWebSocket::RemoteDebuggerPeerWebSocket(Ref<WebSocketPeer> p_peer) { +#ifdef JAVASCRIPT_ENABLED + ws_client = Ref<WebSocketClient>(memnew(EMWSClient)); +#else + ws_client = Ref<WebSocketClient>(memnew(WSLClient)); +#endif + max_queued_messages = (int)GLOBAL_GET("network/limits/debugger/max_queued_messages"); + ws_client->set_buffers(8192, max_queued_messages, 8192, max_queued_messages); + ws_peer = p_peer; +} + +RemoteDebuggerPeer *RemoteDebuggerPeerWebSocket::create(const String &p_uri) { + ERR_FAIL_COND_V(!p_uri.begins_with("ws://") && !p_uri.begins_with("wss://"), nullptr); + RemoteDebuggerPeerWebSocket *peer = memnew(RemoteDebuggerPeerWebSocket); + Error err = peer->connect_to_host(p_uri); + if (err != OK) { + memdelete(peer); + return nullptr; + } + return peer; +} diff --git a/modules/websocket/remote_debugger_peer_websocket.h b/modules/websocket/remote_debugger_peer_websocket.h new file mode 100644 index 0000000000..fe46533bed --- /dev/null +++ b/modules/websocket/remote_debugger_peer_websocket.h @@ -0,0 +1,66 @@ +/*************************************************************************/ +/* script_debugger_websocket.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2019 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2019 Godot Engine contributors (cf. AUTHORS.md) */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef SCRIPT_DEBUGGER_WEBSOCKET_H +#define SCRIPT_DEBUGGER_WEBSOCKET_H + +#ifdef JAVASCRIPT_ENABLED +#include "modules/websocket/emws_client.h" +#else +#include "modules/websocket/wsl_client.h" +#endif +#include "core/debugger/remote_debugger_peer.h" + +class RemoteDebuggerPeerWebSocket : public RemoteDebuggerPeer { + + Ref<WebSocketClient> ws_client; + Ref<WebSocketPeer> ws_peer; + List<Array> in_queue; + List<Array> out_queue; + + int max_queued_messages; + +public: + static RemoteDebuggerPeer *create(const String &p_uri); + + Error connect_to_host(const String &p_uri); + bool is_peer_connected(); + int get_max_message_size() const; + bool has_message(); + Error put_message(const Array &p_arr); + Array get_message(); + void close(); + void poll(); + bool can_block() const; + + RemoteDebuggerPeerWebSocket(Ref<WebSocketPeer> p_peer = Ref<WebSocketPeer>()); +}; + +#endif // SCRIPT_DEBUGGER_WEBSOCKET_H diff --git a/modules/websocket/websocket_macros.h b/modules/websocket/websocket_macros.h index f7eafcff1f..cf4545b435 100644 --- a/modules/websocket/websocket_macros.h +++ b/modules/websocket/websocket_macros.h @@ -31,15 +31,9 @@ #ifndef WEBSOCKETMACTOS_H #define WEBSOCKETMACTOS_H -#define WSC_IN_BUF "network/limits/websocket_client/max_in_buffer_kb" -#define WSC_IN_PKT "network/limits/websocket_client/max_in_packets" -#define WSC_OUT_BUF "network/limits/websocket_client/max_out_buffer_kb" -#define WSC_OUT_PKT "network/limits/websocket_client/max_out_packets" - -#define WSS_IN_BUF "network/limits/websocket_server/max_in_buffer_kb" -#define WSS_IN_PKT "network/limits/websocket_server/max_in_packets" -#define WSS_OUT_BUF "network/limits/websocket_server/max_out_buffer_kb" -#define WSS_OUT_PKT "network/limits/websocket_server/max_out_packets" +// Defaults per peer buffers, 1024 packets with a shared 65536 bytes payload. +#define DEF_PKT_SHIFT 10 +#define DEF_BUF_SHIFT 16 /* clang-format off */ #define GDCICLASS(CNAME) \ diff --git a/modules/websocket/wsl_client.cpp b/modules/websocket/wsl_client.cpp index 9f05500eb9..0eaafe2d66 100644 --- a/modules/websocket/wsl_client.cpp +++ b/modules/websocket/wsl_client.cpp @@ -336,10 +336,10 @@ Error WSLClient::set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer } WSLClient::WSLClient() { - _in_buf_size = nearest_shift((int)GLOBAL_GET(WSC_IN_BUF) - 1) + 10; - _in_pkt_size = nearest_shift((int)GLOBAL_GET(WSC_IN_PKT) - 1); - _out_buf_size = nearest_shift((int)GLOBAL_GET(WSC_OUT_BUF) - 1) + 10; - _out_pkt_size = nearest_shift((int)GLOBAL_GET(WSC_OUT_PKT) - 1); + _in_buf_size = DEF_BUF_SHIFT; + _in_pkt_size = DEF_PKT_SHIFT; + _out_buf_size = DEF_BUF_SHIFT; + _out_pkt_size = DEF_PKT_SHIFT; _peer.instance(); _tcp.instance(); diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index 6f155a6ffa..cc4685973e 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -298,10 +298,10 @@ Error WSLServer::set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer } WSLServer::WSLServer() { - _in_buf_size = nearest_shift((int)GLOBAL_GET(WSS_IN_BUF) - 1) + 10; - _in_pkt_size = nearest_shift((int)GLOBAL_GET(WSS_IN_PKT) - 1); - _out_buf_size = nearest_shift((int)GLOBAL_GET(WSS_OUT_BUF) - 1) + 10; - _out_pkt_size = nearest_shift((int)GLOBAL_GET(WSS_OUT_PKT) - 1); + _in_buf_size = DEF_BUF_SHIFT; + _in_pkt_size = DEF_PKT_SHIFT; + _out_buf_size = DEF_BUF_SHIFT; + _out_pkt_size = DEF_PKT_SHIFT; _server.instance(); } diff --git a/modules/xatlas_unwrap/register_types.cpp b/modules/xatlas_unwrap/register_types.cpp index 8c5525bed3..f77646ce28 100644 --- a/modules/xatlas_unwrap/register_types.cpp +++ b/modules/xatlas_unwrap/register_types.cpp @@ -137,6 +137,7 @@ bool xatlas_mesh_lightmap_unwrap_callback(float p_texel_size, const float *p_ver pack_options.maxChartSize = 4096; pack_options.blockAlign = true; + pack_options.padding = 1; pack_options.texelsPerUnit = 1.0 / p_texel_size; xatlas::Atlas *atlas = xatlas::Create(); diff --git a/platform/android/api/java_class_wrapper.h b/platform/android/api/java_class_wrapper.h index 59fcd94b4d..3dfdc75cf4 100644 --- a/platform/android/api/java_class_wrapper.h +++ b/platform/android/api/java_class_wrapper.h @@ -83,9 +83,13 @@ class JavaClass : public Reference { switch (p_sig) { - case ARG_TYPE_VOID: r_type = Variant::NIL; break; + case ARG_TYPE_VOID: + r_type = Variant::NIL; + break; case ARG_TYPE_BOOLEAN | ARG_NUMBER_CLASS_BIT: - case ARG_TYPE_BOOLEAN: r_type = Variant::BOOL; break; + case ARG_TYPE_BOOLEAN: + r_type = Variant::BOOL; + break; case ARG_TYPE_BYTE | ARG_NUMBER_CLASS_BIT: case ARG_TYPE_BYTE: r_type = Variant::INT; @@ -121,10 +125,18 @@ class JavaClass : public Reference { r_type = Variant::FLOAT; likelihood = 0.5; break; - case ARG_TYPE_STRING: r_type = Variant::STRING; break; - case ARG_TYPE_CLASS: r_type = Variant::OBJECT; break; - case ARG_ARRAY_BIT | ARG_TYPE_VOID: r_type = Variant::NIL; break; - case ARG_ARRAY_BIT | ARG_TYPE_BOOLEAN: r_type = Variant::ARRAY; break; + case ARG_TYPE_STRING: + r_type = Variant::STRING; + break; + case ARG_TYPE_CLASS: + r_type = Variant::OBJECT; + break; + case ARG_ARRAY_BIT | ARG_TYPE_VOID: + r_type = Variant::NIL; + break; + case ARG_ARRAY_BIT | ARG_TYPE_BOOLEAN: + r_type = Variant::ARRAY; + break; case ARG_ARRAY_BIT | ARG_TYPE_BYTE: r_type = Variant::PACKED_BYTE_ARRAY; likelihood = 1.0; @@ -153,8 +165,12 @@ class JavaClass : public Reference { r_type = Variant::PACKED_FLOAT32_ARRAY; likelihood = 0.5; break; - case ARG_ARRAY_BIT | ARG_TYPE_STRING: r_type = Variant::PACKED_STRING_ARRAY; break; - case ARG_ARRAY_BIT | ARG_TYPE_CLASS: r_type = Variant::ARRAY; break; + case ARG_ARRAY_BIT | ARG_TYPE_STRING: + r_type = Variant::PACKED_STRING_ARRAY; + break; + case ARG_ARRAY_BIT | ARG_TYPE_CLASS: + r_type = Variant::ARRAY; + break; } } diff --git a/platform/android/export/export.cpp b/platform/android/export/export.cpp index d4fc52eaa9..f346ca54d2 100644 --- a/platform/android/export/export.cpp +++ b/platform/android/export/export.cpp @@ -372,7 +372,8 @@ class EditorExportPlatformAndroid : public EditorExportPlatform { } d.name = vendor + " " + device; - if (device == String()) continue; + if (device == String()) + continue; } ndevices.push_back(d); diff --git a/platform/android/java/lib/src/org/godotengine/godot/Godot.java b/platform/android/java/lib/src/org/godotengine/godot/Godot.java index 957f6223a9..ffe5402a54 100644 --- a/platform/android/java/lib/src/org/godotengine/godot/Godot.java +++ b/platform/android/java/lib/src/org/godotengine/godot/Godot.java @@ -920,7 +920,8 @@ public abstract class Godot extends FragmentActivity implements SensorEventListe int cnt = 0; for (int i = cc.length; --i >= 0; cnt += cc[i] != 0 ? 1 : 0) ; - if (cnt == 0) return super.onKeyMultiple(inKeyCode, repeatCount, event); + if (cnt == 0) + return super.onKeyMultiple(inKeyCode, repeatCount, event); mRenderView.queueOnRenderThread(new Runnable() { // This method will be called on the rendering thread: public void run() { diff --git a/platform/android/java/lib/src/org/godotengine/godot/plugin/SignalInfo.java b/platform/android/java/lib/src/org/godotengine/godot/plugin/SignalInfo.java index b940d679f0..424251169b 100644 --- a/platform/android/java/lib/src/org/godotengine/godot/plugin/SignalInfo.java +++ b/platform/android/java/lib/src/org/godotengine/godot/plugin/SignalInfo.java @@ -5,8 +5,8 @@ /* GODOT ENGINE */ /* https://godotengine.org */ /*************************************************************************/ -/* Copyright (c) 2007-2019 Juan Linietsky, Ariel Manzur. */ -/* Copyright (c) 2014-2019 Godot Engine contributors (cf. AUTHORS.md) */ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ /* */ /* Permission is hereby granted, free of charge, to any person obtaining */ /* a copy of this software and associated documentation files (the */ diff --git a/platform/android/net_socket_android.cpp b/platform/android/net_socket_android.cpp index 320bdd3817..dc0154ea01 100644 --- a/platform/android/net_socket_android.cpp +++ b/platform/android/net_socket_android.cpp @@ -72,11 +72,6 @@ void NetSocketAndroid::make_default() { _create = _create_func; } -NetSocketAndroid::NetSocketAndroid() : - wants_broadcast(false), - multicast_groups(0) { -} - NetSocketAndroid::~NetSocketAndroid() { close(); } diff --git a/platform/android/net_socket_android.h b/platform/android/net_socket_android.h index 4fc80d2de1..7d6267e8b3 100644 --- a/platform/android/net_socket_android.h +++ b/platform/android/net_socket_android.h @@ -52,8 +52,8 @@ private: static jmethodID _multicast_lock_acquire; static jmethodID _multicast_lock_release; - bool wants_broadcast; - int multicast_groups; + bool wants_broadcast = false; + int multicast_groups = 0; static void multicast_lock_acquire(); static void multicast_lock_release(); @@ -71,7 +71,7 @@ public: virtual Error join_multicast_group(const IP_Address &p_multi_address, String p_if_name); virtual Error leave_multicast_group(const IP_Address &p_multi_address, String p_if_name); - NetSocketAndroid(); + NetSocketAndroid() {} ~NetSocketAndroid(); }; diff --git a/platform/haiku/key_mapping_haiku.h b/platform/haiku/key_mapping_haiku.h index a0e85e3390..e735108e44 100644 --- a/platform/haiku/key_mapping_haiku.h +++ b/platform/haiku/key_mapping_haiku.h @@ -32,7 +32,7 @@ #define KEY_MAPPING_HAIKU_H class KeyMappingHaiku { - KeyMappingHaiku(){}; + KeyMappingHaiku() {} public: static unsigned int get_keysym(int32 raw_char, int32 key); diff --git a/platform/iphone/app_delegate.mm b/platform/iphone/app_delegate.mm index 0ac8bb7a56..fb30441bd3 100644 --- a/platform/iphone/app_delegate.mm +++ b/platform/iphone/app_delegate.mm @@ -247,37 +247,31 @@ static void on_focus_in(ViewController *view_controller, bool *is_focus_out) { int joy_id = [self getJoyIdForController:controller]; if (element == gamepad.buttonA) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_0, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_A, gamepad.buttonA.isPressed); } else if (element == gamepad.buttonB) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_1, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_B, gamepad.buttonB.isPressed); } else if (element == gamepad.buttonX) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_2, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_X, gamepad.buttonX.isPressed); } else if (element == gamepad.buttonY) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_3, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_Y, gamepad.buttonY.isPressed); } else if (element == gamepad.leftShoulder) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_L, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_LEFT_SHOULDER, gamepad.leftShoulder.isPressed); } else if (element == gamepad.rightShoulder) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_R, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_RIGHT_SHOULDER, gamepad.rightShoulder.isPressed); - } else if (element == gamepad.leftTrigger) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_L2, - gamepad.leftTrigger.isPressed); - } else if (element == gamepad.rightTrigger) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_R2, - gamepad.rightTrigger.isPressed); } else if (element == gamepad.dpad) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_UP, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_UP, gamepad.dpad.up.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_DOWN, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_DOWN, gamepad.dpad.down.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_LEFT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_LEFT, gamepad.dpad.left.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_RIGHT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_RIGHT, gamepad.dpad.right.isPressed); }; @@ -285,20 +279,20 @@ static void on_focus_in(ViewController *view_controller, bool *is_focus_out) { jx.min = -1; if (element == gamepad.leftThumbstick) { jx.value = gamepad.leftThumbstick.xAxis.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_LX, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_LEFT_X, jx); jx.value = -gamepad.leftThumbstick.yAxis.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_LY, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_LEFT_Y, jx); } else if (element == gamepad.rightThumbstick) { jx.value = gamepad.rightThumbstick.xAxis.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_RX, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_RIGHT_X, jx); jx.value = -gamepad.rightThumbstick.yAxis.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_RY, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_RIGHT_Y, jx); } else if (element == gamepad.leftTrigger) { jx.value = gamepad.leftTrigger.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_L2, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_TRIGGER_LEFT, jx); } else if (element == gamepad.rightTrigger) { jx.value = gamepad.rightTrigger.value; - OSIPhone::get_singleton()->joy_axis(joy_id, JOY_ANALOG_R2, jx); + OSIPhone::get_singleton()->joy_axis(joy_id, JOY_AXIS_TRIGGER_RIGHT, jx); }; }; } else if (controller.gamepad != nil) { @@ -309,31 +303,31 @@ static void on_focus_in(ViewController *view_controller, bool *is_focus_out) { int joy_id = [self getJoyIdForController:controller]; if (element == gamepad.buttonA) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_0, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_A, gamepad.buttonA.isPressed); } else if (element == gamepad.buttonB) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_1, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_B, gamepad.buttonB.isPressed); } else if (element == gamepad.buttonX) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_2, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_X, gamepad.buttonX.isPressed); } else if (element == gamepad.buttonY) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_3, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_Y, gamepad.buttonY.isPressed); } else if (element == gamepad.leftShoulder) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_L, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_LEFT_SHOULDER, gamepad.leftShoulder.isPressed); } else if (element == gamepad.rightShoulder) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_R, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_RIGHT_SHOULDER, gamepad.rightShoulder.isPressed); } else if (element == gamepad.dpad) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_UP, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_UP, gamepad.dpad.up.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_DOWN, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_DOWN, gamepad.dpad.down.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_LEFT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_LEFT, gamepad.dpad.left.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_RIGHT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_RIGHT, gamepad.dpad.right.isPressed); }; }; @@ -347,19 +341,19 @@ static void on_focus_in(ViewController *view_controller, bool *is_focus_out) { int joy_id = [self getJoyIdForController:controller]; if (element == gamepad.buttonA) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_0, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_A, gamepad.buttonA.isPressed); } else if (element == gamepad.buttonX) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_2, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_X, gamepad.buttonX.isPressed); } else if (element == gamepad.dpad) { - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_UP, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_UP, gamepad.dpad.up.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_DOWN, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_DOWN, gamepad.dpad.down.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_LEFT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_LEFT, gamepad.dpad.left.isPressed); - OSIPhone::get_singleton()->joy_button(joy_id, JOY_DPAD_RIGHT, + OSIPhone::get_singleton()->joy_button(joy_id, JOY_BUTTON_DPAD_RIGHT, gamepad.dpad.right.isPressed); }; }; diff --git a/platform/iphone/export/export.cpp b/platform/iphone/export/export.cpp index 3efe338ac7..3904fd6cf9 100644 --- a/platform/iphone/export/export.cpp +++ b/platform/iphone/export/export.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "export.h" + #include "core/io/image_loader.h" #include "core/io/marshalls.h" #include "core/io/resource_saver.h" @@ -74,12 +75,9 @@ class EditorExportPlatformIOS : public EditorExportPlatform { struct ExportArchitecture { String name; - bool is_default; + bool is_default = false; - ExportArchitecture() : - name(""), - is_default(false) { - } + ExportArchitecture() {} ExportArchitecture(String p_name, bool p_is_default) { name = p_name; @@ -423,7 +421,8 @@ String EditorExportPlatformIOS::_get_linker_flags() { String result; for (int i = 0; i < export_plugins.size(); ++i) { String flags = export_plugins[i]->get_ios_linker_flags(); - if (flags.length() == 0) continue; + if (flags.length() == 0) + continue; if (result.length() > 0) { result += ' '; } @@ -456,8 +455,10 @@ void EditorExportPlatformIOS::_blend_and_rotate(Ref<Image> &p_dst, Ref<Image> &p int xs = (x_pos >= 0) ? 0 : -x_pos; int ys = (y_pos >= 0) ? 0 : -y_pos; - if (sw + x_pos > p_dst->get_width()) sw = p_dst->get_width() - x_pos; - if (sh + y_pos > p_dst->get_height()) sh = p_dst->get_height() - y_pos; + if (sw + x_pos > p_dst->get_width()) + sw = p_dst->get_width() - x_pos; + if (sh + y_pos > p_dst->get_height()) + sh = p_dst->get_height() - y_pos; for (int y = ys; y < sh; y++) { for (int x = xs; x < sw; x++) { diff --git a/platform/iphone/game_center.mm b/platform/iphone/game_center.mm index 99d539d4ff..52f025f333 100644 --- a/platform/iphone/game_center.mm +++ b/platform/iphone/game_center.mm @@ -395,6 +395,6 @@ GameCenter::GameCenter() { authenticated = false; }; -GameCenter::~GameCenter(){}; +GameCenter::~GameCenter() {} #endif diff --git a/platform/iphone/godot_iphone.cpp b/platform/iphone/godot_iphone.cpp index 3e67362e16..cea0e5c7f0 100644 --- a/platform/iphone/godot_iphone.cpp +++ b/platform/iphone/godot_iphone.cpp @@ -50,7 +50,8 @@ int iphone_main(int width, int height, int argc, char **argv, String data_dir) { size_t len = strlen(argv[0]); while (len--) { - if (argv[0][len] == '/') break; + if (argv[0][len] == '/') + break; } if (len >= 0) { diff --git a/platform/iphone/icloud.mm b/platform/iphone/icloud.mm index 251f78f2da..1b9b354727 100644 --- a/platform/iphone/icloud.mm +++ b/platform/iphone/icloud.mm @@ -354,6 +354,6 @@ ICloud::ICloud() { }]; } -ICloud::~ICloud(){}; +ICloud::~ICloud() {} #endif diff --git a/platform/iphone/in_app_store.mm b/platform/iphone/in_app_store.mm index a8a887824f..c5a8ed2b2b 100644 --- a/platform/iphone/in_app_store.mm +++ b/platform/iphone/in_app_store.mm @@ -334,6 +334,6 @@ void InAppStore::set_auto_finish_transaction(bool b) { auto_finish_transactions = b; } -InAppStore::~InAppStore(){}; +InAppStore::~InAppStore() {} #endif diff --git a/platform/iphone/ios.mm b/platform/iphone/ios.mm index 2656f125b9..697c9b8a3d 100644 --- a/platform/iphone/ios.mm +++ b/platform/iphone/ios.mm @@ -81,4 +81,4 @@ String iOS::get_rate_url(int p_app_id) const { return ret; }; -iOS::iOS(){}; +iOS::iOS() {} diff --git a/platform/iphone/os_iphone.cpp b/platform/iphone/os_iphone.cpp index 3ef521a61a..9254e660d8 100644 --- a/platform/iphone/os_iphone.cpp +++ b/platform/iphone/os_iphone.cpp @@ -372,8 +372,8 @@ void OSIPhone::finalize() { event_count = 0; }; -void OSIPhone::set_mouse_show(bool p_show){}; -void OSIPhone::set_mouse_grab(bool p_grab){}; +void OSIPhone::set_mouse_show(bool p_show) {} +void OSIPhone::set_mouse_grab(bool p_grab) {} bool OSIPhone::is_mouse_grab_enabled() const { @@ -390,7 +390,7 @@ int OSIPhone::get_mouse_button_state() const { return 0; }; -void OSIPhone::set_window_title(const String &p_title){}; +void OSIPhone::set_window_title(const String &p_title) {} void OSIPhone::alert(const String &p_alert, const String &p_title) { diff --git a/platform/javascript/SCsub b/platform/javascript/SCsub index 7239648937..dcf9a46bf9 100644 --- a/platform/javascript/SCsub +++ b/platform/javascript/SCsub @@ -4,6 +4,7 @@ Import("env") javascript_files = [ "audio_driver_javascript.cpp", + "display_server_javascript.cpp", "http_client_javascript.cpp", "javascript_eval.cpp", "javascript_main.cpp", @@ -17,22 +18,22 @@ if env["threads_enabled"]: build = env.add_program(build_targets, javascript_files) js_libraries = [ - "http_request.js", + "native/http_request.js", ] for lib in js_libraries: env.Append(LINKFLAGS=["--js-library", env.File(lib).path]) env.Depends(build, js_libraries) -js_modules = [ - "id_handler.js", +js_pre = [ + "native/id_handler.js", + "native/utils.js", ] -for module in js_modules: - env.Append(LINKFLAGS=["--pre-js", env.File(module).path]) -env.Depends(build, js_modules) +for js in js_pre: + env.Append(LINKFLAGS=["--pre-js", env.File(js).path]) +env.Depends(build, js_pre) engine = [ "engine/preloader.js", - "engine/loader.js", "engine/utils.js", "engine/engine.js", ] @@ -47,11 +48,17 @@ wrap_list = [ js_wrapped = env.Textfile("#bin/godot", [env.File(f) for f in wrap_list], TEXTFILESUFFIX="${PROGSUFFIX}.wrapped.js") zip_dir = env.Dir("#bin/.javascript_zip") -out_files = [zip_dir.File("godot.js"), zip_dir.File("godot.wasm"), zip_dir.File("godot.html")] -in_files = [js_wrapped, build[1], "#misc/dist/html/full-size.html"] +binary_name = "godot.tools" if env["tools"] else "godot" +out_files = [ + zip_dir.File(binary_name + ".js"), + zip_dir.File(binary_name + ".wasm"), + zip_dir.File(binary_name + ".html"), +] +html_file = "#misc/dist/html/full-size.html" +in_files = [js_wrapped, build[1], html_file] if env["threads_enabled"]: in_files.append(build[2]) - out_files.append(zip_dir.File("godot.worker.js")) + out_files.append(zip_dir.File(binary_name + ".worker.js")) zip_files = env.InstallAs(out_files, in_files) env.Zip( diff --git a/platform/javascript/audio_driver_javascript.cpp b/platform/javascript/audio_driver_javascript.cpp index 8f857478e5..690ce5ba47 100644 --- a/platform/javascript/audio_driver_javascript.cpp +++ b/platform/javascript/audio_driver_javascript.cpp @@ -190,19 +190,43 @@ void AudioDriverJavaScript::lock() { void AudioDriverJavaScript::unlock() { } -void AudioDriverJavaScript::finish() { +void AudioDriverJavaScript::finish_async() { + + // Close the context, add the operation to the async_finish list in module. + int id = _driver_id; + _driver_id = 0; /* clang-format off */ EM_ASM({ + var ref = Module.IDHandler.get($0); + Module.async_finish.push(new Promise(function(accept, reject) { + if (!ref) { + console.log("Ref not found!", $0, Module.IDHandler); + setTimeout(accept, 0); + } else { + const context = ref['context']; + // Disconnect script and input. + ref['script'].disconnect(); + if (ref['input']) + ref['input'].disconnect(); + ref = null; + context.close().then(function() { + accept(); + }).catch(function(e) { + accept(); + }); + } + })); Module.IDHandler.remove($0); - }, _driver_id); + }, id); /* clang-format on */ +} +void AudioDriverJavaScript::finish() { if (internal_buffer) { memdelete_arr(internal_buffer); internal_buffer = nullptr; } - _driver_id = 0; } Error AudioDriverJavaScript::capture_start() { diff --git a/platform/javascript/audio_driver_javascript.h b/platform/javascript/audio_driver_javascript.h index f6f2dacd4e..4a44f4683f 100644 --- a/platform/javascript/audio_driver_javascript.h +++ b/platform/javascript/audio_driver_javascript.h @@ -56,6 +56,7 @@ public: virtual void lock(); virtual void unlock(); virtual void finish(); + void finish_async(); virtual Error capture_start(); virtual Error capture_stop(); diff --git a/platform/javascript/detect.py b/platform/javascript/detect.py index 9486e10717..81287cead8 100644 --- a/platform/javascript/detect.py +++ b/platform/javascript/detect.py @@ -22,7 +22,7 @@ def get_opts(): # eval() can be a security concern, so it can be disabled. BoolVariable("javascript_eval", "Enable JavaScript eval interface", True), BoolVariable("threads_enabled", "Enable WebAssembly Threads support (limited browser support)", False), - BoolVariable("use_closure_compiler", "Use closure compiler to minimize Javascript code", False), + BoolVariable("use_closure_compiler", "Use closure compiler to minimize JavaScript code", False), ] @@ -57,7 +57,7 @@ def configure(env): env.Append(CPPDEFINES=["DEBUG_ENABLED"]) # Retain function names for backtraces at the cost of file size. env.Append(LINKFLAGS=["--profiling-funcs"]) - else: # 'debug' + else: # "debug" env.Append(CPPDEFINES=["DEBUG_ENABLED"]) env.Append(CCFLAGS=["-O1", "-g"]) env.Append(LINKFLAGS=["-O1", "-g"]) @@ -150,7 +150,7 @@ def configure(env): env.Append(LIBS=["idbfs.js"]) env.Append(LINKFLAGS=["-s", "BINARYEN=1"]) - env.Append(LINKFLAGS=["-s", "MODULARIZE=1", "-s", 'EXPORT_NAME="Godot"']) + env.Append(LINKFLAGS=["-s", "MODULARIZE=1", "-s", "EXPORT_NAME='Godot'"]) # Allow increasing memory buffer size during runtime. This is efficient # when using WebAssembly (in comparison to asm.js) and works well for @@ -162,5 +162,10 @@ def configure(env): env.Append(LINKFLAGS=["-s", "INVOKE_RUN=0"]) - # callMain for manual start, FS for preloading. - env.Append(LINKFLAGS=["-s", 'EXTRA_EXPORTED_RUNTIME_METHODS=["callMain", "FS"]']) + # Allow use to take control of swapping WebGL buffers. + env.Append(LINKFLAGS=["-s", "OFFSCREEN_FRAMEBUFFER=1"]) + + # callMain for manual start, FS for preloading, PATH and ERRNO_CODES for BrowserFS. + env.Append(LINKFLAGS=["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=['callMain', 'FS', 'PATH']"]) + # Add code that allow exiting runtime. + env.Append(LINKFLAGS=["-s", "EXIT_RUNTIME=1"]) diff --git a/platform/javascript/display_server_javascript.cpp b/platform/javascript/display_server_javascript.cpp new file mode 100644 index 0000000000..c302614eca --- /dev/null +++ b/platform/javascript/display_server_javascript.cpp @@ -0,0 +1,1201 @@ +#include "platform/javascript/display_server_javascript.h" + +#include "drivers/dummy/rasterizer_dummy.h" +#include "platform/javascript/os_javascript.h" + +#include <emscripten.h> +#include <png.h> + +#include "dom_keys.inc" + +#define DOM_BUTTON_LEFT 0 +#define DOM_BUTTON_MIDDLE 1 +#define DOM_BUTTON_RIGHT 2 +#define DOM_BUTTON_XBUTTON1 3 +#define DOM_BUTTON_XBUTTON2 4 + +DisplayServerJavaScript *DisplayServerJavaScript::get_singleton() { + return static_cast<DisplayServerJavaScript *>(DisplayServer::get_singleton()); +} + +// Window (canvas) +extern "C" EMSCRIPTEN_KEEPALIVE void _set_canvas_id(uint8_t *p_data, int p_data_size) { + DisplayServerJavaScript *display = DisplayServerJavaScript::get_singleton(); + display->canvas_id.parse_utf8((const char *)p_data, p_data_size); + display->canvas_id = "#" + display->canvas_id; +} + +static void focus_canvas() { + + /* clang-format off */ + EM_ASM( + Module['canvas'].focus(); + ); + /* clang-format on */ +} + +static bool is_canvas_focused() { + + /* clang-format off */ + return EM_ASM_INT_V( + return document.activeElement == Module['canvas']; + ); + /* clang-format on */ +} + +static Point2 compute_position_in_canvas(int x, int y) { + DisplayServerJavaScript *display = DisplayServerJavaScript::get_singleton(); + int canvas_x = EM_ASM_INT({ + return Module['canvas'].getBoundingClientRect().x; + }); + int canvas_y = EM_ASM_INT({ + return Module['canvas'].getBoundingClientRect().y; + }); + int canvas_width; + int canvas_height; + emscripten_get_canvas_element_size(display->canvas_id.utf8().get_data(), &canvas_width, &canvas_height); + + double element_width; + double element_height; + emscripten_get_element_css_size(display->canvas_id.utf8().get_data(), &element_width, &element_height); + + return Point2((int)(canvas_width / element_width * (x - canvas_x)), + (int)(canvas_height / element_height * (y - canvas_y))); +} + +static bool cursor_inside_canvas = true; + +EM_BOOL DisplayServerJavaScript::fullscreen_change_callback(int p_event_type, const EmscriptenFullscreenChangeEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + // Empty ID is canvas. + String target_id = String::utf8(p_event->id); + if (target_id.empty() || "#" + target_id == display->canvas_id) { + // This event property is the only reliable data on + // browser fullscreen state. + if (p_event->isFullscreen) { + display->window_mode = WINDOW_MODE_FULLSCREEN; + } else { + display->window_mode = WINDOW_MODE_WINDOWED; + } + } + return false; +} + +// Drag and drop callback (see native/utils.js). +extern "C" EMSCRIPTEN_KEEPALIVE void _drop_files_callback(char *p_filev[], int p_filec) { + DisplayServerJavaScript *ds = DisplayServerJavaScript::get_singleton(); + if (!ds) { + ERR_FAIL_MSG("Unable to drop files because the DisplayServer is not active"); + } + if (ds->drop_files_callback.is_null()) + return; + Vector<String> files; + for (int i = 0; i < p_filec; i++) { + files.push_back(String::utf8(p_filev[i])); + } + Variant v = files; + Variant *vp = &v; + Variant ret; + Callable::CallError ce; + ds->drop_files_callback.call((const Variant **)&vp, 1, ret, ce); +} + +// Keys + +template <typename T> +static void dom2godot_mod(T *emscripten_event_ptr, Ref<InputEventWithModifiers> godot_event) { + + godot_event->set_shift(emscripten_event_ptr->shiftKey); + godot_event->set_alt(emscripten_event_ptr->altKey); + godot_event->set_control(emscripten_event_ptr->ctrlKey); + godot_event->set_metakey(emscripten_event_ptr->metaKey); +} + +static Ref<InputEventKey> setup_key_event(const EmscriptenKeyboardEvent *emscripten_event) { + + Ref<InputEventKey> ev; + ev.instance(); + ev->set_echo(emscripten_event->repeat); + dom2godot_mod(emscripten_event, ev); + ev->set_keycode(dom2godot_keycode(emscripten_event->keyCode)); + ev->set_physical_keycode(dom2godot_keycode(emscripten_event->keyCode)); + + String unicode = String::utf8(emscripten_event->key); + // Check if empty or multi-character (e.g. `CapsLock`). + if (unicode.length() != 1) { + // Might be empty as well, but better than nonsense. + unicode = String::utf8(emscripten_event->charValue); + } + if (unicode.length() == 1) { + ev->set_unicode(unicode[0]); + } + + return ev; +} + +EM_BOOL DisplayServerJavaScript::keydown_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + Ref<InputEventKey> ev = setup_key_event(p_event); + ev->set_pressed(true); + if (ev->get_unicode() == 0 && keycode_has_unicode(ev->get_keycode())) { + // Defer to keypress event for legacy unicode retrieval. + display->deferred_key_event = ev; + // Do not suppress keypress event. + return false; + } + Input::get_singleton()->parse_input_event(ev); + return true; +} + +EM_BOOL DisplayServerJavaScript::keypress_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + display->deferred_key_event->set_unicode(p_event->charCode); + Input::get_singleton()->parse_input_event(display->deferred_key_event); + return true; +} + +EM_BOOL DisplayServerJavaScript::keyup_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { + + Ref<InputEventKey> ev = setup_key_event(p_event); + ev->set_pressed(false); + Input::get_singleton()->parse_input_event(ev); + return ev->get_keycode() != KEY_UNKNOWN && ev->get_keycode() != 0; +} + +// Mouse + +EM_BOOL DisplayServerJavaScript::mouse_button_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + + Ref<InputEventMouseButton> ev; + ev.instance(); + ev->set_pressed(p_event_type == EMSCRIPTEN_EVENT_MOUSEDOWN); + ev->set_position(compute_position_in_canvas(p_event->clientX, p_event->clientY)); + ev->set_global_position(ev->get_position()); + dom2godot_mod(p_event, ev); + + switch (p_event->button) { + case DOM_BUTTON_LEFT: + ev->set_button_index(BUTTON_LEFT); + break; + case DOM_BUTTON_MIDDLE: + ev->set_button_index(BUTTON_MIDDLE); + break; + case DOM_BUTTON_RIGHT: + ev->set_button_index(BUTTON_RIGHT); + break; + case DOM_BUTTON_XBUTTON1: + ev->set_button_index(BUTTON_XBUTTON1); + break; + case DOM_BUTTON_XBUTTON2: + ev->set_button_index(BUTTON_XBUTTON2); + break; + default: + return false; + } + + if (ev->is_pressed()) { + + double diff = emscripten_get_now() - display->last_click_ms; + + if (ev->get_button_index() == display->last_click_button_index) { + + if (diff < 400 && Point2(display->last_click_pos).distance_to(ev->get_position()) < 5) { + + display->last_click_ms = 0; + display->last_click_pos = Point2(-100, -100); + display->last_click_button_index = -1; + ev->set_doubleclick(true); + } + + } else { + display->last_click_button_index = ev->get_button_index(); + } + + if (!ev->is_doubleclick()) { + display->last_click_ms += diff; + display->last_click_pos = ev->get_position(); + } + } + + Input *input = Input::get_singleton(); + int mask = input->get_mouse_button_mask(); + int button_flag = 1 << (ev->get_button_index() - 1); + if (ev->is_pressed()) { + // Since the event is consumed, focus manually. The containing iframe, + // if exists, may not have focus yet, so focus even if already focused. + focus_canvas(); + mask |= button_flag; + } else if (mask & button_flag) { + mask &= ~button_flag; + } else { + // Received release event, but press was outside the canvas, so ignore. + return false; + } + ev->set_button_mask(mask); + + input->parse_input_event(ev); + // Prevent multi-click text selection and wheel-click scrolling anchor. + // Context menu is prevented through contextmenu event. + return true; +} + +EM_BOOL DisplayServerJavaScript::mousemove_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data) { + + Input *input = Input::get_singleton(); + int input_mask = input->get_mouse_button_mask(); + Point2 pos = compute_position_in_canvas(p_event->clientX, p_event->clientY); + // For motion outside the canvas, only read mouse movement if dragging + // started inside the canvas; imitating desktop app behaviour. + if (!cursor_inside_canvas && !input_mask) + return false; + + Ref<InputEventMouseMotion> ev; + ev.instance(); + dom2godot_mod(p_event, ev); + ev->set_button_mask(input_mask); + + ev->set_position(pos); + ev->set_global_position(ev->get_position()); + + ev->set_relative(Vector2(p_event->movementX, p_event->movementY)); + input->set_mouse_position(ev->get_position()); + ev->set_speed(input->get_last_mouse_speed()); + + input->parse_input_event(ev); + // Don't suppress mouseover/-leave events. + return false; +} + +// Cursor +static const char *godot2dom_cursor(DisplayServer::CursorShape p_shape) { + + switch (p_shape) { + case DisplayServer::CURSOR_ARROW: + return "auto"; + case DisplayServer::CURSOR_IBEAM: + return "text"; + case DisplayServer::CURSOR_POINTING_HAND: + return "pointer"; + case DisplayServer::CURSOR_CROSS: + return "crosshair"; + case DisplayServer::CURSOR_WAIT: + return "progress"; + case DisplayServer::CURSOR_BUSY: + return "wait"; + case DisplayServer::CURSOR_DRAG: + return "grab"; + case DisplayServer::CURSOR_CAN_DROP: + return "grabbing"; + case DisplayServer::CURSOR_FORBIDDEN: + return "no-drop"; + case DisplayServer::CURSOR_VSIZE: + return "ns-resize"; + case DisplayServer::CURSOR_HSIZE: + return "ew-resize"; + case DisplayServer::CURSOR_BDIAGSIZE: + return "nesw-resize"; + case DisplayServer::CURSOR_FDIAGSIZE: + return "nwse-resize"; + case DisplayServer::CURSOR_MOVE: + return "move"; + case DisplayServer::CURSOR_VSPLIT: + return "row-resize"; + case DisplayServer::CURSOR_HSPLIT: + return "col-resize"; + case DisplayServer::CURSOR_HELP: + return "help"; + default: + return "auto"; + } +} + +static void set_css_cursor(const char *p_cursor) { + + /* clang-format off */ + EM_ASM_({ + Module['canvas'].style.cursor = UTF8ToString($0); + }, p_cursor); + /* clang-format on */ +} + +static bool is_css_cursor_hidden() { + + /* clang-format off */ + return EM_ASM_INT({ + return Module['canvas'].style.cursor === 'none'; + }); + /* clang-format on */ +} + +void DisplayServerJavaScript::cursor_set_shape(CursorShape p_shape) { + + ERR_FAIL_INDEX(p_shape, CURSOR_MAX); + + if (mouse_get_mode() == MOUSE_MODE_VISIBLE) { + if (cursors[p_shape] != "") { + Vector<String> url = cursors[p_shape].split("?"); + set_css_cursor(("url(\"" + url[0] + "\") " + url[1] + ", auto").utf8()); + } else { + set_css_cursor(godot2dom_cursor(p_shape)); + } + } + + cursor_shape = p_shape; +} + +DisplayServer::CursorShape DisplayServerJavaScript::cursor_get_shape() const { + return cursor_shape; +} + +void DisplayServerJavaScript::cursor_set_custom_image(const RES &p_cursor, CursorShape p_shape, const Vector2 &p_hotspot) { + + if (p_cursor.is_valid()) { + + Map<CursorShape, Vector<Variant>>::Element *cursor_c = cursors_cache.find(p_shape); + + if (cursor_c) { + if (cursor_c->get()[0] == p_cursor && cursor_c->get()[1] == p_hotspot) { + cursor_set_shape(p_shape); + return; + } + + cursors_cache.erase(p_shape); + } + + Ref<Texture2D> texture = p_cursor; + Ref<AtlasTexture> atlas_texture = p_cursor; + Ref<Image> image; + Size2 texture_size; + Rect2 atlas_rect; + + if (texture.is_valid()) { + image = texture->get_data(); + } + + if (!image.is_valid() && atlas_texture.is_valid()) { + texture = atlas_texture->get_atlas(); + + atlas_rect.size.width = texture->get_width(); + atlas_rect.size.height = texture->get_height(); + atlas_rect.position.x = atlas_texture->get_region().position.x; + atlas_rect.position.y = atlas_texture->get_region().position.y; + + texture_size.width = atlas_texture->get_region().size.x; + texture_size.height = atlas_texture->get_region().size.y; + } else if (image.is_valid()) { + texture_size.width = texture->get_width(); + texture_size.height = texture->get_height(); + } + + ERR_FAIL_COND(!texture.is_valid()); + ERR_FAIL_COND(p_hotspot.x < 0 || p_hotspot.y < 0); + ERR_FAIL_COND(texture_size.width > 256 || texture_size.height > 256); + ERR_FAIL_COND(p_hotspot.x > texture_size.width || p_hotspot.y > texture_size.height); + + image = texture->get_data(); + + ERR_FAIL_COND(!image.is_valid()); + + image = image->duplicate(); + + if (atlas_texture.is_valid()) + image->crop_from_point( + atlas_rect.position.x, + atlas_rect.position.y, + texture_size.width, + texture_size.height); + + if (image->get_format() != Image::FORMAT_RGBA8) { + image->convert(Image::FORMAT_RGBA8); + } + + png_image png_meta; + memset(&png_meta, 0, sizeof png_meta); + png_meta.version = PNG_IMAGE_VERSION; + png_meta.width = texture_size.width; + png_meta.height = texture_size.height; + png_meta.format = PNG_FORMAT_RGBA; + + PackedByteArray png; + size_t len; + PackedByteArray data = image->get_data(); + ERR_FAIL_COND(!png_image_write_get_memory_size(png_meta, len, 0, data.ptr(), 0, nullptr)); + + png.resize(len); + ERR_FAIL_COND(!png_image_write_to_memory(&png_meta, png.ptrw(), &len, 0, data.ptr(), 0, nullptr)); + + char *object_url; + /* clang-format off */ + EM_ASM({ + var PNG_PTR = $0; + var PNG_LEN = $1; + var PTR = $2; + + var png = new Blob([HEAPU8.slice(PNG_PTR, PNG_PTR + PNG_LEN)], { type: 'image/png' }); + var url = URL.createObjectURL(png); + var length_bytes = lengthBytesUTF8(url) + 1; + var string_on_wasm_heap = _malloc(length_bytes); + setValue(PTR, string_on_wasm_heap, '*'); + stringToUTF8(url, string_on_wasm_heap, length_bytes); + }, png.ptr(), len, &object_url); + /* clang-format on */ + + String url = String::utf8(object_url) + "?" + itos(p_hotspot.x) + " " + itos(p_hotspot.y); + + /* clang-format off */ + EM_ASM({ _free($0); }, object_url); + /* clang-format on */ + + if (cursors[p_shape] != "") { + /* clang-format off */ + EM_ASM({ + URL.revokeObjectURL(UTF8ToString($0).split('?')[0]); + }, cursors[p_shape].utf8().get_data()); + /* clang-format on */ + cursors[p_shape] = ""; + } + + cursors[p_shape] = url; + + Vector<Variant> params; + params.push_back(p_cursor); + params.push_back(p_hotspot); + cursors_cache.insert(p_shape, params); + + } else if (cursors[p_shape] != "") { + /* clang-format off */ + EM_ASM({ + URL.revokeObjectURL(UTF8ToString($0).split('?')[0]); + }, cursors[p_shape].utf8().get_data()); + /* clang-format on */ + cursors[p_shape] = ""; + + cursors_cache.erase(p_shape); + } + + cursor_set_shape(cursor_shape); +} + +// Mouse mode +void DisplayServerJavaScript::mouse_set_mode(MouseMode p_mode) { + + ERR_FAIL_COND_MSG(p_mode == MOUSE_MODE_CONFINED, "MOUSE_MODE_CONFINED is not supported for the HTML5 platform."); + if (p_mode == mouse_get_mode()) + return; + + if (p_mode == MOUSE_MODE_VISIBLE) { + + // set_css_cursor must be called before set_cursor_shape to make the cursor visible + set_css_cursor(godot2dom_cursor(cursor_shape)); + cursor_set_shape(cursor_shape); + emscripten_exit_pointerlock(); + + } else if (p_mode == MOUSE_MODE_HIDDEN) { + + set_css_cursor("none"); + emscripten_exit_pointerlock(); + + } else if (p_mode == MOUSE_MODE_CAPTURED) { + + EMSCRIPTEN_RESULT result = emscripten_request_pointerlock("canvas", false); + ERR_FAIL_COND_MSG(result == EMSCRIPTEN_RESULT_FAILED_NOT_DEFERRED, "MOUSE_MODE_CAPTURED can only be entered from within an appropriate input callback."); + ERR_FAIL_COND_MSG(result != EMSCRIPTEN_RESULT_SUCCESS, "MOUSE_MODE_CAPTURED can only be entered from within an appropriate input callback."); + // set_css_cursor must be called before cursor_set_shape to make the cursor visible + set_css_cursor(godot2dom_cursor(cursor_shape)); + cursor_set_shape(cursor_shape); + } +} + +DisplayServer::MouseMode DisplayServerJavaScript::mouse_get_mode() const { + + if (is_css_cursor_hidden()) + return MOUSE_MODE_HIDDEN; + + EmscriptenPointerlockChangeEvent ev; + emscripten_get_pointerlock_status(&ev); + return (ev.isActive && String::utf8(ev.id) == "canvas") ? MOUSE_MODE_CAPTURED : MOUSE_MODE_VISIBLE; +} + +// Wheel + +EM_BOOL DisplayServerJavaScript::wheel_callback(int p_event_type, const EmscriptenWheelEvent *p_event, void *p_user_data) { + + ERR_FAIL_COND_V(p_event_type != EMSCRIPTEN_EVENT_WHEEL, false); + if (!is_canvas_focused()) { + if (cursor_inside_canvas) { + focus_canvas(); + } else { + return false; + } + } + + Input *input = Input::get_singleton(); + Ref<InputEventMouseButton> ev; + ev.instance(); + ev->set_position(input->get_mouse_position()); + ev->set_global_position(ev->get_position()); + + ev->set_shift(input->is_key_pressed(KEY_SHIFT)); + ev->set_alt(input->is_key_pressed(KEY_ALT)); + ev->set_control(input->is_key_pressed(KEY_CONTROL)); + ev->set_metakey(input->is_key_pressed(KEY_META)); + + if (p_event->deltaY < 0) + ev->set_button_index(BUTTON_WHEEL_UP); + else if (p_event->deltaY > 0) + ev->set_button_index(BUTTON_WHEEL_DOWN); + else if (p_event->deltaX > 0) + ev->set_button_index(BUTTON_WHEEL_LEFT); + else if (p_event->deltaX < 0) + ev->set_button_index(BUTTON_WHEEL_RIGHT); + else + return false; + + // Different browsers give wildly different delta values, and we can't + // interpret deltaMode, so use default value for wheel events' factor. + + int button_flag = 1 << (ev->get_button_index() - 1); + + ev->set_pressed(true); + ev->set_button_mask(input->get_mouse_button_mask() | button_flag); + input->parse_input_event(ev); + + ev->set_pressed(false); + ev->set_button_mask(input->get_mouse_button_mask() & ~button_flag); + input->parse_input_event(ev); + + return true; +} + +// Touch +EM_BOOL DisplayServerJavaScript::touch_press_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + Ref<InputEventScreenTouch> ev; + ev.instance(); + int lowest_id_index = -1; + for (int i = 0; i < p_event->numTouches; ++i) { + + const EmscriptenTouchPoint &touch = p_event->touches[i]; + if (lowest_id_index == -1 || touch.identifier < p_event->touches[lowest_id_index].identifier) + lowest_id_index = i; + if (!touch.isChanged) + continue; + ev->set_index(touch.identifier); + ev->set_position(compute_position_in_canvas(touch.clientX, touch.clientY)); + display->touches[i] = ev->get_position(); + ev->set_pressed(p_event_type == EMSCRIPTEN_EVENT_TOUCHSTART); + + Input::get_singleton()->parse_input_event(ev); + } + // Resume audio context after input in case autoplay was denied. + return true; +} + +EM_BOOL DisplayServerJavaScript::touchmove_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data) { + + DisplayServerJavaScript *display = get_singleton(); + Ref<InputEventScreenDrag> ev; + ev.instance(); + int lowest_id_index = -1; + for (int i = 0; i < p_event->numTouches; ++i) { + + const EmscriptenTouchPoint &touch = p_event->touches[i]; + if (lowest_id_index == -1 || touch.identifier < p_event->touches[lowest_id_index].identifier) + lowest_id_index = i; + if (!touch.isChanged) + continue; + ev->set_index(touch.identifier); + ev->set_position(compute_position_in_canvas(touch.clientX, touch.clientY)); + Point2 &prev = display->touches[i]; + ev->set_relative(ev->get_position() - prev); + prev = ev->get_position(); + + Input::get_singleton()->parse_input_event(ev); + } + return true; +} + +bool DisplayServerJavaScript::screen_is_touchscreen(int p_screen) const { + return EM_ASM_INT({ return 'ontouchstart' in window; }); +} + +// Gamepad + +EM_BOOL DisplayServerJavaScript::gamepad_change_callback(int p_event_type, const EmscriptenGamepadEvent *p_event, void *p_user_data) { + + Input *input = Input::get_singleton(); + if (p_event_type == EMSCRIPTEN_EVENT_GAMEPADCONNECTED) { + + String guid = ""; + if (String::utf8(p_event->mapping) == "standard") + guid = "Default HTML5 Gamepad"; + input->joy_connection_changed(p_event->index, true, String::utf8(p_event->id), guid); + } else { + input->joy_connection_changed(p_event->index, false, ""); + } + return true; +} + +void DisplayServerJavaScript::process_joypads() { + + int joypad_count = emscripten_get_num_gamepads(); + Input *input = Input::get_singleton(); + for (int joypad = 0; joypad < joypad_count; joypad++) { + EmscriptenGamepadEvent state; + EMSCRIPTEN_RESULT query_result = emscripten_get_gamepad_status(joypad, &state); + // Chromium reserves gamepads slots, so NO_DATA is an expected result. + ERR_CONTINUE(query_result != EMSCRIPTEN_RESULT_SUCCESS && + query_result != EMSCRIPTEN_RESULT_NO_DATA); + if (query_result == EMSCRIPTEN_RESULT_SUCCESS && state.connected) { + + int button_count = MIN(state.numButtons, 18); + int axis_count = MIN(state.numAxes, 8); + for (int button = 0; button < button_count; button++) { + + float value = state.analogButton[button]; + input->joy_button(joypad, button, value); + } + for (int axis = 0; axis < axis_count; axis++) { + + Input::JoyAxis joy_axis; + joy_axis.min = -1; + joy_axis.value = state.axis[axis]; + input->joy_axis(joypad, axis, joy_axis); + } + } + } +} + +#if 0 +bool DisplayServerJavaScript::is_joy_known(int p_device) { + + return Input::get_singleton()->is_joy_mapped(p_device); +} + +String DisplayServerJavaScript::get_joy_guid(int p_device) const { + + return Input::get_singleton()->get_joy_guid_remapped(p_device); +} +#endif + +Vector<String> DisplayServerJavaScript::get_rendering_drivers_func() { + Vector<String> drivers; + drivers.push_back("dummy"); + return drivers; +} + +// Clipboard +extern "C" EMSCRIPTEN_KEEPALIVE void update_clipboard(const char *p_text) { + // Only call set_clipboard from OS (sets local clipboard) + DisplayServerJavaScript::get_singleton()->clipboard = p_text; +} + +void DisplayServerJavaScript::clipboard_set(const String &p_text) { + /* clang-format off */ + int err = EM_ASM_INT({ + var text = UTF8ToString($0); + if (!navigator.clipboard || !navigator.clipboard.writeText) + return 1; + navigator.clipboard.writeText(text).catch(function(e) { + // Setting OS clipboard is only possible from an input callback. + console.error("Setting OS clipboard is only possible from an input callback for the HTML5 plafrom. Exception:", e); + }); + return 0; + }, p_text.utf8().get_data()); + /* clang-format on */ + ERR_FAIL_COND_MSG(err, "Clipboard API is not supported."); +} + +String DisplayServerJavaScript::clipboard_get() const { + /* clang-format off */ + EM_ASM({ + try { + navigator.clipboard.readText().then(function (result) { + ccall('update_clipboard', 'void', ['string'], [result]); + }).catch(function (e) { + // Fail graciously. + }); + } catch (e) { + // Fail graciously. + } + }); + /* clang-format on */ + return clipboard; +} + +extern "C" EMSCRIPTEN_KEEPALIVE void send_window_event(int p_notification) { + + if (p_notification == DisplayServer::WINDOW_EVENT_MOUSE_ENTER || p_notification == DisplayServer::WINDOW_EVENT_MOUSE_EXIT) { + cursor_inside_canvas = p_notification == DisplayServer::WINDOW_EVENT_MOUSE_ENTER; + } + OS_JavaScript *os = OS_JavaScript::get_singleton(); + if (os->is_finalizing()) + return; // We don't want events anymore. + DisplayServerJavaScript *ds = DisplayServerJavaScript::get_singleton(); + if (ds && !ds->window_event_callback.is_null()) { + Variant event = int(p_notification); + Variant *eventp = &event; + Variant ret; + Callable::CallError ce; + ds->window_event_callback.call((const Variant **)&eventp, 1, ret, ce); + } +} + +void DisplayServerJavaScript::alert(const String &p_alert, const String &p_title) { + + /* clang-format off */ + EM_ASM_({ + window.alert(UTF8ToString($0)); + }, p_alert.utf8().get_data()); + /* clang-format on */ +} + +void DisplayServerJavaScript::set_icon(const Ref<Image> &p_icon) { + + ERR_FAIL_COND(p_icon.is_null()); + Ref<Image> icon = p_icon; + if (icon->is_compressed()) { + icon = icon->duplicate(); + ERR_FAIL_COND(icon->decompress() != OK); + } + if (icon->get_format() != Image::FORMAT_RGBA8) { + if (icon == p_icon) + icon = icon->duplicate(); + icon->convert(Image::FORMAT_RGBA8); + } + + png_image png_meta; + memset(&png_meta, 0, sizeof png_meta); + png_meta.version = PNG_IMAGE_VERSION; + png_meta.width = icon->get_width(); + png_meta.height = icon->get_height(); + png_meta.format = PNG_FORMAT_RGBA; + + PackedByteArray png; + size_t len; + PackedByteArray data = icon->get_data(); + ERR_FAIL_COND(!png_image_write_get_memory_size(png_meta, len, 0, data.ptr(), 0, nullptr)); + + png.resize(len); + ERR_FAIL_COND(!png_image_write_to_memory(&png_meta, png.ptrw(), &len, 0, data.ptr(), 0, nullptr)); + + /* clang-format off */ + EM_ASM({ + var PNG_PTR = $0; + var PNG_LEN = $1; + + var png = new Blob([HEAPU8.slice(PNG_PTR, PNG_PTR + PNG_LEN)], { type: "image/png" }); + var url = URL.createObjectURL(png); + var link = document.getElementById('-gd-engine-icon'); + if (link === null) { + link = document.createElement('link'); + link.rel = 'icon'; + link.id = '-gd-engine-icon'; + document.head.appendChild(link); + } + link.href = url; + }, png.ptr(), len); + /* clang-format on */ +} + +void DisplayServerJavaScript::_dispatch_input_event(const Ref<InputEvent> &p_event) { + OS_JavaScript *os = OS_JavaScript::get_singleton(); + if (os->is_finalizing()) + return; // We don't want events anymore. + + // Resume audio context after input in case autoplay was denied. + os->resume_audio(); + + Callable cb = get_singleton()->input_event_callback; + if (!cb.is_null()) { + Variant ev = p_event; + Variant *evp = &ev; + Variant ret; + Callable::CallError ce; + cb.call((const Variant **)&evp, 1, ret, ce); + } +} + +DisplayServer *DisplayServerJavaScript::create_func(const String &p_rendering_driver, DisplayServer::WindowMode p_mode, uint32_t p_flags, const Vector2i &p_resolution, Error &r_error) { + return memnew(DisplayServerJavaScript(p_rendering_driver, p_mode, p_flags, p_resolution, r_error)); +} + +DisplayServerJavaScript::DisplayServerJavaScript(const String &p_rendering_driver, WindowMode p_mode, uint32_t p_flags, const Vector2i &p_resolution, Error &r_error) { + + /* clang-format off */ + EM_ASM({ + const canvas = Module['canvas']; + var enc = new TextEncoder("utf-8"); + var buffer = new Uint8Array(enc.encode(canvas.id)); + var len = buffer.byteLength; + var out = _malloc(len); + HEAPU8.set(buffer, out); + ccall("_set_canvas_id", + "void", + ["number", "number"], + [out, len] + ); + _free(out); + }); + /* clang-format on */ + + RasterizerDummy::make_current(); // TODO GLES2 in Godot 4.0... or webgpu? +#if 0 + EmscriptenWebGLContextAttributes attributes; + emscripten_webgl_init_context_attributes(&attributes); + attributes.alpha = GLOBAL_GET("display/window/per_pixel_transparency/allowed"); + attributes.antialias = false; + ERR_FAIL_INDEX_V(p_video_driver, VIDEO_DRIVER_MAX, ERR_INVALID_PARAMETER); + + if (p_desired.layered) { + set_window_per_pixel_transparency_enabled(true); + } + + bool gl_initialization_error = false; + + if (RasterizerGLES2::is_viable() == OK) { + attributes.majorVersion = 1; + RasterizerGLES2::register_config(); + RasterizerGLES2::make_current(); + } else { + gl_initialization_error = true; + } + + EMSCRIPTEN_WEBGL_CONTEXT_HANDLE ctx = emscripten_webgl_create_context(canvas_id.utf8().get_data(), &attributes); + if (emscripten_webgl_make_context_current(ctx) != EMSCRIPTEN_RESULT_SUCCESS) { + gl_initialization_error = true; + } + + if (gl_initialization_error) { + OS::get_singleton()->alert("Your browser does not seem to support WebGL. Please update your browser version.", + "Unable to initialize video driver"); + return ERR_UNAVAILABLE; + } + + video_driver_index = p_video_driver; +#endif + + /* clang-format off */ + window_set_mode(p_mode); + if (EM_ASM_INT_V({ return Module['resizeCanvasOnStart'] })) { + /* clang-format on */ + window_set_size(p_resolution); + } + + EMSCRIPTEN_RESULT result; + CharString id = canvas_id.utf8(); +#define EM_CHECK(ev) \ + if (result != EMSCRIPTEN_RESULT_SUCCESS) \ + ERR_PRINT("Error while setting " #ev " callback: Code " + itos(result)); +#define SET_EM_CALLBACK(target, ev, cb) \ + result = emscripten_set_##ev##_callback(target, nullptr, true, &cb); \ + EM_CHECK(ev) +#define SET_EM_CALLBACK_NOTARGET(ev, cb) \ + result = emscripten_set_##ev##_callback(nullptr, true, &cb); \ + EM_CHECK(ev) + // These callbacks from Emscripten's html5.h suffice to access most + // JavaScript APIs. For APIs that are not (sufficiently) exposed, EM_ASM + // is used below. + SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_WINDOW, mousemove, mousemove_callback) + SET_EM_CALLBACK(id.get_data(), mousedown, mouse_button_callback) + SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_WINDOW, mouseup, mouse_button_callback) + SET_EM_CALLBACK(id.get_data(), wheel, wheel_callback) + SET_EM_CALLBACK(id.get_data(), touchstart, touch_press_callback) + SET_EM_CALLBACK(id.get_data(), touchmove, touchmove_callback) + SET_EM_CALLBACK(id.get_data(), touchend, touch_press_callback) + SET_EM_CALLBACK(id.get_data(), touchcancel, touch_press_callback) + SET_EM_CALLBACK(id.get_data(), keydown, keydown_callback) + SET_EM_CALLBACK(id.get_data(), keypress, keypress_callback) + SET_EM_CALLBACK(id.get_data(), keyup, keyup_callback) + SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_DOCUMENT, fullscreenchange, fullscreen_change_callback) + SET_EM_CALLBACK_NOTARGET(gamepadconnected, gamepad_change_callback) + SET_EM_CALLBACK_NOTARGET(gamepaddisconnected, gamepad_change_callback) +#undef SET_EM_CALLBACK_NOTARGET +#undef SET_EM_CALLBACK +#undef EM_CHECK + + /* clang-format off */ + EM_ASM_ARGS({ + Module.listeners = {}; + const canvas = Module['canvas']; + const send_window_event = cwrap('send_window_event', null, ['number']); + const notifications = arguments; + (['mouseover', 'mouseleave', 'focus', 'blur']).forEach(function(event, index) { + Module.listeners[event] = send_window_event.bind(null, notifications[index]); + canvas.addEventListener(event, Module.listeners[event]); + }); + // Clipboard + const update_clipboard = cwrap('update_clipboard', null, ['string']); + Module.listeners['paste'] = function(evt) { + update_clipboard(evt.clipboardData.getData('text')); + }; + window.addEventListener('paste', Module.listeners['paste'], false); + Module.listeners['dragover'] = function(ev) { + // Prevent default behavior (which would try to open the file(s)) + ev.preventDefault(); + }; + Module.listeners['drop'] = Module.drop_handler; // Defined in native/utils.js + canvas.addEventListener('dragover', Module.listeners['dragover'], false); + canvas.addEventListener('drop', Module.listeners['drop'], false); + }, + WINDOW_EVENT_MOUSE_ENTER, + WINDOW_EVENT_MOUSE_EXIT, + WINDOW_EVENT_FOCUS_IN, + WINDOW_EVENT_FOCUS_OUT + ); + /* clang-format on */ + + Input::get_singleton()->set_event_dispatch_function(_dispatch_input_event); +} + +DisplayServerJavaScript::~DisplayServerJavaScript() { + EM_ASM({ + Object.entries(Module.listeners).forEach(function(kv) { + if (kv[0] == 'paste') { + window.removeEventListener(kv[0], kv[1], true); + } else { + Module['canvas'].removeEventListener(kv[0], kv[1]); + } + }); + Module.listeners = {}; + }); + //emscripten_webgl_commit_frame(); + //emscripten_webgl_destroy_context(webgl_ctx); +} + +bool DisplayServerJavaScript::has_feature(Feature p_feature) const { + switch (p_feature) { + //case FEATURE_CONSOLE_WINDOW: + //case FEATURE_GLOBAL_MENU: + //case FEATURE_HIDPI: + //case FEATURE_IME: + case FEATURE_ICON: + case FEATURE_CLIPBOARD: + case FEATURE_CURSOR_SHAPE: + case FEATURE_CUSTOM_CURSOR_SHAPE: + case FEATURE_MOUSE: + case FEATURE_TOUCHSCREEN: + return true; + //case FEATURE_MOUSE_WARP: + //case FEATURE_NATIVE_DIALOG: + //case FEATURE_NATIVE_ICON: + //case FEATURE_NATIVE_VIDEO: + //case FEATURE_WINDOW_TRANSPARENCY: + //case FEATURE_KEEP_SCREEN_ON: + //case FEATURE_ORIENTATION: + //case FEATURE_VIRTUAL_KEYBOARD: + default: + return false; + } +} + +void DisplayServerJavaScript::register_javascript_driver() { + register_create_function("javascript", create_func, get_rendering_drivers_func); +} + +String DisplayServerJavaScript::get_name() const { + return "javascript"; +} + +int DisplayServerJavaScript::get_screen_count() const { + return 1; +} + +Point2i DisplayServerJavaScript::screen_get_position(int p_screen) const { + return Point2i(); // TODO offsetX/Y? +} + +Size2i DisplayServerJavaScript::screen_get_size(int p_screen) const { + + EmscriptenFullscreenChangeEvent ev; + EMSCRIPTEN_RESULT result = emscripten_get_fullscreen_status(&ev); + ERR_FAIL_COND_V(result != EMSCRIPTEN_RESULT_SUCCESS, Size2i()); + return Size2i(ev.screenWidth, ev.screenHeight); +} + +Rect2i DisplayServerJavaScript::screen_get_usable_rect(int p_screen) const { + int canvas[2]; + emscripten_get_canvas_element_size(canvas_id.utf8().get_data(), canvas, canvas + 1); + return Rect2i(0, 0, canvas[0], canvas[1]); +} + +int DisplayServerJavaScript::screen_get_dpi(int p_screen) const { + return 96; // TODO maybe check pixel ratio via window.devicePixelRatio * 96? Inexact. +} + +Vector<DisplayServer::WindowID> DisplayServerJavaScript::get_window_list() const { + Vector<WindowID> ret; + ret.push_back(MAIN_WINDOW_ID); + return ret; +} + +DisplayServerJavaScript::WindowID DisplayServerJavaScript::get_window_at_screen_position(const Point2i &p_position) const { + return MAIN_WINDOW_ID; +} + +void DisplayServerJavaScript::window_attach_instance_id(ObjectID p_instance, WindowID p_window) { + window_attached_instance_id = p_instance; +} + +ObjectID DisplayServerJavaScript::window_get_attached_instance_id(WindowID p_window) const { + return window_attached_instance_id; +} + +void DisplayServerJavaScript::window_set_rect_changed_callback(const Callable &p_callable, WindowID p_window) { + // Not supported. +} + +void DisplayServerJavaScript::window_set_window_event_callback(const Callable &p_callable, WindowID p_window) { + window_event_callback = p_callable; +} + +void DisplayServerJavaScript::window_set_input_event_callback(const Callable &p_callable, WindowID p_window) { + input_event_callback = p_callable; +} + +void DisplayServerJavaScript::window_set_input_text_callback(const Callable &p_callable, WindowID p_window) { + input_text_callback = p_callable; // TODO unused... do I need this? +} + +void DisplayServerJavaScript::window_set_drop_files_callback(const Callable &p_callable, WindowID p_window) { + drop_files_callback = p_callable; +} + +void DisplayServerJavaScript::window_set_title(const String &p_title, WindowID p_window) { + /* clang-format off */ + EM_ASM_({ + document.title = UTF8ToString($0); + }, p_title.utf8().get_data()); + /* clang-format on */ +} + +int DisplayServerJavaScript::window_get_current_screen(WindowID p_window) const { + return 1; +} + +void DisplayServerJavaScript::window_set_current_screen(int p_screen, WindowID p_window) { + // Not implemented. +} + +Point2i DisplayServerJavaScript::window_get_position(WindowID p_window) const { + return Point2i(); // TODO Does this need implementation? +} +void DisplayServerJavaScript::window_set_position(const Point2i &p_position, WindowID p_window) { + // Not supported. +} + +void DisplayServerJavaScript::window_set_transient(WindowID p_window, WindowID p_parent) { + // Not supported. +} + +void DisplayServerJavaScript::window_set_max_size(const Size2i p_size, WindowID p_window) { + // Not supported. +} + +Size2i DisplayServerJavaScript::window_get_max_size(WindowID p_window) const { + return Size2i(); +} + +void DisplayServerJavaScript::window_set_min_size(const Size2i p_size, WindowID p_window) { + // Not supported. +} + +Size2i DisplayServerJavaScript::window_get_min_size(WindowID p_window) const { + return Size2i(); +} + +void DisplayServerJavaScript::window_set_size(const Size2i p_size, WindowID p_window) { + emscripten_set_canvas_element_size(canvas_id.utf8().get_data(), p_size.x, p_size.y); +} + +Size2i DisplayServerJavaScript::window_get_size(WindowID p_window) const { + int canvas[2]; + emscripten_get_canvas_element_size(canvas_id.utf8().get_data(), canvas, canvas + 1); + return Size2(canvas[0], canvas[1]); +} + +Size2i DisplayServerJavaScript::window_get_real_size(WindowID p_window) const { + return window_get_size(p_window); +} + +void DisplayServerJavaScript::window_set_mode(WindowMode p_mode, WindowID p_window) { + if (window_mode == p_mode) + return; + + switch (p_mode) { + case WINDOW_MODE_WINDOWED: { + if (window_mode == WINDOW_MODE_FULLSCREEN) { + emscripten_exit_fullscreen(); + } + window_mode = WINDOW_MODE_WINDOWED; + window_set_size(windowed_size); + } break; + case WINDOW_MODE_FULLSCREEN: { + EmscriptenFullscreenStrategy strategy; + strategy.scaleMode = EMSCRIPTEN_FULLSCREEN_SCALE_STRETCH; + strategy.canvasResolutionScaleMode = EMSCRIPTEN_FULLSCREEN_CANVAS_SCALE_STDDEF; + strategy.filteringMode = EMSCRIPTEN_FULLSCREEN_FILTERING_DEFAULT; + strategy.canvasResizedCallback = nullptr; + EMSCRIPTEN_RESULT result = emscripten_request_fullscreen_strategy(canvas_id.utf8().get_data(), false, &strategy); + ERR_FAIL_COND_MSG(result == EMSCRIPTEN_RESULT_FAILED_NOT_DEFERRED, "Enabling fullscreen is only possible from an input callback for the HTML5 platform."); + ERR_FAIL_COND_MSG(result != EMSCRIPTEN_RESULT_SUCCESS, "Enabling fullscreen is only possible from an input callback for the HTML5 platform."); + } break; + case WINDOW_MODE_MAXIMIZED: + case WINDOW_MODE_MINIMIZED: + WARN_PRINT("WindowMode MAXIMIZED and MINIMIZED are not supported in HTML5 platform."); + break; + default: + break; + } +} + +DisplayServerJavaScript::WindowMode DisplayServerJavaScript::window_get_mode(WindowID p_window) const { + return window_mode; +} + +bool DisplayServerJavaScript::window_is_maximize_allowed(WindowID p_window) const { + return false; +} + +void DisplayServerJavaScript::window_set_flag(WindowFlags p_flag, bool p_enabled, WindowID p_window) { + // Not supported. +} + +bool DisplayServerJavaScript::window_get_flag(WindowFlags p_flag, WindowID p_window) const { + return false; +} + +void DisplayServerJavaScript::window_request_attention(WindowID p_window) { + // Not supported. +} + +void DisplayServerJavaScript::window_move_to_foreground(WindowID p_window) { + // Not supported. +} + +bool DisplayServerJavaScript::window_can_draw(WindowID p_window) const { + return true; +} + +bool DisplayServerJavaScript::can_any_window_draw() const { + return true; +} + +void DisplayServerJavaScript::process_events() { + if (emscripten_sample_gamepad_data() == EMSCRIPTEN_RESULT_SUCCESS) + process_joypads(); +} + +int DisplayServerJavaScript::get_current_video_driver() const { + return 1; +} + +void DisplayServerJavaScript::swap_buffers() { + //emscripten_webgl_commit_frame(); +} diff --git a/platform/javascript/display_server_javascript.h b/platform/javascript/display_server_javascript.h new file mode 100644 index 0000000000..73a7b017e6 --- /dev/null +++ b/platform/javascript/display_server_javascript.h @@ -0,0 +1,188 @@ +/*************************************************************************/ +/* display_server_javascript.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +#ifndef DISPLAY_SERVER_JAVASCRIPT_H +#define DISPLAY_SERVER_JAVASCRIPT_H + +#include "servers/display_server.h" + +#include <emscripten.h> +#include <emscripten/html5.h> + +class DisplayServerJavaScript : public DisplayServer { + + //int video_driver_index; + + Vector2 windowed_size; + + ObjectID window_attached_instance_id = {}; + + Ref<InputEventKey> deferred_key_event; + CursorShape cursor_shape = CURSOR_ARROW; + String cursors[CURSOR_MAX]; + Map<CursorShape, Vector<Variant>> cursors_cache; + Point2 touches[32]; + + Point2i last_click_pos = Point2(-100, -100); // TODO check this again. + double last_click_ms = 0; + int last_click_button_index = -1; + + static EM_BOOL fullscreen_change_callback(int p_event_type, const EmscriptenFullscreenChangeEvent *p_event, void *p_user_data); + + static EM_BOOL keydown_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); + static EM_BOOL keypress_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); + static EM_BOOL keyup_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); + + static EM_BOOL mousemove_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data); + static EM_BOOL mouse_button_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data); + + static EM_BOOL wheel_callback(int p_event_type, const EmscriptenWheelEvent *p_event, void *p_user_data); + + static EM_BOOL touch_press_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data); + static EM_BOOL touchmove_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data); + + static EM_BOOL gamepad_change_callback(int p_event_type, const EmscriptenGamepadEvent *p_event, void *p_user_data); + void process_joypads(); + + static Vector<String> get_rendering_drivers_func(); + static DisplayServer *create_func(const String &p_rendering_driver, DisplayServer::WindowMode p_mode, uint32_t p_flags, const Vector2i &p_resolution, Error &r_error); + + static void _dispatch_input_event(const Ref<InputEvent> &p_event); + +protected: + virtual int get_current_video_driver() const; + +public: + // Override return type to make writing static callbacks less tedious. + static DisplayServerJavaScript *get_singleton(); + + WindowMode window_mode = WINDOW_MODE_WINDOWED; + + String clipboard; + String canvas_id; + + Callable window_event_callback; + Callable input_event_callback; + Callable input_text_callback; + Callable drop_files_callback; + + // from DisplayServer + virtual void alert(const String &p_alert, const String &p_title = "ALERT!"); + virtual bool has_feature(Feature p_feature) const; + virtual String get_name() const; + + // cursor + virtual void cursor_set_shape(CursorShape p_shape); + virtual CursorShape cursor_get_shape() const; + virtual void cursor_set_custom_image(const RES &p_cursor, CursorShape p_shape = CURSOR_ARROW, const Vector2 &p_hotspot = Vector2()); + + // mouse + virtual void mouse_set_mode(MouseMode p_mode); + virtual MouseMode mouse_get_mode() const; + + // touch + virtual bool screen_is_touchscreen(int p_screen = SCREEN_OF_MAIN_WINDOW) const; + + // clipboard + virtual void clipboard_set(const String &p_text); + virtual String clipboard_get() const; + + // screen + virtual int get_screen_count() const; + virtual Point2i screen_get_position(int p_screen = SCREEN_OF_MAIN_WINDOW) const; + virtual Size2i screen_get_size(int p_screen = SCREEN_OF_MAIN_WINDOW) const; + virtual Rect2i screen_get_usable_rect(int p_screen = SCREEN_OF_MAIN_WINDOW) const; + virtual int screen_get_dpi(int p_screen = SCREEN_OF_MAIN_WINDOW) const; + + // windows + virtual Vector<DisplayServer::WindowID> get_window_list() const; + virtual WindowID get_window_at_screen_position(const Point2i &p_position) const; + + virtual void window_attach_instance_id(ObjectID p_instance, WindowID p_window = MAIN_WINDOW_ID); + virtual ObjectID window_get_attached_instance_id(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual void window_set_rect_changed_callback(const Callable &p_callable, WindowID p_window = MAIN_WINDOW_ID); + + virtual void window_set_window_event_callback(const Callable &p_callable, WindowID p_window = MAIN_WINDOW_ID); + virtual void window_set_input_event_callback(const Callable &p_callable, WindowID p_window = MAIN_WINDOW_ID); + virtual void window_set_input_text_callback(const Callable &p_callable, WindowID p_window = MAIN_WINDOW_ID); + + virtual void window_set_drop_files_callback(const Callable &p_callable, WindowID p_window = MAIN_WINDOW_ID); + + virtual void window_set_title(const String &p_title, WindowID p_window = MAIN_WINDOW_ID); + + virtual int window_get_current_screen(WindowID p_window = MAIN_WINDOW_ID) const; + virtual void window_set_current_screen(int p_screen, WindowID p_window = MAIN_WINDOW_ID); + + virtual Point2i window_get_position(WindowID p_window = MAIN_WINDOW_ID) const; + virtual void window_set_position(const Point2i &p_position, WindowID p_window = MAIN_WINDOW_ID); + + virtual void window_set_transient(WindowID p_window, WindowID p_parent); + + virtual void window_set_max_size(const Size2i p_size, WindowID p_window = MAIN_WINDOW_ID); + virtual Size2i window_get_max_size(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual void window_set_min_size(const Size2i p_size, WindowID p_window = MAIN_WINDOW_ID); + virtual Size2i window_get_min_size(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual void window_set_size(const Size2i p_size, WindowID p_window = MAIN_WINDOW_ID); + virtual Size2i window_get_size(WindowID p_window = MAIN_WINDOW_ID) const; + virtual Size2i window_get_real_size(WindowID p_window = MAIN_WINDOW_ID) const; // FIXME: Find clearer name for this. + + virtual void window_set_mode(WindowMode p_mode, WindowID p_window = MAIN_WINDOW_ID); + virtual WindowMode window_get_mode(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual bool window_is_maximize_allowed(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual void window_set_flag(WindowFlags p_flag, bool p_enabled, WindowID p_window = MAIN_WINDOW_ID); + virtual bool window_get_flag(WindowFlags p_flag, WindowID p_window = MAIN_WINDOW_ID) const; + + virtual void window_request_attention(WindowID p_window = MAIN_WINDOW_ID); + virtual void window_move_to_foreground(WindowID p_window = MAIN_WINDOW_ID); + + virtual bool window_can_draw(WindowID p_window = MAIN_WINDOW_ID) const; + + virtual bool can_any_window_draw() const; + + // events + virtual void process_events(); + + // icon + virtual void set_icon(const Ref<Image> &p_icon); + + // others + virtual void swap_buffers(); + + static void register_javascript_driver(); + DisplayServerJavaScript(const String &p_rendering_driver, WindowMode p_mode, uint32_t p_flags, const Vector2i &p_resolution, Error &r_error); + ~DisplayServerJavaScript(); +}; + +#endif // DISPLAY_SERVER_JAVASCRIPT_H diff --git a/platform/javascript/dom_keys.inc b/platform/javascript/dom_keys.inc index fd9df765d2..42d394fd4f 100644 --- a/platform/javascript/dom_keys.inc +++ b/platform/javascript/dom_keys.inc @@ -237,9 +237,12 @@ int dom2godot_keycode(int dom_keycode) { switch (dom_keycode) { //case DOM_VK_CANCEL: return KEY_UNKNOWN; - case DOM_VK_HELP: return KEY_HELP; - case DOM_VK_BACK_SPACE: return KEY_BACKSPACE; - case DOM_VK_TAB: return KEY_TAB; + case DOM_VK_HELP: + return KEY_HELP; + case DOM_VK_BACK_SPACE: + return KEY_BACKSPACE; + case DOM_VK_TAB: + return KEY_TAB; case DOM_VK_CLEAR: case DOM_VK_WIN_OEM_CLEAR: // OEM duplicate @@ -249,14 +252,17 @@ int dom2godot_keycode(int dom_keycode) { case DOM_VK_ENTER: // unused according to MDN return KEY_ENTER; - case DOM_VK_SHIFT: return KEY_SHIFT; - case DOM_VK_CONTROL: return KEY_CONTROL; + case DOM_VK_SHIFT: + return KEY_SHIFT; + case DOM_VK_CONTROL: + return KEY_CONTROL; case DOM_VK_ALT: case DOM_VK_ALTGR: return KEY_ALT; - case DOM_VK_PAUSE: return KEY_PAUSE; + case DOM_VK_PAUSE: + return KEY_PAUSE; case DOM_VK_CAPS_LOCK: return KEY_CAPSLOCK; @@ -279,14 +285,22 @@ int dom2godot_keycode(int dom_keycode) { case DOM_VK_MODECHANGE: return KEY_UNKNOWN; */ - case DOM_VK_SPACE: return KEY_SPACE; - case DOM_VK_PAGE_UP: return KEY_PAGEUP; - case DOM_VK_PAGE_DOWN: return KEY_PAGEDOWN; - case DOM_VK_END: return KEY_END; - case DOM_VK_HOME: return KEY_HOME; - case DOM_VK_LEFT: return KEY_LEFT; - case DOM_VK_UP: return KEY_UP; - case DOM_VK_RIGHT: return KEY_RIGHT; + case DOM_VK_SPACE: + return KEY_SPACE; + case DOM_VK_PAGE_UP: + return KEY_PAGEUP; + case DOM_VK_PAGE_DOWN: + return KEY_PAGEDOWN; + case DOM_VK_END: + return KEY_END; + case DOM_VK_HOME: + return KEY_HOME; + case DOM_VK_LEFT: + return KEY_LEFT; + case DOM_VK_UP: + return KEY_UP; + case DOM_VK_RIGHT: + return KEY_RIGHT; case DOM_VK_DOWN: return KEY_DOWN; @@ -297,24 +311,31 @@ int dom2godot_keycode(int dom_keycode) { return KEY_PRINT; //case DOM_VK_EXECUTE: return KEY_UNKNOWN; - case DOM_VK_INSERT: return KEY_INSERT; - case DOM_VK_DELETE: return KEY_DELETE; + case DOM_VK_INSERT: + return KEY_INSERT; + case DOM_VK_DELETE: + return KEY_DELETE; case DOM_VK_META: case DOM_VK_WIN: return KEY_META; - case DOM_VK_CONTEXT_MENU: return KEY_MENU; + case DOM_VK_CONTEXT_MENU: + return KEY_MENU; case DOM_VK_SLEEP: return KEY_STANDBY; // Numpad keys - case DOM_VK_MULTIPLY: return KEY_KP_MULTIPLY; - case DOM_VK_ADD: return KEY_KP_ADD; + case DOM_VK_MULTIPLY: + return KEY_KP_MULTIPLY; + case DOM_VK_ADD: + return KEY_KP_ADD; case DOM_VK_SEPARATOR: return KEY_KP_PERIOD; // Good enough? - case DOM_VK_SUBTRACT: return KEY_KP_SUBTRACT; - case DOM_VK_DECIMAL: return KEY_KP_PERIOD; + case DOM_VK_SUBTRACT: + return KEY_KP_SUBTRACT; + case DOM_VK_DECIMAL: + return KEY_KP_PERIOD; case DOM_VK_DIVIDE: return KEY_KP_DIVIDE; @@ -329,7 +350,8 @@ int dom2godot_keycode(int dom_keycode) { case DOM_VK_F24: return KEY_UNKNOWN; */ - case DOM_VK_NUM_LOCK: return KEY_NUMLOCK; + case DOM_VK_NUM_LOCK: + return KEY_NUMLOCK; case DOM_VK_SCROLL_LOCK: return KEY_SCROLLLOCK; @@ -341,40 +363,68 @@ int dom2godot_keycode(int dom_keycode) { case DOM_VK_WIN_OEM_FJ_ROYA: return KEY_UNKNOWN; */ - case DOM_VK_CIRCUMFLEX: return KEY_ASCIICIRCUM; - case DOM_VK_EXCLAMATION: return KEY_EXCLAM; - case DOM_VK_DOUBLE_QUOTE: return KEY_QUOTEDBL; - case DOM_VK_HASH: return KEY_NUMBERSIGN; - case DOM_VK_DOLLAR: return KEY_DOLLAR; - case DOM_VK_PERCENT: return KEY_PERCENT; - case DOM_VK_AMPERSAND: return KEY_AMPERSAND; - case DOM_VK_UNDERSCORE: return KEY_UNDERSCORE; - case DOM_VK_OPEN_PAREN: return KEY_PARENLEFT; - case DOM_VK_CLOSE_PAREN: return KEY_PARENRIGHT; - case DOM_VK_ASTERISK: return KEY_ASTERISK; - case DOM_VK_PLUS: return KEY_PLUS; - case DOM_VK_PIPE: return KEY_BAR; - case DOM_VK_HYPHEN_MINUS: return KEY_MINUS; - case DOM_VK_OPEN_CURLY_BRACKET: return KEY_BRACELEFT; - case DOM_VK_CLOSE_CURLY_BRACKET: return KEY_BRACERIGHT; - case DOM_VK_TILDE: return KEY_ASCIITILDE; + case DOM_VK_CIRCUMFLEX: + return KEY_ASCIICIRCUM; + case DOM_VK_EXCLAMATION: + return KEY_EXCLAM; + case DOM_VK_DOUBLE_QUOTE: + return KEY_QUOTEDBL; + case DOM_VK_HASH: + return KEY_NUMBERSIGN; + case DOM_VK_DOLLAR: + return KEY_DOLLAR; + case DOM_VK_PERCENT: + return KEY_PERCENT; + case DOM_VK_AMPERSAND: + return KEY_AMPERSAND; + case DOM_VK_UNDERSCORE: + return KEY_UNDERSCORE; + case DOM_VK_OPEN_PAREN: + return KEY_PARENLEFT; + case DOM_VK_CLOSE_PAREN: + return KEY_PARENRIGHT; + case DOM_VK_ASTERISK: + return KEY_ASTERISK; + case DOM_VK_PLUS: + return KEY_PLUS; + case DOM_VK_PIPE: + return KEY_BAR; + case DOM_VK_HYPHEN_MINUS: + return KEY_MINUS; + case DOM_VK_OPEN_CURLY_BRACKET: + return KEY_BRACELEFT; + case DOM_VK_CLOSE_CURLY_BRACKET: + return KEY_BRACERIGHT; + case DOM_VK_TILDE: + return KEY_ASCIITILDE; - case DOM_VK_VOLUME_MUTE: return KEY_VOLUMEMUTE; - case DOM_VK_VOLUME_DOWN: return KEY_VOLUMEDOWN; - case DOM_VK_VOLUME_UP: return KEY_VOLUMEUP; + case DOM_VK_VOLUME_MUTE: + return KEY_VOLUMEMUTE; + case DOM_VK_VOLUME_DOWN: + return KEY_VOLUMEDOWN; + case DOM_VK_VOLUME_UP: + return KEY_VOLUMEUP; - case DOM_VK_COMMA: return KEY_COMMA; - case DOM_VK_PERIOD: return KEY_PERIOD; - case DOM_VK_SLASH: return KEY_SLASH; - case DOM_VK_BACK_QUOTE: return KEY_QUOTELEFT; - case DOM_VK_OPEN_BRACKET: return KEY_BRACKETLEFT; - case DOM_VK_BACK_SLASH: return KEY_BACKSLASH; - case DOM_VK_CLOSE_BRACKET: return KEY_BRACKETRIGHT; + case DOM_VK_COMMA: + return KEY_COMMA; + case DOM_VK_PERIOD: + return KEY_PERIOD; + case DOM_VK_SLASH: + return KEY_SLASH; + case DOM_VK_BACK_QUOTE: + return KEY_QUOTELEFT; + case DOM_VK_OPEN_BRACKET: + return KEY_BRACKETLEFT; + case DOM_VK_BACK_SLASH: + return KEY_BACKSLASH; + case DOM_VK_CLOSE_BRACKET: + return KEY_BRACKETRIGHT; case DOM_VK_QUOTE: return KEY_APOSTROPHE; // The rest is OEM/unusual. - default: return KEY_UNKNOWN; + default: + return KEY_UNKNOWN; }; } diff --git a/platform/javascript/engine/engine.js b/platform/javascript/engine/engine.js index 6d7509377f..d709422abb 100644 --- a/platform/javascript/engine/engine.js +++ b/platform/javascript/engine/engine.js @@ -1,16 +1,8 @@ Function('return this')()['Engine'] = (function() { - - var unloadAfterInit = true; - var canvas = null; - var resizeCanvasOnStart = false; - var customLocale = 'en_US'; - var wasmExt = '.wasm'; - var preloader = new Preloader(); - var loader = new Loader(); - var rtenv = null; - var executableName = ''; + var wasmExt = '.wasm'; + var unloadAfterInit = true; var loadPath = ''; var loadPromise = null; var initPromise = null; @@ -33,31 +25,43 @@ Function('return this')()['Engine'] = (function() { }; /** @constructor */ - function Engine() {}; + function Engine() { + this.canvas = null; + this.executableName = ''; + this.rtenv = null; + this.customLocale = null; + this.resizeCanvasOnStart = false; + this.onExecute = null; + this.onExit = null; + }; Engine.prototype.init = /** @param {string=} basePath */ function(basePath) { if (initPromise) { return initPromise; } - if (!loadPromise) { + if (loadPromise == null) { if (!basePath) { initPromise = Promise.reject(new Error("A base path must be provided when calling `init` and the engine is not loaded.")); return initPromise; } load(basePath); } - var config = {} + var config = {}; if (typeof stdout === 'function') config.print = stdout; if (typeof stderr === 'function') config.printErr = stderr; - initPromise = loader.init(loadPromise, loadPath, config).then(function() { - return new Promise(function(resolve, reject) { - rtenv = loader.env; + var me = this; + initPromise = new Promise(function(resolve, reject) { + config['locateFile'] = Utils.createLocateRewrite(loadPath); + config['instantiateWasm'] = Utils.createInstantiatePromise(loadPromise); + Godot(config).then(function(module) { + me.rtenv = module; if (unloadAfterInit) { - loadPromise = null; + unload(); } resolve(); + config = null; }); }); return initPromise; @@ -76,33 +80,72 @@ Function('return this')()['Engine'] = (function() { args.push(arguments[i]); } var me = this; - return new Promise(function(resolve, reject) { - return me.init().then(function() { - if (!(canvas instanceof HTMLCanvasElement)) { - canvas = Utils.findCanvas(); - } - rtenv['locale'] = customLocale; - rtenv['canvas'] = canvas; - rtenv['thisProgram'] = executableName; - rtenv['resizeCanvasOnStart'] = resizeCanvasOnStart; - loader.start(preloader.preloadedFiles, args).then(function() { - loader = null; - initPromise = null; - resolve(); + return me.init().then(function() { + if (!me.rtenv) { + return Promise.reject(new Error('The engine must be initialized before it can be started')); + } + + if (!(me.canvas instanceof HTMLCanvasElement)) { + me.canvas = Utils.findCanvas(); + } + + // Canvas can grab focus on click, or key events won't work. + if (me.canvas.tabIndex < 0) { + me.canvas.tabIndex = 0; + } + + // Disable right-click context menu. + me.canvas.addEventListener('contextmenu', function(ev) { + ev.preventDefault(); + }, false); + + // Until context restoration is implemented warn the user of context loss. + me.canvas.addEventListener('webglcontextlost', function(ev) { + alert("WebGL context lost, please reload the page"); + ev.preventDefault(); + }, false); + + // Browser locale, or custom one if defined. + var locale = me.customLocale; + if (!locale) { + locale = navigator.languages ? navigator.languages[0] : navigator.language; + locale = locale.split('.')[0]; + } + me.rtenv['locale'] = locale; + me.rtenv['canvas'] = me.canvas; + me.rtenv['thisProgram'] = me.executableName; + me.rtenv['resizeCanvasOnStart'] = me.resizeCanvasOnStart; + me.rtenv['noExitRuntime'] = true; + me.rtenv['onExecute'] = me.onExecute; + me.rtenv['onExit'] = function(code) { + if (me.onExit) + me.onExit(code); + me.rtenv = null; + } + return new Promise(function(resolve, reject) { + preloader.preloadedFiles.forEach(function(file) { + me.rtenv['copyToFS'](file.path, file.buffer); }); + preloader.preloadedFiles.length = 0; // Clear memory + me.rtenv['callMain'](args); + initPromise = null; + resolve(); }); }); }; - Engine.prototype.startGame = function(execName, mainPack) { + Engine.prototype.startGame = function(execName, mainPack, extraArgs) { // Start and init with execName as loadPath if not inited. - executableName = execName; + this.executableName = execName; var me = this; return Promise.all([ this.init(execName), this.preloadFile(mainPack, mainPack) ]).then(function() { - return me.start('--main-pack', mainPack); + var args = ['--main-pack', mainPack]; + if (extraArgs) + args = args.concat(extraArgs); + return me.start.apply(me, args); }); }; @@ -118,67 +161,85 @@ Function('return this')()['Engine'] = (function() { }; Engine.prototype.setCanvas = function(canvasElem) { - canvas = canvasElem; + this.canvas = canvasElem; }; Engine.prototype.setCanvasResizedOnStart = function(enabled) { - resizeCanvasOnStart = enabled; + this.resizeCanvasOnStart = enabled; }; Engine.prototype.setLocale = function(locale) { - customLocale = locale; + this.customLocale = locale; }; Engine.prototype.setExecutableName = function(newName) { - executableName = newName; + this.executableName = newName; }; Engine.prototype.setProgressFunc = function(func) { progressFunc = func; - } + }; Engine.prototype.setStdoutFunc = function(func) { - var print = function(text) { if (arguments.length > 1) { text = Array.prototype.slice.call(arguments).join(" "); } func(text); }; - if (rtenv) - rtenv.print = print; + if (this.rtenv) + this.rtenv.print = print; stdout = print; }; Engine.prototype.setStderrFunc = function(func) { - var printErr = function(text) { if (arguments.length > 1) text = Array.prototype.slice.call(arguments).join(" "); func(text); }; - if (rtenv) - rtenv.printErr = printErr; + if (this.rtenv) + this.rtenv.printErr = printErr; stderr = printErr; }; + Engine.prototype.setOnExecute = function(onExecute) { + if (this.rtenv) + this.rtenv.onExecute = onExecute; + this.onExecute = onExecute; + } + + Engine.prototype.setOnExit = function(onExit) { + this.onExit = onExit; + } + + Engine.prototype.copyToFS = function(path, buffer) { + if (this.rtenv == null) { + throw new Error("Engine must be inited before copying files"); + } + this.rtenv['copyToFS'](path, buffer); + } + // Closure compiler exported engine methods. /** @export */ Engine['isWebGLAvailable'] = Utils.isWebGLAvailable; Engine['load'] = load; Engine['unload'] = unload; - Engine.prototype['init'] = Engine.prototype.init - Engine.prototype['preloadFile'] = Engine.prototype.preloadFile - Engine.prototype['start'] = Engine.prototype.start - Engine.prototype['startGame'] = Engine.prototype.startGame - Engine.prototype['setWebAssemblyFilenameExtension'] = Engine.prototype.setWebAssemblyFilenameExtension - Engine.prototype['setUnloadAfterInit'] = Engine.prototype.setUnloadAfterInit - Engine.prototype['setCanvas'] = Engine.prototype.setCanvas - Engine.prototype['setCanvasResizedOnStart'] = Engine.prototype.setCanvasResizedOnStart - Engine.prototype['setLocale'] = Engine.prototype.setLocale - Engine.prototype['setExecutableName'] = Engine.prototype.setExecutableName - Engine.prototype['setProgressFunc'] = Engine.prototype.setProgressFunc - Engine.prototype['setStdoutFunc'] = Engine.prototype.setStdoutFunc - Engine.prototype['setStderrFunc'] = Engine.prototype.setStderrFunc + Engine.prototype['init'] = Engine.prototype.init; + Engine.prototype['preloadFile'] = Engine.prototype.preloadFile; + Engine.prototype['start'] = Engine.prototype.start; + Engine.prototype['startGame'] = Engine.prototype.startGame; + Engine.prototype['setWebAssemblyFilenameExtension'] = Engine.prototype.setWebAssemblyFilenameExtension; + Engine.prototype['setUnloadAfterInit'] = Engine.prototype.setUnloadAfterInit; + Engine.prototype['setCanvas'] = Engine.prototype.setCanvas; + Engine.prototype['setCanvasResizedOnStart'] = Engine.prototype.setCanvasResizedOnStart; + Engine.prototype['setLocale'] = Engine.prototype.setLocale; + Engine.prototype['setExecutableName'] = Engine.prototype.setExecutableName; + Engine.prototype['setProgressFunc'] = Engine.prototype.setProgressFunc; + Engine.prototype['setStdoutFunc'] = Engine.prototype.setStdoutFunc; + Engine.prototype['setStderrFunc'] = Engine.prototype.setStderrFunc; + Engine.prototype['setOnExecute'] = Engine.prototype.setOnExecute; + Engine.prototype['setOnExit'] = Engine.prototype.setOnExit; + Engine.prototype['copyToFS'] = Engine.prototype.copyToFS; return Engine; })(); diff --git a/platform/javascript/engine/loader.js b/platform/javascript/engine/loader.js deleted file mode 100644 index d27fbf612e..0000000000 --- a/platform/javascript/engine/loader.js +++ /dev/null @@ -1,33 +0,0 @@ -var Loader = /** @constructor */ function() { - - this.env = null; - - this.init = function(loadPromise, basePath, config) { - var me = this; - return new Promise(function(resolve, reject) { - var cfg = config || {}; - cfg['locateFile'] = Utils.createLocateRewrite(basePath); - cfg['instantiateWasm'] = Utils.createInstantiatePromise(loadPromise); - loadPromise = null; - Godot(cfg).then(function(module) { - me.env = module; - resolve(); - }); - }); - } - - this.start = function(preloadedFiles, args) { - var me = this; - return new Promise(function(resolve, reject) { - if (!me.env) { - reject(new Error('The engine must be initialized before it can be started')); - } - preloadedFiles.forEach(function(file) { - Utils.copyToFS(me.env['FS'], file.path, file.buffer); - }); - preloadedFiles.length = 0; // Clear memory - me.env['callMain'](args); - resolve(); - }); - } -}; diff --git a/platform/javascript/engine/utils.js b/platform/javascript/engine/utils.js index fdff90a923..0c97b38199 100644 --- a/platform/javascript/engine/utils.js +++ b/platform/javascript/engine/utils.js @@ -27,24 +27,6 @@ var Utils = { return instantiateWasm; }, - copyToFS: function(fs, path, buffer) { - var p = path.lastIndexOf("/"); - var dir = "/"; - if (p > 0) { - dir = path.slice(0, path.lastIndexOf("/")); - } - try { - fs.stat(dir); - } catch (e) { - if (e.errno !== 44) { // 'ENOENT', see https://github.com/emscripten-core/emscripten/blob/master/system/lib/libc/musl/arch/emscripten/bits/errno.h - throw e; - } - fs['mkdirTree'](dir); - } - // With memory growth, canOwn should be false. - fs['writeFile'](path, new Uint8Array(buffer), {'flags': 'wx+'}); - }, - findCanvas: function() { var nodes = document.getElementsByTagName('canvas'); if (nodes.length && nodes[0] instanceof HTMLCanvasElement) { diff --git a/platform/javascript/export/export.cpp b/platform/javascript/export/export.cpp index 36076a2af9..fbc298a399 100644 --- a/platform/javascript/export/export.cpp +++ b/platform/javascript/export/export.cpp @@ -28,6 +28,7 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ +#include "core/io/json.h" #include "core/io/tcp_server.h" #include "core/io/zip_io.h" #include "editor/editor_export.h" @@ -198,7 +199,7 @@ class EditorExportPlatformJavaScript : public EditorExportPlatform { Ref<ImageTexture> stop_icon; int menu_options; - void _fix_html(Vector<uint8_t> &p_html, const Ref<EditorExportPreset> &p_preset, const String &p_name, bool p_debug); + void _fix_html(Vector<uint8_t> &p_html, const Ref<EditorExportPreset> &p_preset, const String &p_name, bool p_debug, int p_flags); private: Ref<EditorHTTPServer> server; @@ -238,15 +239,21 @@ public: virtual void resolve_platform_feature_priorities(const Ref<EditorExportPreset> &p_preset, Set<String> &p_features) { } + String get_debug_protocol() const { return "ws://"; } + EditorExportPlatformJavaScript(); ~EditorExportPlatformJavaScript(); }; -void EditorExportPlatformJavaScript::_fix_html(Vector<uint8_t> &p_html, const Ref<EditorExportPreset> &p_preset, const String &p_name, bool p_debug) { +void EditorExportPlatformJavaScript::_fix_html(Vector<uint8_t> &p_html, const Ref<EditorExportPreset> &p_preset, const String &p_name, bool p_debug, int p_flags) { String str_template = String::utf8(reinterpret_cast<const char *>(p_html.ptr()), p_html.size()); String str_export; Vector<String> lines = str_template.split("\n"); + Vector<String> flags; + String flags_json; + gen_export_flags(flags, p_flags); + flags_json = JSON::print(flags); for (int i = 0; i < lines.size(); i++) { @@ -255,6 +262,7 @@ void EditorExportPlatformJavaScript::_fix_html(Vector<uint8_t> &p_html, const Re current_line = current_line.replace("$GODOT_PROJECT_NAME", ProjectSettings::get_singleton()->get_setting("application/config/name")); current_line = current_line.replace("$GODOT_HEAD_INCLUDE", p_preset->get("html/head_include")); current_line = current_line.replace("$GODOT_DEBUG_ENABLED", p_debug ? "true" : "false"); + current_line = current_line.replace("$GODOT_ARGS", flags_json); str_export += current_line + "\n"; } @@ -430,7 +438,7 @@ Error EditorExportPlatformJavaScript::export_project(const Ref<EditorExportPrese if (!custom_html.empty()) { continue; } - _fix_html(data, p_preset, p_path.get_file().get_basename(), p_debug); + _fix_html(data, p_preset, p_path.get_file().get_basename(), p_debug, p_flags); file = p_path.get_file(); } else if (file == "godot.js") { @@ -469,7 +477,7 @@ Error EditorExportPlatformJavaScript::export_project(const Ref<EditorExportPrese buf.resize(f->get_len()); f->get_buffer(buf.ptrw(), buf.size()); memdelete(f); - _fix_html(buf, p_preset, p_path.get_file().get_basename(), p_debug); + _fix_html(buf, p_preset, p_path.get_file().get_basename(), p_debug, p_flags); f = FileAccess::open(p_path, FileAccess::WRITE); if (!f) { diff --git a/platform/javascript/http_client.h.inc b/platform/javascript/http_client.h.inc index ac275aadbc..4d5ff88bdd 100644 --- a/platform/javascript/http_client.h.inc +++ b/platform/javascript/http_client.h.inc @@ -33,21 +33,21 @@ Error prepare_request(Method p_method, const String &p_url, const Vector<String> &p_headers); int xhr_id; -int read_limit; -int response_read_offset; -Status status; +int read_limit = 4096; +int response_read_offset = 0; +Status status = STATUS_DISCONNECTED; String host; -int port; -bool use_tls; +int port = -1; +bool use_tls = false; String username; String password; -int polled_response_code; +int polled_response_code = 0; String polled_response_header; PackedByteArray polled_response; #ifdef DEBUG_ENABLED -bool has_polled; -uint64_t last_polling_frame; +bool has_polled = false; +uint64_t last_polling_frame = 0; #endif diff --git a/platform/javascript/http_client_javascript.cpp b/platform/javascript/http_client_javascript.cpp index 863c207896..19ce7ef219 100644 --- a/platform/javascript/http_client_javascript.cpp +++ b/platform/javascript/http_client_javascript.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "core/io/http_client.h" + #include "http_request.h" Error HTTPClient::connect_to_host(const String &p_host, int p_port, bool p_ssl, bool p_verify_host) { @@ -281,15 +282,6 @@ Error HTTPClient::poll() { HTTPClient::HTTPClient() { xhr_id = godot_xhr_new(); - read_limit = 4096; - status = STATUS_DISCONNECTED; - port = -1; - use_tls = false; - polled_response_code = 0; -#ifdef DEBUG_ENABLED - has_polled = false; - last_polling_frame = 0; -#endif } HTTPClient::~HTTPClient() { diff --git a/platform/javascript/javascript_main.cpp b/platform/javascript/javascript_main.cpp index 815bc7e456..854383aeee 100644 --- a/platform/javascript/javascript_main.cpp +++ b/platform/javascript/javascript_main.cpp @@ -34,22 +34,63 @@ #include <emscripten/emscripten.h> +static OS_JavaScript *os = nullptr; + +void exit_callback() { + emscripten_cancel_main_loop(); // After this, we can exit! + Main::cleanup(); + int exit_code = OS_JavaScript::get_singleton()->get_exit_code(); + memdelete(os); + os = nullptr; + emscripten_force_exit(exit_code); // No matter that we call cancel_main_loop, regular "exit" will not work, forcing. +} + +void main_loop_callback() { + + if (os->main_loop_iterate()) { + emscripten_cancel_main_loop(); // Cancel current loop and wait for finalize_async. + EM_ASM({ + // This will contain the list of operations that need to complete before cleanup. + Module.async_finish = []; + }); + os->get_main_loop()->finish(); + os->finalize_async(); // Will add all the async finish functions. + EM_ASM({ + Promise.all(Module.async_finish).then(function() { + Module.async_finish = []; + ccall("cleanup_after_sync", null, []); + }); + }); + } +} + +extern "C" EMSCRIPTEN_KEEPALIVE void cleanup_after_sync() { + emscripten_set_main_loop(exit_callback, -1, false); +} + extern "C" EMSCRIPTEN_KEEPALIVE void main_after_fs_sync(char *p_idbfs_err) { String idbfs_err = String::utf8(p_idbfs_err); if (!idbfs_err.empty()) { print_line("IndexedDB not available: " + idbfs_err); } - OS_JavaScript *os = OS_JavaScript::get_singleton(); os->set_idb_available(idbfs_err.empty()); + // TODO: Check error return value. + Main::setup2(); // Manual second phase. // Ease up compatibility. ResourceLoader::set_abort_on_missing_resources(false); Main::start(); - os->run_async(); + os->get_main_loop()->init(); + emscripten_resume_main_loop(); } int main(int argc, char *argv[]) { + os = new OS_JavaScript(); + Main::setup(argv[0], argc - 1, &argv[1], false); + emscripten_set_main_loop(main_loop_callback, -1, false); + emscripten_pause_main_loop(); // Will need to wait for FS sync. + // Sync from persistent state into memory and then // run the 'main_after_fs_sync' function. /* clang-format off */ @@ -59,13 +100,10 @@ int main(int argc, char *argv[]) { FS.syncfs(true, function(err) { ccall('main_after_fs_sync', null, ['string'], [err ? err.message : ""]) }); + ); /* clang-format on */ - new OS_JavaScript(argc, argv); - // TODO: Check error return value. - Main::setup(argv[0], argc - 1, &argv[1]); - return 0; // Continued async in main_after_fs_sync() from the syncfs() callback. } diff --git a/platform/javascript/http_request.js b/platform/javascript/native/http_request.js index f621689f9d..f621689f9d 100644 --- a/platform/javascript/http_request.js +++ b/platform/javascript/native/http_request.js diff --git a/platform/javascript/id_handler.js b/platform/javascript/native/id_handler.js index 67d29075b8..67d29075b8 100644 --- a/platform/javascript/id_handler.js +++ b/platform/javascript/native/id_handler.js diff --git a/platform/javascript/native/utils.js b/platform/javascript/native/utils.js new file mode 100644 index 0000000000..95585d26ae --- /dev/null +++ b/platform/javascript/native/utils.js @@ -0,0 +1,204 @@ +/*************************************************************************/ +/* utils.js */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + +Module['copyToFS'] = function(path, buffer) { + var p = path.lastIndexOf("/"); + var dir = "/"; + if (p > 0) { + dir = path.slice(0, path.lastIndexOf("/")); + } + try { + FS.stat(dir); + } catch (e) { + if (e.errno !== ERRNO_CODES.ENOENT) { // 'ENOENT', see https://github.com/emscripten-core/emscripten/blob/master/system/lib/libc/musl/arch/emscripten/bits/errno.h + throw e; + } + FS.mkdirTree(dir); + } + // With memory growth, canOwn should be false. + FS.writeFile(path, new Uint8Array(buffer), {'flags': 'wx+'}); +} + +Module.drop_handler = (function() { + var upload = []; + var uploadPromises = []; + var uploadCallback = null; + + function readFilePromise(entry, path) { + return new Promise(function(resolve, reject) { + entry.file(function(file) { + var reader = new FileReader(); + reader.onload = function() { + var f = { + "path": file.relativePath || file.webkitRelativePath, + "name": file.name, + "type": file.type, + "size": file.size, + "data": reader.result + }; + if (!f['path']) + f['path'] = f['name']; + upload.push(f); + resolve() + }; + reader.onerror = function() { + console.log("Error reading file"); + reject(); + } + + reader.readAsArrayBuffer(file); + + }, function(err) { + console.log("Error!"); + reject(); + }); + }); + } + + function readDirectoryPromise(entry) { + return new Promise(function(resolve, reject) { + var reader = entry.createReader(); + reader.readEntries(function(entries) { + for (var i = 0; i < entries.length; i++) { + var ent = entries[i]; + if (ent.isDirectory) { + uploadPromises.push(readDirectoryPromise(ent)); + } else if (ent.isFile) { + uploadPromises.push(readFilePromise(ent)); + } + } + resolve(); + }); + }); + } + + function processUploadsPromises(resolve, reject) { + if (uploadPromises.length == 0) { + resolve(); + return; + } + uploadPromises.pop().then(function() { + setTimeout(function() { + processUploadsPromises(resolve, reject); + //processUploadsPromises.bind(null, resolve, reject) + }, 0); + }); + } + + function dropFiles(files) { + var args = files || []; + var argc = args.length; + var argv = stackAlloc((argc + 1) * 4); + for (var i = 0; i < argc; i++) { + HEAP32[(argv >> 2) + i] = allocateUTF8OnStack(args[i]); + } + HEAP32[(argv >> 2) + argc] = 0; + // Defined in display_server_javascript.cpp + ccall('_drop_files_callback', 'void', ['number', 'number'], [argv, argc]); + } + + return function(ev) { + ev.preventDefault(); + if (ev.dataTransfer.items) { + // Use DataTransferItemList interface to access the file(s) + for (var i = 0; i < ev.dataTransfer.items.length; i++) { + const item = ev.dataTransfer.items[i]; + var entry = null; + if ("getAsEntry" in item) { + entry = item.getAsEntry(); + } else if ("webkitGetAsEntry" in item) { + entry = item.webkitGetAsEntry(); + } + if (!entry) { + console.error("File upload not supported"); + } else if (entry.isDirectory) { + uploadPromises.push(readDirectoryPromise(entry)); + } else if (entry.isFile) { + uploadPromises.push(readFilePromise(entry)); + } else { + console.error("Unrecognized entry...", entry); + } + } + } else { + console.error("File upload not supported"); + } + uploadCallback = new Promise(processUploadsPromises).then(function() { + const DROP = "/tmp/drop-" + parseInt(Math.random() * Math.pow(2, 31)) + "/"; + var drops = []; + var files = []; + upload.forEach((elem) => { + var path = elem['path']; + Module['copyToFS'](DROP + path, elem['data']); + var idx = path.indexOf("/"); + if (idx == -1) { + // Root file + drops.push(DROP + path); + } else { + // Subdir + var sub = path.substr(0, idx); + idx = sub.indexOf("/"); + if (idx < 0 && drops.indexOf(DROP + sub) == -1) { + drops.push(DROP + sub); + } + } + files.push(DROP + path); + }); + uploadPromises = []; + upload = []; + dropFiles(drops); + var dirs = [DROP.substr(0, DROP.length -1)]; + files.forEach(function (file) { + FS.unlink(file); + var dir = file.replace(DROP, ""); + var idx = dir.lastIndexOf("/"); + while (idx > 0) { + dir = dir.substr(0, idx); + if (dirs.indexOf(DROP + dir) == -1) { + dirs.push(DROP + dir); + } + idx = dir.lastIndexOf("/"); + } + }); + // Remove dirs. + dirs = dirs.sort(function(a, b) { + var al = (a.match(/\//g) || []).length; + var bl = (b.match(/\//g) || []).length; + if (al > bl) + return -1; + else if (al < bl) + return 1; + return 0; + }); + dirs.forEach(function(dir) { + FS.rmdir(dir); + }); + }); + } +})(); diff --git a/platform/javascript/os_javascript.cpp b/platform/javascript/os_javascript.cpp index c0230b94fa..1ec23973d6 100644 --- a/platform/javascript/os_javascript.cpp +++ b/platform/javascript/os_javascript.cpp @@ -30,666 +30,22 @@ #include "os_javascript.h" +#include "core/debugger/engine_debugger.h" #include "core/io/file_access_buffered_fa.h" -//#include "drivers/gles2/rasterizer_gles2.h" -#include "drivers/dummy/rasterizer_dummy.h" +#include "core/io/json.h" #include "drivers/unix/dir_access_unix.h" #include "drivers/unix/file_access_unix.h" #include "main/main.h" -#include "servers/rendering/rendering_server_raster.h" +#include "modules/modules_enabled.gen.h" +#include "platform/javascript/display_server_javascript.h" + +#ifdef MODULE_WEBSOCKET_ENABLED +#include "modules/websocket/remote_debugger_peer_websocket.h" +#endif #include <emscripten.h> -#include <png.h> #include <stdlib.h> -#include "dom_keys.inc" - -#define DOM_BUTTON_LEFT 0 -#define DOM_BUTTON_MIDDLE 1 -#define DOM_BUTTON_RIGHT 2 -#define DOM_BUTTON_XBUTTON1 3 -#define DOM_BUTTON_XBUTTON2 4 -#define GODOT_CANVAS_SELECTOR "#canvas" - -// Window (canvas) - -static void focus_canvas() { - - /* clang-format off */ - EM_ASM( - Module.canvas.focus(); - ); - /* clang-format on */ -} - -static bool is_canvas_focused() { - - /* clang-format off */ - return EM_ASM_INT_V( - return document.activeElement == Module.canvas; - ); - /* clang-format on */ -} - -static Point2 compute_position_in_canvas(int x, int y) { - int canvas_x = EM_ASM_INT({ - return document.getElementById('canvas').getBoundingClientRect().x; - }); - int canvas_y = EM_ASM_INT({ - return document.getElementById('canvas').getBoundingClientRect().y; - }); - int canvas_width; - int canvas_height; - emscripten_get_canvas_element_size(GODOT_CANVAS_SELECTOR, &canvas_width, &canvas_height); - - double element_width; - double element_height; - emscripten_get_element_css_size(GODOT_CANVAS_SELECTOR, &element_width, &element_height); - - return Point2((int)(canvas_width / element_width * (x - canvas_x)), - (int)(canvas_height / element_height * (y - canvas_y))); -} - -static bool cursor_inside_canvas = true; - -EM_BOOL OS_JavaScript::fullscreen_change_callback(int p_event_type, const EmscriptenFullscreenChangeEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - // Empty ID is canvas. - String target_id = String::utf8(p_event->id); - if (target_id.empty() || target_id == "canvas") { - // This event property is the only reliable data on - // browser fullscreen state. - os->video_mode.fullscreen = p_event->isFullscreen; - if (os->video_mode.fullscreen) { - os->entering_fullscreen = false; - } else { - // Restoring maximized window now will cause issues, - // so delay until main_loop_iterate. - os->just_exited_fullscreen = true; - } - } - return false; -} - -void OS_JavaScript::set_video_mode(const VideoMode &p_video_mode, int p_screen) { - - video_mode = p_video_mode; -} - -OS::VideoMode OS_JavaScript::get_video_mode(int p_screen) const { - - return video_mode; -} - -Size2 OS_JavaScript::get_screen_size(int p_screen) const { - - EmscriptenFullscreenChangeEvent ev; - EMSCRIPTEN_RESULT result = emscripten_get_fullscreen_status(&ev); - ERR_FAIL_COND_V(result != EMSCRIPTEN_RESULT_SUCCESS, Size2()); - return Size2(ev.screenWidth, ev.screenHeight); -} - -void OS_JavaScript::set_window_size(const Size2 p_size) { - - windowed_size = p_size; - if (video_mode.fullscreen) { - window_maximized = false; - set_window_fullscreen(false); - } else { - if (window_maximized) { - emscripten_exit_soft_fullscreen(); - window_maximized = false; - } - emscripten_set_canvas_element_size(GODOT_CANVAS_SELECTOR, p_size.x, p_size.y); - } -} - -Size2 OS_JavaScript::get_window_size() const { - - int canvas[2]; - emscripten_get_canvas_element_size(GODOT_CANVAS_SELECTOR, canvas, canvas + 1); - return Size2(canvas[0], canvas[1]); -} - -void OS_JavaScript::set_window_maximized(bool p_enabled) { - - if (video_mode.fullscreen) { - window_maximized = p_enabled; - set_window_fullscreen(false); - } else if (!p_enabled) { - emscripten_exit_soft_fullscreen(); - window_maximized = false; - } else if (!window_maximized) { - // Prevent calling emscripten_enter_soft_fullscreen mutltiple times, - // this would hide page elements permanently. - EmscriptenFullscreenStrategy strategy; - strategy.scaleMode = EMSCRIPTEN_FULLSCREEN_SCALE_STRETCH; - strategy.canvasResolutionScaleMode = EMSCRIPTEN_FULLSCREEN_CANVAS_SCALE_STDDEF; - strategy.filteringMode = EMSCRIPTEN_FULLSCREEN_FILTERING_DEFAULT; - strategy.canvasResizedCallback = nullptr; - emscripten_enter_soft_fullscreen(GODOT_CANVAS_SELECTOR, &strategy); - window_maximized = p_enabled; - } -} - -bool OS_JavaScript::is_window_maximized() const { - - return window_maximized; -} - -void OS_JavaScript::set_window_fullscreen(bool p_enabled) { - - if (p_enabled == video_mode.fullscreen) { - return; - } - - // Just request changes here, if successful, logic continues in - // fullscreen_change_callback. - if (p_enabled) { - if (window_maximized) { - // Soft fullsreen during real fullscreen can cause issues, so exit. - // This must be called before requesting full screen. - emscripten_exit_soft_fullscreen(); - } - EmscriptenFullscreenStrategy strategy; - strategy.scaleMode = EMSCRIPTEN_FULLSCREEN_SCALE_STRETCH; - strategy.canvasResolutionScaleMode = EMSCRIPTEN_FULLSCREEN_CANVAS_SCALE_STDDEF; - strategy.filteringMode = EMSCRIPTEN_FULLSCREEN_FILTERING_DEFAULT; - strategy.canvasResizedCallback = nullptr; - EMSCRIPTEN_RESULT result = emscripten_request_fullscreen_strategy(GODOT_CANVAS_SELECTOR, false, &strategy); - ERR_FAIL_COND_MSG(result == EMSCRIPTEN_RESULT_FAILED_NOT_DEFERRED, "Enabling fullscreen is only possible from an input callback for the HTML5 platform."); - ERR_FAIL_COND_MSG(result != EMSCRIPTEN_RESULT_SUCCESS, "Enabling fullscreen is only possible from an input callback for the HTML5 platform."); - // Not fullscreen yet, so prevent "windowed" canvas dimensions from - // being overwritten. - entering_fullscreen = true; - } else { - // No logic allowed here, since exiting w/ ESC key won't use this function. - ERR_FAIL_COND(emscripten_exit_fullscreen() != EMSCRIPTEN_RESULT_SUCCESS); - } -} - -bool OS_JavaScript::is_window_fullscreen() const { - - return video_mode.fullscreen; -} - -void OS_JavaScript::get_fullscreen_mode_list(List<VideoMode> *p_list, int p_screen) const { - - Size2 screen = get_screen_size(); - p_list->push_back(OS::VideoMode(screen.width, screen.height, true)); -} - -bool OS_JavaScript::get_window_per_pixel_transparency_enabled() const { - if (!is_layered_allowed()) { - return false; - } - return transparency_enabled; -} - -void OS_JavaScript::set_window_per_pixel_transparency_enabled(bool p_enabled) { - if (!is_layered_allowed()) { - return; - } - transparency_enabled = p_enabled; -} - -// Keys - -template <typename T> -static void dom2godot_mod(T *emscripten_event_ptr, Ref<InputEventWithModifiers> godot_event) { - - godot_event->set_shift(emscripten_event_ptr->shiftKey); - godot_event->set_alt(emscripten_event_ptr->altKey); - godot_event->set_control(emscripten_event_ptr->ctrlKey); - godot_event->set_metakey(emscripten_event_ptr->metaKey); -} - -static Ref<InputEventKey> setup_key_event(const EmscriptenKeyboardEvent *emscripten_event) { - - Ref<InputEventKey> ev; - ev.instance(); - ev->set_echo(emscripten_event->repeat); - dom2godot_mod(emscripten_event, ev); - ev->set_keycode(dom2godot_keycode(emscripten_event->keyCode)); - ev->set_physical_keycode(dom2godot_keycode(emscripten_event->keyCode)); - - String unicode = String::utf8(emscripten_event->key); - // Check if empty or multi-character (e.g. `CapsLock`). - if (unicode.length() != 1) { - // Might be empty as well, but better than nonsense. - unicode = String::utf8(emscripten_event->charValue); - } - if (unicode.length() == 1) { - ev->set_unicode(unicode[0]); - } - - return ev; -} - -EM_BOOL OS_JavaScript::keydown_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - Ref<InputEventKey> ev = setup_key_event(p_event); - ev->set_pressed(true); - if (ev->get_unicode() == 0 && keycode_has_unicode(ev->get_keycode())) { - // Defer to keypress event for legacy unicode retrieval. - os->deferred_key_event = ev; - // Do not suppress keypress event. - return false; - } - os->input->parse_input_event(ev); - // Resume audio context after input in case autoplay was denied. - os->audio_driver_javascript.resume(); - return true; -} - -EM_BOOL OS_JavaScript::keypress_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - os->deferred_key_event->set_unicode(p_event->charCode); - os->input->parse_input_event(os->deferred_key_event); - return true; -} - -EM_BOOL OS_JavaScript::keyup_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data) { - - Ref<InputEventKey> ev = setup_key_event(p_event); - ev->set_pressed(false); - get_singleton()->input->parse_input_event(ev); - return ev->get_keycode() != KEY_UNKNOWN && ev->get_keycode() != 0; -} - -// Mouse - -Point2 OS_JavaScript::get_mouse_position() const { - - return input->get_mouse_position(); -} - -int OS_JavaScript::get_mouse_button_state() const { - - return input->get_mouse_button_mask(); -} - -EM_BOOL OS_JavaScript::mouse_button_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - - Ref<InputEventMouseButton> ev; - ev.instance(); - ev->set_pressed(p_event_type == EMSCRIPTEN_EVENT_MOUSEDOWN); - ev->set_position(compute_position_in_canvas(p_event->clientX, p_event->clientY)); - ev->set_global_position(ev->get_position()); - dom2godot_mod(p_event, ev); - - switch (p_event->button) { - case DOM_BUTTON_LEFT: ev->set_button_index(BUTTON_LEFT); break; - case DOM_BUTTON_MIDDLE: ev->set_button_index(BUTTON_MIDDLE); break; - case DOM_BUTTON_RIGHT: ev->set_button_index(BUTTON_RIGHT); break; - case DOM_BUTTON_XBUTTON1: ev->set_button_index(BUTTON_XBUTTON1); break; - case DOM_BUTTON_XBUTTON2: ev->set_button_index(BUTTON_XBUTTON2); break; - default: return false; - } - - if (ev->is_pressed()) { - - double diff = emscripten_get_now() - os->last_click_ms; - - if (ev->get_button_index() == os->last_click_button_index) { - - if (diff < 400 && Point2(os->last_click_pos).distance_to(ev->get_position()) < 5) { - - os->last_click_ms = 0; - os->last_click_pos = Point2(-100, -100); - os->last_click_button_index = -1; - ev->set_doubleclick(true); - } - - } else { - os->last_click_button_index = ev->get_button_index(); - } - - if (!ev->is_doubleclick()) { - os->last_click_ms += diff; - os->last_click_pos = ev->get_position(); - } - } - - int mask = os->input->get_mouse_button_mask(); - int button_flag = 1 << (ev->get_button_index() - 1); - if (ev->is_pressed()) { - // Since the event is consumed, focus manually. The containing iframe, - // if exists, may not have focus yet, so focus even if already focused. - focus_canvas(); - mask |= button_flag; - } else if (mask & button_flag) { - mask &= ~button_flag; - } else { - // Received release event, but press was outside the canvas, so ignore. - return false; - } - ev->set_button_mask(mask); - - os->input->parse_input_event(ev); - // Resume audio context after input in case autoplay was denied. - os->audio_driver_javascript.resume(); - // Prevent multi-click text selection and wheel-click scrolling anchor. - // Context menu is prevented through contextmenu event. - return true; -} - -EM_BOOL OS_JavaScript::mousemove_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - - int input_mask = os->input->get_mouse_button_mask(); - Point2 pos = compute_position_in_canvas(p_event->clientX, p_event->clientY); - // For motion outside the canvas, only read mouse movement if dragging - // started inside the canvas; imitating desktop app behaviour. - if (!cursor_inside_canvas && !input_mask) - return false; - - Ref<InputEventMouseMotion> ev; - ev.instance(); - dom2godot_mod(p_event, ev); - ev->set_button_mask(input_mask); - - ev->set_position(pos); - ev->set_global_position(ev->get_position()); - - ev->set_relative(Vector2(p_event->movementX, p_event->movementY)); - os->input->set_mouse_position(ev->get_position()); - ev->set_speed(os->input->get_last_mouse_speed()); - - os->input->parse_input_event(ev); - // Don't suppress mouseover/-leave events. - return false; -} - -static const char *godot2dom_cursor(OS::CursorShape p_shape) { - - switch (p_shape) { - case OS::CURSOR_ARROW: - default: - return "auto"; - case OS::CURSOR_IBEAM: return "text"; - case OS::CURSOR_POINTING_HAND: return "pointer"; - case OS::CURSOR_CROSS: return "crosshair"; - case OS::CURSOR_WAIT: return "progress"; - case OS::CURSOR_BUSY: return "wait"; - case OS::CURSOR_DRAG: return "grab"; - case OS::CURSOR_CAN_DROP: return "grabbing"; - case OS::CURSOR_FORBIDDEN: return "no-drop"; - case OS::CURSOR_VSIZE: return "ns-resize"; - case OS::CURSOR_HSIZE: return "ew-resize"; - case OS::CURSOR_BDIAGSIZE: return "nesw-resize"; - case OS::CURSOR_FDIAGSIZE: return "nwse-resize"; - case OS::CURSOR_MOVE: return "move"; - case OS::CURSOR_VSPLIT: return "row-resize"; - case OS::CURSOR_HSPLIT: return "col-resize"; - case OS::CURSOR_HELP: return "help"; - } -} - -static void set_css_cursor(const char *p_cursor) { - - /* clang-format off */ - EM_ASM_({ - Module.canvas.style.cursor = UTF8ToString($0); - }, p_cursor); - /* clang-format on */ -} - -static bool is_css_cursor_hidden() { - - /* clang-format off */ - return EM_ASM_INT({ - return Module.canvas.style.cursor === 'none'; - }); - /* clang-format on */ -} - -void OS_JavaScript::set_cursor_shape(CursorShape p_shape) { - - ERR_FAIL_INDEX(p_shape, CURSOR_MAX); - - if (get_mouse_mode() == MOUSE_MODE_VISIBLE) { - if (cursors[p_shape] != "") { - Vector<String> url = cursors[p_shape].split("?"); - set_css_cursor(("url(\"" + url[0] + "\") " + url[1] + ", auto").utf8()); - } else { - set_css_cursor(godot2dom_cursor(p_shape)); - } - } - - cursor_shape = p_shape; -} - -void OS_JavaScript::set_custom_mouse_cursor(const RES &p_cursor, CursorShape p_shape, const Vector2 &p_hotspot) { - - if (p_cursor.is_valid()) { - - Map<CursorShape, Vector<Variant>>::Element *cursor_c = cursors_cache.find(p_shape); - - if (cursor_c) { - if (cursor_c->get()[0] == p_cursor && cursor_c->get()[1] == p_hotspot) { - set_cursor_shape(p_shape); - return; - } - - cursors_cache.erase(p_shape); - } - - Ref<Texture2D> texture = p_cursor; - Ref<AtlasTexture> atlas_texture = p_cursor; - Ref<Image> image; - Size2 texture_size; - Rect2 atlas_rect; - - if (texture.is_valid()) { - image = texture->get_data(); - if (image.is_valid()) { - image->duplicate(); - } - } - - if (!image.is_valid() && atlas_texture.is_valid()) { - texture = atlas_texture->get_atlas(); - - atlas_rect.size.width = texture->get_width(); - atlas_rect.size.height = texture->get_height(); - atlas_rect.position.x = atlas_texture->get_region().position.x; - atlas_rect.position.y = atlas_texture->get_region().position.y; - - texture_size.width = atlas_texture->get_region().size.x; - texture_size.height = atlas_texture->get_region().size.y; - } else if (image.is_valid()) { - texture_size.width = texture->get_width(); - texture_size.height = texture->get_height(); - } - - ERR_FAIL_COND(!texture.is_valid()); - ERR_FAIL_COND(p_hotspot.x < 0 || p_hotspot.y < 0); - ERR_FAIL_COND(texture_size.width > 256 || texture_size.height > 256); - ERR_FAIL_COND(p_hotspot.x > texture_size.width || p_hotspot.y > texture_size.height); - - image = texture->get_data(); - - ERR_FAIL_COND(!image.is_valid()); - - image = image->duplicate(); - - if (atlas_texture.is_valid()) - image->crop_from_point( - atlas_rect.position.x, - atlas_rect.position.y, - texture_size.width, - texture_size.height); - - if (image->get_format() != Image::FORMAT_RGBA8) { - image->convert(Image::FORMAT_RGBA8); - } - - png_image png_meta; - memset(&png_meta, 0, sizeof png_meta); - png_meta.version = PNG_IMAGE_VERSION; - png_meta.width = texture_size.width; - png_meta.height = texture_size.height; - png_meta.format = PNG_FORMAT_RGBA; - - PackedByteArray png; - size_t len; - PackedByteArray data = image->get_data(); - ERR_FAIL_COND(!png_image_write_get_memory_size(png_meta, len, 0, data.ptr(), 0, nullptr)); - - png.resize(len); - ERR_FAIL_COND(!png_image_write_to_memory(&png_meta, png.ptrw(), &len, 0, data.ptr(), 0, nullptr)); - - char *object_url; - /* clang-format off */ - EM_ASM({ - var PNG_PTR = $0; - var PNG_LEN = $1; - var PTR = $2; - - var png = new Blob([HEAPU8.slice(PNG_PTR, PNG_PTR + PNG_LEN)], { type: 'image/png' }); - var url = URL.createObjectURL(png); - var length_bytes = lengthBytesUTF8(url) + 1; - var string_on_wasm_heap = _malloc(length_bytes); - setValue(PTR, string_on_wasm_heap, '*'); - stringToUTF8(url, string_on_wasm_heap, length_bytes); - }, png.ptr(), len, &object_url); - /* clang-format on */ - - String url = String::utf8(object_url) + "?" + itos(p_hotspot.x) + " " + itos(p_hotspot.y); - - /* clang-format off */ - EM_ASM({ _free($0); }, object_url); - /* clang-format on */ - - if (cursors[p_shape] != "") { - /* clang-format off */ - EM_ASM({ - URL.revokeObjectURL(UTF8ToString($0).split('?')[0]); - }, cursors[p_shape].utf8().get_data()); - /* clang-format on */ - cursors[p_shape] = ""; - } - - cursors[p_shape] = url; - - Vector<Variant> params; - params.push_back(p_cursor); - params.push_back(p_hotspot); - cursors_cache.insert(p_shape, params); - - } else if (cursors[p_shape] != "") { - /* clang-format off */ - EM_ASM({ - URL.revokeObjectURL(UTF8ToString($0).split('?')[0]); - }, cursors[p_shape].utf8().get_data()); - /* clang-format on */ - cursors[p_shape] = ""; - - cursors_cache.erase(p_shape); - } - - set_cursor_shape(cursor_shape); -} - -void OS_JavaScript::set_mouse_mode(OS::MouseMode p_mode) { - - ERR_FAIL_COND_MSG(p_mode == MOUSE_MODE_CONFINED, "MOUSE_MODE_CONFINED is not supported for the HTML5 platform."); - if (p_mode == get_mouse_mode()) - return; - - if (p_mode == MOUSE_MODE_VISIBLE) { - - // set_css_cursor must be called before set_cursor_shape to make the cursor visible - set_css_cursor(godot2dom_cursor(cursor_shape)); - set_cursor_shape(cursor_shape); - emscripten_exit_pointerlock(); - - } else if (p_mode == MOUSE_MODE_HIDDEN) { - - set_css_cursor("none"); - emscripten_exit_pointerlock(); - - } else if (p_mode == MOUSE_MODE_CAPTURED) { - - EMSCRIPTEN_RESULT result = emscripten_request_pointerlock("canvas", false); - ERR_FAIL_COND_MSG(result == EMSCRIPTEN_RESULT_FAILED_NOT_DEFERRED, "MOUSE_MODE_CAPTURED can only be entered from within an appropriate input callback."); - ERR_FAIL_COND_MSG(result != EMSCRIPTEN_RESULT_SUCCESS, "MOUSE_MODE_CAPTURED can only be entered from within an appropriate input callback."); - // set_css_cursor must be called before set_cursor_shape to make the cursor visible - set_css_cursor(godot2dom_cursor(cursor_shape)); - set_cursor_shape(cursor_shape); - } -} - -OS::MouseMode OS_JavaScript::get_mouse_mode() const { - - if (is_css_cursor_hidden()) - return MOUSE_MODE_HIDDEN; - - EmscriptenPointerlockChangeEvent ev; - emscripten_get_pointerlock_status(&ev); - return (ev.isActive && String::utf8(ev.id) == "canvas") ? MOUSE_MODE_CAPTURED : MOUSE_MODE_VISIBLE; -} - -// Wheel - -EM_BOOL OS_JavaScript::wheel_callback(int p_event_type, const EmscriptenWheelEvent *p_event, void *p_user_data) { - - ERR_FAIL_COND_V(p_event_type != EMSCRIPTEN_EVENT_WHEEL, false); - if (!is_canvas_focused()) { - if (cursor_inside_canvas) { - focus_canvas(); - } else { - return false; - } - } - - InputDefault *input = get_singleton()->input; - Ref<InputEventMouseButton> ev; - ev.instance(); - ev->set_position(input->get_mouse_position()); - ev->set_global_position(ev->get_position()); - - ev->set_shift(input->is_key_pressed(KEY_SHIFT)); - ev->set_alt(input->is_key_pressed(KEY_ALT)); - ev->set_control(input->is_key_pressed(KEY_CONTROL)); - ev->set_metakey(input->is_key_pressed(KEY_META)); - - if (p_event->deltaY < 0) - ev->set_button_index(BUTTON_WHEEL_UP); - else if (p_event->deltaY > 0) - ev->set_button_index(BUTTON_WHEEL_DOWN); - else if (p_event->deltaX > 0) - ev->set_button_index(BUTTON_WHEEL_LEFT); - else if (p_event->deltaX < 0) - ev->set_button_index(BUTTON_WHEEL_RIGHT); - else - return false; - - // Different browsers give wildly different delta values, and we can't - // interpret deltaMode, so use default value for wheel events' factor. - - int button_flag = 1 << (ev->get_button_index() - 1); - - ev->set_pressed(true); - ev->set_button_mask(input->get_mouse_button_mask() | button_flag); - input->parse_input_event(ev); - - ev->set_pressed(false); - ev->set_button_mask(input->get_mouse_button_mask() & ~button_flag); - input->parse_input_event(ev); - - return true; -} - -// Touch - bool OS_JavaScript::has_touchscreen_ui_hint() const { /* clang-format off */ @@ -699,134 +55,6 @@ bool OS_JavaScript::has_touchscreen_ui_hint() const { /* clang-format on */ } -EM_BOOL OS_JavaScript::touch_press_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - Ref<InputEventScreenTouch> ev; - ev.instance(); - int lowest_id_index = -1; - for (int i = 0; i < p_event->numTouches; ++i) { - - const EmscriptenTouchPoint &touch = p_event->touches[i]; - if (lowest_id_index == -1 || touch.identifier < p_event->touches[lowest_id_index].identifier) - lowest_id_index = i; - if (!touch.isChanged) - continue; - ev->set_index(touch.identifier); - ev->set_position(compute_position_in_canvas(touch.clientX, touch.clientY)); - os->touches[i] = ev->get_position(); - ev->set_pressed(p_event_type == EMSCRIPTEN_EVENT_TOUCHSTART); - - os->input->parse_input_event(ev); - } - // Resume audio context after input in case autoplay was denied. - os->audio_driver_javascript.resume(); - return true; -} - -EM_BOOL OS_JavaScript::touchmove_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data) { - - OS_JavaScript *os = get_singleton(); - Ref<InputEventScreenDrag> ev; - ev.instance(); - int lowest_id_index = -1; - for (int i = 0; i < p_event->numTouches; ++i) { - - const EmscriptenTouchPoint &touch = p_event->touches[i]; - if (lowest_id_index == -1 || touch.identifier < p_event->touches[lowest_id_index].identifier) - lowest_id_index = i; - if (!touch.isChanged) - continue; - ev->set_index(touch.identifier); - ev->set_position(compute_position_in_canvas(touch.clientX, touch.clientY)); - Point2 &prev = os->touches[i]; - ev->set_relative(ev->get_position() - prev); - prev = ev->get_position(); - - os->input->parse_input_event(ev); - } - return true; -} - -// Gamepad - -EM_BOOL OS_JavaScript::gamepad_change_callback(int p_event_type, const EmscriptenGamepadEvent *p_event, void *p_user_data) { - - InputDefault *input = get_singleton()->input; - if (p_event_type == EMSCRIPTEN_EVENT_GAMEPADCONNECTED) { - - String guid = ""; - if (String::utf8(p_event->mapping) == "standard") - guid = "Default HTML5 Gamepad"; - input->joy_connection_changed(p_event->index, true, String::utf8(p_event->id), guid); - } else { - input->joy_connection_changed(p_event->index, false, ""); - } - return true; -} - -void OS_JavaScript::process_joypads() { - - int joypad_count = emscripten_get_num_gamepads(); - for (int joypad = 0; joypad < joypad_count; joypad++) { - EmscriptenGamepadEvent state; - EMSCRIPTEN_RESULT query_result = emscripten_get_gamepad_status(joypad, &state); - // Chromium reserves gamepads slots, so NO_DATA is an expected result. - ERR_CONTINUE(query_result != EMSCRIPTEN_RESULT_SUCCESS && - query_result != EMSCRIPTEN_RESULT_NO_DATA); - if (query_result == EMSCRIPTEN_RESULT_SUCCESS && state.connected) { - - int button_count = MIN(state.numButtons, 18); - int axis_count = MIN(state.numAxes, 8); - for (int button = 0; button < button_count; button++) { - - float value = state.analogButton[button]; - if (String::utf8(state.mapping) == "standard" && (button == JOY_ANALOG_L2 || button == JOY_ANALOG_R2)) { - InputDefault::JoyAxis joy_axis; - joy_axis.min = 0; - joy_axis.value = value; - input->joy_axis(joypad, button, joy_axis); - } else { - input->joy_button(joypad, button, value); - } - } - for (int axis = 0; axis < axis_count; axis++) { - - InputDefault::JoyAxis joy_axis; - joy_axis.min = -1; - joy_axis.value = state.axis[axis]; - input->joy_axis(joypad, axis, joy_axis); - } - } - } -} - -bool OS_JavaScript::is_joy_known(int p_device) { - - return input->is_joy_mapped(p_device); -} - -String OS_JavaScript::get_joy_guid(int p_device) const { - - return input->get_joy_guid_remapped(p_device); -} - -// Video - -int OS_JavaScript::get_video_driver_count() const { - - return VIDEO_DRIVER_MAX; -} - -const char *OS_JavaScript::get_video_driver_name(int p_driver) const { - - switch (p_driver) { - case VIDEO_DRIVER_GLES2: - return "GLES2"; - } - ERR_FAIL_V_MSG(nullptr, "Invalid video driver index: " + itos(p_driver) + "."); -} - // Audio int OS_JavaScript::get_audio_driver_count() const { @@ -839,192 +67,34 @@ const char *OS_JavaScript::get_audio_driver_name(int p_driver) const { return "JavaScript"; } -// Clipboard -extern "C" EMSCRIPTEN_KEEPALIVE void update_clipboard(const char *p_text) { - // Only call set_clipboard from OS (sets local clipboard) - OS::get_singleton()->OS::set_clipboard(p_text); -} - -void OS_JavaScript::set_clipboard(const String &p_text) { - OS::set_clipboard(p_text); - /* clang-format off */ - int err = EM_ASM_INT({ - var text = UTF8ToString($0); - if (!navigator.clipboard || !navigator.clipboard.writeText) - return 1; - navigator.clipboard.writeText(text).catch(function(e) { - // Setting OS clipboard is only possible from an input callback. - console.error("Setting OS clipboard is only possible from an input callback for the HTML5 plafrom. Exception:", e); - }); - return 0; - }, p_text.utf8().get_data()); - /* clang-format on */ - ERR_FAIL_COND_MSG(err, "Clipboard API is not supported."); -} - -String OS_JavaScript::get_clipboard() const { - /* clang-format off */ - EM_ASM({ - try { - navigator.clipboard.readText().then(function (result) { - ccall('update_clipboard', 'void', ['string'], [result]); - }).catch(function (e) { - // Fail graciously. - }); - } catch (e) { - // Fail graciously. - } - }); - /* clang-format on */ - return this->OS::get_clipboard(); -} - // Lifecycle -int OS_JavaScript::get_current_video_driver() const { - return video_driver_index; -} - -void OS_JavaScript::initialize_core() { +void OS_JavaScript::initialize() { OS_Unix::initialize_core(); FileAccess::make_default<FileAccessBufferedFA<FileAccessUnix>>(FileAccess::ACCESS_RESOURCES); -} - -Error OS_JavaScript::initialize(const VideoMode &p_desired, int p_video_driver, int p_audio_driver) { - -#if 0 - EmscriptenWebGLContextAttributes attributes; - emscripten_webgl_init_context_attributes(&attributes); - attributes.alpha = GLOBAL_GET("display/window/per_pixel_transparency/allowed"); - attributes.antialias = false; - ERR_FAIL_INDEX_V(p_video_driver, VIDEO_DRIVER_MAX, ERR_INVALID_PARAMETER); - - if (p_desired.layered) { - set_window_per_pixel_transparency_enabled(true); - } - - bool gl_initialization_error = false; - - if (RasterizerGLES2::is_viable() == OK) { - attributes.majorVersion = 1; - RasterizerGLES2::register_config(); - RasterizerGLES2::make_current(); - } else { - gl_initialization_error = true; - } - - EMSCRIPTEN_WEBGL_CONTEXT_HANDLE ctx = emscripten_webgl_create_context(GODOT_CANVAS_SELECTOR, &attributes); - if (emscripten_webgl_make_context_current(ctx) != EMSCRIPTEN_RESULT_SUCCESS) { - gl_initialization_error = true; - } + DisplayServerJavaScript::register_javascript_driver(); - if (gl_initialization_error) { - OS::get_singleton()->alert("Your browser does not seem to support WebGL. Please update your browser version.", - "Unable to initialize video driver"); - return ERR_UNAVAILABLE; - } - - video_driver_index = p_video_driver; - - video_mode = p_desired; - // fullscreen_change_callback will correct this if the request is successful. - video_mode.fullscreen = false; - // Emscripten only attempts fullscreen requests if the user input callback - // was registered through one its own functions, so request manually for - // start-up fullscreen. - if (p_desired.fullscreen) { - /* clang-format off */ - EM_ASM({ - const canvas = Module.canvas; - (canvas.requestFullscreen || canvas.msRequestFullscreen || - canvas.mozRequestFullScreen || canvas.mozRequestFullscreen || - canvas.webkitRequestFullscreen - ).call(canvas); - }); - /* clang-format on */ - } - /* clang-format off */ - if (EM_ASM_INT_V({ return Module.resizeCanvasOnStart })) { - /* clang-format on */ - set_window_size(Size2(video_mode.width, video_mode.height)); - } else { - set_window_size(get_window_size()); - } +#ifdef MODULE_WEBSOCKET_ENABLED + EngineDebugger::register_uri_handler("ws://", RemoteDebuggerPeerWebSocket::create); + EngineDebugger::register_uri_handler("wss://", RemoteDebuggerPeerWebSocket::create); #endif - RasterizerDummy::make_current(); // TODO GLES2 in Godot 4.0... or webgpu? char locale_ptr[16]; /* clang-format off */ - EM_ASM_ARGS({ - stringToUTF8(Module.locale, $0, 16); + EM_ASM({ + stringToUTF8(Module['locale'], $0, 16); }, locale_ptr); /* clang-format on */ setenv("LANG", locale_ptr, true); +} - AudioDriverManager::initialize(p_audio_driver); - RenderingServer *rendering_server = memnew(RenderingServerRaster()); - input = memnew(InputDefault); - - EMSCRIPTEN_RESULT result; -#define EM_CHECK(ev) \ - if (result != EMSCRIPTEN_RESULT_SUCCESS) \ - ERR_PRINT("Error while setting " #ev " callback: Code " + itos(result)); -#define SET_EM_CALLBACK(target, ev, cb) \ - result = emscripten_set_##ev##_callback(target, nullptr, true, &cb); \ - EM_CHECK(ev) -#define SET_EM_CALLBACK_NOTARGET(ev, cb) \ - result = emscripten_set_##ev##_callback(nullptr, true, &cb); \ - EM_CHECK(ev) - // These callbacks from Emscripten's html5.h suffice to access most - // JavaScript APIs. For APIs that are not (sufficiently) exposed, EM_ASM - // is used below. - SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_WINDOW, mousemove, mousemove_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, mousedown, mouse_button_callback) - SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_WINDOW, mouseup, mouse_button_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, wheel, wheel_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, touchstart, touch_press_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, touchmove, touchmove_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, touchend, touch_press_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, touchcancel, touch_press_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, keydown, keydown_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, keypress, keypress_callback) - SET_EM_CALLBACK(GODOT_CANVAS_SELECTOR, keyup, keyup_callback) - SET_EM_CALLBACK(EMSCRIPTEN_EVENT_TARGET_DOCUMENT, fullscreenchange, fullscreen_change_callback) - SET_EM_CALLBACK_NOTARGET(gamepadconnected, gamepad_change_callback) - SET_EM_CALLBACK_NOTARGET(gamepaddisconnected, gamepad_change_callback) -#undef SET_EM_CALLBACK_NOTARGET -#undef SET_EM_CALLBACK -#undef EM_CHECK - - /* clang-format off */ - EM_ASM_ARGS({ - const send_notification = cwrap('send_notification', null, ['number']); - const notifications = arguments; - (['mouseover', 'mouseleave', 'focus', 'blur']).forEach(function(event, index) { - Module.canvas.addEventListener(event, send_notification.bind(null, notifications[index])); - }); - // Clipboard - const update_clipboard = cwrap('update_clipboard', null, ['string']); - window.addEventListener('paste', function(evt) { - update_clipboard(evt.clipboardData.getData('text')); - }, true); - }, - NOTIFICATION_WM_MOUSE_ENTER, - NOTIFICATION_WM_MOUSE_EXIT, - NOTIFICATION_WM_FOCUS_IN, - NOTIFICATION_WM_FOCUS_OUT - ); - /* clang-format on */ - - rendering_server->init(); - - return OK; +void OS_JavaScript::resume_audio() { + audio_driver_javascript.resume(); } void OS_JavaScript::set_main_loop(MainLoop *p_main_loop) { main_loop = p_main_loop; - input->set_main_loop(p_main_loop); } MainLoop *OS_JavaScript::get_main_loop() const { @@ -1032,12 +102,6 @@ MainLoop *OS_JavaScript::get_main_loop() const { return main_loop; } -void OS_JavaScript::run_async() { - - main_loop->init(); - emscripten_set_main_loop(main_loop_callback, -1, false); -} - void OS_JavaScript::main_loop_callback() { get_singleton()->main_loop_iterate(); @@ -1063,50 +127,51 @@ bool OS_JavaScript::main_loop_iterate() { } } - if (emscripten_sample_gamepad_data() == EMSCRIPTEN_RESULT_SUCCESS) - process_joypads(); - - if (just_exited_fullscreen) { - if (window_maximized) { - EmscriptenFullscreenStrategy strategy; - strategy.scaleMode = EMSCRIPTEN_FULLSCREEN_SCALE_STRETCH; - strategy.canvasResolutionScaleMode = EMSCRIPTEN_FULLSCREEN_CANVAS_SCALE_STDDEF; - strategy.filteringMode = EMSCRIPTEN_FULLSCREEN_FILTERING_DEFAULT; - strategy.canvasResizedCallback = nullptr; - emscripten_enter_soft_fullscreen(GODOT_CANVAS_SELECTOR, &strategy); - } else { - emscripten_set_canvas_element_size(GODOT_CANVAS_SELECTOR, windowed_size.width, windowed_size.height); - } - just_exited_fullscreen = false; - } - - int canvas[2]; - emscripten_get_canvas_element_size(GODOT_CANVAS_SELECTOR, canvas, canvas + 1); - video_mode.width = canvas[0]; - video_mode.height = canvas[1]; - if (!window_maximized && !video_mode.fullscreen && !just_exited_fullscreen && !entering_fullscreen) { - windowed_size.width = canvas[0]; - windowed_size.height = canvas[1]; - } + DisplayServer::get_singleton()->process_events(); return Main::iteration(); } void OS_JavaScript::delete_main_loop() { - memdelete(main_loop); + if (main_loop) { + memdelete(main_loop); + } + main_loop = nullptr; +} + +void OS_JavaScript::finalize_async() { + finalizing = true; + audio_driver_javascript.finish_async(); } void OS_JavaScript::finalize() { - memdelete(input); + delete_main_loop(); } // Miscellaneous Error OS_JavaScript::execute(const String &p_path, const List<String> &p_arguments, bool p_blocking, ProcessID *r_child_id, String *r_pipe, int *r_exitcode, bool read_stderr, Mutex *p_pipe_mutex) { - ERR_FAIL_V_MSG(ERR_UNAVAILABLE, "OS::execute() is not available on the HTML5 platform."); + Array args; + for (const List<String>::Element *E = p_arguments.front(); E; E = E->next()) { + args.push_back(E->get()); + } + String json_args = JSON::print(args); + /* clang-format off */ + int failed = EM_ASM_INT({ + const json_args = UTF8ToString($0); + const args = JSON.parse(json_args); + if (Module["onExecute"]) { + Module["onExecute"](args); + return 0; + } + return 1; + }, json_args.utf8().get_data()); + /* clang-format on */ + ERR_FAIL_COND_V_MSG(failed, ERR_UNAVAILABLE, "OS::execute() must be implemented in JavaScript via 'engine.setOnExecute' if required."); + return OK; } Error OS_JavaScript::kill(const ProcessID &p_pid) { @@ -1119,14 +184,6 @@ int OS_JavaScript::get_process_id() const { ERR_FAIL_V_MSG(0, "OS::get_process_id() is not available on the HTML5 platform."); } -extern "C" EMSCRIPTEN_KEEPALIVE void send_notification(int p_notification) { - - if (p_notification == NOTIFICATION_WM_MOUSE_ENTER || p_notification == NOTIFICATION_WM_MOUSE_EXIT) { - cursor_inside_canvas = p_notification == NOTIFICATION_WM_MOUSE_ENTER; - } - OS_JavaScript::get_singleton()->get_main_loop()->notification(p_notification); -} - bool OS_JavaScript::_check_internal_feature_support(const String &p_feature) { if (p_feature == "HTML5" || p_feature == "web") @@ -1140,72 +197,6 @@ bool OS_JavaScript::_check_internal_feature_support(const String &p_feature) { return false; } -void OS_JavaScript::alert(const String &p_alert, const String &p_title) { - - /* clang-format off */ - EM_ASM_({ - window.alert(UTF8ToString($0)); - }, p_alert.utf8().get_data()); - /* clang-format on */ -} - -void OS_JavaScript::set_window_title(const String &p_title) { - - /* clang-format off */ - EM_ASM_({ - document.title = UTF8ToString($0); - }, p_title.utf8().get_data()); - /* clang-format on */ -} - -void OS_JavaScript::set_icon(const Ref<Image> &p_icon) { - - ERR_FAIL_COND(p_icon.is_null()); - Ref<Image> icon = p_icon; - if (icon->is_compressed()) { - icon = icon->duplicate(); - ERR_FAIL_COND(icon->decompress() != OK); - } - if (icon->get_format() != Image::FORMAT_RGBA8) { - if (icon == p_icon) - icon = icon->duplicate(); - icon->convert(Image::FORMAT_RGBA8); - } - - png_image png_meta; - memset(&png_meta, 0, sizeof png_meta); - png_meta.version = PNG_IMAGE_VERSION; - png_meta.width = icon->get_width(); - png_meta.height = icon->get_height(); - png_meta.format = PNG_FORMAT_RGBA; - - PackedByteArray png; - size_t len; - PackedByteArray data = icon->get_data(); - ERR_FAIL_COND(!png_image_write_get_memory_size(png_meta, len, 0, data.ptr(), 0, nullptr)); - - png.resize(len); - ERR_FAIL_COND(!png_image_write_to_memory(&png_meta, png.ptrw(), &len, 0, data.ptr(), 0, nullptr)); - - /* clang-format off */ - EM_ASM_ARGS({ - var PNG_PTR = $0; - var PNG_LEN = $1; - - var png = new Blob([HEAPU8.slice(PNG_PTR, PNG_PTR + PNG_LEN)], { type: "image/png" }); - var url = URL.createObjectURL(png); - var link = document.getElementById('-gd-engine-icon'); - if (link === null) { - link = document.createElement('link'); - link.rel = 'icon'; - link.id = '-gd-engine-icon'; - document.head.appendChild(link); - } - link.href = url; - }, png.ptr(), len); - /* clang-format on */ -} - String OS_JavaScript::get_executable_path() const { return OS::get_executable_path(); @@ -1237,9 +228,19 @@ String OS_JavaScript::get_user_data_dir() const { return "/userfs"; }; -String OS_JavaScript::get_resource_dir() const { +String OS_JavaScript::get_cache_path() const { + + return "/home/web_user/.cache"; +} + +String OS_JavaScript::get_config_path() const { - return "/"; + return "/home/web_user/.config"; +} + +String OS_JavaScript::get_data_path() const { + + return "/home/web_user/.local/share"; } void OS_JavaScript::file_access_close_callback(const String &p_file, int p_flags) { @@ -1267,27 +268,10 @@ OS_JavaScript *OS_JavaScript::get_singleton() { return static_cast<OS_JavaScript *>(OS::get_singleton()); } -OS_JavaScript::OS_JavaScript(int p_argc, char *p_argv[]) { - - List<String> arguments; - for (int i = 1; i < p_argc; i++) { - arguments.push_back(String::utf8(p_argv[i])); - } - set_cmdline(p_argv[0], arguments); - - last_click_button_index = -1; - last_click_ms = 0; - last_click_pos = Point2(-100, -100); - - window_maximized = false; - entering_fullscreen = false; - just_exited_fullscreen = false; - transparency_enabled = false; - - main_loop = nullptr; +void OS_JavaScript::initialize_joypads() { +} - idb_available = false; - sync_wait_time = -1; +OS_JavaScript::OS_JavaScript() { AudioDriverManager::add_driver(&audio_driver_javascript); diff --git a/platform/javascript/os_javascript.h b/platform/javascript/os_javascript.h index feb0f0651a..0a81a4a5b3 100644 --- a/platform/javascript/os_javascript.h +++ b/platform/javascript/os_javascript.h @@ -35,64 +35,25 @@ #include "core/input/input.h" #include "drivers/unix/os_unix.h" #include "servers/audio_server.h" -#include "servers/rendering/rasterizer.h" #include <emscripten/html5.h> class OS_JavaScript : public OS_Unix { - VideoMode video_mode; - Vector2 windowed_size; - bool window_maximized; - bool entering_fullscreen; - bool just_exited_fullscreen; - bool transparency_enabled; - - InputDefault *input; - Ref<InputEventKey> deferred_key_event; - CursorShape cursor_shape; - String cursors[CURSOR_MAX]; - Map<CursorShape, Vector<Variant>> cursors_cache; - Point2 touches[32]; - - Point2i last_click_pos; - double last_click_ms; - int last_click_button_index; - - MainLoop *main_loop; - int video_driver_index; + MainLoop *main_loop = nullptr; AudioDriverJavaScript audio_driver_javascript; - bool idb_available; - int64_t sync_wait_time; - int64_t last_sync_check_time; - - static EM_BOOL fullscreen_change_callback(int p_event_type, const EmscriptenFullscreenChangeEvent *p_event, void *p_user_data); - - static EM_BOOL keydown_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); - static EM_BOOL keypress_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); - static EM_BOOL keyup_callback(int p_event_type, const EmscriptenKeyboardEvent *p_event, void *p_user_data); - - static EM_BOOL mousemove_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data); - static EM_BOOL mouse_button_callback(int p_event_type, const EmscriptenMouseEvent *p_event, void *p_user_data); - - static EM_BOOL wheel_callback(int p_event_type, const EmscriptenWheelEvent *p_event, void *p_user_data); - - static EM_BOOL touch_press_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data); - static EM_BOOL touchmove_callback(int p_event_type, const EmscriptenTouchEvent *p_event, void *p_user_data); - - static EM_BOOL gamepad_change_callback(int p_event_type, const EmscriptenGamepadEvent *p_event, void *p_user_data); - void process_joypads(); + bool finalizing = false; + bool idb_available = false; + int64_t sync_wait_time = -1; + int64_t last_sync_check_time = -1; static void main_loop_callback(); static void file_access_close_callback(const String &p_file, int p_flags); protected: - virtual int get_current_video_driver() const; - - virtual void initialize_core(); - virtual Error initialize(const VideoMode &p_desired, int p_video_driver, int p_audio_driver); + virtual void initialize(); virtual void set_main_loop(MainLoop *p_main_loop); virtual void delete_main_loop(); @@ -105,65 +66,38 @@ public: // Override return type to make writing static callbacks less tedious. static OS_JavaScript *get_singleton(); - virtual void set_video_mode(const VideoMode &p_video_mode, int p_screen = 0); - virtual VideoMode get_video_mode(int p_screen = 0) const; - virtual void get_fullscreen_mode_list(List<VideoMode> *p_list, int p_screen = 0) const; - - virtual void set_window_size(const Size2); - virtual Size2 get_window_size() const; - virtual void set_window_maximized(bool p_enabled); - virtual bool is_window_maximized() const; - virtual void set_window_fullscreen(bool p_enabled); - virtual bool is_window_fullscreen() const; - virtual Size2 get_screen_size(int p_screen = -1) const; - - virtual Point2 get_mouse_position() const; - virtual int get_mouse_button_state() const; - virtual void set_cursor_shape(CursorShape p_shape); - virtual void set_custom_mouse_cursor(const RES &p_cursor, CursorShape p_shape, const Vector2 &p_hotspot); - virtual void set_mouse_mode(MouseMode p_mode); - virtual MouseMode get_mouse_mode() const; - - virtual bool get_window_per_pixel_transparency_enabled() const; - virtual void set_window_per_pixel_transparency_enabled(bool p_enabled); + virtual void initialize_joypads(); virtual bool has_touchscreen_ui_hint() const; - virtual bool is_joy_known(int p_device); - virtual String get_joy_guid(int p_device) const; - - virtual int get_video_driver_count() const; - virtual const char *get_video_driver_name(int p_driver) const; - virtual int get_audio_driver_count() const; virtual const char *get_audio_driver_name(int p_driver) const; - virtual void set_clipboard(const String &p_text); - virtual String get_clipboard() const; - virtual MainLoop *get_main_loop() const; - void run_async(); + void finalize_async(); bool main_loop_iterate(); virtual Error execute(const String &p_path, const List<String> &p_arguments, bool p_blocking = true, ProcessID *r_child_id = nullptr, String *r_pipe = nullptr, int *r_exitcode = nullptr, bool read_stderr = false, Mutex *p_pipe_mutex = nullptr); virtual Error kill(const ProcessID &p_pid); virtual int get_process_id() const; - virtual void alert(const String &p_alert, const String &p_title = "ALERT!"); - virtual void set_window_title(const String &p_title); - virtual void set_icon(const Ref<Image> &p_icon); String get_executable_path() const; virtual Error shell_open(String p_uri); virtual String get_name() const; virtual bool can_draw() const; - virtual String get_resource_dir() const; + virtual String get_cache_path() const; + virtual String get_config_path() const; + virtual String get_data_path() const; virtual String get_user_data_dir() const; void set_idb_available(bool p_idb_available); virtual bool is_userfs_persistent() const; - OS_JavaScript(int p_argc, char *p_argv[]); + void resume_audio(); + bool is_finalizing() { return finalizing; } + + OS_JavaScript(); }; #endif diff --git a/platform/linuxbsd/detect_prime_x11.cpp b/platform/linuxbsd/detect_prime_x11.cpp index 1bec65ff04..1e46d3222d 100644 --- a/platform/linuxbsd/detect_prime_x11.cpp +++ b/platform/linuxbsd/detect_prime_x11.cpp @@ -178,7 +178,8 @@ int detect_prime() { close(fdset[0]); - if (i) setenv("DRI_PRIME", "1", 1); + if (i) + setenv("DRI_PRIME", "1", 1); create_context(); const char *vendor = (const char *)glGetString(GL_VENDOR); diff --git a/platform/linuxbsd/display_server_x11.cpp b/platform/linuxbsd/display_server_x11.cpp index dd9298d667..f016892453 100644 --- a/platform/linuxbsd/display_server_x11.cpp +++ b/platform/linuxbsd/display_server_x11.cpp @@ -551,7 +551,8 @@ int DisplayServerX11::get_screen_count() const { // Using Xinerama Extension int event_base, error_base; const Bool ext_okay = XineramaQueryExtension(x11_display, &event_base, &error_base); - if (!ext_okay) return 0; + if (!ext_okay) + return 0; int count; XineramaScreenInfo *xsi = XineramaQueryScreens(x11_display, &count); @@ -600,11 +601,13 @@ Rect2i DisplayServerX11::screen_get_usable_rect(int p_screen) const { // Using Xinerama Extension int event_base, error_base; const Bool ext_okay = XineramaQueryExtension(x11_display, &event_base, &error_base); - if (!ext_okay) return Rect2i(0, 0, 0, 0); + if (!ext_okay) + return Rect2i(0, 0, 0, 0); int count; XineramaScreenInfo *xsi = XineramaQueryScreens(x11_display, &count); - if (p_screen >= count) return Rect2i(0, 0, 0, 0); + if (p_screen >= count) + return Rect2i(0, 0, 0, 0); Rect2i rect = Rect2i(xsi[p_screen].x_org, xsi[p_screen].y_org, xsi[p_screen].width, xsi[p_screen].height); XFree(xsi); @@ -827,7 +830,8 @@ void DisplayServerX11::window_set_current_screen(int p_screen, WindowID p_window WindowData &wd = windows[p_window]; int count = get_screen_count(); - if (p_screen >= count) return; + if (p_screen >= count) + return; if (window_get_mode(p_window) == WINDOW_MODE_FULLSCREEN) { Point2i position = screen_get_position(p_screen); diff --git a/platform/linuxbsd/joypad_linux.cpp b/platform/linuxbsd/joypad_linux.cpp index 5ceea788e0..f57f74907f 100644 --- a/platform/linuxbsd/joypad_linux.cpp +++ b/platform/linuxbsd/joypad_linux.cpp @@ -463,7 +463,8 @@ void JoypadLinux::process_joypads() { } for (int i = 0; i < JOYPADS_MAX; i++) { - if (joypads[i].fd == -1) continue; + if (joypads[i].fd == -1) + continue; input_event events[32]; Joypad *joy = &joypads[i]; diff --git a/platform/linuxbsd/key_mapping_x11.h b/platform/linuxbsd/key_mapping_x11.h index 10db43bcc4..8f5e01a3c2 100644 --- a/platform/linuxbsd/key_mapping_x11.h +++ b/platform/linuxbsd/key_mapping_x11.h @@ -41,7 +41,7 @@ #include "core/os/keyboard.h" class KeyMappingX11 { - KeyMappingX11(){}; + KeyMappingX11() {} public: static unsigned int get_keycode(KeySym p_keysym); diff --git a/platform/osx/display_server_osx.mm b/platform/osx/display_server_osx.mm index 9d92992332..71e4584dac 100644 --- a/platform/osx/display_server_osx.mm +++ b/platform/osx/display_server_osx.mm @@ -1656,7 +1656,8 @@ String DisplayServerOSX::global_menu_get_item_submenu(const String &p_menu_root, const NSMenu *sub_menu = [menu_item submenu]; if (sub_menu) { for (Map<String, NSMenu *>::Element *E = submenu.front(); E; E = E->next()) { - if (E->get() == sub_menu) return E->key(); + if (E->get() == sub_menu) + return E->key(); } } } @@ -2479,7 +2480,8 @@ void DisplayServerOSX::_set_window_per_pixel_transparency_enabled(bool p_enabled ERR_FAIL_COND(!windows.has(p_window)); WindowData &wd = windows[p_window]; - if (!OS_OSX::get_singleton()->is_layered_allowed()) return; + if (!OS_OSX::get_singleton()->is_layered_allowed()) + return; if (wd.layered_window != p_enabled) { if (p_enabled) { [wd.window_object setBackgroundColor:[NSColor clearColor]]; @@ -2774,23 +2776,57 @@ void DisplayServerOSX::cursor_set_shape(CursorShape p_shape) { [cursors[p_shape] set]; } else { switch (p_shape) { - case CURSOR_ARROW: [[NSCursor arrowCursor] set]; break; - case CURSOR_IBEAM: [[NSCursor IBeamCursor] set]; break; - case CURSOR_POINTING_HAND: [[NSCursor pointingHandCursor] set]; break; - case CURSOR_CROSS: [[NSCursor crosshairCursor] set]; break; - case CURSOR_WAIT: [[NSCursor arrowCursor] set]; break; - case CURSOR_BUSY: [[NSCursor arrowCursor] set]; break; - case CURSOR_DRAG: [[NSCursor closedHandCursor] set]; break; - case CURSOR_CAN_DROP: [[NSCursor openHandCursor] set]; break; - case CURSOR_FORBIDDEN: [[NSCursor operationNotAllowedCursor] set]; break; - case CURSOR_VSIZE: [_cursorFromSelector(@selector(_windowResizeNorthSouthCursor), @selector(resizeUpDownCursor)) set]; break; - case CURSOR_HSIZE: [_cursorFromSelector(@selector(_windowResizeEastWestCursor), @selector(resizeLeftRightCursor)) set]; break; - case CURSOR_BDIAGSIZE: [_cursorFromSelector(@selector(_windowResizeNorthEastSouthWestCursor)) set]; break; - case CURSOR_FDIAGSIZE: [_cursorFromSelector(@selector(_windowResizeNorthWestSouthEastCursor)) set]; break; - case CURSOR_MOVE: [[NSCursor arrowCursor] set]; break; - case CURSOR_VSPLIT: [[NSCursor resizeUpDownCursor] set]; break; - case CURSOR_HSPLIT: [[NSCursor resizeLeftRightCursor] set]; break; - case CURSOR_HELP: [_cursorFromSelector(@selector(_helpCursor)) set]; break; + case CURSOR_ARROW: + [[NSCursor arrowCursor] set]; + break; + case CURSOR_IBEAM: + [[NSCursor IBeamCursor] set]; + break; + case CURSOR_POINTING_HAND: + [[NSCursor pointingHandCursor] set]; + break; + case CURSOR_CROSS: + [[NSCursor crosshairCursor] set]; + break; + case CURSOR_WAIT: + [[NSCursor arrowCursor] set]; + break; + case CURSOR_BUSY: + [[NSCursor arrowCursor] set]; + break; + case CURSOR_DRAG: + [[NSCursor closedHandCursor] set]; + break; + case CURSOR_CAN_DROP: + [[NSCursor openHandCursor] set]; + break; + case CURSOR_FORBIDDEN: + [[NSCursor operationNotAllowedCursor] set]; + break; + case CURSOR_VSIZE: + [_cursorFromSelector(@selector(_windowResizeNorthSouthCursor), @selector(resizeUpDownCursor)) set]; + break; + case CURSOR_HSIZE: + [_cursorFromSelector(@selector(_windowResizeEastWestCursor), @selector(resizeLeftRightCursor)) set]; + break; + case CURSOR_BDIAGSIZE: + [_cursorFromSelector(@selector(_windowResizeNorthEastSouthWestCursor)) set]; + break; + case CURSOR_FDIAGSIZE: + [_cursorFromSelector(@selector(_windowResizeNorthWestSouthEastCursor)) set]; + break; + case CURSOR_MOVE: + [[NSCursor arrowCursor] set]; + break; + case CURSOR_VSPLIT: + [[NSCursor resizeUpDownCursor] set]; + break; + case CURSOR_HSPLIT: + [[NSCursor resizeLeftRightCursor] set]; + break; + case CURSOR_HELP: + [_cursorFromSelector(@selector(_helpCursor)) set]; + break; default: { } } diff --git a/platform/osx/joypad_osx.cpp b/platform/osx/joypad_osx.cpp index 7f5ec05967..0f50ba63c6 100644 --- a/platform/osx/joypad_osx.cpp +++ b/platform/osx/joypad_osx.cpp @@ -374,7 +374,8 @@ bool joypad::check_ff_features() { if (ret == FF_OK && (features.supportedEffects & FFCAP_ET_CONSTANTFORCE)) { uint32_t val; ret = FFDeviceGetForceFeedbackProperty(ff_device, FFPROP_FFGAIN, &val, sizeof(val)); - if (ret != FF_OK) return false; + if (ret != FF_OK) + return false; int num_axes = features.numFfAxes; ff_axes = (DWORD *)memalloc(sizeof(DWORD) * num_axes); ff_directions = (LONG *)memalloc(sizeof(LONG) * num_axes); @@ -509,14 +510,16 @@ void JoypadOSX::joypad_vibration_stop(int p_id, uint64_t p_timestamp) { int JoypadOSX::get_joy_index(int p_id) const { for (int i = 0; i < device_list.size(); i++) { - if (device_list[i].id == p_id) return i; + if (device_list[i].id == p_id) + return i; } return -1; } int JoypadOSX::get_joy_ref(IOHIDDeviceRef p_device) const { for (int i = 0; i < device_list.size(); i++) { - if (device_list[i].device_ref == p_device) return i; + if (device_list[i].device_ref == p_device) + return i; } return -1; } diff --git a/platform/uwp/app.cpp b/platform/uwp/app.cpp index d3870b0b6c..988f958739 100644 --- a/platform/uwp/app.cpp +++ b/platform/uwp/app.cpp @@ -78,16 +78,6 @@ public: return 0; } -App::App() : - mWindowClosed(false), - mWindowVisible(true), - mWindowWidth(0), - mWindowHeight(0), - mEglDisplay(EGL_NO_DISPLAY), - mEglContext(EGL_NO_CONTEXT), - mEglSurface(EGL_NO_SURFACE) { -} - // The first method called when the IFrameworkView is being created. void App::Initialize(CoreApplicationView ^ applicationView) { // Register event handlers for app lifecycle. This example includes Activated, so that we diff --git a/platform/uwp/app.h b/platform/uwp/app.h index b7265ad086..5cffe378b1 100644 --- a/platform/uwp/app.h +++ b/platform/uwp/app.h @@ -45,7 +45,7 @@ namespace GodotUWP ref class App sealed : public Windows::ApplicationModel::Core::IFrameworkView { public: - App(); + App() {} // IFrameworkView Methods. virtual void Initialize(Windows::ApplicationModel::Core::CoreApplicationView^ applicationView); @@ -92,14 +92,14 @@ namespace GodotUWP char** get_command_line(unsigned int* out_argc); - bool mWindowClosed; - bool mWindowVisible; - GLsizei mWindowWidth; - GLsizei mWindowHeight; + bool mWindowClosed = false; + bool mWindowVisible = true; + GLsizei mWindowWidth = 0; + GLsizei mWindowHeight = 0; - EGLDisplay mEglDisplay; - EGLContext mEglContext; - EGLSurface mEglSurface; + EGLDisplay mEglDisplay = EGL_NO_DISPLAY; + EGLContext mEglContext = EGL_NO_CONTEXT; + EGLSurface mEglSurface = EGL_NO_SURFACE; CoreWindow^ window; OS_UWP* os; diff --git a/platform/uwp/export/export.cpp b/platform/uwp/export/export.cpp index 06bf738dc1..50136ccb66 100644 --- a/platform/uwp/export/export.cpp +++ b/platform/uwp/export/export.cpp @@ -115,21 +115,15 @@ class AppxPackager { struct FileMeta { String name; - int lfh_size; - bool compressed; - size_t compressed_size; - size_t uncompressed_size; + int lfh_size = 0; + bool compressed = false; + size_t compressed_size = 0; + size_t uncompressed_size = 0; Vector<BlockHash> hashes; - uLong file_crc32; - ZPOS64_T zip_offset; - - FileMeta() : - lfh_size(0), - compressed(false), - compressed_size(0), - uncompressed_size(0), - file_crc32(0), - zip_offset(0) {} + uLong file_crc32 = 0; + ZPOS64_T zip_offset = 0; + + FileMeta() {} }; String progress_task; @@ -265,7 +259,8 @@ void AppxPackager::make_content_types(const String &p_path) { String ext = file_metadata[i].name.get_extension(); - if (types.has(ext)) continue; + if (types.has(ext)) + continue; types[ext] = content_type(ext); @@ -664,8 +659,10 @@ class EditorExportPlatformUWP : public EditorExportPlatform { bool _valid_resource_name(const String &p_name) const { - if (p_name.empty()) return false; - if (p_name.ends_with(".")) return false; + if (p_name.empty()) + return false; + if (p_name.ends_with(".")) + return false; static const char *invalid_names[] = { "CON", "PRN", "AUX", "NUL", "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", @@ -675,7 +672,8 @@ class EditorExportPlatformUWP : public EditorExportPlatform { const char **t = invalid_names; while (*t) { - if (p_name == *t) return false; + if (p_name == *t) + return false; t++; } @@ -686,19 +684,25 @@ class EditorExportPlatformUWP : public EditorExportPlatform { Vector<String> parts = p_guid.split("-"); - if (parts.size() != 5) return false; - if (parts[0].length() != 8) return false; + if (parts.size() != 5) + return false; + if (parts[0].length() != 8) + return false; for (int i = 1; i < 4; i++) - if (parts[i].length() != 4) return false; - if (parts[4].length() != 12) return false; + if (parts[i].length() != 4) + return false; + if (parts[4].length() != 12) + return false; return true; } bool _valid_bgcolor(const String &p_color) const { - if (p_color.empty()) return true; - if (p_color.begins_with("#") && p_color.is_valid_html_color()) return true; + if (p_color.empty()) + return true; + if (p_color.begins_with("#") && p_color.is_valid_html_color()) + return true; // Colors from https://msdn.microsoft.com/en-us/library/windows/apps/dn934817.aspx static const char *valid_colors[] = { @@ -732,14 +736,15 @@ class EditorExportPlatformUWP : public EditorExportPlatform { const char **color = valid_colors; while (*color) { - if (p_color == *color) return true; + if (p_color == *color) + return true; color++; } return false; } - bool _valid_image(const StreamTexture *p_image, int p_width, int p_height) const { + bool _valid_image(const StreamTexture2D *p_image, int p_width, int p_height) const { if (!p_image) { return false; @@ -876,27 +881,28 @@ class EditorExportPlatformUWP : public EditorExportPlatform { Vector<uint8_t> _get_image_data(const Ref<EditorExportPreset> &p_preset, const String &p_path) { Vector<uint8_t> data; - StreamTexture *image = nullptr; + StreamTexture2D *image = nullptr; if (p_path.find("StoreLogo") != -1) { - image = p_preset->get("images/store_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/store_logo"))); + image = p_preset->get("images/store_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/store_logo"))); } else if (p_path.find("Square44x44Logo") != -1) { - image = p_preset->get("images/square44x44_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/square44x44_logo"))); + image = p_preset->get("images/square44x44_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/square44x44_logo"))); } else if (p_path.find("Square71x71Logo") != -1) { - image = p_preset->get("images/square71x71_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/square71x71_logo"))); + image = p_preset->get("images/square71x71_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/square71x71_logo"))); } else if (p_path.find("Square150x150Logo") != -1) { - image = p_preset->get("images/square150x150_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/square150x150_logo"))); + image = p_preset->get("images/square150x150_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/square150x150_logo"))); } else if (p_path.find("Square310x310Logo") != -1) { - image = p_preset->get("images/square310x310_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/square310x310_logo"))); + image = p_preset->get("images/square310x310_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/square310x310_logo"))); } else if (p_path.find("Wide310x150Logo") != -1) { - image = p_preset->get("images/wide310x150_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/wide310x150_logo"))); + image = p_preset->get("images/wide310x150_logo").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/wide310x150_logo"))); } else if (p_path.find("SplashScreen") != -1) { - image = p_preset->get("images/splash_screen").is_zero() ? nullptr : Object::cast_to<StreamTexture>(((Object *)p_preset->get("images/splash_screen"))); + image = p_preset->get("images/splash_screen").is_zero() ? nullptr : Object::cast_to<StreamTexture2D>(((Object *)p_preset->get("images/splash_screen"))); } else { ERR_PRINT("Unable to load logo"); } - if (!image) return data; + if (!image) + return data; String tmp_path = EditorSettings::get_singleton()->get_cache_dir().plus_file("uwp_tmp_logo.png"); @@ -1054,13 +1060,13 @@ public: r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "orientation/portrait_flipped"), true)); r_options->push_back(ExportOption(PropertyInfo(Variant::STRING, "images/background_color"), "transparent")); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/store_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square44x44_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square71x71_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square150x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square310x310_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/wide310x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); - r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/splash_screen", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/store_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square44x44_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square71x71_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square150x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/square310x310_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/wide310x150_logo", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); + r_options->push_back(ExportOption(PropertyInfo(Variant::OBJECT, "images/splash_screen", PROPERTY_HINT_RESOURCE_TYPE, "StreamTexture2D"), Variant())); r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "tiles/show_name_on_square150x150"), false)); r_options->push_back(ExportOption(PropertyInfo(Variant::BOOL, "tiles/show_name_on_wide310x150"), false)); @@ -1161,37 +1167,37 @@ public: err += TTR("Invalid background color.") + "\n"; } - if (!p_preset->get("images/store_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/store_logo"))), 50, 50)) { + if (!p_preset->get("images/store_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/store_logo"))), 50, 50)) { valid = false; err += TTR("Invalid Store Logo image dimensions (should be 50x50).") + "\n"; } - if (!p_preset->get("images/square44x44_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/square44x44_logo"))), 44, 44)) { + if (!p_preset->get("images/square44x44_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/square44x44_logo"))), 44, 44)) { valid = false; err += TTR("Invalid square 44x44 logo image dimensions (should be 44x44).") + "\n"; } - if (!p_preset->get("images/square71x71_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/square71x71_logo"))), 71, 71)) { + if (!p_preset->get("images/square71x71_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/square71x71_logo"))), 71, 71)) { valid = false; err += TTR("Invalid square 71x71 logo image dimensions (should be 71x71).") + "\n"; } - if (!p_preset->get("images/square150x150_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/square150x150_logo"))), 150, 150)) { + if (!p_preset->get("images/square150x150_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/square150x150_logo"))), 150, 150)) { valid = false; err += TTR("Invalid square 150x150 logo image dimensions (should be 150x150).") + "\n"; } - if (!p_preset->get("images/square310x310_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/square310x310_logo"))), 310, 310)) { + if (!p_preset->get("images/square310x310_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/square310x310_logo"))), 310, 310)) { valid = false; err += TTR("Invalid square 310x310 logo image dimensions (should be 310x310).") + "\n"; } - if (!p_preset->get("images/wide310x150_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/wide310x150_logo"))), 310, 150)) { + if (!p_preset->get("images/wide310x150_logo").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/wide310x150_logo"))), 310, 150)) { valid = false; err += TTR("Invalid wide 310x150 logo image dimensions (should be 310x150).") + "\n"; } - if (!p_preset->get("images/splash_screen").is_zero() && !_valid_image((Object::cast_to<StreamTexture>((Object *)p_preset->get("images/splash_screen"))), 620, 300)) { + if (!p_preset->get("images/splash_screen").is_zero() && !_valid_image((Object::cast_to<StreamTexture2D>((Object *)p_preset->get("images/splash_screen"))), 620, 300)) { valid = false; err += TTR("Invalid splash screen image dimensions (should be 620x300).") + "\n"; } @@ -1301,7 +1307,8 @@ public: path = path.replace(".scale-100", ""); data = _get_image_data(p_preset, path); - if (data.size() > 0) do_read = false; + if (data.size() > 0) + do_read = false; } //read diff --git a/platform/uwp/joypad_uwp.cpp b/platform/uwp/joypad_uwp.cpp index 90df6fe5d7..066787e9d6 100644 --- a/platform/uwp/joypad_uwp.cpp +++ b/platform/uwp/joypad_uwp.cpp @@ -48,7 +48,8 @@ void JoypadUWP::process_controllers() { ControllerDevice &joy = controllers[i]; - if (!joy.connected) break; + if (!joy.connected) + break; switch (joy.type) { @@ -63,12 +64,12 @@ void JoypadUWP::process_controllers() { button_mask *= 2; } - input->joy_axis(joy.id, JOY_AXIS_0, axis_correct(reading.LeftThumbstickX)); - input->joy_axis(joy.id, JOY_AXIS_1, axis_correct(reading.LeftThumbstickY, true)); - input->joy_axis(joy.id, JOY_AXIS_2, axis_correct(reading.RightThumbstickX)); - input->joy_axis(joy.id, JOY_AXIS_3, axis_correct(reading.RightThumbstickY, true)); - input->joy_axis(joy.id, JOY_AXIS_4, axis_correct(reading.LeftTrigger, false, true)); - input->joy_axis(joy.id, JOY_AXIS_5, axis_correct(reading.RightTrigger, false, true)); + input->joy_axis(joy.id, JOY_AXIS_LEFT_X, axis_correct(reading.LeftThumbstickX)); + input->joy_axis(joy.id, JOY_AXIS_LEFT_Y, axis_correct(reading.LeftThumbstickY, true)); + input->joy_axis(joy.id, JOY_AXIS_RIGHT_X, axis_correct(reading.RightThumbstickX)); + input->joy_axis(joy.id, JOY_AXIS_RIGHT_Y, axis_correct(reading.RightThumbstickY, true)); + input->joy_axis(joy.id, JOY_AXIS_TRIGGER_LEFT, axis_correct(reading.LeftTrigger, false, true)); + input->joy_axis(joy.id, JOY_AXIS_TRIGGER_RIGHT, axis_correct(reading.RightTrigger, false, true)); uint64_t timestamp = input->get_joy_vibration_timestamp(joy.id); if (timestamp > joy.ff_timestamp) { diff --git a/platform/uwp/os_uwp.cpp b/platform/uwp/os_uwp.cpp index f5e989b370..1c83ebfdf7 100644 --- a/platform/uwp/os_uwp.cpp +++ b/platform/uwp/os_uwp.cpp @@ -123,7 +123,8 @@ bool OS_UWP::is_window_fullscreen() const { void OS_UWP::set_keep_screen_on(bool p_enabled) { - if (is_keep_screen_on() == p_enabled) return; + if (is_keep_screen_on() == p_enabled) + return; if (p_enabled) display_request->RequestActive(); @@ -826,7 +827,8 @@ void OS_UWP::run() { while (!force_quit) { CoreWindow::GetForCurrentThread()->Dispatcher->ProcessEvents(CoreProcessEventsOption::ProcessAllIfPresent); - if (managed_object->alert_close_handle) continue; + if (managed_object->alert_close_handle) + continue; process_events(); // get rid of pending events if (Main::iteration()) break; diff --git a/platform/windows/display_server_windows.cpp b/platform/windows/display_server_windows.cpp index 701cf69207..e794efb4fb 100644 --- a/platform/windows/display_server_windows.cpp +++ b/platform/windows/display_server_windows.cpp @@ -677,7 +677,8 @@ void DisplayServerWindows::window_set_position(const Point2i &p_position, Window ERR_FAIL_COND(!windows.has(p_window)); WindowData &wd = windows[p_window]; - if (wd.fullscreen) return; + if (wd.fullscreen) + return; #if 0 //wrong needs to account properly for decorations RECT r; @@ -1057,7 +1058,8 @@ void DisplayServerWindows::window_set_flag(WindowFlags p_flag, bool p_enabled, W wd.no_focus = p_enabled; _update_window_style(p_window); } break; - case WINDOW_FLAG_MAX: break; + case WINDOW_FLAG_MAX: + break; } } @@ -1088,7 +1090,8 @@ bool DisplayServerWindows::window_get_flag(WindowFlags p_flag, WindowID p_window return wd.no_focus; } break; - case WINDOW_FLAG_MAX: break; + case WINDOW_FLAG_MAX: + break; } return false; @@ -1474,19 +1477,22 @@ DisplayServer::LatinKeyboardVariant DisplayServerWindows::get_latin_keyboard_var int i = 0; while (azerty[i] != 0) { - if (azerty[i] == hex) return LATIN_KEYBOARD_AZERTY; + if (azerty[i] == hex) + return LATIN_KEYBOARD_AZERTY; i++; } i = 0; while (qwertz[i] != 0) { - if (qwertz[i] == hex) return LATIN_KEYBOARD_QWERTZ; + if (qwertz[i] == hex) + return LATIN_KEYBOARD_QWERTZ; i++; } i = 0; while (dvorak[i] != 0) { - if (dvorak[i] == hex) return LATIN_KEYBOARD_DVORAK; + if (dvorak[i] == hex) + return LATIN_KEYBOARD_DVORAK; i++; } diff --git a/platform/windows/joypad_windows.cpp b/platform/windows/joypad_windows.cpp index 2adf2a8652..821d4eb685 100644 --- a/platform/windows/joypad_windows.cpp +++ b/platform/windows/joypad_windows.cpp @@ -112,7 +112,8 @@ bool JoypadWindows::is_xinput_device(const GUID *p_guid) { return false; } dev_list = (PRAWINPUTDEVICELIST)malloc(sizeof(RAWINPUTDEVICELIST) * dev_list_count); - if (!dev_list) return false; + if (!dev_list) + return false; if (GetRawInputDeviceList(dev_list, &dev_list_count, sizeof(RAWINPUTDEVICELIST)) == (UINT)-1) { free(dev_list); @@ -267,7 +268,8 @@ void JoypadWindows::close_joypad(int id) { return; } - if (!d_joypads[id].attached) return; + if (!d_joypads[id].attached) + return; d_joypads[id].di_joy->Unacquire(); d_joypads[id].di_joy->Release(); @@ -345,12 +347,12 @@ void JoypadWindows::process_joypads() { button_mask = button_mask * 2; } - input->joy_axis(joy.id, JOY_AXIS_0, axis_correct(joy.state.Gamepad.sThumbLX, true)); - input->joy_axis(joy.id, JOY_AXIS_1, axis_correct(joy.state.Gamepad.sThumbLY, true, false, true)); - input->joy_axis(joy.id, JOY_AXIS_2, axis_correct(joy.state.Gamepad.sThumbRX, true)); - input->joy_axis(joy.id, JOY_AXIS_3, axis_correct(joy.state.Gamepad.sThumbRY, true, false, true)); - input->joy_axis(joy.id, JOY_AXIS_4, axis_correct(joy.state.Gamepad.bLeftTrigger, true, true)); - input->joy_axis(joy.id, JOY_AXIS_5, axis_correct(joy.state.Gamepad.bRightTrigger, true, true)); + input->joy_axis(joy.id, JOY_AXIS_LEFT_X, axis_correct(joy.state.Gamepad.sThumbLX, true)); + input->joy_axis(joy.id, JOY_AXIS_LEFT_Y, axis_correct(joy.state.Gamepad.sThumbLY, true, false, true)); + input->joy_axis(joy.id, JOY_AXIS_RIGHT_X, axis_correct(joy.state.Gamepad.sThumbRX, true)); + input->joy_axis(joy.id, JOY_AXIS_RIGHT_Y, axis_correct(joy.state.Gamepad.sThumbRY, true, false, true)); + input->joy_axis(joy.id, JOY_AXIS_TRIGGER_LEFT, axis_correct(joy.state.Gamepad.bLeftTrigger, true, true)); + input->joy_axis(joy.id, JOY_AXIS_TRIGGER_RIGHT, axis_correct(joy.state.Gamepad.bRightTrigger, true, true)); joy.last_packet = joy.state.dwPacketNumber; } uint64_t timestamp = input->get_joy_vibration_timestamp(joy.id); diff --git a/platform/windows/key_mapping_windows.h b/platform/windows/key_mapping_windows.h index 3361ad397f..028ca9fd5d 100644 --- a/platform/windows/key_mapping_windows.h +++ b/platform/windows/key_mapping_windows.h @@ -39,7 +39,7 @@ class KeyMappingWindows { - KeyMappingWindows(){}; + KeyMappingWindows() {} public: static unsigned int get_keysym(unsigned int p_code); diff --git a/platform/windows/windows_terminal_logger.cpp b/platform/windows/windows_terminal_logger.cpp index 884d95e082..3fb2adfd2c 100644 --- a/platform/windows/windows_terminal_logger.cpp +++ b/platform/windows/windows_terminal_logger.cpp @@ -89,20 +89,36 @@ void WindowsTerminalLogger::log_error(const char *p_function, const char *p_file uint32_t basecol = 0; switch (p_type) { - case ERR_ERROR: basecol = FOREGROUND_RED; break; - case ERR_WARNING: basecol = FOREGROUND_RED | FOREGROUND_GREEN; break; - case ERR_SCRIPT: basecol = FOREGROUND_RED | FOREGROUND_BLUE; break; - case ERR_SHADER: basecol = FOREGROUND_GREEN | FOREGROUND_BLUE; break; + case ERR_ERROR: + basecol = FOREGROUND_RED; + break; + case ERR_WARNING: + basecol = FOREGROUND_RED | FOREGROUND_GREEN; + break; + case ERR_SCRIPT: + basecol = FOREGROUND_RED | FOREGROUND_BLUE; + break; + case ERR_SHADER: + basecol = FOREGROUND_GREEN | FOREGROUND_BLUE; + break; } basecol |= current_bg; SetConsoleTextAttribute(hCon, basecol | FOREGROUND_INTENSITY); switch (p_type) { - case ERR_ERROR: logf("ERROR:"); break; - case ERR_WARNING: logf("WARNING:"); break; - case ERR_SCRIPT: logf("SCRIPT ERROR:"); break; - case ERR_SHADER: logf("SHADER ERROR:"); break; + case ERR_ERROR: + logf("ERROR:"); + break; + case ERR_WARNING: + logf("WARNING:"); + break; + case ERR_SCRIPT: + logf("SCRIPT ERROR:"); + break; + case ERR_SHADER: + logf("SHADER ERROR:"); + break; } SetConsoleTextAttribute(hCon, basecol); @@ -115,10 +131,18 @@ void WindowsTerminalLogger::log_error(const char *p_function, const char *p_file // `FOREGROUND_INTENSITY` alone results in gray text. SetConsoleTextAttribute(hCon, FOREGROUND_INTENSITY); switch (p_type) { - case ERR_ERROR: logf(" at: "); break; - case ERR_WARNING: logf(" at: "); break; - case ERR_SCRIPT: logf(" at: "); break; - case ERR_SHADER: logf(" at: "); break; + case ERR_ERROR: + logf(" at: "); + break; + case ERR_WARNING: + logf(" at: "); + break; + case ERR_SCRIPT: + logf(" at: "); + break; + case ERR_SHADER: + logf(" at: "); + break; } if (p_rationale && p_rationale[0]) { diff --git a/scene/2d/cpu_particles_2d.cpp b/scene/2d/cpu_particles_2d.cpp index 0a6b091a51..c37cd398c4 100644 --- a/scene/2d/cpu_particles_2d.cpp +++ b/scene/2d/cpu_particles_2d.cpp @@ -961,7 +961,8 @@ void CPUParticles2D::_particles_process(float p_delta) { //scale by scale float base_scale = tex_scale * Math::lerp(parameters[PARAM_SCALE], 1.0f, p.scale_rand * randomness[PARAM_SCALE]); - if (base_scale < 0.000001) base_scale = 0.000001; + if (base_scale < 0.000001) + base_scale = 0.000001; p.transform.elements[0] *= base_scale; p.transform.elements[1] *= base_scale; @@ -1196,7 +1197,8 @@ void CPUParticles2D::convert_from_particles(Node *p_particles) { set_param(m_param, material->get_param(ParticlesMaterial::m_param)); \ { \ Ref<CurveTexture> ctex = material->get_param_texture(ParticlesMaterial::m_param); \ - if (ctex.is_valid()) set_param_curve(m_param, ctex->get_curve()); \ + if (ctex.is_valid()) \ + set_param_curve(m_param, ctex->get_curve()); \ } \ set_param_randomness(m_param, material->get_param_randomness(ParticlesMaterial::m_param)); diff --git a/scene/2d/navigation_agent_2d.cpp b/scene/2d/navigation_agent_2d.cpp index 32da46e8a8..dfeb5eea45 100644 --- a/scene/2d/navigation_agent_2d.cpp +++ b/scene/2d/navigation_agent_2d.cpp @@ -133,15 +133,7 @@ void NavigationAgent2D::_notification(int p_what) { } } -NavigationAgent2D::NavigationAgent2D() : - agent_parent(nullptr), - navigation(nullptr), - agent(RID()), - target_desired_distance(1.0), - path_max_distance(3.0), - velocity_submitted(false), - target_reached(false), - navigation_finished(true) { +NavigationAgent2D::NavigationAgent2D() { agent = NavigationServer2D::get_singleton()->agent_create(); set_neighbor_dist(500.0); set_max_neighbors(10); @@ -288,9 +280,12 @@ String NavigationAgent2D::get_configuration_warning() const { void NavigationAgent2D::update_navigation() { - if (agent_parent == nullptr) return; - if (navigation == nullptr) return; - if (update_frame_id == Engine::get_singleton()->get_physics_frames()) return; + if (agent_parent == nullptr) + return; + if (navigation == nullptr) + return; + if (update_frame_id == Engine::get_singleton()->get_physics_frames()) + return; update_frame_id = Engine::get_singleton()->get_physics_frames(); diff --git a/scene/2d/navigation_agent_2d.h b/scene/2d/navigation_agent_2d.h index 26eccfc949..796a85f3f2 100644 --- a/scene/2d/navigation_agent_2d.h +++ b/scene/2d/navigation_agent_2d.h @@ -40,29 +40,29 @@ class Navigation2D; class NavigationAgent2D : public Node { GDCLASS(NavigationAgent2D, Node); - Node2D *agent_parent; - Navigation2D *navigation; + Node2D *agent_parent = nullptr; + Navigation2D *navigation = nullptr; RID agent; - real_t target_desired_distance; + real_t target_desired_distance = 1.0; real_t radius; real_t neighbor_dist; int max_neighbors; real_t time_horizon; real_t max_speed; - real_t path_max_distance; + real_t path_max_distance = 3.0; Vector2 target_location; Vector<Vector2> navigation_path; int nav_path_index; - bool velocity_submitted; + bool velocity_submitted = false; Vector2 prev_safe_velocity; /// The submitted target velocity Vector2 target_velocity; - bool target_reached; - bool navigation_finished; + bool target_reached = false; + bool navigation_finished = true; // No initialized on purpose uint32_t update_frame_id; diff --git a/scene/2d/navigation_obstacle_2d.cpp b/scene/2d/navigation_obstacle_2d.cpp index 50d02ca507..3eb3ef332f 100644 --- a/scene/2d/navigation_obstacle_2d.cpp +++ b/scene/2d/navigation_obstacle_2d.cpp @@ -78,9 +78,7 @@ void NavigationObstacle2D::_notification(int p_what) { } } -NavigationObstacle2D::NavigationObstacle2D() : - navigation(nullptr), - agent(RID()) { +NavigationObstacle2D::NavigationObstacle2D() { agent = NavigationServer2D::get_singleton()->agent_create(); } diff --git a/scene/2d/navigation_obstacle_2d.h b/scene/2d/navigation_obstacle_2d.h index 3935fe1bc5..bdef6f2843 100644 --- a/scene/2d/navigation_obstacle_2d.h +++ b/scene/2d/navigation_obstacle_2d.h @@ -38,7 +38,7 @@ class Navigation2D; class NavigationObstacle2D : public Node { GDCLASS(NavigationObstacle2D, Node); - Navigation2D *navigation; + Navigation2D *navigation = nullptr; RID agent; diff --git a/scene/2d/navigation_region_2d.cpp b/scene/2d/navigation_region_2d.cpp index abbfbf83b7..f3f335e66a 100644 --- a/scene/2d/navigation_region_2d.cpp +++ b/scene/2d/navigation_region_2d.cpp @@ -367,13 +367,6 @@ void NavigationPolygon::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "outlines", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_outlines", "_get_outlines"); } -NavigationPolygon::NavigationPolygon() : - rect_cache_dirty(true) { -} - -NavigationPolygon::~NavigationPolygon() { -} - void NavigationRegion2D::set_enabled(bool p_enabled) { if (enabled == p_enabled) @@ -569,12 +562,8 @@ void NavigationRegion2D::_bind_methods() { } NavigationRegion2D::NavigationRegion2D() { - - enabled = true; set_notify_transform(true); region = NavigationServer2D::get_singleton()->region_create(); - - navigation = nullptr; } NavigationRegion2D::~NavigationRegion2D() { diff --git a/scene/2d/navigation_region_2d.h b/scene/2d/navigation_region_2d.h index e730df6373..54d2ac11f8 100644 --- a/scene/2d/navigation_region_2d.h +++ b/scene/2d/navigation_region_2d.h @@ -46,7 +46,7 @@ class NavigationPolygon : public Resource { Vector<Vector<Vector2>> outlines; mutable Rect2 item_rect; - mutable bool rect_cache_dirty; + mutable bool rect_cache_dirty = true; Mutex navmesh_generation; // Navigation mesh @@ -88,8 +88,8 @@ public: Ref<NavigationMesh> get_mesh(); - NavigationPolygon(); - ~NavigationPolygon(); + NavigationPolygon() {} + ~NavigationPolygon() {} }; class Navigation2D; @@ -98,9 +98,9 @@ class NavigationRegion2D : public Node2D { GDCLASS(NavigationRegion2D, Node2D); - bool enabled; + bool enabled = true; RID region; - Navigation2D *navigation; + Navigation2D *navigation = nullptr; Ref<NavigationPolygon> navpoly; void _navpoly_changed(); diff --git a/scene/2d/physics_body_2d.cpp b/scene/2d/physics_body_2d.cpp index de15f0efc2..4198eb6c06 100644 --- a/scene/2d/physics_body_2d.cpp +++ b/scene/2d/physics_body_2d.cpp @@ -1394,7 +1394,8 @@ Vector2 KinematicCollision2D::get_remainder() const { return collision.remainder; } Object *KinematicCollision2D::get_local_shape() const { - if (!owner) return nullptr; + if (!owner) + return nullptr; uint32_t ownerid = owner->shape_find_owner(collision.local_shape); return owner->shape_owner_get_owner(ownerid); } diff --git a/scene/2d/polygon_2d.cpp b/scene/2d/polygon_2d.cpp index 84c1828b47..cec598774e 100644 --- a/scene/2d/polygon_2d.cpp +++ b/scene/2d/polygon_2d.cpp @@ -723,7 +723,7 @@ void Polygon2D::_bind_methods() { Polygon2D::Polygon2D() { - invert = 0; + invert = false; invert_border = 100; antialiased = false; tex_rot = 0; diff --git a/scene/2d/tile_map.cpp b/scene/2d/tile_map.cpp index 86e61fe878..869f290f97 100644 --- a/scene/2d/tile_map.cpp +++ b/scene/2d/tile_map.cpp @@ -40,7 +40,7 @@ int TileMap::_get_quadrant_size() const { - if (y_sort_mode) + if (use_y_sort) return 1; else return quadrant_size; @@ -1355,7 +1355,8 @@ bool TileMap::get_collision_use_parent() const { void TileMap::set_collision_use_parent(bool p_use_parent) { - if (use_parent == p_use_parent) return; + if (use_parent == p_use_parent) + return; _clear_quadrants(); @@ -1648,18 +1649,18 @@ Vector2 TileMap::world_to_map(const Vector2 &p_pos) const { return ret.floor(); } -void TileMap::set_y_sort_mode(bool p_enable) { +void TileMap::set_y_sort_enabled(bool p_enable) { _clear_quadrants(); - y_sort_mode = p_enable; - RS::get_singleton()->canvas_item_set_sort_children_by_y(get_canvas_item(), y_sort_mode); + use_y_sort = p_enable; + RS::get_singleton()->canvas_item_set_sort_children_by_y(get_canvas_item(), use_y_sort); _recreate_quadrants(); emit_signal("settings_changed"); } -bool TileMap::is_y_sort_mode_enabled() const { +bool TileMap::is_y_sort_enabled() const { - return y_sort_mode; + return use_y_sort; } void TileMap::set_compatibility_mode(bool p_enable) { @@ -1702,7 +1703,7 @@ TypedArray<Vector2i> TileMap::get_used_cells() const { return a; } -TypedArray<Vector2i> TileMap::get_used_cells_by_id(int p_id) const { +TypedArray<Vector2i> TileMap::get_used_cells_by_index(int p_id) const { TypedArray<Vector2i> a; for (Map<PosKey, Cell>::Element *E = tile_map.front(); E; E = E->next()) { @@ -1822,8 +1823,8 @@ void TileMap::_bind_methods() { ClassDB::bind_method(D_METHOD("set_clip_uv", "enable"), &TileMap::set_clip_uv); ClassDB::bind_method(D_METHOD("get_clip_uv"), &TileMap::get_clip_uv); - ClassDB::bind_method(D_METHOD("set_y_sort_mode", "enable"), &TileMap::set_y_sort_mode); - ClassDB::bind_method(D_METHOD("is_y_sort_mode_enabled"), &TileMap::is_y_sort_mode_enabled); + ClassDB::bind_method(D_METHOD("set_y_sort_enabled", "enable"), &TileMap::set_y_sort_enabled); + ClassDB::bind_method(D_METHOD("is_y_sort_enabled"), &TileMap::is_y_sort_enabled); ClassDB::bind_method(D_METHOD("set_compatibility_mode", "enable"), &TileMap::set_compatibility_mode); ClassDB::bind_method(D_METHOD("is_compatibility_mode_enabled"), &TileMap::is_compatibility_mode_enabled); @@ -1873,7 +1874,7 @@ void TileMap::_bind_methods() { ClassDB::bind_method(D_METHOD("clear"), &TileMap::clear); ClassDB::bind_method(D_METHOD("get_used_cells"), &TileMap::get_used_cells); - ClassDB::bind_method(D_METHOD("get_used_cells_by_id", "id"), &TileMap::get_used_cells_by_id); + ClassDB::bind_method(D_METHOD("get_used_cells_by_index", "index"), &TileMap::get_used_cells_by_index); ClassDB::bind_method(D_METHOD("get_used_rect"), &TileMap::get_used_rect); ClassDB::bind_method(D_METHOD("map_to_world", "map_position", "ignore_half_ofs"), &TileMap::map_to_world, DEFVAL(false)); @@ -1897,7 +1898,7 @@ void TileMap::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::TRANSFORM2D, "cell_custom_transform"), "set_custom_transform", "get_custom_transform"); ADD_PROPERTY(PropertyInfo(Variant::INT, "cell_half_offset", PROPERTY_HINT_ENUM, "Offset X,Offset Y,Disabled,Offset Negative X,Offset Negative Y"), "set_half_offset", "get_half_offset"); ADD_PROPERTY(PropertyInfo(Variant::INT, "cell_tile_origin", PROPERTY_HINT_ENUM, "Top Left,Center,Bottom Left"), "set_tile_origin", "get_tile_origin"); - ADD_PROPERTY(PropertyInfo(Variant::BOOL, "cell_y_sort"), "set_y_sort_mode", "is_y_sort_mode_enabled"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "cell_y_sort"), "set_y_sort_enabled", "is_y_sort_enabled"); ADD_PROPERTY(PropertyInfo(Variant::BOOL, "compatibility_mode"), "set_compatibility_mode", "is_compatibility_mode_enabled"); ADD_PROPERTY(PropertyInfo(Variant::BOOL, "centered_textures"), "set_centered_textures", "is_centered_textures_enabled"); ADD_PROPERTY(PropertyInfo(Variant::BOOL, "cell_clip_uv"), "set_clip_uv", "get_clip_uv"); @@ -1959,7 +1960,7 @@ TileMap::TileMap() { collision_parent = nullptr; use_kinematic = false; navigation = nullptr; - y_sort_mode = false; + use_y_sort = false; compatibility_mode = false; centered_textures = false; occluder_light_mask = 1; diff --git a/scene/2d/tile_map.h b/scene/2d/tile_map.h index cc1537f583..16784571bf 100644 --- a/scene/2d/tile_map.h +++ b/scene/2d/tile_map.h @@ -187,7 +187,7 @@ private: Rect2 used_size_cache; bool used_size_cache_dirty; bool quadrant_order_dirty; - bool y_sort_mode; + bool use_y_sort; bool compatibility_mode; bool centered_textures; bool clip_uv; @@ -319,8 +319,8 @@ public: Vector2 map_to_world(const Vector2 &p_pos, bool p_ignore_ofs = false) const; Vector2 world_to_map(const Vector2 &p_pos) const; - void set_y_sort_mode(bool p_enable); - bool is_y_sort_mode_enabled() const; + void set_y_sort_enabled(bool p_enable); + bool is_y_sort_enabled() const; void set_compatibility_mode(bool p_enable); bool is_compatibility_mode_enabled() const; @@ -329,7 +329,7 @@ public: bool is_centered_textures_enabled() const; TypedArray<Vector2i> get_used_cells() const; - TypedArray<Vector2i> get_used_cells_by_id(int p_id) const; + TypedArray<Vector2i> get_used_cells_by_index(int p_index) const; Rect2 get_used_rect(); // Not const because of cache void set_occluder_light_mask(int p_mask); diff --git a/scene/2d/visibility_notifier_2d.cpp b/scene/2d/visibility_notifier_2d.cpp index c374dd5faa..780d08693d 100644 --- a/scene/2d/visibility_notifier_2d.cpp +++ b/scene/2d/visibility_notifier_2d.cpp @@ -248,10 +248,19 @@ void VisibilityEnabler2D::_notification(int p_what) { _find_nodes(from); - if (enabler[ENABLER_PARENT_PHYSICS_PROCESS] && get_parent()) - get_parent()->set_physics_process(false); - if (enabler[ENABLER_PARENT_PROCESS] && get_parent()) - get_parent()->set_process(false); + // We need to defer the call of set_process and set_physics_process, + // otherwise they are overwritten inside NOTIFICATION_READY. + // We can't use call_deferred, because it happens after a physics frame. + // The ready signal works as it's emitted immediately after NOTIFICATION_READY. + + if (enabler[ENABLER_PARENT_PHYSICS_PROCESS] && get_parent()) { + get_parent()->connect(SceneStringNames::get_singleton()->ready, + callable_mp(get_parent(), &Node::set_physics_process), varray(false), CONNECT_ONESHOT); + } + if (enabler[ENABLER_PARENT_PROCESS] && get_parent()) { + get_parent()->connect(SceneStringNames::get_singleton()->ready, + callable_mp(get_parent(), &Node::set_process), varray(false), CONNECT_ONESHOT); + } } if (p_what == NOTIFICATION_EXIT_TREE) { diff --git a/scene/3d/audio_stream_player_3d.cpp b/scene/3d/audio_stream_player_3d.cpp index 5701d3cea2..8a00d67e12 100644 --- a/scene/3d/audio_stream_player_3d.cpp +++ b/scene/3d/audio_stream_player_3d.cpp @@ -320,7 +320,8 @@ float AudioStreamPlayer3D::_get_attenuation_db(float p_distance) const { case ATTENUATION_LOGARITHMIC: { att = -20 * Math::log(p_distance / unit_size + CMP_EPSILON); } break; - case ATTENUATION_DISABLED: break; + case ATTENUATION_DISABLED: + break; default: { ERR_PRINT("Unknown attenuation type"); break; diff --git a/scene/3d/baked_lightmap.cpp b/scene/3d/baked_lightmap.cpp index 6bde56104e..82a9a1e589 100644 --- a/scene/3d/baked_lightmap.cpp +++ b/scene/3d/baked_lightmap.cpp @@ -28,72 +28,25 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 #include "baked_lightmap.h" + #include "core/io/config_file.h" #include "core/io/resource_saver.h" +#include "core/math/camera_matrix.h" +#include "core/math/delaunay_3d.h" #include "core/os/dir_access.h" +#include "core/os/file_access.h" #include "core/os/os.h" -#include "voxel_light_baker.h" - -void BakedLightmapData::set_bounds(const AABB &p_bounds) { - - bounds = p_bounds; - RS::get_singleton()->lightmap_capture_set_bounds(baked_light, p_bounds); -} - -AABB BakedLightmapData::get_bounds() const { - - return bounds; -} - -void BakedLightmapData::set_octree(const Vector<uint8_t> &p_octree) { - - RS::get_singleton()->lightmap_capture_set_octree(baked_light, p_octree); -} - -Vector<uint8_t> BakedLightmapData::get_octree() const { - - return RS::get_singleton()->lightmap_capture_get_octree(baked_light); -} - -void BakedLightmapData::set_cell_space_transform(const Transform &p_xform) { - - cell_space_xform = p_xform; - RS::get_singleton()->lightmap_capture_set_octree_cell_transform(baked_light, p_xform); -} - -Transform BakedLightmapData::get_cell_space_transform() const { - return cell_space_xform; -} - -void BakedLightmapData::set_cell_subdiv(int p_cell_subdiv) { - cell_subdiv = p_cell_subdiv; - RS::get_singleton()->lightmap_capture_set_octree_cell_subdiv(baked_light, p_cell_subdiv); -} - -int BakedLightmapData::get_cell_subdiv() const { - return cell_subdiv; -} +#include "core/sort_array.h" +#include "lightmap_probe.h" -void BakedLightmapData::set_energy(float p_energy) { +void BakedLightmapData::add_user(const NodePath &p_path, const Rect2 &p_uv_scale, int p_slice_index, int32_t p_sub_instance) { - energy = p_energy; - RS::get_singleton()->lightmap_capture_set_energy(baked_light, energy); -} - -float BakedLightmapData::get_energy() const { - - return energy; -} - -void BakedLightmapData::add_user(const NodePath &p_path, const Ref<Texture2D> &p_lightmap, int p_instance) { - - ERR_FAIL_COND_MSG(p_lightmap.is_null(), "It's not a reference to a valid Texture object."); User user; user.path = p_path; - user.lightmap = p_lightmap; - user.instance_index = p_instance; + user.uv_scale = p_uv_scale; + user.slice_index = p_slice_index; + user.sub_instance = p_sub_instance; users.push_back(user); } @@ -106,16 +59,23 @@ NodePath BakedLightmapData::get_user_path(int p_user) const { ERR_FAIL_INDEX_V(p_user, users.size(), NodePath()); return users[p_user].path; } -Ref<Texture2D> BakedLightmapData::get_user_lightmap(int p_user) const { - ERR_FAIL_INDEX_V(p_user, users.size(), Ref<Texture2D>()); - return users[p_user].lightmap; +int32_t BakedLightmapData::get_user_sub_instance(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), -1); + return users[p_user].sub_instance; +} + +Rect2 BakedLightmapData::get_user_lightmap_uv_scale(int p_user) const { + + ERR_FAIL_INDEX_V(p_user, users.size(), Rect2()); + return users[p_user].uv_scale; } -int BakedLightmapData::get_user_instance(int p_user) const { +int BakedLightmapData::get_user_lightmap_slice_index(int p_user) const { ERR_FAIL_INDEX_V(p_user, users.size(), -1); - return users[p_user].instance_index; + return users[p_user].slice_index; } void BakedLightmapData::clear_users() { @@ -124,10 +84,10 @@ void BakedLightmapData::clear_users() { void BakedLightmapData::_set_user_data(const Array &p_data) { - ERR_FAIL_COND((p_data.size() % 3) != 0); + ERR_FAIL_COND((p_data.size() % 4) != 0); - for (int i = 0; i < p_data.size(); i += 3) { - add_user(p_data[i], p_data[i + 1], p_data[i + 2]); + for (int i = 0; i < p_data.size(); i += 4) { + add_user(p_data[i + 0], p_data[i + 1], p_data[i + 2], p_data[i + 3]); } } @@ -136,522 +96,1132 @@ Array BakedLightmapData::_get_user_data() const { Array ret; for (int i = 0; i < users.size(); i++) { ret.push_back(users[i].path); - ret.push_back(users[i].lightmap); - ret.push_back(users[i].instance_index); + ret.push_back(users[i].uv_scale); + ret.push_back(users[i].slice_index); + ret.push_back(users[i].sub_instance); } return ret; } RID BakedLightmapData::get_rid() const { - return baked_light; + return lightmap; } -void BakedLightmapData::_bind_methods() { - ClassDB::bind_method(D_METHOD("_set_user_data", "data"), &BakedLightmapData::_set_user_data); - ClassDB::bind_method(D_METHOD("_get_user_data"), &BakedLightmapData::_get_user_data); +void BakedLightmapData::clear() { + users.clear(); +} - ClassDB::bind_method(D_METHOD("set_bounds", "bounds"), &BakedLightmapData::set_bounds); - ClassDB::bind_method(D_METHOD("get_bounds"), &BakedLightmapData::get_bounds); +void BakedLightmapData::set_light_texture(const Ref<TextureLayered> &p_light_texture) { + light_texture = p_light_texture; + RS::get_singleton()->lightmap_set_textures(lightmap, light_texture.is_valid() ? light_texture->get_rid() : RID(), uses_spherical_harmonics); +} - ClassDB::bind_method(D_METHOD("set_cell_space_transform", "xform"), &BakedLightmapData::set_cell_space_transform); - ClassDB::bind_method(D_METHOD("get_cell_space_transform"), &BakedLightmapData::get_cell_space_transform); +Ref<TextureLayered> BakedLightmapData::get_light_texture() const { + return light_texture; +} - ClassDB::bind_method(D_METHOD("set_cell_subdiv", "cell_subdiv"), &BakedLightmapData::set_cell_subdiv); - ClassDB::bind_method(D_METHOD("get_cell_subdiv"), &BakedLightmapData::get_cell_subdiv); +void BakedLightmapData::set_uses_spherical_harmonics(bool p_enable) { + uses_spherical_harmonics = p_enable; + RS::get_singleton()->lightmap_set_textures(lightmap, light_texture.is_valid() ? light_texture->get_rid() : RID(), uses_spherical_harmonics); +} - ClassDB::bind_method(D_METHOD("set_octree", "octree"), &BakedLightmapData::set_octree); - ClassDB::bind_method(D_METHOD("get_octree"), &BakedLightmapData::get_octree); +bool BakedLightmapData::is_using_spherical_harmonics() const { + return uses_spherical_harmonics; +} - ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmapData::set_energy); - ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmapData::get_energy); +void BakedLightmapData::set_capture_data(const AABB &p_bounds, bool p_interior, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) { + if (p_points.size()) { + int pc = p_points.size(); + ERR_FAIL_COND(pc * 9 != p_point_sh.size()); + ERR_FAIL_COND((p_tetrahedra.size() % 4) != 0); + ERR_FAIL_COND((p_bsp_tree.size() % 6) != 0); + RS::get_singleton()->lightmap_set_probe_capture_data(lightmap, p_points, p_point_sh, p_tetrahedra, p_bsp_tree); + RS::get_singleton()->lightmap_set_probe_bounds(lightmap, p_bounds); + RS::get_singleton()->lightmap_set_probe_interior(lightmap, p_interior); + } else { + RS::get_singleton()->lightmap_set_probe_capture_data(lightmap, PackedVector3Array(), PackedColorArray(), PackedInt32Array(), PackedInt32Array()); + RS::get_singleton()->lightmap_set_probe_bounds(lightmap, AABB()); + RS::get_singleton()->lightmap_set_probe_interior(lightmap, false); + } + interior = p_interior; + bounds = p_bounds; +} - ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "instance"), &BakedLightmapData::add_user); - ClassDB::bind_method(D_METHOD("get_user_count"), &BakedLightmapData::get_user_count); - ClassDB::bind_method(D_METHOD("get_user_path", "user_idx"), &BakedLightmapData::get_user_path); - ClassDB::bind_method(D_METHOD("get_user_lightmap", "user_idx"), &BakedLightmapData::get_user_lightmap); - ClassDB::bind_method(D_METHOD("clear_users"), &BakedLightmapData::clear_users); +PackedVector3Array BakedLightmapData::get_capture_points() const { + return RS::get_singleton()->lightmap_get_probe_capture_points(lightmap); +} +PackedColorArray BakedLightmapData::get_capture_sh() const { + return RS::get_singleton()->lightmap_get_probe_capture_sh(lightmap); +} +PackedInt32Array BakedLightmapData::get_capture_tetrahedra() const { + return RS::get_singleton()->lightmap_get_probe_capture_tetrahedra(lightmap); +} - ADD_PROPERTY(PropertyInfo(Variant::AABB, "bounds", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_bounds", "get_bounds"); - ADD_PROPERTY(PropertyInfo(Variant::TRANSFORM, "cell_space_transform", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_cell_space_transform", "get_cell_space_transform"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "cell_subdiv", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_cell_subdiv", "get_cell_subdiv"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "energy", PROPERTY_HINT_RANGE, "0,16,0.01,or_greater"), "set_energy", "get_energy"); - ADD_PROPERTY(PropertyInfo(Variant::PACKED_BYTE_ARRAY, "octree", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR), "set_octree", "get_octree"); - ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "user_data", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_user_data", "_get_user_data"); +PackedInt32Array BakedLightmapData::get_capture_bsp_tree() const { + return RS::get_singleton()->lightmap_get_probe_capture_bsp_tree(lightmap); } -BakedLightmapData::BakedLightmapData() { +AABB BakedLightmapData::get_capture_bounds() const { + return bounds; +} - baked_light = RS::get_singleton()->lightmap_capture_create(); - energy = 1; - cell_subdiv = 1; +bool BakedLightmapData::is_interior() const { + return interior; } -BakedLightmapData::~BakedLightmapData() { +void BakedLightmapData::_set_probe_data(const Dictionary &p_data) { + ERR_FAIL_COND(!p_data.has("bounds")); + ERR_FAIL_COND(!p_data.has("points")); + ERR_FAIL_COND(!p_data.has("tetrahedra")); + ERR_FAIL_COND(!p_data.has("bsp")); + ERR_FAIL_COND(!p_data.has("sh")); + ERR_FAIL_COND(!p_data.has("interior")); + set_capture_data(p_data["bounds"], p_data["interior"], p_data["points"], p_data["sh"], p_data["tetrahedra"], p_data["bsp"]); +} - RS::get_singleton()->free(baked_light); +Dictionary BakedLightmapData::_get_probe_data() const { + Dictionary d; + d["bounds"] = get_capture_bounds(); + d["points"] = get_capture_points(); + d["tetrahedra"] = get_capture_tetrahedra(); + d["bsp"] = get_capture_bsp_tree(); + d["sh"] = get_capture_sh(); + d["interior"] = is_interior(); + return d; } +void BakedLightmapData::_bind_methods() { -/////////////////////////// + ClassDB::bind_method(D_METHOD("_set_user_data", "data"), &BakedLightmapData::_set_user_data); + ClassDB::bind_method(D_METHOD("_get_user_data"), &BakedLightmapData::_get_user_data); -BakedLightmap::BakeBeginFunc BakedLightmap::bake_begin_function = nullptr; -BakedLightmap::BakeStepFunc BakedLightmap::bake_step_function = nullptr; -BakedLightmap::BakeEndFunc BakedLightmap::bake_end_function = nullptr; + ClassDB::bind_method(D_METHOD("set_light_texture", "light_texture"), &BakedLightmapData::set_light_texture); + ClassDB::bind_method(D_METHOD("get_light_texture"), &BakedLightmapData::get_light_texture); -void BakedLightmap::set_bake_cell_size(float p_cell_size) { - bake_cell_size = p_cell_size; -} + ClassDB::bind_method(D_METHOD("set_uses_spherical_harmonics", "uses_spherical_harmonics"), &BakedLightmapData::set_uses_spherical_harmonics); + ClassDB::bind_method(D_METHOD("is_using_spherical_harmonics"), &BakedLightmapData::is_using_spherical_harmonics); -float BakedLightmap::get_bake_cell_size() const { - return bake_cell_size; -} + ClassDB::bind_method(D_METHOD("add_user", "path", "lightmap", "offset"), &BakedLightmapData::add_user); + ClassDB::bind_method(D_METHOD("get_user_count"), &BakedLightmapData::get_user_count); + ClassDB::bind_method(D_METHOD("get_user_path", "user_idx"), &BakedLightmapData::get_user_path); + ClassDB::bind_method(D_METHOD("clear_users"), &BakedLightmapData::clear_users); -void BakedLightmap::set_capture_cell_size(float p_cell_size) { - capture_cell_size = p_cell_size; -} + ClassDB::bind_method(D_METHOD("_set_probe_data", "data"), &BakedLightmapData::_set_probe_data); + ClassDB::bind_method(D_METHOD("_get_probe_data"), &BakedLightmapData::_get_probe_data); -float BakedLightmap::get_capture_cell_size() const { - return capture_cell_size; + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "light_texture", PROPERTY_HINT_RESOURCE_TYPE, "TextureLayered"), "set_light_texture", "get_light_texture"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "uses_spherical_harmonics", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "set_uses_spherical_harmonics", "is_using_spherical_harmonics"); + ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "user_data", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_user_data", "_get_user_data"); + ADD_PROPERTY(PropertyInfo(Variant::DICTIONARY, "probe_data", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_NOEDITOR | PROPERTY_USAGE_INTERNAL), "_set_probe_data", "_get_probe_data"); } -void BakedLightmap::set_extents(const Vector3 &p_extents) { - extents = p_extents; - update_gizmo(); - _change_notify("bake_extents"); -} +BakedLightmapData::BakedLightmapData() { -Vector3 BakedLightmap::get_extents() const { - return extents; + lightmap = RS::get_singleton()->lightmap_create(); } -void BakedLightmap::set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit) { - bake_default_texels_per_unit = p_bake_texels_per_unit; - update_gizmo(); -} +BakedLightmapData::~BakedLightmapData() { -float BakedLightmap::get_bake_default_texels_per_unit() const { - return bake_default_texels_per_unit; + RS::get_singleton()->free(lightmap); } -void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, List<PlotMesh> &plot_meshes, List<PlotLight> &plot_lights) { +/////////////////////////// + +void BakedLightmap::_find_meshes_and_lights(Node *p_at_node, Vector<MeshesFound> &meshes, Vector<LightsFound> &lights, Vector<Vector3> &probes) { - MeshInstance *mi = Object::cast_to<MeshInstance>(p_at_node); - if (mi && mi->get_flag(GeometryInstance::FLAG_USE_BAKED_LIGHT) && mi->is_visible_in_tree()) { + MeshInstance3D *mi = Object::cast_to<MeshInstance3D>(p_at_node); + if (mi && mi->get_gi_mode() == GeometryInstance3D::GI_MODE_BAKED && mi->is_visible_in_tree()) { Ref<Mesh> mesh = mi->get_mesh(); if (mesh.is_valid()) { - bool all_have_uv2 = true; + bool all_have_uv2_and_normal = true; + bool surfaces_found = false; for (int i = 0; i < mesh->get_surface_count(); i++) { + + if (mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; + } if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_TEX_UV2)) { - all_have_uv2 = false; + all_have_uv2_and_normal = false; break; } + if (!(mesh->surface_get_format(i) & Mesh::ARRAY_FORMAT_NORMAL)) { + all_have_uv2_and_normal = false; + break; + } + surfaces_found = true; } - if (all_have_uv2) { + if (surfaces_found && all_have_uv2_and_normal) { //READY TO BAKE! size hint could be computed if not found, actually.. - AABB aabb = mesh->get_aabb(); - - Transform xf = get_global_transform().affine_inverse() * mi->get_global_transform(); - - if (AABB(-extents, extents * 2).intersects(xf.xform(aabb))) { - PlotMesh pm; - pm.local_xform = xf; - pm.mesh = mesh; - pm.path = get_path_to(mi); - pm.instance_idx = -1; - for (int i = 0; i < mesh->get_surface_count(); i++) { - pm.instance_materials.push_back(mi->get_surface_material(i)); + MeshesFound mf; + mf.xform = get_global_transform().affine_inverse() * mi->get_global_transform(); + mf.node_path = get_path_to(mi); + mf.subindex = -1; + mf.mesh = mesh; + + static const int lightmap_scale[GeometryInstance3D::LIGHTMAP_SCALE_MAX] = { 1, 2, 4, 8 }; + mf.lightmap_scale = lightmap_scale[mi->get_lightmap_scale()]; + + Ref<Material> all_override = mi->get_material_override(); + for (int i = 0; i < mesh->get_surface_count(); i++) { + if (all_override.is_valid()) { + mf.overrides.push_back(all_override); + } else { + mf.overrides.push_back(mi->get_surface_material(i)); } - pm.override_material = mi->get_material_override(); - plot_meshes.push_back(pm); } + + meshes.push_back(mf); } } } - Spatial *s = Object::cast_to<Spatial>(p_at_node); + Node3D *s = Object::cast_to<Node3D>(p_at_node); if (!mi && s) { - Array meshes = p_at_node->call("get_bake_meshes"); - if (meshes.size() && (meshes.size() & 1) == 0) { + Array bmeshes = p_at_node->call("get_bake_bmeshes"); + if (bmeshes.size() && (bmeshes.size() & 1) == 0) { Transform xf = get_global_transform().affine_inverse() * s->get_global_transform(); - for (int i = 0; i < meshes.size(); i += 2) { - PlotMesh pm; - Transform mesh_xf = meshes[i + 1]; - pm.local_xform = xf * mesh_xf; - pm.mesh = meshes[i]; - pm.instance_idx = i / 2; - if (!pm.mesh.is_valid()) + for (int i = 0; i < bmeshes.size(); i += 2) { + + Ref<Mesh> mesh = bmeshes[i]; + if (!mesh.is_valid()) continue; - pm.path = get_path_to(s); - plot_meshes.push_back(pm); + + MeshesFound mf; + + Transform mesh_xf = bmeshes[i + 1]; + mf.xform = xf * mesh_xf; + mf.node_path = get_path_to(s); + mf.subindex = i / 2; + mf.lightmap_scale = 1; + mf.mesh = mesh; + + meshes.push_back(mf); } } } - Light *light = Object::cast_to<Light>(p_at_node); + Light3D *light = Object::cast_to<Light3D>(p_at_node); + + if (light && light->get_bake_mode() != Light3D::BAKE_DISABLED) { + + LightsFound lf; + lf.xform = get_global_transform().affine_inverse() * light->get_global_transform(); + lf.light = light; + lights.push_back(lf); + } - if (light && light->get_bake_mode() != Light::BAKE_DISABLED) { - PlotLight pl; - Transform xf = get_global_transform().affine_inverse() * light->get_global_transform(); + LightmapProbe *probe = Object::cast_to<LightmapProbe>(p_at_node); - pl.local_xform = xf; - pl.light = light; - plot_lights.push_back(pl); + if (probe) { + Transform xf = get_global_transform().affine_inverse() * probe->get_global_transform(); + probes.push_back(xf.origin); } + for (int i = 0; i < p_at_node->get_child_count(); i++) { Node *child = p_at_node->get_child(i); if (!child->get_owner()) continue; //maybe a helper - _find_meshes_and_lights(child, plot_meshes, plot_lights); + _find_meshes_and_lights(child, meshes, lights, probes); } } -void BakedLightmap::set_hdr(bool p_enable) { - hdr = p_enable; -} +int BakedLightmap::_bsp_get_simplex_side(const Vector<Vector3> &p_points, const LocalVector<BSPSimplex> &p_simplices, const Plane &p_plane, uint32_t p_simplex) const { + + int over = 0; + int under = 0; + int coplanar = 0; + const BSPSimplex &s = p_simplices[p_simplex]; + for (int i = 0; i < 4; i++) { + const Vector3 v = p_points[s.vertices[i]]; + if (p_plane.has_point(v)) { //coplanar + coplanar++; + } else if (p_plane.is_point_over(v)) { + over++; + } else { + under++; + } + } -bool BakedLightmap::is_hdr() const { - return hdr; + ERR_FAIL_COND_V(under == 0 && over == 0, -2); //should never happen, we discarded flat simplices before, but in any case drop it from the bsp tree and throw an error + if (under == 0) { + return 1; // all over + } else if (over == 0) { + return -1; // all under + } else { + return 0; // crossing + } } -bool BakedLightmap::_bake_time(void *ud, float p_secs, float p_progress) { +//#define DEBUG_BSP - uint64_t time = OS::get_singleton()->get_ticks_usec(); - BakeTimeData *btd = (BakeTimeData *)ud; +int32_t BakedLightmap::_compute_bsp_tree(const Vector<Vector3> &p_points, const LocalVector<Plane> &p_planes, LocalVector<int32_t> &planes_tested, const LocalVector<BSPSimplex> &p_simplices, const LocalVector<int32_t> &p_simplex_indices, LocalVector<BSPNode> &bsp_nodes) { - if (time - btd->last_step > 1000000) { + //if we reach here, it means there is more than one simplex + int32_t node_index = (int32_t)bsp_nodes.size(); + bsp_nodes.push_back(BSPNode()); - int mins_left = p_secs / 60; - int secs_left = Math::fmod(p_secs, 60.0f); - int percent = p_progress * 100; - bool abort = bake_step_function(btd->pass + percent, btd->text + " " + vformat(RTR("%d%%"), percent) + " " + vformat(RTR("(Time Left: %d:%02d s)"), mins_left, secs_left)); - btd->last_step = time; - if (abort) - return true; - } + //test with all the simplex planes + Plane best_plane; + float best_plane_score = -1.0; - return false; -} + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + const BSPSimplex &s = p_simplices[p_simplex_indices[i]]; + for (int j = 0; j < 4; j++) { + uint32_t plane_index = s.planes[j]; + if (planes_tested[plane_index] == node_index) { + continue; //tested this plane already + } + + planes_tested[plane_index] = node_index; + + static const int face_order[4][3] = { + { 0, 1, 2 }, + { 0, 2, 3 }, + { 0, 1, 3 }, + { 1, 2, 3 } + }; + + // despite getting rid of plane duplicates, we should still use here the actual plane to avoid numerical error + // from thinking this same simplex is intersecting rather than on a side + Vector3 v0 = p_points[s.vertices[face_order[j][0]]]; + Vector3 v1 = p_points[s.vertices[face_order[j][1]]]; + Vector3 v2 = p_points[s.vertices[face_order[j][2]]]; + + Plane plane(v0, v1, v2); + + //test with all the simplices + int over_count = 0; + int under_count = 0; + + for (uint32_t k = 0; k < p_simplex_indices.size(); k++) { + int side = _bsp_get_simplex_side(p_points, p_simplices, plane, p_simplex_indices[k]); + if (side == -2) { + continue; //this simplex is invalid, skip for now + } else if (side < 0) { + under_count++; + } else if (side > 0) { + over_count++; + } + } + + if (under_count == 0 && over_count == 0) { + continue; //most likely precision issue with a flat simplex, do not try this plane + } -BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, bool p_create_visual_debug) { + if (under_count > over_count) { //make sure under is always less than over, so we can compute the same ratio + SWAP(under_count, over_count); + } - String save_path; + float score = 0; //by default, score is 0 (worst) + if (over_count > 0) { + //give score mainly based on ratio (under / over), this means that this plane is splitting simplices a lot, but its balanced + score = float(under_count) / over_count; + } - if (image_path.begins_with("res://")) { - save_path = image_path; - } else { - if (get_filename() != "") { - save_path = get_filename().get_base_dir(); - } else if (get_owner() && get_owner()->get_filename() != "") { - save_path = get_owner()->get_filename().get_base_dir(); + //adjusting priority over least splits, probably not a great idea + //score *= Math::sqrt(float(over_count + under_count) / p_simplex_indices.size()); //also multiply score + + if (score > best_plane_score) { + + best_plane = plane; + best_plane_score = score; + } } + } - if (save_path == "") { - return BAKE_ERROR_NO_SAVE_PATH; + LocalVector<int32_t> indices_over; + LocalVector<int32_t> indices_under; + + //split again, but add to list + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + + uint32_t index = p_simplex_indices[i]; + int side = _bsp_get_simplex_side(p_points, p_simplices, best_plane, index); + + if (side == -2) { + continue; //simplex sits on the plane, does not make sense to use it } - if (image_path != "") { - save_path.plus_file(image_path); + if (side <= 0) { + indices_under.push_back(index); } - } - { - //check for valid save path - DirAccessRef d = DirAccess::open(save_path); - if (!d) { - ERR_PRINT("Invalid Save Path '" + save_path + "'."); - return BAKE_ERROR_NO_SAVE_PATH; + + if (side >= 0) { + indices_over.push_back(index); } } - Ref<BakedLightmapData> new_light_data; - new_light_data.instance(); +#ifdef DEBUG_BSP + print_line("node " + itos(node_index) + " found plane: " + best_plane + " score:" + rtos(best_plane_score) + " - over " + itos(indices_over.size()) + " under " + itos(indices_under.size()) + " intersecting " + itos(intersecting)); +#endif - Voxelizer baker; + if (best_plane_score < 0.0 || indices_over.size() == p_simplex_indices.size() || indices_under.size() == p_simplex_indices.size()) { + ERR_FAIL_COND_V(p_simplex_indices.size() <= 1, 0); //should not happen, this is a bug - int bake_subdiv; - int capture_subdiv; - AABB bake_bounds; - { - bake_bounds = AABB(-extents, extents * 2.0); - int subdiv = nearest_power_of_2_templated(int(bake_bounds.get_longest_axis_size() / bake_cell_size)); - bake_bounds.size[bake_bounds.get_longest_axis_index()] = subdiv * bake_cell_size; - bake_subdiv = nearest_shift(subdiv) + 1; + // Failed to separate the tetrahedrons using planes + // this means Delaunay borked at some point. + // Luckily, because we are using tetrahedrons, we can resort to + // less precise but still working ways to generate the separating plane + // this will most likely look bad when interpolating, but at least it will not crash. + // and the arctifact will most likely also be very small, so too difficult to notice. + + //find the longest axis + + WARN_PRINT("Inconsistency found in triangulation while building BSP, probe interpolation quality may degrade a bit."); + + LocalVector<Vector3> centers; + AABB bounds_all; + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + AABB bounds; + for (uint32_t j = 0; j < 4; j++) { + + Vector3 p = p_points[p_simplices[p_simplex_indices[i]].vertices[j]]; + if (j == 0) { + bounds.position = p; + } else { + bounds.expand_to(p); + } + } + if (i == 0) { + centers.push_back(bounds.position + bounds.size * 0.5); + } else { + bounds_all.merge_with(bounds); + } + } + Vector3::Axis longest_axis = Vector3::Axis(bounds_all.get_longest_axis_index()); + + //find the simplex that will go under + uint32_t min_d_idx = 0xFFFFFFFF; + float min_d_dist = 1e20; + + for (uint32_t i = 0; i < centers.size(); i++) { + if (centers[i][longest_axis] < min_d_dist) { + min_d_idx = i; + min_d_dist = centers[i][longest_axis]; + } + } + //rebuild best_plane and over/under arrays + best_plane = Plane(); + best_plane.normal[longest_axis] = 1.0; + best_plane.d = min_d_dist; + + indices_under.clear(); + indices_under.push_back(min_d_idx); - capture_subdiv = bake_subdiv; - float css = bake_cell_size; - while (css < capture_cell_size && capture_subdiv > 2) { - capture_subdiv--; - css *= 2.0; + indices_over.clear(); + + for (uint32_t i = 0; i < p_simplex_indices.size(); i++) { + if (i == min_d_idx) { + continue; + } + indices_over.push_back(p_simplex_indices[i]); } } - baker.begin_bake(bake_subdiv, bake_bounds); + BSPNode node; + node.plane = best_plane; + + if (indices_under.size() == 0) { + //noting to do here + node.under = BSPNode::EMPTY_LEAF; + } else if (indices_under.size() == 1) { + node.under = -(indices_under[0] + 1); + } else { + node.under = _compute_bsp_tree(p_points, p_planes, planes_tested, p_simplices, indices_under, bsp_nodes); + } + + if (indices_over.size() == 0) { + //noting to do here + node.over = BSPNode::EMPTY_LEAF; + } else if (indices_over.size() == 1) { + node.over = -(indices_over[0] + 1); + } else { + node.over = _compute_bsp_tree(p_points, p_planes, planes_tested, p_simplices, indices_over, bsp_nodes); + } + + bsp_nodes[node_index] = node; - List<PlotMesh> mesh_list; - List<PlotLight> light_list; + return node_index; +} - _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), mesh_list, light_list); +bool BakedLightmap::_lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh) { - if (bake_begin_function) { - bake_begin_function(mesh_list.size() + light_list.size() + 1 + mesh_list.size() * 100); + BakeStepUD *bsud = (BakeStepUD *)ud; + bool ret = false; + if (bsud->func) { + ret = bsud->func(bsud->from_percent + p_completion * (bsud->to_percent - bsud->from_percent), p_text, bsud->ud, p_refresh); } + return ret; +} - int step = 0; +void BakedLightmap::_plot_triangle_into_octree(GenProbesOctree *p_cell, float p_cell_size, const Vector3 *p_triangle) { - int pmc = 0; + for (int i = 0; i < 8; i++) { + Vector3i pos = p_cell->offset; + uint32_t half_size = p_cell->size / 2; + if (i & 1) { + pos.x += half_size; + } + if (i & 2) { + pos.y += half_size; + } + if (i & 4) { + pos.z += half_size; + } - for (List<PlotMesh>::Element *E = mesh_list.front(); E; E = E->next()) { + AABB subcell; + subcell.position = Vector3(pos) * p_cell_size; + subcell.size = Vector3(half_size, half_size, half_size) * p_cell_size; - if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Meshes: ") + " (" + itos(pmc + 1) + "/" + itos(mesh_list.size()) + ")"); + if (!Geometry::triangle_box_overlap(subcell.position + subcell.size * 0.5, subcell.size * 0.5, p_triangle)) + continue; + + if (p_cell->children[i] == nullptr) { + GenProbesOctree *child = memnew(GenProbesOctree); + child->offset = pos; + child->size = half_size; + p_cell->children[i] = child; } - pmc++; - baker.plot_mesh(E->get().local_xform, E->get().mesh, E->get().instance_materials, E->get().override_material); + if (half_size > 1) { + //still levels missing + _plot_triangle_into_octree(p_cell->children[i], p_cell_size, p_triangle); + } } +} +void BakedLightmap::_gen_new_positions_from_octree(const GenProbesOctree *p_cell, float p_cell_size, const Vector<Vector3> &probe_positions, LocalVector<Vector3> &new_probe_positions, HashMap<Vector3i, bool, Vector3iHash> &positions_used, const AABB &p_bounds) { - pmc = 0; - baker.begin_bake_light(Voxelizer::BakeQuality(bake_quality), Voxelizer::BakeMode(bake_mode), propagation, energy); + for (int i = 0; i < 8; i++) { - for (List<PlotLight>::Element *E = light_list.front(); E; E = E->next()) { + Vector3i pos = p_cell->offset; + if (i & 1) { + pos.x += p_cell->size; + } + if (i & 2) { + pos.y += p_cell->size; + } + if (i & 4) { + pos.z += p_cell->size; + } - if (bake_step_function) { - bake_step_function(step++, RTR("Plotting Lights:") + " (" + itos(pmc + 1) + "/" + itos(light_list.size()) + ")"); + if (p_cell->size == 1 && !positions_used.has(pos)) { + //new position to insert! + Vector3 real_pos = p_bounds.position + Vector3(pos) * p_cell_size; + //see if a user submitted probe is too close + int ppcount = probe_positions.size(); + const Vector3 *pp = probe_positions.ptr(); + bool exists = false; + for (int j = 0; j < ppcount; j++) { + + if (pp[j].distance_to(real_pos) < CMP_EPSILON) { + exists = true; + break; + } + } + + if (!exists) { + new_probe_positions.push_back(real_pos); + } + + positions_used[pos] = true; } - pmc++; - PlotLight pl = E->get(); - switch (pl.light->get_light_type()) { - case RS::LIGHT_DIRECTIONAL: { - baker.plot_light_directional(-pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_bake_mode() == Light::BAKE_ALL); - } break; - case RS::LIGHT_OMNI: { - baker.plot_light_omni(pl.local_xform.origin, pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); - } break; - case RS::LIGHT_SPOT: { - baker.plot_light_spot(pl.local_xform.origin, pl.local_xform.basis.get_axis(2), pl.light->get_color(), pl.light->get_param(Light::PARAM_ENERGY), pl.light->get_param(Light::PARAM_INDIRECT_ENERGY), pl.light->get_param(Light::PARAM_RANGE), pl.light->get_param(Light::PARAM_ATTENUATION), pl.light->get_param(Light::PARAM_SPOT_ANGLE), pl.light->get_param(Light::PARAM_SPOT_ATTENUATION), pl.light->get_bake_mode() == Light::BAKE_ALL); + if (p_cell->children[i] != nullptr) { + _gen_new_positions_from_octree(p_cell->children[i], p_cell_size, probe_positions, new_probe_positions, positions_used, p_bounds); + } + } +} +BakedLightmap::BakeError BakedLightmap::bake(Node *p_from_node, String p_image_data_path, Lightmapper::BakeStepFunc p_bake_step, void *p_bake_userdata) { - } break; + if (p_image_data_path == "" && (get_light_data().is_null() || !get_light_data()->get_path().is_resource_file())) { + return BAKE_ERROR_NO_SAVE_PATH; + } + + if (p_image_data_path == "") { + + if (get_light_data().is_null()) { + return BAKE_ERROR_NO_SAVE_PATH; + } + + p_image_data_path = get_light_data()->get_path(); + if (!p_image_data_path.is_resource_file()) { + return BAKE_ERROR_NO_SAVE_PATH; } } - /*if (bake_step_function) { - bake_step_function(pmc++, RTR("Finishing Plot")); - }*/ - baker.end_bake(); + Ref<Lightmapper> lightmapper = Lightmapper::create(); + ERR_FAIL_COND_V(lightmapper.is_null(), BAKE_ERROR_NO_LIGHTMAPPER); - Set<String> used_mesh_names; + BakeStepUD bsud; + bsud.func = p_bake_step; + bsud.ud = p_bake_userdata; + bsud.from_percent = 0.2; + bsud.to_percent = 0.8; - pmc = 0; - for (List<PlotMesh>::Element *E = mesh_list.front(); E; E = E->next()) { + if (p_bake_step) { + p_bake_step(0.0, TTR("Finding meshes, lights and probes"), p_bake_userdata, true); + } + /* STEP 1, FIND MESHES, LIGHTS AND PROBES */ + Vector<Lightmapper::MeshData> mesh_data; + Vector<LightsFound> lights_found; + Vector<Vector3> probes_found; + AABB bounds; + { + Vector<MeshesFound> meshes_found; + _find_meshes_and_lights(p_from_node ? p_from_node : get_parent(), meshes_found, lights_found, probes_found); - String mesh_name = E->get().mesh->get_name(); - if (mesh_name == "" || mesh_name.find(":") != -1 || mesh_name.find("/") != -1) { - mesh_name = "LightMap"; + if (meshes_found.size() == 0) { + return BAKE_ERROR_NO_MESHES; } + // create mesh data for insert - if (used_mesh_names.has(mesh_name)) { - int idx = 2; - String base = mesh_name; - while (true) { - mesh_name = base + itos(idx); - if (!used_mesh_names.has(mesh_name)) - break; - idx++; + //get the base material textures, help compute altlas size and bounds + for (int m_i = 0; m_i < meshes_found.size(); m_i++) { + + if (p_bake_step) { + float p = (float)(m_i) / meshes_found.size(); + p_bake_step(p * 0.1, vformat(TTR("Preparing geometry %d/%d"), m_i, meshes_found.size()), p_bake_userdata, false); } - } - used_mesh_names.insert(mesh_name); - pmc++; - Voxelizer::LightMapData lm; + MeshesFound &mf = meshes_found.write[m_i]; - Error err; - if (bake_step_function) { - BakeTimeData btd; - btd.text = RTR("Lighting Meshes: ") + mesh_name + " (" + itos(pmc) + "/" + itos(mesh_list.size()) + ")"; - btd.pass = step; - btd.last_step = 0; - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm, _bake_time, &btd); - if (err != OK) { - bake_end_function(); - if (err == ERR_SKIP) - return BAKE_ERROR_USER_ABORTED; - return BAKE_ERROR_CANT_CREATE_IMAGE; + Size2i lightmap_size = mf.mesh->get_lightmap_size_hint() * mf.lightmap_scale; + Vector<RID> overrides; + overrides.resize(mf.overrides.size()); + for (int i = 0; i < mf.overrides.size(); i++) { + if (mf.overrides[i].is_valid()) { + overrides.write[i] = mf.overrides[i]->get_rid(); + } } - step += 100; - } else { + TypedArray<Image> images = RS::get_singleton()->bake_render_uv2(mf.mesh->get_rid(), overrides, lightmap_size); - err = baker.make_lightmap(E->get().local_xform, E->get().mesh, bake_default_texels_per_unit, lm); - } + ERR_FAIL_COND_V(images.empty(), BAKE_ERROR_CANT_CREATE_IMAGE); + + Ref<Image> albedo = images[RS::BAKE_CHANNEL_ALBEDO_ALPHA]; + Ref<Image> orm = images[RS::BAKE_CHANNEL_ORM]; - if (err == OK) { + //multiply albedo by metal - Ref<Image> image; - image.instance(); + Lightmapper::MeshData md; - if (hdr) { + { + Dictionary d; + d["path"] = mf.node_path; + if (mf.subindex >= 0) { + d["subindex"] = mf.subindex; + } + md.userdata = d; + } - //just save a regular image - Vector<uint8_t> data; - int s = lm.light.size(); - data.resize(lm.light.size() * 2); - { + { - uint8_t* w = data.ptrw(); - const float* r = lm.light.ptr(); - uint16_t *hfw = (uint16_t *)w.ptr(); - for (int i = 0; i < s; i++) { - hfw[i] = Math::make_half_float(r[i]); - } + if (albedo->get_format() != Image::FORMAT_RGBA8) { + albedo->convert(Image::FORMAT_RGBA8); + } + if (orm->get_format() != Image::FORMAT_RGBA8) { + orm->convert(Image::FORMAT_RGBA8); + } + Vector<uint8_t> albedo_alpha = albedo->get_data(); + Vector<uint8_t> orm_data = orm->get_data(); + + Vector<uint8_t> albedom; + uint32_t len = albedo_alpha.size(); + albedom.resize(len); + const uint8_t *r_aa = albedo_alpha.ptr(); + const uint8_t *r_orm = orm_data.ptr(); + uint8_t *w_albedo = albedom.ptrw(); + + for (uint32_t i = 0; i < len; i += 4) { + w_albedo[i + 0] = uint8_t(CLAMP(float(r_aa[i + 0]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 1] = uint8_t(CLAMP(float(r_aa[i + 1]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 2] = uint8_t(CLAMP(float(r_aa[i + 2]) * (1.0 - float(r_orm[i + 2] / 255.0)), 0, 255)); + w_albedo[i + 3] = 255; } - image->create(lm.width, lm.height, false, Image::FORMAT_RGBH, data); + md.albedo_on_uv2.instance(); + md.albedo_on_uv2->create(lightmap_size.width, lightmap_size.height, false, Image::FORMAT_RGBA8, albedom); + } - } else { + md.emission_on_uv2 = images[RS::BAKE_CHANNEL_EMISSION]; + if (md.emission_on_uv2->get_format() != Image::FORMAT_RGBAH) { + md.emission_on_uv2->convert(Image::FORMAT_RGBAH); + } - //just save a regular image - Vector<uint8_t> data; - int s = lm.light.size(); - data.resize(lm.light.size()); - { - - uint8_t* w = data.ptrw(); - const float* r = lm.light.ptr(); - for (int i = 0; i < s; i += 3) { - Color c(r[i + 0], r[i + 1], r[i + 2]); - c = c.to_srgb(); - w[i + 0] = CLAMP(c.r * 255, 0, 255); - w[i + 1] = CLAMP(c.g * 255, 0, 255); - w[i + 2] = CLAMP(c.b * 255, 0, 255); - } + //get geometry + + Basis normal_xform = mf.xform.basis.inverse().transposed(); + + for (int i = 0; i < mf.mesh->get_surface_count(); i++) { + if (mf.mesh->surface_get_primitive_type(i) != Mesh::PRIMITIVE_TRIANGLES) { + continue; } + Array a = mf.mesh->surface_get_arrays(i); - image->create(lm.width, lm.height, false, Image::FORMAT_RGB8, data); + Vector<Vector3> vertices = a[Mesh::ARRAY_VERTEX]; + const Vector3 *vr = vertices.ptr(); + Vector<Vector2> uv = a[Mesh::ARRAY_TEX_UV2]; + const Vector2 *uvr = nullptr; + Vector<Vector3> normals = a[Mesh::ARRAY_NORMAL]; + const Vector3 *nr = nullptr; + Vector<int> index = a[Mesh::ARRAY_INDEX]; - //This texture is saved to SRGB for two reasons: - // 1) first is so it looks better when doing the LINEAR->SRGB conversion (more accurate) - // 2) So it can be used in the GLES2 backend, which does not support linkear workflow - } + ERR_CONTINUE(uv.size() == 0); + ERR_CONTINUE(normals.size() == 0); - String image_path = save_path.plus_file(mesh_name); - Ref<Texture2D> texture; + uvr = uv.ptr(); + nr = normals.ptr(); - if (ResourceLoader::import) { + int facecount; + const int *ir = nullptr; - bool srgb = false; - if (false && hdr) { - //save hdr + if (index.size()) { + + facecount = index.size() / 3; + ir = index.ptr(); } else { - image_path += ".png"; - print_line("image path saving png: " + image_path); - image->save_png(image_path); - srgb = true; + facecount = vertices.size() / 3; } - if (!FileAccess::exists(image_path + ".import")) { - Ref<ConfigFile> config; - config.instance(); - config->set_value("remap", "importer", "texture"); - config->set_value("remap", "type", "StreamTexture"); - config->set_value("params", "compress/mode", 2); - config->set_value("params", "detect_3d", false); - config->set_value("params", "flags/repeat", false); - config->set_value("params", "flags/filter", true); - config->set_value("params", "flags/mipmaps", false); - config->set_value("params", "flags/srgb", srgb); - - config->save(image_path + ".import"); - } + for (int j = 0; j < facecount; j++) { - ResourceLoader::import(image_path); - texture = ResourceLoader::load(image_path); //if already loaded, it will be updated on refocus? - } else { + uint32_t vidx[3]; - image_path += ".text"; - Ref<ImageTexture> tex; - bool set_path = true; - if (ResourceCache::has(image_path)) { - tex = Ref<Resource>((Resource *)ResourceCache::get(image_path)); - set_path = false; - } + if (ir) { + for (int k = 0; k < 3; k++) { + vidx[k] = ir[j * 3 + k]; + } + } else { + for (int k = 0; k < 3; k++) { + vidx[k] = j * 3 + k; + } + } - if (!tex.is_valid()) { - tex.instance(); + for (int k = 0; k < 3; k++) { + Vector3 v = mf.xform.xform(vr[vidx[k]]); + if (bounds == AABB()) { + bounds.position = v; + } else { + bounds.expand_to(v); + } + md.points.push_back(v); + + md.uv2.push_back(uvr[vidx[k]]); + md.normal.push_back(normal_xform.xform(nr[vidx[k]]).normalized()); + } } + } - tex->create_from_image(image); + mesh_data.push_back(md); + } + } - err = ResourceSaver::save(image_path, tex, ResourceSaver::FLAG_CHANGE_PATH); - if (set_path) { - tex->set_path(image_path); - } - texture = tex; + /* STEP 2, CREATE PROBES */ + + if (p_bake_step) { + p_bake_step(0.3, TTR("Creating probes"), p_bake_userdata, true); + } + + //bounds need to include the user probes + for (int i = 0; i < probes_found.size(); i++) { + bounds.expand_to(probes_found[i]); + } + + bounds.grow_by(bounds.size.length() * 0.001); + + if (gen_probes == GENERATE_PROBES_DISABLED) { + // generate 8 probes on bound endpoints + for (int i = 0; i < 8; i++) { + probes_found.push_back(bounds.get_endpoint(i)); + } + } else { + // detect probes from geometry + static const int subdiv_values[6] = { 0, 4, 8, 16, 32 }; + int subdiv = subdiv_values[gen_probes]; + + float subdiv_cell_size; + Vector3i bound_limit; + { + int longest_axis = bounds.get_longest_axis_index(); + subdiv_cell_size = bounds.size[longest_axis] / subdiv; + int axis_n1 = (longest_axis + 1) % 3; + int axis_n2 = (longest_axis + 2) % 3; + + bound_limit[longest_axis] = subdiv; + bound_limit[axis_n1] = int(Math::ceil(bounds.size[axis_n1] / subdiv_cell_size)); + bound_limit[axis_n2] = int(Math::ceil(bounds.size[axis_n2] / subdiv_cell_size)); + //compensate bounds + bounds.size[axis_n1] = bound_limit[axis_n1] * subdiv_cell_size; + bounds.size[axis_n2] = bound_limit[axis_n2] * subdiv_cell_size; + } + + GenProbesOctree octree; + octree.size = subdiv; + + for (int i = 0; i < mesh_data.size(); i++) { + if (p_bake_step) { + float p = (float)(i) / mesh_data.size(); + p_bake_step(0.3 + p * 0.1, vformat(TTR("Creating probes from mesh %d/%d"), i, mesh_data.size()), p_bake_userdata, false); } - if (err != OK) { - if (bake_end_function) { - bake_end_function(); - } - ERR_FAIL_COND_V(err != OK, BAKE_ERROR_CANT_CREATE_IMAGE); + + for (int j = 0; j < mesh_data[i].points.size(); j += 3) { + Vector3 points[3] = { mesh_data[i].points[j + 0] - bounds.position, mesh_data[i].points[j + 1] - bounds.position, mesh_data[i].points[j + 2] - bounds.position }; + _plot_triangle_into_octree(&octree, subdiv_cell_size, points); + } + } + + LocalVector<Vector3> new_probe_positions; + HashMap<Vector3i, bool, Vector3iHash> positions_used; + for (uint32_t i = 0; i < 8; i++) { //insert bounding endpoints + Vector3i pos; + if (i & 1) { + pos.x += bound_limit.x; + } + if (i & 2) { + pos.y += bound_limit.y; + } + if (i & 4) { + pos.z += bound_limit.z; + } + + positions_used[pos] = true; + Vector3 real_pos = bounds.position + Vector3(pos) * subdiv_cell_size; //use same formula for numerical stability + new_probe_positions.push_back(real_pos); + } + //skip first level, since probes are always added at bounds endpoints anyway (code above this) + for (int i = 0; i < 8; i++) { + + if (octree.children[i]) { + _gen_new_positions_from_octree(octree.children[i], subdiv_cell_size, probes_found, new_probe_positions, positions_used, bounds); } + } - new_light_data->add_user(E->get().path, texture, E->get().instance_idx); + for (uint32_t i = 0; i < new_probe_positions.size(); i++) { + probes_found.push_back(new_probe_positions[i]); } } - AABB bounds = AABB(-extents, extents * 2); - new_light_data->set_cell_subdiv(capture_subdiv); - new_light_data->set_bounds(bounds); - new_light_data->set_octree(baker.create_capture_octree(capture_subdiv)); + // Add everything to lightmapper + if (p_bake_step) { + p_bake_step(0.4, TTR("Preparing Lightmapper"), p_bake_userdata, true); + } + { - float bake_bound_size = bake_bounds.get_longest_axis_size(); - Transform to_bounds; - to_bounds.basis.scale(Vector3(bake_bound_size, bake_bound_size, bake_bound_size)); - to_bounds.origin = bounds.position; + for (int i = 0; i < mesh_data.size(); i++) { + lightmapper->add_mesh(mesh_data[i]); + } + for (int i = 0; i < lights_found.size(); i++) { + Light3D *light = lights_found[i].light; + Transform xf = lights_found[i].xform; + + if (Object::cast_to<DirectionalLight3D>(light)) { + DirectionalLight3D *l = Object::cast_to<DirectionalLight3D>(light); + lightmapper->add_directional_light(light->get_bake_mode() == Light3D::BAKE_ALL, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_SIZE)); + } else if (Object::cast_to<OmniLight3D>(light)) { + OmniLight3D *l = Object::cast_to<OmniLight3D>(light); + lightmapper->add_omni_light(light->get_bake_mode() == Light3D::BAKE_ALL, xf.origin, l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_RANGE), l->get_param(Light3D::PARAM_ATTENUATION), l->get_param(Light3D::PARAM_SIZE)); + } else if (Object::cast_to<SpotLight3D>(light)) { + SpotLight3D *l = Object::cast_to<SpotLight3D>(light); + lightmapper->add_spot_light(light->get_bake_mode() == Light3D::BAKE_ALL, xf.origin, -xf.basis.get_axis(Vector3::AXIS_Z).normalized(), l->get_color(), l->get_param(Light3D::PARAM_ENERGY), l->get_param(Light3D::PARAM_RANGE), l->get_param(Light3D::PARAM_ATTENUATION), l->get_param(Light3D::PARAM_SPOT_ANGLE), l->get_param(Light3D::PARAM_SPOT_ATTENUATION), l->get_param(Light3D::PARAM_SIZE)); + } + } + for (int i = 0; i < probes_found.size(); i++) { + lightmapper->add_probe(probes_found[i]); + } + } + + Ref<Image> environment_image; + Basis environment_transform; + + // Add everything to lightmapper + if (environment_mode != ENVIRONMENT_MODE_DISABLED) { + if (p_bake_step) { + p_bake_step(4.1, TTR("Preparing Environment"), p_bake_userdata, true); + } + + environment_transform = get_global_transform().basis; - Transform to_grid; - to_grid.basis.scale(Vector3(1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1), 1 << (capture_subdiv - 1))); + switch (environment_mode) { + case ENVIRONMENT_MODE_DISABLED: { + //nothing + } break; + case ENVIRONMENT_MODE_SCENE: { + Ref<World3D> world = get_world_3d(); + if (world.is_valid()) { + Ref<Environment> env = world->get_environment(); + if (env.is_null()) { + env = world->get_fallback_environment(); + } + + if (env.is_valid()) { + environment_image = RS::get_singleton()->environment_bake_panorama(env->get_rid(), true, Size2i(128, 64)); + } + } + } break; + case ENVIRONMENT_MODE_CUSTOM_SKY: { + if (environment_custom_sky.is_valid()) { + environment_image = RS::get_singleton()->sky_bake_panorama(environment_custom_sky->get_rid(), environment_custom_energy, true, Size2i(128, 64)); + } + + } break; + case ENVIRONMENT_MODE_CUSTOM_COLOR: { + environment_image.instance(); + environment_image->create(128, 64, false, Image::FORMAT_RGBAF); + Color c = environment_custom_color; + c.r *= environment_custom_energy; + c.g *= environment_custom_energy; + c.b *= environment_custom_energy; + for (int i = 0; i < 128; i++) { + for (int j = 0; j < 64; j++) { + environment_image->set_pixel(i, j, c); + } + } - Transform to_cell_space = to_grid * to_bounds.affine_inverse(); - new_light_data->set_cell_space_transform(to_cell_space); + } break; + } } - if (bake_end_function) { - bake_end_function(); + Lightmapper::BakeError bake_err = lightmapper->bake(Lightmapper::BakeQuality(bake_quality), use_denoiser, bounces, bias, max_texture_size, directional, Lightmapper::GenerateProbes(gen_probes), environment_image, environment_transform, _lightmap_bake_step_function, &bsud); + + if (bake_err == Lightmapper::BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES) { + return BAKE_ERROR_MESHES_INVALID; } - //create the data for visual server + /* POSTBAKE: Save Textures */ - if (p_create_visual_debug) { - MultiMeshInstance *mmi = memnew(MultiMeshInstance); - mmi->set_multimesh(baker.create_debug_multimesh(Voxelizer::DEBUG_LIGHT)); - add_child(mmi); -#ifdef TOOLS_ENABLED - if (get_tree()->get_edited_scene_root() == this) { - mmi->set_owner(this); - } else { - mmi->set_owner(get_owner()); + Ref<TextureLayered> texture; + { + + Vector<Ref<Image>> images; + for (int i = 0; i < lightmapper->get_bake_texture_count(); i++) { + images.push_back(lightmapper->get_bake_texture(i)); + } + //we assume they are all the same, so lets create a large one for saving + Ref<Image> large_image; + large_image.instance(); + + large_image->create(images[0]->get_width(), images[0]->get_height() * images.size(), false, images[0]->get_format()); + + for (int i = 0; i < lightmapper->get_bake_texture_count(); i++) { + large_image->blit_rect(images[i], Rect2(0, 0, images[i]->get_width(), images[i]->get_height()), Point2(0, images[i]->get_height() * i)); + } + + String base_path = p_image_data_path.get_basename() + ".exr"; + + Ref<ConfigFile> config; + + config.instance(); + if (FileAccess::exists(base_path + ".import")) { + + config->load(base_path + ".import"); + } + + config->set_value("remap", "importer", "2d_array_texture"); + config->set_value("remap", "type", "StreamTexture2DArray"); + if (!config->has_section_key("params", "compress/mode")) { + config->set_value("params", "compress/mode", 2); //user may want another compression, so leave it be + } + config->set_value("params", "compress/channel_pack", 1); + config->set_value("params", "mipmaps/generate", false); + config->set_value("params", "slices/horizontal", 1); + config->set_value("params", "slices/vertical", images.size()); + + config->save(base_path + ".import"); + + Error err = large_image->save_exr(base_path, false); + ERR_FAIL_COND_V(err, BAKE_ERROR_CANT_CREATE_IMAGE); + ResourceLoader::import(base_path); + Ref<Texture> t = ResourceLoader::load(base_path); //if already loaded, it will be updated on refocus? + ERR_FAIL_COND_V(t.is_null(), BAKE_ERROR_CANT_CREATE_IMAGE); + texture = t; + } + + /* POSTBAKE: Save Light Data */ + + Ref<BakedLightmapData> data; + if (get_light_data().is_valid()) { + data = get_light_data(); + set_light_data(Ref<BakedLightmapData>()); //clear + data->clear(); + } else { + data.instance(); + } + + data->set_light_texture(texture); + data->set_uses_spherical_harmonics(directional); + + for (int i = 0; i < lightmapper->get_bake_mesh_count(); i++) { + Dictionary d = lightmapper->get_bake_mesh_userdata(i); + NodePath np = d["path"]; + int32_t subindex = -1; + if (d.has("subindex")) { + subindex = d["subindex"]; } -#else - mmi->set_owner(get_owner()); + + Rect2 uv_scale = lightmapper->get_bake_mesh_uv_scale(i); + int slice_index = lightmapper->get_bake_mesh_texture_slice(i); + data->add_user(np, uv_scale, slice_index, subindex); + } + + { + // create tetrahedrons + Vector<Vector3> points; + Vector<Color> sh; + points.resize(lightmapper->get_bake_probe_count()); + sh.resize(lightmapper->get_bake_probe_count() * 9); + for (int i = 0; i < lightmapper->get_bake_probe_count(); i++) { + points.write[i] = lightmapper->get_bake_probe_point(i); + Vector<Color> colors = lightmapper->get_bake_probe_sh(i); + ERR_CONTINUE(colors.size() != 9); + for (int j = 0; j < 9; j++) { + sh.write[i * 9 + j] = colors[j]; + } + } + + //Obtain solved simplices + + if (p_bake_step) { + p_bake_step(0.8, TTR("Generating Probe Volumes"), p_bake_userdata, true); + } + Vector<Delaunay3D::OutputSimplex> solved_simplices = Delaunay3D::tetrahedralize(points); + + LocalVector<BSPSimplex> bsp_simplices; + LocalVector<Plane> bsp_planes; + LocalVector<int32_t> bsp_simplex_indices; + PackedInt32Array tetrahedrons; + + for (int i = 0; i < solved_simplices.size(); i++) { + + //Prepare a special representation of the simplex, which uses a BSP Tree + BSPSimplex bsp_simplex; + for (int j = 0; j < 4; j++) { + bsp_simplex.vertices[j] = solved_simplices[i].points[j]; + } + for (int j = 0; j < 4; j++) { + static const int face_order[4][3] = { + { 0, 1, 2 }, + { 0, 2, 3 }, + { 0, 1, 3 }, + { 1, 2, 3 } + }; + Vector3 a = points[solved_simplices[i].points[face_order[j][0]]]; + Vector3 b = points[solved_simplices[i].points[face_order[j][1]]]; + Vector3 c = points[solved_simplices[i].points[face_order[j][2]]]; + + //store planes in an array, but ensure they are reused, to speed up processing + + Plane p(a, b, c); + int plane_index = -1; + for (uint32_t k = 0; k < bsp_planes.size(); k++) { + + if (bsp_planes[k].is_equal_approx_any_side(p)) { + plane_index = k; + break; + } + } + + if (plane_index == -1) { + plane_index = bsp_planes.size(); + bsp_planes.push_back(p); + } + + bsp_simplex.planes[j] = plane_index; + + //also fill simplex array + tetrahedrons.push_back(solved_simplices[i].points[j]); + } + + bsp_simplex_indices.push_back(bsp_simplices.size()); + bsp_simplices.push_back(bsp_simplex); + } + +//#define DEBUG_SIMPLICES_AS_OBJ_FILE +#ifdef DEBUG_SIMPLICES_AS_OBJ_FILE + { + FileAccessRef f = FileAccess::open("res://bsp.obj", FileAccess::WRITE); + for (uint32_t i = 0; i < bsp_simplices.size(); i++) { + f->store_line("o Simplex" + itos(i)); + for (int j = 0; j < 4; j++) { + f->store_line(vformat("v %f %f %f", points[bsp_simplices[i].vertices[j]].x, points[bsp_simplices[i].vertices[j]].y, points[bsp_simplices[i].vertices[j]].z)); + } + static const int face_order[4][3] = { + { 1, 2, 3 }, + { 1, 3, 4 }, + { 1, 2, 4 }, + { 2, 3, 4 } + }; + + for (int j = 0; j < 4; j++) { + f->store_line(vformat("f %d %d %d", 4 * i + face_order[j][0], 4 * i + face_order[j][1], 4 * i + face_order[j][2])); + } + } + f->close(); + } +#endif + + LocalVector<BSPNode> bsp_nodes; + LocalVector<int32_t> planes_tested; + planes_tested.resize(bsp_planes.size()); + for (uint32_t i = 0; i < planes_tested.size(); i++) { + planes_tested[i] = 0x7FFFFFFF; + } + + if (p_bake_step) { + p_bake_step(0.9, TTR("Generating Probe Acceleration Structures"), p_bake_userdata, true); + } + + _compute_bsp_tree(points, bsp_planes, planes_tested, bsp_simplices, bsp_simplex_indices, bsp_nodes); + + PackedInt32Array bsp_array; + bsp_array.resize(bsp_nodes.size() * 6); // six 32 bits values used for each BSP node + { + float *fptr = (float *)bsp_array.ptrw(); + int32_t *iptr = (int32_t *)bsp_array.ptrw(); + for (uint32_t i = 0; i < bsp_nodes.size(); i++) { + fptr[i * 6 + 0] = bsp_nodes[i].plane.normal.x; + fptr[i * 6 + 1] = bsp_nodes[i].plane.normal.y; + fptr[i * 6 + 2] = bsp_nodes[i].plane.normal.z; + fptr[i * 6 + 3] = bsp_nodes[i].plane.d; + iptr[i * 6 + 4] = bsp_nodes[i].over; + iptr[i * 6 + 5] = bsp_nodes[i].under; + } +//#define DEBUG_BSP_TREE +#ifdef DEBUG_BSP_TREE + FileAccessRef f = FileAccess::open("res://bsp.txt", FileAccess::WRITE); + for (uint32_t i = 0; i < bsp_nodes.size(); i++) { + f->store_line(itos(i) + " - plane: " + bsp_nodes[i].plane + " over: " + itos(bsp_nodes[i].over) + " under: " + itos(bsp_nodes[i].under)); + } #endif + } + + /* Obtain the colors from the images, they will be re-created as cubemaps on the server, depending on the driver */ + + data->set_capture_data(bounds, interior, points, sh, tetrahedrons, bsp_array); + /* Compute a BSP tree of the simplices, so it's easy to find the exact one */ + } + + Error err = ResourceSaver::save(p_image_data_path, data); + data->set_path(p_image_data_path); + + if (err != OK) { + return BAKE_ERROR_CANT_CREATE_IMAGE; } - set_light_data(new_light_data); + set_light_data(data); return BAKE_ERROR_OK; } void BakedLightmap::_notification(int p_what) { - if (p_what == NOTIFICATION_READY) { + if (p_what == NOTIFICATION_POST_ENTER_TREE) { if (light_data.is_valid()) { _assign_lightmaps(); } - request_ready(); //will need ready again if re-enters tree } if (p_what == NOTIFICATION_EXIT_TREE) { @@ -667,20 +1237,18 @@ void BakedLightmap::_assign_lightmaps() { ERR_FAIL_COND(!light_data.is_valid()); for (int i = 0; i < light_data->get_user_count(); i++) { - Ref<Texture2D> lightmap = light_data->get_user_lightmap(i); - ERR_CONTINUE(!lightmap.is_valid()); Node *node = get_node(light_data->get_user_path(i)); - int instance_idx = light_data->get_user_instance(i); + int instance_idx = light_data->get_user_sub_instance(i); if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - RS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), lightmap->get_rid()); + RS::get_singleton()->instance_geometry_set_lightmap(instance, get_instance(), light_data->get_user_lightmap_uv_scale(i), light_data->get_user_lightmap_slice_index(i)); } } else { - VisualInstance *vi = Object::cast_to<VisualInstance>(node); + VisualInstance3D *vi = Object::cast_to<VisualInstance3D>(node); ERR_CONTINUE(!vi); - RS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), lightmap->get_rid()); + RS::get_singleton()->instance_geometry_set_lightmap(vi->get_instance(), get_instance(), light_data->get_user_lightmap_uv_scale(i), light_data->get_user_lightmap_slice_index(i)); } } } @@ -689,16 +1257,16 @@ void BakedLightmap::_clear_lightmaps() { ERR_FAIL_COND(!light_data.is_valid()); for (int i = 0; i < light_data->get_user_count(); i++) { Node *node = get_node(light_data->get_user_path(i)); - int instance_idx = light_data->get_user_instance(i); + int instance_idx = light_data->get_user_sub_instance(i); if (instance_idx >= 0) { RID instance = node->call("get_bake_mesh_instance", instance_idx); if (instance.is_valid()) { - RS::get_singleton()->instance_set_use_lightmap(instance, get_instance(), RID()); + RS::get_singleton()->instance_geometry_set_lightmap(instance, RID(), Rect2(), 0); } } else { - VisualInstance *vi = Object::cast_to<VisualInstance>(node); + VisualInstance3D *vi = Object::cast_to<VisualInstance3D>(node); ERR_CONTINUE(!vi); - RS::get_singleton()->instance_set_use_lightmap(vi->get_instance(), get_instance(), RID()); + RS::get_singleton()->instance_geometry_set_lightmap(vi->get_instance(), RID(), Rect2(), 0); } } } @@ -719,6 +1287,8 @@ void BakedLightmap::set_light_data(const Ref<BakedLightmapData> &p_data) { _assign_lightmaps(); } } + + update_gizmo(); } Ref<BakedLightmapData> BakedLightmap::get_light_data() const { @@ -726,57 +1296,122 @@ Ref<BakedLightmapData> BakedLightmap::get_light_data() const { return light_data; } -void BakedLightmap::_debug_bake() { - bake(get_parent(), true); +void BakedLightmap::set_bake_quality(BakeQuality p_quality) { + bake_quality = p_quality; +} + +BakedLightmap::BakeQuality BakedLightmap::get_bake_quality() const { + return bake_quality; } -void BakedLightmap::set_propagation(float p_propagation) { - propagation = p_propagation; +AABB BakedLightmap::get_aabb() const { + return AABB(); +} +Vector<Face3> BakedLightmap::get_faces(uint32_t p_usage_flags) const { + return Vector<Face3>(); } -float BakedLightmap::get_propagation() const { +void BakedLightmap::set_use_denoiser(bool p_enable) { - return propagation; + use_denoiser = p_enable; } -void BakedLightmap::set_energy(float p_energy) { - energy = p_energy; +bool BakedLightmap::is_using_denoiser() const { + + return use_denoiser; } -float BakedLightmap::get_energy() const { +void BakedLightmap::set_directional(bool p_enable) { + directional = p_enable; +} - return energy; +bool BakedLightmap::is_directional() const { + return directional; } -void BakedLightmap::set_bake_quality(BakeQuality p_quality) { - bake_quality = p_quality; +void BakedLightmap::set_interior(bool p_enable) { + interior = p_enable; +} +bool BakedLightmap::is_interior() const { + return interior; } -BakedLightmap::BakeQuality BakedLightmap::get_bake_quality() const { - return bake_quality; +void BakedLightmap::set_environment_mode(EnvironmentMode p_mode) { + environment_mode = p_mode; + _change_notify(); } -void BakedLightmap::set_bake_mode(BakeMode p_mode) { - bake_mode = p_mode; +BakedLightmap::EnvironmentMode BakedLightmap::get_environment_mode() const { + return environment_mode; } -BakedLightmap::BakeMode BakedLightmap::get_bake_mode() const { - return bake_mode; +void BakedLightmap::set_environment_custom_sky(const Ref<Sky> &p_sky) { + environment_custom_sky = p_sky; } -void BakedLightmap::set_image_path(const String &p_path) { - image_path = p_path; +Ref<Sky> BakedLightmap::get_environment_custom_sky() const { + return environment_custom_sky; } -String BakedLightmap::get_image_path() const { - return image_path; +void BakedLightmap::set_environment_custom_color(const Color &p_color) { + environment_custom_color = p_color; +} +Color BakedLightmap::get_environment_custom_color() const { + return environment_custom_color; } -AABB BakedLightmap::get_aabb() const { - return AABB(-extents, extents * 2); +void BakedLightmap::set_environment_custom_energy(float p_energy) { + environment_custom_energy = p_energy; } -Vector<Face3> BakedLightmap::get_faces(uint32_t p_usage_flags) const { - return Vector<Face3>(); +float BakedLightmap::get_environment_custom_energy() const { + return environment_custom_energy; +} + +void BakedLightmap::set_bounces(int p_bounces) { + ERR_FAIL_COND(p_bounces < 0 || p_bounces > 16); + bounces = p_bounces; +} + +int BakedLightmap::get_bounces() const { + return bounces; +} + +void BakedLightmap::set_bias(float p_bias) { + ERR_FAIL_COND(p_bias < 0.00001); + bias = p_bias; +} + +float BakedLightmap::get_bias() const { + return bias; +} + +void BakedLightmap::set_max_texture_size(int p_size) { + ERR_FAIL_COND(p_size < 2048); + max_texture_size = p_size; +} + +int BakedLightmap::get_max_texture_size() const { + return max_texture_size; +} + +void BakedLightmap::set_generate_probes(GenerateProbes p_generate_probes) { + gen_probes = p_generate_probes; +} + +BakedLightmap::GenerateProbes BakedLightmap::get_generate_probes() const { + return gen_probes; +} + +void BakedLightmap::_validate_property(PropertyInfo &property) const { + if (property.name == "environment_custom_sky" && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } + if (property.name == "environment_custom_color" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR) { + property.usage = 0; + } + if (property.name == "environment_custom_energy" && environment_mode != ENVIRONMENT_MODE_CUSTOM_COLOR && environment_mode != ENVIRONMENT_MODE_CUSTOM_SKY) { + property.usage = 0; + } } void BakedLightmap::_bind_methods() { @@ -784,81 +1419,100 @@ void BakedLightmap::_bind_methods() { ClassDB::bind_method(D_METHOD("set_light_data", "data"), &BakedLightmap::set_light_data); ClassDB::bind_method(D_METHOD("get_light_data"), &BakedLightmap::get_light_data); - ClassDB::bind_method(D_METHOD("set_bake_cell_size", "bake_cell_size"), &BakedLightmap::set_bake_cell_size); - ClassDB::bind_method(D_METHOD("get_bake_cell_size"), &BakedLightmap::get_bake_cell_size); - - ClassDB::bind_method(D_METHOD("set_capture_cell_size", "capture_cell_size"), &BakedLightmap::set_capture_cell_size); - ClassDB::bind_method(D_METHOD("get_capture_cell_size"), &BakedLightmap::get_capture_cell_size); - ClassDB::bind_method(D_METHOD("set_bake_quality", "bake_quality"), &BakedLightmap::set_bake_quality); ClassDB::bind_method(D_METHOD("get_bake_quality"), &BakedLightmap::get_bake_quality); - ClassDB::bind_method(D_METHOD("set_bake_mode", "bake_mode"), &BakedLightmap::set_bake_mode); - ClassDB::bind_method(D_METHOD("get_bake_mode"), &BakedLightmap::get_bake_mode); + ClassDB::bind_method(D_METHOD("set_bounces", "bounces"), &BakedLightmap::set_bounces); + ClassDB::bind_method(D_METHOD("get_bounces"), &BakedLightmap::get_bounces); + + ClassDB::bind_method(D_METHOD("set_generate_probes", "subdivision"), &BakedLightmap::set_generate_probes); + ClassDB::bind_method(D_METHOD("get_generate_probes"), &BakedLightmap::get_generate_probes); + + ClassDB::bind_method(D_METHOD("set_bias", "bias"), &BakedLightmap::set_bias); + ClassDB::bind_method(D_METHOD("get_bias"), &BakedLightmap::get_bias); - ClassDB::bind_method(D_METHOD("set_extents", "extents"), &BakedLightmap::set_extents); - ClassDB::bind_method(D_METHOD("get_extents"), &BakedLightmap::get_extents); + ClassDB::bind_method(D_METHOD("set_environment_mode", "mode"), &BakedLightmap::set_environment_mode); + ClassDB::bind_method(D_METHOD("get_environment_mode"), &BakedLightmap::get_environment_mode); - ClassDB::bind_method(D_METHOD("set_bake_default_texels_per_unit", "texels"), &BakedLightmap::set_bake_default_texels_per_unit); - ClassDB::bind_method(D_METHOD("get_bake_default_texels_per_unit"), &BakedLightmap::get_bake_default_texels_per_unit); + ClassDB::bind_method(D_METHOD("set_environment_custom_sky", "sky"), &BakedLightmap::set_environment_custom_sky); + ClassDB::bind_method(D_METHOD("get_environment_custom_sky"), &BakedLightmap::get_environment_custom_sky); - ClassDB::bind_method(D_METHOD("set_propagation", "propagation"), &BakedLightmap::set_propagation); - ClassDB::bind_method(D_METHOD("get_propagation"), &BakedLightmap::get_propagation); + ClassDB::bind_method(D_METHOD("set_environment_custom_color", "color"), &BakedLightmap::set_environment_custom_color); + ClassDB::bind_method(D_METHOD("get_environment_custom_color"), &BakedLightmap::get_environment_custom_color); - ClassDB::bind_method(D_METHOD("set_energy", "energy"), &BakedLightmap::set_energy); - ClassDB::bind_method(D_METHOD("get_energy"), &BakedLightmap::get_energy); + ClassDB::bind_method(D_METHOD("set_environment_custom_energy", "energy"), &BakedLightmap::set_environment_custom_energy); + ClassDB::bind_method(D_METHOD("get_environment_custom_energy"), &BakedLightmap::get_environment_custom_energy); - ClassDB::bind_method(D_METHOD("set_hdr", "hdr"), &BakedLightmap::set_hdr); - ClassDB::bind_method(D_METHOD("is_hdr"), &BakedLightmap::is_hdr); + ClassDB::bind_method(D_METHOD("set_max_texture_size", "max_texture_size"), &BakedLightmap::set_max_texture_size); + ClassDB::bind_method(D_METHOD("get_max_texture_size"), &BakedLightmap::get_max_texture_size); - ClassDB::bind_method(D_METHOD("set_image_path", "image_path"), &BakedLightmap::set_image_path); - ClassDB::bind_method(D_METHOD("get_image_path"), &BakedLightmap::get_image_path); + ClassDB::bind_method(D_METHOD("set_use_denoiser", "use_denoiser"), &BakedLightmap::set_use_denoiser); + ClassDB::bind_method(D_METHOD("is_using_denoiser"), &BakedLightmap::is_using_denoiser); - ClassDB::bind_method(D_METHOD("bake", "from_node", "create_visual_debug"), &BakedLightmap::bake, DEFVAL(Variant()), DEFVAL(false)); - ClassDB::bind_method(D_METHOD("debug_bake"), &BakedLightmap::_debug_bake); - ClassDB::set_method_flags(get_class_static(), _scs_create("debug_bake"), METHOD_FLAGS_DEFAULT | METHOD_FLAG_EDITOR); + ClassDB::bind_method(D_METHOD("set_interior", "enable"), &BakedLightmap::set_interior); + ClassDB::bind_method(D_METHOD("is_interior"), &BakedLightmap::is_interior); - ADD_GROUP("Bake", "bake_"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_bake_cell_size", "get_bake_cell_size"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_quality", PROPERTY_HINT_ENUM, "Low,Medium,High"), "set_bake_quality", "get_bake_quality"); - ADD_PROPERTY(PropertyInfo(Variant::INT, "bake_mode", PROPERTY_HINT_ENUM, "ConeTrace,RayTrace"), "set_bake_mode", "get_bake_mode"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_propagation", PROPERTY_HINT_RANGE, "0,1,0.01"), "set_propagation", "get_propagation"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_energy", PROPERTY_HINT_RANGE, "0,32,0.01"), "set_energy", "get_energy"); - ADD_PROPERTY(PropertyInfo(Variant::BOOL, "bake_hdr"), "set_hdr", "is_hdr"); - ADD_PROPERTY(PropertyInfo(Variant::VECTOR3, "bake_extents"), "set_extents", "get_extents"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bake_default_texels_per_unit"), "set_bake_default_texels_per_unit", "get_bake_default_texels_per_unit"); - ADD_GROUP("Capture", "capture_"); - ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "capture_cell_size", PROPERTY_HINT_RANGE, "0.01,64,0.01"), "set_capture_cell_size", "get_capture_cell_size"); + ClassDB::bind_method(D_METHOD("set_directional", "directional"), &BakedLightmap::set_directional); + ClassDB::bind_method(D_METHOD("is_directional"), &BakedLightmap::is_directional); + + // ClassDB::bind_method(D_METHOD("bake", "from_node"), &BakedLightmap::bake, DEFVAL(Variant())); + + ADD_GROUP("Tweaks", ""); + ADD_PROPERTY(PropertyInfo(Variant::INT, "quality", PROPERTY_HINT_ENUM, "Low,Medium,High,Ultra"), "set_bake_quality", "get_bake_quality"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "bounces", PROPERTY_HINT_RANGE, "0,16,1"), "set_bounces", "get_bounces"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "directional"), "set_directional", "is_directional"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "interior"), "set_interior", "is_interior"); + ADD_PROPERTY(PropertyInfo(Variant::BOOL, "use_denoiser"), "set_use_denoiser", "is_using_denoiser"); + ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "bias", PROPERTY_HINT_RANGE, "0.00001,0.1,0.00001,or_greater"), "set_bias", "get_bias"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "max_texture_size"), "set_max_texture_size", "get_max_texture_size"); + ADD_GROUP("Environment", "environment_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "environment_mode", PROPERTY_HINT_ENUM, "Disabled,Scene,Custom Sky,Custom Color"), "set_environment_mode", "get_environment_mode"); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "environment_custom_sky", PROPERTY_HINT_RESOURCE_TYPE, "Sky"), "set_environment_custom_sky", "get_environment_custom_sky"); + ADD_PROPERTY(PropertyInfo(Variant::COLOR, "environment_custom_color", PROPERTY_HINT_COLOR_NO_ALPHA), "set_environment_custom_color", "get_environment_custom_color"); + ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "environment_custom_energy", PROPERTY_HINT_RANGE, "0,64,0.01"), "set_environment_custom_energy", "get_environment_custom_energy"); + ADD_GROUP("Gen Probes", "generate_probes_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "generate_probes_subdiv", PROPERTY_HINT_ENUM, "Disabled,4,8,16,32"), "set_generate_probes", "get_generate_probes"); ADD_GROUP("Data", ""); - ADD_PROPERTY(PropertyInfo(Variant::STRING, "image_path", PROPERTY_HINT_DIR), "set_image_path", "get_image_path"); ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "light_data", PROPERTY_HINT_RESOURCE_TYPE, "BakedLightmapData"), "set_light_data", "get_light_data"); BIND_ENUM_CONSTANT(BAKE_QUALITY_LOW); BIND_ENUM_CONSTANT(BAKE_QUALITY_MEDIUM); BIND_ENUM_CONSTANT(BAKE_QUALITY_HIGH); - BIND_ENUM_CONSTANT(BAKE_MODE_CONE_TRACE); - BIND_ENUM_CONSTANT(BAKE_MODE_RAY_TRACE); + BIND_ENUM_CONSTANT(BAKE_QUALITY_ULTRA); + + BIND_ENUM_CONSTANT(GENERATE_PROBES_DISABLED); + BIND_ENUM_CONSTANT(GENERATE_PROBES_SUBDIV_4); + BIND_ENUM_CONSTANT(GENERATE_PROBES_SUBDIV_8); + BIND_ENUM_CONSTANT(GENERATE_PROBES_SUBDIV_16); + BIND_ENUM_CONSTANT(GENERATE_PROBES_SUBDIV_32); BIND_ENUM_CONSTANT(BAKE_ERROR_OK); + BIND_ENUM_CONSTANT(BAKE_ERROR_NO_LIGHTMAPPER); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_SAVE_PATH); BIND_ENUM_CONSTANT(BAKE_ERROR_NO_MESHES); + BIND_ENUM_CONSTANT(BAKE_ERROR_MESHES_INVALID); BIND_ENUM_CONSTANT(BAKE_ERROR_CANT_CREATE_IMAGE); BIND_ENUM_CONSTANT(BAKE_ERROR_USER_ABORTED); + + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_DISABLED); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_SCENE); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_SKY); + BIND_ENUM_CONSTANT(ENVIRONMENT_MODE_CUSTOM_COLOR); } BakedLightmap::BakedLightmap() { - extents = Vector3(10, 10, 10); - bake_default_texels_per_unit = 20; - bake_cell_size = 0.25; - capture_cell_size = 0.5; + environment_mode = ENVIRONMENT_MODE_DISABLED; + environment_custom_color = Color(0.2, 0.7, 1.0); + environment_custom_energy = 1.0; bake_quality = BAKE_QUALITY_MEDIUM; - bake_mode = BAKE_MODE_CONE_TRACE; - energy = 1; - propagation = 1; - hdr = false; - image_path = "."; - set_disable_scale(true); + interior = false; + directional = false; + + gen_probes = GENERATE_PROBES_DISABLED; + use_denoiser = true; + bounces = 1; + bias = 0.0005; + max_texture_size = 16384; } -#endif diff --git a/scene/3d/baked_lightmap.h b/scene/3d/baked_lightmap.h index bc9e3f55ea..748fdf913f 100644 --- a/scene/3d/baked_lightmap.h +++ b/scene/3d/baked_lightmap.h @@ -28,189 +28,257 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#if 0 -#ifndef BAKED_INDIRECT_LIGHT_H -#define BAKED_INDIRECT_LIGHT_H +#ifndef BAKED_LIGHTMAP_H +#define BAKED_LIGHTMAP_H -#include "multimesh_instance.h" -#include "scene/3d/light.h" -#include "scene/3d/visual_instance.h" +#include "core/local_vector.h" +#include "scene/3d/light_3d.h" +#include "scene/3d/lightmapper.h" +#include "scene/3d/mesh_instance_3d.h" +#include "scene/3d/multimesh_instance_3d.h" +#include "scene/3d/visual_instance_3d.h" +#include "scene/resources/sky.h" class BakedLightmapData : public Resource { GDCLASS(BakedLightmapData, Resource); + RES_BASE_EXTENSION("lmbake") - RID baked_light; + Ref<TextureLayered> light_texture; + + bool uses_spherical_harmonics = false; + bool interior = false; + + RID lightmap; AABB bounds; - float energy; - int cell_subdiv; - Transform cell_space_xform; struct User { NodePath path; - Ref<Texture2D> lightmap; - int instance_index; + int32_t sub_instance; + Rect2 uv_scale; + int slice_index; }; Vector<User> users; void _set_user_data(const Array &p_data); Array _get_user_data() const; + void _set_probe_data(const Dictionary &p_data); + Dictionary _get_probe_data() const; protected: static void _bind_methods(); public: - void set_bounds(const AABB &p_bounds); - AABB get_bounds() const; + void add_user(const NodePath &p_path, const Rect2 &p_uv_scale, int p_slice_index, int32_t p_sub_instance = -1); + int get_user_count() const; + NodePath get_user_path(int p_user) const; + int32_t get_user_sub_instance(int p_user) const; + Rect2 get_user_lightmap_uv_scale(int p_user) const; + int get_user_lightmap_slice_index(int p_user) const; + void clear_users(); - void set_octree(const Vector<uint8_t> &p_octree); - Vector<uint8_t> get_octree() const; + void set_light_texture(const Ref<TextureLayered> &p_light_texture); + Ref<TextureLayered> get_light_texture() const; - void set_cell_space_transform(const Transform &p_xform); - Transform get_cell_space_transform() const; + void set_uses_spherical_harmonics(bool p_enable); + bool is_using_spherical_harmonics() const; - void set_cell_subdiv(int p_cell_subdiv); - int get_cell_subdiv() const; + bool is_interior() const; - void set_energy(float p_energy); - float get_energy() const; + void set_capture_data(const AABB &p_bounds, bool p_interior, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree); + PackedVector3Array get_capture_points() const; + PackedColorArray get_capture_sh() const; + PackedInt32Array get_capture_tetrahedra() const; + PackedInt32Array get_capture_bsp_tree() const; + AABB get_capture_bounds() const; - void add_user(const NodePath &p_path, const Ref<Texture2D> &p_lightmap, int p_instance = -1); - int get_user_count() const; - NodePath get_user_path(int p_user) const; - Ref<Texture2D> get_user_lightmap(int p_user) const; - int get_user_instance(int p_user) const; - void clear_users(); + void clear(); virtual RID get_rid() const; BakedLightmapData(); ~BakedLightmapData(); }; -class BakedLightmap : public VisualInstance { - GDCLASS(BakedLightmap, VisualInstance); +class BakedLightmap : public VisualInstance3D { + GDCLASS(BakedLightmap, VisualInstance3D); public: enum BakeQuality { BAKE_QUALITY_LOW, BAKE_QUALITY_MEDIUM, - BAKE_QUALITY_HIGH + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA, }; - enum BakeMode { - BAKE_MODE_CONE_TRACE, - BAKE_MODE_RAY_TRACE, + enum GenerateProbes { + GENERATE_PROBES_DISABLED, + GENERATE_PROBES_SUBDIV_4, + GENERATE_PROBES_SUBDIV_8, + GENERATE_PROBES_SUBDIV_16, + GENERATE_PROBES_SUBDIV_32, }; enum BakeError { BAKE_ERROR_OK, + BAKE_ERROR_NO_LIGHTMAPPER, BAKE_ERROR_NO_SAVE_PATH, BAKE_ERROR_NO_MESHES, + BAKE_ERROR_MESHES_INVALID, BAKE_ERROR_CANT_CREATE_IMAGE, - BAKE_ERROR_USER_ABORTED - + BAKE_ERROR_USER_ABORTED, }; - typedef void (*BakeBeginFunc)(int); - typedef bool (*BakeStepFunc)(int, const String &); - typedef void (*BakeEndFunc)(); + enum EnvironmentMode { + ENVIRONMENT_MODE_DISABLED, + ENVIRONMENT_MODE_SCENE, + ENVIRONMENT_MODE_CUSTOM_SKY, + ENVIRONMENT_MODE_CUSTOM_COLOR, + }; private: - float bake_cell_size; - float capture_cell_size; - Vector3 extents; - float bake_default_texels_per_unit; - float propagation; - float energy; BakeQuality bake_quality; - BakeMode bake_mode; - bool hdr; - String image_path; + bool use_denoiser; + int bounces; + float bias; + int max_texture_size; + bool interior; + EnvironmentMode environment_mode; + Ref<Sky> environment_custom_sky; + Color environment_custom_color; + float environment_custom_energy; + bool directional; + GenerateProbes gen_probes; Ref<BakedLightmapData> light_data; - struct PlotMesh { - Ref<Material> override_material; - Vector<Ref<Material> > instance_materials; - Ref<Mesh> mesh; - Transform local_xform; - NodePath path; - int instance_idx; + struct LightsFound { + Transform xform; + Light3D *light; }; - struct PlotLight { - Light *light; - Transform local_xform; + struct MeshesFound { + Transform xform; + NodePath node_path; + int32_t subindex; + Ref<Mesh> mesh; + int32_t lightmap_scale; + Vector<Ref<Material>> overrides; }; - void _find_meshes_and_lights(Node *p_at_node, List<PlotMesh> &plot_meshes, List<PlotLight> &plot_lights); - - void _debug_bake(); + void _find_meshes_and_lights(Node *p_at_node, Vector<MeshesFound> &meshes, Vector<LightsFound> &lights, Vector<Vector3> &probes); void _assign_lightmaps(); void _clear_lightmaps(); - static bool _bake_time(void *ud, float p_secs, float p_progress); - struct BakeTimeData { String text; int pass; uint64_t last_step; }; + struct BSPSimplex { + int vertices[4]; + int planes[4]; + }; + + struct BSPNode { + static const int32_t EMPTY_LEAF = INT32_MIN; + Plane plane; + int32_t over = EMPTY_LEAF, under = EMPTY_LEAF; + }; + + int _bsp_get_simplex_side(const Vector<Vector3> &p_points, const LocalVector<BSPSimplex> &p_simplices, const Plane &p_plane, uint32_t p_simplex) const; + int32_t _compute_bsp_tree(const Vector<Vector3> &p_points, const LocalVector<Plane> &p_planes, LocalVector<int32_t> &planes_tested, const LocalVector<BSPSimplex> &p_simplices, const LocalVector<int32_t> &p_simplex_indices, LocalVector<BSPNode> &bsp_nodes); + + struct BakeStepUD { + Lightmapper::BakeStepFunc func; + void *ud; + float from_percent; + float to_percent; + }; + + static bool _lightmap_bake_step_function(float p_completion, const String &p_text, void *ud, bool p_refresh); + + struct GenProbesOctree { + Vector3i offset; + uint32_t size; + GenProbesOctree *children[8] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + ~GenProbesOctree() { + for (int i = 0; i < 8; i++) { + if (children[i] != nullptr) { + memdelete(children[i]); + } + } + } + }; + + struct Vector3iHash { + _FORCE_INLINE_ static uint32_t hash(const Vector3i &p_vtx) { + uint32_t h = hash_djb2_one_32(p_vtx.x); + h = hash_djb2_one_32(p_vtx.y, h); + return hash_djb2_one_32(p_vtx.z, h); + } + }; + + void _plot_triangle_into_octree(GenProbesOctree *p_cell, float p_cell_size, const Vector3 *p_triangle); + void _gen_new_positions_from_octree(const GenProbesOctree *p_cell, float p_cell_size, const Vector<Vector3> &probe_positions, LocalVector<Vector3> &new_probe_positions, HashMap<Vector3i, bool, Vector3iHash> &positions_used, const AABB &p_bounds); + protected: + void _validate_property(PropertyInfo &property) const; static void _bind_methods(); void _notification(int p_what); public: - static BakeBeginFunc bake_begin_function; - static BakeStepFunc bake_step_function; - static BakeEndFunc bake_end_function; - void set_light_data(const Ref<BakedLightmapData> &p_data); Ref<BakedLightmapData> get_light_data() const; - void set_bake_cell_size(float p_cell_size); - float get_bake_cell_size() const; + void set_bake_quality(BakeQuality p_quality); + BakeQuality get_bake_quality() const; + + void set_use_denoiser(bool p_enable); + bool is_using_denoiser() const; - void set_capture_cell_size(float p_cell_size); - float get_capture_cell_size() const; + void set_directional(bool p_enable); + bool is_directional() const; - void set_extents(const Vector3 &p_extents); - Vector3 get_extents() const; + void set_interior(bool p_interior); + bool is_interior() const; - void set_bake_default_texels_per_unit(const float &p_bake_texels_per_unit); - float get_bake_default_texels_per_unit() const; + void set_environment_mode(EnvironmentMode p_mode); + EnvironmentMode get_environment_mode() const; - void set_propagation(float p_propagation); - float get_propagation() const; + void set_environment_custom_sky(const Ref<Sky> &p_sky); + Ref<Sky> get_environment_custom_sky() const; - void set_energy(float p_energy); - float get_energy() const; + void set_environment_custom_color(const Color &p_color); + Color get_environment_custom_color() const; - void set_bake_quality(BakeQuality p_quality); - BakeQuality get_bake_quality() const; + void set_environment_custom_energy(float p_energy); + float get_environment_custom_energy() const; + + void set_bounces(int p_bounces); + int get_bounces() const; - void set_bake_mode(BakeMode p_mode); - BakeMode get_bake_mode() const; + void set_bias(float p_bias); + float get_bias() const; - void set_hdr(bool p_enable); - bool is_hdr() const; + void set_max_texture_size(int p_size); + int get_max_texture_size() const; - void set_image_path(const String &p_path); - String get_image_path() const; + void set_generate_probes(GenerateProbes p_generate_probes); + GenerateProbes get_generate_probes() const; AABB get_aabb() const; Vector<Face3> get_faces(uint32_t p_usage_flags) const; - BakeError bake(Node *p_from_node, bool p_create_visual_debug = false); + BakeError bake(Node *p_from_node, String p_image_data_path = "", Lightmapper::BakeStepFunc p_bake_step = nullptr, void *p_bake_userdata = nullptr); BakedLightmap(); }; VARIANT_ENUM_CAST(BakedLightmap::BakeQuality); -VARIANT_ENUM_CAST(BakedLightmap::BakeMode); +VARIANT_ENUM_CAST(BakedLightmap::GenerateProbes); VARIANT_ENUM_CAST(BakedLightmap::BakeError); +VARIANT_ENUM_CAST(BakedLightmap::EnvironmentMode); -#endif -#endif // BAKED_INDIRECT_LIGHT_H +#endif // BAKED_LIGHTMAP_H diff --git a/scene/3d/cpu_particles_3d.cpp b/scene/3d/cpu_particles_3d.cpp index 4c25f55f0b..414dc6f97b 100644 --- a/scene/3d/cpu_particles_3d.cpp +++ b/scene/3d/cpu_particles_3d.cpp @@ -1004,7 +1004,8 @@ void CPUParticles3D::_particles_process(float p_delta) { //scale by scale float base_scale = tex_scale * Math::lerp(parameters[PARAM_SCALE], 1.0f, p.scale_rand * randomness[PARAM_SCALE]); - if (base_scale < 0.000001) base_scale = 0.000001; + if (base_scale < 0.000001) + base_scale = 0.000001; p.transform.basis.scale(Vector3(1, 1, 1) * base_scale); @@ -1253,7 +1254,8 @@ void CPUParticles3D::convert_from_particles(Node *p_particles) { set_param(m_param, material->get_param(ParticlesMaterial::m_param)); \ { \ Ref<CurveTexture> ctex = material->get_param_texture(ParticlesMaterial::m_param); \ - if (ctex.is_valid()) set_param_curve(m_param, ctex->get_curve()); \ + if (ctex.is_valid()) \ + set_param_curve(m_param, ctex->get_curve()); \ } \ set_param_randomness(m_param, material->get_param_randomness(ParticlesMaterial::m_param)); diff --git a/scene/3d/gi_probe.cpp b/scene/3d/gi_probe.cpp index 6d571ee4f2..65a330ddc0 100644 --- a/scene/3d/gi_probe.cpp +++ b/scene/3d/gi_probe.cpp @@ -349,7 +349,7 @@ Vector3 GIProbe::get_extents() const { void GIProbe::_find_meshes(Node *p_at_node, List<PlotMesh> &plot_meshes) { MeshInstance3D *mi = Object::cast_to<MeshInstance3D>(p_at_node); - if (mi && mi->get_flag(GeometryInstance3D::FLAG_USE_BAKED_LIGHT) && mi->is_visible_in_tree()) { + if (mi && mi->get_gi_mode() == GeometryInstance3D::GI_MODE_BAKED && mi->is_visible_in_tree()) { Ref<Mesh> mesh = mi->get_mesh(); if (mesh.is_valid()) { diff --git a/scene/3d/gpu_particles_3d.cpp b/scene/3d/gpu_particles_3d.cpp index 7744c477cb..01886a730f 100644 --- a/scene/3d/gpu_particles_3d.cpp +++ b/scene/3d/gpu_particles_3d.cpp @@ -257,7 +257,8 @@ String GPUParticles3D::get_configuration_warning() const { StandardMaterial3D *spat = Object::cast_to<StandardMaterial3D>(draw_passes[i]->surface_get_material(j).ptr()); anim_material_found = anim_material_found || (spat && spat->get_billboard_mode() == StandardMaterial3D::BILLBOARD_PARTICLES); } - if (anim_material_found) break; + if (anim_material_found) + break; } } diff --git a/scene/3d/light_3d.cpp b/scene/3d/light_3d.cpp index c048f60ebd..0aa0f7e5ac 100644 --- a/scene/3d/light_3d.cpp +++ b/scene/3d/light_3d.cpp @@ -326,9 +326,15 @@ Light3D::Light3D(RenderingServer::LightType p_type) { type = p_type; switch (p_type) { - case RS::LIGHT_DIRECTIONAL: light = RenderingServer::get_singleton()->directional_light_create(); break; - case RS::LIGHT_OMNI: light = RenderingServer::get_singleton()->omni_light_create(); break; - case RS::LIGHT_SPOT: light = RenderingServer::get_singleton()->spot_light_create(); break; + case RS::LIGHT_DIRECTIONAL: + light = RenderingServer::get_singleton()->directional_light_create(); + break; + case RS::LIGHT_OMNI: + light = RenderingServer::get_singleton()->omni_light_create(); + break; + case RS::LIGHT_SPOT: + light = RenderingServer::get_singleton()->spot_light_create(); + break; default: { }; } diff --git a/scene/3d/lightmap_probe.cpp b/scene/3d/lightmap_probe.cpp new file mode 100644 index 0000000000..2da81337f0 --- /dev/null +++ b/scene/3d/lightmap_probe.cpp @@ -0,0 +1,4 @@ +#include "lightmap_probe.h" + +LightmapProbe::LightmapProbe() { +} diff --git a/scene/3d/lightmap_probe.h b/scene/3d/lightmap_probe.h new file mode 100644 index 0000000000..65bc6914f4 --- /dev/null +++ b/scene/3d/lightmap_probe.h @@ -0,0 +1,12 @@ +#ifndef LIGHTMAP_PROBE_H +#define LIGHTMAP_PROBE_H + +#include "scene/3d/node_3d.h" + +class LightmapProbe : public Node3D { + GDCLASS(LightmapProbe, Node3D) +public: + LightmapProbe(); +}; + +#endif // LIGHTMAP_PROBE_H diff --git a/scene/3d/lightmapper.cpp b/scene/3d/lightmapper.cpp new file mode 100644 index 0000000000..53ebd5ee2e --- /dev/null +++ b/scene/3d/lightmapper.cpp @@ -0,0 +1,37 @@ +#include "lightmapper.h" + +LightmapDenoiser *(*LightmapDenoiser::create_function)() = nullptr; + +Ref<LightmapDenoiser> LightmapDenoiser::create() { + if (create_function) { + return Ref<LightmapDenoiser>(create_function()); + } + return Ref<LightmapDenoiser>(); +} + +Lightmapper::CreateFunc Lightmapper::create_custom = nullptr; +Lightmapper::CreateFunc Lightmapper::create_gpu = nullptr; +Lightmapper::CreateFunc Lightmapper::create_cpu = nullptr; + +Ref<Lightmapper> Lightmapper::create() { + Lightmapper *lm = nullptr; + if (create_custom) { + lm = create_custom(); + } + + if (!lm && create_gpu) { + lm = create_gpu(); + } + + if (!lm && create_cpu) { + lm = create_cpu(); + } + if (!lm) { + return Ref<Lightmapper>(); + } else { + return Ref<Lightmapper>(lm); + } +} + +Lightmapper::Lightmapper() { +} diff --git a/scene/3d/lightmapper.h b/scene/3d/lightmapper.h new file mode 100644 index 0000000000..7c052a30a0 --- /dev/null +++ b/scene/3d/lightmapper.h @@ -0,0 +1,90 @@ +#ifndef LIGHTMAPPER_H +#define LIGHTMAPPER_H + +#include "scene/resources/mesh.h" +#include "servers/rendering/rendering_device.h" + +class LightmapDenoiser : public Reference { + GDCLASS(LightmapDenoiser, Reference) +protected: + static LightmapDenoiser *(*create_function)(); + +public: + virtual Ref<Image> denoise_image(const Ref<Image> &p_image) = 0; + static Ref<LightmapDenoiser> create(); +}; + +class Lightmapper : public Reference { + GDCLASS(Lightmapper, Reference) +public: + enum GenerateProbes { + GENERATE_PROBES_DISABLED, + GENERATE_PROBES_SUBDIV_4, + GENERATE_PROBES_SUBDIV_8, + GENERATE_PROBES_SUBDIV_16, + GENERATE_PROBES_SUBDIV_32, + + }; + + enum LightType { + LIGHT_TYPE_DIRECTIONAL, + LIGHT_TYPE_OMNI, + LIGHT_TYPE_SPOT + }; + + enum BakeError { + BAKE_ERROR_LIGHTMAP_TOO_SMALL, + BAKE_ERROR_LIGHTMAP_CANT_PRE_BAKE_MESHES, + BAKE_OK + }; + + enum BakeQuality { + BAKE_QUALITY_LOW, + BAKE_QUALITY_MEDIUM, + BAKE_QUALITY_HIGH, + BAKE_QUALITY_ULTRA, + }; + + typedef Lightmapper *(*CreateFunc)(); + + static CreateFunc create_custom; + static CreateFunc create_gpu; + static CreateFunc create_cpu; + +protected: +public: + typedef bool (*BakeStepFunc)(float, const String &, void *, bool); //step index, step total, step description, userdata + + struct MeshData { + //triangle data + Vector<Vector3> points; + Vector<Vector2> uv2; + Vector<Vector3> normal; + Ref<Image> albedo_on_uv2; + Ref<Image> emission_on_uv2; + Variant userdata; + }; + + virtual void add_mesh(const MeshData &p_mesh) = 0; + virtual void add_directional_light(bool p_static, const Vector3 &p_direction, const Color &p_color, float p_energy, float p_angular_distance) = 0; + virtual void add_omni_light(bool p_static, const Vector3 &p_position, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_size) = 0; + virtual void add_spot_light(bool p_static, const Vector3 &p_position, const Vector3 p_direction, const Color &p_color, float p_energy, float p_range, float p_attenuation, float p_spot_angle, float p_spot_attenuation, float p_size) = 0; + virtual void add_probe(const Vector3 &p_position) = 0; + virtual BakeError bake(BakeQuality p_quality, bool p_use_denoiser, int p_bounces, float p_bias, int p_max_texture_size, bool p_bake_sh, GenerateProbes p_generate_probes, const Ref<Image> &p_environment_panorama, const Basis &p_environment_transform, BakeStepFunc p_step_function = nullptr, void *p_step_userdata = nullptr) = 0; + + virtual int get_bake_texture_count() const = 0; + virtual Ref<Image> get_bake_texture(int p_index) const = 0; + virtual int get_bake_mesh_count() const = 0; + virtual Variant get_bake_mesh_userdata(int p_index) const = 0; + virtual Rect2 get_bake_mesh_uv_scale(int p_index) const = 0; + virtual int get_bake_mesh_texture_slice(int p_index) const = 0; + virtual int get_bake_probe_count() const = 0; + virtual Vector3 get_bake_probe_point(int p_probe) const = 0; + virtual Vector<Color> get_bake_probe_sh(int p_probe) const = 0; + + static Ref<Lightmapper> create(); + + Lightmapper(); +}; + +#endif // LIGHTMAPPER_H diff --git a/scene/3d/navigation_agent_3d.cpp b/scene/3d/navigation_agent_3d.cpp index 0449ab15b7..020d598f00 100644 --- a/scene/3d/navigation_agent_3d.cpp +++ b/scene/3d/navigation_agent_3d.cpp @@ -141,16 +141,7 @@ void NavigationAgent3D::_notification(int p_what) { } } -NavigationAgent3D::NavigationAgent3D() : - agent_parent(nullptr), - navigation(nullptr), - agent(RID()), - target_desired_distance(1.0), - navigation_height_offset(0.0), - path_max_distance(3.0), - velocity_submitted(false), - target_reached(false), - navigation_finished(true) { +NavigationAgent3D::NavigationAgent3D() { agent = NavigationServer3D::get_singleton()->agent_create(); set_neighbor_dist(50.0); set_max_neighbors(10); @@ -306,9 +297,12 @@ String NavigationAgent3D::get_configuration_warning() const { void NavigationAgent3D::update_navigation() { - if (agent_parent == nullptr) return; - if (navigation == nullptr) return; - if (update_frame_id == Engine::get_singleton()->get_physics_frames()) return; + if (agent_parent == nullptr) + return; + if (navigation == nullptr) + return; + if (update_frame_id == Engine::get_singleton()->get_physics_frames()) + return; update_frame_id = Engine::get_singleton()->get_physics_frames(); diff --git a/scene/3d/navigation_agent_3d.h b/scene/3d/navigation_agent_3d.h index 3558b4e51b..6dc375ef24 100644 --- a/scene/3d/navigation_agent_3d.h +++ b/scene/3d/navigation_agent_3d.h @@ -40,31 +40,31 @@ class Navigation3D; class NavigationAgent3D : public Node { GDCLASS(NavigationAgent3D, Node); - Node3D *agent_parent; - Navigation3D *navigation; + Node3D *agent_parent = nullptr; + Navigation3D *navigation = nullptr; RID agent; - real_t target_desired_distance; + real_t target_desired_distance = 1.0; real_t radius; - real_t navigation_height_offset; + real_t navigation_height_offset = 0.0; bool ignore_y; real_t neighbor_dist; int max_neighbors; real_t time_horizon; real_t max_speed; - real_t path_max_distance; + real_t path_max_distance = 3.0; Vector3 target_location; Vector<Vector3> navigation_path; int nav_path_index; - bool velocity_submitted; + bool velocity_submitted = false; Vector3 prev_safe_velocity; /// The submitted target velocity Vector3 target_velocity; - bool target_reached; - bool navigation_finished; + bool target_reached = false; + bool navigation_finished = true; // No initialized on purpose uint32_t update_frame_id; diff --git a/scene/3d/navigation_obstacle_3d.cpp b/scene/3d/navigation_obstacle_3d.cpp index 2ee2008799..2f99a5f99e 100644 --- a/scene/3d/navigation_obstacle_3d.cpp +++ b/scene/3d/navigation_obstacle_3d.cpp @@ -86,9 +86,7 @@ void NavigationObstacle3D::_notification(int p_what) { } } -NavigationObstacle3D::NavigationObstacle3D() : - navigation(nullptr), - agent(RID()) { +NavigationObstacle3D::NavigationObstacle3D() { agent = NavigationServer3D::get_singleton()->agent_create(); } diff --git a/scene/3d/navigation_obstacle_3d.h b/scene/3d/navigation_obstacle_3d.h index b58d7c4991..c7d2b556af 100644 --- a/scene/3d/navigation_obstacle_3d.h +++ b/scene/3d/navigation_obstacle_3d.h @@ -38,7 +38,7 @@ class Navigation3D; class NavigationObstacle3D : public Node { GDCLASS(NavigationObstacle3D, Node); - Navigation3D *navigation; + Navigation3D *navigation = nullptr; RID agent; diff --git a/scene/3d/navigation_region_3d.cpp b/scene/3d/navigation_region_3d.cpp index 043b816033..b15fa6166f 100644 --- a/scene/3d/navigation_region_3d.cpp +++ b/scene/3d/navigation_region_3d.cpp @@ -242,14 +242,8 @@ void NavigationRegion3D::_changed_callback(Object *p_changed, const char *p_prop } NavigationRegion3D::NavigationRegion3D() { - - enabled = true; set_notify_transform(true); region = NavigationServer3D::get_singleton()->region_create(); - - navigation = nullptr; - debug_view = nullptr; - bake_thread = nullptr; } NavigationRegion3D::~NavigationRegion3D() { diff --git a/scene/3d/navigation_region_3d.h b/scene/3d/navigation_region_3d.h index ae071e6b7a..a7b5077f53 100644 --- a/scene/3d/navigation_region_3d.h +++ b/scene/3d/navigation_region_3d.h @@ -41,13 +41,13 @@ class NavigationRegion3D : public Node3D { GDCLASS(NavigationRegion3D, Node3D); - bool enabled; + bool enabled = true; RID region; Ref<NavigationMesh> navmesh; - Navigation3D *navigation; - Node *debug_view; - Thread *bake_thread; + Navigation3D *navigation = nullptr; + Node *debug_view = nullptr; + Thread *bake_thread = nullptr; protected: void _notification(int p_what); diff --git a/scene/3d/physics_body_3d.cpp b/scene/3d/physics_body_3d.cpp index 3991efc7c0..d672c6f6b5 100644 --- a/scene/3d/physics_body_3d.cpp +++ b/scene/3d/physics_body_3d.cpp @@ -1346,7 +1346,8 @@ Vector3 KinematicCollision3D::get_remainder() const { return collision.remainder; } Object *KinematicCollision3D::get_local_shape() const { - if (!owner) return nullptr; + if (!owner) + return nullptr; uint32_t ownerid = owner->shape_find_owner(collision.local_shape); return owner->shape_owner_get_owner(ownerid); } @@ -2578,24 +2579,7 @@ bool PhysicalBone3D::get_axis_lock(PhysicsServer3D::BodyAxis p_axis) const { } PhysicalBone3D::PhysicalBone3D() : - PhysicsBody3D(PhysicsServer3D::BODY_MODE_STATIC), -#ifdef TOOLS_ENABLED - gizmo_move_joint(false), -#endif - joint_data(nullptr), - parent_skeleton(nullptr), - simulate_physics(false), - _internal_simulate_physics(false), - bone_id(-1), - bone_name(""), - bounce(0), - mass(1), - friction(1), - gravity_scale(1), - linear_damp(-1), - angular_damp(-1), - can_sleep(true) { - + PhysicsBody3D(PhysicsServer3D::BODY_MODE_STATIC) { reset_physics_simulation_state(); } diff --git a/scene/3d/physics_body_3d.h b/scene/3d/physics_body_3d.h index 0e719f5108..205052f798 100644 --- a/scene/3d/physics_body_3d.h +++ b/scene/3d/physics_body_3d.h @@ -396,14 +396,11 @@ public: virtual bool _get(const StringName &p_name, Variant &r_ret) const; virtual void _get_property_list(List<PropertyInfo> *p_list) const; - real_t bias; - real_t damping; - real_t impulse_clamp; - - PinJointData() : - bias(0.3), - damping(1.), - impulse_clamp(0) {} + real_t bias = 0.3; + real_t damping = 1.; + real_t impulse_clamp = 0; + + PinJointData() {} }; struct ConeJointData : public JointData { @@ -414,17 +411,13 @@ public: virtual void _get_property_list(List<PropertyInfo> *p_list) const; real_t swing_span; - real_t twist_span; - real_t bias; - real_t softness; - real_t relaxation; + real_t twist_span = Math_PI; + real_t bias = 0.3; + real_t softness = 0.8; + real_t relaxation = 1.; ConeJointData() : - swing_span(Math_PI * 0.25), - twist_span(Math_PI), - bias(0.3), - softness(0.8), - relaxation(1.) {} + swing_span(Math_PI * 0.25) {} }; struct HingeJointData : public JointData { @@ -434,20 +427,17 @@ public: virtual bool _get(const StringName &p_name, Variant &r_ret) const; virtual void _get_property_list(List<PropertyInfo> *p_list) const; - bool angular_limit_enabled; + bool angular_limit_enabled = false; real_t angular_limit_upper; real_t angular_limit_lower; - real_t angular_limit_bias; - real_t angular_limit_softness; - real_t angular_limit_relaxation; + real_t angular_limit_bias = 0.3; + real_t angular_limit_softness = 0.9; + real_t angular_limit_relaxation = 1.; HingeJointData() : - angular_limit_enabled(false), + angular_limit_upper(Math_PI * 0.5), - angular_limit_lower(-Math_PI * 0.5), - angular_limit_bias(0.3), - angular_limit_softness(0.9), - angular_limit_relaxation(1.) {} + angular_limit_lower(-Math_PI * 0.5) {} }; struct SliderJointData : public JointData { @@ -457,76 +447,45 @@ public: virtual bool _get(const StringName &p_name, Variant &r_ret) const; virtual void _get_property_list(List<PropertyInfo> *p_list) const; - real_t linear_limit_upper; - real_t linear_limit_lower; - real_t linear_limit_softness; - real_t linear_limit_restitution; - real_t linear_limit_damping; - real_t angular_limit_upper; - real_t angular_limit_lower; - real_t angular_limit_softness; - real_t angular_limit_restitution; - real_t angular_limit_damping; - - SliderJointData() : - linear_limit_upper(1.), - linear_limit_lower(-1.), - linear_limit_softness(1.), - linear_limit_restitution(0.7), - linear_limit_damping(1.), - angular_limit_upper(0), - angular_limit_lower(0), - angular_limit_softness(1.), - angular_limit_restitution(0.7), - angular_limit_damping(1.) {} + real_t linear_limit_upper = 1.; + real_t linear_limit_lower = -1.; + real_t linear_limit_softness = 1.; + real_t linear_limit_restitution = 0.7; + real_t linear_limit_damping = 1.; + real_t angular_limit_upper = 0; + real_t angular_limit_lower = 0; + real_t angular_limit_softness = 1.; + real_t angular_limit_restitution = 0.7; + real_t angular_limit_damping = 1.; + + SliderJointData() {} }; struct SixDOFJointData : public JointData { struct SixDOFAxisData { - bool linear_limit_enabled; - real_t linear_limit_upper; - real_t linear_limit_lower; - real_t linear_limit_softness; - real_t linear_restitution; - real_t linear_damping; - bool linear_spring_enabled; - real_t linear_spring_stiffness; - real_t linear_spring_damping; - real_t linear_equilibrium_point; - bool angular_limit_enabled; - real_t angular_limit_upper; - real_t angular_limit_lower; - real_t angular_limit_softness; - real_t angular_restitution; - real_t angular_damping; - real_t erp; - bool angular_spring_enabled; - real_t angular_spring_stiffness; - real_t angular_spring_damping; - real_t angular_equilibrium_point; - - SixDOFAxisData() : - linear_limit_enabled(true), - linear_limit_upper(0), - linear_limit_lower(0), - linear_limit_softness(0.7), - linear_restitution(0.5), - linear_damping(1.), - linear_spring_enabled(false), - linear_spring_stiffness(0), - linear_spring_damping(0), - linear_equilibrium_point(0), - angular_limit_enabled(true), - angular_limit_upper(0), - angular_limit_lower(0), - angular_limit_softness(0.5), - angular_restitution(0), - angular_damping(1.), - erp(0.5), - angular_spring_enabled(false), - angular_spring_stiffness(0), - angular_spring_damping(0.), - angular_equilibrium_point(0) {} + bool linear_limit_enabled = true; + real_t linear_limit_upper = 0; + real_t linear_limit_lower = 0; + real_t linear_limit_softness = 0.7; + real_t linear_restitution = 0.5; + real_t linear_damping = 1.; + bool linear_spring_enabled = false; + real_t linear_spring_stiffness = 0; + real_t linear_spring_damping = 0; + real_t linear_equilibrium_point = 0; + bool angular_limit_enabled = true; + real_t angular_limit_upper = 0; + real_t angular_limit_lower = 0; + real_t angular_limit_softness = 0.5; + real_t angular_restitution = 0; + real_t angular_damping = 1.; + real_t erp = 0.5; + bool angular_spring_enabled = false; + real_t angular_spring_stiffness = 0; + real_t angular_spring_damping = 0.; + real_t angular_equilibrium_point = 0; + + SixDOFAxisData() {} }; virtual JointType get_joint_type() { return JOINT_TYPE_6DOF; } @@ -543,28 +502,28 @@ public: private: #ifdef TOOLS_ENABLED // if false gizmo move body - bool gizmo_move_joint; + bool gizmo_move_joint = false; #endif - JointData *joint_data; + JointData *joint_data = nullptr; Transform joint_offset; RID joint; - Skeleton3D *parent_skeleton; + Skeleton3D *parent_skeleton = nullptr; Transform body_offset; Transform body_offset_inverse; - bool simulate_physics; - bool _internal_simulate_physics; - int bone_id; + bool simulate_physics = false; + bool _internal_simulate_physics = false; + int bone_id = -1; String bone_name; - real_t bounce; - real_t mass; - real_t friction; - real_t gravity_scale; - real_t linear_damp; - real_t angular_damp; - bool can_sleep; + real_t bounce = 0; + real_t mass = 1; + real_t friction = 1; + real_t gravity_scale = 1; + real_t linear_damp = -1; + real_t angular_damp = -1; + bool can_sleep = true; protected: bool _set(const StringName &p_name, const Variant &p_value); diff --git a/scene/3d/physics_joint_3d.cpp b/scene/3d/physics_joint_3d.cpp index 140d887d9a..b6953fafac 100644 --- a/scene/3d/physics_joint_3d.cpp +++ b/scene/3d/physics_joint_3d.cpp @@ -960,8 +960,7 @@ RID Generic6DOFJoint3D::_configure_joint(PhysicsBody3D *body_a, PhysicsBody3D *b return j; } -Generic6DOFJoint3D::Generic6DOFJoint3D() : - precision(1) { +Generic6DOFJoint3D::Generic6DOFJoint3D() { set_param_x(PARAM_LINEAR_LOWER_LIMIT, 0); set_param_x(PARAM_LINEAR_UPPER_LIMIT, 0); diff --git a/scene/3d/physics_joint_3d.h b/scene/3d/physics_joint_3d.h index ce0c7af5d1..38a3f314ba 100644 --- a/scene/3d/physics_joint_3d.h +++ b/scene/3d/physics_joint_3d.h @@ -305,7 +305,7 @@ protected: float params_z[PARAM_MAX]; bool flags_z[FLAG_MAX]; - int precision; + int precision = 1; virtual RID _configure_joint(PhysicsBody3D *body_a, PhysicsBody3D *body_b); static void _bind_methods(); diff --git a/scene/3d/skeleton_ik_3d.cpp b/scene/3d/skeleton_ik_3d.cpp index 10bdd71d73..5c0e48a5df 100644 --- a/scene/3d/skeleton_ik_3d.cpp +++ b/scene/3d/skeleton_ik_3d.cpp @@ -434,15 +434,7 @@ void SkeletonIK3D::_notification(int p_what) { } } -SkeletonIK3D::SkeletonIK3D() : - interpolation(1), - override_tip_basis(true), - use_magnet(false), - min_distance(0.01), - max_iterations(10), - skeleton(nullptr), - target_node_override(nullptr), - task(nullptr) { +SkeletonIK3D::SkeletonIK3D() { } SkeletonIK3D::~SkeletonIK3D() { diff --git a/scene/3d/skeleton_ik_3d.h b/scene/3d/skeleton_ik_3d.h index 5fbbe6e9e7..ad2623193b 100644 --- a/scene/3d/skeleton_ik_3d.h +++ b/scene/3d/skeleton_ik_3d.h @@ -50,36 +50,30 @@ class FabrikInverseKinematic { struct ChainItem { Vector<ChainItem> children; - ChainItem *parent_item; + ChainItem *parent_item = nullptr; // Bone info - BoneId bone; - PhysicalBone3D *pb; + BoneId bone = -1; + PhysicalBone3D *pb = nullptr; - real_t length; + real_t length = 0; /// Positions relative to root bone Transform initial_transform; Vector3 current_pos; // Direction from this bone to child Vector3 current_ori; - ChainItem() : - parent_item(nullptr), - bone(-1), - pb(nullptr), - length(0) {} + ChainItem() {} ChainItem *find_child(const BoneId p_bone_id); ChainItem *add_child(const BoneId p_bone_id); }; struct ChainTip { - ChainItem *chain_item; - const EndEffector *end_effector; + ChainItem *chain_item = nullptr; + const EndEffector *end_effector = nullptr; - ChainTip() : - chain_item(nullptr), - end_effector(nullptr) {} + ChainTip() {} ChainTip(ChainItem *p_chain_item, const EndEffector *p_end_effector) : chain_item(p_chain_item), @@ -100,25 +94,21 @@ class FabrikInverseKinematic { public: struct Task { RID self; - Skeleton3D *skeleton; + Skeleton3D *skeleton = nullptr; Chain chain; // Settings - real_t min_distance; - int max_iterations; + real_t min_distance = 0.01; + int max_iterations = 10; // Bone data - BoneId root_bone; + BoneId root_bone = -1; Vector<EndEffector> end_effectors; Transform goal_global_transform; - Task() : - skeleton(nullptr), - min_distance(0.01), - max_iterations(10), - root_bone(-1) {} + Task() {} }; private: @@ -146,19 +136,19 @@ class SkeletonIK3D : public Node { StringName root_bone; StringName tip_bone; - real_t interpolation; + real_t interpolation = 1; Transform target; NodePath target_node_path_override; - bool override_tip_basis; - bool use_magnet; + bool override_tip_basis = true; + bool use_magnet = false; Vector3 magnet_position; - real_t min_distance; - int max_iterations; + real_t min_distance = 0.01; + int max_iterations = 10; - Skeleton3D *skeleton; - Node3D *target_node_override; - FabrikInverseKinematic::Task *task; + Skeleton3D *skeleton = nullptr; + Node3D *target_node_override = nullptr; + FabrikInverseKinematic::Task *task = nullptr; protected: virtual void diff --git a/scene/3d/soft_body_3d.cpp b/scene/3d/soft_body_3d.cpp index 850ffab292..91b8b5c859 100644 --- a/scene/3d/soft_body_3d.cpp +++ b/scene/3d/soft_body_3d.cpp @@ -97,9 +97,7 @@ void SoftBodyRenderingServerHandler::set_aabb(const AABB &p_aabb) { RS::get_singleton()->mesh_set_custom_aabb(mesh, p_aabb); } -SoftBody3D::PinnedPoint::PinnedPoint() : - point_index(-1), - spatial_attachment(nullptr) { +SoftBody3D::PinnedPoint::PinnedPoint() { } SoftBody3D::PinnedPoint::PinnedPoint(const PinnedPoint &obj_tocopy) { @@ -702,13 +700,7 @@ bool SoftBody3D::is_ray_pickable() const { } SoftBody3D::SoftBody3D() : - physics_rid(PhysicsServer3D::get_singleton()->soft_body_create()), - mesh_owner(false), - collision_mask(1), - collision_layer(1), - simulation_started(false), - pinned_points_cache_dirty(true), - ray_pickable(true) { + physics_rid(PhysicsServer3D::get_singleton()->soft_body_create()) { PhysicsServer3D::get_singleton()->body_attach_object_instance_id(physics_rid, get_instance_id()); } diff --git a/scene/3d/soft_body_3d.h b/scene/3d/soft_body_3d.h index 7dd5880985..485f7427f8 100644 --- a/scene/3d/soft_body_3d.h +++ b/scene/3d/soft_body_3d.h @@ -68,9 +68,9 @@ class SoftBody3D : public MeshInstance3D { public: struct PinnedPoint { - int point_index; + int point_index = -1; NodePath spatial_attachment_path; - Node3D *spatial_attachment; // Cache + Node3D *spatial_attachment = nullptr; // Cache Vector3 offset; PinnedPoint(); @@ -83,19 +83,19 @@ private: RID physics_rid; - bool mesh_owner; - uint32_t collision_mask; - uint32_t collision_layer; + bool mesh_owner = false; + uint32_t collision_mask = 1; + uint32_t collision_layer = 1; NodePath parent_collision_ignore; Vector<PinnedPoint> pinned_points; - bool simulation_started; - bool pinned_points_cache_dirty; + bool simulation_started = false; + bool pinned_points_cache_dirty = true; Ref<ArrayMesh> debug_mesh_cache; class MeshInstance3D *debug_mesh; bool capture_input_on_drag; - bool ray_pickable; + bool ray_pickable = true; void _update_pickable(); diff --git a/scene/3d/spring_arm_3d.cpp b/scene/3d/spring_arm_3d.cpp index 1410b730fd..f61e6eb2a7 100644 --- a/scene/3d/spring_arm_3d.cpp +++ b/scene/3d/spring_arm_3d.cpp @@ -29,18 +29,12 @@ /*************************************************************************/ #include "spring_arm_3d.h" + #include "core/engine.h" #include "scene/3d/collision_object_3d.h" #include "scene/resources/sphere_shape_3d.h" #include "servers/physics_server_3d.h" -SpringArm3D::SpringArm3D() : - spring_length(1), - current_spring_length(0), - keep_child_basis(false), - mask(1), - margin(0.01) {} - void SpringArm3D::_notification(int p_what) { switch (p_what) { case NOTIFICATION_ENTER_TREE: diff --git a/scene/3d/spring_arm_3d.h b/scene/3d/spring_arm_3d.h index cb8a00ecf9..7f6fe2f1a2 100644 --- a/scene/3d/spring_arm_3d.h +++ b/scene/3d/spring_arm_3d.h @@ -38,11 +38,11 @@ class SpringArm3D : public Node3D { Ref<Shape3D> shape; Set<RID> excluded_objects; - float spring_length; - float current_spring_length; - bool keep_child_basis; - uint32_t mask; - float margin; + float spring_length = 1; + float current_spring_length = 0; + bool keep_child_basis = false; + uint32_t mask = 1; + float margin = 0.01; protected: void _notification(int p_what); @@ -62,7 +62,7 @@ public: void set_margin(float p_margin); float get_margin(); - SpringArm3D(); + SpringArm3D() {} private: void process_spring(); diff --git a/scene/3d/vehicle_body_3d.cpp b/scene/3d/vehicle_body_3d.cpp index 5c2fa59a21..66fcf0e40b 100644 --- a/scene/3d/vehicle_body_3d.cpp +++ b/scene/3d/vehicle_body_3d.cpp @@ -44,7 +44,7 @@ public: real_t getDiagonal() const { return m_Adiag; } - btVehicleJacobianEntry(){}; + btVehicleJacobianEntry() {} //constraint between two different rigidbodies btVehicleJacobianEntry( const Basis &world2A, diff --git a/scene/3d/visual_instance_3d.cpp b/scene/3d/visual_instance_3d.cpp index ce8672c1dd..4724c88a30 100644 --- a/scene/3d/visual_instance_3d.cpp +++ b/scene/3d/visual_instance_3d.cpp @@ -239,7 +239,17 @@ bool GeometryInstance3D::_set(const StringName &p_name, const Variant &p_value) set_shader_instance_uniform(*r, p_value); return true; } +#ifndef DISABLE_DEPRECATED + if (p_name == SceneStringNames::get_singleton()->use_in_baked_light && bool(p_value)) { + set_gi_mode(GI_MODE_BAKED); + return true; + } + if (p_name == SceneStringNames::get_singleton()->use_dynamic_gi && bool(p_value)) { + set_gi_mode(GI_MODE_DYNAMIC); + return true; + } +#endif return false; } @@ -273,23 +283,6 @@ void GeometryInstance3D::_get_property_list(List<PropertyInfo> *p_list) const { } } -void GeometryInstance3D::set_flag(Flags p_flag, bool p_value) { - - ERR_FAIL_INDEX(p_flag, FLAG_MAX); - if (flags[p_flag] == p_value) - return; - - flags[p_flag] = p_value; - RS::get_singleton()->instance_geometry_set_flag(get_instance(), (RS::InstanceFlags)p_flag, p_value); -} - -bool GeometryInstance3D::get_flag(Flags p_flag) const { - - ERR_FAIL_INDEX_V(p_flag, FLAG_MAX, false); - - return flags[p_flag]; -} - void GeometryInstance3D::set_cast_shadows_setting(ShadowCastingSetting p_shadow_casting_setting) { shadow_casting_setting = p_shadow_casting_setting; @@ -335,14 +328,45 @@ void GeometryInstance3D::set_custom_aabb(AABB aabb) { RS::get_singleton()->instance_set_custom_aabb(get_instance(), aabb); } +void GeometryInstance3D::set_lightmap_scale(LightmapScale p_scale) { + ERR_FAIL_INDEX(p_scale, LIGHTMAP_SCALE_MAX); + lightmap_scale = p_scale; +} + +GeometryInstance3D::LightmapScale GeometryInstance3D::get_lightmap_scale() const { + return lightmap_scale; +} + +void GeometryInstance3D::set_gi_mode(GIMode p_mode) { + + switch (p_mode) { + case GI_MODE_DISABLED: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, false); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, false); + } break; + case GI_MODE_BAKED: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, true); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, false); + + } break; + case GI_MODE_DYNAMIC: { + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_BAKED_LIGHT, false); + RS::get_singleton()->instance_geometry_set_flag(get_instance(), RS::INSTANCE_FLAG_USE_DYNAMIC_GI, true); + } break; + } + + gi_mode = p_mode; +} + +GeometryInstance3D::GIMode GeometryInstance3D::get_gi_mode() const { + return gi_mode; +} + void GeometryInstance3D::_bind_methods() { ClassDB::bind_method(D_METHOD("set_material_override", "material"), &GeometryInstance3D::set_material_override); ClassDB::bind_method(D_METHOD("get_material_override"), &GeometryInstance3D::get_material_override); - ClassDB::bind_method(D_METHOD("set_flag", "flag", "value"), &GeometryInstance3D::set_flag); - ClassDB::bind_method(D_METHOD("get_flag", "flag"), &GeometryInstance3D::get_flag); - ClassDB::bind_method(D_METHOD("set_cast_shadows_setting", "shadow_casting_setting"), &GeometryInstance3D::set_cast_shadows_setting); ClassDB::bind_method(D_METHOD("get_cast_shadows_setting"), &GeometryInstance3D::get_cast_shadows_setting); @@ -364,6 +388,12 @@ void GeometryInstance3D::_bind_methods() { ClassDB::bind_method(D_METHOD("set_extra_cull_margin", "margin"), &GeometryInstance3D::set_extra_cull_margin); ClassDB::bind_method(D_METHOD("get_extra_cull_margin"), &GeometryInstance3D::get_extra_cull_margin); + ClassDB::bind_method(D_METHOD("set_lightmap_scale", "scale"), &GeometryInstance3D::set_lightmap_scale); + ClassDB::bind_method(D_METHOD("get_lightmap_scale"), &GeometryInstance3D::get_lightmap_scale); + + ClassDB::bind_method(D_METHOD("set_gi_mode", "mode"), &GeometryInstance3D::set_gi_mode); + ClassDB::bind_method(D_METHOD("get_gi_mode"), &GeometryInstance3D::get_gi_mode); + ClassDB::bind_method(D_METHOD("set_custom_aabb", "aabb"), &GeometryInstance3D::set_custom_aabb); ClassDB::bind_method(D_METHOD("get_aabb"), &GeometryInstance3D::get_aabb); @@ -372,8 +402,9 @@ void GeometryInstance3D::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "material_override", PROPERTY_HINT_RESOURCE_TYPE, "ShaderMaterial,StandardMaterial3D", PROPERTY_USAGE_DEFAULT | PROPERTY_USAGE_DEFERRED_SET_RESOURCE), "set_material_override", "get_material_override"); ADD_PROPERTY(PropertyInfo(Variant::INT, "cast_shadow", PROPERTY_HINT_ENUM, "Off,On,Double-Sided,Shadows Only"), "set_cast_shadows_setting", "get_cast_shadows_setting"); ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "extra_cull_margin", PROPERTY_HINT_RANGE, "0,16384,0.01"), "set_extra_cull_margin", "get_extra_cull_margin"); - ADD_PROPERTYI(PropertyInfo(Variant::BOOL, "use_in_baked_light"), "set_flag", "get_flag", FLAG_USE_BAKED_LIGHT); - ADD_PROPERTYI(PropertyInfo(Variant::BOOL, "use_dynamic_gi"), "set_flag", "get_flag", FLAG_USE_DYNAMIC_GI); + ADD_GROUP("Global Illumination", "gi_"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "gi_mode", PROPERTY_HINT_ENUM, "Disabled,Baked,Dynamic"), "set_gi_mode", "get_gi_mode"); + ADD_PROPERTY(PropertyInfo(Variant::INT, "gi_lightmap_scale", PROPERTY_HINT_ENUM, "1x,2x,4x,8x"), "set_lightmap_scale", "get_lightmap_scale"); ADD_GROUP("LOD", "lod_"); ADD_PROPERTY(PropertyInfo(Variant::INT, "lod_min_distance", PROPERTY_HINT_RANGE, "0,32768,0.01"), "set_lod_min_distance", "get_lod_min_distance"); @@ -388,10 +419,15 @@ void GeometryInstance3D::_bind_methods() { BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_DOUBLE_SIDED); BIND_ENUM_CONSTANT(SHADOW_CASTING_SETTING_SHADOWS_ONLY); - BIND_ENUM_CONSTANT(FLAG_USE_BAKED_LIGHT); - BIND_ENUM_CONSTANT(FLAG_USE_DYNAMIC_GI); - BIND_ENUM_CONSTANT(FLAG_DRAW_NEXT_FRAME_IF_VISIBLE); - BIND_ENUM_CONSTANT(FLAG_MAX); + BIND_ENUM_CONSTANT(GI_MODE_DISABLED); + BIND_ENUM_CONSTANT(GI_MODE_BAKED); + BIND_ENUM_CONSTANT(GI_MODE_DYNAMIC); + + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_1X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_2X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_4X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_8X); + BIND_ENUM_CONSTANT(LIGHTMAP_SCALE_MAX); } GeometryInstance3D::GeometryInstance3D() { @@ -400,9 +436,8 @@ GeometryInstance3D::GeometryInstance3D() { lod_min_hysteresis = 0; lod_max_hysteresis = 0; - for (int i = 0; i < FLAG_MAX; i++) { - flags[i] = false; - } + gi_mode = GI_MODE_DISABLED; + lightmap_scale = LIGHTMAP_SCALE_1X; shadow_casting_setting = SHADOW_CASTING_SETTING_ON; extra_cull_margin = 0; diff --git a/scene/3d/visual_instance_3d.h b/scene/3d/visual_instance_3d.h index cc5f92066f..a871c65b6a 100644 --- a/scene/3d/visual_instance_3d.h +++ b/scene/3d/visual_instance_3d.h @@ -85,13 +85,6 @@ class GeometryInstance3D : public VisualInstance3D { GDCLASS(GeometryInstance3D, VisualInstance3D); public: - enum Flags { - FLAG_USE_BAKED_LIGHT = RS::INSTANCE_FLAG_USE_BAKED_LIGHT, - FLAG_USE_DYNAMIC_GI = RS::INSTANCE_FLAG_USE_DYNAMIC_GI, - FLAG_DRAW_NEXT_FRAME_IF_VISIBLE = RS::INSTANCE_FLAG_DRAW_NEXT_FRAME_IF_VISIBLE, - FLAG_MAX = RS::INSTANCE_FLAG_MAX, - }; - enum ShadowCastingSetting { SHADOW_CASTING_SETTING_OFF = RS::SHADOW_CASTING_SETTING_OFF, SHADOW_CASTING_SETTING_ON = RS::SHADOW_CASTING_SETTING_ON, @@ -99,8 +92,21 @@ public: SHADOW_CASTING_SETTING_SHADOWS_ONLY = RS::SHADOW_CASTING_SETTING_SHADOWS_ONLY }; + enum GIMode { + GI_MODE_DISABLED, + GI_MODE_BAKED, + GI_MODE_DYNAMIC + }; + + enum LightmapScale { + LIGHTMAP_SCALE_1X, + LIGHTMAP_SCALE_2X, + LIGHTMAP_SCALE_4X, + LIGHTMAP_SCALE_8X, + LIGHTMAP_SCALE_MAX, + }; + private: - bool flags[FLAG_MAX]; ShadowCastingSetting shadow_casting_setting; Ref<Material> material_override; float lod_min_distance; @@ -112,6 +118,8 @@ private: mutable HashMap<StringName, StringName> instance_uniform_property_remap; float extra_cull_margin; + LightmapScale lightmap_scale; + GIMode gi_mode; const StringName *_instance_uniform_get_remap(const StringName p_name) const; @@ -124,9 +132,6 @@ protected: static void _bind_methods(); public: - void set_flag(Flags p_flag, bool p_value); - bool get_flag(Flags p_flag) const; - void set_cast_shadows_setting(ShadowCastingSetting p_shadow_casting_setting); ShadowCastingSetting get_cast_shadows_setting() const; @@ -148,6 +153,12 @@ public: void set_extra_cull_margin(float p_margin); float get_extra_cull_margin() const; + void set_gi_mode(GIMode p_mode); + GIMode get_gi_mode() const; + + void set_lightmap_scale(LightmapScale p_scale); + LightmapScale get_lightmap_scale() const; + void set_shader_instance_uniform(const StringName &p_uniform, const Variant &p_value); Variant get_shader_instance_uniform(const StringName &p_uniform) const; @@ -156,7 +167,8 @@ public: GeometryInstance3D(); }; -VARIANT_ENUM_CAST(GeometryInstance3D::Flags); VARIANT_ENUM_CAST(GeometryInstance3D::ShadowCastingSetting); +VARIANT_ENUM_CAST(GeometryInstance3D::LightmapScale); +VARIANT_ENUM_CAST(GeometryInstance3D::GIMode); #endif diff --git a/scene/3d/voxelizer.cpp b/scene/3d/voxelizer.cpp index f30c58be55..f9c3810843 100644 --- a/scene/3d/voxelizer.cpp +++ b/scene/3d/voxelizer.cpp @@ -29,192 +29,12 @@ /*************************************************************************/ #include "voxelizer.h" +#include "core/math/geometry.h" #include "core/os/os.h" #include "core/os/threaded_array_processor.h" #include <stdlib.h> -#define FINDMINMAX(x0, x1, x2, min, max) \ - min = max = x0; \ - if (x1 < min) min = x1; \ - if (x1 > max) max = x1; \ - if (x2 < min) min = x2; \ - if (x2 > max) max = x2; - -static bool planeBoxOverlap(Vector3 normal, float d, Vector3 maxbox) { - int q; - Vector3 vmin, vmax; - for (q = 0; q <= 2; q++) { - if (normal[q] > 0.0f) { - vmin[q] = -maxbox[q]; - vmax[q] = maxbox[q]; - } else { - vmin[q] = maxbox[q]; - vmax[q] = -maxbox[q]; - } - } - if (normal.dot(vmin) + d > 0.0f) return false; - if (normal.dot(vmax) + d >= 0.0f) return true; - - return false; -} - -/*======================== X-tests ========================*/ -#define AXISTEST_X01(a, b, fa, fb) \ - p0 = a * v0.y - b * v0.z; \ - p2 = a * v2.y - b * v2.z; \ - if (p0 < p2) { \ - min = p0; \ - max = p2; \ - } else { \ - min = p2; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) return false; - -#define AXISTEST_X2(a, b, fa, fb) \ - p0 = a * v0.y - b * v0.z; \ - p1 = a * v1.y - b * v1.z; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.y + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) return false; - -/*======================== Y-tests ========================*/ -#define AXISTEST_Y02(a, b, fa, fb) \ - p0 = -a * v0.x + b * v0.z; \ - p2 = -a * v2.x + b * v2.z; \ - if (p0 < p2) { \ - min = p0; \ - max = p2; \ - } else { \ - min = p2; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) return false; - -#define AXISTEST_Y1(a, b, fa, fb) \ - p0 = -a * v0.x + b * v0.z; \ - p1 = -a * v1.x + b * v1.z; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.z; \ - if (min > rad || max < -rad) return false; - -/*======================== Z-tests ========================*/ - -#define AXISTEST_Z12(a, b, fa, fb) \ - p1 = a * v1.x - b * v1.y; \ - p2 = a * v2.x - b * v2.y; \ - if (p2 < p1) { \ - min = p2; \ - max = p1; \ - } else { \ - min = p1; \ - max = p2; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ - if (min > rad || max < -rad) return false; - -#define AXISTEST_Z0(a, b, fa, fb) \ - p0 = a * v0.x - b * v0.y; \ - p1 = a * v1.x - b * v1.y; \ - if (p0 < p1) { \ - min = p0; \ - max = p1; \ - } else { \ - min = p1; \ - max = p0; \ - } \ - rad = fa * boxhalfsize.x + fb * boxhalfsize.y; \ - if (min > rad || max < -rad) return false; - -static bool fast_tri_box_overlap(const Vector3 &boxcenter, const Vector3 boxhalfsize, const Vector3 *triverts) { - - /* use separating axis theorem to test overlap between triangle and box */ - /* need to test for overlap in these directions: */ - /* 1) the {x,y,z}-directions (actually, since we use the AABB of the triangle */ - /* we do not even need to test these) */ - /* 2) normal of the triangle */ - /* 3) crossproduct(edge from tri, {x,y,z}-directin) */ - /* this gives 3x3=9 more tests */ - Vector3 v0, v1, v2; - float min, max, d, p0, p1, p2, rad, fex, fey, fez; - Vector3 normal, e0, e1, e2; - - /* This is the fastest branch on Sun */ - /* move everything so that the boxcenter is in (0,0,0) */ - - v0 = triverts[0] - boxcenter; - v1 = triverts[1] - boxcenter; - v2 = triverts[2] - boxcenter; - - /* compute triangle edges */ - e0 = v1 - v0; /* tri edge 0 */ - e1 = v2 - v1; /* tri edge 1 */ - e2 = v0 - v2; /* tri edge 2 */ - - /* Bullet 3: */ - /* test the 9 tests first (this was faster) */ - fex = Math::abs(e0.x); - fey = Math::abs(e0.y); - fez = Math::abs(e0.z); - AXISTEST_X01(e0.z, e0.y, fez, fey); - AXISTEST_Y02(e0.z, e0.x, fez, fex); - AXISTEST_Z12(e0.y, e0.x, fey, fex); - - fex = Math::abs(e1.x); - fey = Math::abs(e1.y); - fez = Math::abs(e1.z); - AXISTEST_X01(e1.z, e1.y, fez, fey); - AXISTEST_Y02(e1.z, e1.x, fez, fex); - AXISTEST_Z0(e1.y, e1.x, fey, fex); - - fex = Math::abs(e2.x); - fey = Math::abs(e2.y); - fez = Math::abs(e2.z); - AXISTEST_X2(e2.z, e2.y, fez, fey); - AXISTEST_Y1(e2.z, e2.x, fez, fex); - AXISTEST_Z12(e2.y, e2.x, fey, fex); - - /* Bullet 1: */ - /* first test overlap in the {x,y,z}-directions */ - /* find min, max of the triangle each direction, and test for overlap in */ - /* that direction -- this is equivalent to testing a minimal AABB around */ - /* the triangle against the AABB */ - - /* test in X-direction */ - FINDMINMAX(v0.x, v1.x, v2.x, min, max); - if (min > boxhalfsize.x || max < -boxhalfsize.x) return false; - - /* test in Y-direction */ - FINDMINMAX(v0.y, v1.y, v2.y, min, max); - if (min > boxhalfsize.y || max < -boxhalfsize.y) return false; - - /* test in Z-direction */ - FINDMINMAX(v0.z, v1.z, v2.z, min, max); - if (min > boxhalfsize.z || max < -boxhalfsize.z) return false; - - /* Bullet 2: */ - /* test if the box intersects the plane of the triangle */ - /* compute plane equation of triangle: normal*x+d=0 */ - normal = e0.cross(e1); - d = -normal.dot(v0); /* plane eq: normal.x+d=0 */ - return planeBoxOverlap(normal, d, boxhalfsize); /* if true, box and triangle overlaps */ -} - static _FORCE_INLINE_ void get_uv_and_normal(const Vector3 &p_pos, const Vector3 *p_vtx, const Vector2 *p_uv, const Vector3 *p_normal, Vector2 &r_uv, Vector3 &r_normal) { if (p_pos.distance_squared_to(p_vtx[0]) < CMP_EPSILON2) { @@ -309,7 +129,7 @@ void Voxelizer::_plot_face(int p_idx, int p_level, int p_x, int p_y, int p_z, co Vector3 half = (to - from) * 0.5; //is in this cell? - if (!fast_tri_box_overlap(from + half, half, p_vtx)) { + if (!Geometry::triangle_box_overlap(from + half, half, p_vtx)) { continue; //face does not span this cell } @@ -452,7 +272,7 @@ void Voxelizer::_plot_face(int p_idx, int p_level, int p_x, int p_y, int p_z, co //test_aabb.grow_by(test_aabb.get_longest_axis_size()*0.05); //grow a bit to avoid numerical error in real-time Vector3 qsize = test_aabb.size * 0.5; //quarter size, for fast aabb test - if (!fast_tri_box_overlap(test_aabb.position + qsize, qsize, p_vtx)) { + if (!Geometry::triangle_box_overlap(test_aabb.position + qsize, qsize, p_vtx)) { //if (!Face3(p_vtx[0],p_vtx[1],p_vtx[2]).intersects_aabb2(aabb)) { //does not fit in child, go on continue; @@ -633,7 +453,7 @@ void Voxelizer::plot_mesh(const Transform &p_xform, Ref<Mesh> &p_mesh, const Vec } //test against original bounds - if (!fast_tri_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) + if (!Geometry::triangle_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) continue; //plot _plot_face(0, 0, 0, 0, 0, vtxs, normal, uvs, material, po2_bounds); @@ -666,7 +486,7 @@ void Voxelizer::plot_mesh(const Transform &p_xform, Ref<Mesh> &p_mesh, const Vec } //test against original bounds - if (!fast_tri_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) + if (!Geometry::triangle_box_overlap(original_bounds.position + original_bounds.size * 0.5, original_bounds.size * 0.5, vtxs)) continue; //plot face _plot_face(0, 0, 0, 0, 0, vtxs, normal, uvs, material, po2_bounds); diff --git a/scene/3d/xr_nodes.cpp b/scene/3d/xr_nodes.cpp index 6f41629bac..1b13b64744 100644 --- a/scene/3d/xr_nodes.cpp +++ b/scene/3d/xr_nodes.cpp @@ -269,11 +269,11 @@ void XRController3D::set_controller_id(int p_controller_id) { update_configuration_warning(); }; -int XRController3D::get_controller_id(void) const { +int XRController3D::get_controller_id() const { return controller_id; }; -String XRController3D::get_controller_name(void) const { +String XRController3D::get_controller_name() const { // get our XRServer XRServer *xr_server = XRServer::get_singleton(); ERR_FAIL_NULL_V(xr_server, String()); @@ -465,7 +465,7 @@ void XRAnchor3D::set_anchor_id(int p_anchor_id) { update_configuration_warning(); }; -int XRAnchor3D::get_anchor_id(void) const { +int XRAnchor3D::get_anchor_id() const { return anchor_id; }; @@ -473,7 +473,7 @@ Vector3 XRAnchor3D::get_size() const { return size; }; -String XRAnchor3D::get_anchor_name(void) const { +String XRAnchor3D::get_anchor_name() const { // get our XRServer XRServer *xr_server = XRServer::get_singleton(); ERR_FAIL_NULL_V(xr_server, String()); diff --git a/scene/3d/xr_nodes.h b/scene/3d/xr_nodes.h index a2f16545d1..55dcfe087e 100644 --- a/scene/3d/xr_nodes.h +++ b/scene/3d/xr_nodes.h @@ -84,8 +84,8 @@ protected: public: void set_controller_id(int p_controller_id); - int get_controller_id(void) const; - String get_controller_name(void) const; + int get_controller_id() const; + String get_controller_name() const; int get_joystick_id() const; bool is_button_pressed(int p_button) const; @@ -97,7 +97,7 @@ public: bool get_is_active() const; XRPositionalTracker::TrackerHand get_hand() const; - Ref<Mesh> get_mesh(void) const; + Ref<Mesh> get_mesh() const; String get_configuration_warning() const; @@ -125,15 +125,15 @@ protected: public: void set_anchor_id(int p_anchor_id); - int get_anchor_id(void) const; - String get_anchor_name(void) const; + int get_anchor_id() const; + String get_anchor_name() const; bool get_is_active() const; Vector3 get_size() const; Plane get_plane() const; - Ref<Mesh> get_mesh(void) const; + Ref<Mesh> get_mesh() const; String get_configuration_warning() const; diff --git a/scene/animation/animation_blend_space_2d.cpp b/scene/animation/animation_blend_space_2d.cpp index 638531df41..ad60249f9a 100644 --- a/scene/animation/animation_blend_space_2d.cpp +++ b/scene/animation/animation_blend_space_2d.cpp @@ -29,7 +29,8 @@ /*************************************************************************/ #include "animation_blend_space_2d.h" -#include "core/math/delaunay.h" + +#include "core/math/delaunay_2d.h" void AnimationNodeBlendSpace2D::get_parameter_list(List<PropertyInfo> *r_list) const { r_list->push_back(PropertyInfo(Variant::VECTOR2, blend_position)); diff --git a/scene/animation/animation_player.cpp b/scene/animation/animation_player.cpp index 8228cf67bd..7bac09f839 100644 --- a/scene/animation/animation_player.cpp +++ b/scene/animation/animation_player.cpp @@ -1479,9 +1479,14 @@ void AnimationPlayer::_set_process(bool p_process, bool p_force) { switch (animation_process_mode) { - case ANIMATION_PROCESS_PHYSICS: set_physics_process_internal(p_process && active); break; - case ANIMATION_PROCESS_IDLE: set_process_internal(p_process && active); break; - case ANIMATION_PROCESS_MANUAL: break; + case ANIMATION_PROCESS_PHYSICS: + set_physics_process_internal(p_process && active); + break; + case ANIMATION_PROCESS_IDLE: + set_process_internal(p_process && active); + break; + case ANIMATION_PROCESS_MANUAL: + break; } processing = p_process; diff --git a/scene/animation/animation_player.h b/scene/animation/animation_player.h index c134aff707..d709082f62 100644 --- a/scene/animation/animation_player.h +++ b/scene/animation/animation_player.h @@ -87,41 +87,37 @@ private: struct TrackNodeCache { NodePath path; - uint32_t id; + uint32_t id = 0; RES resource; - Node *node; - Node3D *spatial; - Node2D *node_2d; - Skeleton3D *skeleton; - int bone_idx; + Node *node = nullptr; + Node3D *spatial = nullptr; + Node2D *node_2d = nullptr; + Skeleton3D *skeleton = nullptr; + int bone_idx = -1; // accumulated transforms Vector3 loc_accum; Quat rot_accum; Vector3 scale_accum; - uint64_t accum_pass; + uint64_t accum_pass = 0; - bool audio_playing; - float audio_start; - float audio_len; + bool audio_playing = false; + float audio_start = 0.0; + float audio_len = 0.0; - bool animation_playing; + bool animation_playing = false; struct PropertyAnim { - TrackNodeCache *owner; - SpecialProperty special; //small optimization + TrackNodeCache *owner = nullptr; + SpecialProperty special = SP_NONE; //small optimization Vector<StringName> subpath; - Object *object; + Object *object = nullptr; Variant value_accum; - uint64_t accum_pass; + uint64_t accum_pass = 0; Variant capture; - PropertyAnim() : - owner(nullptr), - special(SP_NONE), - object(nullptr), - accum_pass(0) {} + PropertyAnim() {} }; Map<StringName, PropertyAnim> property_anim; @@ -129,32 +125,17 @@ private: struct BezierAnim { Vector<StringName> bezier_property; - TrackNodeCache *owner; - float bezier_accum; - Object *object; - uint64_t accum_pass; - - BezierAnim() : - owner(nullptr), - bezier_accum(0.0), - object(nullptr), - accum_pass(0) {} + TrackNodeCache *owner = nullptr; + float bezier_accum = 0.0; + Object *object = nullptr; + uint64_t accum_pass = 0; + + BezierAnim() {} }; Map<StringName, BezierAnim> bezier_anim; - TrackNodeCache() : - id(0), - node(nullptr), - spatial(nullptr), - node_2d(nullptr), - skeleton(nullptr), - bone_idx(-1), - accum_pass(0), - audio_playing(false), - audio_start(0.0), - audio_len(0.0), - animation_playing(false) {} + TrackNodeCache() {} }; struct TrackNodeCacheKey { diff --git a/scene/animation/tween.cpp b/scene/animation/tween.cpp index d0c6cac8cf..b826907a3a 100644 --- a/scene/animation/tween.cpp +++ b/scene/animation/tween.cpp @@ -353,7 +353,8 @@ Variant Tween::_get_final_val(const InterpolateData &p_data) const { // If we're looking at an INT value, instead convert it to a FLOAT // This is better for interpolation - if (final_val.get_type() == Variant::INT) final_val = final_val.operator real_t(); + if (final_val.get_type() == Variant::INT) + final_val = final_val.operator real_t(); return final_val; } @@ -395,7 +396,8 @@ Variant &Tween::_get_delta_val(InterpolateData &p_data) { // If we're looking at an INT value, instead convert it to a FLOAT // This is better for interpolation - if (final_val.get_type() == Variant::INT) final_val = final_val.operator real_t(); + if (final_val.get_type() == Variant::INT) + final_val = final_val.operator real_t(); // Calculate the delta based on the initial value and the final value _calc_delta_val(p_data.initial_val, final_val, p_data.delta_val); @@ -409,7 +411,8 @@ Variant &Tween::_get_delta_val(InterpolateData &p_data) { // If we're looking at an INT value, instead convert it to a FLOAT // This is better for interpolation - if (initial_val.get_type() == Variant::INT) initial_val = initial_val.operator real_t(); + if (initial_val.get_type() == Variant::INT) + initial_val = initial_val.operator real_t(); // Calculate the delta based on the initial value and the final value _calc_delta_val(initial_val, p_data.final_val, p_data.delta_val); @@ -823,8 +826,12 @@ void Tween::set_active(bool p_active) { // Depending on physics or idle, set processing switch (tween_process_mode) { - case TWEEN_PROCESS_IDLE: set_process_internal(p_active); break; - case TWEEN_PROCESS_PHYSICS: set_physics_process_internal(p_active); break; + case TWEEN_PROCESS_IDLE: + set_process_internal(p_active); + break; + case TWEEN_PROCESS_PHYSICS: + set_physics_process_internal(p_active); + break; } } @@ -1334,11 +1341,14 @@ void Tween::interpolate_property(Object *p_object, NodePath p_property, Variant // If no initial value given, grab the initial value from the object // TODO: Is this documented? This is very useful and removes a lot of clutter from tweens! - if (p_initial_val.get_type() == Variant::NIL) p_initial_val = p_object->get_indexed(p_property.get_subnames()); + if (p_initial_val.get_type() == Variant::NIL) + p_initial_val = p_object->get_indexed(p_property.get_subnames()); // Convert any integers into REALs as they are better for interpolation - if (p_initial_val.get_type() == Variant::INT) p_initial_val = p_initial_val.operator real_t(); - if (p_final_val.get_type() == Variant::INT) p_final_val = p_final_val.operator real_t(); + if (p_initial_val.get_type() == Variant::INT) + p_initial_val = p_initial_val.operator real_t(); + if (p_final_val.get_type() == Variant::INT) + p_final_val = p_final_val.operator real_t(); // Build the interpolation data _build_interpolation(INTER_PROPERTY, p_object, &p_property, nullptr, p_initial_val, p_final_val, p_duration, p_trans_type, p_ease_type, p_delay); @@ -1352,8 +1362,10 @@ void Tween::interpolate_method(Object *p_object, StringName p_method, Variant p_ } // Convert any integers into REALs as they are better for interpolation - if (p_initial_val.get_type() == Variant::INT) p_initial_val = p_initial_val.operator real_t(); - if (p_final_val.get_type() == Variant::INT) p_final_val = p_final_val.operator real_t(); + if (p_initial_val.get_type() == Variant::INT) + p_initial_val = p_initial_val.operator real_t(); + if (p_final_val.get_type() == Variant::INT) + p_final_val = p_final_val.operator real_t(); // Build the interpolation data _build_interpolation(INTER_METHOD, p_object, nullptr, &p_method, p_initial_val, p_final_val, p_duration, p_trans_type, p_ease_type, p_delay); @@ -1486,10 +1498,12 @@ void Tween::follow_property(Object *p_object, NodePath p_property, Variant p_ini // If no initial value is given, grab it from the source object // TODO: Is this documented? It's really helpful for decluttering tweens - if (p_initial_val.get_type() == Variant::NIL) p_initial_val = p_object->get_indexed(p_property.get_subnames()); + if (p_initial_val.get_type() == Variant::NIL) + p_initial_val = p_object->get_indexed(p_property.get_subnames()); // Convert initial INT values to FLOAT as they are better for interpolation - if (p_initial_val.get_type() == Variant::INT) p_initial_val = p_initial_val.operator real_t(); + if (p_initial_val.get_type() == Variant::INT) + p_initial_val = p_initial_val.operator real_t(); // Confirm the source and target objects are valid ERR_FAIL_COND(p_object == nullptr); @@ -1515,7 +1529,8 @@ void Tween::follow_property(Object *p_object, NodePath p_property, Variant p_ini ERR_FAIL_COND(!target_prop_valid); // Convert target INT to FLOAT since it is better for interpolation - if (target_val.get_type() == Variant::INT) target_val = target_val.operator real_t(); + if (target_val.get_type() == Variant::INT) + target_val = target_val.operator real_t(); // Verify that the target value and initial value are the same type ERR_FAIL_COND(target_val.get_type() != p_initial_val.get_type()); @@ -1550,7 +1565,8 @@ void Tween::follow_method(Object *p_object, StringName p_method, Variant p_initi return; } // Convert initial INT values to FLOAT as they are better for interpolation - if (p_initial_val.get_type() == Variant::INT) p_initial_val = p_initial_val.operator real_t(); + if (p_initial_val.get_type() == Variant::INT) + p_initial_val = p_initial_val.operator real_t(); // Verify the source and target objects are valid ERR_FAIL_COND(p_object == nullptr); @@ -1576,7 +1592,8 @@ void Tween::follow_method(Object *p_object, StringName p_method, Variant p_initi ERR_FAIL_COND(error.error != Callable::CallError::CALL_OK); // Convert target INT values to FLOAT as they are better for interpolation - if (target_val.get_type() == Variant::INT) target_val = target_val.operator real_t(); + if (target_val.get_type() == Variant::INT) + target_val = target_val.operator real_t(); ERR_FAIL_COND(target_val.get_type() != p_initial_val.get_type()); // Make the new InterpolateData for the method follow @@ -1613,7 +1630,8 @@ void Tween::targeting_property(Object *p_object, NodePath p_property, Object *p_ p_initial_property = p_initial_property.get_as_property_path(); // Convert the initial INT values to FLOAT as they are better for Interpolation - if (p_final_val.get_type() == Variant::INT) p_final_val = p_final_val.operator real_t(); + if (p_final_val.get_type() == Variant::INT) + p_final_val = p_final_val.operator real_t(); // Verify both objects are valid ERR_FAIL_COND(p_object == nullptr); @@ -1639,7 +1657,8 @@ void Tween::targeting_property(Object *p_object, NodePath p_property, Object *p_ ERR_FAIL_COND(!initial_prop_valid); // Convert the initial INT value to FLOAT as it is better for interpolation - if (initial_val.get_type() == Variant::INT) initial_val = initial_val.operator real_t(); + if (initial_val.get_type() == Variant::INT) + initial_val = initial_val.operator real_t(); ERR_FAIL_COND(initial_val.get_type() != p_final_val.get_type()); // Build the InterpolateData object @@ -1679,7 +1698,8 @@ void Tween::targeting_method(Object *p_object, StringName p_method, Object *p_in } // Convert final INT values to FLOAT as they are better for interpolation - if (p_final_val.get_type() == Variant::INT) p_final_val = p_final_val.operator real_t(); + if (p_final_val.get_type() == Variant::INT) + p_final_val = p_final_val.operator real_t(); // Make sure the given objects are valid ERR_FAIL_COND(p_object == nullptr); @@ -1705,7 +1725,8 @@ void Tween::targeting_method(Object *p_object, StringName p_method, Object *p_in ERR_FAIL_COND(error.error != Callable::CallError::CALL_OK); // Convert initial INT values to FLOAT as they aer better for interpolation - if (initial_val.get_type() == Variant::INT) initial_val = initial_val.operator real_t(); + if (initial_val.get_type() == Variant::INT) + initial_val = initial_val.operator real_t(); ERR_FAIL_COND(initial_val.get_type() != p_final_val.get_type()); // Build the new InterpolateData object diff --git a/scene/debugger/scene_debugger.h b/scene/debugger/scene_debugger.h index e295510960..e8521d6a20 100644 --- a/scene/debugger/scene_debugger.h +++ b/scene/debugger/scene_debugger.h @@ -100,7 +100,7 @@ public: void serialize(Array &r_arr); void deserialize(const Array &p_arr); SceneDebuggerTree(Node *p_root); - SceneDebuggerTree(){}; + SceneDebuggerTree() {} }; class LiveEditor { diff --git a/scene/gui/file_dialog.cpp b/scene/gui/file_dialog.cpp index 89d13ecd7f..a449d680a8 100644 --- a/scene/gui/file_dialog.cpp +++ b/scene/gui/file_dialog.cpp @@ -410,7 +410,8 @@ void FileDialog::_tree_item_activated() { void FileDialog::update_file_name() { int idx = filter->get_selected() - 1; if ((idx == -1 && filter->get_item_count() == 2) || (filter->get_item_count() > 2 && idx >= 0 && idx < filter->get_item_count() - 2)) { - if (idx == -1) idx += 1; + if (idx == -1) + idx += 1; String filter_str = filters[idx]; String file_str = file->get_text(); String base_name = file_str.get_basename(); diff --git a/scene/gui/label.cpp b/scene/gui/label.cpp index bedcef2df2..7490101ee3 100644 --- a/scene/gui/label.cpp +++ b/scene/gui/label.cpp @@ -29,6 +29,7 @@ /*************************************************************************/ #include "label.h" + #include "core/print_string.h" #include "core/project_settings.h" #include "core/translation.h" @@ -693,24 +694,8 @@ void Label::_bind_methods() { } Label::Label(const String &p_text) { - - align = ALIGN_LEFT; - valign = VALIGN_TOP; - xl_text = ""; - word_cache = nullptr; - word_cache_dirty = true; - autowrap = false; - line_count = 0; - set_v_size_flags(0); - clip = false; set_mouse_filter(MOUSE_FILTER_IGNORE); - total_char_cache = 0; - visible_chars = -1; - percent_visible = 1; - lines_skipped = 0; - max_lines_visible = -1; set_text(p_text); - uppercase = false; set_v_size_flags(SIZE_SHRINK_CENTER); } diff --git a/scene/gui/label.h b/scene/gui/label.h index ba6e627c58..0f84b01604 100644 --- a/scene/gui/label.h +++ b/scene/gui/label.h @@ -55,15 +55,15 @@ public: }; private: - Align align; - VAlign valign; + Align align = ALIGN_LEFT; + VAlign valign = VALIGN_TOP; String text; String xl_text; - bool autowrap; - bool clip; + bool autowrap = false; + bool clip = false; Size2 minsize; - int line_count; - bool uppercase; + int line_count = 0; + bool uppercase = false; int get_longest_line_width() const; @@ -73,30 +73,23 @@ private: CHAR_NEWLINE = -1, CHAR_WRAPLINE = -2 }; - int char_pos; // if -1, then newline - int word_len; - int pixel_width; - int space_count; - WordCache *next; - WordCache() { - char_pos = 0; - word_len = 0; - pixel_width = 0; - next = 0; - space_count = 0; - } + int char_pos = 0; // if -1, then newline + int word_len = 0; + int pixel_width = 0; + int space_count = 0; + WordCache *next = nullptr; }; - bool word_cache_dirty; + bool word_cache_dirty = true; void regenerate_word_cache(); - float percent_visible; + float percent_visible = 1; - WordCache *word_cache; - int total_char_cache; - int visible_chars; - int lines_skipped; - int max_lines_visible; + WordCache *word_cache = nullptr; + int total_char_cache = 0; + int visible_chars = -1; + int lines_skipped = 0; + int max_lines_visible = -1; protected: void _notification(int p_what); diff --git a/scene/gui/line_edit.cpp b/scene/gui/line_edit.cpp index b9b7560f2e..bb177ae0e7 100644 --- a/scene/gui/line_edit.cpp +++ b/scene/gui/line_edit.cpp @@ -277,10 +277,14 @@ void LineEdit::_gui_input(Ref<InputEvent> p_event) { } break; #ifdef APPLE_STYLE_KEYS case (KEY_LEFT): { // Go to start of text - like HOME key. + shift_selection_check_pre(k->get_shift()); set_cursor_position(0); + shift_selection_check_post(k->get_shift()); } break; case (KEY_RIGHT): { // Go to end of text - like END key. + shift_selection_check_pre(k->get_shift()); set_cursor_position(text.length()); + shift_selection_check_post(k->get_shift()); } break; #endif default: { @@ -463,14 +467,16 @@ void LineEdit::_gui_input(Ref<InputEvent> p_event) { case KEY_UP: { shift_selection_check_pre(k->get_shift()); - if (get_cursor_position() == 0) handled = false; + if (get_cursor_position() == 0) + handled = false; set_cursor_position(0); shift_selection_check_post(k->get_shift()); } break; case KEY_DOWN: { shift_selection_check_pre(k->get_shift()); - if (get_cursor_position() == text.length()) handled = false; + if (get_cursor_position() == text.length()) + handled = false; set_cursor_position(text.length()); shift_selection_check_post(k->get_shift()); } break; @@ -988,7 +994,8 @@ void LineEdit::paste_text() { if (paste_buffer != "") { int prev_len = text.length(); - if (selection.enabled) selection_delete(); + if (selection.enabled) + selection_delete(); append_at_cursor(paste_buffer); if (!text_changed_dirty) { @@ -1204,7 +1211,8 @@ void LineEdit::_toggle_draw_caret() { void LineEdit::delete_char() { - if ((text.length() <= 0) || (cursor_pos == 0)) return; + if ((text.length() <= 0) || (cursor_pos == 0)) + return; Ref<Font> font = get_theme_font("font"); if (font != nullptr) { @@ -1379,7 +1387,8 @@ int LineEdit::get_cursor_position() const { void LineEdit::set_window_pos(int p_pos) { window_pos = p_pos; - if (window_pos < 0) window_pos = 0; + if (window_pos < 0) + window_pos = 0; } void LineEdit::append_at_cursor(String p_text) { diff --git a/scene/gui/rich_text_label.cpp b/scene/gui/rich_text_label.cpp index 84097eb6a1..a6c0c99bdb 100644 --- a/scene/gui/rich_text_label.cpp +++ b/scene/gui/rich_text_label.cpp @@ -230,10 +230,18 @@ int RichTextLabel::_process_line(ItemFrame *p_frame, const Vector2 &p_ofs, int & } else { \ int used = wofs - margin; \ switch (align) { \ - case ALIGN_LEFT: l.offset_caches.push_back(0); break; \ - case ALIGN_CENTER: l.offset_caches.push_back(((p_width - margin) - used) / 2); break; \ - case ALIGN_RIGHT: l.offset_caches.push_back(((p_width - margin) - used)); break; \ - case ALIGN_FILL: l.offset_caches.push_back(line_wrapped ? ((p_width - margin) - used) : 0); break; \ + case ALIGN_LEFT: \ + l.offset_caches.push_back(0); \ + break; \ + case ALIGN_CENTER: \ + l.offset_caches.push_back(((p_width - margin) - used) / 2); \ + break; \ + case ALIGN_RIGHT: \ + l.offset_caches.push_back(((p_width - margin) - used)); \ + break; \ + case ALIGN_FILL: \ + l.offset_caches.push_back(line_wrapped ? ((p_width - margin) - used) : 0); \ + break; \ } \ l.height_caches.push_back(line_height); \ l.ascent_caches.push_back(line_ascent); \ @@ -255,7 +263,8 @@ int RichTextLabel::_process_line(ItemFrame *p_frame, const Vector2 &p_ofs, int & line_descent = line < l.descent_caches.size() ? l.descent_caches[line] : 1; \ } \ if (p_mode == PROCESS_POINTER && r_click_item && p_click_pos.y >= p_ofs.y + y && p_click_pos.y <= p_ofs.y + y + lh && p_click_pos.x < p_ofs.x + wofs) { \ - if (r_outside) *r_outside = true; \ + if (r_outside) \ + *r_outside = true; \ *r_click_item = it; \ *r_click_char = rchar; \ RETURN; \ @@ -275,7 +284,8 @@ int RichTextLabel::_process_line(ItemFrame *p_frame, const Vector2 &p_ofs, int & } \ const bool x_in_range = (p_click_pos.x > p_ofs.x + wofs) && (!p_frame->cell || p_click_pos.x < p_ofs.x + p_width); \ if (p_mode == PROCESS_POINTER && r_click_item && p_click_pos.y >= p_ofs.y + y && p_click_pos.y <= p_ofs.y + y + lh && x_in_range) { \ - if (r_outside) *r_outside = true; \ + if (r_outside) \ + *r_outside = true; \ *r_click_item = it; \ *r_click_char = rchar; \ RETURN; \ @@ -286,7 +296,8 @@ int RichTextLabel::_process_line(ItemFrame *p_frame, const Vector2 &p_ofs, int & #define ADVANCE(m_width) \ { \ if (p_mode == PROCESS_POINTER && r_click_item && p_click_pos.y >= p_ofs.y + y && p_click_pos.y <= p_ofs.y + y + lh && p_click_pos.x >= p_ofs.x + wofs && p_click_pos.x < p_ofs.x + wofs + m_width) { \ - if (r_outside) *r_outside = false; \ + if (r_outside) \ + *r_outside = false; \ *r_click_item = it; \ *r_click_char = rchar; \ RETURN; \ @@ -855,7 +866,8 @@ int RichTextLabel::_process_line(ItemFrame *p_frame, const Vector2 &p_ofs, int & if (p_mode == PROCESS_POINTER && r_click_item && p_click_pos.y >= p_ofs.y + y && p_click_pos.y <= p_ofs.y + y + lh) { //went to next line, but pointer was on the previous one - if (r_outside) *r_outside = true; + if (r_outside) + *r_outside = true; *r_click_item = itp; *r_click_char = rchar; RETURN; diff --git a/scene/gui/slider.cpp b/scene/gui/slider.cpp index 1f135163d4..910d5f8230 100644 --- a/scene/gui/slider.cpp +++ b/scene/gui/slider.cpp @@ -182,7 +182,8 @@ void Slider::_notification(int p_what) { if (ticks > 1) { int grabber_offset = (grabber->get_size().height / 2 - tick->get_height() / 2); for (int i = 0; i < ticks; i++) { - if (!ticks_on_borders && (i == 0 || i + 1 == ticks)) continue; + if (!ticks_on_borders && (i == 0 || i + 1 == ticks)) + continue; int ofs = (i * areasize / (ticks - 1)) + grabber_offset; tick->draw(ci, Point2i((size.width - widget_width) / 2, ofs)); } @@ -199,7 +200,8 @@ void Slider::_notification(int p_what) { if (ticks > 1) { int grabber_offset = (grabber->get_size().width / 2 - tick->get_width() / 2); for (int i = 0; i < ticks; i++) { - if ((!ticks_on_borders) && ((i == 0) || ((i + 1) == ticks))) continue; + if ((!ticks_on_borders) && ((i == 0) || ((i + 1) == ticks))) + continue; int ofs = (i * areasize / (ticks - 1)) + grabber_offset; tick->draw(ci, Point2i(ofs, (size.height - widget_height) / 2)); } diff --git a/scene/gui/subviewport_container.cpp b/scene/gui/subviewport_container.cpp index 50f468741d..7252a43651 100644 --- a/scene/gui/subviewport_container.cpp +++ b/scene/gui/subviewport_container.cpp @@ -1,5 +1,5 @@ /*************************************************************************/ -/* subviewport_container.cpp */ +/* subviewport_container.cpp */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ diff --git a/scene/gui/subviewport_container.h b/scene/gui/subviewport_container.h index 6ff3d188e2..13b711f838 100644 --- a/scene/gui/subviewport_container.h +++ b/scene/gui/subviewport_container.h @@ -1,5 +1,5 @@ /*************************************************************************/ -/* subviewport_container.h */ +/* subviewport_container.h */ /*************************************************************************/ /* This file is part of: */ /* GODOT ENGINE */ diff --git a/scene/gui/tabs.cpp b/scene/gui/tabs.cpp index 1a3b53f489..b856d3ab3a 100644 --- a/scene/gui/tabs.cpp +++ b/scene/gui/tabs.cpp @@ -399,7 +399,8 @@ int Tabs::get_tab_count() const { void Tabs::set_current_tab(int p_current) { - if (current == p_current) return; + if (current == p_current) + return; ERR_FAIL_INDEX(p_current, get_tab_count()); current = p_current; @@ -856,7 +857,8 @@ void Tabs::ensure_tab_visible(int p_idx) { if (!is_inside_tree()) return; - if (tabs.size() == 0) return; + if (tabs.size() == 0) + return; ERR_FAIL_INDEX(p_idx, tabs.size()); if (p_idx == offset) { diff --git a/scene/gui/text_edit.cpp b/scene/gui/text_edit.cpp index aa518fbb7d..e986e350f3 100644 --- a/scene/gui/text_edit.cpp +++ b/scene/gui/text_edit.cpp @@ -941,7 +941,7 @@ void TextEdit::_notification(int p_what) { // calculate viewport size and y offset int viewport_height = (draw_amount - 1) * minimap_line_height; int control_height = _get_control_height() - viewport_height; - int viewport_offset_y = round(get_scroll_pos_for_line(first_visible_line) * control_height) / ((v_scroll->get_max() <= minimap_visible_lines) ? (minimap_visible_lines - draw_amount) : (v_scroll->get_max() - draw_amount)); + int viewport_offset_y = round(get_scroll_pos_for_line(first_visible_line + 1) * control_height) / ((v_scroll->get_max() <= minimap_visible_lines) ? (minimap_visible_lines - draw_amount) : (v_scroll->get_max() - draw_amount)); // calculate the first line. int num_lines_before = round((viewport_offset_y) / minimap_line_height); @@ -2991,7 +2991,8 @@ void TextEdit::_gui_input(const Ref<InputEvent> &p_gui_input) { } } break; case KEY_TAB: { - if (k->get_command()) break; // Avoid tab when command. + if (k->get_command()) + break; // Avoid tab when command. if (readonly) break; diff --git a/scene/gui/tree.cpp b/scene/gui/tree.cpp index 329c1085df..4a550727fc 100644 --- a/scene/gui/tree.cpp +++ b/scene/gui/tree.cpp @@ -2277,7 +2277,8 @@ void Tree::_gui_input(Ref<InputEvent> p_event) { bool is_command = k.is_valid() && k->get_command(); if (p_event->is_action("ui_right") && p_event->is_pressed()) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); if (!selected_item || select_mode == SELECT_ROW || selected_col > (columns.size() - 1)) { return; @@ -2294,7 +2295,8 @@ void Tree::_gui_input(Ref<InputEvent> p_event) { } } else if (p_event->is_action("ui_left") && p_event->is_pressed()) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); if (!selected_item || select_mode == SELECT_ROW || selected_col < 0) { return; @@ -2313,19 +2315,22 @@ void Tree::_gui_input(Ref<InputEvent> p_event) { } else if (p_event->is_action("ui_up") && p_event->is_pressed() && !is_command) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); _go_up(); } else if (p_event->is_action("ui_down") && p_event->is_pressed() && !is_command) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); _go_down(); } else if (p_event->is_action("ui_page_down") && p_event->is_pressed()) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); TreeItem *next = nullptr; if (!selected_item) @@ -2363,7 +2368,8 @@ void Tree::_gui_input(Ref<InputEvent> p_event) { ensure_cursor_is_visible(); } else if (p_event->is_action("ui_page_up") && p_event->is_pressed()) { - if (!cursor_can_exit_tree) accept_event(); + if (!cursor_can_exit_tree) + accept_event(); TreeItem *prev = nullptr; if (!selected_item) diff --git a/scene/gui/tree.h b/scene/gui/tree.h index 87c2588a12..3c67a79558 100644 --- a/scene/gui/tree.h +++ b/scene/gui/tree.h @@ -544,7 +544,7 @@ public: void clear(); - TreeItem *create_item(TreeItem *p_parent = 0, int p_idx = -1); + TreeItem *create_item(TreeItem *p_parent = nullptr, int p_idx = -1); TreeItem *get_root(); TreeItem *get_last_item(); diff --git a/scene/main/canvas_item.cpp b/scene/main/canvas_item.cpp index b5d54b2199..e91826d44b 100644 --- a/scene/main/canvas_item.cpp +++ b/scene/main/canvas_item.cpp @@ -95,18 +95,35 @@ void CanvasItemMaterial::_update_shader() { String code = "shader_type canvas_item;\nrender_mode "; switch (blend_mode) { - case BLEND_MODE_MIX: code += "blend_mix"; break; - case BLEND_MODE_ADD: code += "blend_add"; break; - case BLEND_MODE_SUB: code += "blend_sub"; break; - case BLEND_MODE_MUL: code += "blend_mul"; break; - case BLEND_MODE_PREMULT_ALPHA: code += "blend_premul_alpha"; break; - case BLEND_MODE_DISABLED: code += "blend_disabled"; break; + case BLEND_MODE_MIX: + code += "blend_mix"; + break; + case BLEND_MODE_ADD: + code += "blend_add"; + break; + case BLEND_MODE_SUB: + code += "blend_sub"; + break; + case BLEND_MODE_MUL: + code += "blend_mul"; + break; + case BLEND_MODE_PREMULT_ALPHA: + code += "blend_premul_alpha"; + break; + case BLEND_MODE_DISABLED: + code += "blend_disabled"; + break; } switch (light_mode) { - case LIGHT_MODE_NORMAL: break; - case LIGHT_MODE_UNSHADED: code += ",unshaded"; break; - case LIGHT_MODE_LIGHT_ONLY: code += ",light_only"; break; + case LIGHT_MODE_NORMAL: + break; + case LIGHT_MODE_UNSHADED: + code += ",unshaded"; + break; + case LIGHT_MODE_LIGHT_ONLY: + code += ",light_only"; + break; } code += ";\n"; @@ -1372,10 +1389,18 @@ void CanvasItem::_update_texture_filter_changed(bool p_propagate) { } else { //from viewport switch (get_viewport()->get_default_canvas_item_texture_filter()) { - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_NEAREST: texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_NEAREST; break; - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_LINEAR: texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_LINEAR; break; - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_LINEAR_WITH_MIPMAPS: texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_LINEAR_WITH_MIPMAPS; break; - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_NEAREST_WITH_MIPMAPS: texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_NEAREST_WITH_MIPMAPS; break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_NEAREST: + texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_NEAREST; + break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_LINEAR: + texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_LINEAR; + break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_LINEAR_WITH_MIPMAPS: + texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_LINEAR_WITH_MIPMAPS; + break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_FILTER_NEAREST_WITH_MIPMAPS: + texture_filter_cache = RS::CANVAS_ITEM_TEXTURE_FILTER_NEAREST_WITH_MIPMAPS; + break; default: { } } @@ -1421,9 +1446,15 @@ void CanvasItem::_update_texture_repeat_changed(bool p_propagate) { } else { //from viewport switch (get_viewport()->get_default_canvas_item_texture_repeat()) { - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_DISABLED: texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_DISABLED; break; - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_ENABLED: texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_ENABLED; break; - case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_MIRROR: texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_MIRROR; break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_DISABLED: + texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_DISABLED; + break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_ENABLED: + texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_ENABLED; + break; + case Viewport::DEFAULT_CANVAS_ITEM_TEXTURE_REPEAT_MIRROR: + texture_repeat_cache = RS::CANVAS_ITEM_TEXTURE_REPEAT_MIRROR; + break; default: { } } diff --git a/scene/main/canvas_item.h b/scene/main/canvas_item.h index dc17c5283b..805659376a 100644 --- a/scene/main/canvas_item.h +++ b/scene/main/canvas_item.h @@ -247,9 +247,11 @@ private: protected: _FORCE_INLINE_ void _notify_transform() { - if (!is_inside_tree()) return; + if (!is_inside_tree()) + return; _notify_transform(this); - if (!block_transform_notify && notify_local_transform) notification(NOTIFICATION_LOCAL_TRANSFORM_CHANGED); + if (!block_transform_notify && notify_local_transform) + notification(NOTIFICATION_LOCAL_TRANSFORM_CHANGED); } void item_rect_changed(bool p_size_changed = true); @@ -275,7 +277,7 @@ public: virtual bool _edit_is_selected_on_click(const Point2 &p_point, double p_tolerance) const; // Save and restore a CanvasItem state - virtual void _edit_set_state(const Dictionary &p_state){}; + virtual void _edit_set_state(const Dictionary &p_state) {} virtual Dictionary _edit_get_state() const { return Dictionary(); }; // Used to move the node @@ -288,18 +290,18 @@ public: // Used to rotate the node virtual bool _edit_use_rotation() const { return false; }; - virtual void _edit_set_rotation(float p_rotation){}; + virtual void _edit_set_rotation(float p_rotation) {} virtual float _edit_get_rotation() const { return 0.0; }; // Used to resize/move the node virtual bool _edit_use_rect() const { return false; }; // MAYBE REPLACE BY A _edit_get_editmode() - virtual void _edit_set_rect(const Rect2 &p_rect){}; + virtual void _edit_set_rect(const Rect2 &p_rect) {} virtual Rect2 _edit_get_rect() const { return Rect2(0, 0, 0, 0); }; virtual Size2 _edit_get_minimum_size() const { return Size2(-1, -1); }; // LOOKS WEIRD // Used to set a pivot virtual bool _edit_use_pivot() const { return false; }; - virtual void _edit_set_pivot(const Point2 &p_pivot){}; + virtual void _edit_set_pivot(const Point2 &p_pivot) {} virtual Point2 _edit_get_pivot() const { return Point2(); }; virtual Transform2D _edit_get_transform() const; diff --git a/scene/main/node.cpp b/scene/main/node.cpp index 4c02a15531..3d56b51e26 100644 --- a/scene/main/node.cpp +++ b/scene/main/node.cpp @@ -819,10 +819,14 @@ String Node::get_rpc_md5() const { bool Node::can_process_notification(int p_what) const { switch (p_what) { - case NOTIFICATION_PHYSICS_PROCESS: return data.physics_process; - case NOTIFICATION_PROCESS: return data.idle_process; - case NOTIFICATION_INTERNAL_PROCESS: return data.idle_process_internal; - case NOTIFICATION_INTERNAL_PHYSICS_PROCESS: return data.physics_process_internal; + case NOTIFICATION_PHYSICS_PROCESS: + return data.physics_process; + case NOTIFICATION_PROCESS: + return data.idle_process; + case NOTIFICATION_INTERNAL_PROCESS: + return data.idle_process_internal; + case NOTIFICATION_INTERNAL_PHYSICS_PROCESS: + return data.physics_process_internal; } return true; @@ -2994,10 +2998,14 @@ void Node::_bind_methods() { String Node::_get_name_num_separator() { switch (ProjectSettings::get_singleton()->get("node/name_num_separator").operator int()) { - case 0: return ""; - case 1: return " "; - case 2: return "_"; - case 3: return "-"; + case 0: + return ""; + case 1: + return " "; + case 2: + return "_"; + case 3: + return "-"; } return " "; } diff --git a/scene/main/scene_tree.cpp b/scene/main/scene_tree.cpp index 7afc1b2edd..94be22ccd2 100644 --- a/scene/main/scene_tree.cpp +++ b/scene/main/scene_tree.cpp @@ -1372,7 +1372,8 @@ void SceneTree::get_argument_options(const StringName &p_function, int p_idx, Li SceneTree::SceneTree() { - if (singleton == nullptr) singleton = this; + if (singleton == nullptr) + singleton = this; _quit = false; accept_quit = true; quit_on_go_back = true; @@ -1478,5 +1479,6 @@ SceneTree::~SceneTree() { memdelete(root); } - if (singleton == this) singleton = nullptr; + if (singleton == this) + singleton = nullptr; } diff --git a/scene/main/shader_globals_override.cpp b/scene/main/shader_globals_override.cpp index 13582cf655..272cea8f2b 100644 --- a/scene/main/shader_globals_override.cpp +++ b/scene/main/shader_globals_override.cpp @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_globals_override.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #include "shader_globals_override.h" #include "core/core_string_names.h" @@ -176,7 +206,7 @@ void ShaderGlobalsOverride::_get_property_list(List<PropertyInfo> *p_list) const Override o; o.in_use = false; Callable::CallError ce; - o.override = Variant::construct(pinfo.type, NULL, 0, ce); + o.override = Variant::construct(pinfo.type, nullptr, 0, ce); overrides[variables[i]] = o; } diff --git a/scene/main/shader_globals_override.h b/scene/main/shader_globals_override.h index 33d0dc948f..d470e6a7dc 100644 --- a/scene/main/shader_globals_override.h +++ b/scene/main/shader_globals_override.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* shader_globals_override.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef SHADER_GLOBALS_OVERRIDE_H #define SHADER_GLOBALS_OVERRIDE_H diff --git a/scene/main/timer.cpp b/scene/main/timer.cpp index 7c847095e1..7cab4028b8 100644 --- a/scene/main/timer.cpp +++ b/scene/main/timer.cpp @@ -173,8 +173,12 @@ Timer::TimerProcessMode Timer::get_timer_process_mode() const { void Timer::_set_process(bool p_process, bool p_force) { switch (timer_process_mode) { - case TIMER_PROCESS_PHYSICS: set_physics_process_internal(p_process && !paused); break; - case TIMER_PROCESS_IDLE: set_process_internal(p_process && !paused); break; + case TIMER_PROCESS_PHYSICS: + set_physics_process_internal(p_process && !paused); + break; + case TIMER_PROCESS_IDLE: + set_process_internal(p_process && !paused); + break; } processing = p_process; } diff --git a/scene/main/viewport.cpp b/scene/main/viewport.cpp index e9827d521b..58024dab38 100644 --- a/scene/main/viewport.cpp +++ b/scene/main/viewport.cpp @@ -173,7 +173,7 @@ class TooltipPanel : public PopupPanel { GDCLASS(TooltipPanel, PopupPanel); public: - TooltipPanel(){}; + TooltipPanel() {} }; class TooltipLabel : public Label { @@ -181,7 +181,7 @@ class TooltipLabel : public Label { GDCLASS(TooltipLabel, Label); public: - TooltipLabel(){}; + TooltipLabel() {} }; Viewport::GUI::GUI() { diff --git a/scene/register_scene_types.cpp b/scene/register_scene_types.cpp index dc3ef5b508..24f92bdc20 100644 --- a/scene/register_scene_types.cpp +++ b/scene/register_scene_types.cpp @@ -193,6 +193,7 @@ #include "scene/3d/gpu_particles_3d.h" #include "scene/3d/immediate_geometry_3d.h" #include "scene/3d/light_3d.h" +#include "scene/3d/lightmap_probe.h" #include "scene/3d/listener_3d.h" #include "scene/3d/mesh_instance_3d.h" #include "scene/3d/multimesh_instance_3d.h" @@ -225,8 +226,8 @@ static Ref<ResourceFormatLoaderText> resource_loader_text; static Ref<ResourceFormatLoaderDynamicFont> resource_loader_dynamic_font; -static Ref<ResourceFormatLoaderStreamTexture> resource_loader_stream_texture; -static Ref<ResourceFormatLoaderTextureLayered> resource_loader_texture_layered; +static Ref<ResourceFormatLoaderStreamTexture2D> resource_loader_stream_texture; +static Ref<ResourceFormatLoaderStreamTextureLayered> resource_loader_texture_layered; static Ref<ResourceFormatLoaderBMFont> resource_loader_bmfont; @@ -432,8 +433,10 @@ void register_scene_types() { ClassDB::register_class<Decal>(); ClassDB::register_class<GIProbe>(); ClassDB::register_class<GIProbeData>(); - //ClassDB::register_class<BakedLightmap>(); - //ClassDB::register_class<BakedLightmapData>(); + ClassDB::register_class<BakedLightmap>(); + ClassDB::register_class<BakedLightmapData>(); + ClassDB::register_class<LightmapProbe>(); + ClassDB::register_virtual_class<Lightmapper>(); ClassDB::register_class<GPUParticles3D>(); ClassDB::register_class<CPUParticles3D>(); ClassDB::register_class<Position3D>(); @@ -675,7 +678,7 @@ void register_scene_types() { ClassDB::register_virtual_class<Texture>(); ClassDB::register_virtual_class<Texture2D>(); ClassDB::register_class<Sky>(); - ClassDB::register_class<StreamTexture>(); + ClassDB::register_class<StreamTexture2D>(); ClassDB::register_class<ImageTexture>(); ClassDB::register_class<AtlasTexture>(); ClassDB::register_class<MeshTexture>(); @@ -686,9 +689,15 @@ void register_scene_types() { ClassDB::register_class<AnimatedTexture>(); ClassDB::register_class<CameraTexture>(); ClassDB::register_virtual_class<TextureLayered>(); + ClassDB::register_virtual_class<ImageTextureLayered>(); ClassDB::register_class<Cubemap>(); ClassDB::register_class<CubemapArray>(); ClassDB::register_class<Texture2DArray>(); + ClassDB::register_virtual_class<StreamTextureLayered>(); + ClassDB::register_class<StreamCubemap>(); + ClassDB::register_class<StreamCubemapArray>(); + ClassDB::register_class<StreamTexture2DArray>(); + ClassDB::register_class<Animation>(); ClassDB::register_virtual_class<Font>(); ClassDB::register_class<BitmapFont>(); @@ -867,6 +876,7 @@ void register_scene_types() { ClassDB::add_compatibility_class("VisualShaderNodeScalarOp", "VisualShaderNodeFloatOp"); ClassDB::add_compatibility_class("VisualShaderNodeScalarUniform", "VisualShaderNodeFloatUniform"); ClassDB::add_compatibility_class("World", "World3D"); + ClassDB::add_compatibility_class("StreamTexture", "StreamTexture2D"); #endif diff --git a/scene/resources/animation.cpp b/scene/resources/animation.cpp index ea4338519e..c5806ee7b3 100644 --- a/scene/resources/animation.cpp +++ b/scene/resources/animation.cpp @@ -349,12 +349,24 @@ bool Animation::_get(const StringName &p_name, Variant &r_ret) const { switch (track_get_type(track)) { - case TYPE_TRANSFORM: r_ret = "transform"; break; - case TYPE_VALUE: r_ret = "value"; break; - case TYPE_METHOD: r_ret = "method"; break; - case TYPE_BEZIER: r_ret = "bezier"; break; - case TYPE_AUDIO: r_ret = "audio"; break; - case TYPE_ANIMATION: r_ret = "animation"; break; + case TYPE_TRANSFORM: + r_ret = "transform"; + break; + case TYPE_VALUE: + r_ret = "value"; + break; + case TYPE_METHOD: + r_ret = "method"; + break; + case TYPE_BEZIER: + r_ret = "bezier"; + break; + case TYPE_AUDIO: + r_ret = "audio"; + break; + case TYPE_ANIMATION: + r_ret = "animation"; + break; } return true; @@ -1842,7 +1854,8 @@ T Animation::_interpolate(const Vector<TKey<T>> &p_keys, float p_time, Interpola return _cubic_interpolate(p_keys[pre].value, p_keys[idx].value, p_keys[next].value, p_keys[post].value, c); } break; - default: return p_keys[idx].value; + default: + return p_keys[idx].value; } // do a barrel roll diff --git a/scene/resources/audio_stream_sample.cpp b/scene/resources/audio_stream_sample.cpp index d630a1f3ee..fdf5e2c2d0 100644 --- a/scene/resources/audio_stream_sample.cpp +++ b/scene/resources/audio_stream_sample.cpp @@ -230,9 +230,15 @@ void AudioStreamPlaybackSample::mix(AudioFrame *p_buffer, float p_rate_scale, in int len = base->data_bytes; switch (base->format) { - case AudioStreamSample::FORMAT_8_BITS: len /= 1; break; - case AudioStreamSample::FORMAT_16_BITS: len /= 2; break; - case AudioStreamSample::FORMAT_IMA_ADPCM: len *= 2; break; + case AudioStreamSample::FORMAT_8_BITS: + len /= 1; + break; + case AudioStreamSample::FORMAT_16_BITS: + len /= 2; + break; + case AudioStreamSample::FORMAT_IMA_ADPCM: + len *= 2; + break; } if (base->stereo) { @@ -465,9 +471,15 @@ float AudioStreamSample::get_length() const { int len = data_bytes; switch (format) { - case AudioStreamSample::FORMAT_8_BITS: len /= 1; break; - case AudioStreamSample::FORMAT_16_BITS: len /= 2; break; - case AudioStreamSample::FORMAT_IMA_ADPCM: len *= 2; break; + case AudioStreamSample::FORMAT_8_BITS: + len /= 1; + break; + case AudioStreamSample::FORMAT_16_BITS: + len /= 2; + break; + case AudioStreamSample::FORMAT_IMA_ADPCM: + len *= 2; + break; } if (stereo) { @@ -536,9 +548,15 @@ Error AudioStreamSample::save_to_wav(const String &p_path) { int byte_pr_sample = 0; switch (format) { - case AudioStreamSample::FORMAT_8_BITS: byte_pr_sample = 1; break; - case AudioStreamSample::FORMAT_16_BITS: byte_pr_sample = 2; break; - case AudioStreamSample::FORMAT_IMA_ADPCM: byte_pr_sample = 4; break; + case AudioStreamSample::FORMAT_8_BITS: + byte_pr_sample = 1; + break; + case AudioStreamSample::FORMAT_16_BITS: + byte_pr_sample = 2; + break; + case AudioStreamSample::FORMAT_IMA_ADPCM: + byte_pr_sample = 4; + break; } String file_path = p_path; diff --git a/scene/resources/capsule_shape_2d.cpp b/scene/resources/capsule_shape_2d.cpp index ab2657c892..5a3282478c 100644 --- a/scene/resources/capsule_shape_2d.cpp +++ b/scene/resources/capsule_shape_2d.cpp @@ -72,6 +72,9 @@ real_t CapsuleShape2D::get_radius() const { void CapsuleShape2D::set_height(real_t p_height) { height = p_height; + if (height < 0) + height = 0; + _update_shape(); } diff --git a/scene/resources/environment.cpp b/scene/resources/environment.cpp index 02ea5b24b8..abbe579307 100644 --- a/scene/resources/environment.cpp +++ b/scene/resources/environment.cpp @@ -164,7 +164,7 @@ float Environment::get_ambient_light_sky_contribution() const { return ambient_sky_contribution; } -int Environment::get_camera_feed_id(void) const { +int Environment::get_camera_feed_id() const { return camera_feed_id; } @@ -1117,11 +1117,7 @@ void Environment::_bind_methods() { BIND_ENUM_CONSTANT(SSAO_BLUR_3x3); } -Environment::Environment() : - bg_mode(BG_CLEAR_COLOR), - tone_mapper(TONE_MAPPER_LINEAR), - ssao_blur(SSAO_BLUR_3x3), - glow_blend_mode(GLOW_BLEND_MODE_ADDITIVE) { +Environment::Environment() { environment = RS::get_singleton()->environment_create(); diff --git a/scene/resources/environment.h b/scene/resources/environment.h index 02434bc592..f0ea8489ff 100644 --- a/scene/resources/environment.h +++ b/scene/resources/environment.h @@ -90,7 +90,7 @@ public: private: RID environment; - BGMode bg_mode; + BGMode bg_mode = BG_CLEAR_COLOR; Ref<Sky> bg_sky; float bg_sky_custom_fov; Vector3 sky_rotation; @@ -105,7 +105,7 @@ private: AmbientSource ambient_source; ReflectionSource reflection_source; - ToneMapper tone_mapper; + ToneMapper tone_mapper = TONE_MAPPER_LINEAR; float tonemap_exposure; float tonemap_white; bool tonemap_auto_exposure; @@ -132,7 +132,7 @@ private: float ssao_bias; float ssao_direct_light_affect; float ssao_ao_channel_affect; - SSAOBlur ssao_blur; + SSAOBlur ssao_blur = SSAO_BLUR_3x3; float ssao_edge_sharpness; bool glow_enabled; @@ -141,7 +141,7 @@ private: float glow_strength; float glow_mix; float glow_bloom; - GlowBlendMode glow_blend_mode; + GlowBlendMode glow_blend_mode = GLOW_BLEND_MODE_ADDITIVE; float glow_hdr_bleed_threshold; float glow_hdr_bleed_scale; float glow_hdr_luminance_cap; @@ -200,7 +200,7 @@ public: Color get_ambient_light_color() const; float get_ambient_light_energy() const; float get_ambient_light_sky_contribution() const; - int get_camera_feed_id(void) const; + int get_camera_feed_id() const; void set_tonemapper(ToneMapper p_tone_mapper); ToneMapper get_tonemapper() const; diff --git a/scene/resources/material.cpp b/scene/resources/material.cpp index fd8cff7cd0..137657d239 100644 --- a/scene/resources/material.cpp +++ b/scene/resources/material.cpp @@ -414,13 +414,26 @@ void BaseMaterial3D::_update_shader() { String texfilter_str; switch (texture_filter) { - case TEXTURE_FILTER_NEAREST: texfilter_str = "filter_nearest"; break; - case TEXTURE_FILTER_LINEAR: texfilter_str = "filter_linear"; break; - case TEXTURE_FILTER_NEAREST_WITH_MIPMAPS: texfilter_str = "filter_nearest_mipmap"; break; - case TEXTURE_FILTER_LINEAR_WITH_MIPMAPS: texfilter_str = "filter_linear_mipmap"; break; - case TEXTURE_FILTER_NEAREST_WITH_MIPMAPS_ANISOTROPIC: texfilter_str = "filter_nearest_mipmap_aniso"; break; - case TEXTURE_FILTER_LINEAR_WITH_MIPMAPS_ANISOTROPIC: texfilter_str = "filter_linear_mipmap_aniso"; break; - case TEXTURE_FILTER_MAX: break; // Internal value, skip. + case TEXTURE_FILTER_NEAREST: + texfilter_str = "filter_nearest"; + break; + case TEXTURE_FILTER_LINEAR: + texfilter_str = "filter_linear"; + break; + case TEXTURE_FILTER_NEAREST_WITH_MIPMAPS: + texfilter_str = "filter_nearest_mipmap"; + break; + case TEXTURE_FILTER_LINEAR_WITH_MIPMAPS: + texfilter_str = "filter_linear_mipmap"; + break; + case TEXTURE_FILTER_NEAREST_WITH_MIPMAPS_ANISOTROPIC: + texfilter_str = "filter_nearest_mipmap_aniso"; + break; + case TEXTURE_FILTER_LINEAR_WITH_MIPMAPS_ANISOTROPIC: + texfilter_str = "filter_linear_mipmap_aniso"; + break; + case TEXTURE_FILTER_MAX: + break; // Internal value, skip. } if (flags[FLAG_USE_TEXTURE_REPEAT]) { @@ -433,10 +446,18 @@ void BaseMaterial3D::_update_shader() { String code = "shader_type spatial;\nrender_mode "; switch (blend_mode) { - case BLEND_MODE_MIX: code += "blend_mix"; break; - case BLEND_MODE_ADD: code += "blend_add"; break; - case BLEND_MODE_SUB: code += "blend_sub"; break; - case BLEND_MODE_MUL: code += "blend_mul"; break; + case BLEND_MODE_MIX: + code += "blend_mix"; + break; + case BLEND_MODE_ADD: + code += "blend_add"; + break; + case BLEND_MODE_SUB: + code += "blend_sub"; + break; + case BLEND_MODE_MUL: + code += "blend_mul"; + break; } DepthDrawMode ddm = depth_draw_mode; @@ -445,9 +466,15 @@ void BaseMaterial3D::_update_shader() { } switch (ddm) { - case DEPTH_DRAW_OPAQUE_ONLY: code += ",depth_draw_opaque"; break; - case DEPTH_DRAW_ALWAYS: code += ",depth_draw_always"; break; - case DEPTH_DRAW_DISABLED: code += ",depth_draw_never"; break; + case DEPTH_DRAW_OPAQUE_ONLY: + code += ",depth_draw_opaque"; + break; + case DEPTH_DRAW_ALWAYS: + code += ",depth_draw_always"; + break; + case DEPTH_DRAW_DISABLED: + code += ",depth_draw_never"; + break; } if (transparency == TRANSPARENCY_ALPHA_DEPTH_PRE_PASS) { @@ -455,23 +482,49 @@ void BaseMaterial3D::_update_shader() { } switch (cull_mode) { - case CULL_BACK: code += ",cull_back"; break; - case CULL_FRONT: code += ",cull_front"; break; - case CULL_DISABLED: code += ",cull_disabled"; break; + case CULL_BACK: + code += ",cull_back"; + break; + case CULL_FRONT: + code += ",cull_front"; + break; + case CULL_DISABLED: + code += ",cull_disabled"; + break; } switch (diffuse_mode) { - case DIFFUSE_BURLEY: code += ",diffuse_burley"; break; - case DIFFUSE_LAMBERT: code += ",diffuse_lambert"; break; - case DIFFUSE_LAMBERT_WRAP: code += ",diffuse_lambert_wrap"; break; - case DIFFUSE_OREN_NAYAR: code += ",diffuse_oren_nayar"; break; - case DIFFUSE_TOON: code += ",diffuse_toon"; break; + case DIFFUSE_BURLEY: + code += ",diffuse_burley"; + break; + case DIFFUSE_LAMBERT: + code += ",diffuse_lambert"; + break; + case DIFFUSE_LAMBERT_WRAP: + code += ",diffuse_lambert_wrap"; + break; + case DIFFUSE_OREN_NAYAR: + code += ",diffuse_oren_nayar"; + break; + case DIFFUSE_TOON: + code += ",diffuse_toon"; + break; } switch (specular_mode) { - case SPECULAR_SCHLICK_GGX: code += ",specular_schlick_ggx"; break; - case SPECULAR_BLINN: code += ",specular_blinn"; break; - case SPECULAR_PHONG: code += ",specular_phong"; break; - case SPECULAR_TOON: code += ",specular_toon"; break; - case SPECULAR_DISABLED: code += ",specular_disabled"; break; + case SPECULAR_SCHLICK_GGX: + code += ",specular_schlick_ggx"; + break; + case SPECULAR_BLINN: + code += ",specular_blinn"; + break; + case SPECULAR_PHONG: + code += ",specular_phong"; + break; + case SPECULAR_TOON: + code += ",specular_toon"; + break; + case SPECULAR_DISABLED: + code += ",specular_disabled"; + break; } if (features[FEATURE_SUBSURFACE_SCATTERING] && flags[FLAG_SUBSURFACE_MODE_SKIN]) { code += ",sss_mode_skin"; @@ -2637,7 +2690,7 @@ BaseMaterial3D::BaseMaterial3D(bool p_orm) : depth_draw_mode = DEPTH_DRAW_OPAQUE_ONLY; cull_mode = CULL_BACK; for (int i = 0; i < FLAG_MAX; i++) { - flags[i] = 0; + flags[i] = false; } flags[FLAG_USE_TEXTURE_REPEAT] = true; diff --git a/scene/resources/mesh.cpp b/scene/resources/mesh.cpp index 401b689145..6548c65cd7 100644 --- a/scene/resources/mesh.cpp +++ b/scene/resources/mesh.cpp @@ -472,11 +472,11 @@ Ref<Mesh> Mesh::create_outline(float p_margin) const { return newmesh; } -void Mesh::set_lightmap_size_hint(const Vector2 &p_size) { +void Mesh::set_lightmap_size_hint(const Size2i &p_size) { lightmap_size_hint = p_size; } -Size2 Mesh::get_lightmap_size_hint() const { +Size2i Mesh::get_lightmap_size_hint() const { return lightmap_size_hint; } @@ -486,7 +486,7 @@ void Mesh::_bind_methods() { ClassDB::bind_method(D_METHOD("get_lightmap_size_hint"), &Mesh::get_lightmap_size_hint); ClassDB::bind_method(D_METHOD("get_aabb"), &Mesh::get_aabb); - ADD_PROPERTY(PropertyInfo(Variant::VECTOR2, "lightmap_size_hint"), "set_lightmap_size_hint", "get_lightmap_size_hint"); + ADD_PROPERTY(PropertyInfo(Variant::VECTOR2I, "lightmap_size_hint"), "set_lightmap_size_hint", "get_lightmap_size_hint"); ClassDB::bind_method(D_METHOD("get_surface_count"), &Mesh::get_surface_count); ClassDB::bind_method(D_METHOD("surface_get_arrays", "surf_idx"), &Mesh::surface_get_arrays); diff --git a/scene/resources/mesh.h b/scene/resources/mesh.h index a65cf0a928..80cd57846b 100644 --- a/scene/resources/mesh.h +++ b/scene/resources/mesh.h @@ -43,7 +43,7 @@ class Mesh : public Resource { mutable Ref<TriangleMesh> triangle_mesh; //cached mutable Vector<Vector3> debug_lines; - Size2 lightmap_size_hint; + Size2i lightmap_size_hint; protected: static void _bind_methods(); @@ -138,8 +138,8 @@ public: virtual AABB get_aabb() const = 0; - void set_lightmap_size_hint(const Vector2 &p_size); - Size2 get_lightmap_size_hint() const; + void set_lightmap_size_hint(const Size2i &p_size); + Size2i get_lightmap_size_hint() const; void clear_cache() const; typedef Vector<Vector<Face3>> (*ConvexDecompositionFunc)(const Vector<Face3> &); diff --git a/scene/resources/particles_material.cpp b/scene/resources/particles_material.cpp index 83430aef9e..1a28e2586d 100644 --- a/scene/resources/particles_material.cpp +++ b/scene/resources/particles_material.cpp @@ -703,7 +703,8 @@ void ParticlesMaterial::set_param(Parameter p_param, float p_value) { case PARAM_ANIM_OFFSET: { RenderingServer::get_singleton()->material_set_param(_get_material(), shader_names->anim_offset, p_value); } break; - case PARAM_MAX: break; // Can't happen, but silences warning + case PARAM_MAX: + break; // Can't happen, but silences warning } } float ParticlesMaterial::get_param(Parameter p_param) const { @@ -756,7 +757,8 @@ void ParticlesMaterial::set_param_randomness(Parameter p_param, float p_value) { case PARAM_ANIM_OFFSET: { RenderingServer::get_singleton()->material_set_param(_get_material(), shader_names->anim_offset_random, p_value); } break; - case PARAM_MAX: break; // Can't happen, but silences warning + case PARAM_MAX: + break; // Can't happen, but silences warning } } float ParticlesMaterial::get_param_randomness(Parameter p_param) const { @@ -828,7 +830,8 @@ void ParticlesMaterial::set_param_texture(Parameter p_param, const Ref<Texture2D case PARAM_ANIM_OFFSET: { RenderingServer::get_singleton()->material_set_param(_get_material(), shader_names->anim_offset_texture, p_texture); } break; - case PARAM_MAX: break; // Can't happen, but silences warning + case PARAM_MAX: + break; // Can't happen, but silences warning } _queue_shader_change(); diff --git a/scene/resources/physics_material.cpp b/scene/resources/physics_material.cpp index 8ac0191452..89fbf7a248 100644 --- a/scene/resources/physics_material.cpp +++ b/scene/resources/physics_material.cpp @@ -69,9 +69,3 @@ void PhysicsMaterial::set_absorbent(bool p_val) { absorbent = p_val; emit_changed(); } - -PhysicsMaterial::PhysicsMaterial() : - friction(1), - rough(false), - bounce(0), - absorbent(false) {} diff --git a/scene/resources/physics_material.h b/scene/resources/physics_material.h index f4a77d9854..1fac3fa6ef 100644 --- a/scene/resources/physics_material.h +++ b/scene/resources/physics_material.h @@ -28,8 +28,8 @@ /* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /*************************************************************************/ -#ifndef physics_material_override_H -#define physics_material_override_H +#ifndef PHYSICS_MATERIAL_H +#define PHYSICS_MATERIAL_H #include "core/resource.h" #include "servers/physics_server_3d.h" @@ -40,10 +40,10 @@ class PhysicsMaterial : public Resource { OBJ_SAVE_TYPE(PhysicsMaterial); RES_BASE_EXTENSION("phymat"); - real_t friction; - bool rough; - real_t bounce; - bool absorbent; + real_t friction = 1; + bool rough = false; + real_t bounce = 0; + bool absorbent = false; protected: static void _bind_methods(); @@ -69,7 +69,7 @@ public: return absorbent ? -bounce : bounce; } - PhysicsMaterial(); + PhysicsMaterial() {} }; -#endif // physics_material_override_H +#endif // PHYSICS_MATERIAL_H diff --git a/scene/resources/shape_3d.cpp b/scene/resources/shape_3d.cpp index f4a5d91e52..97d03b3cfc 100644 --- a/scene/resources/shape_3d.cpp +++ b/scene/resources/shape_3d.cpp @@ -109,19 +109,13 @@ void Shape3D::_bind_methods() { ADD_PROPERTY(PropertyInfo(Variant::FLOAT, "margin", PROPERTY_HINT_RANGE, "0.001,10,0.001"), "set_margin", "get_margin"); } -Shape3D::Shape3D() : - margin(0.04) { - - ERR_PRINT("Constructor must not be called!"); +Shape3D::Shape3D() { + ERR_PRINT("Default constructor must not be called!"); } Shape3D::Shape3D(RID p_shape) : - margin(0.04) { - - shape = p_shape; -} + shape(p_shape) {} Shape3D::~Shape3D() { - PhysicsServer3D::get_singleton()->free(shape); } diff --git a/scene/resources/shape_3d.h b/scene/resources/shape_3d.h index e7a516412d..6516868fd5 100644 --- a/scene/resources/shape_3d.h +++ b/scene/resources/shape_3d.h @@ -41,7 +41,7 @@ class Shape3D : public Resource { OBJ_SAVE_TYPE(Shape3D); RES_BASE_EXTENSION("shape"); RID shape; - real_t margin; + real_t margin = 0.04; Ref<ArrayMesh> debug_mesh_cache; diff --git a/scene/resources/texture.cpp b/scene/resources/texture.cpp index d57af29599..b212bba826 100644 --- a/scene/resources/texture.cpp +++ b/scene/resources/texture.cpp @@ -355,7 +355,7 @@ ImageTexture::~ImageTexture() { ////////////////////////////////////////// -Ref<Image> StreamTexture::load_image_from_file(FileAccess *f, int p_size_limit) { +Ref<Image> StreamTexture2D::load_image_from_file(FileAccess *f, int p_size_limit) { uint32_t data_format = f->get_32(); uint32_t w = f->get_16(); @@ -492,7 +492,7 @@ Ref<Image> StreamTexture::load_image_from_file(FileAccess *f, int p_size_limit) return Ref<Image>(); } -void StreamTexture::set_path(const String &p_path, bool p_take_over) { +void StreamTexture2D::set_path(const String &p_path, bool p_take_over) { if (texture.is_valid()) { RenderingServer::get_singleton()->texture_set_path(texture, p_path); @@ -501,40 +501,40 @@ void StreamTexture::set_path(const String &p_path, bool p_take_over) { Resource::set_path(p_path, p_take_over); } -void StreamTexture::_requested_3d(void *p_ud) { +void StreamTexture2D::_requested_3d(void *p_ud) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref<StreamTexture> stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref<StreamTexture2D> stex(st); ERR_FAIL_COND(!request_3d_callback); request_3d_callback(stex); } -void StreamTexture::_requested_roughness(void *p_ud, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel) { +void StreamTexture2D::_requested_roughness(void *p_ud, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref<StreamTexture> stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref<StreamTexture2D> stex(st); ERR_FAIL_COND(!request_roughness_callback); request_roughness_callback(stex, p_normal_path, p_roughness_channel); } -void StreamTexture::_requested_normal(void *p_ud) { +void StreamTexture2D::_requested_normal(void *p_ud) { - StreamTexture *st = (StreamTexture *)p_ud; - Ref<StreamTexture> stex(st); + StreamTexture2D *st = (StreamTexture2D *)p_ud; + Ref<StreamTexture2D> stex(st); ERR_FAIL_COND(!request_normal_callback); request_normal_callback(stex); } -StreamTexture::TextureFormatRequestCallback StreamTexture::request_3d_callback = nullptr; -StreamTexture::TextureFormatRoughnessRequestCallback StreamTexture::request_roughness_callback = nullptr; -StreamTexture::TextureFormatRequestCallback StreamTexture::request_normal_callback = nullptr; +StreamTexture2D::TextureFormatRequestCallback StreamTexture2D::request_3d_callback = nullptr; +StreamTexture2D::TextureFormatRoughnessRequestCallback StreamTexture2D::request_roughness_callback = nullptr; +StreamTexture2D::TextureFormatRequestCallback StreamTexture2D::request_normal_callback = nullptr; -Image::Format StreamTexture::get_format() const { +Image::Format StreamTexture2D::get_format() const { return format; } -Error StreamTexture::_load_data(const String &p_path, int &tw, int &th, int &tw_custom, int &th_custom, Ref<Image> &image, bool &r_request_3d, bool &r_request_normal, bool &r_request_roughness, int &mipmap_limit, int p_size_limit) { +Error StreamTexture2D::_load_data(const String &p_path, int &tw, int &th, int &tw_custom, int &th_custom, Ref<Image> &image, bool &r_request_3d, bool &r_request_normal, bool &r_request_roughness, int &mipmap_limit, int p_size_limit) { alpha_cache.unref(); @@ -595,7 +595,7 @@ Error StreamTexture::_load_data(const String &p_path, int &tw, int &th, int &tw_ return OK; } -Error StreamTexture::load(const String &p_path) { +Error StreamTexture2D::load(const String &p_path) { int lw, lh, lwc, lhc; Ref<Image> image; @@ -661,20 +661,20 @@ Error StreamTexture::load(const String &p_path) { emit_changed(); return OK; } -String StreamTexture::get_load_path() const { +String StreamTexture2D::get_load_path() const { return path_to_file; } -int StreamTexture::get_width() const { +int StreamTexture2D::get_width() const { return w; } -int StreamTexture::get_height() const { +int StreamTexture2D::get_height() const { return h; } -RID StreamTexture::get_rid() const { +RID StreamTexture2D::get_rid() const { if (!texture.is_valid()) { texture = RS::get_singleton()->texture_2d_placeholder_create(); @@ -682,7 +682,7 @@ RID StreamTexture::get_rid() const { return texture; } -void StreamTexture::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { +void StreamTexture2D::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { if ((w | h) == 0) return; @@ -690,7 +690,7 @@ void StreamTexture::draw(RID p_canvas_item, const Point2 &p_pos, const Color &p_ RID specular_rid = p_specular_map.is_valid() ? p_specular_map->get_rid() : RID(); RenderingServer::get_singleton()->canvas_item_add_texture_rect(p_canvas_item, Rect2(p_pos, Size2(w, h)), texture, false, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_texture_filter, p_texture_repeat); } -void StreamTexture::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_tile, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { +void StreamTexture2D::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_tile, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat) const { if ((w | h) == 0) return; @@ -698,7 +698,7 @@ void StreamTexture::draw_rect(RID p_canvas_item, const Rect2 &p_rect, bool p_til RID specular_rid = p_specular_map.is_valid() ? p_specular_map->get_rid() : RID(); RenderingServer::get_singleton()->canvas_item_add_texture_rect(p_canvas_item, p_rect, texture, p_tile, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_texture_filter, p_texture_repeat); } -void StreamTexture::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, const Rect2 &p_src_rect, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat, bool p_clip_uv) const { +void StreamTexture2D::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, const Rect2 &p_src_rect, const Color &p_modulate, bool p_transpose, const Ref<Texture2D> &p_normal_map, const Ref<Texture2D> &p_specular_map, const Color &p_specular_color_shininess, RS::CanvasItemTextureFilter p_texture_filter, RS::CanvasItemTextureRepeat p_texture_repeat, bool p_clip_uv) const { if ((w | h) == 0) return; @@ -707,12 +707,12 @@ void StreamTexture::draw_rect_region(RID p_canvas_item, const Rect2 &p_rect, con RenderingServer::get_singleton()->canvas_item_add_texture_rect_region(p_canvas_item, p_rect, texture, p_src_rect, p_modulate, p_transpose, normal_rid, specular_rid, p_specular_color_shininess, p_clip_uv, p_texture_filter, p_texture_repeat); } -bool StreamTexture::has_alpha() const { +bool StreamTexture2D::has_alpha() const { return false; } -Ref<Image> StreamTexture::get_data() const { +Ref<Image> StreamTexture2D::get_data() const { if (texture.is_valid()) { return RS::get_singleton()->texture_2d_get(texture); @@ -721,7 +721,7 @@ Ref<Image> StreamTexture::get_data() const { } } -bool StreamTexture::is_pixel_opaque(int p_x, int p_y) const { +bool StreamTexture2D::is_pixel_opaque(int p_x, int p_y) const { if (!alpha_cache.is_valid()) { Ref<Image> img = get_data(); @@ -757,7 +757,7 @@ bool StreamTexture::is_pixel_opaque(int p_x, int p_y) const { return true; } -void StreamTexture::reload_from_file() { +void StreamTexture2D::reload_from_file() { String path = get_path(); if (!path.is_resource_file()) @@ -771,34 +771,34 @@ void StreamTexture::reload_from_file() { load(path); } -void StreamTexture::_validate_property(PropertyInfo &property) const { +void StreamTexture2D::_validate_property(PropertyInfo &property) const { } -void StreamTexture::_bind_methods() { +void StreamTexture2D::_bind_methods() { - ClassDB::bind_method(D_METHOD("load", "path"), &StreamTexture::load); - ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTexture::get_load_path); + ClassDB::bind_method(D_METHOD("load", "path"), &StreamTexture2D::load); + ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTexture2D::get_load_path); ADD_PROPERTY(PropertyInfo(Variant::STRING, "load_path", PROPERTY_HINT_FILE, "*.stex"), "load", "get_load_path"); } -StreamTexture::StreamTexture() { +StreamTexture2D::StreamTexture2D() { format = Image::FORMAT_MAX; w = 0; h = 0; } -StreamTexture::~StreamTexture() { +StreamTexture2D::~StreamTexture2D() { if (texture.is_valid()) { RS::get_singleton()->free(texture); } } -RES ResourceFormatLoaderStreamTexture::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { +RES ResourceFormatLoaderStreamTexture2D::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { - Ref<StreamTexture> st; + Ref<StreamTexture2D> st; st.instance(); Error err = st->load(p_path); if (r_error) @@ -809,17 +809,17 @@ RES ResourceFormatLoaderStreamTexture::load(const String &p_path, const String & return st; } -void ResourceFormatLoaderStreamTexture::get_recognized_extensions(List<String> *p_extensions) const { +void ResourceFormatLoaderStreamTexture2D::get_recognized_extensions(List<String> *p_extensions) const { p_extensions->push_back("stex"); } -bool ResourceFormatLoaderStreamTexture::handles_type(const String &p_type) const { - return p_type == "StreamTexture"; +bool ResourceFormatLoaderStreamTexture2D::handles_type(const String &p_type) const { + return p_type == "StreamTexture2D"; } -String ResourceFormatLoaderStreamTexture::get_resource_type(const String &p_path) const { +String ResourceFormatLoaderStreamTexture2D::get_resource_type(const String &p_path) const { if (p_path.get_extension().to_lower() == "stex") - return "StreamTexture"; + return "StreamTexture2D"; return ""; } @@ -1037,8 +1037,10 @@ bool AtlasTexture::is_pixel_opaque(int p_x, int p_y) const { int y = p_y + region.position.y - margin.position.y; // margin edge may outside of atlas - if (x < 0 || x >= atlas->get_width()) return false; - if (y < 0 || y >= atlas->get_height()) return false; + if (x < 0 || x >= atlas->get_width()) + return false; + if (y < 0 || y >= atlas->get_height()) + return false; return atlas->is_pixel_opaque(x, y); } @@ -1928,23 +1930,47 @@ AnimatedTexture::~AnimatedTexture() { } /////////////////////////////// -Image::Format TextureLayered::get_format() const { +void TextureLayered::_bind_methods() { + + ClassDB::bind_method(D_METHOD("get_format"), &TextureLayered::get_format); + ClassDB::bind_method(D_METHOD("get_layered_type"), &TextureLayered::get_layered_type); + ClassDB::bind_method(D_METHOD("get_width"), &TextureLayered::get_width); + ClassDB::bind_method(D_METHOD("get_height"), &TextureLayered::get_height); + ClassDB::bind_method(D_METHOD("get_layers"), &TextureLayered::get_layers); + ClassDB::bind_method(D_METHOD("has_mipmaps"), &TextureLayered::has_mipmaps); + ClassDB::bind_method(D_METHOD("get_layer_data", "layer"), &TextureLayered::get_layer_data); + + BIND_ENUM_CONSTANT(LAYERED_TYPE_2D_ARRAY); + BIND_ENUM_CONSTANT(LAYERED_TYPE_CUBEMAP); + BIND_ENUM_CONSTANT(LAYERED_TYPE_CUBEMAP_ARRAY); +} + +/////////////////////////////// +Image::Format ImageTextureLayered::get_format() const { return format; } -uint32_t TextureLayered::get_width() const { +int ImageTextureLayered::get_width() const { return width; } -uint32_t TextureLayered::get_height() const { +int ImageTextureLayered::get_height() const { return height; } -uint32_t TextureLayered::get_layers() const { +int ImageTextureLayered::get_layers() const { return layers; } -Error TextureLayered::_create_from_images(const Array &p_images) { +bool ImageTextureLayered::has_mipmaps() const { + return mipmaps; +} + +ImageTextureLayered::LayeredType ImageTextureLayered::get_layered_type() const { + return layered_type; +} + +Error ImageTextureLayered::_create_from_images(const Array &p_images) { Vector<Ref<Image>> images; for (int i = 0; i < p_images.size(); i++) { Ref<Image> img = p_images[i]; @@ -1955,7 +1981,7 @@ Error TextureLayered::_create_from_images(const Array &p_images) { return create_from_images(images); } -Array TextureLayered::_get_images() const { +Array ImageTextureLayered::_get_images() const { Array images; for (int i = 0; i < layers; i++) { images.push_back(get_layer_data(i)); @@ -1963,14 +1989,14 @@ Array TextureLayered::_get_images() const { return images; } -Error TextureLayered::create_from_images(Vector<Ref<Image>> p_images) { +Error ImageTextureLayered::create_from_images(Vector<Ref<Image>> p_images) { int new_layers = p_images.size(); ERR_FAIL_COND_V(new_layers == 0, ERR_INVALID_PARAMETER); - if (layered_type == RS::TEXTURE_LAYERED_CUBEMAP) { + if (layered_type == LAYERED_TYPE_CUBEMAP) { ERR_FAIL_COND_V_MSG(new_layers != 6, ERR_INVALID_PARAMETER, "Cubemaps require exactly 6 layers"); - } else if (layered_type == RS::TEXTURE_LAYERED_CUBEMAP_ARRAY) { + } else if (layered_type == LAYERED_TYPE_CUBEMAP_ARRAY) { ERR_FAIL_COND_V_MSG((new_layers % 6) != 0, ERR_INVALID_PARAMETER, "Cubemap array layers must be a multiple of 6"); } @@ -1992,11 +2018,11 @@ Error TextureLayered::create_from_images(Vector<Ref<Image>> p_images) { } if (texture.is_valid()) { - RID new_texture = RS::get_singleton()->texture_2d_layered_create(p_images, layered_type); + RID new_texture = RS::get_singleton()->texture_2d_layered_create(p_images, RS::TextureLayeredType(layered_type)); ERR_FAIL_COND_V(!new_texture.is_valid(), ERR_CANT_CREATE); RS::get_singleton()->texture_replace(texture, new_texture); } else { - texture = RS::get_singleton()->texture_2d_layered_create(p_images, layered_type); + texture = RS::get_singleton()->texture_2d_layered_create(p_images, RS::TextureLayeredType(layered_type)); ERR_FAIL_COND_V(!texture.is_valid(), ERR_CANT_CREATE); } @@ -2008,7 +2034,7 @@ Error TextureLayered::create_from_images(Vector<Ref<Image>> p_images) { return OK; } -void TextureLayered::update_layer(const Ref<Image> &p_image, int p_layer) { +void ImageTextureLayered::update_layer(const Ref<Image> &p_image, int p_layer) { ERR_FAIL_COND(texture.is_valid()); ERR_FAIL_COND(p_image.is_null()); ERR_FAIL_COND(p_image->get_format() != format); @@ -2018,19 +2044,19 @@ void TextureLayered::update_layer(const Ref<Image> &p_image, int p_layer) { RS::get_singleton()->texture_2d_update(texture, p_image, p_layer); } -Ref<Image> TextureLayered::get_layer_data(int p_layer) const { +Ref<Image> ImageTextureLayered::get_layer_data(int p_layer) const { ERR_FAIL_INDEX_V(p_layer, layers, Ref<Image>()); return RS::get_singleton()->texture_2d_layer_get(texture, p_layer); } -RID TextureLayered::get_rid() const { +RID ImageTextureLayered::get_rid() const { if (texture.is_null()) { - texture = RS::get_singleton()->texture_2d_layered_placeholder_create(); + texture = RS::get_singleton()->texture_2d_layered_placeholder_create(RS::TextureLayeredType(layered_type)); } return texture; } -void TextureLayered::set_path(const String &p_path, bool p_take_over) { +void ImageTextureLayered::set_path(const String &p_path, bool p_take_over) { if (texture.is_valid()) { RS::get_singleton()->texture_set_path(texture, p_path); } @@ -2038,24 +2064,17 @@ void TextureLayered::set_path(const String &p_path, bool p_take_over) { Resource::set_path(p_path, p_take_over); } -void TextureLayered::_bind_methods() { +void ImageTextureLayered::_bind_methods() { - ClassDB::bind_method(D_METHOD("get_format"), &TextureLayered::get_format); + ClassDB::bind_method(D_METHOD("create_from_images", "images"), &ImageTextureLayered::_create_from_images); + ClassDB::bind_method(D_METHOD("update_layer", "image", "layer"), &ImageTextureLayered::update_layer); - ClassDB::bind_method(D_METHOD("get_width"), &TextureLayered::get_width); - ClassDB::bind_method(D_METHOD("get_height"), &TextureLayered::get_height); - ClassDB::bind_method(D_METHOD("get_layers"), &TextureLayered::get_layers); - - ClassDB::bind_method(D_METHOD("create_from_images", "images"), &TextureLayered::_create_from_images); - ClassDB::bind_method(D_METHOD("update_layer", "image", "layer"), &TextureLayered::update_layer); - ClassDB::bind_method(D_METHOD("get_layer_data", "layer"), &TextureLayered::get_layer_data); - - ClassDB::bind_method(D_METHOD("_get_images"), &TextureLayered::_get_images); + ClassDB::bind_method(D_METHOD("_get_images"), &ImageTextureLayered::_get_images); ADD_PROPERTY(PropertyInfo(Variant::ARRAY, "_images", PROPERTY_HINT_NONE, "", PROPERTY_USAGE_INTERNAL), "create_from_images", "_get_images"); } -TextureLayered::TextureLayered(RenderingServer::TextureLayeredType p_layered_type) { +ImageTextureLayered::ImageTextureLayered(LayeredType p_layered_type) { layered_type = p_layered_type; format = Image::FORMAT_MAX; @@ -2064,193 +2083,241 @@ TextureLayered::TextureLayered(RenderingServer::TextureLayeredType p_layered_typ layers = 0; } -TextureLayered::~TextureLayered() { +ImageTextureLayered::~ImageTextureLayered() { if (texture.is_valid()) { RS::get_singleton()->free(texture); } } -RES ResourceFormatLoaderTextureLayered::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { +/////////////////////////////////////////// + +void StreamTextureLayered::set_path(const String &p_path, bool p_take_over) { - if (r_error) { - *r_error = ERR_CANT_OPEN; + if (texture.is_valid()) { + RenderingServer::get_singleton()->texture_set_path(texture, p_path); } - Ref<TextureLayered> lt; + Resource::set_path(p_path, p_take_over); +} - if (p_path.ends_with("cube")) { - Ref<Cubemap> cube; - cube.instance(); - lt = cube; - } else if (p_path.ends_with("cubearr")) { - Ref<CubemapArray> cubearr; - cubearr.instance(); - lt = cubearr; - } else if (p_path.ends_with("tex2darr")) { - Ref<Texture2DArray> t2darr; - t2darr.instance(); - lt = t2darr; - } else { - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture extension."); +Image::Format StreamTextureLayered::get_format() const { + + return format; +} + +Error StreamTextureLayered::_load_data(const String &p_path, Vector<Ref<Image>> &images, int &mipmap_limit, int p_size_limit) { + + ERR_FAIL_COND_V(images.size() != 0, ERR_INVALID_PARAMETER); + + FileAccessRef f = FileAccess::open(p_path, FileAccess::READ); + ERR_FAIL_COND_V(!f, ERR_CANT_OPEN); + + uint8_t header[4]; + f->get_buffer(header, 4); + if (header[0] != 'G' || header[1] != 'S' || header[2] != 'T' || header[3] != 'L') { + ERR_FAIL_V_MSG(ERR_FILE_CORRUPT, "Stream texture layered file is corrupt (Bad header)."); } - FileAccess *f = FileAccess::open(p_path, FileAccess::READ); - ERR_FAIL_COND_V_MSG(!f, RES(), "Cannot open file '" + p_path + "'."); + uint32_t version = f->get_32(); - char header[5] = { 0, 0, 0, 0, 0 }; - f->get_buffer((uint8_t *)header, 4); + if (version > FORMAT_VERSION) { + ERR_FAIL_V_MSG(ERR_FILE_CORRUPT, "Stream texture file is too new."); + } - if (String(header) != "GDLT") { - f->close(); - memdelete(f); - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - // FIXME: It's bogus that we fail in both branches. Seen while rebasing - // vulkan branch on master branch. - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture."); - } else { + uint32_t layer_count = f->get_32(); //layer count + uint32_t type = f->get_32(); //layer count + ERR_FAIL_COND_V(type != layered_type, ERR_INVALID_DATA); - f->close(); - memdelete(f); - ERR_FAIL_V_MSG(RES(), "Unrecognized layered texture file format '" + String((const char *)header) + "'."); + uint32_t df = f->get_32(); //data format + mipmap_limit = int(f->get_32()); + //reserverd + f->get_32(); + f->get_32(); + f->get_32(); + + if (!(df & FORMAT_BIT_STREAM)) { + p_size_limit = 0; } - int tw = f->get_32(); - int th = f->get_32(); - int td = f->get_32(); - bool use_mipmaps = f->get_32() != 0; //texture flags (deprecated) - Image::Format format = Image::Format(f->get_32()); - uint32_t compression = f->get_32(); // 0 - lossless (PNG), 1 - vram, 2 - uncompressed + images.resize(layer_count); - Vector<Ref<Image>> images; - for (int layer = 0; layer < td; layer++) { + for (uint32_t i = 0; i < layer_count; i++) { + Ref<Image> image = StreamTexture2D::load_image_from_file(f, p_size_limit); + ERR_FAIL_COND_V(image.is_null() || image->empty(), ERR_CANT_OPEN); + images.write[i] = image; + } - Ref<Image> image; - image.instance(); + return OK; +} - if (compression == COMPRESSION_LOSSLESS) { - //look for a PNG file inside +Error StreamTextureLayered::load(const String &p_path) { - int mipmaps = f->get_32(); - Vector<Ref<Image>> mipmap_images; + Vector<Ref<Image>> images; - for (int i = 0; i < mipmaps; i++) { - uint32_t size = f->get_32(); + int mipmap_limit; - Vector<uint8_t> pv; - pv.resize(size); - { - uint8_t *w = pv.ptrw(); - f->get_buffer(w, size); - } + Error err = _load_data(p_path, images, mipmap_limit); + if (err) + return err; - Ref<Image> img = Image::lossless_unpacker(pv); + if (texture.is_valid()) { + RID new_texture = RS::get_singleton()->texture_2d_layered_create(images, RS::TextureLayeredType(layered_type)); + RS::get_singleton()->texture_replace(texture, new_texture); + } else { + texture = RS::get_singleton()->texture_2d_layered_create(images, RS::TextureLayeredType(layered_type)); + } - if (img.is_null() || img->empty() || format != img->get_format()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } + w = images[0]->get_width(); + h = images[0]->get_height(); + mipmaps = images[0]->has_mipmaps(); + format = images[0]->get_format(); + layers = images.size(); - mipmap_images.push_back(img); - } + path_to_file = p_path; - if (mipmap_images.size() == 1) { + if (get_path() == String()) { + //temporarily set path if no path set for resource, helps find errors + RenderingServer::get_singleton()->texture_set_path(texture, p_path); + } - image = mipmap_images[0]; + _change_notify(); + emit_changed(); + return OK; +} +String StreamTextureLayered::get_load_path() const { - } else { - int total_size = Image::get_image_data_size(tw, th, format, true); - Vector<uint8_t> img_data; - img_data.resize(total_size); - - { - uint8_t *w = img_data.ptrw(); - - int ofs = 0; - for (int i = 0; i < mipmap_images.size(); i++) { - - Vector<uint8_t> id = mipmap_images[i]->get_data(); - int len = id.size(); - const uint8_t *r = id.ptr(); - copymem(&w[ofs], r, len); - ofs += len; - } - } + return path_to_file; +} - image->create(tw, th, true, format, img_data); - if (image->empty()) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } +int StreamTextureLayered::get_width() const { - } else { + return w; +} +int StreamTextureLayered::get_height() const { - //look for regular format + return h; +} +int StreamTextureLayered::get_layers() const { - int total_size = Image::get_image_data_size(tw, th, format, use_mipmaps); + return layers; +} +bool StreamTextureLayered::has_mipmaps() const { + return mipmaps; +} - Vector<uint8_t> img_data; - img_data.resize(total_size); +TextureLayered::LayeredType StreamTextureLayered::get_layered_type() const { - { - uint8_t *w = img_data.ptrw(); - int bytes = f->get_buffer(w, total_size); - if (bytes != total_size) { - if (r_error) { - *r_error = ERR_FILE_CORRUPT; - } - f->close(); - memdelete(f); - ERR_FAIL_V(RES()); - } - } + return layered_type; +} - image->create(tw, th, use_mipmaps, format, img_data); - } +RID StreamTextureLayered::get_rid() const { - images.push_back(image); + if (!texture.is_valid()) { + texture = RS::get_singleton()->texture_2d_layered_placeholder_create(RS::TextureLayeredType(layered_type)); } + return texture; +} - Error err = lt->create_from_images(images); - if (err != OK) { - *r_error = err; - return RES(); +Ref<Image> StreamTextureLayered::get_layer_data(int p_layer) const { + + if (texture.is_valid()) { + return RS::get_singleton()->texture_2d_layer_get(texture, p_layer); } else { + return Ref<Image>(); + } +} + +void StreamTextureLayered::reload_from_file() { - if (r_error) - *r_error = OK; + String path = get_path(); + if (!path.is_resource_file()) + return; + + path = ResourceLoader::path_remap(path); //remap for translation + path = ResourceLoader::import_remap(path); //remap for import + if (!path.is_resource_file()) + return; + + load(path); +} + +void StreamTextureLayered::_validate_property(PropertyInfo &property) const { +} + +void StreamTextureLayered::_bind_methods() { + + ClassDB::bind_method(D_METHOD("load", "path"), &StreamTextureLayered::load); + ClassDB::bind_method(D_METHOD("get_load_path"), &StreamTextureLayered::get_load_path); + + ADD_PROPERTY(PropertyInfo(Variant::STRING, "load_path", PROPERTY_HINT_FILE, "*.stex"), "load", "get_load_path"); +} + +StreamTextureLayered::StreamTextureLayered(LayeredType p_type) { + + layered_type = p_type; + format = Image::FORMAT_MAX; + w = 0; + h = 0; + layers = 0; + mipmaps = false; +} + +StreamTextureLayered::~StreamTextureLayered() { + + if (texture.is_valid()) { + RS::get_singleton()->free(texture); + } +} + +///////////////////////////////////////////////// + +RES ResourceFormatLoaderStreamTextureLayered::load(const String &p_path, const String &p_original_path, Error *r_error, bool p_use_sub_threads, float *r_progress, bool p_no_cache) { + + Ref<StreamTextureLayered> st; + if (p_path.get_extension().to_lower() == "stexarray") { + Ref<StreamTexture2DArray> s; + s.instance(); + st = s; + } else if (p_path.get_extension().to_lower() == "scube") { + Ref<StreamCubemap> s; + s.instance(); + st = s; + } else if (p_path.get_extension().to_lower() == "scubearray") { + Ref<StreamCubemapArray> s; + s.instance(); + st = s; + } else { + if (r_error) { + *r_error = ERR_FILE_UNRECOGNIZED; + } + return RES(); } + Error err = st->load(p_path); + if (r_error) + *r_error = err; + if (err != OK) + return RES(); - return lt; + return st; } -void ResourceFormatLoaderTextureLayered::get_recognized_extensions(List<String> *p_extensions) const { +void ResourceFormatLoaderStreamTextureLayered::get_recognized_extensions(List<String> *p_extensions) const { - p_extensions->push_back("cube"); - p_extensions->push_back("cubearr"); - p_extensions->push_back("tex2darr"); + p_extensions->push_back("stexarray"); + p_extensions->push_back("scube"); + p_extensions->push_back("scubearray"); } -bool ResourceFormatLoaderTextureLayered::handles_type(const String &p_type) const { - return p_type == "Texture2DArray" || p_type == "Cubemap" || p_type == "CubemapArray"; +bool ResourceFormatLoaderStreamTextureLayered::handles_type(const String &p_type) const { + return p_type == "StreamTexture2DArray" || p_type == "StreamCubemap" || p_type == "StreamCubemapArray"; } -String ResourceFormatLoaderTextureLayered::get_resource_type(const String &p_path) const { +String ResourceFormatLoaderStreamTextureLayered::get_resource_type(const String &p_path) const { - if (p_path.get_extension().to_lower() == "cube") - return "Cubemap"; - if (p_path.get_extension().to_lower() == "cubearr") - return "CubemapArray"; - if (p_path.get_extension().to_lower() == "tex2darr") - return "Texture2DArray"; + if (p_path.get_extension().to_lower() == "stexarray") + return "StreamTexture2DArray"; + if (p_path.get_extension().to_lower() == "scube") + return "StreamCubemap"; + if (p_path.get_extension().to_lower() == "scubearray") + return "StreamCubemapArray"; return ""; } diff --git a/scene/resources/texture.h b/scene/resources/texture.h index 5d5f438eba..5d4131ec4c 100644 --- a/scene/resources/texture.h +++ b/scene/resources/texture.h @@ -132,9 +132,9 @@ public: ~ImageTexture(); }; -class StreamTexture : public Texture2D { +class StreamTexture2D : public Texture2D { - GDCLASS(StreamTexture, Texture2D); + GDCLASS(StreamTexture2D, Texture2D); public: enum DataFormat { @@ -181,8 +181,8 @@ protected: public: static Ref<Image> load_image_from_file(FileAccess *p_file, int p_size_limit); - typedef void (*TextureFormatRequestCallback)(const Ref<StreamTexture> &); - typedef void (*TextureFormatRoughnessRequestCallback)(const Ref<StreamTexture> &, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel); + typedef void (*TextureFormatRequestCallback)(const Ref<StreamTexture2D> &); + typedef void (*TextureFormatRoughnessRequestCallback)(const Ref<StreamTexture2D> &, const String &p_normal_path, RS::TextureDetectRoughnessChannel p_roughness_channel); static TextureFormatRequestCallback request_3d_callback; static TextureFormatRoughnessRequestCallback request_roughness_callback; @@ -207,11 +207,11 @@ public: virtual Ref<Image> get_data() const; - StreamTexture(); - ~StreamTexture(); + StreamTexture2D(); + ~StreamTexture2D(); }; -class ResourceFormatLoaderStreamTexture : public ResourceFormatLoader { +class ResourceFormatLoaderStreamTexture2D : public ResourceFormatLoader { public: virtual RES load(const String &p_path, const String &p_original_path = "", Error *r_error = nullptr, bool p_use_sub_threads = false, float *r_progress = nullptr, bool p_no_cache = false); virtual void get_recognized_extensions(List<String> *p_extensions) const; @@ -349,10 +349,34 @@ public: }; class TextureLayered : public Texture { - GDCLASS(TextureLayered, Texture); - RS::TextureLayeredType layered_type; +protected: + static void _bind_methods(); + +public: + enum LayeredType { + LAYERED_TYPE_2D_ARRAY, + LAYERED_TYPE_CUBEMAP, + LAYERED_TYPE_CUBEMAP_ARRAY + }; + + virtual Image::Format get_format() const = 0; + virtual LayeredType get_layered_type() const = 0; + virtual int get_width() const = 0; + virtual int get_height() const = 0; + virtual int get_layers() const = 0; + virtual bool has_mipmaps() const = 0; + virtual Ref<Image> get_layer_data(int p_layer) const = 0; +}; + +VARIANT_ENUM_CAST(TextureLayered::LayeredType) + +class ImageTextureLayered : public TextureLayered { + + GDCLASS(ImageTextureLayered, TextureLayered); + + LayeredType layered_type; mutable RID texture; Image::Format format; @@ -370,57 +394,137 @@ protected: static void _bind_methods(); public: - Image::Format get_format() const; - uint32_t get_width() const; - uint32_t get_height() const; - uint32_t get_layers() const; - bool has_mipmaps() const; + virtual Image::Format get_format() const; + virtual int get_width() const; + virtual int get_height() const; + virtual int get_layers() const; + virtual bool has_mipmaps() const; + virtual LayeredType get_layered_type() const; Error create_from_images(Vector<Ref<Image>> p_images); void update_layer(const Ref<Image> &p_image, int p_layer); - Ref<Image> get_layer_data(int p_layer) const; + virtual Ref<Image> get_layer_data(int p_layer) const; virtual RID get_rid() const; virtual void set_path(const String &p_path, bool p_take_over = false); - TextureLayered(RS::TextureLayeredType p_layered_type); - ~TextureLayered(); + ImageTextureLayered(LayeredType p_layered_type); + ~ImageTextureLayered(); }; -class Texture2DArray : public TextureLayered { +class Texture2DArray : public ImageTextureLayered { - GDCLASS(Texture2DArray, TextureLayered) + GDCLASS(Texture2DArray, ImageTextureLayered) public: Texture2DArray() : - TextureLayered(RS::TEXTURE_LAYERED_2D_ARRAY) {} + ImageTextureLayered(LAYERED_TYPE_2D_ARRAY) {} }; -class Cubemap : public TextureLayered { +class Cubemap : public ImageTextureLayered { - GDCLASS(Cubemap, TextureLayered); + GDCLASS(Cubemap, ImageTextureLayered); public: Cubemap() : - TextureLayered(RS::TEXTURE_LAYERED_CUBEMAP) {} + ImageTextureLayered(LAYERED_TYPE_CUBEMAP) {} }; -class CubemapArray : public TextureLayered { +class CubemapArray : public ImageTextureLayered { - GDCLASS(CubemapArray, TextureLayered); + GDCLASS(CubemapArray, ImageTextureLayered); public: CubemapArray() : - TextureLayered(RS::TEXTURE_LAYERED_CUBEMAP_ARRAY) {} + ImageTextureLayered(LAYERED_TYPE_CUBEMAP_ARRAY) {} }; -class ResourceFormatLoaderTextureLayered : public ResourceFormatLoader { +class StreamTextureLayered : public TextureLayered { + + GDCLASS(StreamTextureLayered, TextureLayered); + public: - enum Compression { - COMPRESSION_LOSSLESS, - COMPRESSION_VRAM, - COMPRESSION_UNCOMPRESSED + enum DataFormat { + DATA_FORMAT_IMAGE, + DATA_FORMAT_LOSSLESS, + DATA_FORMAT_LOSSY, + DATA_FORMAT_BASIS_UNIVERSAL, + }; + + enum { + FORMAT_VERSION = 1 }; + enum FormatBits { + FORMAT_MASK_IMAGE_FORMAT = (1 << 20) - 1, + FORMAT_BIT_LOSSLESS = 1 << 20, + FORMAT_BIT_LOSSY = 1 << 21, + FORMAT_BIT_STREAM = 1 << 22, + FORMAT_BIT_HAS_MIPMAPS = 1 << 23, + }; + +private: + Error _load_data(const String &p_path, Vector<Ref<Image>> &images, int &mipmap_limit, int p_size_limit = 0); + String path_to_file; + mutable RID texture; + Image::Format format; + int w, h, layers; + bool mipmaps; + LayeredType layered_type; + + virtual void reload_from_file(); + +protected: + static void _bind_methods(); + void _validate_property(PropertyInfo &property) const; + +public: + Image::Format get_format() const; + Error load(const String &p_path); + String get_load_path() const; + virtual LayeredType get_layered_type() const; + + int get_width() const; + int get_height() const; + int get_layers() const; + virtual bool has_mipmaps() const; + virtual RID get_rid() const; + + virtual void set_path(const String &p_path, bool p_take_over); + + virtual Ref<Image> get_layer_data(int p_layer) const; + + StreamTextureLayered(LayeredType p_layered_type); + ~StreamTextureLayered(); +}; + +class StreamTexture2DArray : public StreamTextureLayered { + + GDCLASS(StreamTexture2DArray, StreamTextureLayered) +public: + StreamTexture2DArray() : + StreamTextureLayered(LAYERED_TYPE_2D_ARRAY) {} +}; + +class StreamCubemap : public StreamTextureLayered { + + GDCLASS(StreamCubemap, StreamTextureLayered); + +public: + StreamCubemap() : + StreamTextureLayered(LAYERED_TYPE_CUBEMAP) {} +}; + +class StreamCubemapArray : public StreamTextureLayered { + + GDCLASS(StreamCubemapArray, StreamTextureLayered); + +public: + StreamCubemapArray() : + StreamTextureLayered(LAYERED_TYPE_CUBEMAP_ARRAY) {} +}; + +class ResourceFormatLoaderStreamTextureLayered : public ResourceFormatLoader { +public: virtual RES load(const String &p_path, const String &p_original_path = "", Error *r_error = nullptr, bool p_use_sub_threads = false, float *r_progress = nullptr, bool p_no_cache = false); virtual void get_recognized_extensions(List<String> *p_extensions) const; virtual bool handles_type(const String &p_type) const; diff --git a/scene/resources/theme.cpp b/scene/resources/theme.cpp index 98ebf048dc..325ada69ea 100644 --- a/scene/resources/theme.cpp +++ b/scene/resources/theme.cpp @@ -69,7 +69,7 @@ Vector<String> Theme::_get_stylebox_list(const String &p_type) const { return ilret; } -Vector<String> Theme::_get_stylebox_types(void) const { +Vector<String> Theme::_get_stylebox_types() const { Vector<String> ilret; List<StringName> il; diff --git a/scene/resources/theme.h b/scene/resources/theme.h index d6d724e3f7..b5043c35e8 100644 --- a/scene/resources/theme.h +++ b/scene/resources/theme.h @@ -54,7 +54,7 @@ class Theme : public Resource { Vector<String> _get_icon_list(const String &p_type) const; Vector<String> _get_stylebox_list(const String &p_type) const; - Vector<String> _get_stylebox_types(void) const; + Vector<String> _get_stylebox_types() const; Vector<String> _get_font_list(const String &p_type) const; Vector<String> _get_color_list(const String &p_type) const; Vector<String> _get_constant_list(const String &p_type) const; diff --git a/scene/resources/tile_set.cpp b/scene/resources/tile_set.cpp index 6f8a53be1a..1b68b7486b 100644 --- a/scene/resources/tile_set.cpp +++ b/scene/resources/tile_set.cpp @@ -993,7 +993,8 @@ void TileSet::_tile_set_shapes(int p_id, const Array &p_shapes) { if (p_shapes[i].get_type() == Variant::OBJECT) { Ref<Shape2D> shape = p_shapes[i]; - if (shape.is_null()) continue; + if (shape.is_null()) + continue; s.shape = shape; s.shape_transform = default_transform; diff --git a/scene/resources/tile_set.h b/scene/resources/tile_set.h index 05b43dfb89..5f01959253 100644 --- a/scene/resources/tile_set.h +++ b/scene/resources/tile_set.h @@ -48,13 +48,10 @@ public: Ref<Shape2D> shape; Transform2D shape_transform; Vector2 autotile_coord; - bool one_way_collision; - float one_way_collision_margin; + bool one_way_collision = false; + float one_way_collision_margin = 1.0; - ShapeData() { - one_way_collision = false; - one_way_collision_margin = 1.0; - } + ShapeData() {} }; enum BitmaskMode { @@ -92,22 +89,18 @@ public: }; struct AutotileData { - BitmaskMode bitmask_mode; - Size2 size; - int spacing; - Vector2 icon_coord; + BitmaskMode bitmask_mode = BITMASK_2X2; + // Default size to prevent invalid value + Size2 size = Size2(64, 64); + Vector2 icon_coord = Vector2(0, 0); + int spacing = 0; Map<Vector2, uint32_t> flags; Map<Vector2, Ref<OccluderPolygon2D>> occluder_map; Map<Vector2, Ref<NavigationPolygon>> navpoly_map; Map<Vector2, int> priority_map; Map<Vector2, int> z_index_map; - // Default size to prevent invalid value - explicit AutotileData() : - bitmask_mode(BITMASK_2X2), - size(64, 64), - spacing(0), - icon_coord(0, 0) {} + explicit AutotileData() {} }; private: @@ -124,16 +117,13 @@ private: Vector2 navigation_polygon_offset; Ref<NavigationPolygon> navigation_polygon; Ref<ShaderMaterial> material; - TileMode tile_mode; - Color modulate; + TileMode tile_mode = SINGLE_TILE; + // Default modulate for back-compat + Color modulate = Color(1, 1, 1); AutotileData autotile_data; - int z_index; + int z_index = 0; - // Default modulate for back-compat - explicit TileData() : - tile_mode(SINGLE_TILE), - modulate(1, 1, 1), - z_index(0) {} + explicit TileData() {} }; Map<int, TileData> tile_map; diff --git a/scene/resources/visual_shader.cpp b/scene/resources/visual_shader.cpp index 3b245f908a..5637aaec9a 100644 --- a/scene/resources/visual_shader.cpp +++ b/scene/resources/visual_shader.cpp @@ -1229,11 +1229,21 @@ Error VisualShader::_write_node(Type type, StringBuilder &global_code, StringBui for (int i = 0; i < output_count; i++) { String var_name = "n_out" + itos(node) + "p" + itos(i); switch (vsnode->get_output_port_type(i)) { - case VisualShaderNode::PORT_TYPE_SCALAR: outputs[i] = "float " + var_name; break; - case VisualShaderNode::PORT_TYPE_SCALAR_INT: outputs[i] = "int " + var_name; break; - case VisualShaderNode::PORT_TYPE_VECTOR: outputs[i] = "vec3 " + var_name; break; - case VisualShaderNode::PORT_TYPE_BOOLEAN: outputs[i] = "bool " + var_name; break; - case VisualShaderNode::PORT_TYPE_TRANSFORM: outputs[i] = "mat4 " + var_name; break; + case VisualShaderNode::PORT_TYPE_SCALAR: + outputs[i] = "float " + var_name; + break; + case VisualShaderNode::PORT_TYPE_SCALAR_INT: + outputs[i] = "int " + var_name; + break; + case VisualShaderNode::PORT_TYPE_VECTOR: + outputs[i] = "vec3 " + var_name; + break; + case VisualShaderNode::PORT_TYPE_BOOLEAN: + outputs[i] = "bool " + var_name; + break; + case VisualShaderNode::PORT_TYPE_TRANSFORM: + outputs[i] = "mat4 " + var_name; + break; default: { } } @@ -1243,11 +1253,21 @@ Error VisualShader::_write_node(Type type, StringBuilder &global_code, StringBui for (int i = 0; i < output_count; i++) { outputs[i] = "n_out" + itos(node) + "p" + itos(i); switch (vsnode->get_output_port_type(i)) { - case VisualShaderNode::PORT_TYPE_SCALAR: code += String() + "\tfloat " + outputs[i] + ";\n"; break; - case VisualShaderNode::PORT_TYPE_SCALAR_INT: code += String() + "\tint " + outputs[i] + ";\n"; break; - case VisualShaderNode::PORT_TYPE_VECTOR: code += String() + "\tvec3 " + outputs[i] + ";\n"; break; - case VisualShaderNode::PORT_TYPE_BOOLEAN: code += String() + "\tbool " + outputs[i] + ";\n"; break; - case VisualShaderNode::PORT_TYPE_TRANSFORM: code += String() + "\tmat4 " + outputs[i] + ";\n"; break; + case VisualShaderNode::PORT_TYPE_SCALAR: + code += String() + "\tfloat " + outputs[i] + ";\n"; + break; + case VisualShaderNode::PORT_TYPE_SCALAR_INT: + code += String() + "\tint " + outputs[i] + ";\n"; + break; + case VisualShaderNode::PORT_TYPE_VECTOR: + code += String() + "\tvec3 " + outputs[i] + ";\n"; + break; + case VisualShaderNode::PORT_TYPE_BOOLEAN: + code += String() + "\tbool " + outputs[i] + ";\n"; + break; + case VisualShaderNode::PORT_TYPE_TRANSFORM: + code += String() + "\tmat4 " + outputs[i] + ";\n"; + break; default: { } } diff --git a/scene/resources/visual_shader_nodes.cpp b/scene/resources/visual_shader_nodes.cpp index 7b9953a90f..03db8c3ac5 100644 --- a/scene/resources/visual_shader_nodes.cpp +++ b/scene/resources/visual_shader_nodes.cpp @@ -497,9 +497,14 @@ String VisualShaderNodeTexture::generate_global(Shader::Mode p_mode, VisualShade String u = "uniform sampler2D " + make_unique_id(p_type, p_id, "tex"); switch (texture_type) { - case TYPE_DATA: break; - case TYPE_COLOR: u += " : hint_albedo"; break; - case TYPE_NORMALMAP: u += " : hint_normal"; break; + case TYPE_DATA: + break; + case TYPE_COLOR: + u += " : hint_albedo"; + break; + case TYPE_NORMALMAP: + u += " : hint_normal"; + break; } return u + ";\n"; } @@ -869,9 +874,14 @@ String VisualShaderNodeCubemap::generate_global(Shader::Mode p_mode, VisualShade if (source == SOURCE_TEXTURE) { String u = "uniform samplerCube " + make_unique_id(p_type, p_id, "cube"); switch (texture_type) { - case TYPE_DATA: break; - case TYPE_COLOR: u += " : hint_albedo"; break; - case TYPE_NORMALMAP: u += " : hint_normal"; break; + case TYPE_DATA: + break; + case TYPE_COLOR: + u += " : hint_albedo"; + break; + case TYPE_NORMALMAP: + u += " : hint_normal"; + break; } return u + ";\n"; } @@ -1032,16 +1042,36 @@ String VisualShaderNodeFloatOp::generate_code(Shader::Mode p_mode, VisualShader: String code = "\t" + p_output_vars[0] + " = "; switch (op) { - case OP_ADD: code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; break; - case OP_SUB: code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; break; - case OP_MUL: code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; break; - case OP_DIV: code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; break; - case OP_MOD: code += "mod(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_POW: code += "pow(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_MAX: code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_MIN: code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_ATAN2: code += "atan(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_STEP: code += "step(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; + case OP_ADD: + code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; + break; + case OP_SUB: + code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; + break; + case OP_MUL: + code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; + break; + case OP_DIV: + code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; + break; + case OP_MOD: + code += "mod(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_POW: + code += "pow(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_MAX: + code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_MIN: + code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_ATAN2: + code += "atan(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_STEP: + code += "step(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; } return code; @@ -1124,13 +1154,27 @@ String VisualShaderNodeIntOp::generate_code(Shader::Mode p_mode, VisualShader::T String code = "\t" + p_output_vars[0] + " = "; switch (op) { - case OP_ADD: code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; break; - case OP_SUB: code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; break; - case OP_MUL: code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; break; - case OP_DIV: code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; break; - case OP_MOD: code += p_input_vars[0] + " % " + p_input_vars[1] + ";\n"; break; - case OP_MAX: code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_MIN: code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; + case OP_ADD: + code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; + break; + case OP_SUB: + code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; + break; + case OP_MUL: + code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; + break; + case OP_DIV: + code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; + break; + case OP_MOD: + code += p_input_vars[0] + " % " + p_input_vars[1] + ";\n"; + break; + case OP_MAX: + code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_MIN: + code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; } return code; @@ -1209,18 +1253,42 @@ String VisualShaderNodeVectorOp::generate_code(Shader::Mode p_mode, VisualShader String code = "\t" + p_output_vars[0] + " = "; switch (op) { - case OP_ADD: code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; break; - case OP_SUB: code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; break; - case OP_MUL: code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; break; - case OP_DIV: code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; break; - case OP_MOD: code += "mod(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_POW: code += "pow(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_MAX: code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_MIN: code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_CROSS: code += "cross(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_ATAN2: code += "atan(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_REFLECT: code += "reflect(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; - case OP_STEP: code += "step(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; break; + case OP_ADD: + code += p_input_vars[0] + " + " + p_input_vars[1] + ";\n"; + break; + case OP_SUB: + code += p_input_vars[0] + " - " + p_input_vars[1] + ";\n"; + break; + case OP_MUL: + code += p_input_vars[0] + " * " + p_input_vars[1] + ";\n"; + break; + case OP_DIV: + code += p_input_vars[0] + " / " + p_input_vars[1] + ";\n"; + break; + case OP_MOD: + code += "mod(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_POW: + code += "pow(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_MAX: + code += "max(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_MIN: + code += "min(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_CROSS: + code += "cross(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_ATAN2: + code += "atan(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_REFLECT: + code += "reflect(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; + case OP_STEP: + code += "step(" + p_input_vars[0] + ", " + p_input_vars[1] + ");\n"; + break; } return code; @@ -3752,8 +3820,12 @@ String VisualShaderNodeTextureUniform::generate_global(Shader::Mode p_mode, Visu else code += " : hint_albedo;\n"; break; - case TYPE_NORMALMAP: code += " : hint_normal;\n"; break; - case TYPE_ANISO: code += " : hint_aniso;\n"; break; + case TYPE_NORMALMAP: + code += " : hint_normal;\n"; + break; + case TYPE_ANISO: + code += " : hint_aniso;\n"; + break; } return code; @@ -4003,8 +4075,12 @@ String VisualShaderNodeCubemapUniform::generate_global(Shader::Mode p_mode, Visu else code += " : hint_albedo;\n"; break; - case TYPE_NORMALMAP: code += " : hint_normal;\n"; break; - case TYPE_ANISO: code += " : hint_aniso;\n"; break; + case TYPE_NORMALMAP: + code += " : hint_normal;\n"; + break; + case TYPE_ANISO: + code += " : hint_aniso;\n"; + break; } return code; diff --git a/scene/scene_string_names.cpp b/scene/scene_string_names.cpp index 5e3f8b803b..e761354cc9 100644 --- a/scene/scene_string_names.cpp +++ b/scene/scene_string_names.cpp @@ -205,4 +205,9 @@ SceneStringNames::SceneStringNames() { shader_overrides_group = StaticCString::create("_shader_overrides_group_"); shader_overrides_group_active = StaticCString::create("_shader_overrides_group_active_"); + +#ifndef DISABLE_DEPRECATED + use_in_baked_light = StaticCString::create("use_in_baked_light"); + use_dynamic_gi = StaticCString::create("use_dynamic_gi"); +#endif } diff --git a/scene/scene_string_names.h b/scene/scene_string_names.h index c5de10a6f6..c5c98ba9e5 100644 --- a/scene/scene_string_names.h +++ b/scene/scene_string_names.h @@ -210,6 +210,10 @@ public: StringName shader_overrides_group; StringName shader_overrides_group_active; +#ifndef DISABLE_DEPRECATED + StringName use_in_baked_light; + StringName use_dynamic_gi; +#endif enum { MAX_MATERIALS = 32 }; diff --git a/servers/audio/audio_filter_sw.cpp b/servers/audio/audio_filter_sw.cpp index 2771fc177b..3928ba1388 100644 --- a/servers/audio/audio_filter_sw.cpp +++ b/servers/audio/audio_filter_sw.cpp @@ -58,7 +58,8 @@ void AudioFilterSW::prepare_coefficients(Coeffs *p_coeffs) { int sr_limit = (sampling_rate / 2) + 512; double final_cutoff = (cutoff > sr_limit) ? sr_limit : cutoff; - if (final_cutoff < 1) final_cutoff = 1; //don't allow less than this + if (final_cutoff < 1) + final_cutoff = 1; //don't allow less than this double omega = 2.0 * Math_PI * final_cutoff / sampling_rate; diff --git a/servers/audio/audio_rb_resampler.cpp b/servers/audio/audio_rb_resampler.cpp index 0ac7ddc7a9..7097c3934f 100644 --- a/servers/audio/audio_rb_resampler.cpp +++ b/servers/audio/audio_rb_resampler.cpp @@ -119,10 +119,18 @@ bool AudioRBResampler::mix(AudioFrame *p_dest, int p_frames) { { int src_read = 0; switch (channels) { - case 1: src_read = _resample<1>(p_dest, target_todo, increment); break; - case 2: src_read = _resample<2>(p_dest, target_todo, increment); break; - case 4: src_read = _resample<4>(p_dest, target_todo, increment); break; - case 6: src_read = _resample<6>(p_dest, target_todo, increment); break; + case 1: + src_read = _resample<1>(p_dest, target_todo, increment); + break; + case 2: + src_read = _resample<2>(p_dest, target_todo, increment); + break; + case 4: + src_read = _resample<4>(p_dest, target_todo, increment); + break; + case 6: + src_read = _resample<6>(p_dest, target_todo, increment); + break; } if (src_read > read_space) diff --git a/servers/audio/effects/audio_effect_filter.h b/servers/audio/effects/audio_effect_filter.h index 0088118d8c..c439c5a5b5 100644 --- a/servers/audio/effects/audio_effect_filter.h +++ b/servers/audio/effects/audio_effect_filter.h @@ -99,7 +99,8 @@ class AudioEffectLowPassFilter : public AudioEffectFilter { GDCLASS(AudioEffectLowPassFilter, AudioEffectFilter); void _validate_property(PropertyInfo &property) const { - if (property.name == "gain") property.usage = 0; + if (property.name == "gain") + property.usage = 0; } public: @@ -110,7 +111,8 @@ public: class AudioEffectHighPassFilter : public AudioEffectFilter { GDCLASS(AudioEffectHighPassFilter, AudioEffectFilter); void _validate_property(PropertyInfo &property) const { - if (property.name == "gain") property.usage = 0; + if (property.name == "gain") + property.usage = 0; } public: @@ -121,7 +123,8 @@ public: class AudioEffectBandPassFilter : public AudioEffectFilter { GDCLASS(AudioEffectBandPassFilter, AudioEffectFilter); void _validate_property(PropertyInfo &property) const { - if (property.name == "gain") property.usage = 0; + if (property.name == "gain") + property.usage = 0; } public: diff --git a/servers/audio/effects/audio_effect_record.h b/servers/audio/effects/audio_effect_record.h index 09101033be..68d968c04b 100644 --- a/servers/audio/effects/audio_effect_record.h +++ b/servers/audio/effects/audio_effect_record.h @@ -49,7 +49,7 @@ class AudioEffectRecordInstance : public AudioEffectInstance { bool is_recording; Thread *io_thread; - bool thread_active; + bool thread_active = false; Vector<AudioFrame> ring_buffer; Vector<float> recording_data; @@ -71,8 +71,7 @@ public: virtual void process(const AudioFrame *p_src_frames, AudioFrame *p_dst_frames, int p_frame_count); virtual bool process_silence() const; - AudioEffectRecordInstance() : - thread_active(false) {} + AudioEffectRecordInstance() {} ~AudioEffectRecordInstance(); }; diff --git a/servers/audio/effects/audio_effect_spectrum_analyzer.cpp b/servers/audio/effects/audio_effect_spectrum_analyzer.cpp index 47aee02de2..680ef567fd 100644 --- a/servers/audio/effects/audio_effect_spectrum_analyzer.cpp +++ b/servers/audio/effects/audio_effect_spectrum_analyzer.cpp @@ -50,7 +50,8 @@ static void smbFft(float *fftBuffer, long fftFrameSize, long sign) for (i = 2; i < 2 * fftFrameSize - 2; i += 2) { for (bitm = 2, j = 0; bitm < 2 * fftFrameSize; bitm <<= 1) { - if (i & bitm) j++; + if (i & bitm) + j++; j <<= 1; } if (i < j) { diff --git a/servers/audio/effects/eq.cpp b/servers/audio/effects/eq.cpp index 426f178dcb..b90e8a2726 100644 --- a/servers/audio/effects/eq.cpp +++ b/servers/audio/effects/eq.cpp @@ -125,6 +125,7 @@ void EQ::set_preset_band_mode(Preset p_preset) { for (int i = 0; i < m_bands; i++) { \ Band b; \ b.freq = bands[i]; \ + b.c1 = b.c2 = b.c3 = 0; \ band.push_back(b); \ } diff --git a/servers/audio/effects/reverb.cpp b/servers/audio/effects/reverb.cpp index ea2174f1d4..02565f4516 100644 --- a/servers/audio/effects/reverb.cpp +++ b/servers/audio/effects/reverb.cpp @@ -31,7 +31,9 @@ // Author: Juan Linietsky <reduzio@gmail.com>, (C) 2006 #include "reverb.h" + #include "core/math/math_funcs.h" + #include <math.h> const float Reverb::comb_tunings[MAX_COMBS] = { @@ -338,11 +340,8 @@ Reverb::Reverb() { params.predelay = 150; params.predelay_fb = 0.4; params.hpf = 0; - hpf_h1 = 0; - hpf_h2 = 0; input_buffer = memnew_arr(float, INPUT_BUFFER_MAX_SIZE); - echo_buffer = nullptr; configure_buffers(); update_parameters(); diff --git a/servers/audio/effects/reverb.h b/servers/audio/effects/reverb.h index 92e4aed435..ed8c824148 100644 --- a/servers/audio/effects/reverb.h +++ b/servers/audio/effects/reverb.h @@ -58,44 +58,34 @@ private: struct Comb { - int size; - float *buffer; - float feedback; - float damp; //lowpass - float damp_h; //history - int pos; - int extra_spread_frames; - - Comb() { - size = 0; - buffer = 0; - feedback = 0; - damp_h = 0; - pos = 0; - } + int size = 0; + float *buffer = nullptr; + float feedback = 0; + float damp = 0; //lowpass + float damp_h = 0; //history + int pos = 0; + int extra_spread_frames = 0; + + Comb() {} }; struct AllPass { - int size; - float *buffer; - int pos; - int extra_spread_frames; - AllPass() { - size = 0; - buffer = 0; - pos = 0; - } + int size = 0; + float *buffer = nullptr; + int pos = 0; + int extra_spread_frames = 0; + AllPass() {} }; Comb comb[MAX_COMBS]; AllPass allpass[MAX_ALLPASS]; float *input_buffer; - float *echo_buffer; + float *echo_buffer = nullptr; int echo_buffer_size; int echo_buffer_pos; - float hpf_h1, hpf_h2; + float hpf_h1, hpf_h2 = 0; struct Parameters { diff --git a/servers/audio_server.cpp b/servers/audio_server.cpp index 90033d4a87..146e70f38d 100644 --- a/servers/audio_server.cpp +++ b/servers/audio_server.cpp @@ -109,9 +109,12 @@ void AudioDriver::input_buffer_write(int32_t sample) { AudioDriver::SpeakerMode AudioDriver::get_speaker_mode_by_total_channels(int p_channels) const { switch (p_channels) { - case 4: return SPEAKER_SURROUND_31; - case 6: return SPEAKER_SURROUND_51; - case 8: return SPEAKER_SURROUND_71; + case 4: + return SPEAKER_SURROUND_31; + case 6: + return SPEAKER_SURROUND_51; + case 8: + return SPEAKER_SURROUND_71; } // Default to STEREO @@ -120,10 +123,14 @@ AudioDriver::SpeakerMode AudioDriver::get_speaker_mode_by_total_channels(int p_c int AudioDriver::get_total_channels_by_speaker_mode(AudioDriver::SpeakerMode p_mode) const { switch (p_mode) { - case SPEAKER_MODE_STEREO: return 2; - case SPEAKER_SURROUND_31: return 4; - case SPEAKER_SURROUND_51: return 6; - case SPEAKER_SURROUND_71: return 8; + case SPEAKER_MODE_STEREO: + return 2; + case SPEAKER_SURROUND_31: + return 4; + case SPEAKER_SURROUND_51: + return 6; + case SPEAKER_SURROUND_71: + return 8; } ERR_FAIL_V(2); diff --git a/servers/audio_server.h b/servers/audio_server.h index a1a3dde719..f10c2ecdc0 100644 --- a/servers/audio_server.h +++ b/servers/audio_server.h @@ -258,10 +258,14 @@ protected: public: _FORCE_INLINE_ int get_channel_count() const { switch (get_speaker_mode()) { - case SPEAKER_MODE_STEREO: return 1; - case SPEAKER_SURROUND_31: return 2; - case SPEAKER_SURROUND_51: return 3; - case SPEAKER_SURROUND_71: return 4; + case SPEAKER_MODE_STEREO: + return 1; + case SPEAKER_SURROUND_31: + return 2; + case SPEAKER_SURROUND_51: + return 3; + case SPEAKER_SURROUND_71: + return 4; } ERR_FAIL_V(1); } diff --git a/servers/physics_2d/area_2d_sw.cpp b/servers/physics_2d/area_2d_sw.cpp index 85ec2aae47..6ae7c90c89 100644 --- a/servers/physics_2d/area_2d_sw.cpp +++ b/servers/physics_2d/area_2d_sw.cpp @@ -129,28 +129,52 @@ void Area2DSW::set_space_override_mode(PhysicsServer2D::AreaSpaceOverrideMode p_ void Area2DSW::set_param(PhysicsServer2D::AreaParameter p_param, const Variant &p_value) { switch (p_param) { - case PhysicsServer2D::AREA_PARAM_GRAVITY: gravity = p_value; break; - case PhysicsServer2D::AREA_PARAM_GRAVITY_VECTOR: gravity_vector = p_value; break; - case PhysicsServer2D::AREA_PARAM_GRAVITY_IS_POINT: gravity_is_point = p_value; break; - case PhysicsServer2D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: gravity_distance_scale = p_value; break; - case PhysicsServer2D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: point_attenuation = p_value; break; - case PhysicsServer2D::AREA_PARAM_LINEAR_DAMP: linear_damp = p_value; break; - case PhysicsServer2D::AREA_PARAM_ANGULAR_DAMP: angular_damp = p_value; break; - case PhysicsServer2D::AREA_PARAM_PRIORITY: priority = p_value; break; + case PhysicsServer2D::AREA_PARAM_GRAVITY: + gravity = p_value; + break; + case PhysicsServer2D::AREA_PARAM_GRAVITY_VECTOR: + gravity_vector = p_value; + break; + case PhysicsServer2D::AREA_PARAM_GRAVITY_IS_POINT: + gravity_is_point = p_value; + break; + case PhysicsServer2D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: + gravity_distance_scale = p_value; + break; + case PhysicsServer2D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: + point_attenuation = p_value; + break; + case PhysicsServer2D::AREA_PARAM_LINEAR_DAMP: + linear_damp = p_value; + break; + case PhysicsServer2D::AREA_PARAM_ANGULAR_DAMP: + angular_damp = p_value; + break; + case PhysicsServer2D::AREA_PARAM_PRIORITY: + priority = p_value; + break; } } Variant Area2DSW::get_param(PhysicsServer2D::AreaParameter p_param) const { switch (p_param) { - case PhysicsServer2D::AREA_PARAM_GRAVITY: return gravity; - case PhysicsServer2D::AREA_PARAM_GRAVITY_VECTOR: return gravity_vector; - case PhysicsServer2D::AREA_PARAM_GRAVITY_IS_POINT: return gravity_is_point; - case PhysicsServer2D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: return gravity_distance_scale; - case PhysicsServer2D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: return point_attenuation; - case PhysicsServer2D::AREA_PARAM_LINEAR_DAMP: return linear_damp; - case PhysicsServer2D::AREA_PARAM_ANGULAR_DAMP: return angular_damp; - case PhysicsServer2D::AREA_PARAM_PRIORITY: return priority; + case PhysicsServer2D::AREA_PARAM_GRAVITY: + return gravity; + case PhysicsServer2D::AREA_PARAM_GRAVITY_VECTOR: + return gravity_vector; + case PhysicsServer2D::AREA_PARAM_GRAVITY_IS_POINT: + return gravity_is_point; + case PhysicsServer2D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: + return gravity_distance_scale; + case PhysicsServer2D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: + return point_attenuation; + case PhysicsServer2D::AREA_PARAM_LINEAR_DAMP: + return linear_damp; + case PhysicsServer2D::AREA_PARAM_ANGULAR_DAMP: + return angular_damp; + case PhysicsServer2D::AREA_PARAM_PRIORITY: + return priority; } return Variant(); diff --git a/servers/physics_2d/body_2d_sw.h b/servers/physics_2d/body_2d_sw.h index 0514b263b4..a36dc3bfe2 100644 --- a/servers/physics_2d/body_2d_sw.h +++ b/servers/physics_2d/body_2d_sw.h @@ -158,7 +158,8 @@ public: _FORCE_INLINE_ void set_max_contacts_reported(int p_size) { contacts.resize(p_size); contact_count = 0; - if (mode == PhysicsServer2D::BODY_MODE_KINEMATIC && p_size) set_active(true); + if (mode == PhysicsServer2D::BODY_MODE_KINEMATIC && p_size) + set_active(true); } _FORCE_INLINE_ int get_max_contacts_reported() const { return contacts.size(); } diff --git a/servers/physics_2d/physics_server_2d_sw.cpp b/servers/physics_2d/physics_server_2d_sw.cpp index 871e2aba1d..bc3b468cea 100644 --- a/servers/physics_2d/physics_server_2d_sw.cpp +++ b/servers/physics_2d/physics_server_2d_sw.cpp @@ -1091,9 +1091,15 @@ void PhysicsServer2DSW::joint_set_param(RID p_joint, JointParam p_param, real_t ERR_FAIL_COND(!joint); switch (p_param) { - case JOINT_PARAM_BIAS: joint->set_bias(p_value); break; - case JOINT_PARAM_MAX_BIAS: joint->set_max_bias(p_value); break; - case JOINT_PARAM_MAX_FORCE: joint->set_max_force(p_value); break; + case JOINT_PARAM_BIAS: + joint->set_bias(p_value); + break; + case JOINT_PARAM_MAX_BIAS: + joint->set_max_bias(p_value); + break; + case JOINT_PARAM_MAX_FORCE: + joint->set_max_force(p_value); + break; } } @@ -1103,9 +1109,15 @@ real_t PhysicsServer2DSW::joint_get_param(RID p_joint, JointParam p_param) const ERR_FAIL_COND_V(!joint, -1); switch (p_param) { - case JOINT_PARAM_BIAS: return joint->get_bias(); break; - case JOINT_PARAM_MAX_BIAS: return joint->get_max_bias(); break; - case JOINT_PARAM_MAX_FORCE: return joint->get_max_force(); break; + case JOINT_PARAM_BIAS: + return joint->get_bias(); + break; + case JOINT_PARAM_MAX_BIAS: + return joint->get_max_bias(); + break; + case JOINT_PARAM_MAX_FORCE: + return joint->get_max_force(); + break; } return 0; diff --git a/servers/physics_2d/space_2d_sw.cpp b/servers/physics_2d/space_2d_sw.cpp index 7ae2e9769f..46fe0afda6 100644 --- a/servers/physics_2d/space_2d_sw.cpp +++ b/servers/physics_2d/space_2d_sw.cpp @@ -311,7 +311,7 @@ bool PhysicsDirectSpaceState2DSW::cast_motion(const RID &p_shape, const Transfor bool PhysicsDirectSpaceState2DSW::collide_shape(RID p_shape, const Transform2D &p_shape_xform, const Vector2 &p_motion, real_t p_margin, Vector2 *r_results, int p_result_max, int &r_result_count, const Set<RID> &p_exclude, uint32_t p_collision_mask, bool p_collide_with_bodies, bool p_collide_with_areas) { if (p_result_max <= 0) - return 0; + return false; Shape2DSW *shape = PhysicsServer2DSW::singletonsw->shape_owner.getornull(p_shape); ERR_FAIL_COND_V(!shape, 0); @@ -1281,14 +1281,30 @@ void Space2DSW::set_param(PhysicsServer2D::SpaceParameter p_param, real_t p_valu switch (p_param) { - case PhysicsServer2D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: contact_recycle_radius = p_value; break; - case PhysicsServer2D::SPACE_PARAM_CONTACT_MAX_SEPARATION: contact_max_separation = p_value; break; - case PhysicsServer2D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: contact_max_allowed_penetration = p_value; break; - case PhysicsServer2D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: body_linear_velocity_sleep_threshold = p_value; break; - case PhysicsServer2D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: body_angular_velocity_sleep_threshold = p_value; break; - case PhysicsServer2D::SPACE_PARAM_BODY_TIME_TO_SLEEP: body_time_to_sleep = p_value; break; - case PhysicsServer2D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: constraint_bias = p_value; break; - case PhysicsServer2D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: test_motion_min_contact_depth = p_value; break; + case PhysicsServer2D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: + contact_recycle_radius = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_CONTACT_MAX_SEPARATION: + contact_max_separation = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: + contact_max_allowed_penetration = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: + body_linear_velocity_sleep_threshold = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: + body_angular_velocity_sleep_threshold = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_BODY_TIME_TO_SLEEP: + body_time_to_sleep = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: + constraint_bias = p_value; + break; + case PhysicsServer2D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: + test_motion_min_contact_depth = p_value; + break; } } @@ -1296,14 +1312,22 @@ real_t Space2DSW::get_param(PhysicsServer2D::SpaceParameter p_param) const { switch (p_param) { - case PhysicsServer2D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: return contact_recycle_radius; - case PhysicsServer2D::SPACE_PARAM_CONTACT_MAX_SEPARATION: return contact_max_separation; - case PhysicsServer2D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: return contact_max_allowed_penetration; - case PhysicsServer2D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: return body_linear_velocity_sleep_threshold; - case PhysicsServer2D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: return body_angular_velocity_sleep_threshold; - case PhysicsServer2D::SPACE_PARAM_BODY_TIME_TO_SLEEP: return body_time_to_sleep; - case PhysicsServer2D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: return constraint_bias; - case PhysicsServer2D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: return test_motion_min_contact_depth; + case PhysicsServer2D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: + return contact_recycle_radius; + case PhysicsServer2D::SPACE_PARAM_CONTACT_MAX_SEPARATION: + return contact_max_separation; + case PhysicsServer2D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: + return contact_max_allowed_penetration; + case PhysicsServer2D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: + return body_linear_velocity_sleep_threshold; + case PhysicsServer2D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: + return body_angular_velocity_sleep_threshold; + case PhysicsServer2D::SPACE_PARAM_BODY_TIME_TO_SLEEP: + return body_time_to_sleep; + case PhysicsServer2D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: + return constraint_bias; + case PhysicsServer2D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: + return test_motion_min_contact_depth; } return 0; } diff --git a/servers/physics_2d/space_2d_sw.h b/servers/physics_2d/space_2d_sw.h index c6b324c928..e87d842ef1 100644 --- a/servers/physics_2d/space_2d_sw.h +++ b/servers/physics_2d/space_2d_sw.h @@ -192,7 +192,8 @@ public: void set_debug_contacts(int p_amount) { contact_debug.resize(p_amount); } _FORCE_INLINE_ bool is_debugging_contacts() const { return !contact_debug.empty(); } _FORCE_INLINE_ void add_debug_contact(const Vector2 &p_contact) { - if (contact_debug_count < contact_debug.size()) contact_debug.write[contact_debug_count++] = p_contact; + if (contact_debug_count < contact_debug.size()) + contact_debug.write[contact_debug_count++] = p_contact; } _FORCE_INLINE_ Vector<Vector2> get_debug_contacts() { return contact_debug; } _FORCE_INLINE_ int get_debug_contact_count() { return contact_debug_count; } diff --git a/servers/physics_3d/area_3d_sw.cpp b/servers/physics_3d/area_3d_sw.cpp index 911a664a10..2367b653a5 100644 --- a/servers/physics_3d/area_3d_sw.cpp +++ b/servers/physics_3d/area_3d_sw.cpp @@ -129,28 +129,52 @@ void Area3DSW::set_space_override_mode(PhysicsServer3D::AreaSpaceOverrideMode p_ void Area3DSW::set_param(PhysicsServer3D::AreaParameter p_param, const Variant &p_value) { switch (p_param) { - case PhysicsServer3D::AREA_PARAM_GRAVITY: gravity = p_value; break; - case PhysicsServer3D::AREA_PARAM_GRAVITY_VECTOR: gravity_vector = p_value; break; - case PhysicsServer3D::AREA_PARAM_GRAVITY_IS_POINT: gravity_is_point = p_value; break; - case PhysicsServer3D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: gravity_distance_scale = p_value; break; - case PhysicsServer3D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: point_attenuation = p_value; break; - case PhysicsServer3D::AREA_PARAM_LINEAR_DAMP: linear_damp = p_value; break; - case PhysicsServer3D::AREA_PARAM_ANGULAR_DAMP: angular_damp = p_value; break; - case PhysicsServer3D::AREA_PARAM_PRIORITY: priority = p_value; break; + case PhysicsServer3D::AREA_PARAM_GRAVITY: + gravity = p_value; + break; + case PhysicsServer3D::AREA_PARAM_GRAVITY_VECTOR: + gravity_vector = p_value; + break; + case PhysicsServer3D::AREA_PARAM_GRAVITY_IS_POINT: + gravity_is_point = p_value; + break; + case PhysicsServer3D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: + gravity_distance_scale = p_value; + break; + case PhysicsServer3D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: + point_attenuation = p_value; + break; + case PhysicsServer3D::AREA_PARAM_LINEAR_DAMP: + linear_damp = p_value; + break; + case PhysicsServer3D::AREA_PARAM_ANGULAR_DAMP: + angular_damp = p_value; + break; + case PhysicsServer3D::AREA_PARAM_PRIORITY: + priority = p_value; + break; } } Variant Area3DSW::get_param(PhysicsServer3D::AreaParameter p_param) const { switch (p_param) { - case PhysicsServer3D::AREA_PARAM_GRAVITY: return gravity; - case PhysicsServer3D::AREA_PARAM_GRAVITY_VECTOR: return gravity_vector; - case PhysicsServer3D::AREA_PARAM_GRAVITY_IS_POINT: return gravity_is_point; - case PhysicsServer3D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: return gravity_distance_scale; - case PhysicsServer3D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: return point_attenuation; - case PhysicsServer3D::AREA_PARAM_LINEAR_DAMP: return linear_damp; - case PhysicsServer3D::AREA_PARAM_ANGULAR_DAMP: return angular_damp; - case PhysicsServer3D::AREA_PARAM_PRIORITY: return priority; + case PhysicsServer3D::AREA_PARAM_GRAVITY: + return gravity; + case PhysicsServer3D::AREA_PARAM_GRAVITY_VECTOR: + return gravity_vector; + case PhysicsServer3D::AREA_PARAM_GRAVITY_IS_POINT: + return gravity_is_point; + case PhysicsServer3D::AREA_PARAM_GRAVITY_DISTANCE_SCALE: + return gravity_distance_scale; + case PhysicsServer3D::AREA_PARAM_GRAVITY_POINT_ATTENUATION: + return point_attenuation; + case PhysicsServer3D::AREA_PARAM_LINEAR_DAMP: + return linear_damp; + case PhysicsServer3D::AREA_PARAM_ANGULAR_DAMP: + return angular_damp; + case PhysicsServer3D::AREA_PARAM_PRIORITY: + return priority; } return Variant(); diff --git a/servers/physics_3d/body_3d_sw.cpp b/servers/physics_3d/body_3d_sw.cpp index fea5aed6ad..2f2525bb75 100644 --- a/servers/physics_3d/body_3d_sw.cpp +++ b/servers/physics_3d/body_3d_sw.cpp @@ -764,7 +764,7 @@ void Body3DSW::set_kinematic_margin(real_t p_margin) { Body3DSW::Body3DSW() : CollisionObject3DSW(TYPE_BODY), - locked_axis(0), + active_list(this), inertia_update_list(this), direct_state_query_list(this) { diff --git a/servers/physics_3d/body_3d_sw.h b/servers/physics_3d/body_3d_sw.h index 2bd335e6c0..0308d8689e 100644 --- a/servers/physics_3d/body_3d_sw.h +++ b/servers/physics_3d/body_3d_sw.h @@ -54,7 +54,7 @@ class Body3DSW : public CollisionObject3DSW { real_t angular_damp; real_t gravity_scale; - uint16_t locked_axis; + uint16_t locked_axis = 0; real_t kinematic_safe_margin; real_t _inv_mass; @@ -175,7 +175,8 @@ public: _FORCE_INLINE_ void set_max_contacts_reported(int p_size) { contacts.resize(p_size); contact_count = 0; - if (mode == PhysicsServer3D::BODY_MODE_KINEMATIC && p_size) set_active(true); + if (mode == PhysicsServer3D::BODY_MODE_KINEMATIC && p_size) + set_active(true); } _FORCE_INLINE_ int get_max_contacts_reported() const { return contacts.size(); } diff --git a/servers/physics_3d/collision_object_3d_sw.h b/servers/physics_3d/collision_object_3d_sw.h index c5773d0c61..9dd798af26 100644 --- a/servers/physics_3d/collision_object_3d_sw.h +++ b/servers/physics_3d/collision_object_3d_sw.h @@ -92,7 +92,8 @@ protected: #endif transform = p_transform; - if (p_update_shapes) _update_shapes(); + if (p_update_shapes) + _update_shapes(); } _FORCE_INLINE_ void _set_inv_transform(const Transform &p_transform) { inv_transform = p_transform; } void _set_static(bool p_static); diff --git a/servers/physics_3d/gjk_epa.cpp b/servers/physics_3d/gjk_epa.cpp index db37f261ce..aaa7de7531 100644 --- a/servers/physics_3d/gjk_epa.cpp +++ b/servers/physics_3d/gjk_epa.cpp @@ -513,16 +513,16 @@ struct GJK }; struct sList { - sFace* root; - U count; - sList() : root(nullptr),count(0) {} + sFace* root = nullptr; + U count = 0; + sList() {} }; struct sHorizon { - sFace* cf; - sFace* ff; - U nf; - sHorizon() : cf(nullptr),ff(nullptr),nf(0) {} + sFace* cf = nullptr; + sFace* ff = nullptr; + U nf = 0; + sHorizon() {} }; struct eStatus { enum _ { Valid, diff --git a/servers/physics_3d/joints/cone_twist_joint_3d_sw.cpp b/servers/physics_3d/joints/cone_twist_joint_3d_sw.cpp index a072c1f3a0..56dd8044e4 100644 --- a/servers/physics_3d/joints/cone_twist_joint_3d_sw.cpp +++ b/servers/physics_3d/joints/cone_twist_joint_3d_sw.cpp @@ -332,7 +332,8 @@ void ConeTwistJoint3DSW::set_param(PhysicsServer3D::ConeTwistJointParam p_param, m_relaxationFactor = p_value; } break; - case PhysicsServer3D::CONE_TWIST_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::CONE_TWIST_MAX: + break; // Can't happen, but silences warning } } @@ -359,7 +360,8 @@ real_t ConeTwistJoint3DSW::get_param(PhysicsServer3D::ConeTwistJointParam p_para return m_relaxationFactor; } break; - case PhysicsServer3D::CONE_TWIST_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::CONE_TWIST_MAX: + break; // Can't happen, but silences warning } return 0; diff --git a/servers/physics_3d/joints/generic_6dof_joint_3d_sw.cpp b/servers/physics_3d/joints/generic_6dof_joint_3d_sw.cpp index e15aeca842..9f387ea5c5 100644 --- a/servers/physics_3d/joints/generic_6dof_joint_3d_sw.cpp +++ b/servers/physics_3d/joints/generic_6dof_joint_3d_sw.cpp @@ -83,7 +83,8 @@ int G6DOFRotationalLimitMotor3DSW::testLimitValue(real_t test_value) { real_t G6DOFRotationalLimitMotor3DSW::solveAngularLimits( real_t timeStep, Vector3 &axis, real_t jacDiagABInv, Body3DSW *body0, Body3DSW *body1) { - if (!needApplyTorques()) return 0.0f; + if (!needApplyTorques()) + return 0.0f; real_t target_velocity = m_targetVelocity; real_t maxMotorForce = m_maxMotorForce; @@ -137,7 +138,8 @@ real_t G6DOFRotationalLimitMotor3DSW::solveAngularLimits( Vector3 motorImp = clippedMotorImpulse * axis; body0->apply_torque_impulse(motorImp); - if (body1) body1->apply_torque_impulse(-motorImp); + if (body1) + body1->apply_torque_impulse(-motorImp); return clippedMotorImpulse; } @@ -409,7 +411,7 @@ real_t Generic6DOFJoint3DSW::getAngle(int axis_index) const { return m_calculatedAxisAngleDiff[axis_index]; } -void Generic6DOFJoint3DSW::calcAnchorPos(void) { +void Generic6DOFJoint3DSW::calcAnchorPos() { real_t imA = A->get_inv_mass(); real_t imB = B->get_inv_mass(); real_t weight; @@ -520,7 +522,8 @@ void Generic6DOFJoint3DSW::set_param(Vector3::Axis p_axis, PhysicsServer3D::G6DO case PhysicsServer3D::G6DOF_JOINT_ANGULAR_SPRING_EQUILIBRIUM_POINT: { // Not implemented in GodotPhysics3D backend } break; - case PhysicsServer3D::G6DOF_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::G6DOF_JOINT_MAX: + break; // Can't happen, but silences warning } } @@ -620,7 +623,8 @@ real_t Generic6DOFJoint3DSW::get_param(Vector3::Axis p_axis, PhysicsServer3D::G6 case PhysicsServer3D::G6DOF_JOINT_ANGULAR_SPRING_EQUILIBRIUM_POINT: { // Not implemented in GodotPhysics3D backend } break; - case PhysicsServer3D::G6DOF_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::G6DOF_JOINT_MAX: + break; // Can't happen, but silences warning } return 0; } @@ -651,7 +655,8 @@ void Generic6DOFJoint3DSW::set_flag(Vector3::Axis p_axis, PhysicsServer3D::G6DOF case PhysicsServer3D::G6DOF_JOINT_FLAG_ENABLE_ANGULAR_SPRING: { // Not implemented in GodotPhysics3D backend } break; - case PhysicsServer3D::G6DOF_JOINT_FLAG_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::G6DOF_JOINT_FLAG_MAX: + break; // Can't happen, but silences warning } } bool Generic6DOFJoint3DSW::get_flag(Vector3::Axis p_axis, PhysicsServer3D::G6DOFJointAxisFlag p_flag) const { @@ -679,8 +684,9 @@ bool Generic6DOFJoint3DSW::get_flag(Vector3::Axis p_axis, PhysicsServer3D::G6DOF case PhysicsServer3D::G6DOF_JOINT_FLAG_ENABLE_ANGULAR_SPRING: { // Not implemented in GodotPhysics3D backend } break; - case PhysicsServer3D::G6DOF_JOINT_FLAG_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::G6DOF_JOINT_FLAG_MAX: + break; // Can't happen, but silences warning } - return 0; + return false; } diff --git a/servers/physics_3d/joints/generic_6dof_joint_3d_sw.h b/servers/physics_3d/joints/generic_6dof_joint_3d_sw.h index f7aa607901..cc1423a1cb 100644 --- a/servers/physics_3d/joints/generic_6dof_joint_3d_sw.h +++ b/servers/physics_3d/joints/generic_6dof_joint_3d_sw.h @@ -389,7 +389,7 @@ public: return B; } - virtual void calcAnchorPos(void); // overridable + virtual void calcAnchorPos(); // overridable void set_param(Vector3::Axis p_axis, PhysicsServer3D::G6DOFJointAxisParam p_param, real_t p_value); real_t get_param(Vector3::Axis p_axis, PhysicsServer3D::G6DOFJointAxisParam p_param) const; diff --git a/servers/physics_3d/joints/hinge_joint_3d_sw.cpp b/servers/physics_3d/joints/hinge_joint_3d_sw.cpp index e76d366422..dea1be5040 100644 --- a/servers/physics_3d/joints/hinge_joint_3d_sw.cpp +++ b/servers/physics_3d/joints/hinge_joint_3d_sw.cpp @@ -400,15 +400,32 @@ void HingeJoint3DSW::set_param(PhysicsServer3D::HingeJointParam p_param, real_t switch (p_param) { - case PhysicsServer3D::HINGE_JOINT_BIAS: tau = p_value; break; - case PhysicsServer3D::HINGE_JOINT_LIMIT_UPPER: m_upperLimit = p_value; break; - case PhysicsServer3D::HINGE_JOINT_LIMIT_LOWER: m_lowerLimit = p_value; break; - case PhysicsServer3D::HINGE_JOINT_LIMIT_BIAS: m_biasFactor = p_value; break; - case PhysicsServer3D::HINGE_JOINT_LIMIT_SOFTNESS: m_limitSoftness = p_value; break; - case PhysicsServer3D::HINGE_JOINT_LIMIT_RELAXATION: m_relaxationFactor = p_value; break; - case PhysicsServer3D::HINGE_JOINT_MOTOR_TARGET_VELOCITY: m_motorTargetVelocity = p_value; break; - case PhysicsServer3D::HINGE_JOINT_MOTOR_MAX_IMPULSE: m_maxMotorImpulse = p_value; break; - case PhysicsServer3D::HINGE_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::HINGE_JOINT_BIAS: + tau = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_LIMIT_UPPER: + m_upperLimit = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_LIMIT_LOWER: + m_lowerLimit = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_LIMIT_BIAS: + m_biasFactor = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_LIMIT_SOFTNESS: + m_limitSoftness = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_LIMIT_RELAXATION: + m_relaxationFactor = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_MOTOR_TARGET_VELOCITY: + m_motorTargetVelocity = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_MOTOR_MAX_IMPULSE: + m_maxMotorImpulse = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_MAX: + break; // Can't happen, but silences warning } } @@ -416,15 +433,24 @@ real_t HingeJoint3DSW::get_param(PhysicsServer3D::HingeJointParam p_param) const switch (p_param) { - case PhysicsServer3D::HINGE_JOINT_BIAS: return tau; - case PhysicsServer3D::HINGE_JOINT_LIMIT_UPPER: return m_upperLimit; - case PhysicsServer3D::HINGE_JOINT_LIMIT_LOWER: return m_lowerLimit; - case PhysicsServer3D::HINGE_JOINT_LIMIT_BIAS: return m_biasFactor; - case PhysicsServer3D::HINGE_JOINT_LIMIT_SOFTNESS: return m_limitSoftness; - case PhysicsServer3D::HINGE_JOINT_LIMIT_RELAXATION: return m_relaxationFactor; - case PhysicsServer3D::HINGE_JOINT_MOTOR_TARGET_VELOCITY: return m_motorTargetVelocity; - case PhysicsServer3D::HINGE_JOINT_MOTOR_MAX_IMPULSE: return m_maxMotorImpulse; - case PhysicsServer3D::HINGE_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::HINGE_JOINT_BIAS: + return tau; + case PhysicsServer3D::HINGE_JOINT_LIMIT_UPPER: + return m_upperLimit; + case PhysicsServer3D::HINGE_JOINT_LIMIT_LOWER: + return m_lowerLimit; + case PhysicsServer3D::HINGE_JOINT_LIMIT_BIAS: + return m_biasFactor; + case PhysicsServer3D::HINGE_JOINT_LIMIT_SOFTNESS: + return m_limitSoftness; + case PhysicsServer3D::HINGE_JOINT_LIMIT_RELAXATION: + return m_relaxationFactor; + case PhysicsServer3D::HINGE_JOINT_MOTOR_TARGET_VELOCITY: + return m_motorTargetVelocity; + case PhysicsServer3D::HINGE_JOINT_MOTOR_MAX_IMPULSE: + return m_maxMotorImpulse; + case PhysicsServer3D::HINGE_JOINT_MAX: + break; // Can't happen, but silences warning } return 0; @@ -433,17 +459,25 @@ real_t HingeJoint3DSW::get_param(PhysicsServer3D::HingeJointParam p_param) const void HingeJoint3DSW::set_flag(PhysicsServer3D::HingeJointFlag p_flag, bool p_value) { switch (p_flag) { - case PhysicsServer3D::HINGE_JOINT_FLAG_USE_LIMIT: m_useLimit = p_value; break; - case PhysicsServer3D::HINGE_JOINT_FLAG_ENABLE_MOTOR: m_enableAngularMotor = p_value; break; - case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::HINGE_JOINT_FLAG_USE_LIMIT: + m_useLimit = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_FLAG_ENABLE_MOTOR: + m_enableAngularMotor = p_value; + break; + case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: + break; // Can't happen, but silences warning } } bool HingeJoint3DSW::get_flag(PhysicsServer3D::HingeJointFlag p_flag) const { switch (p_flag) { - case PhysicsServer3D::HINGE_JOINT_FLAG_USE_LIMIT: return m_useLimit; - case PhysicsServer3D::HINGE_JOINT_FLAG_ENABLE_MOTOR: return m_enableAngularMotor; - case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::HINGE_JOINT_FLAG_USE_LIMIT: + return m_useLimit; + case PhysicsServer3D::HINGE_JOINT_FLAG_ENABLE_MOTOR: + return m_enableAngularMotor; + case PhysicsServer3D::HINGE_JOINT_FLAG_MAX: + break; // Can't happen, but silences warning } return false; diff --git a/servers/physics_3d/joints/jacobian_entry_3d_sw.h b/servers/physics_3d/joints/jacobian_entry_3d_sw.h index 7e605ab173..1737c21b3d 100644 --- a/servers/physics_3d/joints/jacobian_entry_3d_sw.h +++ b/servers/physics_3d/joints/jacobian_entry_3d_sw.h @@ -54,7 +54,7 @@ subject to the following restrictions: class JacobianEntry3DSW { public: - JacobianEntry3DSW(){}; + JacobianEntry3DSW() {} //constraint between two different rigidbodies JacobianEntry3DSW( const Basis &world2A, diff --git a/servers/physics_3d/joints/pin_joint_3d_sw.cpp b/servers/physics_3d/joints/pin_joint_3d_sw.cpp index 95c01bc463..dd6b315152 100644 --- a/servers/physics_3d/joints/pin_joint_3d_sw.cpp +++ b/servers/physics_3d/joints/pin_joint_3d_sw.cpp @@ -129,18 +129,27 @@ void PinJoint3DSW::solve(real_t p_step) { void PinJoint3DSW::set_param(PhysicsServer3D::PinJointParam p_param, real_t p_value) { switch (p_param) { - case PhysicsServer3D::PIN_JOINT_BIAS: m_tau = p_value; break; - case PhysicsServer3D::PIN_JOINT_DAMPING: m_damping = p_value; break; - case PhysicsServer3D::PIN_JOINT_IMPULSE_CLAMP: m_impulseClamp = p_value; break; + case PhysicsServer3D::PIN_JOINT_BIAS: + m_tau = p_value; + break; + case PhysicsServer3D::PIN_JOINT_DAMPING: + m_damping = p_value; + break; + case PhysicsServer3D::PIN_JOINT_IMPULSE_CLAMP: + m_impulseClamp = p_value; + break; } } real_t PinJoint3DSW::get_param(PhysicsServer3D::PinJointParam p_param) const { switch (p_param) { - case PhysicsServer3D::PIN_JOINT_BIAS: return m_tau; - case PhysicsServer3D::PIN_JOINT_DAMPING: return m_damping; - case PhysicsServer3D::PIN_JOINT_IMPULSE_CLAMP: return m_impulseClamp; + case PhysicsServer3D::PIN_JOINT_BIAS: + return m_tau; + case PhysicsServer3D::PIN_JOINT_DAMPING: + return m_damping; + case PhysicsServer3D::PIN_JOINT_IMPULSE_CLAMP: + return m_impulseClamp; } return 0; diff --git a/servers/physics_3d/joints/slider_joint_3d_sw.cpp b/servers/physics_3d/joints/slider_joint_3d_sw.cpp index 066c30e0f3..57ad64ca21 100644 --- a/servers/physics_3d/joints/slider_joint_3d_sw.cpp +++ b/servers/physics_3d/joints/slider_joint_3d_sw.cpp @@ -304,7 +304,7 @@ void SliderJoint3DSW::solve(real_t p_step) { //----------------------------------------------------------------------------- -void SliderJoint3DSW::calculateTransforms(void) { +void SliderJoint3DSW::calculateTransforms() { m_calculatedTransformA = A->get_transform() * m_frameInA; m_calculatedTransformB = B->get_transform() * m_frameInB; m_realPivotAInW = m_calculatedTransformA.origin; @@ -323,7 +323,7 @@ void SliderJoint3DSW::calculateTransforms(void) { //----------------------------------------------------------------------------- -void SliderJoint3DSW::testLinLimits(void) { +void SliderJoint3DSW::testLinLimits() { m_solveLinLim = false; m_linPos = m_depth[0]; if (m_lowerLinLimit <= m_upperLinLimit) { @@ -343,7 +343,7 @@ void SliderJoint3DSW::testLinLimits(void) { //----------------------------------------------------------------------------- -void SliderJoint3DSW::testAngLimits(void) { +void SliderJoint3DSW::testAngLimits() { m_angDepth = real_t(0.); m_solveAngLim = false; if (m_lowerAngLimit <= m_upperAngLimit) { @@ -363,7 +363,7 @@ void SliderJoint3DSW::testAngLimits(void) { //----------------------------------------------------------------------------- -Vector3 SliderJoint3DSW::getAncorInA(void) { +Vector3 SliderJoint3DSW::getAncorInA() { Vector3 ancorInA; ancorInA = m_realPivotAInW + (m_lowerLinLimit + m_upperLinLimit) * real_t(0.5) * m_sliderAxis; ancorInA = A->get_transform().inverse().xform(ancorInA); @@ -372,7 +372,7 @@ Vector3 SliderJoint3DSW::getAncorInA(void) { //----------------------------------------------------------------------------- -Vector3 SliderJoint3DSW::getAncorInB(void) { +Vector3 SliderJoint3DSW::getAncorInB() { Vector3 ancorInB; ancorInB = m_frameInB.origin; return ancorInB; @@ -381,62 +381,130 @@ Vector3 SliderJoint3DSW::getAncorInB(void) { void SliderJoint3DSW::set_param(PhysicsServer3D::SliderJointParam p_param, real_t p_value) { switch (p_param) { - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: m_upperLinLimit = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: m_lowerLinLimit = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: m_softnessLimLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: m_restitutionLimLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: m_dampingLimLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: m_softnessDirLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: m_restitutionDirLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: m_dampingDirLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: m_softnessOrthoLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: m_restitutionOrthoLin = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: m_dampingOrthoLin = p_value; break; - - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: m_upperAngLimit = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: m_lowerAngLimit = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: m_softnessLimAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: m_restitutionLimAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: m_dampingLimAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: m_softnessDirAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: m_restitutionDirAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: m_dampingDirAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: m_softnessOrthoAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: m_restitutionOrthoAng = p_value; break; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: m_dampingOrthoAng = p_value; break; - - case PhysicsServer3D::SLIDER_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: + m_upperLinLimit = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: + m_lowerLinLimit = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: + m_softnessLimLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: + m_restitutionLimLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: + m_dampingLimLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: + m_softnessDirLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: + m_restitutionDirLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: + m_dampingDirLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: + m_softnessOrthoLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: + m_restitutionOrthoLin = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: + m_dampingOrthoLin = p_value; + break; + + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: + m_upperAngLimit = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: + m_lowerAngLimit = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: + m_softnessLimAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: + m_restitutionLimAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: + m_dampingLimAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: + m_softnessDirAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: + m_restitutionDirAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: + m_dampingDirAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: + m_softnessOrthoAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: + m_restitutionOrthoAng = p_value; + break; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: + m_dampingOrthoAng = p_value; + break; + + case PhysicsServer3D::SLIDER_JOINT_MAX: + break; // Can't happen, but silences warning } } real_t SliderJoint3DSW::get_param(PhysicsServer3D::SliderJointParam p_param) const { switch (p_param) { - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: return m_upperLinLimit; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: return m_lowerLinLimit; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: return m_softnessLimLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: return m_restitutionLimLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: return m_dampingLimLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: return m_softnessDirLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: return m_restitutionDirLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: return m_dampingDirLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: return m_softnessOrthoLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: return m_restitutionOrthoLin; - case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: return m_dampingOrthoLin; - - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: return m_upperAngLimit; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: return m_lowerAngLimit; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: return m_softnessLimAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: return m_restitutionLimAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: return m_dampingLimAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: return m_softnessDirAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: return m_restitutionDirAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: return m_dampingDirAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: return m_softnessOrthoAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: return m_restitutionOrthoAng; - case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: return m_dampingOrthoAng; - - case PhysicsServer3D::SLIDER_JOINT_MAX: break; // Can't happen, but silences warning + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_UPPER: + return m_upperLinLimit; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_LOWER: + return m_lowerLinLimit; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_SOFTNESS: + return m_softnessLimLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_RESTITUTION: + return m_restitutionLimLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_LIMIT_DAMPING: + return m_dampingLimLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_SOFTNESS: + return m_softnessDirLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_RESTITUTION: + return m_restitutionDirLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_MOTION_DAMPING: + return m_dampingDirLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_SOFTNESS: + return m_softnessOrthoLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_RESTITUTION: + return m_restitutionOrthoLin; + case PhysicsServer3D::SLIDER_JOINT_LINEAR_ORTHOGONAL_DAMPING: + return m_dampingOrthoLin; + + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_UPPER: + return m_upperAngLimit; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_LOWER: + return m_lowerAngLimit; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_SOFTNESS: + return m_softnessLimAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_RESTITUTION: + return m_restitutionLimAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_LIMIT_DAMPING: + return m_dampingLimAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_SOFTNESS: + return m_softnessDirAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_RESTITUTION: + return m_restitutionDirAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_MOTION_DAMPING: + return m_dampingDirAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_SOFTNESS: + return m_softnessOrthoAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_RESTITUTION: + return m_restitutionOrthoAng; + case PhysicsServer3D::SLIDER_JOINT_ANGULAR_ORTHOGONAL_DAMPING: + return m_dampingOrthoAng; + + case PhysicsServer3D::SLIDER_JOINT_MAX: + break; // Can't happen, but silences warning } return 0; diff --git a/servers/physics_3d/joints/slider_joint_3d_sw.h b/servers/physics_3d/joints/slider_joint_3d_sw.h index 18287db9c2..37394a1580 100644 --- a/servers/physics_3d/joints/slider_joint_3d_sw.h +++ b/servers/physics_3d/joints/slider_joint_3d_sw.h @@ -230,12 +230,12 @@ public: bool getSolveAngLimit() { return m_solveAngLim; } real_t getAngDepth() { return m_angDepth; } // shared code used by ODE solver - void calculateTransforms(void); - void testLinLimits(void); - void testAngLimits(void); + void calculateTransforms(); + void testLinLimits(); + void testAngLimits(); // access for PE Solver - Vector3 getAncorInA(void); - Vector3 getAncorInB(void); + Vector3 getAncorInA(); + Vector3 getAncorInB(); void set_param(PhysicsServer3D::SliderJointParam p_param, real_t p_value); real_t get_param(PhysicsServer3D::SliderJointParam p_param) const; diff --git a/servers/physics_3d/physics_server_3d_sw.h b/servers/physics_3d/physics_server_3d_sw.h index 6e79d9eceb..481cb667c3 100644 --- a/servers/physics_3d/physics_server_3d_sw.h +++ b/servers/physics_3d/physics_server_3d_sw.h @@ -307,7 +307,7 @@ public: virtual void soft_body_remove_all_pinned_points(RID p_body) {} virtual void soft_body_pin_point(RID p_body, int p_point_index, bool p_pin) {} - virtual bool soft_body_is_point_pinned(RID p_body, int p_point_index) { return 0; } + virtual bool soft_body_is_point_pinned(RID p_body, int p_point_index) { return false; } /* JOINT API */ diff --git a/servers/physics_3d/space_3d_sw.cpp b/servers/physics_3d/space_3d_sw.cpp index bd8e89a8f5..6b31e0e094 100644 --- a/servers/physics_3d/space_3d_sw.cpp +++ b/servers/physics_3d/space_3d_sw.cpp @@ -331,7 +331,7 @@ bool PhysicsDirectSpaceState3DSW::cast_motion(const RID &p_shape, const Transfor bool PhysicsDirectSpaceState3DSW::collide_shape(RID p_shape, const Transform &p_shape_xform, real_t p_margin, Vector3 *r_results, int p_result_max, int &r_result_count, const Set<RID> &p_exclude, uint32_t p_collision_mask, bool p_collide_with_bodies, bool p_collide_with_areas) { if (p_result_max <= 0) - return 0; + return false; Shape3DSW *shape = static_cast<PhysicsServer3DSW *>(PhysicsServer3D::get_singleton())->shape_owner.getornull(p_shape); ERR_FAIL_COND_V(!shape, 0); @@ -1154,15 +1154,33 @@ void Space3DSW::set_param(PhysicsServer3D::SpaceParameter p_param, real_t p_valu switch (p_param) { - case PhysicsServer3D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: contact_recycle_radius = p_value; break; - case PhysicsServer3D::SPACE_PARAM_CONTACT_MAX_SEPARATION: contact_max_separation = p_value; break; - case PhysicsServer3D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: contact_max_allowed_penetration = p_value; break; - case PhysicsServer3D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: body_linear_velocity_sleep_threshold = p_value; break; - case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: body_angular_velocity_sleep_threshold = p_value; break; - case PhysicsServer3D::SPACE_PARAM_BODY_TIME_TO_SLEEP: body_time_to_sleep = p_value; break; - case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_DAMP_RATIO: body_angular_velocity_damp_ratio = p_value; break; - case PhysicsServer3D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: constraint_bias = p_value; break; - case PhysicsServer3D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: test_motion_min_contact_depth = p_value; break; + case PhysicsServer3D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: + contact_recycle_radius = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_CONTACT_MAX_SEPARATION: + contact_max_separation = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: + contact_max_allowed_penetration = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: + body_linear_velocity_sleep_threshold = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: + body_angular_velocity_sleep_threshold = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_BODY_TIME_TO_SLEEP: + body_time_to_sleep = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_DAMP_RATIO: + body_angular_velocity_damp_ratio = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: + constraint_bias = p_value; + break; + case PhysicsServer3D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: + test_motion_min_contact_depth = p_value; + break; } } @@ -1170,15 +1188,24 @@ real_t Space3DSW::get_param(PhysicsServer3D::SpaceParameter p_param) const { switch (p_param) { - case PhysicsServer3D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: return contact_recycle_radius; - case PhysicsServer3D::SPACE_PARAM_CONTACT_MAX_SEPARATION: return contact_max_separation; - case PhysicsServer3D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: return contact_max_allowed_penetration; - case PhysicsServer3D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: return body_linear_velocity_sleep_threshold; - case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: return body_angular_velocity_sleep_threshold; - case PhysicsServer3D::SPACE_PARAM_BODY_TIME_TO_SLEEP: return body_time_to_sleep; - case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_DAMP_RATIO: return body_angular_velocity_damp_ratio; - case PhysicsServer3D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: return constraint_bias; - case PhysicsServer3D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: return test_motion_min_contact_depth; + case PhysicsServer3D::SPACE_PARAM_CONTACT_RECYCLE_RADIUS: + return contact_recycle_radius; + case PhysicsServer3D::SPACE_PARAM_CONTACT_MAX_SEPARATION: + return contact_max_separation; + case PhysicsServer3D::SPACE_PARAM_BODY_MAX_ALLOWED_PENETRATION: + return contact_max_allowed_penetration; + case PhysicsServer3D::SPACE_PARAM_BODY_LINEAR_VELOCITY_SLEEP_THRESHOLD: + return body_linear_velocity_sleep_threshold; + case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_SLEEP_THRESHOLD: + return body_angular_velocity_sleep_threshold; + case PhysicsServer3D::SPACE_PARAM_BODY_TIME_TO_SLEEP: + return body_time_to_sleep; + case PhysicsServer3D::SPACE_PARAM_BODY_ANGULAR_VELOCITY_DAMP_RATIO: + return body_angular_velocity_damp_ratio; + case PhysicsServer3D::SPACE_PARAM_CONSTRAINT_DEFAULT_BIAS: + return constraint_bias; + case PhysicsServer3D::SPACE_PARAM_TEST_MOTION_MIN_CONTACT_DEPTH: + return test_motion_min_contact_depth; } return 0; } diff --git a/servers/physics_3d/space_3d_sw.h b/servers/physics_3d/space_3d_sw.h index 3634834952..a0b4379376 100644 --- a/servers/physics_3d/space_3d_sw.h +++ b/servers/physics_3d/space_3d_sw.h @@ -187,7 +187,8 @@ public: void set_debug_contacts(int p_amount) { contact_debug.resize(p_amount); } _FORCE_INLINE_ bool is_debugging_contacts() const { return !contact_debug.empty(); } _FORCE_INLINE_ void add_debug_contact(const Vector3 &p_contact) { - if (contact_debug_count < contact_debug.size()) contact_debug.write[contact_debug_count++] = p_contact; + if (contact_debug_count < contact_debug.size()) + contact_debug.write[contact_debug_count++] = p_contact; } _FORCE_INLINE_ Vector<Vector3> get_debug_contacts() { return contact_debug; } _FORCE_INLINE_ int get_debug_contact_count() { return contact_debug_count; } diff --git a/servers/physics_server_2d.h b/servers/physics_server_2d.h index 8c833b390f..7f73a50b34 100644 --- a/servers/physics_server_2d.h +++ b/servers/physics_server_2d.h @@ -640,11 +640,9 @@ typedef PhysicsServer2D *(*CreatePhysicsServer2DCallback)(); class PhysicsServer2DManager { struct ClassInfo { String name; - CreatePhysicsServer2DCallback create_callback; + CreatePhysicsServer2DCallback create_callback = nullptr; - ClassInfo() : - name(""), - create_callback(nullptr) {} + ClassInfo() {} ClassInfo(String p_name, CreatePhysicsServer2DCallback p_create_callback) : name(p_name), diff --git a/servers/physics_server_3d.h b/servers/physics_server_3d.h index 8ea8b22455..6900da1393 100644 --- a/servers/physics_server_3d.h +++ b/servers/physics_server_3d.h @@ -781,11 +781,9 @@ typedef PhysicsServer3D *(*CreatePhysicsServer3DCallback)(); class PhysicsServer3DManager { struct ClassInfo { String name; - CreatePhysicsServer3DCallback create_callback; + CreatePhysicsServer3DCallback create_callback = nullptr; - ClassInfo() : - name(""), - create_callback(nullptr) {} + ClassInfo() {} ClassInfo(String p_name, CreatePhysicsServer3DCallback p_create_callback) : name(p_name), diff --git a/servers/rendering/rasterizer.h b/servers/rendering/rasterizer.h index 955241e79c..48ee6e02d0 100644 --- a/servers/rendering/rasterizer.h +++ b/servers/rendering/rasterizer.h @@ -57,6 +57,7 @@ public: virtual void sky_set_radiance_size(RID p_sky, int p_radiance_size) = 0; virtual void sky_set_mode(RID p_sky, RS::SkyMode p_samples) = 0; virtual void sky_set_material(RID p_sky, RID p_material) = 0; + virtual Ref<Image> sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) = 0; /* ENVIRONMENT API */ @@ -94,6 +95,8 @@ public: virtual void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) = 0; virtual void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) = 0; + virtual Ref<Image> environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) = 0; + virtual bool is_environment(RID p_env) const = 0; virtual RS::EnvironmentBG environment_get_background(RID p_env) const = 0; virtual int environment_get_canvas_max_layer(RID p_env) const = 0; @@ -160,9 +163,11 @@ public: SelfList<InstanceBase> dependency_item; - InstanceBase *lightmap_capture; - RID lightmap; - Vector<Color> lightmap_capture_data; //in a array (12 values) to avoid wasting space if unused. Alpha is unused, but needed to send to shader + InstanceBase *lightmap; + Rect2 lightmap_uv_scale; + int lightmap_slice_index; + uint32_t lightmap_cull_index; + Vector<Color> lightmap_sh; //spherical harmonic AABB aabb; AABB transformed_aabb; @@ -178,8 +183,8 @@ public: bool instance_allocated_shader_parameters = false; int32_t instance_allocated_shader_parameters_offset = -1; - virtual void dependency_deleted(RID p_dependency) = 0; - virtual void dependency_changed(bool p_aabb, bool p_dependencies) = 0; + virtual void dependency_deleted(RID p_dependency) {} + virtual void dependency_changed(bool p_aabb, bool p_dependencies) {} Set<InstanceDependency *> dependencies; @@ -233,7 +238,9 @@ public: baked_light = false; dynamic_gi = false; redraw_if_visible = false; - lightmap_capture = nullptr; + lightmap_slice_index = 0; + lightmap = nullptr; + lightmap_cull_index = 0; } virtual ~InstanceBase() { @@ -268,7 +275,7 @@ public: virtual bool gi_probe_needs_update(RID p_probe) const = 0; virtual void gi_probe_update(RID p_probe, bool p_update_light_instances, const Vector<RID> &p_light_instances, int p_dynamic_object_count, InstanceBase **p_dynamic_objects) = 0; - virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) = 0; + virtual void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) = 0; virtual void render_shadow(RID p_light, RID p_shadow_atlas, int p_pass, InstanceBase **p_cull_result, int p_cull_count) = 0; virtual void render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; @@ -286,6 +293,8 @@ public: virtual void sub_surface_scattering_set_quality(RS::SubSurfaceScatteringQuality p_quality) = 0; virtual void sub_surface_scattering_set_scale(float p_scale, float p_depth_scale) = 0; + virtual TypedArray<Image> bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size) = 0; + virtual bool free(RID p_rid) = 0; virtual void update() = 0; @@ -311,7 +320,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create() = 0; - virtual RID texture_2d_layered_placeholder_create() = 0; + virtual RID texture_2d_layered_placeholder_create(RenderingServer::TextureLayeredType p_layered_type) = 0; virtual RID texture_3d_placeholder_create() = 0; virtual Ref<Image> texture_2d_get(RID p_texture) const = 0; @@ -593,29 +602,21 @@ public: /* LIGHTMAP CAPTURE */ - struct LightmapCaptureOctree { - - enum { - CHILD_EMPTY = 0xFFFFFFFF - }; - - uint16_t light[6][3]; //anisotropic light - float alpha; - uint32_t children[8]; - }; - - virtual RID lightmap_capture_create() = 0; - virtual void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) = 0; - virtual AABB lightmap_capture_get_bounds(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree(RID p_capture, const Vector<uint8_t> &p_octree) = 0; - virtual Vector<uint8_t> lightmap_capture_get_octree(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) = 0; - virtual Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) = 0; - virtual int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const = 0; - virtual void lightmap_capture_set_energy(RID p_capture, float p_energy) = 0; - virtual float lightmap_capture_get_energy(RID p_capture) const = 0; - virtual const Vector<LightmapCaptureOctree> *lightmap_capture_get_octree_ptr(RID p_capture) const = 0; + virtual RID lightmap_create() = 0; + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) = 0; + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) = 0; + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior) = 0; + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) = 0; + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const = 0; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const = 0; + virtual AABB lightmap_get_aabb(RID p_lightmap) const = 0; + virtual void lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh) = 0; + virtual bool lightmap_is_interior(RID p_lightmap) const = 0; + virtual void lightmap_set_probe_capture_update_speed(float p_speed) = 0; + virtual float lightmap_get_probe_capture_update_speed() const = 0; /* PARTICLES */ @@ -721,14 +722,16 @@ public: Color get_default_clear_color() const { return default_clear_color; } -#define TIMESTAMP_BEGIN() \ - { \ - if (RSG::storage->capturing_timestamps) RSG::storage->capture_timestamps_begin(); \ +#define TIMESTAMP_BEGIN() \ + { \ + if (RSG::storage->capturing_timestamps) \ + RSG::storage->capture_timestamps_begin(); \ } -#define RENDER_TIMESTAMP(m_text) \ - { \ - if (RSG::storage->capturing_timestamps) RSG::storage->capture_timestamp(m_text); \ +#define RENDER_TIMESTAMP(m_text) \ + { \ + if (RSG::storage->capturing_timestamps) \ + RSG::storage->capture_timestamp(m_text); \ } bool capturing_timestamps = false; @@ -857,7 +860,8 @@ public: _FORCE_INLINE_ TextureBinding() { binding_id = 0; } _FORCE_INLINE_ ~TextureBinding() { - if (binding_id) singleton->free_texture_binding(binding_id); + if (binding_id) + singleton->free_texture_binding(binding_id); } }; @@ -886,7 +890,8 @@ public: _FORCE_INLINE_ Polygon() { polygon_id = 0; } _FORCE_INLINE_ ~Polygon() { - if (polygon_id) singleton->free_polygon(polygon_id); + if (polygon_id) + singleton->free_polygon(polygon_id); } }; @@ -1288,7 +1293,8 @@ public: for (int i = 0; i < blocks.size(); i++) { memfree(blocks[i].memory); } - if (copy_back_buffer) memdelete(copy_back_buffer); + if (copy_back_buffer) + memdelete(copy_back_buffer); if (custom_data) { memdelete(custom_data); } @@ -1365,6 +1371,8 @@ public: virtual void end_frame(bool p_swap_buffers) = 0; virtual void finalize() = 0; + virtual uint64_t get_frame_number() const = 0; + virtual float get_frame_delta_time() const = 0; virtual bool is_low_end() const = 0; diff --git a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp index 956bf54d01..3505b18c8a 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.cpp @@ -1539,8 +1539,8 @@ void RasterizerCanvasRD::canvas_render_items(RID p_to_render_target, Item *p_ite } } - if (md->last_frame != RasterizerRD::get_frame_number()) { - md->last_frame = RasterizerRD::get_frame_number(); + if (md->last_frame != RasterizerRD::singleton->get_frame_number()) { + md->last_frame = RasterizerRD::singleton->get_frame_number(); if (!RD::get_singleton()->uniform_set_is_valid(md->uniform_set)) { // uniform set may be gone because a dependency was erased. In this case, it will happen // if a texture is deleted, so just re-create it. diff --git a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.h b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.h index 4d47b3e13b..48467cbd8c 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_canvas_rd.h @@ -489,7 +489,7 @@ public: void canvas_render_items(RID p_to_render_target, Item *p_item_list, const Color &p_modulate, Light *p_light_list, const Transform2D &p_canvas_transform); - void canvas_debug_viewport_shadows(Light *p_lights_with_shadow){}; + void canvas_debug_viewport_shadows(Light *p_lights_with_shadow) {} void draw_window_margins(int *p_margins, RID *p_margin_textures) {} diff --git a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp index d469dd97ca..ed25cc4139 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.cpp @@ -282,6 +282,30 @@ void RasterizerEffectsRD::copy_to_rect(RID p_source_rd_texture, RID p_dest_textu RD::get_singleton()->compute_list_end(); } +void RasterizerEffectsRD::copy_cubemap_to_panorama(RID p_source_cube, RID p_dest_panorama, const Size2i &p_panorama_size, float p_lod, bool p_is_array) { + + zeromem(©.push_constant, sizeof(CopyPushConstant)); + + copy.push_constant.section[0] = 0; + copy.push_constant.section[1] = 0; + copy.push_constant.section[2] = p_panorama_size.width; + copy.push_constant.section[3] = p_panorama_size.height; + copy.push_constant.target[0] = 0; + copy.push_constant.target[1] = 0; + copy.push_constant.camera_z_far = p_lod; + + int32_t x_groups = (p_panorama_size.width - 1) / 8 + 1; + int32_t y_groups = (p_panorama_size.height - 1) / 8 + 1; + + RD::ComputeListID compute_list = RD::get_singleton()->compute_list_begin(); + RD::get_singleton()->compute_list_bind_compute_pipeline(compute_list, copy.pipelines[p_is_array ? COPY_MODE_CUBE_ARRAY_TO_PANORAMA : COPY_MODE_CUBE_TO_PANORAMA]); + RD::get_singleton()->compute_list_bind_uniform_set(compute_list, _get_compute_uniform_set_from_texture(p_source_cube), 0); + RD::get_singleton()->compute_list_bind_uniform_set(compute_list, _get_uniform_set_from_image(p_dest_panorama), 3); + RD::get_singleton()->compute_list_set_push_constant(compute_list, ©.push_constant, sizeof(CopyPushConstant)); + RD::get_singleton()->compute_list_dispatch(compute_list, x_groups, y_groups, 1); + RD::get_singleton()->compute_list_end(); +} + void RasterizerEffectsRD::copy_depth_to_rect_and_linearize(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y, float p_z_near, float p_z_far) { zeromem(©.push_constant, sizeof(CopyPushConstant)); @@ -1202,7 +1226,9 @@ void RasterizerEffectsRD::render_sky(RD::DrawListID p_list, float p_time, RID p_ RD::get_singleton()->draw_list_bind_render_pipeline(draw_list, p_pipeline->get_render_pipeline(RD::INVALID_ID, fb_format)); RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_samplers, 0); - RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_uniform_set, 1); + if (p_uniform_set.is_valid()) { //material may not have uniform set + RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_uniform_set, 1); + } RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_texture_set, 2); RD::get_singleton()->draw_list_bind_uniform_set(draw_list, p_lights, 3); @@ -1226,6 +1252,8 @@ RasterizerEffectsRD::RasterizerEffectsRD() { copy_modes.push_back("\n#define MODE_SIMPLE_COPY_DEPTH\n"); copy_modes.push_back("\n#define MODE_MIPMAP\n"); copy_modes.push_back("\n#define MODE_LINEARIZE_DEPTH_COPY\n"); + copy_modes.push_back("\n#define MODE_CUBEMAP_TO_PANORAMA\n"); + copy_modes.push_back("\n#define MODE_CUBEMAP_ARRAY_TO_PANORAMA\n"); copy.shader.initialize(copy_modes); zeromem(©.push_constant, sizeof(CopyPushConstant)); diff --git a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h index 531591442b..1b16648ca6 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_effects_rd.h @@ -66,6 +66,8 @@ class RasterizerEffectsRD { COPY_MODE_SIMPLY_COPY_DEPTH, COPY_MODE_MIPMAP, COPY_MODE_LINEARIZE_DEPTH, + COPY_MODE_CUBE_TO_PANORAMA, + COPY_MODE_CUBE_ARRAY_TO_PANORAMA, COPY_MODE_MAX, }; @@ -564,6 +566,7 @@ class RasterizerEffectsRD { public: void copy_to_fb_rect(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2i &p_rect, bool p_flip_y = false, bool p_force_luminance = false, bool p_alpha_to_zero = false); void copy_to_rect(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y = false, bool p_force_luminance = false, bool p_all_source = false, bool p_8_bit_dst = false); + void copy_cubemap_to_panorama(RID p_source_cube, RID p_dest_panorama, const Size2i &p_panorama_size, float p_lod, bool p_is_array); void copy_depth_to_rect(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2i &p_rect, bool p_flip_y = false); void copy_depth_to_rect_and_linearize(RID p_source_rd_texture, RID p_dest_texture, const Rect2i &p_rect, bool p_flip_y, float p_z_near, float p_z_far); void copy_to_atlas_fb(RID p_source_rd_texture, RID p_dest_framebuffer, const Rect2 &p_uv_rect, RD::DrawListID p_draw_list, bool p_flip_y = false, bool p_panorama = false); diff --git a/servers/rendering/rasterizer_rd/rasterizer_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_rd.cpp index 4c92912e9c..4267a087b6 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_rd.cpp @@ -79,6 +79,7 @@ void RasterizerRD::blit_render_targets_to_screen(DisplayServer::WindowID p_scree void RasterizerRD::begin_frame(double frame_step) { frame++; + delta = frame_step; time += frame_step; double time_roll_over = GLOBAL_GET("rendering/limits/time/time_rollover_secs"); @@ -157,7 +158,7 @@ void RasterizerRD::initialize() { } ThreadWorkPool RasterizerRD::thread_work_pool; -uint32_t RasterizerRD::frame = 1; +uint64_t RasterizerRD::frame = 1; void RasterizerRD::finalize() { @@ -173,7 +174,10 @@ void RasterizerRD::finalize() { RD::get_singleton()->free(copy_viewports_sampler); } +RasterizerRD *RasterizerRD::singleton = nullptr; + RasterizerRD::RasterizerRD() { + singleton = this; thread_work_pool.init(); time = 0; diff --git a/servers/rendering/rasterizer_rd/rasterizer_rd.h b/servers/rendering/rasterizer_rd/rasterizer_rd.h index 756b9499ca..cb53a531ac 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_rd.h @@ -53,8 +53,9 @@ protected: Map<RID, RID> render_target_descriptors; double time; + float delta; - static uint32_t frame; + static uint64_t frame; public: RasterizerStorage *get_storage() { return storage; } @@ -71,7 +72,8 @@ public: void end_frame(bool p_swap_buffers); void finalize(); - static _ALWAYS_INLINE_ uint64_t get_frame_number() { return frame; } + _ALWAYS_INLINE_ uint64_t get_frame_number() const { return frame; } + _ALWAYS_INLINE_ float get_frame_delta_time() const { return delta; } static Error is_viable() { return OK; @@ -89,6 +91,7 @@ public: static ThreadWorkPool thread_work_pool; + static RasterizerRD *singleton; RasterizerRD(); ~RasterizerRD() {} }; diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp index 6986f82065..b639b4200a 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.cpp @@ -67,18 +67,18 @@ static _FORCE_INLINE_ void store_basis_3x4(const Basis &p_mtx, float *p_array) { p_array[11] = 0; } -static _FORCE_INLINE_ void store_transform_3x3(const Transform &p_mtx, float *p_array) { - p_array[0] = p_mtx.basis.elements[0][0]; - p_array[1] = p_mtx.basis.elements[1][0]; - p_array[2] = p_mtx.basis.elements[2][0]; +static _FORCE_INLINE_ void store_transform_3x3(const Basis &p_mtx, float *p_array) { + p_array[0] = p_mtx.elements[0][0]; + p_array[1] = p_mtx.elements[1][0]; + p_array[2] = p_mtx.elements[2][0]; p_array[3] = 0; - p_array[4] = p_mtx.basis.elements[0][1]; - p_array[5] = p_mtx.basis.elements[1][1]; - p_array[6] = p_mtx.basis.elements[2][1]; + p_array[4] = p_mtx.elements[0][1]; + p_array[5] = p_mtx.elements[1][1]; + p_array[6] = p_mtx.elements[2][1]; p_array[7] = 0; - p_array[8] = p_mtx.basis.elements[0][2]; - p_array[9] = p_mtx.basis.elements[1][2]; - p_array[10] = p_mtx.basis.elements[2][2]; + p_array[8] = p_mtx.elements[0][2]; + p_array[9] = p_mtx.elements[1][2]; + p_array[10] = p_mtx.elements[2][2]; p_array[11] = 0; } @@ -841,6 +841,8 @@ bool RasterizerSceneHighEndRD::free(RID p_rid) { void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, int p_element_count, bool p_for_depth) { + uint32_t lightmap_captures_used = 0; + for (int i = 0; i < p_element_count; i++) { const RenderList::Element *e = p_elements[i]; @@ -898,6 +900,7 @@ void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, if (written == 0) { id.gi_offset = index; + id.flags |= INSTANCE_DATA_FLAG_USE_GIPROBE; written = 1; } else { id.gi_offset = index << 16; @@ -910,17 +913,53 @@ void RasterizerSceneHighEndRD::_fill_instances(RenderList::Element **p_elements, } else if (written == 1) { id.gi_offset |= 0xFFFF0000; } + } else if (e->instance->lightmap) { + + int32_t lightmap_index = storage->lightmap_get_array_index(e->instance->lightmap->base); + if (lightmap_index >= 0) { + id.gi_offset = lightmap_index; + id.gi_offset |= e->instance->lightmap_slice_index << 12; + id.gi_offset |= e->instance->lightmap_cull_index << 20; + id.lightmap_uv_scale[0] = e->instance->lightmap_uv_scale.position.x; + id.lightmap_uv_scale[1] = e->instance->lightmap_uv_scale.position.y; + id.lightmap_uv_scale[2] = e->instance->lightmap_uv_scale.size.width; + id.lightmap_uv_scale[3] = e->instance->lightmap_uv_scale.size.height; + id.flags |= INSTANCE_DATA_FLAG_USE_LIGHTMAP; + if (storage->lightmap_uses_spherical_harmonics(e->instance->lightmap->base)) { + id.flags |= INSTANCE_DATA_FLAG_USE_SH_LIGHTMAP; + } + } else { + id.gi_offset = 0xFFFFFFFF; + } + } else if (!e->instance->lightmap_sh.empty()) { + if (lightmap_captures_used < scene_state.max_lightmap_captures) { + + const Color *src_capture = e->instance->lightmap_sh.ptr(); + LightmapCaptureData &lcd = scene_state.lightmap_captures[lightmap_captures_used]; + for (int j = 0; j < 9; j++) { + lcd.sh[j * 4 + 0] = src_capture[j].r; + lcd.sh[j * 4 + 1] = src_capture[j].g; + lcd.sh[j * 4 + 2] = src_capture[j].b; + lcd.sh[j * 4 + 3] = src_capture[j].a; + } + id.flags |= INSTANCE_DATA_FLAG_USE_LIGHTMAP_CAPTURE; + id.gi_offset = lightmap_captures_used; + lightmap_captures_used++; + } } else { id.gi_offset = 0xFFFFFFFF; } } RD::get_singleton()->buffer_update(scene_state.instance_buffer, 0, sizeof(InstanceData) * p_element_count, scene_state.instances, true); + if (lightmap_captures_used) { + RD::get_singleton()->buffer_update(scene_state.lightmap_capture_buffer, 0, sizeof(LightmapCaptureData) * lightmap_captures_used, scene_state.lightmap_captures, true); + } } /// RENDERING /// -void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set) { +void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set, bool p_force_wireframe, const Vector2 &p_uv_offset) { RD::DrawListID draw_list = p_draw_list; RD::FramebufferFormatID framebuffer_format = p_framebuffer_Format; @@ -949,6 +988,8 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l PushConstant push_constant; zeromem(&push_constant, sizeof(PushConstant)); + push_constant.bake_uv2_offset[0] = p_uv_offset.x; + push_constant.bake_uv2_offset[1] = p_uv_offset.y; for (int i = 0; i < p_element_count; i++) { @@ -961,7 +1002,7 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l //find cull variant ShaderData::CullVariant cull_variant; - if ((p_pass_mode == PASS_MODE_SHADOW || p_pass_mode == PASS_MODE_SHADOW_DP) && e->instance->cast_shadows == RS::SHADOW_CASTING_SETTING_DOUBLE_SIDED) { + if (p_pass_mode == PASS_MODE_DEPTH_MATERIAL || ((p_pass_mode == PASS_MODE_SHADOW || p_pass_mode == PASS_MODE_SHADOW_DP) && e->instance->cast_shadows == RS::SHADOW_CASTING_SETTING_DOUBLE_SIDED)) { cull_variant = ShaderData::CULL_VARIANT_DOUBLE_SIDED; } else { bool mirror = e->instance->mirror; @@ -1080,7 +1121,7 @@ void RasterizerSceneHighEndRD::_render_list(RenderingDevice::DrawListID p_draw_l prev_index_array_rd = index_array_rd; } - RID pipeline_rd = pipeline->get_render_pipeline(vertex_format, framebuffer_format); + RID pipeline_rd = pipeline->get_render_pipeline(vertex_format, framebuffer_format, p_force_wireframe); if (pipeline_rd != prev_pipeline_rd) { // checking with prev shader does not make so much sense, as @@ -1255,6 +1296,7 @@ void RasterizerSceneHighEndRD::_setup_environment(RID p_environment, const Camer scene_state.ubo.use_ambient_cubemap = false; scene_state.ubo.use_reflection_cubemap = false; + scene_state.ubo.ssao_enabled = false; } scene_state.ubo.roughness_limiter_enabled = p_opaque_render_buffers && screen_space_roughness_limiter_is_active(); @@ -1271,8 +1313,6 @@ void RasterizerSceneHighEndRD::_add_geometry(InstanceBase *p_instance, uint32_t if (unlikely(get_debug_draw_mode() != RS::VIEWPORT_DEBUG_DRAW_DISABLED)) { if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_OVERDRAW) { m_src = overdraw_material; - } else if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME) { - m_src = wireframe_material; } else if (get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_LIGHTING) { m_src = default_material; } @@ -1374,7 +1414,7 @@ void RasterizerSceneHighEndRD::_add_geometry_with_material(InstanceBase *p_insta e->geometry_index = p_geometry_index; e->material_index = e->material->index; e->uses_instancing = e->instance->base_type == RS::INSTANCE_MULTIMESH; - e->uses_lightmap = e->instance->lightmap.is_valid(); + e->uses_lightmap = e->instance->lightmap != nullptr || !e->instance->lightmap_sh.empty(); e->uses_vct = e->instance->gi_probe_instances.size(); e->shader_index = e->shader_index; e->depth_layer = e->instance->depth_layer; @@ -1575,6 +1615,26 @@ void RasterizerSceneHighEndRD::_setup_reflections(RID *p_reflection_probe_cull_r } } +void RasterizerSceneHighEndRD::_setup_lightmaps(InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, const Transform &p_cam_transform) { + + uint32_t lightmaps_used = 0; + for (int i = 0; i < p_lightmap_cull_count; i++) { + if (i >= (int)scene_state.max_lightmaps) { + break; + } + + InstanceBase *lm = p_lightmap_cull_result[i]; + Basis to_lm = lm->transform.basis.inverse() * p_cam_transform.basis; + to_lm = to_lm.inverse().transposed(); //will transform normals + store_transform_3x3(to_lm, scene_state.lightmaps[i].normal_xform); + lm->lightmap_cull_index = i; + lightmaps_used++; + } + if (lightmaps_used > 0) { + RD::get_singleton()->buffer_update(scene_state.lightmap_buffer, 0, sizeof(LightmapData) * lightmaps_used, scene_state.lightmaps, true); + } +} + void RasterizerSceneHighEndRD::_setup_gi_probes(RID *p_gi_probe_probe_cull_result, int p_gi_probe_probe_cull_count, const Transform &p_camera_transform) { int index = 0; @@ -2118,7 +2178,7 @@ void RasterizerSceneHighEndRD::_setup_decals(const RID *p_decal_instances, int p } } -void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color) { +void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color) { RenderBufferDataHighEnd *render_buffer = nullptr; if (p_render_buffer.is_valid()) { @@ -2238,6 +2298,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor _setup_decals(p_decal_cull_result, p_decal_cull_count, p_cam_transform.affine_inverse()); _setup_reflections(p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_cam_transform.affine_inverse(), p_environment); _setup_gi_probes(p_gi_probe_cull_result, p_gi_probe_cull_count, p_cam_transform); + _setup_lightmaps(p_lightmap_cull_result, p_lightmap_cull_count, p_cam_transform); _setup_environment(p_environment, p_cam_projection, p_cam_transform, p_reflection_probe, p_reflection_probe.is_valid(), screen_pixel_size, p_shadow_atlas, !p_reflection_probe.is_valid(), p_default_bg_color, p_cam_projection.get_z_near(), p_cam_projection.get_z_far(), false); cluster_builder.bake_cluster(); //bake to cluster @@ -2338,7 +2399,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor bool finish_depth = using_ssao; RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(depth_framebuffer, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, finish_depth ? RD::FINAL_ACTION_READ : RD::FINAL_ACTION_CONTINUE, depth_pass_clear); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(depth_framebuffer), render_list.elements, render_list.element_count, false, depth_pass_mode, render_buffer == nullptr, radiance_uniform_set, RID()); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(depth_framebuffer), render_list.elements, render_list.element_count, false, depth_pass_mode, render_buffer == nullptr, radiance_uniform_set, RID(), get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); if (render_buffer && render_buffer->msaa != RS::VIEWPORT_MSAA_DISABLED) { @@ -2394,7 +2455,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor RID framebuffer = using_separate_specular ? opaque_specular_framebuffer : opaque_framebuffer; RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(framebuffer, keep_color ? RD::INITIAL_ACTION_KEEP : RD::INITIAL_ACTION_CLEAR, will_continue_color ? RD::FINAL_ACTION_CONTINUE : RD::FINAL_ACTION_READ, depth_pre_pass ? (using_ssao ? RD::INITIAL_ACTION_KEEP : RD::INITIAL_ACTION_CONTINUE) : RD::INITIAL_ACTION_CLEAR, will_continue_depth ? RD::FINAL_ACTION_CONTINUE : RD::FINAL_ACTION_READ, c, 1.0, 0); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(framebuffer), render_list.elements, render_list.element_count, false, using_separate_specular ? PASS_MODE_COLOR_SPECULAR : PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(framebuffer), render_list.elements, render_list.element_count, false, using_separate_specular ? PASS_MODE_COLOR_SPECULAR : PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set, get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); if (will_continue_color && using_separate_specular) { @@ -2472,7 +2533,7 @@ void RasterizerSceneHighEndRD::_render_scene(RID p_render_buffer, const Transfor { RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(alpha_framebuffer, can_continue_color ? RD::INITIAL_ACTION_CONTINUE : RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ, can_continue_depth ? RD::INITIAL_ACTION_CONTINUE : RD::INITIAL_ACTION_KEEP, RD::FINAL_ACTION_READ); - _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(alpha_framebuffer), &render_list.elements[render_list.max_elements - render_list.alpha_element_count], render_list.alpha_element_count, false, PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set); + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(alpha_framebuffer), &render_list.elements[render_list.max_elements - render_list.alpha_element_count], render_list.alpha_element_count, false, PASS_MODE_COLOR, render_buffer == nullptr, radiance_uniform_set, render_buffers_uniform_set, get_debug_draw_mode() == RS::VIEWPORT_DEBUG_DRAW_WIREFRAME); RD::get_singleton()->draw_list_end(); } @@ -2517,13 +2578,14 @@ void RasterizerSceneHighEndRD::_render_shadow(RID p_framebuffer, InstanceBase ** } void RasterizerSceneHighEndRD::_render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) { - RENDER_TIMESTAMP("Setup Rendering Shadow"); + RENDER_TIMESTAMP("Setup Rendering Material"); _update_render_base_uniform_set(); render_pass++; scene_state.ubo.dual_paraboloid_side = 0; + scene_state.ubo.material_uv2_mode = true; _setup_environment(RID(), p_cam_projection, p_cam_transform, RID(), true, Vector2(1, 1), RID(), false, Color(), 0, 0); @@ -2554,6 +2616,67 @@ void RasterizerSceneHighEndRD::_render_material(const Transform &p_cam_transform } } +void RasterizerSceneHighEndRD::_render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) { + RENDER_TIMESTAMP("Setup Rendering UV2"); + + _update_render_base_uniform_set(); + + render_pass++; + + scene_state.ubo.dual_paraboloid_side = 0; + scene_state.ubo.material_uv2_mode = true; + + _setup_environment(RID(), CameraMatrix(), Transform(), RID(), true, Vector2(1, 1), RID(), false, Color(), 0, 0); + + render_list.clear(); + + PassMode pass_mode = PASS_MODE_DEPTH_MATERIAL; + _fill_render_list(p_cull_result, p_cull_count, pass_mode, true); + + _setup_view_dependant_uniform_set(RID(), RID()); + + RENDER_TIMESTAMP("Render Material"); + + render_list.sort_by_key(false); + + _fill_instances(render_list.elements, render_list.element_count, true); + + { + //regular forward for now + Vector<Color> clear; + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + clear.push_back(Color(0, 0, 0, 0)); + RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(p_framebuffer, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, clear, 1.0, 0, p_region); + + const int uv_offset_count = 9; + static const Vector2 uv_offsets[uv_offset_count] = { + Vector2(-1, 1), + Vector2(1, 1), + Vector2(1, -1), + Vector2(-1, -1), + Vector2(-1, 0), + Vector2(1, 0), + Vector2(0, -1), + Vector2(0, 1), + Vector2(0, 0), + + }; + + for (int i = 0; i < uv_offset_count; i++) { + Vector2 ofs = uv_offsets[i]; + ofs.x /= p_region.size.width; + ofs.y /= p_region.size.height; + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(p_framebuffer), render_list.elements, render_list.element_count, true, pass_mode, true, RID(), RID(), true, ofs); //first wireframe, for pseudo conservative + } + _render_list(draw_list, RD::get_singleton()->framebuffer_get_format(p_framebuffer), render_list.elements, render_list.element_count, true, pass_mode, true, RID(), RID(), false); //second regular triangles + + RD::get_singleton()->draw_list_end(); + } +} + void RasterizerSceneHighEndRD::_base_uniforms_changed() { if (!render_base_uniform_set.is_null() && RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { @@ -2564,12 +2687,14 @@ void RasterizerSceneHighEndRD::_base_uniforms_changed() { void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { - if (render_base_uniform_set.is_null() || !RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { + if (render_base_uniform_set.is_null() || !RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set) || (lightmap_texture_array_version != storage->lightmap_array_get_version())) { if (render_base_uniform_set.is_valid() && RD::get_singleton()->uniform_set_is_valid(render_base_uniform_set)) { RD::get_singleton()->free(render_base_uniform_set); } + lightmap_texture_array_version = storage->lightmap_array_get_version(); + Vector<RD::Uniform> uniforms; { @@ -2685,6 +2810,27 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; u.binding = 10; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.ids.push_back(scene_state.lightmap_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 11; + u.type = RD::UNIFORM_TYPE_TEXTURE; + u.ids = storage->lightmap_array_get_textures(); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 12; + u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; + u.ids.push_back(scene_state.lightmap_capture_buffer); + uniforms.push_back(u); + } + { + RD::Uniform u; + u.binding = 13; u.type = RD::UNIFORM_TYPE_TEXTURE; RID decal_atlas = storage->decal_atlas_get_texture(); u.ids.push_back(decal_atlas); @@ -2692,7 +2838,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { } { RD::Uniform u; - u.binding = 11; + u.binding = 14; u.type = RD::UNIFORM_TYPE_TEXTURE; RID decal_atlas = storage->decal_atlas_get_texture_srgb(); u.ids.push_back(decal_atlas); @@ -2700,7 +2846,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { } { RD::Uniform u; - u.binding = 12; + u.binding = 15; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; u.ids.push_back(scene_state.decal_buffer); uniforms.push_back(u); @@ -2708,14 +2854,14 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; - u.binding = 13; + u.binding = 16; u.type = RD::UNIFORM_TYPE_TEXTURE; u.ids.push_back(cluster_builder.get_cluster_texture()); uniforms.push_back(u); } { RD::Uniform u; - u.binding = 14; + u.binding = 17; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; u.ids.push_back(cluster_builder.get_cluster_indices_buffer()); uniforms.push_back(u); @@ -2723,7 +2869,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; - u.binding = 15; + u.binding = 18; u.type = RD::UNIFORM_TYPE_TEXTURE; if (directional_shadow_get_texture().is_valid()) { u.ids.push_back(directional_shadow_get_texture()); @@ -2736,7 +2882,7 @@ void RasterizerSceneHighEndRD::_update_render_base_uniform_set() { { RD::Uniform u; u.type = RD::UNIFORM_TYPE_STORAGE_BUFFER; - u.binding = 16; + u.binding = 19; u.ids.push_back(storage->global_variables_get_storage_buffer()); uniforms.push_back(u); } @@ -2951,7 +3097,21 @@ RasterizerSceneHighEndRD::RasterizerSceneHighEndRD(RasterizerStorageRD *p_storag scene_state.gi_probe_buffer = RD::get_singleton()->uniform_buffer_create(sizeof(GIProbeData) * scene_state.max_gi_probes); defines += "\n#define MAX_GI_PROBES " + itos(scene_state.max_gi_probes) + "\n"; } + { + //lightmaps + scene_state.max_lightmaps = storage->lightmap_array_get_size(); + defines += "\n#define MAX_LIGHTMAP_TEXTURES " + itos(scene_state.max_lightmaps) + "\n"; + defines += "\n#define MAX_LIGHTMAPS " + itos(scene_state.max_lightmaps) + "\n"; + scene_state.lightmaps = memnew_arr(LightmapData, scene_state.max_lightmaps); + scene_state.lightmap_buffer = RD::get_singleton()->storage_buffer_create(sizeof(LightmapData) * scene_state.max_lightmaps); + } + { + //captures + scene_state.max_lightmap_captures = 2048; + scene_state.lightmap_captures = memnew_arr(LightmapCaptureData, scene_state.max_lightmap_captures); + scene_state.lightmap_capture_buffer = RD::get_singleton()->storage_buffer_create(sizeof(LightmapCaptureData) * scene_state.max_lightmap_captures); + } { //decals scene_state.max_decals = MIN(1024 * 1024, uniform_max_size) / sizeof(DecalData); //1mb of decals uint32_t decal_buffer_size = scene_state.max_decals * sizeof(DecalData); @@ -2959,6 +3119,11 @@ RasterizerSceneHighEndRD::RasterizerSceneHighEndRD(RasterizerStorageRD *p_storag scene_state.decal_buffer = RD::get_singleton()->storage_buffer_create(decal_buffer_size); } + { + + defines += "\n#define MATERIAL_UNIFORM_SET " + itos(MATERIAL_UNIFORM_SET) + "\n"; + } + Vector<String> shader_versions; shader_versions.push_back("\n#define MODE_RENDER_DEPTH\n"); shader_versions.push_back("\n#define MODE_RENDER_DEPTH\n#define MODE_DUAL_PARABOLOID\n"); @@ -3239,12 +3404,16 @@ RasterizerSceneHighEndRD::~RasterizerSceneHighEndRD() { RD::get_singleton()->free(scene_state.gi_probe_buffer); RD::get_singleton()->free(scene_state.directional_light_buffer); RD::get_singleton()->free(scene_state.light_buffer); + RD::get_singleton()->free(scene_state.lightmap_buffer); + RD::get_singleton()->free(scene_state.lightmap_capture_buffer); RD::get_singleton()->free(scene_state.reflection_buffer); RD::get_singleton()->free(scene_state.decal_buffer); memdelete_arr(scene_state.instances); memdelete_arr(scene_state.gi_probes); memdelete_arr(scene_state.directional_lights); memdelete_arr(scene_state.lights); + memdelete_arr(scene_state.lightmaps); + memdelete_arr(scene_state.lightmap_captures); memdelete_arr(scene_state.reflections); memdelete_arr(scene_state.decals); } diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h index a48e2e2259..e8736a0e53 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_high_end_rd.h @@ -193,7 +193,8 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { struct PushConstant { uint32_t index; - uint32_t pad[3]; + uint32_t pad; + float bake_uv2_offset[2]; }; /* Framebuffer */ @@ -241,6 +242,8 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { RID render_base_uniform_set; RID view_dependant_uniform_set; + uint64_t lightmap_texture_array_version = 0xFFFFFFFF; + virtual void _base_uniforms_changed(); void _render_buffers_clear_uniform_set(RenderBufferDataHighEnd *rb); virtual void _render_buffers_uniform_set_changed(RID p_render_buffers); @@ -331,6 +334,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t pad[1]; }; + struct LightmapData { + float normal_xform[12]; + }; + struct DecalData { float xform[16]; float inv_extents[3]; @@ -349,7 +356,15 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { float normal_fade; }; + struct LightmapCaptureData { + float sh[9 * 4]; + }; + enum { + INSTANCE_DATA_FLAG_USE_LIGHTMAP_CAPTURE = 1 << 8, + INSTANCE_DATA_FLAG_USE_LIGHTMAP = 1 << 9, + INSTANCE_DATA_FLAG_USE_SH_LIGHTMAP = 1 << 10, + INSTANCE_DATA_FLAG_USE_GIPROBE = 1 << 11, INSTANCE_DATA_FLAG_MULTIMESH = 1 << 12, INSTANCE_DATA_FLAG_MULTIMESH_FORMAT_2D = 1 << 13, INSTANCE_DATA_FLAG_MULTIMESH_HAS_COLOR = 1 << 14, @@ -366,6 +381,7 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t instance_uniforms_ofs; //instance_offset in instancing/skeleton buffer uint32_t gi_offset; //GI information when using lightmapping (VCT or lightmap) uint32_t mask; + float lightmap_uv_scale[4]; }; struct SceneState { @@ -418,6 +434,9 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t roughness_limiter_enabled; float ao_color[4]; + + uint32_t material_uv2_mode; + uint32_t pad_material[3]; }; UBO ubo; @@ -434,6 +453,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { RID gi_probe_buffer; uint32_t max_gi_probe_probes_per_instance; + LightmapData *lightmaps; + uint32_t max_lightmaps; + RID lightmap_buffer; + DecalData *decals; uint32_t max_decals; RID decal_buffer; @@ -446,6 +469,10 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { uint32_t max_directional_lights; RID directional_light_buffer; + LightmapCaptureData *lightmap_captures; + uint32_t max_lightmap_captures; + RID lightmap_capture_buffer; + RID instance_buffer; InstanceData *instances; uint32_t max_instances; @@ -456,6 +483,7 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { bool used_sss = false; uint32_t current_shader_index = 0; uint32_t current_material_index = 0; + } scene_state; /* Render List */ @@ -632,18 +660,20 @@ class RasterizerSceneHighEndRD : public RasterizerSceneRD { void _setup_decals(const RID *p_decal_instances, int p_decal_count, const Transform &p_camera_inverse_xform); void _setup_reflections(RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, const Transform &p_camera_inverse_transform, RID p_environment); void _setup_gi_probes(RID *p_gi_probe_probe_cull_result, int p_gi_probe_probe_cull_count, const Transform &p_camera_transform); + void _setup_lightmaps(InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, const Transform &p_cam_transform); void _fill_instances(RenderList::Element **p_elements, int p_element_count, bool p_for_depth); - void _render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set); + void _render_list(RenderingDevice::DrawListID p_draw_list, RenderingDevice::FramebufferFormatID p_framebuffer_Format, RenderList::Element **p_elements, int p_element_count, bool p_reverse_cull, PassMode p_pass_mode, bool p_no_gi, RID p_radiance_uniform_set, RID p_render_buffers_uniform_set, bool p_force_wireframe = false, const Vector2 &p_uv_offset = Vector2()); _FORCE_INLINE_ void _add_geometry(InstanceBase *p_instance, uint32_t p_surface, RID p_material, PassMode p_pass_mode, uint32_t p_geometry_index); _FORCE_INLINE_ void _add_geometry_with_material(InstanceBase *p_instance, uint32_t p_surface, MaterialData *p_material, RID p_material_rid, PassMode p_pass_mode, uint32_t p_geometry_index); void _fill_render_list(InstanceBase **p_cull_result, int p_cull_count, PassMode p_pass_mode, bool p_no_gi); protected: - virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color); + virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_bg_color); virtual void _render_shadow(RID p_framebuffer, InstanceBase **p_cull_result, int p_cull_count, const CameraMatrix &p_projection, const Transform &p_transform, float p_zfar, float p_bias, float p_normal_bias, bool p_use_dp, bool p_use_dp_flip, bool p_use_pancake); virtual void _render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region); + virtual void _render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region); public: virtual void set_time(double p_time, double p_step); diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp index 8877de87ac..02221d1536 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.cpp @@ -263,7 +263,47 @@ void RasterizerSceneRD::sky_set_material(RID p_sky, RID p_material) { Sky *sky = sky_owner.getornull(p_sky); ERR_FAIL_COND(!sky); sky->material = p_material; + _sky_invalidate(sky); +} + +Ref<Image> RasterizerSceneRD::sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) { + + Sky *sky = sky_owner.getornull(p_sky); + ERR_FAIL_COND_V(!sky, Ref<Image>()); + + _update_dirty_skys(); + + if (sky->radiance.is_valid()) { + + RD::TextureFormat tf; + tf.format = RD::DATA_FORMAT_R32G32B32A32_SFLOAT; + tf.width = p_size.width; + tf.height = p_size.height; + tf.usage_bits = RD::TEXTURE_USAGE_STORAGE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + + RID rad_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + storage->get_effects()->copy_cubemap_to_panorama(sky->radiance, rad_tex, p_size, p_bake_irradiance ? roughness_layers : 0, sky->reflection.layers.size() > 1); + Vector<uint8_t> data = RD::get_singleton()->texture_get_data(rad_tex, 0); + RD::get_singleton()->free(rad_tex); + + Ref<Image> img; + img.instance(); + img->create(p_size.width, p_size.height, false, Image::FORMAT_RGBAF, data); + for (int i = 0; i < p_size.width; i++) { + for (int j = 0; j < p_size.height; j++) { + Color c = img->get_pixel(i, j); + c.r *= p_energy; + c.g *= p_energy; + c.b *= p_energy; + img->set_pixel(i, j, c); + } + } + return img; + } + + return Ref<Image>(); } + void RasterizerSceneRD::_update_dirty_skys() { Sky *sky = dirty_sky_list; @@ -1336,6 +1376,43 @@ bool RasterizerSceneRD::is_environment(RID p_env) const { return environment_owner.owns(p_env); } +Ref<Image> RasterizerSceneRD::environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) { + Environent *env = environment_owner.getornull(p_env); + ERR_FAIL_COND_V(!env, Ref<Image>()); + + if (env->background == RS::ENV_BG_CAMERA_FEED || env->background == RS::ENV_BG_CANVAS || env->background == RS::ENV_BG_KEEP) { + return Ref<Image>(); //nothing to bake + } + + if (env->background == RS::ENV_BG_CLEAR_COLOR || env->background == RS::ENV_BG_COLOR) { + Color color; + if (env->background == RS::ENV_BG_CLEAR_COLOR) { + color = storage->get_default_clear_color(); + } else { + color = env->bg_color; + } + color.r *= env->bg_energy; + color.g *= env->bg_energy; + color.b *= env->bg_energy; + + Ref<Image> ret; + ret.instance(); + ret->create(p_size.width, p_size.height, false, Image::FORMAT_RGBAF); + for (int i = 0; i < p_size.width; i++) { + for (int j = 0; j < p_size.height; j++) { + ret->set_pixel(i, j, color); + } + } + return ret; + } + + if (env->background == RS::ENV_BG_SKY && env->sky.is_valid()) { + return sky_bake_panorama(env->sky, env->bg_energy, p_bake_irradiance, p_size); + } + + return Ref<Image>(); +} + //////////////////////////////////////////////////////////// RID RasterizerSceneRD::reflection_atlas_create() { @@ -1990,8 +2067,12 @@ int RasterizerSceneRD::get_directional_light_shadow_size(RID p_light_intance) { switch (storage->light_directional_get_shadow_mode(light_instance->light)) { case RS::LIGHT_DIRECTIONAL_SHADOW_ORTHOGONAL: break; //none - case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_2_SPLITS: r.size.height /= 2; break; - case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_4_SPLITS: r.size /= 2; break; + case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_2_SPLITS: + r.size.height /= 2; + break; + case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_4_SPLITS: + r.size /= 2; + break; } return MAX(r.size.width, r.size.height); @@ -3737,7 +3818,7 @@ RasterizerSceneRD::RenderBufferData *RasterizerSceneRD::render_buffers_get_data( return rb->data; } -void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) { +void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass) { Color clear_color; if (p_render_buffers.is_valid()) { @@ -3748,7 +3829,7 @@ void RasterizerSceneRD::render_scene(RID p_render_buffers, const Transform &p_ca clear_color = storage->get_default_clear_color(); } - _render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_ortogonal, p_cull_result, p_cull_count, p_light_cull_result, p_light_cull_count, p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_gi_probe_cull_result, p_gi_probe_cull_count, p_decal_cull_result, p_decal_cull_count, p_environment, p_camera_effects, p_shadow_atlas, p_reflection_atlas, p_reflection_probe, p_reflection_probe_pass, clear_color); + _render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_ortogonal, p_cull_result, p_cull_count, p_light_cull_result, p_light_cull_count, p_reflection_probe_cull_result, p_reflection_probe_cull_count, p_gi_probe_cull_result, p_gi_probe_cull_count, p_decal_cull_result, p_decal_cull_count, p_lightmap_cull_result, p_lightmap_cull_count, p_environment, p_camera_effects, p_shadow_atlas, p_reflection_atlas, p_reflection_probe, p_reflection_probe_pass, clear_color); if (p_render_buffers.is_valid()) { RENDER_TIMESTAMP("Tonemap"); @@ -4075,6 +4156,98 @@ float RasterizerSceneRD::screen_space_roughness_limiter_get_curve() const { return screen_space_roughness_limiter_curve; } +TypedArray<Image> RasterizerSceneRD::bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size) { + + RD::TextureFormat tf; + tf.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + tf.width = p_image_size.width; // Always 64x64 + tf.height = p_image_size.height; + tf.usage_bits = RD::TEXTURE_USAGE_COLOR_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + + RID albedo_alpha_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + RID normal_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + RID orm_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R16G16B16A16_SFLOAT; + RID emission_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.format = RD::DATA_FORMAT_R32_SFLOAT; + RID depth_write_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + tf.usage_bits = RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + tf.format = RD::get_singleton()->texture_is_format_supported_for_usage(RD::DATA_FORMAT_D32_SFLOAT, RD::TEXTURE_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT) ? RD::DATA_FORMAT_D32_SFLOAT : RD::DATA_FORMAT_X8_D24_UNORM_PACK32; + RID depth_tex = RD::get_singleton()->texture_create(tf, RD::TextureView()); + + Vector<RID> fb_tex; + fb_tex.push_back(albedo_alpha_tex); + fb_tex.push_back(normal_tex); + fb_tex.push_back(orm_tex); + fb_tex.push_back(emission_tex); + fb_tex.push_back(depth_write_tex); + fb_tex.push_back(depth_tex); + + RID fb = RD::get_singleton()->framebuffer_create(fb_tex); + + //RID sampled_light; + + InstanceBase ins; + + ins.base_type = RSG::storage->get_base_type(p_base); + ins.base = p_base; + ins.materials.resize(RSG::storage->mesh_get_surface_count(p_base)); + for (int i = 0; i < ins.materials.size(); i++) { + if (i < p_material_overrides.size()) { + ins.materials.write[i] = p_material_overrides[i]; + } + } + + InstanceBase *cull = &ins; + _render_uv2(&cull, 1, fb, Rect2i(0, 0, p_image_size.width, p_image_size.height)); + + TypedArray<Image> ret; + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(albedo_alpha_tex, 0); + Ref<Image> img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(albedo_alpha_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(normal_tex, 0); + Ref<Image> img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(normal_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(orm_tex, 0); + Ref<Image> img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBA8, data); + RD::get_singleton()->free(orm_tex); + ret.push_back(img); + } + + { + PackedByteArray data = RD::get_singleton()->texture_get_data(emission_tex, 0); + Ref<Image> img; + img.instance(); + img->create(p_image_size.width, p_image_size.height, false, Image::FORMAT_RGBAH, data); + RD::get_singleton()->free(emission_tex); + ret.push_back(img); + } + + RD::get_singleton()->free(depth_write_tex); + RD::get_singleton()->free(depth_tex); + + return ret; +} + RasterizerSceneRD *RasterizerSceneRD::singleton = nullptr; RasterizerSceneRD::RasterizerSceneRD(RasterizerStorageRD *p_storage) { diff --git a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h index a511838e16..5aaa15f441 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_scene_rd.h @@ -80,9 +80,10 @@ protected: }; virtual RenderBufferData *_create_render_buffer_data() = 0; - virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_color) = 0; + virtual void _render_scene(RID p_render_buffer, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_camera_effects, RID p_shadow_atlas, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass, const Color &p_default_color) = 0; virtual void _render_shadow(RID p_framebuffer, InstanceBase **p_cull_result, int p_cull_count, const CameraMatrix &p_projection, const Transform &p_transform, float p_zfar, float p_bias, float p_normal_bias, bool p_use_dp, bool use_dp_flip, bool p_use_pancake) = 0; virtual void _render_material(const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; + virtual void _render_uv2(InstanceBase **p_cull_result, int p_cull_count, RID p_framebuffer, const Rect2i &p_region) = 0; virtual void _debug_giprobe(RID p_gi_probe, RenderingDevice::DrawListID p_draw_list, RID p_framebuffer, const CameraMatrix &p_camera_with_transform, bool p_lighting, bool p_emission, float p_alpha); @@ -843,6 +844,7 @@ public: void sky_set_radiance_size(RID p_sky, int p_radiance_size); void sky_set_mode(RID p_sky, RS::SkyMode p_mode); void sky_set_material(RID p_sky, RID p_material); + Ref<Image> sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size); RID sky_get_radiance_texture_rd(RID p_sky) const; RID sky_get_radiance_uniform_set_rd(RID p_sky, RID p_shader, int p_set) const; @@ -900,6 +902,8 @@ public: void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) {} void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) {} + virtual Ref<Image> environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size); + virtual RID camera_effects_create(); virtual void camera_effects_set_dof_blur_quality(RS::DOFBlurQuality p_quality, bool p_use_jitter); @@ -1194,7 +1198,7 @@ public: RID render_buffers_get_ao_texture(RID p_render_buffers); RID render_buffers_get_back_buffer_texture(RID p_render_buffers); - void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, RID p_environment, RID p_shadow_atlas, RID p_camera_effects, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass); + void render_scene(RID p_render_buffers, const Transform &p_cam_transform, const CameraMatrix &p_cam_projection, bool p_cam_ortogonal, InstanceBase **p_cull_result, int p_cull_count, RID *p_light_cull_result, int p_light_cull_count, RID *p_reflection_probe_cull_result, int p_reflection_probe_cull_count, RID *p_gi_probe_cull_result, int p_gi_probe_cull_count, RID *p_decal_cull_result, int p_decal_cull_count, InstanceBase **p_lightmap_cull_result, int p_lightmap_cull_count, RID p_environment, RID p_shadow_atlas, RID p_camera_effects, RID p_reflection_atlas, RID p_reflection_probe, int p_reflection_probe_pass); void render_shadow(RID p_light, RID p_shadow_atlas, int p_pass, InstanceBase **p_cull_result, int p_cull_count); @@ -1235,6 +1239,8 @@ public: int get_roughness_layers() const; bool is_using_radiance_cubemap_array() const; + virtual TypedArray<Image> bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size); + virtual bool free(RID p_rid); virtual void update(); diff --git a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp index 8d299d623a..5f3803e8be 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp +++ b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.cpp @@ -610,7 +610,113 @@ RID RasterizerStorageRD::texture_2d_create(const Ref<Image> &p_image) { RID RasterizerStorageRD::texture_2d_layered_create(const Vector<Ref<Image>> &p_layers, RS::TextureLayeredType p_layered_type) { - return RID(); + ERR_FAIL_COND_V(p_layers.size() == 0, RID()); + + ERR_FAIL_COND_V(p_layered_type == RS::TEXTURE_LAYERED_CUBEMAP && p_layers.size() != 6, RID()); + ERR_FAIL_COND_V(p_layered_type == RS::TEXTURE_LAYERED_CUBEMAP_ARRAY && (p_layers.size() < 6 || (p_layers.size() % 6) != 0), RID()); + + TextureToRDFormat ret_format; + Vector<Ref<Image>> images; + { + int valid_width = 0; + int valid_height = 0; + bool valid_mipmaps = false; + Image::Format valid_format = Image::FORMAT_MAX; + + for (int i = 0; i < p_layers.size(); i++) { + ERR_FAIL_COND_V(p_layers[i]->empty(), RID()); + + if (i == 0) { + valid_width = p_layers[i]->get_width(); + valid_height = p_layers[i]->get_height(); + valid_format = p_layers[i]->get_format(); + valid_mipmaps = p_layers[i]->has_mipmaps(); + } else { + ERR_FAIL_COND_V(p_layers[i]->get_width() != valid_width, RID()); + ERR_FAIL_COND_V(p_layers[i]->get_height() != valid_height, RID()); + ERR_FAIL_COND_V(p_layers[i]->get_format() != valid_format, RID()); + ERR_FAIL_COND_V(p_layers[i]->has_mipmaps() != valid_mipmaps, RID()); + } + + images.push_back(_validate_texture_format(p_layers[i], ret_format)); + } + } + + Texture texture; + + texture.type = Texture::TYPE_LAYERED; + texture.layered_type = p_layered_type; + + texture.width = p_layers[0]->get_width(); + texture.height = p_layers[0]->get_height(); + texture.layers = p_layers.size(); + texture.mipmaps = p_layers[0]->get_mipmap_count() + 1; + texture.depth = 1; + texture.format = p_layers[0]->get_format(); + texture.validated_format = images[0]->get_format(); + + switch (p_layered_type) { + case RS::TEXTURE_LAYERED_2D_ARRAY: { + texture.rd_type = RD::TEXTURE_TYPE_2D_ARRAY; + } break; + case RS::TEXTURE_LAYERED_CUBEMAP: { + texture.rd_type = RD::TEXTURE_TYPE_CUBE; + } break; + case RS::TEXTURE_LAYERED_CUBEMAP_ARRAY: { + texture.rd_type = RD::TEXTURE_TYPE_CUBE_ARRAY; + } break; + } + + texture.rd_format = ret_format.format; + texture.rd_format_srgb = ret_format.format_srgb; + + RD::TextureFormat rd_format; + RD::TextureView rd_view; + { //attempt register + rd_format.format = texture.rd_format; + rd_format.width = texture.width; + rd_format.height = texture.height; + rd_format.depth = 1; + rd_format.array_layers = texture.layers; + rd_format.mipmaps = texture.mipmaps; + rd_format.type = texture.rd_type; + rd_format.samples = RD::TEXTURE_SAMPLES_1; + rd_format.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT | RD::TEXTURE_USAGE_CAN_COPY_FROM_BIT; + if (texture.rd_format_srgb != RD::DATA_FORMAT_MAX) { + rd_format.shareable_formats.push_back(texture.rd_format); + rd_format.shareable_formats.push_back(texture.rd_format_srgb); + } + } + { + rd_view.swizzle_r = ret_format.swizzle_r; + rd_view.swizzle_g = ret_format.swizzle_g; + rd_view.swizzle_b = ret_format.swizzle_b; + rd_view.swizzle_a = ret_format.swizzle_a; + } + Vector<Vector<uint8_t>> data_slices; + for (int i = 0; i < images.size(); i++) { + Vector<uint8_t> data = images[i]->get_data(); //use image data + data_slices.push_back(data); + } + texture.rd_texture = RD::get_singleton()->texture_create(rd_format, rd_view, data_slices); + ERR_FAIL_COND_V(texture.rd_texture.is_null(), RID()); + if (texture.rd_format_srgb != RD::DATA_FORMAT_MAX) { + rd_view.format_override = texture.rd_format_srgb; + texture.rd_texture_srgb = RD::get_singleton()->texture_create_shared(rd_view, texture.rd_texture); + if (texture.rd_texture_srgb.is_null()) { + RD::get_singleton()->free(texture.rd_texture); + ERR_FAIL_COND_V(texture.rd_texture_srgb.is_null(), RID()); + } + } + + //used for 2D, overridable + texture.width_2d = texture.width; + texture.height_2d = texture.height; + texture.is_render_target = false; + texture.rd_view = rd_view; + texture.is_proxy = false; + + return texture_owner.make_rid(texture); } RID RasterizerStorageRD::texture_3d_create(const Vector<Ref<Image>> &p_slices) { @@ -729,9 +835,31 @@ RID RasterizerStorageRD::texture_2d_placeholder_create() { return texture_2d_create(image); } -RID RasterizerStorageRD::texture_2d_layered_placeholder_create() { +RID RasterizerStorageRD::texture_2d_layered_placeholder_create(RS::TextureLayeredType p_layered_type) { - return RID(); + //this could be better optimized to reuse an existing image , done this way + //for now to get it working + Ref<Image> image; + image.instance(); + image->create(4, 4, false, Image::FORMAT_RGBA8); + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + image->set_pixel(i, j, Color(1, 0, 1, 1)); + } + } + + Vector<Ref<Image>> images; + if (p_layered_type == RS::TEXTURE_LAYERED_2D_ARRAY) { + images.push_back(image); + } else { + //cube + for (int i = 0; i < 6; i++) { + images.push_back(image); + } + } + + return texture_2d_layered_create(images, p_layered_type); } RID RasterizerStorageRD::texture_3d_placeholder_create() { @@ -2604,7 +2732,7 @@ void RasterizerStorageRD::_multimesh_make_local(MultiMesh *multimesh) const { uint32_t data_cache_dirty_region_count = (multimesh->instances - 1) / MULTIMESH_DIRTY_REGION_SIZE + 1; multimesh->data_cache_dirty_regions = memnew_arr(bool, data_cache_dirty_region_count); for (uint32_t i = 0; i < data_cache_dirty_region_count; i++) { - multimesh->data_cache_dirty_regions[i] = 0; + multimesh->data_cache_dirty_regions[i] = false; } multimesh->data_cache_used_dirty_regions = 0; } @@ -4139,6 +4267,180 @@ RID RasterizerStorageRD::gi_probe_get_sdf_texture(RID p_gi_probe) { return gi_probe->sdf_texture; } +/* LIGHTMAP API */ + +RID RasterizerStorageRD::lightmap_create() { + return lightmap_owner.make_rid(Lightmap()); +} + +void RasterizerStorageRD::lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + lightmap_array_version++; + + //erase lightmap users + if (lm->light_texture.is_valid()) { + Texture *t = texture_owner.getornull(lm->light_texture); + if (t) { + t->lightmap_users.erase(p_lightmap); + } + } + + Texture *t = texture_owner.getornull(p_light); + lm->light_texture = p_light; + lm->uses_spherical_harmonics = p_uses_spherical_haromics; + + RID default_2d_array = default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE]; + if (!t) { + + if (using_lightmap_array) { + if (lm->array_index >= 0) { + lightmap_textures.write[lm->array_index] = default_2d_array; + lm->array_index = -1; + } + } + + return; + } + + t->lightmap_users.insert(p_lightmap); + + if (using_lightmap_array) { + if (lm->array_index < 0) { + //not in array, try to put in array + for (int i = 0; i < lightmap_textures.size(); i++) { + if (lightmap_textures[i] == default_2d_array) { + lm->array_index = i; + break; + } + } + } + ERR_FAIL_COND_MSG(lm->array_index < 0, "Maximum amount of lightmaps in use (" + itos(lightmap_textures.size()) + ") has been exceeded, lightmap will nod display properly."); + + lightmap_textures.write[lm->array_index] = t->rd_texture; + } +} + +void RasterizerStorageRD::lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + lm->bounds = p_bounds; +} + +void RasterizerStorageRD::lightmap_set_probe_interior(RID p_lightmap, bool p_interior) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + lm->interior = p_interior; +} + +void RasterizerStorageRD::lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + if (p_points.size()) { + ERR_FAIL_COND(p_points.size() * 9 != p_point_sh.size()); + ERR_FAIL_COND((p_tetrahedra.size() % 4) != 0); + ERR_FAIL_COND((p_bsp_tree.size() % 6) != 0); + } + + lm->points = p_points; + lm->bsp_tree = p_bsp_tree; + lm->point_sh = p_point_sh; + lm->tetrahedra = p_tetrahedra; +} + +PackedVector3Array RasterizerStorageRD::lightmap_get_probe_capture_points(RID p_lightmap) const { + + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedVector3Array()); + + return lm->points; +} +PackedColorArray RasterizerStorageRD::lightmap_get_probe_capture_sh(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedColorArray()); + return lm->point_sh; +} +PackedInt32Array RasterizerStorageRD::lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedInt32Array()); + return lm->tetrahedra; +} +PackedInt32Array RasterizerStorageRD::lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, PackedInt32Array()); + return lm->bsp_tree; +} + +void RasterizerStorageRD::lightmap_set_probe_capture_update_speed(float p_speed) { + lightmap_probe_capture_update_speed = p_speed; +} + +void RasterizerStorageRD::lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh) { + Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND(!lm); + + for (int i = 0; i < 9; i++) { + r_sh[i] = Color(0, 0, 0, 0); + } + + if (!lm->points.size() || !lm->bsp_tree.size() || !lm->tetrahedra.size()) { + return; + } + + static_assert(sizeof(Lightmap::BSP) == 24); + + const Lightmap::BSP *bsp = (const Lightmap::BSP *)lm->bsp_tree.ptr(); + int32_t node = 0; + while (node >= 0) { + + if (Plane(bsp[node].plane[0], bsp[node].plane[1], bsp[node].plane[2], bsp[node].plane[3]).is_point_over(p_point)) { +#ifdef DEBUG_ENABLED + ERR_FAIL_COND(bsp[node].over >= 0 && bsp[node].over < node); +#endif + + node = bsp[node].over; + } else { +#ifdef DEBUG_ENABLED + ERR_FAIL_COND(bsp[node].under >= 0 && bsp[node].under < node); +#endif + node = bsp[node].under; + } + } + + if (node == Lightmap::BSP::EMPTY_LEAF) { + return; //nothing could be done + } + + node = ABS(node) - 1; + + uint32_t *tetrahedron = (uint32_t *)&lm->tetrahedra[node * 4]; + Vector3 points[4] = { lm->points[tetrahedron[0]], lm->points[tetrahedron[1]], lm->points[tetrahedron[2]], lm->points[tetrahedron[3]] }; + const Color *sh_colors[4]{ &lm->point_sh[tetrahedron[0] * 9], &lm->point_sh[tetrahedron[1] * 9], &lm->point_sh[tetrahedron[2] * 9], &lm->point_sh[tetrahedron[3] * 9] }; + Color barycentric = Geometry::tetrahedron_get_barycentric_coords(points[0], points[1], points[2], points[3], p_point); + + for (int i = 0; i < 4; i++) { + float c = CLAMP(barycentric[i], 0.0, 1.0); + for (int j = 0; j < 9; j++) { + r_sh[j] += sh_colors[i][j] * c; + } + } +} + +bool RasterizerStorageRD::lightmap_is_interior(RID p_lightmap) const { + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, false); + return lm->interior; +} +AABB RasterizerStorageRD::lightmap_get_aabb(RID p_lightmap) const { + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + ERR_FAIL_COND_V(!lm, AABB()); + return lm->bounds; +} /* RENDER TARGET API */ @@ -4491,6 +4793,9 @@ void RasterizerStorageRD::base_update_dependency(RID p_base, RasterizerScene::In } else if (gi_probe_owner.owns(p_base)) { GIProbe *gip = gi_probe_owner.getornull(p_base); p_instance->update_dependency(&gip->instance_dependency); + } else if (lightmap_owner.owns(p_base)) { + Lightmap *lm = lightmap_owner.getornull(p_base); + p_instance->update_dependency(&lm->instance_dependency); } else if (light_owner.owns(p_base)) { Light *l = light_owner.getornull(p_base); p_instance->update_dependency(&l->instance_dependency); @@ -4525,6 +4830,9 @@ RS::InstanceType RasterizerStorageRD::get_base_type(RID p_rid) const { if (light_owner.owns(p_rid)) { return RS::INSTANCE_LIGHT; } + if (lightmap_owner.owns(p_rid)) { + return RS::INSTANCE_LIGHTMAP; + } return RS::INSTANCE_NONE; } @@ -4588,7 +4896,7 @@ void RasterizerStorageRD::_update_decal_atlas() { Vector<DecalAtlas::SortItem> itemsv; itemsv.resize(decal_atlas.textures.size()); int base_size = 8; - const RID *K = NULL; + const RID *K = nullptr; int idx = 0; while ((K = decal_atlas.textures.next(K))) { @@ -4678,7 +4986,7 @@ void RasterizerStorageRD::_update_decal_atlas() { DecalAtlas::Texture *t = decal_atlas.textures.getptr(items[i].texture); t->uv_rect.position = items[i].pos * border + Vector2i(border / 2, border / 2); t->uv_rect.size = items[i].pixel_size; - //print_line("blitrect: " + t->uv_rect); + t->uv_rect.position /= Size2(decal_atlas.size); t->uv_rect.size /= Size2(decal_atlas.size); } @@ -4742,7 +5050,7 @@ void RasterizerStorageRD::_update_decal_atlas() { RD::DrawListID draw_list = RD::get_singleton()->draw_list_begin(mm.fb, RD::INITIAL_ACTION_CLEAR, RD::FINAL_ACTION_READ, RD::INITIAL_ACTION_DROP, RD::FINAL_ACTION_DISCARD, cc); - const RID *K = NULL; + const RID *K = nullptr; while ((K = decal_atlas.textures.next(K))) { DecalAtlas::Texture *t = decal_atlas.textures.getptr(*K); Texture *src_tex = texture_owner.getornull(*K); @@ -5151,7 +5459,7 @@ Vector<StringName> RasterizerStorageRD::global_variable_get_list() const { ERR_FAIL_V_MSG(Vector<StringName>(), "This function should never be used outside the editor, it can severely damage performance."); } - const StringName *K = NULL; + const StringName *K = nullptr; Vector<StringName> names; while ((K = global_variables.variables.next(K))) { names.push_back(*K); @@ -5563,6 +5871,11 @@ bool RasterizerStorageRD::free(RID p_rid) { GIProbe *gi_probe = gi_probe_owner.getornull(p_rid); gi_probe->instance_dependency.instance_notify_deleted(p_rid); gi_probe_owner.free(p_rid); + } else if (lightmap_owner.owns(p_rid)) { + lightmap_set_textures(p_rid, RID(), false); + Lightmap *lightmap = lightmap_owner.getornull(p_rid); + lightmap->instance_dependency.instance_notify_deleted(p_rid); + lightmap_owner.free(p_rid); } else if (light_owner.owns(p_rid)) { @@ -5801,6 +6114,32 @@ RasterizerStorageRD::RasterizerStorageRD() { } } + { //create default array + + RD::TextureFormat tformat; + tformat.format = RD::DATA_FORMAT_R8G8B8A8_UNORM; + tformat.width = 4; + tformat.height = 4; + tformat.array_layers = 1; + tformat.usage_bits = RD::TEXTURE_USAGE_SAMPLING_BIT | RD::TEXTURE_USAGE_CAN_UPDATE_BIT; + tformat.type = RD::TEXTURE_TYPE_2D_ARRAY; + + Vector<uint8_t> pv; + pv.resize(16 * 4); + for (int i = 0; i < 16; i++) { + pv.set(i * 4 + 0, 255); + pv.set(i * 4 + 1, 255); + pv.set(i * 4 + 2, 255); + pv.set(i * 4 + 3, 255); + } + + { + Vector<Vector<uint8_t>> vpv; + vpv.push_back(pv); + default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE] = RD::get_singleton()->texture_create(tformat, RD::TextureView(), vpv); + } + } + //default samplers for (int i = 1; i < RS::CANVAS_ITEM_TEXTURE_FILTER_MAX; i++) { for (int j = 1; j < RS::CANVAS_ITEM_TEXTURE_REPEAT_MAX; j++) { @@ -5872,124 +6211,133 @@ RasterizerStorageRD::RasterizerStorageRD() { //default rd buffers { - //vertex + Vector<uint8_t> buffer; { - Vector<uint8_t> buffer; + buffer.resize(sizeof(float) * 3); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_VERTEX] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } - buffer.resize(sizeof(float) * 3); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_VERTEX] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //normal + buffer.resize(sizeof(float) * 3); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_NORMAL] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //normal - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 3); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_NORMAL] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //tangent + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + fptr[3] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TANGENT] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //tangent - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 4); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - fptr[3] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TANGENT] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //color + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 1.0; + fptr[1] = 1.0; + fptr[2] = 1.0; + fptr[3] = 1.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_COLOR] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //color - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 4); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 1.0; - fptr[1] = 1.0; - fptr[2] = 1.0; - fptr[3] = 1.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_COLOR] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} + { //tex uv 1 + buffer.resize(sizeof(float) * 2); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } + { //tex uv 2 + buffer.resize(sizeof(float) * 2); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV2] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } -{ //tex uv 1 - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 2); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -{ //tex uv 2 - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 2); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; + { //bones + buffer.resize(sizeof(uint32_t) * 4); + { + uint8_t *w = buffer.ptrw(); + uint32_t *fptr = (uint32_t *)w; + fptr[0] = 0; + fptr[1] = 0; + fptr[2] = 0; + fptr[3] = 0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_BONES] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } + + { //weights + buffer.resize(sizeof(float) * 4); + { + uint8_t *w = buffer.ptrw(); + float *fptr = (float *)w; + fptr[0] = 0.0; + fptr[1] = 0.0; + fptr[2] = 0.0; + fptr[3] = 0.0; + } + mesh_default_rd_buffers[DEFAULT_RD_BUFFER_WEIGHTS] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); + } } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_TEX_UV2] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -{ //bones - Vector<uint8_t> buffer; - buffer.resize(sizeof(uint32_t) * 4); { - uint8_t *w = buffer.ptrw(); - uint32_t *fptr = (uint32_t *)w; - fptr[0] = 0; - fptr[1] = 0; - fptr[2] = 0; - fptr[3] = 0; + Vector<String> sdf_versions; + sdf_versions.push_back(""); //one only + giprobe_sdf_shader.initialize(sdf_versions); + giprobe_sdf_shader_version = giprobe_sdf_shader.version_create(); + giprobe_sdf_shader.version_set_compute_code(giprobe_sdf_shader_version, "", "", "", Vector<String>()); + giprobe_sdf_shader_version_shader = giprobe_sdf_shader.version_get_shader(giprobe_sdf_shader_version, 0); + giprobe_sdf_shader_pipeline = RD::get_singleton()->compute_pipeline_create(giprobe_sdf_shader_version_shader); } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_BONES] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -{ //weights - Vector<uint8_t> buffer; - buffer.resize(sizeof(float) * 4); - { - uint8_t *w = buffer.ptrw(); - float *fptr = (float *)w; - fptr[0] = 0.0; - fptr[1] = 0.0; - fptr[2] = 0.0; - fptr[3] = 0.0; + using_lightmap_array = true; // high end + if (using_lightmap_array) { + + uint32_t textures_per_stage = RD::get_singleton()->limit_get(RD::LIMIT_MAX_TEXTURES_PER_SHADER_STAGE); + + if (textures_per_stage <= 256) { + lightmap_textures.resize(32); + } else { + lightmap_textures.resize(1024); + } + + for (int i = 0; i < lightmap_textures.size(); i++) { + lightmap_textures.write[i] = default_rd_textures[DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE]; + } } - mesh_default_rd_buffers[DEFAULT_RD_BUFFER_WEIGHTS] = RD::get_singleton()->vertex_buffer_create(buffer.size(), buffer); -} -} -{ - Vector<String> sdf_versions; - sdf_versions.push_back(""); //one only - giprobe_sdf_shader.initialize(sdf_versions); - giprobe_sdf_shader_version = giprobe_sdf_shader.version_create(); - giprobe_sdf_shader.version_set_compute_code(giprobe_sdf_shader_version, "", "", "", Vector<String>()); - giprobe_sdf_shader_version_shader = giprobe_sdf_shader.version_get_shader(giprobe_sdf_shader_version, 0); - giprobe_sdf_shader_pipeline = RD::get_singleton()->compute_pipeline_create(giprobe_sdf_shader_version_shader); -} + lightmap_probe_capture_update_speed = GLOBAL_GET("rendering/lightmapper/probe_capture_update_speed"); } RasterizerStorageRD::~RasterizerStorageRD() { diff --git a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h index f874c3baf8..94b373247f 100644 --- a/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h +++ b/servers/rendering/rasterizer_rd/rasterizer_storage_rd.h @@ -92,6 +92,7 @@ public: DEFAULT_RD_TEXTURE_CUBEMAP_BLACK, DEFAULT_RD_TEXTURE_CUBEMAP_ARRAY_BLACK, DEFAULT_RD_TEXTURE_3D_WHITE, + DEFAULT_RD_TEXTURE_2D_ARRAY_WHITE, DEFAULT_RD_TEXTURE_MAX }; @@ -118,6 +119,7 @@ private: }; Type type; + RS::TextureLayeredType layered_type = RS::TEXTURE_LAYERED_2D_ARRAY; RenderingDevice::TextureType rd_type; RID rd_texture; @@ -147,6 +149,7 @@ private: RID proxy_to; Vector<RID> proxies; + Set<RID> lightmap_users; RS::TextureDetectCallback detect_3d_callback = nullptr; void *detect_3d_callback_ud = nullptr; @@ -524,6 +527,40 @@ private: mutable RID_Owner<GIProbe> gi_probe_owner; + /* REFLECTION PROBE */ + + struct Lightmap { + + RID light_texture; + bool uses_spherical_harmonics = false; + bool interior = false; + AABB bounds = AABB(Vector3(), Vector3(1, 1, 1)); + int32_t array_index = -1; //unassigned + PackedVector3Array points; + PackedColorArray point_sh; + PackedInt32Array tetrahedra; + PackedInt32Array bsp_tree; + + struct BSP { + static const int32_t EMPTY_LEAF = INT32_MIN; + float plane[4]; + int32_t over = EMPTY_LEAF, under = EMPTY_LEAF; + }; + + RasterizerScene::InstanceDependency instance_dependency; + }; + + bool using_lightmap_array; //high end uses this + /* for high end */ + + Vector<RID> lightmap_textures; + + uint64_t lightmap_array_version = 0; + + mutable RID_Owner<Lightmap> lightmap_owner; + + float lightmap_probe_capture_update_speed = 4; + /* RENDER TARGET */ struct RenderTarget { @@ -653,7 +690,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create(); - virtual RID texture_2d_layered_placeholder_create(); + virtual RID texture_2d_layered_placeholder_create(RenderingServer::TextureLayeredType p_layered_type); virtual RID texture_3d_placeholder_create(); virtual Ref<Image> texture_2d_get(RID p_texture) const; @@ -1270,23 +1307,47 @@ public: /* LIGHTMAP CAPTURE */ - void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) {} - AABB lightmap_capture_get_bounds(RID p_capture) const { return AABB(); } - void lightmap_capture_set_octree(RID p_capture, const Vector<uint8_t> &p_octree) {} - RID lightmap_capture_create() { - return RID(); + virtual RID lightmap_create(); + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics); + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds); + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior); + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree); + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const; + virtual AABB lightmap_get_aabb(RID p_lightmap) const; + virtual bool lightmap_is_interior(RID p_lightmap) const; + virtual void lightmap_tap_sh_light(RID p_lightmap, const Vector3 &p_point, Color *r_sh); + virtual void lightmap_set_probe_capture_update_speed(float p_speed); + _FORCE_INLINE_ float lightmap_get_probe_capture_update_speed() const { + return lightmap_probe_capture_update_speed; + } + + _FORCE_INLINE_ int32_t lightmap_get_array_index(RID p_lightmap) const { + ERR_FAIL_COND_V(!using_lightmap_array, -1); //only for arrays + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + return lm->array_index; + } + _FORCE_INLINE_ bool lightmap_uses_spherical_harmonics(RID p_lightmap) const { + ERR_FAIL_COND_V(!using_lightmap_array, false); //only for arrays + const Lightmap *lm = lightmap_owner.getornull(p_lightmap); + return lm->uses_spherical_harmonics; } - Vector<uint8_t> lightmap_capture_get_octree(RID p_capture) const { - return Vector<uint8_t>(); + _FORCE_INLINE_ uint64_t lightmap_array_get_version() const { + ERR_FAIL_COND_V(!using_lightmap_array, 0); //only for arrays + return lightmap_array_version; } - void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) {} - Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const { return Transform(); } - void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) {} - int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const { return 0; } - void lightmap_capture_set_energy(RID p_capture, float p_energy) {} - float lightmap_capture_get_energy(RID p_capture) const { return 0.0; } - const Vector<LightmapCaptureOctree> *lightmap_capture_get_octree_ptr(RID p_capture) const { - return nullptr; + + _FORCE_INLINE_ int lightmap_array_get_size() const { + ERR_FAIL_COND_V(!using_lightmap_array, 0); //only for arrays + return lightmap_textures.size(); + } + + _FORCE_INLINE_ const Vector<RID> &lightmap_array_get_textures() const { + ERR_FAIL_COND_V(!using_lightmap_array, lightmap_textures); //only for arrays + return lightmap_textures; } /* PARTICLES */ diff --git a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp index 2bfdb7fffe..5838936f35 100644 --- a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp +++ b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.cpp @@ -31,16 +31,20 @@ #include "render_pipeline_vertex_format_cache_rd.h" #include "core/os/memory.h" -RID RenderPipelineVertexFormatCacheRD::_generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id) { +RID RenderPipelineVertexFormatCacheRD::_generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe) { RD::PipelineMultisampleState multisample_state_version = multisample_state; multisample_state_version.sample_count = RD::get_singleton()->framebuffer_format_get_texture_samples(p_framebuffer_format_id); - RID pipeline = RD::get_singleton()->render_pipeline_create(shader, p_framebuffer_format_id, p_vertex_format_id, render_primitive, rasterization_state, multisample_state_version, depth_stencil_state, blend_state, dynamic_state_flags); + RD::PipelineRasterizationState raster_state_version = rasterization_state; + raster_state_version.wireframe = p_wireframe; + + RID pipeline = RD::get_singleton()->render_pipeline_create(shader, p_framebuffer_format_id, p_vertex_format_id, render_primitive, raster_state_version, multisample_state_version, depth_stencil_state, blend_state, dynamic_state_flags); ERR_FAIL_COND_V(pipeline.is_null(), RID()); versions = (Version *)memrealloc(versions, sizeof(Version) * (version_count + 1)); versions[version_count].framebuffer_id = p_framebuffer_format_id; versions[version_count].vertex_id = p_vertex_format_id; + versions[version_count].wireframe = p_wireframe; versions[version_count].pipeline = pipeline; version_count++; return pipeline; diff --git a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h index ecb1b42b06..a8bfdb5a26 100644 --- a/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h +++ b/servers/rendering/rasterizer_rd/render_pipeline_vertex_format_cache_rd.h @@ -51,13 +51,14 @@ class RenderPipelineVertexFormatCacheRD { struct Version { RD::VertexFormatID vertex_id; RD::FramebufferFormatID framebuffer_id; + bool wireframe; RID pipeline; }; Version *versions; uint32_t version_count; - RID _generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id); + RID _generate_version(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe); void _clear(); @@ -65,7 +66,7 @@ public: void setup(RID p_shader, RD::RenderPrimitive p_primitive, const RD::PipelineRasterizationState &p_rasterization_state, RD::PipelineMultisampleState p_multisample, const RD::PipelineDepthStencilState &p_depth_stencil_state, const RD::PipelineColorBlendState &p_blend_state, int p_dynamic_state_flags = 0); void update_shader(RID p_shader); - _FORCE_INLINE_ RID get_render_pipeline(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id) { + _FORCE_INLINE_ RID get_render_pipeline(RD::VertexFormatID p_vertex_format_id, RD::FramebufferFormatID p_framebuffer_format_id, bool p_wireframe = false) { #ifdef DEBUG_ENABLED ERR_FAIL_COND_V_MSG(shader.is_null(), RID(), "Attempted to use an unused shader variant (shader is null),"); @@ -74,13 +75,13 @@ public: spin_lock.lock(); RID result; for (uint32_t i = 0; i < version_count; i++) { - if (versions[i].vertex_id == p_vertex_format_id && versions[i].framebuffer_id == p_framebuffer_format_id) { + if (versions[i].vertex_id == p_vertex_format_id && versions[i].framebuffer_id == p_framebuffer_format_id && versions[i].wireframe == p_wireframe) { result = versions[i].pipeline; spin_lock.unlock(); return result; } } - result = _generate_version(p_vertex_format_id, p_framebuffer_format_id); + result = _generate_version(p_vertex_format_id, p_framebuffer_format_id, p_wireframe); spin_lock.unlock(); return result; } diff --git a/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp b/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp index d4e6576125..2ef29e97ff 100644 --- a/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp +++ b/servers/rendering/rasterizer_rd/shader_compiler_rd.cpp @@ -60,39 +60,71 @@ static int _get_datatype_size(SL::DataType p_type) { switch (p_type) { - case SL::TYPE_VOID: return 0; - case SL::TYPE_BOOL: return 4; - case SL::TYPE_BVEC2: return 8; - case SL::TYPE_BVEC3: return 12; - case SL::TYPE_BVEC4: return 16; - case SL::TYPE_INT: return 4; - case SL::TYPE_IVEC2: return 8; - case SL::TYPE_IVEC3: return 12; - case SL::TYPE_IVEC4: return 16; - case SL::TYPE_UINT: return 4; - case SL::TYPE_UVEC2: return 8; - case SL::TYPE_UVEC3: return 12; - case SL::TYPE_UVEC4: return 16; - case SL::TYPE_FLOAT: return 4; - case SL::TYPE_VEC2: return 8; - case SL::TYPE_VEC3: return 12; - case SL::TYPE_VEC4: return 16; + case SL::TYPE_VOID: + return 0; + case SL::TYPE_BOOL: + return 4; + case SL::TYPE_BVEC2: + return 8; + case SL::TYPE_BVEC3: + return 12; + case SL::TYPE_BVEC4: + return 16; + case SL::TYPE_INT: + return 4; + case SL::TYPE_IVEC2: + return 8; + case SL::TYPE_IVEC3: + return 12; + case SL::TYPE_IVEC4: + return 16; + case SL::TYPE_UINT: + return 4; + case SL::TYPE_UVEC2: + return 8; + case SL::TYPE_UVEC3: + return 12; + case SL::TYPE_UVEC4: + return 16; + case SL::TYPE_FLOAT: + return 4; + case SL::TYPE_VEC2: + return 8; + case SL::TYPE_VEC3: + return 12; + case SL::TYPE_VEC4: + return 16; case SL::TYPE_MAT2: return 32; //4 * 4 + 4 * 4 case SL::TYPE_MAT3: return 48; // 4 * 4 + 4 * 4 + 4 * 4 - case SL::TYPE_MAT4: return 64; - case SL::TYPE_SAMPLER2D: return 16; - case SL::TYPE_ISAMPLER2D: return 16; - case SL::TYPE_USAMPLER2D: return 16; - case SL::TYPE_SAMPLER2DARRAY: return 16; - case SL::TYPE_ISAMPLER2DARRAY: return 16; - case SL::TYPE_USAMPLER2DARRAY: return 16; - case SL::TYPE_SAMPLER3D: return 16; - case SL::TYPE_ISAMPLER3D: return 16; - case SL::TYPE_USAMPLER3D: return 16; - case SL::TYPE_SAMPLERCUBE: return 16; - case SL::TYPE_STRUCT: return 0; + case SL::TYPE_MAT4: + return 64; + case SL::TYPE_SAMPLER2D: + return 16; + case SL::TYPE_ISAMPLER2D: + return 16; + case SL::TYPE_USAMPLER2D: + return 16; + case SL::TYPE_SAMPLER2DARRAY: + return 16; + case SL::TYPE_ISAMPLER2DARRAY: + return 16; + case SL::TYPE_USAMPLER2DARRAY: + return 16; + case SL::TYPE_SAMPLER3D: + return 16; + case SL::TYPE_ISAMPLER3D: + return 16; + case SL::TYPE_USAMPLER3D: + return 16; + case SL::TYPE_SAMPLERCUBE: + return 16; + case SL::TYPE_SAMPLERCUBEARRAY: + return 16; + case SL::TYPE_STRUCT: + return 0; + case SL::TYPE_MAX: { ERR_FAIL_V(0); }; @@ -105,37 +137,70 @@ static int _get_datatype_alignment(SL::DataType p_type) { switch (p_type) { - case SL::TYPE_VOID: return 0; - case SL::TYPE_BOOL: return 4; - case SL::TYPE_BVEC2: return 8; - case SL::TYPE_BVEC3: return 16; - case SL::TYPE_BVEC4: return 16; - case SL::TYPE_INT: return 4; - case SL::TYPE_IVEC2: return 8; - case SL::TYPE_IVEC3: return 16; - case SL::TYPE_IVEC4: return 16; - case SL::TYPE_UINT: return 4; - case SL::TYPE_UVEC2: return 8; - case SL::TYPE_UVEC3: return 16; - case SL::TYPE_UVEC4: return 16; - case SL::TYPE_FLOAT: return 4; - case SL::TYPE_VEC2: return 8; - case SL::TYPE_VEC3: return 16; - case SL::TYPE_VEC4: return 16; - case SL::TYPE_MAT2: return 16; - case SL::TYPE_MAT3: return 16; - case SL::TYPE_MAT4: return 16; - case SL::TYPE_SAMPLER2D: return 16; - case SL::TYPE_ISAMPLER2D: return 16; - case SL::TYPE_USAMPLER2D: return 16; - case SL::TYPE_SAMPLER2DARRAY: return 16; - case SL::TYPE_ISAMPLER2DARRAY: return 16; - case SL::TYPE_USAMPLER2DARRAY: return 16; - case SL::TYPE_SAMPLER3D: return 16; - case SL::TYPE_ISAMPLER3D: return 16; - case SL::TYPE_USAMPLER3D: return 16; - case SL::TYPE_SAMPLERCUBE: return 16; - case SL::TYPE_STRUCT: return 0; + case SL::TYPE_VOID: + return 0; + case SL::TYPE_BOOL: + return 4; + case SL::TYPE_BVEC2: + return 8; + case SL::TYPE_BVEC3: + return 16; + case SL::TYPE_BVEC4: + return 16; + case SL::TYPE_INT: + return 4; + case SL::TYPE_IVEC2: + return 8; + case SL::TYPE_IVEC3: + return 16; + case SL::TYPE_IVEC4: + return 16; + case SL::TYPE_UINT: + return 4; + case SL::TYPE_UVEC2: + return 8; + case SL::TYPE_UVEC3: + return 16; + case SL::TYPE_UVEC4: + return 16; + case SL::TYPE_FLOAT: + return 4; + case SL::TYPE_VEC2: + return 8; + case SL::TYPE_VEC3: + return 16; + case SL::TYPE_VEC4: + return 16; + case SL::TYPE_MAT2: + return 16; + case SL::TYPE_MAT3: + return 16; + case SL::TYPE_MAT4: + return 16; + case SL::TYPE_SAMPLER2D: + return 16; + case SL::TYPE_ISAMPLER2D: + return 16; + case SL::TYPE_USAMPLER2D: + return 16; + case SL::TYPE_SAMPLER2DARRAY: + return 16; + case SL::TYPE_ISAMPLER2DARRAY: + return 16; + case SL::TYPE_USAMPLER2DARRAY: + return 16; + case SL::TYPE_SAMPLER3D: + return 16; + case SL::TYPE_ISAMPLER3D: + return 16; + case SL::TYPE_USAMPLER3D: + return 16; + case SL::TYPE_SAMPLERCUBE: + return 16; + case SL::TYPE_SAMPLERCUBEARRAY: + return 16; + case SL::TYPE_STRUCT: + return 0; case SL::TYPE_MAX: { ERR_FAIL_V(0); } @@ -146,8 +211,10 @@ static int _get_datatype_alignment(SL::DataType p_type) { static String _interpstr(SL::DataInterpolation p_interp) { switch (p_interp) { - case SL::INTERPOLATION_FLAT: return "flat "; - case SL::INTERPOLATION_SMOOTH: return ""; + case SL::INTERPOLATION_FLAT: + return "flat "; + case SL::INTERPOLATION_SMOOTH: + return ""; } return ""; } @@ -155,10 +222,14 @@ static String _interpstr(SL::DataInterpolation p_interp) { static String _prestr(SL::DataPrecision p_pres) { switch (p_pres) { - case SL::PRECISION_LOWP: return "lowp "; - case SL::PRECISION_MEDIUMP: return "mediump "; - case SL::PRECISION_HIGHP: return "highp "; - case SL::PRECISION_DEFAULT: return ""; + case SL::PRECISION_LOWP: + return "lowp "; + case SL::PRECISION_MEDIUMP: + return "mediump "; + case SL::PRECISION_HIGHP: + return "highp "; + case SL::PRECISION_DEFAULT: + return ""; } return ""; } @@ -166,9 +237,12 @@ static String _prestr(SL::DataPrecision p_pres) { static String _qualstr(SL::ArgumentQualifier p_qual) { switch (p_qual) { - case SL::ARGUMENT_QUALIFIER_IN: return ""; - case SL::ARGUMENT_QUALIFIER_OUT: return "out "; - case SL::ARGUMENT_QUALIFIER_INOUT: return "inout "; + case SL::ARGUMENT_QUALIFIER_IN: + return ""; + case SL::ARGUMENT_QUALIFIER_OUT: + return "out "; + case SL::ARGUMENT_QUALIFIER_INOUT: + return "inout "; } return ""; } @@ -196,7 +270,8 @@ static String f2sp0(float p_float) { static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNode::Value> &p_values) { switch (p_type) { - case SL::TYPE_BOOL: return p_values[0].boolean ? "true" : "false"; + case SL::TYPE_BOOL: + return p_values[0].boolean ? "true" : "false"; case SL::TYPE_BVEC2: case SL::TYPE_BVEC3: case SL::TYPE_BVEC4: { @@ -212,7 +287,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo return text; } - case SL::TYPE_INT: return itos(p_values[0].sint); + case SL::TYPE_INT: + return itos(p_values[0].sint); case SL::TYPE_IVEC2: case SL::TYPE_IVEC3: case SL::TYPE_IVEC4: { @@ -228,7 +304,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo return text; } break; - case SL::TYPE_UINT: return itos(p_values[0].uint) + "u"; + case SL::TYPE_UINT: + return itos(p_values[0].uint) + "u"; case SL::TYPE_UVEC2: case SL::TYPE_UVEC3: case SL::TYPE_UVEC4: { @@ -243,7 +320,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo text += ")"; return text; } break; - case SL::TYPE_FLOAT: return f2sp0(p_values[0].real); + case SL::TYPE_FLOAT: + return f2sp0(p_values[0].real); case SL::TYPE_VEC2: case SL::TYPE_VEC3: case SL::TYPE_VEC4: { @@ -274,7 +352,8 @@ static String get_constant_text(SL::DataType p_type, const Vector<SL::ConstantNo return text; } break; - default: ERR_FAIL_V(String()); + default: + ERR_FAIL_V(String()); } } diff --git a/servers/rendering/rasterizer_rd/shaders/copy.glsl b/servers/rendering/rasterizer_rd/shaders/copy.glsl index 2d7661f65f..075ee2af22 100644 --- a/servers/rendering/rasterizer_rd/shaders/copy.glsl +++ b/servers/rendering/rasterizer_rd/shaders/copy.glsl @@ -39,7 +39,13 @@ layout(push_constant, binding = 1, std430) uniform Params { } params; +#ifdef MODE_CUBEMAP_ARRAY_TO_PANORAMA +layout(set = 0, binding = 0) uniform samplerCubeArray source_color; +#elif defined(MODE_CUBEMAP_TO_PANORAMA) +layout(set = 0, binding = 0) uniform samplerCube source_color; +#else layout(set = 0, binding = 0) uniform sampler2D source_color; +#endif #ifdef GLOW_USE_AUTO_EXPOSURE layout(set = 1, binding = 0) uniform sampler2D source_auto_exposure; @@ -57,7 +63,7 @@ void main() { // Pixel being shaded ivec2 pos = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(pos, params.section.zw))) { //too large, do nothing + if (any(greaterThanEqual(pos, params.section.zw))) { //too large, do nothing return; } @@ -217,4 +223,25 @@ void main() { imageStore(dest_buffer, pos + params.target, color); #endif + +#if defined(MODE_CUBEMAP_TO_PANORAMA) || defined(MODE_CUBEMAP_ARRAY_TO_PANORAMA) + + const float PI = 3.14159265359; + vec2 uv = vec2(pos) / vec2(params.section.zw); + uv.y = 1.0 - uv.y; + float phi = uv.x * 2.0 * PI; + float theta = uv.y * PI; + + vec3 normal; + normal.x = sin(phi) * sin(theta) * -1.0; + normal.y = cos(theta); + normal.z = cos(phi) * sin(theta) * -1.0; + +#ifdef MODE_CUBEMAP_TO_PANORAMA + vec4 color = textureLod(source_color, normal, params.camera_z_far); //the biggest the lod the least the acne +#else + vec4 color = textureLod(source_color, vec4(normal, params.camera_z_far), 0.0); //the biggest the lod the least the acne +#endif + imageStore(dest_buffer, pos + params.target, color); +#endif } diff --git a/servers/rendering/rasterizer_rd/shaders/giprobe_debug.glsl b/servers/rendering/rasterizer_rd/shaders/giprobe_debug.glsl index b1784e7eee..c0fcb9a8c4 100644 --- a/servers/rendering/rasterizer_rd/shaders/giprobe_debug.glsl +++ b/servers/rendering/rasterizer_rd/shaders/giprobe_debug.glsl @@ -130,12 +130,24 @@ void main() { float strength = 0.0; switch (side) { - case POS_X: strength = aniso_pos.x; break; - case POS_Y: strength = aniso_pos.y; break; - case POS_Z: strength = aniso_pos.z; break; - case NEG_X: strength = aniso_neg.x; break; - case NEG_Y: strength = aniso_neg.y; break; - case NEG_Z: strength = aniso_neg.z; break; + case POS_X: + strength = aniso_pos.x; + break; + case POS_Y: + strength = aniso_pos.y; + break; + case POS_Z: + strength = aniso_pos.z; + break; + case NEG_X: + strength = aniso_neg.x; + break; + case NEG_Y: + strength = aniso_neg.y; + break; + case NEG_Z: + strength = aniso_neg.z; + break; } color_interp.xyz *= strength; @@ -184,22 +196,38 @@ void main() { int index = x + y * 4; float limit = 0.0; if (x < 8) { - if (index == 0) limit = 0.0625; - if (index == 1) limit = 0.5625; - if (index == 2) limit = 0.1875; - if (index == 3) limit = 0.6875; - if (index == 4) limit = 0.8125; - if (index == 5) limit = 0.3125; - if (index == 6) limit = 0.9375; - if (index == 7) limit = 0.4375; - if (index == 8) limit = 0.25; - if (index == 9) limit = 0.75; - if (index == 10) limit = 0.125; - if (index == 11) limit = 0.625; - if (index == 12) limit = 1.0; - if (index == 13) limit = 0.5; - if (index == 14) limit = 0.875; - if (index == 15) limit = 0.375; + if (index == 0) + limit = 0.0625; + if (index == 1) + limit = 0.5625; + if (index == 2) + limit = 0.1875; + if (index == 3) + limit = 0.6875; + if (index == 4) + limit = 0.8125; + if (index == 5) + limit = 0.3125; + if (index == 6) + limit = 0.9375; + if (index == 7) + limit = 0.4375; + if (index == 8) + limit = 0.25; + if (index == 9) + limit = 0.75; + if (index == 10) + limit = 0.125; + if (index == 11) + limit = 0.625; + if (index == 12) + limit = 1.0; + if (index == 13) + limit = 0.5; + if (index == 14) + limit = 0.875; + if (index == 15) + limit = 0.375; } if (frag_color.a < limit) { discard; diff --git a/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl b/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl index ec47887036..90e37b3ec4 100644 --- a/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl +++ b/servers/rendering/rasterizer_rd/shaders/scene_high_end.glsl @@ -22,7 +22,7 @@ layout(location = 3) in vec4 color_attrib; layout(location = 4) in vec2 uv_attrib; -#if defined(UV2_USED) || defined(USE_LIGHTMAP) +#if defined(UV2_USED) || defined(USE_LIGHTMAP) || defined(MODE_RENDER_MATERIAL) layout(location = 5) in vec2 uv2_attrib; #endif @@ -49,7 +49,7 @@ layout(location = 6) out vec3 binormal_interp; #endif #ifdef USE_MATERIAL_UNIFORMS -layout(set = 5, binding = 0, std140) uniform MaterialUniforms{ +layout(set = MATERIAL_UNIFORM_SET, binding = 0, std140) uniform MaterialUniforms{ /* clang-format off */ MATERIAL_UNIFORMS /* clang-format on */ @@ -263,6 +263,14 @@ VERTEX_SHADER_CODE } } #endif + +#ifdef MODE_RENDER_MATERIAL + if (scene_data.material_uv2_mode) { + gl_Position.xy = (uv2_attrib.xy + draw_call.bake_uv2_offset) * 2.0 - 1.0; + gl_Position.z = 0.00001; + gl_Position.w = 1.0; + } +#endif } /* clang-format off */ @@ -315,7 +323,7 @@ layout(location = 8) in float dp_clip; #endif #ifdef USE_MATERIAL_UNIFORMS -layout(set = 5, binding = 0, std140) uniform MaterialUniforms{ +layout(set = MATERIAL_UNIFORM_SET, binding = 0, std140) uniform MaterialUniforms{ /* clang-format off */ MATERIAL_UNIFORMS /* clang-format on */ @@ -420,7 +428,8 @@ float SchlickFresnel(float u) { } float GTR1(float NdotH, float a) { - if (a >= 1.0) return 1.0 / M_PI; + if (a >= 1.0) + return 1.0 / M_PI; float a2 = a * a; float t = 1.0 + (a2 - 1.0) * NdotH * NdotH; return (a2 - 1.0) / (M_PI * log(a2) * t); @@ -1916,42 +1925,96 @@ FRAGMENT_SHADER_CODE #if !defined(MODE_RENDER_DEPTH) && !defined(MODE_UNSHADED) //gi probes +#ifdef USE_LIGHTMAP + //lightmap + if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_LIGHTMAP_CAPTURE)) { //has lightmap capture + uint index = instances.data[instance_index].gi_offset; + + vec3 wnormal = mat3(scene_data.camera_matrix) * normal; + const float c1 = 0.429043; + const float c2 = 0.511664; + const float c3 = 0.743125; + const float c4 = 0.886227; + const float c5 = 0.247708; + ambient_light += (c1 * lightmap_captures.data[index].sh[8].rgb * (wnormal.x * wnormal.x - wnormal.y * wnormal.y) + + c3 * lightmap_captures.data[index].sh[6].rgb * wnormal.z * wnormal.z + + c4 * lightmap_captures.data[index].sh[0].rgb - + c5 * lightmap_captures.data[index].sh[6].rgb + + 2.0 * c1 * lightmap_captures.data[index].sh[4].rgb * wnormal.x * wnormal.y + + 2.0 * c1 * lightmap_captures.data[index].sh[7].rgb * wnormal.x * wnormal.z + + 2.0 * c1 * lightmap_captures.data[index].sh[5].rgb * wnormal.y * wnormal.z + + 2.0 * c2 * lightmap_captures.data[index].sh[3].rgb * wnormal.x + + 2.0 * c2 * lightmap_captures.data[index].sh[1].rgb * wnormal.y + + 2.0 * c2 * lightmap_captures.data[index].sh[2].rgb * wnormal.z); + + } else if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_LIGHTMAP)) { // has actual lightmap + bool uses_sh = bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_SH_LIGHTMAP); + uint ofs = instances.data[instance_index].gi_offset & 0xFFF; + vec3 uvw; + uvw.xy = uv2 * instances.data[instance_index].lightmap_uv_scale.zw + instances.data[instance_index].lightmap_uv_scale.xy; + uvw.z = float((instances.data[instance_index].gi_offset >> 12) & 0xFF); + + if (uses_sh) { + uvw.z *= 4.0; //SH textures use 4 times more data + vec3 lm_light_l0 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 0.0), 0.0).rgb; + vec3 lm_light_l1n1 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 1.0), 0.0).rgb; + vec3 lm_light_l1_0 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 2.0), 0.0).rgb; + vec3 lm_light_l1p1 = textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw + vec3(0.0, 0.0, 3.0), 0.0).rgb; + + uint idx = instances.data[instance_index].gi_offset >> 20; + vec3 n = normalize(lightmaps.data[idx].normal_xform * normal); + + ambient_light += lm_light_l0 * 0.282095f; + ambient_light += lm_light_l1n1 * 0.32573 * n.y; + ambient_light += lm_light_l1_0 * 0.32573 * n.z; + ambient_light += lm_light_l1p1 * 0.32573 * n.x; + if (metallic > 0.01) { // since the more direct bounced light is lost, we can kind of fake it with this trick + vec3 r = reflect(normalize(-vertex), normal); + specular_light += lm_light_l1n1 * 0.32573 * r.y; + specular_light += lm_light_l1_0 * 0.32573 * r.z; + specular_light += lm_light_l1p1 * 0.32573 * r.x; + } + + } else { + ambient_light += textureLod(sampler2DArray(lightmap_textures[ofs], material_samplers[SAMPLER_LINEAR_CLAMP]), uvw, 0.0).rgb; + } + } +#endif //lightmap capture #ifdef USE_VOXEL_CONE_TRACING - { // process giprobes - uint index1 = instances.data[instance_index].gi_offset & 0xFFFF; - if (index1 != 0xFFFF) { - vec3 ref_vec = normalize(reflect(normalize(vertex), normal)); - //find arbitrary tangent and bitangent, then build a matrix - vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); - vec3 tangent = normalize(cross(v0, normal)); - vec3 bitangent = normalize(cross(tangent, normal)); - mat3 normal_mat = mat3(tangent, bitangent, normal); + if (bool(instances.data[instance_index].flags & INSTANCE_FLAGS_USE_GIPROBE)) { // process giprobes - vec4 amb_accum = vec4(0.0); - vec4 spec_accum = vec4(0.0); - gi_probe_compute(index1, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); + uint index1 = instances.data[instance_index].gi_offset & 0xFFFF; + vec3 ref_vec = normalize(reflect(normalize(vertex), normal)); + //find arbitrary tangent and bitangent, then build a matrix + vec3 v0 = abs(normal.z) < 0.999 ? vec3(0.0, 0.0, 1.0) : vec3(0.0, 1.0, 0.0); + vec3 tangent = normalize(cross(v0, normal)); + vec3 bitangent = normalize(cross(tangent, normal)); + mat3 normal_mat = mat3(tangent, bitangent, normal); - uint index2 = instances.data[instance_index].gi_offset >> 16; + vec4 amb_accum = vec4(0.0); + vec4 spec_accum = vec4(0.0); + gi_probe_compute(index1, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); - if (index2 != 0xFFFF) { - gi_probe_compute(index2, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); - } + uint index2 = instances.data[instance_index].gi_offset >> 16; - if (amb_accum.a > 0.0) { - amb_accum.rgb /= amb_accum.a; - } + if (index2 != 0xFFFF) { + gi_probe_compute(index2, vertex, normal, ref_vec, normal_mat, roughness * roughness, ambient_light, specular_light, spec_accum, amb_accum); + } - if (spec_accum.a > 0.0) { - spec_accum.rgb /= spec_accum.a; - } + if (amb_accum.a > 0.0) { + amb_accum.rgb /= amb_accum.a; + } - specular_light = spec_accum.rgb; - ambient_light = amb_accum.rgb; + if (spec_accum.a > 0.0) { + spec_accum.rgb /= spec_accum.a; } + + specular_light = spec_accum.rgb; + ambient_light = amb_accum.rgb; } #endif @@ -2423,7 +2486,6 @@ FRAGMENT_SHADER_CODE ao_light_affect = mix(1.0, ao, ao_light_affect); specular_light = mix(scene_data.ao_color.rgb, specular_light, ao_light_affect); diffuse_light = mix(scene_data.ao_color.rgb, diffuse_light, ao_light_affect); - #else if (scene_data.ssao_enabled) { diff --git a/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl b/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl index ce4fabf9f2..89706b74d6 100644 --- a/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl +++ b/servers/rendering/rasterizer_rd/shaders/scene_high_end_inc.glsl @@ -3,7 +3,8 @@ layout(push_constant, binding = 0, std430) uniform DrawCall { uint instance_index; - uint pad[3]; //16 bits minimum size + uint pad; //16 bits minimum size + vec2 bake_uv2_offset; //used for bake to uv2, ignored otherwise } draw_call; @@ -77,6 +78,10 @@ layout(set = 0, binding = 3, std140) uniform SceneData { bool roughness_limiter_enabled; vec4 ao_color; + bool material_uv2_mode; + uint pad_material0; + uint pad_material1; + uint pad_material2; #if 0 vec4 ambient_light_color; @@ -115,11 +120,10 @@ layout(set = 0, binding = 3, std140) uniform SceneData { } scene_data; -#define INSTANCE_FLAGS_FORWARD_MASK 0x7 -#define INSTANCE_FLAGS_FORWARD_OMNI_LIGHT_SHIFT 3 -#define INSTANCE_FLAGS_FORWARD_SPOT_LIGHT_SHIFT 6 -#define INSTANCE_FLAGS_FORWARD_DECAL_SHIFT 9 - +#define INSTANCE_FLAGS_USE_LIGHTMAP_CAPTURE (1 << 8) +#define INSTANCE_FLAGS_USE_LIGHTMAP (1 << 9) +#define INSTANCE_FLAGS_USE_SH_LIGHTMAP (1 << 10) +#define INSTANCE_FLAGS_USE_GIPROBE (1 << 11) #define INSTANCE_FLAGS_MULTIMESH (1 << 12) #define INSTANCE_FLAGS_MULTIMESH_FORMAT_2D (1 << 13) #define INSTANCE_FLAGS_MULTIMESH_HAS_COLOR (1 << 14) @@ -135,8 +139,9 @@ struct InstanceData { mat4 normal_transform; uint flags; uint instance_uniforms_ofs; //base offset in global buffer for instance variables - uint gi_offset; //GI information when using lightmapping (VCT or lightmap) + uint gi_offset; //GI information when using lightmapping (VCT or lightmap index) uint layer_mask; + vec4 lightmap_uv_scale; }; layout(set = 0, binding = 4, std430) restrict readonly buffer Instances { @@ -248,12 +253,35 @@ gi_probes; layout(set = 0, binding = 9) uniform texture3D gi_probe_textures[MAX_GI_PROBE_TEXTURES]; +#define LIGHTMAP_FLAG_USE_DIRECTION 1 +#define LIGHTMAP_FLAG_USE_SPECULAR_DIRECTION 2 + +struct Lightmap { + mat3 normal_xform; +}; + +layout(set = 0, binding = 10, std140) restrict readonly buffer Lightmaps { + Lightmap data[]; +} +lightmaps; + +layout(set = 0, binding = 11) uniform texture2DArray lightmap_textures[MAX_LIGHTMAP_TEXTURES]; + +struct LightmapCapture { + vec4 sh[9]; +}; + +layout(set = 0, binding = 12, std140) restrict readonly buffer LightmapCaptures { + LightmapCapture data[]; +} +lightmap_captures; + #define CLUSTER_COUNTER_SHIFT 20 #define CLUSTER_POINTER_MASK ((1 << CLUSTER_COUNTER_SHIFT) - 1) #define CLUSTER_COUNTER_MASK 0xfff -layout(set = 0, binding = 10) uniform texture2D decal_atlas; -layout(set = 0, binding = 11) uniform texture2D decal_atlas_srgb; +layout(set = 0, binding = 13) uniform texture2D decal_atlas; +layout(set = 0, binding = 14) uniform texture2D decal_atlas_srgb; struct DecalData { mat4 xform; //to decal transform @@ -273,21 +301,21 @@ struct DecalData { float normal_fade; }; -layout(set = 0, binding = 12, std430) restrict readonly buffer Decals { +layout(set = 0, binding = 15, std430) restrict readonly buffer Decals { DecalData data[]; } decals; -layout(set = 0, binding = 13) uniform utexture3D cluster_texture; +layout(set = 0, binding = 16) uniform utexture3D cluster_texture; -layout(set = 0, binding = 14, std430) restrict readonly buffer ClusterData { +layout(set = 0, binding = 17, std430) restrict readonly buffer ClusterData { uint indices[]; } cluster_data; -layout(set = 0, binding = 15) uniform texture2D directional_shadow_atlas; +layout(set = 0, binding = 18) uniform texture2D directional_shadow_atlas; -layout(set = 0, binding = 16, std430) restrict readonly buffer GlobalVariableData { +layout(set = 0, binding = 19, std430) restrict readonly buffer GlobalVariableData { vec4 data[]; } global_variables; @@ -312,7 +340,7 @@ layout(set = 2, binding = 0) uniform textureCubeArray reflection_atlas; layout(set = 2, binding = 1) uniform texture2D shadow_atlas; -/* Set 1, Render Buffers */ +/* Set 3, Render Buffers */ layout(set = 3, binding = 0) uniform texture2D depth_buffer; layout(set = 3, binding = 1) uniform texture2D color_buffer; diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl index e3c26c9b72..11a0d85c58 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection.glsl @@ -68,7 +68,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl index 1a5dd5ab55..8571d9d6d1 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_filter.glsl @@ -120,7 +120,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl index cec6c14c76..f2c3230679 100644 --- a/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl +++ b/servers/rendering/rasterizer_rd/shaders/screen_space_reflection_scale.glsl @@ -34,7 +34,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } //do not filter, SSR will generate arctifacts if this is done diff --git a/servers/rendering/rasterizer_rd/shaders/ssao.glsl b/servers/rendering/rasterizer_rd/shaders/ssao.glsl index c9d7134610..0175e26b85 100644 --- a/servers/rendering/rasterizer_rd/shaders/ssao.glsl +++ b/servers/rendering/rasterizer_rd/shaders/ssao.glsl @@ -212,7 +212,7 @@ float sampleAO(in ivec2 ssC, in vec3 C, in vec3 n_C, in float ssDiskRadius, in f void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl b/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl index e90c788e08..877e5d50fe 100644 --- a/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl +++ b/servers/rendering/rasterizer_rd/shaders/ssao_blur.glsl @@ -49,7 +49,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl b/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl index 41f8fde3ca..4cb486a499 100644 --- a/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl +++ b/servers/rendering/rasterizer_rd/shaders/subsurface_scattering.glsl @@ -142,7 +142,7 @@ void main() { // Pixel being shaded ivec2 ssC = ivec2(gl_GlobalInvocationID.xy); - if (any(greaterThan(ssC, params.screen_size))) { //too large, do nothing + if (any(greaterThanEqual(ssC, params.screen_size))) { //too large, do nothing return; } diff --git a/servers/rendering/rendering_device.cpp b/servers/rendering/rendering_device.cpp index a3bb39cd90..aeac6f2eff 100644 --- a/servers/rendering/rendering_device.cpp +++ b/servers/rendering/rendering_device.cpp @@ -147,7 +147,7 @@ Ref<RDShaderBytecode> RenderingDevice::_shader_compile_from_source(const Ref<RDS return bytecode; } -RID RenderingDevice::_shader_create(const Ref<RDShaderBytecode> &p_bytecode) { +RID RenderingDevice::shader_create_from_bytecode(const Ref<RDShaderBytecode> &p_bytecode) { ERR_FAIL_COND_V(p_bytecode.is_null(), RID()); Vector<ShaderStageData> stage_data; @@ -276,7 +276,7 @@ void RenderingDevice::_bind_methods() { ClassDB::bind_method(D_METHOD("index_array_create", "index_buffer", "index_offset", "index_count"), &RenderingDevice::index_array_create); ClassDB::bind_method(D_METHOD("shader_compile_from_source", "shader_source", "allow_cache"), &RenderingDevice::_shader_compile_from_source, DEFVAL(true)); - ClassDB::bind_method(D_METHOD("shader_create", "shader_data"), &RenderingDevice::_shader_create); + ClassDB::bind_method(D_METHOD("shader_create", "shader_data"), &RenderingDevice::shader_create_from_bytecode); ClassDB::bind_method(D_METHOD("shader_get_vertex_input_attribute_mask", "shader"), &RenderingDevice::shader_get_vertex_input_attribute_mask); ClassDB::bind_method(D_METHOD("uniform_buffer_create", "size_bytes", "data"), &RenderingDevice::uniform_buffer_create, DEFVAL(Vector<uint8_t>())); diff --git a/servers/rendering/rendering_device.h b/servers/rendering/rendering_device.h index c76fce5b5c..c7d0a1cdd2 100644 --- a/servers/rendering/rendering_device.h +++ b/servers/rendering/rendering_device.h @@ -596,6 +596,7 @@ public: } }; + RID shader_create_from_bytecode(const Ref<RDShaderBytecode> &p_bytecode); virtual RID shader_create(const Vector<ShaderStageData> &p_stages) = 0; virtual uint32_t shader_get_vertex_input_attribute_mask(RID p_shader) = 0; @@ -1045,6 +1046,8 @@ public: virtual void submit() = 0; virtual void sync() = 0; + virtual uint64_t get_memory_usage() const = 0; + virtual RenderingDevice *create_local_device() = 0; static RenderingDevice *get_singleton(); @@ -1063,7 +1066,6 @@ protected: RID _vertex_array_create(uint32_t p_vertex_count, VertexFormatID p_vertex_format, const TypedArray<RID> &p_src_buffers); Ref<RDShaderBytecode> _shader_compile_from_source(const Ref<RDShaderSource> &p_source, bool p_allow_cache = true); - RID _shader_create(const Ref<RDShaderBytecode> &p_bytecode); RID _uniform_set_create(const Array &p_uniforms, RID p_shader, uint32_t p_shader_set); diff --git a/servers/rendering/rendering_device_binds.cpp b/servers/rendering/rendering_device_binds.cpp index 111755eba3..91076a538e 100644 --- a/servers/rendering/rendering_device_binds.cpp +++ b/servers/rendering/rendering_device_binds.cpp @@ -1,6 +1,36 @@ +/*************************************************************************/ +/* rendering_device_binds.cpp */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #include "rendering_device_binds.h" -Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFunction p_include_func, void *p_include_func_userdata) { +Error RDShaderFile::parse_versions_from_text(const String &p_text, const String p_defines, OpenIncludeFunction p_include_func, void *p_include_func_userdata) { Vector<String> lines = p_text.split("\n"); @@ -26,6 +56,9 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu { String ls = line.strip_edges(); + if (ls.begins_with("#[")) { //workaround for clang format + ls = ls.replace_first("#[", "["); + } if (ls.begins_with("[") && ls.ends_with("]")) { String section = ls.substr(1, ls.length() - 2).strip_edges(); if (section == "versions") { @@ -60,9 +93,17 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu } } + if (stage == RD::SHADER_STAGE_MAX && line.strip_edges() != "") { + line = line.strip_edges(); + if (line.begins_with("//") || line.begins_with("/*")) { + continue; //assuming comment (single line) + } + } + if (reading_versions) { String l = line.strip_edges(); if (l != "") { + int eqpos = l.find("="); if (eqpos == -1) { base_error = "Version syntax is version=\"<defines with C escaping>\"."; @@ -80,7 +121,7 @@ Error RDShaderFile::parse_versions_from_text(const String &p_text, OpenIncludeFu } define = "\n" + define.substr(1, define.length() - 2).c_unescape() + "\n"; //add newline before and after jsut in case - version_texts[version] = define; + version_texts[version] = define + "\n" + p_defines; } } else { if (stage == RD::SHADER_STAGE_MAX && line.strip_edges() != "") { diff --git a/servers/rendering/rendering_device_binds.h b/servers/rendering/rendering_device_binds.h index f57f59876d..fe8d554594 100644 --- a/servers/rendering/rendering_device_binds.h +++ b/servers/rendering/rendering_device_binds.h @@ -1,3 +1,33 @@ +/*************************************************************************/ +/* rendering_device_binds.h */ +/*************************************************************************/ +/* This file is part of: */ +/* GODOT ENGINE */ +/* https://godotengine.org */ +/*************************************************************************/ +/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur. */ +/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md). */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining */ +/* a copy of this software and associated documentation files (the */ +/* "Software"), to deal in the Software without restriction, including */ +/* without limitation the rights to use, copy, modify, merge, publish, */ +/* distribute, sublicense, and/or sell copies of the Software, and to */ +/* permit persons to whom the Software is furnished to do so, subject to */ +/* the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be */ +/* included in all copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */ +/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */ +/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/ +/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */ +/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */ +/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */ +/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ +/*************************************************************************/ + #ifndef RENDERING_DEVICE_BINDS_H #define RENDERING_DEVICE_BINDS_H @@ -292,8 +322,31 @@ public: return base_error; } + void print_errors(const String &p_file) { + if (base_error != "") { + ERR_PRINT("Error parsing shader '" + p_file + "':\n\n" + base_error); + } else { + for (Map<StringName, Ref<RDShaderBytecode>>::Element *E = versions.front(); E; E = E->next()) { + for (int i = 0; i < RD::SHADER_STAGE_MAX; i++) { + String error = E->get()->get_stage_compile_error(RD::ShaderStage(i)); + if (error != String()) { + static const char *stage_str[RD::SHADER_STAGE_MAX] = { + "vertex", + "fragment", + "tesselation_control", + "tesselation_evaluation", + "compute" + }; + + ERR_PRINT("Error parsing shader '" + p_file + "', version '" + String(E->key()) + "', stage '" + stage_str[i] + "':\n\n" + error); + } + } + } + } + } + typedef String (*OpenIncludeFunction)(const String &, void *userdata); - Error parse_versions_from_text(const String &p_text, OpenIncludeFunction p_include_func = nullptr, void *p_include_func_userdata = nullptr); + Error parse_versions_from_text(const String &p_text, const String p_defines = String(), OpenIncludeFunction p_include_func = nullptr, void *p_include_func_userdata = nullptr); protected: Dictionary _get_versions() const { diff --git a/servers/rendering/rendering_server_raster.h b/servers/rendering/rendering_server_raster.h index f7b963a015..5dd146861d 100644 --- a/servers/rendering/rendering_server_raster.h +++ b/servers/rendering/rendering_server_raster.h @@ -108,8 +108,12 @@ public: m_r m_name(m_type1 arg1, m_type2 arg2) { return BINDBASE->m_name(arg1, arg2); } #define BIND2RC(m_r, m_name, m_type1, m_type2) \ m_r m_name(m_type1 arg1, m_type2 arg2) const { return BINDBASE->m_name(arg1, arg2); } +#define BIND3R(m_r, m_name, m_type1, m_type2, m_type3) \ + m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3) { return BINDBASE->m_name(arg1, arg2, arg3); } #define BIND3RC(m_r, m_name, m_type1, m_type2, m_type3) \ m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3) const { return BINDBASE->m_name(arg1, arg2, arg3); } +#define BIND4R(m_r, m_name, m_type1, m_type2, m_type3, m_type4) \ + m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3, m_type4 arg4) { return BINDBASE->m_name(arg1, arg2, arg3, arg4); } #define BIND4RC(m_r, m_name, m_type1, m_type2, m_type3, m_type4) \ m_r m_name(m_type1 arg1, m_type2 arg2, m_type3 arg3, m_type4 arg4) const { return BINDBASE->m_name(arg1, arg2, arg3, arg4); } @@ -170,7 +174,7 @@ public: //these also go pass-through BIND0R(RID, texture_2d_placeholder_create) - BIND0R(RID, texture_2d_layered_placeholder_create) + BIND1R(RID, texture_2d_layered_placeholder_create, TextureLayeredType) BIND0R(RID, texture_3d_placeholder_create) BIND1RC(Ref<Image>, texture_2d_get, RID) @@ -404,23 +408,19 @@ public: BIND2(gi_probe_set_anisotropy_strength, RID, float) BIND1RC(float, gi_probe_get_anisotropy_strength, RID) - /* LIGHTMAP CAPTURE */ + /* LIGHTMAP */ - BIND0R(RID, lightmap_capture_create) + BIND0R(RID, lightmap_create) - BIND2(lightmap_capture_set_bounds, RID, const AABB &) - BIND1RC(AABB, lightmap_capture_get_bounds, RID) - - BIND2(lightmap_capture_set_octree, RID, const Vector<uint8_t> &) - BIND1RC(Vector<uint8_t>, lightmap_capture_get_octree, RID) - - BIND2(lightmap_capture_set_octree_cell_transform, RID, const Transform &) - BIND1RC(Transform, lightmap_capture_get_octree_cell_transform, RID) - BIND2(lightmap_capture_set_octree_cell_subdiv, RID, int) - BIND1RC(int, lightmap_capture_get_octree_cell_subdiv, RID) - - BIND2(lightmap_capture_set_energy, RID, float) - BIND1RC(float, lightmap_capture_get_energy, RID) + BIND3(lightmap_set_textures, RID, RID, bool) + BIND2(lightmap_set_probe_bounds, RID, const AABB &) + BIND2(lightmap_set_probe_interior, RID, bool) + BIND5(lightmap_set_probe_capture_data, RID, const PackedVector3Array &, const PackedColorArray &, const PackedInt32Array &, const PackedInt32Array &) + BIND1RC(PackedVector3Array, lightmap_get_probe_capture_points, RID) + BIND1RC(PackedColorArray, lightmap_get_probe_capture_sh, RID) + BIND1RC(PackedInt32Array, lightmap_get_probe_capture_tetrahedra, RID) + BIND1RC(PackedInt32Array, lightmap_get_probe_capture_bsp_tree, RID) + BIND1(lightmap_set_probe_capture_update_speed, float) /* PARTICLES */ @@ -532,6 +532,7 @@ public: BIND2(sky_set_radiance_size, RID, int) BIND2(sky_set_mode, RID, SkyMode) BIND2(sky_set_material, RID, RID) + BIND4R(Ref<Image>, sky_bake_panorama, RID, float, bool, const Size2i &) BIND0R(RID, environment_create) @@ -565,6 +566,8 @@ public: BIND7(environment_set_fog_depth, RID, bool, float, float, float, bool, float) BIND5(environment_set_fog_height, RID, bool, float, float, float) + BIND3R(Ref<Image>, environment_bake_panorama, RID, bool, const Size2i &) + BIND2(screen_space_roughness_limiter_set_active, bool, float) BIND1(sub_surface_scattering_set_quality, SubSurfaceScatteringQuality) BIND2(sub_surface_scattering_set_scale, float, float) @@ -605,7 +608,6 @@ public: BIND3(instance_set_blend_shape_weight, RID, int, float) BIND3(instance_set_surface_material, RID, int, RID) BIND2(instance_set_visible, RID, bool) - BIND3(instance_set_use_lightmap, RID, RID, RID) BIND2(instance_set_custom_aabb, RID, AABB) @@ -625,12 +627,15 @@ public: BIND5(instance_geometry_set_draw_range, RID, float, float, float, float) BIND2(instance_geometry_set_as_instance_lod, RID, RID) + BIND4(instance_geometry_set_lightmap, RID, RID, const Rect2 &, int) BIND3(instance_geometry_set_shader_parameter, RID, const StringName &, const Variant &) BIND2RC(Variant, instance_geometry_get_shader_parameter, RID, const StringName &) BIND2RC(Variant, instance_geometry_get_shader_parameter_default_value, RID, const StringName &) BIND2C(instance_geometry_get_shader_parameter_list, RID, List<PropertyInfo> *) + BIND3R(TypedArray<Image>, bake_render_uv2, RID, const Vector<RID> &, const Size2i &) + #undef BINDBASE //from now on, calls forwarded to this singleton #define BINDBASE RSG::canvas diff --git a/servers/rendering/rendering_server_scene.cpp b/servers/rendering/rendering_server_scene.cpp index 2c3c2730d5..95334ee102 100644 --- a/servers/rendering/rendering_server_scene.cpp +++ b/servers/rendering/rendering_server_scene.cpp @@ -169,19 +169,22 @@ void *RenderingServerScene::_instance_pair(void *p_self, OctreeElementID, Instan geom->decal_dirty = true; return E; //this element should make freeing faster - } else if (B->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { + } else if (B->base_type == RS::INSTANCE_LIGHTMAP && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(B->base_data); + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(B->base_data); InstanceGeometryData *geom = static_cast<InstanceGeometryData *>(A->base_data); - InstanceLightmapCaptureData::PairInfo pinfo; - pinfo.geometry = A; - pinfo.L = geom->lightmap_captures.push_back(B); - - List<InstanceLightmapCaptureData::PairInfo>::Element *E = lightmap_capture->geometries.push_back(pinfo); - ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + if (A->dynamic_gi) { + InstanceLightmapData::PairInfo pinfo; + pinfo.geometry = A; + pinfo.L = geom->lightmap_captures.push_back(B); + List<InstanceLightmapData::PairInfo>::Element *E = lightmap_data->geometries.push_back(pinfo); + ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + return E; //this element should make freeing faster + } else { + return nullptr; + } - return E; //this element should make freeing faster } else if (B->base_type == RS::INSTANCE_GI_PROBE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { InstanceGIProbeData *gi_probe = static_cast<InstanceGIProbeData *>(B->base_data); @@ -258,16 +261,18 @@ void RenderingServerScene::_instance_unpair(void *p_self, OctreeElementID, Insta decal->geometries.erase(E); geom->decal_dirty = true; - } else if (B->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { + } else if (B->base_type == RS::INSTANCE_LIGHTMAP && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(B->base_data); - InstanceGeometryData *geom = static_cast<InstanceGeometryData *>(A->base_data); + if (udata) { //only for dynamic geometries + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(B->base_data); + InstanceGeometryData *geom = static_cast<InstanceGeometryData *>(A->base_data); - List<InstanceLightmapCaptureData::PairInfo>::Element *E = reinterpret_cast<List<InstanceLightmapCaptureData::PairInfo>::Element *>(udata); + List<InstanceLightmapData::PairInfo>::Element *E = reinterpret_cast<List<InstanceLightmapData::PairInfo>::Element *>(udata); - geom->lightmap_captures.erase(E->get().L); - lightmap_capture->geometries.erase(E); - ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + geom->lightmap_captures.erase(E->get().L); + lightmap_data->geometries.erase(E); + ((RenderingServerScene *)p_self)->_instance_queue_update(A, false, false); //need to update capture + } } else if (B->base_type == RS::INSTANCE_GI_PROBE && ((1 << A->base_type) & RS::INSTANCE_GEOMETRY_MASK)) { @@ -418,12 +423,12 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { RSG::scene_render->free(decal->instance); } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(instance->base_data); + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(instance->base_data); //erase dependencies, since no longer a lightmap - while (lightmap_capture->users.front()) { - instance_set_use_lightmap(lightmap_capture->users.front()->get()->self, RID(), RID()); + while (lightmap_data->users.front()) { + instance_geometry_set_lightmap(lightmap_data->users.front()->get()->self, RID(), Rect2(), 0); } } break; case RS::INSTANCE_GI_PROBE: { @@ -443,14 +448,6 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { gi_probe_update_list.remove(&gi_probe->update_element); } - if (instance->lightmap_capture) { - Instance *capture = (Instance *)instance->lightmap_capture; - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(capture->base_data); - lightmap_capture->users.erase(instance); - instance->lightmap_capture = nullptr; - instance->lightmap = RID(); - } - RSG::scene_render->free(gi_probe->probe_instance); } break; @@ -515,11 +512,11 @@ void RenderingServerScene::instance_set_base(RID p_instance, RID p_base) { decal->instance = RSG::scene_render->decal_instance_create(p_base); } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { - InstanceLightmapCaptureData *lightmap_capture = memnew(InstanceLightmapCaptureData); - instance->base_data = lightmap_capture; - //lightmap_capture->instance = RSG::scene_render->lightmap_capture_instance_create(p_base); + InstanceLightmapData *lightmap_data = memnew(InstanceLightmapData); + instance->base_data = lightmap_data; + //lightmap_data->instance = RSG::scene_render->lightmap_data_instance_create(p_base); } break; case RS::INSTANCE_GI_PROBE: { @@ -736,9 +733,9 @@ void RenderingServerScene::instance_set_visible(RID p_instance, bool p_visible) } } break; - case RS::INSTANCE_LIGHTMAP_CAPTURE: { + case RS::INSTANCE_LIGHTMAP: { if (instance->octree_id && instance->scenario) { - instance->scenario->octree.set_pairable(instance->octree_id, p_visible, 1 << RS::INSTANCE_LIGHTMAP_CAPTURE, p_visible ? RS::INSTANCE_GEOMETRY_MASK : 0); + instance->scenario->octree.set_pairable(instance->octree_id, p_visible, 1 << RS::INSTANCE_LIGHTMAP, p_visible ? RS::INSTANCE_GEOMETRY_MASK : 0); } } break; @@ -756,30 +753,6 @@ inline bool is_geometry_instance(RenderingServer::InstanceType p_type) { return p_type == RS::INSTANCE_MESH || p_type == RS::INSTANCE_MULTIMESH || p_type == RS::INSTANCE_PARTICLES || p_type == RS::INSTANCE_IMMEDIATE; } -void RenderingServerScene::instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) { - - Instance *instance = instance_owner.getornull(p_instance); - ERR_FAIL_COND(!instance); - - if (instance->lightmap_capture) { - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(((Instance *)instance->lightmap_capture)->base_data); - lightmap_capture->users.erase(instance); - instance->lightmap = RID(); - instance->lightmap_capture = nullptr; - } - - if (p_lightmap_instance.is_valid()) { - Instance *lightmap_instance = instance_owner.getornull(p_lightmap_instance); - ERR_FAIL_COND(!lightmap_instance); - ERR_FAIL_COND(lightmap_instance->base_type != RS::INSTANCE_LIGHTMAP_CAPTURE); - instance->lightmap_capture = lightmap_instance; - - InstanceLightmapCaptureData *lightmap_capture = static_cast<InstanceLightmapCaptureData *>(((Instance *)instance->lightmap_capture)->base_data); - lightmap_capture->users.insert(instance); - instance->lightmap = p_lightmap; - } -} - void RenderingServerScene::instance_set_custom_aabb(RID p_instance, AABB p_aabb) { Instance *instance = instance_owner.getornull(p_instance); @@ -968,6 +941,29 @@ void RenderingServerScene::instance_geometry_set_draw_range(RID p_instance, floa void RenderingServerScene::instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance) { } +void RenderingServerScene::instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_slice_index) { + + Instance *instance = instance_owner.getornull(p_instance); + ERR_FAIL_COND(!instance); + + if (instance->lightmap) { + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(((Instance *)instance->lightmap)->base_data); + lightmap_data->users.erase(instance); + instance->lightmap = nullptr; + } + + Instance *lightmap_instance = instance_owner.getornull(p_lightmap); + + instance->lightmap = lightmap_instance; + instance->lightmap_uv_scale = p_lightmap_uv_scale; + instance->lightmap_slice_index = p_slice_index; + + if (lightmap_instance) { + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(lightmap_instance->base_data); + lightmap_data->users.insert(instance); + } +} + void RenderingServerScene::instance_geometry_set_shader_parameter(RID p_instance, const StringName &p_parameter, const Variant &p_value) { Instance *instance = instance_owner.getornull(p_instance); @@ -1084,16 +1080,29 @@ void RenderingServerScene::_update_instance(Instance *p_instance) { } } - if (!p_instance->lightmap_capture && geom->lightmap_captures.size()) { + if (!p_instance->lightmap && geom->lightmap_captures.size()) { //affected by lightmap captures, must update capture info! _update_instance_lightmap_captures(p_instance); } else { - if (!p_instance->lightmap_capture_data.empty()) { - p_instance->lightmap_capture_data.resize(0); //not in use, clear capture data + if (!p_instance->lightmap_sh.empty()) { + p_instance->lightmap_sh.clear(); //don't need SH + p_instance->lightmap_target_sh.clear(); //don't need SH } } } + if (p_instance->base_type == RS::INSTANCE_LIGHTMAP) { + + //if this moved, update the captured objects + InstanceLightmapData *lightmap_data = static_cast<InstanceLightmapData *>(p_instance->base_data); + //erase dependencies, since no longer a lightmap + + for (List<InstanceLightmapData::PairInfo>::Element *E = lightmap_data->geometries.front(); E; E = E->next()) { + Instance *geom = E->get().geometry; + _instance_queue_update(geom, true, false); + } + } + p_instance->mirror = p_instance->transform.basis.determinant() < 0.0; AABB new_aabb; @@ -1113,7 +1122,7 @@ void RenderingServerScene::_update_instance(Instance *p_instance) { uint32_t pairable_mask = 0; bool pairable = false; - if (p_instance->base_type == RS::INSTANCE_LIGHT || p_instance->base_type == RS::INSTANCE_REFLECTION_PROBE || p_instance->base_type == RS::INSTANCE_DECAL || p_instance->base_type == RS::INSTANCE_LIGHTMAP_CAPTURE) { + if (p_instance->base_type == RS::INSTANCE_LIGHT || p_instance->base_type == RS::INSTANCE_REFLECTION_PROBE || p_instance->base_type == RS::INSTANCE_DECAL || p_instance->base_type == RS::INSTANCE_LIGHTMAP) { pairable_mask = p_instance->visible ? RS::INSTANCE_GEOMETRY_MASK : 0; pairable = true; @@ -1203,9 +1212,9 @@ void RenderingServerScene::_update_instance_aabb(Instance *p_instance) { new_aabb = RSG::storage->gi_probe_get_bounds(p_instance->base); } break; - case RenderingServer::INSTANCE_LIGHTMAP_CAPTURE: { + case RenderingServer::INSTANCE_LIGHTMAP: { - new_aabb = RSG::storage->lightmap_capture_get_bounds(p_instance->base); + new_aabb = RSG::storage->lightmap_get_aabb(p_instance->base); } break; default: { @@ -1219,235 +1228,82 @@ void RenderingServerScene::_update_instance_aabb(Instance *p_instance) { p_instance->aabb = new_aabb; } -_FORCE_INLINE_ static void _light_capture_sample_octree(const RasterizerStorage::LightmapCaptureOctree *p_octree, int p_cell_subdiv, const Vector3 &p_pos, const Vector3 &p_dir, float p_level, Vector3 &r_color, float &r_alpha) { - - static const Vector3 aniso_normal[6] = { - Vector3(-1, 0, 0), - Vector3(1, 0, 0), - Vector3(0, -1, 0), - Vector3(0, 1, 0), - Vector3(0, 0, -1), - Vector3(0, 0, 1) - }; - - int size = 1 << (p_cell_subdiv - 1); - - int clamp_v = size - 1; - //first of all, clamp - Vector3 pos; - pos.x = CLAMP(p_pos.x, 0, clamp_v); - pos.y = CLAMP(p_pos.y, 0, clamp_v); - pos.z = CLAMP(p_pos.z, 0, clamp_v); - - float level = (p_cell_subdiv - 1) - p_level; - - int target_level; - float level_filter; - if (level <= 0.0) { - level_filter = 0; - target_level = 0; - } else { - target_level = Math::ceil(level); - level_filter = target_level - level; - } - - Vector3 color[2][8]; - float alpha[2][8]; - zeromem(alpha, sizeof(float) * 2 * 8); - - //find cell at given level first - - for (int c = 0; c < 2; c++) { - - int current_level = MAX(0, target_level - c); - int level_cell_size = (1 << (p_cell_subdiv - 1)) >> current_level; - - for (int n = 0; n < 8; n++) { +void RenderingServerScene::_update_instance_lightmap_captures(Instance *p_instance) { - int x = int(pos.x); - int y = int(pos.y); - int z = int(pos.z); + bool first_set = p_instance->lightmap_sh.size() == 0; + p_instance->lightmap_sh.resize(9); //using SH + p_instance->lightmap_target_sh.resize(9); //using SH + Color *instance_sh = p_instance->lightmap_target_sh.ptrw(); + bool inside = false; + Color accum_sh[9]; + float accum_blend = 0.0; - if (n & 1) - x += level_cell_size; - if (n & 2) - y += level_cell_size; - if (n & 4) - z += level_cell_size; + InstanceGeometryData *geom = static_cast<InstanceGeometryData *>(p_instance->base_data); + for (List<Instance *>::Element *E = geom->lightmap_captures.front(); E; E = E->next()) { + Instance *lightmap = E->get(); - int ofs_x = 0; - int ofs_y = 0; - int ofs_z = 0; + bool interior = RSG::storage->lightmap_is_interior(lightmap->base); - x = CLAMP(x, 0, clamp_v); - y = CLAMP(y, 0, clamp_v); - z = CLAMP(z, 0, clamp_v); + if (inside && !interior) { + continue; //we are inside, ignore exteriors + } - int half = size / 2; - uint32_t cell = 0; - for (int i = 0; i < current_level; i++) { + Transform to_bounds = lightmap->transform.affine_inverse(); + Vector3 center = p_instance->transform.xform(p_instance->aabb.position + p_instance->aabb.size * 0.5); //use aabb center - const RasterizerStorage::LightmapCaptureOctree *bc = &p_octree[cell]; + Vector3 lm_pos = to_bounds.xform(center); - int child = 0; - if (x >= ofs_x + half) { - child |= 1; - ofs_x += half; - } - if (y >= ofs_y + half) { - child |= 2; - ofs_y += half; - } - if (z >= ofs_z + half) { - child |= 4; - ofs_z += half; - } + AABB bounds = RSG::storage->lightmap_get_aabb(lightmap->base); + if (!bounds.has_point(lm_pos)) { + continue; //not in this lightmap + } - cell = bc->children[child]; - if (cell == RasterizerStorage::LightmapCaptureOctree::CHILD_EMPTY) - break; + Color sh[9]; + RSG::storage->lightmap_tap_sh_light(lightmap->base, lm_pos, sh); - half >>= 1; + //rotate it + Basis rot = lightmap->transform.basis.orthonormalized(); + for (int i = 0; i < 3; i++) { + float csh[9]; + for (int j = 0; j < 9; j++) { + csh[j] = sh[j][i]; } - - if (cell == RasterizerStorage::LightmapCaptureOctree::CHILD_EMPTY) { - alpha[c][n] = 0; - } else { - alpha[c][n] = p_octree[cell].alpha; - - for (int i = 0; i < 6; i++) { - //anisotropic read light - float amount = p_dir.dot(aniso_normal[i]); - if (amount < 0) - amount = 0; - color[c][n].x += p_octree[cell].light[i][0] / 1024.0 * amount; - color[c][n].y += p_octree[cell].light[i][1] / 1024.0 * amount; - color[c][n].z += p_octree[cell].light[i][2] / 1024.0 * amount; - } + rot.rotate_sh(csh); + for (int j = 0; j < 9; j++) { + sh[j][i] = csh[j]; } - - //print_line("\tlev " + itos(c) + " - " + itos(n) + " alpha: " + rtos(cells[test_cell].alpha) + " col: " + color[c][n]); } - } - - float target_level_size = size >> target_level; - Vector3 pos_fract[2]; - - pos_fract[0].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[0].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[0].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - target_level_size = size >> MAX(0, target_level - 1); - - pos_fract[1].x = Math::fmod(pos.x, target_level_size) / target_level_size; - pos_fract[1].y = Math::fmod(pos.y, target_level_size) / target_level_size; - pos_fract[1].z = Math::fmod(pos.z, target_level_size) / target_level_size; - - float alpha_interp[2]; - Vector3 color_interp[2]; - - for (int i = 0; i < 2; i++) { - - Vector3 color_x00 = color[i][0].lerp(color[i][1], pos_fract[i].x); - Vector3 color_xy0 = color[i][2].lerp(color[i][3], pos_fract[i].x); - Vector3 blend_z0 = color_x00.lerp(color_xy0, pos_fract[i].y); - - Vector3 color_x0z = color[i][4].lerp(color[i][5], pos_fract[i].x); - Vector3 color_xyz = color[i][6].lerp(color[i][7], pos_fract[i].x); - Vector3 blend_z1 = color_x0z.lerp(color_xyz, pos_fract[i].y); - - color_interp[i] = blend_z0.lerp(blend_z1, pos_fract[i].z); - - float alpha_x00 = Math::lerp(alpha[i][0], alpha[i][1], pos_fract[i].x); - float alpha_xy0 = Math::lerp(alpha[i][2], alpha[i][3], pos_fract[i].x); - float alpha_z0 = Math::lerp(alpha_x00, alpha_xy0, pos_fract[i].y); - - float alpha_x0z = Math::lerp(alpha[i][4], alpha[i][5], pos_fract[i].x); - float alpha_xyz = Math::lerp(alpha[i][6], alpha[i][7], pos_fract[i].x); - float alpha_z1 = Math::lerp(alpha_x0z, alpha_xyz, pos_fract[i].y); - - alpha_interp[i] = Math::lerp(alpha_z0, alpha_z1, pos_fract[i].z); - } - - r_color = color_interp[0].lerp(color_interp[1], level_filter); - r_alpha = Math::lerp(alpha_interp[0], alpha_interp[1], level_filter); - - //print_line("pos: " + p_posf + " level " + rtos(p_level) + " down to " + itos(target_level) + "." + rtos(level_filter) + " color " + r_color + " alpha " + rtos(r_alpha)); -} - -_FORCE_INLINE_ static Color _light_capture_voxel_cone_trace(const RasterizerStorage::LightmapCaptureOctree *p_octree, const Vector3 &p_pos, const Vector3 &p_dir, float p_aperture, int p_cell_subdiv) { - - float bias = 0.0; //no need for bias here - float max_distance = (Vector3(1, 1, 1) * (1 << (p_cell_subdiv - 1))).length(); - - float dist = bias; - float alpha = 0.0; - Vector3 color; - - Vector3 scolor; - float salpha; - - while (dist < max_distance && alpha < 0.95) { - float diameter = MAX(1.0, 2.0 * p_aperture * dist); - _light_capture_sample_octree(p_octree, p_cell_subdiv, p_pos + dist * p_dir, p_dir, log2(diameter), scolor, salpha); - float a = (1.0 - alpha); - color += scolor * a; - alpha += a * salpha; - dist += diameter * 0.5; - } - - return Color(color.x, color.y, color.z, alpha); -} -void RenderingServerScene::_update_instance_lightmap_captures(Instance *p_instance) { + Vector3 inner_pos = ((lm_pos - bounds.position) / bounds.size) * 2.0 - Vector3(1.0, 1.0, 1.0); - InstanceGeometryData *geom = static_cast<InstanceGeometryData *>(p_instance->base_data); + float blend = MAX(inner_pos.x, MAX(inner_pos.y, inner_pos.z)); + //make blend more rounded + blend = Math::lerp(inner_pos.length(), blend, blend); + blend *= blend; + blend = MAX(0.0, 1.0 - blend); - static const Vector3 cone_traces[12] = { - Vector3(0, 0, 1), - Vector3(0.866025, 0, 0.5), - Vector3(0.267617, 0.823639, 0.5), - Vector3(-0.700629, 0.509037, 0.5), - Vector3(-0.700629, -0.509037, 0.5), - Vector3(0.267617, -0.823639, 0.5), - Vector3(0, 0, -1), - Vector3(0.866025, 0, -0.5), - Vector3(0.267617, 0.823639, -0.5), - Vector3(-0.700629, 0.509037, -0.5), - Vector3(-0.700629, -0.509037, -0.5), - Vector3(0.267617, -0.823639, -0.5) - }; - - float cone_aperture = 0.577; // tan(angle) 60 degrees - - if (p_instance->lightmap_capture_data.empty()) { - p_instance->lightmap_capture_data.resize(12); + if (interior && !inside) { + //do not blend, just replace + for (int j = 0; j < 9; j++) { + accum_sh[j] = sh[j] * blend; + } + accum_blend = blend; + inside = true; + } else { + for (int j = 0; j < 9; j++) { + accum_sh[j] += sh[j] * blend; + } + accum_blend += blend; + } } - //print_line("update captures for pos: " + p_instance->transform.origin); - - for (int i = 0; i < 12; i++) - new (&p_instance->lightmap_capture_data.ptrw()[i]) Color; - - //this could use some sort of blending.. - for (List<Instance *>::Element *E = geom->lightmap_captures.front(); E; E = E->next()) { - const Vector<RasterizerStorage::LightmapCaptureOctree> *octree = RSG::storage->lightmap_capture_get_octree_ptr(E->get()->base); - //print_line("octree size: " + itos(octree->size())); - if (octree->size() == 0) - continue; - Transform to_cell_xform = RSG::storage->lightmap_capture_get_octree_cell_transform(E->get()->base); - int cell_subdiv = RSG::storage->lightmap_capture_get_octree_cell_subdiv(E->get()->base); - to_cell_xform = to_cell_xform * E->get()->transform.affine_inverse(); - - const RasterizerStorage::LightmapCaptureOctree *octree_r = octree->ptr(); - - Vector3 pos = to_cell_xform.xform(p_instance->transform.origin); - - for (int i = 0; i < 12; i++) { + if (accum_blend > 0.0) { + for (int j = 0; j < 9; j++) { - Vector3 dir = to_cell_xform.basis.xform(cone_traces[i]).normalized(); - Color capture = _light_capture_voxel_cone_trace(octree_r, pos, dir, cone_aperture, cell_subdiv); - p_instance->lightmap_capture_data.write[i] += capture; + instance_sh[j] = accum_sh[j] / accum_blend; + if (first_set) { + p_instance->lightmap_sh.write[j] = instance_sh[j]; + } } } } @@ -1523,9 +1379,15 @@ bool RenderingServerScene::_light_instance_update_shadow(Instance *p_instance, c int splits = 0; switch (RSG::storage->light_directional_get_shadow_mode(p_instance->base)) { - case RS::LIGHT_DIRECTIONAL_SHADOW_ORTHOGONAL: splits = 1; break; - case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_2_SPLITS: splits = 2; break; - case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_4_SPLITS: splits = 4; break; + case RS::LIGHT_DIRECTIONAL_SHADOW_ORTHOGONAL: + splits = 1; + break; + case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_2_SPLITS: + splits = 2; + break; + case RS::LIGHT_DIRECTIONAL_SHADOW_PARALLEL_4_SPLITS: + splits = 4; + break; } real_t distances[5]; @@ -1756,10 +1618,6 @@ bool RenderingServerScene::_light_instance_update_shadow(Instance *p_instance, c } else { camera_matrix_square.set_frustum(vp_he.y * 2.0, 1.0, Vector2(), distances[(i == 0 || !overlap) ? i : i - 1], distances[i + 1], false); } - - if (i == 0) { - //print_line("prev he: " + vp_he + " new he: " + camera_matrix_square.get_viewport_half_extents()); - } } Vector3 endpoints_square[8]; // frustum plane endpoints @@ -2141,6 +1999,7 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const reflection_probe_cull_count = 0; decal_cull_count = 0; gi_probe_cull_count = 0; + lightmap_cull_count = 0; //light_samplers_culled=0; @@ -2155,6 +2014,8 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const //removed, will replace with culling /* STEP 4 - REMOVE FURTHER CULLED OBJECTS, ADD LIGHTS */ + uint64_t frame_number = RSG::rasterizer->get_frame_number(); + float lightmap_probe_update_speed = RSG::storage->lightmap_get_probe_capture_update_speed() * RSG::rasterizer->get_frame_delta_time(); for (int i = 0; i < instance_cull_count; i++) { @@ -2233,6 +2094,12 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const gi_probe_instance_cull_result[gi_probe_cull_count] = gi_probe->probe_instance; gi_probe_cull_count++; } + } else if (ins->base_type == RS::INSTANCE_LIGHTMAP && ins->visible) { + + if (lightmap_cull_count < MAX_LIGHTMAPS_CULLED) { + lightmap_cull_result[lightmap_cull_count] = ins; + lightmap_cull_count++; + } } else if (((1 << ins->base_type) & RS::INSTANCE_GEOMETRY_MASK) && ins->visible && ins->cast_shadows != RS::SHADOW_CASTING_SETTING_SHADOWS_ONLY) { @@ -2301,6 +2168,14 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const geom->gi_probes_dirty = false; } + if (ins->last_frame_pass != frame_number && !ins->lightmap_target_sh.empty() && !ins->lightmap_sh.empty()) { + Color *sh = ins->lightmap_sh.ptrw(); + const Color *target_sh = ins->lightmap_target_sh.ptr(); + for (uint32_t j = 0; j < 9; j++) { + sh[j] = sh[j].lerp(target_sh[j], MIN(1.0, lightmap_probe_update_speed)); + } + } + ins->depth = near_plane.distance_to(ins->transform.origin); ins->depth_layer = CLAMP(int(ins->depth * 16 / z_far), 0, 15); } @@ -2315,6 +2190,7 @@ void RenderingServerScene::_prepare_scene(const Transform p_cam_transform, const ins->last_render_pass = render_pass; } + ins->last_frame_pass = frame_number; } /* STEP 5 - PROCESS LIGHTS */ @@ -2488,7 +2364,7 @@ void RenderingServerScene::_render_scene(RID p_render_buffers, const Transform p /* PROCESS GEOMETRY AND DRAW SCENE */ RENDER_TIMESTAMP("Render Scene "); - RSG::scene_render->render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_orthogonal, (RasterizerScene::InstanceBase **)instance_cull_result, instance_cull_count, light_instance_cull_result, light_cull_count + directional_light_count, reflection_probe_instance_cull_result, reflection_probe_cull_count, gi_probe_instance_cull_result, gi_probe_cull_count, decal_instance_cull_result, decal_cull_count, environment, camera_effects, p_shadow_atlas, p_reflection_probe.is_valid() ? RID() : scenario->reflection_atlas, p_reflection_probe, p_reflection_probe_pass); + RSG::scene_render->render_scene(p_render_buffers, p_cam_transform, p_cam_projection, p_cam_orthogonal, (RasterizerScene::InstanceBase **)instance_cull_result, instance_cull_count, light_instance_cull_result, light_cull_count + directional_light_count, reflection_probe_instance_cull_result, reflection_probe_cull_count, gi_probe_instance_cull_result, gi_probe_cull_count, decal_instance_cull_result, decal_cull_count, (RasterizerScene::InstanceBase **)lightmap_cull_result, lightmap_cull_count, environment, camera_effects, p_shadow_atlas, p_reflection_probe.is_valid() ? RID() : scenario->reflection_atlas, p_reflection_probe, p_reflection_probe_pass); } void RenderingServerScene::render_empty_scene(RID p_render_buffers, RID p_scenario, RID p_shadow_atlas) { @@ -2503,7 +2379,7 @@ void RenderingServerScene::render_empty_scene(RID p_render_buffers, RID p_scenar else environment = scenario->fallback_environment; RENDER_TIMESTAMP("Render Empty Scene "); - RSG::scene_render->render_scene(p_render_buffers, Transform(), CameraMatrix(), true, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, environment, RID(), p_shadow_atlas, scenario->reflection_atlas, RID(), 0); + RSG::scene_render->render_scene(p_render_buffers, Transform(), CameraMatrix(), true, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, environment, RID(), p_shadow_atlas, scenario->reflection_atlas, RID(), 0); #endif } @@ -3117,7 +2993,7 @@ bool RenderingServerScene::free(RID p_rid) { Instance *instance = instance_owner.getornull(p_rid); - instance_set_use_lightmap(p_rid, RID(), RID()); + instance_geometry_set_lightmap(p_rid, RID(), Rect2(), 0); instance_set_scenario(p_rid, RID()); instance_set_base(p_rid, RID()); instance_geometry_set_material_override(p_rid, RID()); @@ -3138,6 +3014,10 @@ bool RenderingServerScene::free(RID p_rid) { return true; } +TypedArray<Image> RenderingServerScene::bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size) { + return RSG::scene_render->bake_render_uv2(p_base, p_material_overrides, p_image_size); +} + RenderingServerScene *RenderingServerScene::singleton = nullptr; RenderingServerScene::RenderingServerScene() { diff --git a/servers/rendering/rendering_server_scene.h b/servers/rendering/rendering_server_scene.h index eb66cea3aa..df9e650ac7 100644 --- a/servers/rendering/rendering_server_scene.h +++ b/servers/rendering/rendering_server_scene.h @@ -51,6 +51,7 @@ public: MAX_DECALS_CULLED = 4096, MAX_GI_PROBES_CULLED = 4096, MAX_ROOM_CULL = 32, + MAX_LIGHTMAPS_CULLED = 4096, MAX_EXTERIOR_PORTALS = 128, }; @@ -171,6 +172,8 @@ public: float lod_end_hysteresis; RID lod_instance; + Vector<Color> lightmap_target_sh; //target is used for incrementally changing the SH over time, this avoids pops in some corner cases and when going interior <-> exterior + uint64_t last_render_pass; uint64_t last_frame_pass; @@ -374,7 +377,7 @@ public: SelfList<InstanceGIProbeData>::List gi_probe_update_list; - struct InstanceLightmapCaptureData : public InstanceBaseData { + struct InstanceLightmapData : public InstanceBaseData { struct PairInfo { List<Instance *>::Element *L; //iterator in geometry @@ -384,7 +387,7 @@ public: Set<Instance *> users; - InstanceLightmapCaptureData() { + InstanceLightmapData() { } }; @@ -401,6 +404,8 @@ public: int decal_cull_count; RID gi_probe_instance_cull_result[MAX_GI_PROBES_CULLED]; int gi_probe_cull_count; + Instance *lightmap_cull_result[MAX_LIGHTS_CULLED]; + int lightmap_cull_count; RID_PtrOwner<Instance> instance_owner; @@ -414,7 +419,6 @@ public: virtual void instance_set_blend_shape_weight(RID p_instance, int p_shape, float p_weight); virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material); virtual void instance_set_visible(RID p_instance, bool p_visible); - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap); virtual void instance_set_custom_aabb(RID p_instance, AABB p_aabb); @@ -434,6 +438,7 @@ public: virtual void instance_geometry_set_draw_range(RID p_instance, float p_min, float p_max, float p_min_margin, float p_max_margin); virtual void instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance); + virtual void instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_slice_index); void _update_instance_shader_parameters_from_material(Map<StringName, RasterizerScene::InstanceBase::InstanceShaderParameter> &isparams, const Map<StringName, RasterizerScene::InstanceBase::InstanceShaderParameter> &existing_isparams, RID p_material); @@ -460,6 +465,8 @@ public: void render_probes(); + TypedArray<Image> bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size); + bool free(RID p_rid); RenderingServerScene(); diff --git a/servers/rendering/rendering_server_wrap_mt.cpp b/servers/rendering/rendering_server_wrap_mt.cpp index 4ca13dbef9..a7a166bec6 100644 --- a/servers/rendering/rendering_server_wrap_mt.cpp +++ b/servers/rendering/rendering_server_wrap_mt.cpp @@ -138,7 +138,7 @@ void RenderingServerWrapMT::finish() { spot_light_free_cached_ids(); reflection_probe_free_cached_ids(); gi_probe_free_cached_ids(); - lightmap_capture_free_cached_ids(); + lightmap_free_cached_ids(); particles_free_cached_ids(); camera_free_cached_ids(); viewport_free_cached_ids(); diff --git a/servers/rendering/rendering_server_wrap_mt.h b/servers/rendering/rendering_server_wrap_mt.h index d4e58485b8..6580b71508 100644 --- a/servers/rendering/rendering_server_wrap_mt.h +++ b/servers/rendering/rendering_server_wrap_mt.h @@ -92,7 +92,7 @@ public: //these also go pass-through virtual RID texture_2d_placeholder_create() { return rendering_server->texture_2d_placeholder_create(); } - virtual RID texture_2d_layered_placeholder_create() { return rendering_server->texture_2d_layered_placeholder_create(); } + virtual RID texture_2d_layered_placeholder_create(TextureLayeredType p_type) { return rendering_server->texture_2d_layered_placeholder_create(p_type); } virtual RID texture_3d_placeholder_create() { return rendering_server->texture_3d_placeholder_create(); } FUNC1RC(Ref<Image>, texture_2d_get, RID) @@ -324,19 +324,17 @@ public: /* LIGHTMAP CAPTURE */ - FUNCRID(lightmap_capture) + FUNCRID(lightmap) + FUNC3(lightmap_set_textures, RID, RID, bool) + FUNC2(lightmap_set_probe_bounds, RID, const AABB &) + FUNC2(lightmap_set_probe_interior, RID, bool) + FUNC5(lightmap_set_probe_capture_data, RID, const PackedVector3Array &, const PackedColorArray &, const PackedInt32Array &, const PackedInt32Array &) + FUNC1RC(PackedVector3Array, lightmap_get_probe_capture_points, RID) + FUNC1RC(PackedColorArray, lightmap_get_probe_capture_sh, RID) + FUNC1RC(PackedInt32Array, lightmap_get_probe_capture_tetrahedra, RID) + FUNC1RC(PackedInt32Array, lightmap_get_probe_capture_bsp_tree, RID) - FUNC2(lightmap_capture_set_bounds, RID, const AABB &) - FUNC1RC(AABB, lightmap_capture_get_bounds, RID) - - FUNC2(lightmap_capture_set_octree, RID, const Vector<uint8_t> &) - FUNC1RC(Vector<uint8_t>, lightmap_capture_get_octree, RID) - FUNC2(lightmap_capture_set_octree_cell_transform, RID, const Transform &) - FUNC1RC(Transform, lightmap_capture_get_octree_cell_transform, RID) - FUNC2(lightmap_capture_set_octree_cell_subdiv, RID, int) - FUNC1RC(int, lightmap_capture_get_octree_cell_subdiv, RID) - FUNC2(lightmap_capture_set_energy, RID, float) - FUNC1RC(float, lightmap_capture_get_energy, RID) + FUNC1(lightmap_set_probe_capture_update_speed, float) /* PARTICLES */ @@ -442,6 +440,7 @@ public: FUNC2(sky_set_radiance_size, RID, int) FUNC2(sky_set_mode, RID, SkyMode) FUNC2(sky_set_material, RID, RID) + FUNC4R(Ref<Image>, sky_bake_panorama, RID, float, bool, const Size2i &) /* ENVIRONMENT API */ @@ -478,6 +477,8 @@ public: FUNC7(environment_set_fog_depth, RID, bool, float, float, float, bool, float) FUNC5(environment_set_fog_height, RID, bool, float, float, float) + FUNC3R(Ref<Image>, environment_bake_panorama, RID, bool, const Size2i &) + FUNC2(screen_space_roughness_limiter_set_active, bool, float) FUNC1(sub_surface_scattering_set_quality, SubSurfaceScatteringQuality) FUNC2(sub_surface_scattering_set_scale, float, float) @@ -511,7 +512,6 @@ public: FUNC3(instance_set_blend_shape_weight, RID, int, float) FUNC3(instance_set_surface_material, RID, int, RID) FUNC2(instance_set_visible, RID, bool) - FUNC3(instance_set_use_lightmap, RID, RID, RID) FUNC2(instance_set_custom_aabb, RID, AABB) @@ -531,12 +531,17 @@ public: FUNC5(instance_geometry_set_draw_range, RID, float, float, float, float) FUNC2(instance_geometry_set_as_instance_lod, RID, RID) + FUNC4(instance_geometry_set_lightmap, RID, RID, const Rect2 &, int) FUNC3(instance_geometry_set_shader_parameter, RID, const StringName &, const Variant &) FUNC2RC(Variant, instance_geometry_get_shader_parameter, RID, const StringName &) FUNC2RC(Variant, instance_geometry_get_shader_parameter_default_value, RID, const StringName &) FUNC2SC(instance_geometry_get_shader_parameter_list, RID, List<PropertyInfo> *) + /* BAKE */ + + FUNC3R(TypedArray<Image>, bake_render_uv2, RID, const Vector<RID> &, const Size2i &) + /* CANVAS (2D) */ FUNCRID(canvas) diff --git a/servers/rendering/shader_language.cpp b/servers/rendering/shader_language.cpp index 93593effd4..e3725043d9 100644 --- a/servers/rendering/shader_language.cpp +++ b/servers/rendering/shader_language.cpp @@ -132,6 +132,7 @@ const char *ShaderLanguage::token_names[TK_MAX] = { "TYPE_ISAMPLER3D", "TYPE_USAMPLER3D", "TYPE_SAMPLERCUBE", + "TYPE_SAMPLERCUBEARRAY", "INTERPOLATION_FLAT", "INTERPOLATION_SMOOTH", "CONST", @@ -283,6 +284,7 @@ const ShaderLanguage::KeyWord ShaderLanguage::keyword_list[] = { { TK_TYPE_ISAMPLER3D, "isampler3D" }, { TK_TYPE_USAMPLER3D, "usampler3D" }, { TK_TYPE_SAMPLERCUBE, "samplerCube" }, + { TK_TYPE_SAMPLERCUBEARRAY, "samplerCubeArray" }, { TK_INTERPOLATION_FLAT, "flat" }, { TK_INTERPOLATION_SMOOTH, "smooth" }, { TK_CONST, "const" }, @@ -783,7 +785,8 @@ bool ShaderLanguage::is_token_datatype(TokenType p_type) { p_type == TK_TYPE_SAMPLER3D || p_type == TK_TYPE_ISAMPLER3D || p_type == TK_TYPE_USAMPLER3D || - p_type == TK_TYPE_SAMPLERCUBE); + p_type == TK_TYPE_SAMPLERCUBE || + p_type == TK_TYPE_SAMPLERCUBEARRAY); } ShaderLanguage::DataType ShaderLanguage::get_token_datatype(TokenType p_type) { @@ -826,9 +829,12 @@ ShaderLanguage::DataPrecision ShaderLanguage::get_token_precision(TokenType p_ty String ShaderLanguage::get_precision_name(DataPrecision p_type) { switch (p_type) { - case PRECISION_LOWP: return "lowp"; - case PRECISION_MEDIUMP: return "mediump"; - case PRECISION_HIGHP: return "highp"; + case PRECISION_LOWP: + return "lowp"; + case PRECISION_MEDIUMP: + return "mediump"; + case PRECISION_HIGHP: + return "highp"; default: break; } @@ -839,38 +845,72 @@ String ShaderLanguage::get_datatype_name(DataType p_type) { switch (p_type) { - case TYPE_VOID: return "void"; - case TYPE_BOOL: return "bool"; - case TYPE_BVEC2: return "bvec2"; - case TYPE_BVEC3: return "bvec3"; - case TYPE_BVEC4: return "bvec4"; - case TYPE_INT: return "int"; - case TYPE_IVEC2: return "ivec2"; - case TYPE_IVEC3: return "ivec3"; - case TYPE_IVEC4: return "ivec4"; - case TYPE_UINT: return "uint"; - case TYPE_UVEC2: return "uvec2"; - case TYPE_UVEC3: return "uvec3"; - case TYPE_UVEC4: return "uvec4"; - case TYPE_FLOAT: return "float"; - case TYPE_VEC2: return "vec2"; - case TYPE_VEC3: return "vec3"; - case TYPE_VEC4: return "vec4"; - case TYPE_MAT2: return "mat2"; - case TYPE_MAT3: return "mat3"; - case TYPE_MAT4: return "mat4"; - case TYPE_SAMPLER2D: return "sampler2D"; - case TYPE_ISAMPLER2D: return "isampler2D"; - case TYPE_USAMPLER2D: return "usampler2D"; - case TYPE_SAMPLER2DARRAY: return "sampler2DArray"; - case TYPE_ISAMPLER2DARRAY: return "isampler2DArray"; - case TYPE_USAMPLER2DARRAY: return "usampler2DArray"; - case TYPE_SAMPLER3D: return "sampler3D"; - case TYPE_ISAMPLER3D: return "isampler3D"; - case TYPE_USAMPLER3D: return "usampler3D"; - case TYPE_SAMPLERCUBE: return "samplerCube"; - case TYPE_STRUCT: return "struct"; - case TYPE_MAX: return "invalid"; + case TYPE_VOID: + return "void"; + case TYPE_BOOL: + return "bool"; + case TYPE_BVEC2: + return "bvec2"; + case TYPE_BVEC3: + return "bvec3"; + case TYPE_BVEC4: + return "bvec4"; + case TYPE_INT: + return "int"; + case TYPE_IVEC2: + return "ivec2"; + case TYPE_IVEC3: + return "ivec3"; + case TYPE_IVEC4: + return "ivec4"; + case TYPE_UINT: + return "uint"; + case TYPE_UVEC2: + return "uvec2"; + case TYPE_UVEC3: + return "uvec3"; + case TYPE_UVEC4: + return "uvec4"; + case TYPE_FLOAT: + return "float"; + case TYPE_VEC2: + return "vec2"; + case TYPE_VEC3: + return "vec3"; + case TYPE_VEC4: + return "vec4"; + case TYPE_MAT2: + return "mat2"; + case TYPE_MAT3: + return "mat3"; + case TYPE_MAT4: + return "mat4"; + case TYPE_SAMPLER2D: + return "sampler2D"; + case TYPE_ISAMPLER2D: + return "isampler2D"; + case TYPE_USAMPLER2D: + return "usampler2D"; + case TYPE_SAMPLER2DARRAY: + return "sampler2DArray"; + case TYPE_ISAMPLER2DARRAY: + return "isampler2DArray"; + case TYPE_USAMPLER2DARRAY: + return "usampler2DArray"; + case TYPE_SAMPLER3D: + return "sampler3D"; + case TYPE_ISAMPLER3D: + return "isampler3D"; + case TYPE_USAMPLER3D: + return "usampler3D"; + case TYPE_SAMPLERCUBE: + return "samplerCube"; + case TYPE_SAMPLERCUBEARRAY: + return "samplerCubeArray"; + case TYPE_STRUCT: + return "struct"; + case TYPE_MAX: + return "invalid"; } return ""; @@ -2011,6 +2051,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureSize", TYPE_IVEC3, { TYPE_ISAMPLER3D, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureSize", TYPE_IVEC3, { TYPE_USAMPLER3D, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureSize", TYPE_IVEC2, { TYPE_SAMPLERCUBE, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, + { "textureSize", TYPE_IVEC2, { TYPE_SAMPLERCUBEARRAY, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "texture", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC2, TYPE_VOID }, TAG_GLOBAL, false }, { "texture", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC2, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, @@ -2032,6 +2073,8 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "texture", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, false }, { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, + { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_VOID }, TAG_GLOBAL, false }, + { "texture", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, { "textureProj", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureProj", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_VEC4, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2062,6 +2105,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureLod", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureLod", TYPE_UVEC4, { TYPE_USAMPLER3D, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "textureLod", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, + { "textureLod", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, false }, { "texelFetch", TYPE_VEC4, { TYPE_SAMPLER2D, TYPE_IVEC2, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, { "texelFetch", TYPE_IVEC4, { TYPE_ISAMPLER2D, TYPE_IVEC2, TYPE_INT, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2093,6 +2137,7 @@ const ShaderLanguage::BuiltinFuncDef ShaderLanguage::builtin_func_defs[] = { { "textureGrad", TYPE_IVEC4, { TYPE_ISAMPLER3D, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureGrad", TYPE_UVEC4, { TYPE_USAMPLER3D, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "textureGrad", TYPE_VEC4, { TYPE_SAMPLERCUBE, TYPE_VEC3, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, + { "textureGrad", TYPE_VEC4, { TYPE_SAMPLERCUBEARRAY, TYPE_VEC4, TYPE_VEC3, TYPE_VEC3, TYPE_VOID }, TAG_GLOBAL, true }, { "dFdx", TYPE_FLOAT, { TYPE_FLOAT, TYPE_VOID }, TAG_GLOBAL, true }, { "dFdx", TYPE_VEC2, { TYPE_VEC2, TYPE_VOID }, TAG_GLOBAL, true }, @@ -2583,7 +2628,8 @@ bool ShaderLanguage::is_sampler_type(DataType p_type) { p_type == TYPE_SAMPLER3D || p_type == TYPE_ISAMPLER3D || p_type == TYPE_USAMPLER3D || - p_type == TYPE_SAMPLERCUBE; + p_type == TYPE_SAMPLERCUBE || + p_type == TYPE_SAMPLERCUBEARRAY; } Variant ShaderLanguage::constant_value_to_variant(const Vector<ShaderLanguage::ConstantNode::Value> &p_value, DataType p_type, ShaderLanguage::ShaderNode::Uniform::Hint p_hint) { @@ -2677,7 +2723,9 @@ Variant ShaderLanguage::constant_value_to_variant(const Vector<ShaderLanguage::C case ShaderLanguage::TYPE_USAMPLER2DARRAY: case ShaderLanguage::TYPE_USAMPLER2D: case ShaderLanguage::TYPE_USAMPLER3D: - case ShaderLanguage::TYPE_SAMPLERCUBE: { + case ShaderLanguage::TYPE_SAMPLERCUBE: + case ShaderLanguage::TYPE_SAMPLERCUBEARRAY: { + // Texture types, likely not relevant here. break; } @@ -2696,8 +2744,12 @@ Variant ShaderLanguage::constant_value_to_variant(const Vector<ShaderLanguage::C PropertyInfo ShaderLanguage::uniform_to_property_info(const ShaderNode::Uniform &p_uniform) { PropertyInfo pi; switch (p_uniform.type) { - case ShaderLanguage::TYPE_VOID: pi.type = Variant::NIL; break; - case ShaderLanguage::TYPE_BOOL: pi.type = Variant::BOOL; break; + case ShaderLanguage::TYPE_VOID: + pi.type = Variant::NIL; + break; + case ShaderLanguage::TYPE_BOOL: + pi.type = Variant::BOOL; + break; case ShaderLanguage::TYPE_BVEC2: pi.type = Variant::INT; pi.hint = PROPERTY_HINT_FLAGS; @@ -2739,8 +2791,12 @@ PropertyInfo ShaderLanguage::uniform_to_property_info(const ShaderNode::Uniform } } break; - case ShaderLanguage::TYPE_VEC2: pi.type = Variant::VECTOR2; break; - case ShaderLanguage::TYPE_VEC3: pi.type = Variant::VECTOR3; break; + case ShaderLanguage::TYPE_VEC2: + pi.type = Variant::VECTOR2; + break; + case ShaderLanguage::TYPE_VEC3: + pi.type = Variant::VECTOR3; + break; case ShaderLanguage::TYPE_VEC4: { if (p_uniform.hint == ShaderLanguage::ShaderNode::Uniform::HINT_COLOR) { pi.type = Variant::COLOR; @@ -2748,9 +2804,15 @@ PropertyInfo ShaderLanguage::uniform_to_property_info(const ShaderNode::Uniform pi.type = Variant::PLANE; } } break; - case ShaderLanguage::TYPE_MAT2: pi.type = Variant::TRANSFORM2D; break; - case ShaderLanguage::TYPE_MAT3: pi.type = Variant::BASIS; break; - case ShaderLanguage::TYPE_MAT4: pi.type = Variant::TRANSFORM; break; + case ShaderLanguage::TYPE_MAT2: + pi.type = Variant::TRANSFORM2D; + break; + case ShaderLanguage::TYPE_MAT3: + pi.type = Variant::BASIS; + break; + case ShaderLanguage::TYPE_MAT4: + pi.type = Variant::TRANSFORM; + break; case ShaderLanguage::TYPE_SAMPLER2D: case ShaderLanguage::TYPE_ISAMPLER2D: case ShaderLanguage::TYPE_USAMPLER2D: { @@ -2765,7 +2827,7 @@ PropertyInfo ShaderLanguage::uniform_to_property_info(const ShaderNode::Uniform pi.type = Variant::OBJECT; pi.hint = PROPERTY_HINT_RESOURCE_TYPE; - pi.hint_string = "TextureArray"; + pi.hint_string = "TextureLayered"; } break; case ShaderLanguage::TYPE_SAMPLER3D: case ShaderLanguage::TYPE_ISAMPLER3D: @@ -2774,11 +2836,12 @@ PropertyInfo ShaderLanguage::uniform_to_property_info(const ShaderNode::Uniform pi.hint = PROPERTY_HINT_RESOURCE_TYPE; pi.hint_string = "Texture3D"; } break; - case ShaderLanguage::TYPE_SAMPLERCUBE: { + case ShaderLanguage::TYPE_SAMPLERCUBE: + case ShaderLanguage::TYPE_SAMPLERCUBEARRAY: { pi.type = Variant::OBJECT; pi.hint = PROPERTY_HINT_RESOURCE_TYPE; - pi.hint_string = "CubeMap"; + pi.hint_string = "TextureLayered"; } break; case ShaderLanguage::TYPE_STRUCT: { // FIXME: Implement this. @@ -2829,6 +2892,7 @@ uint32_t ShaderLanguage::get_type_size(DataType p_type) { case TYPE_ISAMPLER3D: case TYPE_USAMPLER3D: case TYPE_SAMPLERCUBE: + case TYPE_SAMPLERCUBEARRAY: return 4; //not really, but useful for indices case TYPE_STRUCT: // FIXME: Implement. @@ -3758,12 +3822,23 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons e.is_op = true; switch (tk.type) { - case TK_OP_SUB: e.op = OP_NEGATE; break; - case TK_OP_NOT: e.op = OP_NOT; break; - case TK_OP_BIT_INVERT: e.op = OP_BIT_INVERT; break; - case TK_OP_INCREMENT: e.op = OP_INCREMENT; break; - case TK_OP_DECREMENT: e.op = OP_DECREMENT; break; - default: ERR_FAIL_V(nullptr); + case TK_OP_SUB: + e.op = OP_NEGATE; + break; + case TK_OP_NOT: + e.op = OP_NOT; + break; + case TK_OP_BIT_INVERT: + e.op = OP_BIT_INVERT; + break; + case TK_OP_INCREMENT: + e.op = OP_INCREMENT; + break; + case TK_OP_DECREMENT: + e.op = OP_DECREMENT; + break; + default: + ERR_FAIL_V(nullptr); } expression.push_back(e); @@ -4157,12 +4232,23 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons } switch (expr->get_datatype()) { - case TYPE_BVEC2: member_type = TYPE_BOOL; break; - case TYPE_VEC2: member_type = TYPE_FLOAT; break; - case TYPE_IVEC2: member_type = TYPE_INT; break; - case TYPE_UVEC2: member_type = TYPE_UINT; break; - case TYPE_MAT2: member_type = TYPE_VEC2; break; - default: break; + case TYPE_BVEC2: + member_type = TYPE_BOOL; + break; + case TYPE_VEC2: + member_type = TYPE_FLOAT; + break; + case TYPE_IVEC2: + member_type = TYPE_INT; + break; + case TYPE_UVEC2: + member_type = TYPE_UINT; + break; + case TYPE_MAT2: + member_type = TYPE_VEC2; + break; + default: + break; } break; @@ -4180,12 +4266,23 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons } switch (expr->get_datatype()) { - case TYPE_BVEC3: member_type = TYPE_BOOL; break; - case TYPE_VEC3: member_type = TYPE_FLOAT; break; - case TYPE_IVEC3: member_type = TYPE_INT; break; - case TYPE_UVEC3: member_type = TYPE_UINT; break; - case TYPE_MAT3: member_type = TYPE_VEC3; break; - default: break; + case TYPE_BVEC3: + member_type = TYPE_BOOL; + break; + case TYPE_VEC3: + member_type = TYPE_FLOAT; + break; + case TYPE_IVEC3: + member_type = TYPE_INT; + break; + case TYPE_UVEC3: + member_type = TYPE_UINT; + break; + case TYPE_MAT3: + member_type = TYPE_VEC3; + break; + default: + break; } break; case TYPE_BVEC4: @@ -4202,12 +4299,23 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons } switch (expr->get_datatype()) { - case TYPE_BVEC4: member_type = TYPE_BOOL; break; - case TYPE_VEC4: member_type = TYPE_FLOAT; break; - case TYPE_IVEC4: member_type = TYPE_INT; break; - case TYPE_UVEC4: member_type = TYPE_UINT; break; - case TYPE_MAT4: member_type = TYPE_VEC4; break; - default: break; + case TYPE_BVEC4: + member_type = TYPE_BOOL; + break; + case TYPE_VEC4: + member_type = TYPE_FLOAT; + break; + case TYPE_IVEC4: + member_type = TYPE_INT; + break; + case TYPE_UVEC4: + member_type = TYPE_UINT; + break; + case TYPE_MAT4: + member_type = TYPE_VEC4; + break; + default: + break; } break; default: { @@ -4267,37 +4375,99 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons switch (tk.type) { - case TK_OP_EQUAL: o.op = OP_EQUAL; break; - case TK_OP_NOT_EQUAL: o.op = OP_NOT_EQUAL; break; - case TK_OP_LESS: o.op = OP_LESS; break; - case TK_OP_LESS_EQUAL: o.op = OP_LESS_EQUAL; break; - case TK_OP_GREATER: o.op = OP_GREATER; break; - case TK_OP_GREATER_EQUAL: o.op = OP_GREATER_EQUAL; break; - case TK_OP_AND: o.op = OP_AND; break; - case TK_OP_OR: o.op = OP_OR; break; - case TK_OP_ADD: o.op = OP_ADD; break; - case TK_OP_SUB: o.op = OP_SUB; break; - case TK_OP_MUL: o.op = OP_MUL; break; - case TK_OP_DIV: o.op = OP_DIV; break; - case TK_OP_MOD: o.op = OP_MOD; break; - case TK_OP_SHIFT_LEFT: o.op = OP_SHIFT_LEFT; break; - case TK_OP_SHIFT_RIGHT: o.op = OP_SHIFT_RIGHT; break; - case TK_OP_ASSIGN: o.op = OP_ASSIGN; break; - case TK_OP_ASSIGN_ADD: o.op = OP_ASSIGN_ADD; break; - case TK_OP_ASSIGN_SUB: o.op = OP_ASSIGN_SUB; break; - case TK_OP_ASSIGN_MUL: o.op = OP_ASSIGN_MUL; break; - case TK_OP_ASSIGN_DIV: o.op = OP_ASSIGN_DIV; break; - case TK_OP_ASSIGN_MOD: o.op = OP_ASSIGN_MOD; break; - case TK_OP_ASSIGN_SHIFT_LEFT: o.op = OP_ASSIGN_SHIFT_LEFT; break; - case TK_OP_ASSIGN_SHIFT_RIGHT: o.op = OP_ASSIGN_SHIFT_RIGHT; break; - case TK_OP_ASSIGN_BIT_AND: o.op = OP_ASSIGN_BIT_AND; break; - case TK_OP_ASSIGN_BIT_OR: o.op = OP_ASSIGN_BIT_OR; break; - case TK_OP_ASSIGN_BIT_XOR: o.op = OP_ASSIGN_BIT_XOR; break; - case TK_OP_BIT_AND: o.op = OP_BIT_AND; break; - case TK_OP_BIT_OR: o.op = OP_BIT_OR; break; - case TK_OP_BIT_XOR: o.op = OP_BIT_XOR; break; - case TK_QUESTION: o.op = OP_SELECT_IF; break; - case TK_COLON: o.op = OP_SELECT_ELSE; break; + case TK_OP_EQUAL: + o.op = OP_EQUAL; + break; + case TK_OP_NOT_EQUAL: + o.op = OP_NOT_EQUAL; + break; + case TK_OP_LESS: + o.op = OP_LESS; + break; + case TK_OP_LESS_EQUAL: + o.op = OP_LESS_EQUAL; + break; + case TK_OP_GREATER: + o.op = OP_GREATER; + break; + case TK_OP_GREATER_EQUAL: + o.op = OP_GREATER_EQUAL; + break; + case TK_OP_AND: + o.op = OP_AND; + break; + case TK_OP_OR: + o.op = OP_OR; + break; + case TK_OP_ADD: + o.op = OP_ADD; + break; + case TK_OP_SUB: + o.op = OP_SUB; + break; + case TK_OP_MUL: + o.op = OP_MUL; + break; + case TK_OP_DIV: + o.op = OP_DIV; + break; + case TK_OP_MOD: + o.op = OP_MOD; + break; + case TK_OP_SHIFT_LEFT: + o.op = OP_SHIFT_LEFT; + break; + case TK_OP_SHIFT_RIGHT: + o.op = OP_SHIFT_RIGHT; + break; + case TK_OP_ASSIGN: + o.op = OP_ASSIGN; + break; + case TK_OP_ASSIGN_ADD: + o.op = OP_ASSIGN_ADD; + break; + case TK_OP_ASSIGN_SUB: + o.op = OP_ASSIGN_SUB; + break; + case TK_OP_ASSIGN_MUL: + o.op = OP_ASSIGN_MUL; + break; + case TK_OP_ASSIGN_DIV: + o.op = OP_ASSIGN_DIV; + break; + case TK_OP_ASSIGN_MOD: + o.op = OP_ASSIGN_MOD; + break; + case TK_OP_ASSIGN_SHIFT_LEFT: + o.op = OP_ASSIGN_SHIFT_LEFT; + break; + case TK_OP_ASSIGN_SHIFT_RIGHT: + o.op = OP_ASSIGN_SHIFT_RIGHT; + break; + case TK_OP_ASSIGN_BIT_AND: + o.op = OP_ASSIGN_BIT_AND; + break; + case TK_OP_ASSIGN_BIT_OR: + o.op = OP_ASSIGN_BIT_OR; + break; + case TK_OP_ASSIGN_BIT_XOR: + o.op = OP_ASSIGN_BIT_XOR; + break; + case TK_OP_BIT_AND: + o.op = OP_BIT_AND; + break; + case TK_OP_BIT_OR: + o.op = OP_BIT_OR; + break; + case TK_OP_BIT_XOR: + o.op = OP_BIT_XOR; + break; + case TK_QUESTION: + o.op = OP_SELECT_IF; + break; + case TK_COLON: + o.op = OP_SELECT_ELSE; + break; default: { _set_error("Invalid token for operator: " + get_token_text(tk)); return nullptr; @@ -4333,14 +4503,30 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons int priority; switch (expression[i].op) { - case OP_EQUAL: priority = 8; break; - case OP_NOT_EQUAL: priority = 8; break; - case OP_LESS: priority = 7; break; - case OP_LESS_EQUAL: priority = 7; break; - case OP_GREATER: priority = 7; break; - case OP_GREATER_EQUAL: priority = 7; break; - case OP_AND: priority = 12; break; - case OP_OR: priority = 14; break; + case OP_EQUAL: + priority = 8; + break; + case OP_NOT_EQUAL: + priority = 8; + break; + case OP_LESS: + priority = 7; + break; + case OP_LESS_EQUAL: + priority = 7; + break; + case OP_GREATER: + priority = 7; + break; + case OP_GREATER_EQUAL: + priority = 7; + break; + case OP_AND: + priority = 12; + break; + case OP_OR: + priority = 14; + break; case OP_NOT: priority = 3; unary = true; @@ -4349,27 +4535,69 @@ ShaderLanguage::Node *ShaderLanguage::_parse_expression(BlockNode *p_block, cons priority = 3; unary = true; break; - case OP_ADD: priority = 5; break; - case OP_SUB: priority = 5; break; - case OP_MUL: priority = 4; break; - case OP_DIV: priority = 4; break; - case OP_MOD: priority = 4; break; - case OP_SHIFT_LEFT: priority = 6; break; - case OP_SHIFT_RIGHT: priority = 6; break; - case OP_ASSIGN: priority = 16; break; - case OP_ASSIGN_ADD: priority = 16; break; - case OP_ASSIGN_SUB: priority = 16; break; - case OP_ASSIGN_MUL: priority = 16; break; - case OP_ASSIGN_DIV: priority = 16; break; - case OP_ASSIGN_MOD: priority = 16; break; - case OP_ASSIGN_SHIFT_LEFT: priority = 16; break; - case OP_ASSIGN_SHIFT_RIGHT: priority = 16; break; - case OP_ASSIGN_BIT_AND: priority = 16; break; - case OP_ASSIGN_BIT_OR: priority = 16; break; - case OP_ASSIGN_BIT_XOR: priority = 16; break; - case OP_BIT_AND: priority = 9; break; - case OP_BIT_OR: priority = 11; break; - case OP_BIT_XOR: priority = 10; break; + case OP_ADD: + priority = 5; + break; + case OP_SUB: + priority = 5; + break; + case OP_MUL: + priority = 4; + break; + case OP_DIV: + priority = 4; + break; + case OP_MOD: + priority = 4; + break; + case OP_SHIFT_LEFT: + priority = 6; + break; + case OP_SHIFT_RIGHT: + priority = 6; + break; + case OP_ASSIGN: + priority = 16; + break; + case OP_ASSIGN_ADD: + priority = 16; + break; + case OP_ASSIGN_SUB: + priority = 16; + break; + case OP_ASSIGN_MUL: + priority = 16; + break; + case OP_ASSIGN_DIV: + priority = 16; + break; + case OP_ASSIGN_MOD: + priority = 16; + break; + case OP_ASSIGN_SHIFT_LEFT: + priority = 16; + break; + case OP_ASSIGN_SHIFT_RIGHT: + priority = 16; + break; + case OP_ASSIGN_BIT_AND: + priority = 16; + break; + case OP_ASSIGN_BIT_OR: + priority = 16; + break; + case OP_ASSIGN_BIT_XOR: + priority = 16; + break; + case OP_BIT_AND: + priority = 9; + break; + case OP_BIT_OR: + priority = 11; + break; + case OP_BIT_XOR: + priority = 10; + break; case OP_BIT_INVERT: priority = 3; unary = true; @@ -5972,7 +6200,7 @@ Error ShaderLanguage::_parse_shader(const Map<StringName, FunctionInfo> &p_funct return ERR_PARSE_ERROR; } } else { - if (uniform_scope == ShaderNode::Uniform::SCOPE_LOCAL && (type == TYPE_MAT2 || type == TYPE_MAT3 || type == TYPE_MAT4)) { + if (uniform_scope == ShaderNode::Uniform::SCOPE_INSTANCE && (type == TYPE_MAT2 || type == TYPE_MAT3 || type == TYPE_MAT4)) { _set_error("Uniforms with 'instance' qualifiers can't be of matrix type."); return ERR_PARSE_ERROR; } @@ -6688,7 +6916,8 @@ static int _get_first_ident_pos(const String &p_code) { if (GETCHAR(0) == '/' && GETCHAR(1) == '/') { idx += 2; while (true) { - if (GETCHAR(0) == 0) return 0; + if (GETCHAR(0) == 0) + return 0; if (GETCHAR(0) == '\n') { idx++; break; // loop @@ -6698,7 +6927,8 @@ static int _get_first_ident_pos(const String &p_code) { } else if (GETCHAR(0) == '/' && GETCHAR(1) == '*') { idx += 2; while (true) { - if (GETCHAR(0) == 0) return 0; + if (GETCHAR(0) == 0) + return 0; if (GETCHAR(0) == '*' && GETCHAR(1) == '/') { idx += 2; break; // loop @@ -7085,9 +7315,15 @@ Error ShaderLanguage::complete(const String &p_code, const Map<StringName, Funct limit = 4; } break; - case TYPE_MAT2: limit = 2; break; - case TYPE_MAT3: limit = 3; break; - case TYPE_MAT4: limit = 4; break; + case TYPE_MAT2: + limit = 2; + break; + case TYPE_MAT3: + limit = 3; + break; + case TYPE_MAT4: + limit = 4; + break; default: { } } diff --git a/servers/rendering/shader_language.h b/servers/rendering/shader_language.h index 973e1c4937..c78d850e21 100644 --- a/servers/rendering/shader_language.h +++ b/servers/rendering/shader_language.h @@ -79,6 +79,7 @@ public: TK_TYPE_ISAMPLER3D, TK_TYPE_USAMPLER3D, TK_TYPE_SAMPLERCUBE, + TK_TYPE_SAMPLERCUBEARRAY, TK_INTERPOLATION_FLAT, TK_INTERPOLATION_SMOOTH, TK_CONST, @@ -218,6 +219,7 @@ public: TYPE_ISAMPLER3D, TYPE_USAMPLER3D, TYPE_SAMPLERCUBE, + TYPE_SAMPLERCUBEARRAY, TYPE_STRUCT, TYPE_MAX }; @@ -326,7 +328,7 @@ public: }; struct Node { - Node *next; + Node *next = nullptr; enum Type { TYPE_SHADER, @@ -350,7 +352,6 @@ public: virtual String get_datatype_name() const { return ""; } Node(Type t) : - next(nullptr), type(t) {} virtual ~Node() {} }; @@ -366,41 +367,35 @@ public: Node *nodes; struct OperatorNode : public Node { - DataType return_cache; - DataPrecision return_precision_cache; - Operator op; + DataType return_cache = TYPE_VOID; + DataPrecision return_precision_cache = PRECISION_DEFAULT; + Operator op = OP_EQUAL; StringName struct_name; Vector<Node *> arguments; virtual DataType get_datatype() const { return return_cache; } virtual String get_datatype_name() const { return String(struct_name); } OperatorNode() : - Node(TYPE_OPERATOR), - return_cache(TYPE_VOID), - return_precision_cache(PRECISION_DEFAULT), - op(OP_EQUAL), - struct_name("") {} + Node(TYPE_OPERATOR) {} }; struct VariableNode : public Node { - DataType datatype_cache; + DataType datatype_cache = TYPE_VOID; StringName name; StringName struct_name; virtual DataType get_datatype() const { return datatype_cache; } virtual String get_datatype_name() const { return String(struct_name); } - bool is_const; + bool is_const = false; VariableNode() : - Node(TYPE_VARIABLE), - datatype_cache(TYPE_VOID), - is_const(false) {} + Node(TYPE_VARIABLE) {} }; struct VariableDeclarationNode : public Node { - DataPrecision precision; - DataType datatype; + DataPrecision precision = PRECISION_DEFAULT; + DataType datatype = TYPE_VOID; String struct_name; - bool is_const; + bool is_const = false; struct Declaration { StringName name; @@ -411,47 +406,38 @@ public: virtual DataType get_datatype() const { return datatype; } VariableDeclarationNode() : - Node(TYPE_VARIABLE_DECLARATION), - precision(PRECISION_DEFAULT), - datatype(TYPE_VOID), - is_const(false) {} + Node(TYPE_VARIABLE_DECLARATION) {} }; struct ArrayNode : public Node { - DataType datatype_cache; + DataType datatype_cache = TYPE_VOID; StringName struct_name; StringName name; - Node *index_expression; - Node *call_expression; - bool is_const; + Node *index_expression = nullptr; + Node *call_expression = nullptr; + bool is_const = false; virtual DataType get_datatype() const { return datatype_cache; } virtual String get_datatype_name() const { return String(struct_name); } ArrayNode() : - Node(TYPE_ARRAY), - datatype_cache(TYPE_VOID), - index_expression(nullptr), - call_expression(nullptr), - is_const(false) {} + Node(TYPE_ARRAY) {} }; struct ArrayConstructNode : public Node { - DataType datatype; + DataType datatype = TYPE_VOID; String struct_name; Vector<Node *> initializer; ArrayConstructNode() : - Node(TYPE_ARRAY_CONSTRUCT), - datatype(TYPE_VOID) { - } + Node(TYPE_ARRAY_CONSTRUCT) {} }; struct ArrayDeclarationNode : public Node { - DataPrecision precision; - DataType datatype; + DataPrecision precision = PRECISION_DEFAULT; + DataType datatype = TYPE_VOID; String struct_name; - bool is_const; + bool is_const = false; struct Declaration { StringName name; @@ -463,14 +449,11 @@ public: virtual DataType get_datatype() const { return datatype; } ArrayDeclarationNode() : - Node(TYPE_ARRAY_DECLARATION), - precision(PRECISION_DEFAULT), - datatype(TYPE_VOID), - is_const(false) {} + Node(TYPE_ARRAY_DECLARATION) {} }; struct ConstantNode : public Node { - DataType datatype; + DataType datatype = TYPE_VOID; union Value { bool boolean; @@ -483,15 +466,14 @@ public: virtual DataType get_datatype() const { return datatype; } ConstantNode() : - Node(TYPE_CONSTANT), - datatype(TYPE_VOID) {} + Node(TYPE_CONSTANT) {} }; struct FunctionNode; struct BlockNode : public Node { - FunctionNode *parent_function; - BlockNode *parent_block; + FunctionNode *parent_function = nullptr; + BlockNode *parent_block = nullptr; enum BlockType { BLOCK_TYPE_STANDART, @@ -501,8 +483,8 @@ public: BLOCK_TYPE_DEFAULT, }; - int block_type; - SubClassTag block_tag; + int block_type = BLOCK_TYPE_STANDART; + SubClassTag block_tag = SubClassTag::TAG_GLOBAL; struct Variable { DataType type; @@ -515,52 +497,39 @@ public: Map<StringName, Variable> variables; List<Node *> statements; - bool single_statement; + bool single_statement = false; BlockNode() : - Node(TYPE_BLOCK), - parent_function(nullptr), - parent_block(nullptr), - block_type(BLOCK_TYPE_STANDART), - block_tag(SubClassTag::TAG_GLOBAL), - single_statement(false) {} + Node(TYPE_BLOCK) {} }; struct ControlFlowNode : public Node { - FlowOperation flow_op; + FlowOperation flow_op = FLOW_OP_IF; Vector<Node *> expressions; Vector<BlockNode *> blocks; ControlFlowNode() : - Node(TYPE_CONTROL_FLOW), - flow_op(FLOW_OP_IF) {} + Node(TYPE_CONTROL_FLOW) {} }; struct MemberNode : public Node { - DataType basetype; - bool basetype_const; + DataType basetype = TYPE_VOID; + bool basetype_const = false; StringName base_struct_name; DataPrecision precision; - DataType datatype; - int array_size; + DataType datatype = TYPE_VOID; + int array_size = 0; StringName struct_name; StringName name; - Node *owner; - Node *index_expression; - bool has_swizzling_duplicates; + Node *owner = nullptr; + Node *index_expression = nullptr; + bool has_swizzling_duplicates = false; virtual DataType get_datatype() const { return datatype; } virtual String get_datatype_name() const { return String(struct_name); } MemberNode() : - Node(TYPE_MEMBER), - basetype(TYPE_VOID), - basetype_const(false), - datatype(TYPE_VOID), - array_size(0), - owner(nullptr), - index_expression(nullptr), - has_swizzling_duplicates(false) {} + Node(TYPE_MEMBER) {} }; struct StructNode : public Node { @@ -589,19 +558,15 @@ public: }; StringName name; - DataType return_type; + DataType return_type = TYPE_VOID; StringName return_struct_name; - DataPrecision return_precision; + DataPrecision return_precision = PRECISION_DEFAULT; Vector<Argument> arguments; - BlockNode *body; - bool can_discard; + BlockNode *body = nullptr; + bool can_discard = false; FunctionNode() : - Node(TYPE_FUNCTION), - return_type(TYPE_VOID), - return_precision(PRECISION_DEFAULT), - body(nullptr), - can_discard(false) {} + Node(TYPE_FUNCTION) {} }; struct ShaderNode : public Node { @@ -627,16 +592,12 @@ public: }; struct Varying { - DataType type; - DataInterpolation interpolation; - DataPrecision precision; - int array_size; + DataType type = TYPE_VOID; + DataInterpolation interpolation = INTERPOLATION_FLAT; + DataPrecision precision = PRECISION_DEFAULT; + int array_size = 0; - Varying() : - type(TYPE_VOID), - interpolation(INTERPOLATION_FLAT), - precision(PRECISION_DEFAULT), - array_size(0) {} + Varying() {} }; struct Uniform { @@ -665,27 +626,19 @@ public: SCOPE_GLOBAL, }; - int order; - int texture_order; - DataType type; - DataPrecision precision; + int order = 0; + int texture_order = 0; + DataType type = TYPE_VOID; + DataPrecision precision = PRECISION_DEFAULT; Vector<ConstantNode::Value> default_value; - Scope scope; - Hint hint; - TextureFilter filter; - TextureRepeat repeat; + Scope scope = SCOPE_LOCAL; + Hint hint = HINT_NONE; + TextureFilter filter = FILTER_DEFAULT; + TextureRepeat repeat = REPEAT_DEFAULT; float hint_range[3]; - int instance_index; - - Uniform() : - order(0), - texture_order(0), - type(TYPE_VOID), - precision(PRECISION_DEFAULT), - hint(HINT_NONE), - filter(FILTER_DEFAULT), - repeat(REPEAT_DEFAULT), - instance_index(0) { + int instance_index = 0; + + Uniform() { hint_range[0] = 0.0f; hint_range[1] = 1.0f; hint_range[2] = 0.001f; @@ -765,12 +718,10 @@ public: static void get_builtin_funcs(List<String> *r_keywords); struct BuiltInInfo { - DataType type; - bool constant; + DataType type = TYPE_VOID; + bool constant = false; - BuiltInInfo() : - type(TYPE_VOID), - constant(false) {} + BuiltInInfo() {} BuiltInInfo(DataType p_type, bool p_constant = false) : type(p_type), diff --git a/servers/rendering_server.cpp b/servers/rendering_server.cpp index dace935e89..3dac846357 100644 --- a/servers/rendering_server.cpp +++ b/servers/rendering_server.cpp @@ -1569,35 +1569,64 @@ Array RenderingServer::_mesh_surface_get_skeleton_aabb_bind(RID p_mesh, int p_su ShaderLanguage::DataType RenderingServer::global_variable_type_get_shader_datatype(GlobalVariableType p_type) { switch (p_type) { - case RS::GLOBAL_VAR_TYPE_BOOL: return ShaderLanguage::TYPE_BOOL; - case RS::GLOBAL_VAR_TYPE_BVEC2: return ShaderLanguage::TYPE_BVEC2; - case RS::GLOBAL_VAR_TYPE_BVEC3: return ShaderLanguage::TYPE_BVEC3; - case RS::GLOBAL_VAR_TYPE_BVEC4: return ShaderLanguage::TYPE_BVEC4; - case RS::GLOBAL_VAR_TYPE_INT: return ShaderLanguage::TYPE_INT; - case RS::GLOBAL_VAR_TYPE_IVEC2: return ShaderLanguage::TYPE_IVEC2; - case RS::GLOBAL_VAR_TYPE_IVEC3: return ShaderLanguage::TYPE_IVEC3; - case RS::GLOBAL_VAR_TYPE_IVEC4: return ShaderLanguage::TYPE_IVEC4; - case RS::GLOBAL_VAR_TYPE_RECT2I: return ShaderLanguage::TYPE_IVEC4; - case RS::GLOBAL_VAR_TYPE_UINT: return ShaderLanguage::TYPE_UINT; - case RS::GLOBAL_VAR_TYPE_UVEC2: return ShaderLanguage::TYPE_UVEC2; - case RS::GLOBAL_VAR_TYPE_UVEC3: return ShaderLanguage::TYPE_UVEC3; - case RS::GLOBAL_VAR_TYPE_UVEC4: return ShaderLanguage::TYPE_UVEC4; - case RS::GLOBAL_VAR_TYPE_FLOAT: return ShaderLanguage::TYPE_FLOAT; - case RS::GLOBAL_VAR_TYPE_VEC2: return ShaderLanguage::TYPE_VEC2; - case RS::GLOBAL_VAR_TYPE_VEC3: return ShaderLanguage::TYPE_VEC3; - case RS::GLOBAL_VAR_TYPE_VEC4: return ShaderLanguage::TYPE_VEC4; - case RS::GLOBAL_VAR_TYPE_COLOR: return ShaderLanguage::TYPE_VEC4; - case RS::GLOBAL_VAR_TYPE_RECT2: return ShaderLanguage::TYPE_VEC4; - case RS::GLOBAL_VAR_TYPE_MAT2: return ShaderLanguage::TYPE_MAT2; - case RS::GLOBAL_VAR_TYPE_MAT3: return ShaderLanguage::TYPE_MAT3; - case RS::GLOBAL_VAR_TYPE_MAT4: return ShaderLanguage::TYPE_MAT4; - case RS::GLOBAL_VAR_TYPE_TRANSFORM_2D: return ShaderLanguage::TYPE_MAT3; - case RS::GLOBAL_VAR_TYPE_TRANSFORM: return ShaderLanguage::TYPE_MAT4; - case RS::GLOBAL_VAR_TYPE_SAMPLER2D: return ShaderLanguage::TYPE_SAMPLER2D; - case RS::GLOBAL_VAR_TYPE_SAMPLER2DARRAY: return ShaderLanguage::TYPE_SAMPLER2DARRAY; - case RS::GLOBAL_VAR_TYPE_SAMPLER3D: return ShaderLanguage::TYPE_SAMPLER3D; - case RS::GLOBAL_VAR_TYPE_SAMPLERCUBE: return ShaderLanguage::TYPE_SAMPLERCUBE; - default: return ShaderLanguage::TYPE_MAX; //invalid or not found + case RS::GLOBAL_VAR_TYPE_BOOL: + return ShaderLanguage::TYPE_BOOL; + case RS::GLOBAL_VAR_TYPE_BVEC2: + return ShaderLanguage::TYPE_BVEC2; + case RS::GLOBAL_VAR_TYPE_BVEC3: + return ShaderLanguage::TYPE_BVEC3; + case RS::GLOBAL_VAR_TYPE_BVEC4: + return ShaderLanguage::TYPE_BVEC4; + case RS::GLOBAL_VAR_TYPE_INT: + return ShaderLanguage::TYPE_INT; + case RS::GLOBAL_VAR_TYPE_IVEC2: + return ShaderLanguage::TYPE_IVEC2; + case RS::GLOBAL_VAR_TYPE_IVEC3: + return ShaderLanguage::TYPE_IVEC3; + case RS::GLOBAL_VAR_TYPE_IVEC4: + return ShaderLanguage::TYPE_IVEC4; + case RS::GLOBAL_VAR_TYPE_RECT2I: + return ShaderLanguage::TYPE_IVEC4; + case RS::GLOBAL_VAR_TYPE_UINT: + return ShaderLanguage::TYPE_UINT; + case RS::GLOBAL_VAR_TYPE_UVEC2: + return ShaderLanguage::TYPE_UVEC2; + case RS::GLOBAL_VAR_TYPE_UVEC3: + return ShaderLanguage::TYPE_UVEC3; + case RS::GLOBAL_VAR_TYPE_UVEC4: + return ShaderLanguage::TYPE_UVEC4; + case RS::GLOBAL_VAR_TYPE_FLOAT: + return ShaderLanguage::TYPE_FLOAT; + case RS::GLOBAL_VAR_TYPE_VEC2: + return ShaderLanguage::TYPE_VEC2; + case RS::GLOBAL_VAR_TYPE_VEC3: + return ShaderLanguage::TYPE_VEC3; + case RS::GLOBAL_VAR_TYPE_VEC4: + return ShaderLanguage::TYPE_VEC4; + case RS::GLOBAL_VAR_TYPE_COLOR: + return ShaderLanguage::TYPE_VEC4; + case RS::GLOBAL_VAR_TYPE_RECT2: + return ShaderLanguage::TYPE_VEC4; + case RS::GLOBAL_VAR_TYPE_MAT2: + return ShaderLanguage::TYPE_MAT2; + case RS::GLOBAL_VAR_TYPE_MAT3: + return ShaderLanguage::TYPE_MAT3; + case RS::GLOBAL_VAR_TYPE_MAT4: + return ShaderLanguage::TYPE_MAT4; + case RS::GLOBAL_VAR_TYPE_TRANSFORM_2D: + return ShaderLanguage::TYPE_MAT3; + case RS::GLOBAL_VAR_TYPE_TRANSFORM: + return ShaderLanguage::TYPE_MAT4; + case RS::GLOBAL_VAR_TYPE_SAMPLER2D: + return ShaderLanguage::TYPE_SAMPLER2D; + case RS::GLOBAL_VAR_TYPE_SAMPLER2DARRAY: + return ShaderLanguage::TYPE_SAMPLER2DARRAY; + case RS::GLOBAL_VAR_TYPE_SAMPLER3D: + return ShaderLanguage::TYPE_SAMPLER3D; + case RS::GLOBAL_VAR_TYPE_SAMPLERCUBE: + return ShaderLanguage::TYPE_SAMPLERCUBE; + default: + return ShaderLanguage::TYPE_MAX; //invalid or not found } } @@ -1755,8 +1784,8 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("gi_probe_set_compress", "probe", "enable"), &RenderingServer::gi_probe_set_compress); ClassDB::bind_method(D_METHOD("gi_probe_is_compressed", "probe"), &RenderingServer::gi_probe_is_compressed); #endif - - ClassDB::bind_method(D_METHOD("lightmap_capture_create"), &RenderingServer::lightmap_capture_create); +/* + ClassDB::bind_method(D_METHOD("lightmap_create()"), &RenderingServer::lightmap_capture_create); ClassDB::bind_method(D_METHOD("lightmap_capture_set_bounds", "capture", "bounds"), &RenderingServer::lightmap_capture_set_bounds); ClassDB::bind_method(D_METHOD("lightmap_capture_get_bounds", "capture"), &RenderingServer::lightmap_capture_get_bounds); ClassDB::bind_method(D_METHOD("lightmap_capture_set_octree", "capture", "octree"), &RenderingServer::lightmap_capture_set_octree); @@ -1767,6 +1796,7 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("lightmap_capture_get_octree", "capture"), &RenderingServer::lightmap_capture_get_octree); ClassDB::bind_method(D_METHOD("lightmap_capture_set_energy", "capture", "energy"), &RenderingServer::lightmap_capture_set_energy); ClassDB::bind_method(D_METHOD("lightmap_capture_get_energy", "capture"), &RenderingServer::lightmap_capture_get_energy); +*/ #endif ClassDB::bind_method(D_METHOD("particles_create"), &RenderingServer::particles_create); ClassDB::bind_method(D_METHOD("particles_set_emitting", "particles", "emitting"), &RenderingServer::particles_set_emitting); @@ -1866,7 +1896,7 @@ void RenderingServer::_bind_methods() { ClassDB::bind_method(D_METHOD("instance_set_blend_shape_weight", "instance", "shape", "weight"), &RenderingServer::instance_set_blend_shape_weight); ClassDB::bind_method(D_METHOD("instance_set_surface_material", "instance", "surface", "material"), &RenderingServer::instance_set_surface_material); ClassDB::bind_method(D_METHOD("instance_set_visible", "instance", "visible"), &RenderingServer::instance_set_visible); - ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap"), &RenderingServer::instance_set_use_lightmap); + // ClassDB::bind_method(D_METHOD("instance_set_use_lightmap", "instance", "lightmap_instance", "lightmap"), &RenderingServer::instance_set_use_lightmap); ClassDB::bind_method(D_METHOD("instance_set_custom_aabb", "instance", "aabb"), &RenderingServer::instance_set_custom_aabb); ClassDB::bind_method(D_METHOD("instance_attach_skeleton", "instance", "skeleton"), &RenderingServer::instance_attach_skeleton); ClassDB::bind_method(D_METHOD("instance_set_exterior", "instance", "enabled"), &RenderingServer::instance_set_exterior); @@ -2237,7 +2267,7 @@ void RenderingServer::_bind_methods() { BIND_ENUM_CONSTANT(INSTANCE_REFLECTION_PROBE); BIND_ENUM_CONSTANT(INSTANCE_DECAL); BIND_ENUM_CONSTANT(INSTANCE_GI_PROBE); - BIND_ENUM_CONSTANT(INSTANCE_LIGHTMAP_CAPTURE); + BIND_ENUM_CONSTANT(INSTANCE_LIGHTMAP); BIND_ENUM_CONSTANT(INSTANCE_MAX); BIND_ENUM_CONSTANT(INSTANCE_GEOMETRY_MASK); @@ -2484,6 +2514,9 @@ RenderingServer::RenderingServer() { ProjectSettings::get_singleton()->set_custom_property_info("rendering/quality/subsurface_scattering/subsurface_scattering_depth_scale", PropertyInfo(Variant::FLOAT, "rendering/quality/subsurface_scattering/subsurface_scattering_depth_scale", PROPERTY_HINT_RANGE, "0.001,1,0.001")); GLOBAL_DEF("rendering/high_end/global_shader_variables_buffer_size", 65536); + + GLOBAL_DEF("rendering/lightmapper/probe_capture_update_speed", 15); + ProjectSettings::get_singleton()->set_custom_property_info("rendering/lightmapper/probe_capture_update_speed", PropertyInfo(Variant::FLOAT, "rendering/lightmapper/probe_capture_update_speed", PROPERTY_HINT_RANGE, "0.001,256,0.001")); } RenderingServer::~RenderingServer() { diff --git a/servers/rendering_server.h b/servers/rendering_server.h index 8ca070b4a9..d426f205d0 100644 --- a/servers/rendering_server.h +++ b/servers/rendering_server.h @@ -36,6 +36,7 @@ #include "core/math/transform_2d.h" #include "core/object.h" #include "core/rid.h" +#include "core/typed_array.h" #include "core/variant.h" #include "servers/display_server.h" #include "servers/rendering/shader_language.h" @@ -106,7 +107,7 @@ public: //these two APIs can be used together or in combination with the others. virtual RID texture_2d_placeholder_create() = 0; - virtual RID texture_2d_layered_placeholder_create() = 0; + virtual RID texture_2d_layered_placeholder_create(TextureLayeredType p_layered_type) = 0; virtual RID texture_3d_placeholder_create() = 0; virtual Ref<Image> texture_2d_get(RID p_texture) const = 0; @@ -522,19 +523,20 @@ public: virtual void gi_probe_set_anisotropy_strength(RID p_gi_probe, float p_strength) = 0; virtual float gi_probe_get_anisotropy_strength(RID p_gi_probe) const = 0; - /* LIGHTMAP CAPTURE */ + /* LIGHTMAP */ - virtual RID lightmap_capture_create() = 0; - virtual void lightmap_capture_set_bounds(RID p_capture, const AABB &p_bounds) = 0; - virtual AABB lightmap_capture_get_bounds(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree(RID p_capture, const Vector<uint8_t> &p_octree) = 0; - virtual void lightmap_capture_set_octree_cell_transform(RID p_capture, const Transform &p_xform) = 0; - virtual Transform lightmap_capture_get_octree_cell_transform(RID p_capture) const = 0; - virtual void lightmap_capture_set_octree_cell_subdiv(RID p_capture, int p_subdiv) = 0; - virtual int lightmap_capture_get_octree_cell_subdiv(RID p_capture) const = 0; - virtual Vector<uint8_t> lightmap_capture_get_octree(RID p_capture) const = 0; - virtual void lightmap_capture_set_energy(RID p_capture, float p_energy) = 0; - virtual float lightmap_capture_get_energy(RID p_capture) const = 0; + virtual RID lightmap_create() = 0; + + virtual void lightmap_set_textures(RID p_lightmap, RID p_light, bool p_uses_spherical_haromics) = 0; + virtual void lightmap_set_probe_bounds(RID p_lightmap, const AABB &p_bounds) = 0; + virtual void lightmap_set_probe_interior(RID p_lightmap, bool p_interior) = 0; + virtual void lightmap_set_probe_capture_data(RID p_lightmap, const PackedVector3Array &p_points, const PackedColorArray &p_point_sh, const PackedInt32Array &p_tetrahedra, const PackedInt32Array &p_bsp_tree) = 0; + virtual PackedVector3Array lightmap_get_probe_capture_points(RID p_lightmap) const = 0; + virtual PackedColorArray lightmap_get_probe_capture_sh(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_tetrahedra(RID p_lightmap) const = 0; + virtual PackedInt32Array lightmap_get_probe_capture_bsp_tree(RID p_lightmap) const = 0; + + virtual void lightmap_set_probe_capture_update_speed(float p_speed) = 0; /* PARTICLES API */ @@ -713,6 +715,7 @@ public: virtual void sky_set_radiance_size(RID p_sky, int p_radiance_size) = 0; virtual void sky_set_mode(RID p_sky, SkyMode p_mode) = 0; virtual void sky_set_material(RID p_sky, RID p_material) = 0; + virtual Ref<Image> sky_bake_panorama(RID p_sky, float p_energy, bool p_bake_irradiance, const Size2i &p_size) = 0; /* ENVIRONMENT API */ @@ -809,6 +812,8 @@ public: virtual void environment_set_fog_depth(RID p_env, bool p_enable, float p_depth_begin, float p_depth_end, float p_depth_curve, bool p_transmit, float p_transmit_curve) = 0; virtual void environment_set_fog_height(RID p_env, bool p_enable, float p_min_height, float p_max_height, float p_height_curve) = 0; + virtual Ref<Image> environment_bake_panorama(RID p_env, bool p_bake_irradiance, const Size2i &p_size) = 0; + virtual void screen_space_roughness_limiter_set_active(bool p_enable, float p_curve) = 0; enum SubSurfaceScatteringQuality { @@ -885,7 +890,7 @@ public: INSTANCE_REFLECTION_PROBE, INSTANCE_DECAL, INSTANCE_GI_PROBE, - INSTANCE_LIGHTMAP_CAPTURE, + INSTANCE_LIGHTMAP, INSTANCE_MAX, INSTANCE_GEOMETRY_MASK = (1 << INSTANCE_MESH) | (1 << INSTANCE_MULTIMESH) | (1 << INSTANCE_IMMEDIATE) | (1 << INSTANCE_PARTICLES) @@ -904,8 +909,6 @@ public: virtual void instance_set_surface_material(RID p_instance, int p_surface, RID p_material) = 0; virtual void instance_set_visible(RID p_instance, bool p_visible) = 0; - virtual void instance_set_use_lightmap(RID p_instance, RID p_lightmap_instance, RID p_lightmap) = 0; - virtual void instance_set_custom_aabb(RID p_instance, AABB aabb) = 0; virtual void instance_attach_skeleton(RID p_instance, RID p_skeleton) = 0; @@ -942,12 +945,24 @@ public: virtual void instance_geometry_set_draw_range(RID p_instance, float p_min, float p_max, float p_min_margin, float p_max_margin) = 0; virtual void instance_geometry_set_as_instance_lod(RID p_instance, RID p_as_lod_of_instance) = 0; + virtual void instance_geometry_set_lightmap(RID p_instance, RID p_lightmap, const Rect2 &p_lightmap_uv_scale, int p_lightmap_slice) = 0; virtual void instance_geometry_set_shader_parameter(RID p_instance, const StringName &, const Variant &p_value) = 0; virtual Variant instance_geometry_get_shader_parameter(RID p_instance, const StringName &) const = 0; virtual Variant instance_geometry_get_shader_parameter_default_value(RID p_instance, const StringName &) const = 0; virtual void instance_geometry_get_shader_parameter_list(RID p_instance, List<PropertyInfo> *p_parameters) const = 0; + /* Bake 3D objects */ + + enum BakeChannels { + BAKE_CHANNEL_ALBEDO_ALPHA, + BAKE_CHANNEL_NORMAL, + BAKE_CHANNEL_ORM, + BAKE_CHANNEL_EMISSION + }; + + virtual TypedArray<Image> bake_render_uv2(RID p_base, const Vector<RID> &p_material_overrides, const Size2i &p_image_size) = 0; + /* CANVAS (2D) */ virtual RID canvas_create() = 0; @@ -1186,8 +1201,6 @@ public: virtual Vector<FrameProfileArea> get_frame_profile() = 0; virtual uint64_t get_frame_profile_frame() = 0; - /* Materials for 2D on 3D */ - /* TESTING */ virtual RID get_test_cube() = 0; diff --git a/servers/xr/xr_interface.cpp b/servers/xr/xr_interface.cpp index c1233ae810..5ffe50894a 100644 --- a/servers/xr/xr_interface.cpp +++ b/servers/xr/xr_interface.cpp @@ -122,7 +122,7 @@ XRInterface::XRInterface() { tracking_state = XR_UNKNOWN_TRACKING; }; -XRInterface::~XRInterface(){}; +XRInterface::~XRInterface() {} // optional render to external texture which enhances performance on those platforms that require us to submit our end result into special textures. unsigned int XRInterface::get_external_texture_for_eye(XRInterface::Eyes p_eye) { diff --git a/thirdparty/README.md b/thirdparty/README.md index e51e7d7f24..c000133fe7 100644 --- a/thirdparty/README.md +++ b/thirdparty/README.md @@ -384,11 +384,19 @@ Collection of single-file libraries used in Godot components. * Upstream: http://www.pcg-random.org * Version: minimal C implementation, http://www.pcg-random.org/download.html * License: Apache 2.0 +- `r128.h` + * Upstream: https://github.com/fahickman/r128 + * Version: 1.4.3 (2019) + * License: Public Domain - `smaz.{c,h}` * Upstream: https://github.com/antirez/smaz * Version: git (150e125cbae2e8fd20dd332432776ce13395d4d4, 2009) * License: BSD-3-Clause * Modifications: use `const char*` instead of `char*` for input string +- `stb_rect_pack.h` + * Upstream: https://github.com/nothings/stb + * Version: 1.00 (2019) + * License: Public Domain (Unlicense) or MIT - `triangulator.{cpp,h}` * Upstream: https://github.com/ivanfratric/polypartition (`src/polypartition.cpp`) * Version: TBD, class was renamed @@ -437,6 +445,17 @@ Files extracted from the upstream source: - LICENSE.txt +## oidn + +- Upstream: https://github.com/OpenImageDenoise/oidn +- Version: TBD +- License: Apache 2.0 + +Files extracted from upstream source: + +- TBD + + ## opus - Upstream: https://opus-codec.org diff --git a/thirdparty/misc/r128.c b/thirdparty/misc/r128.c new file mode 100644 index 0000000000..6b981aa693 --- /dev/null +++ b/thirdparty/misc/r128.c @@ -0,0 +1,2 @@ +#define R128_IMPLEMENTATION +#include "r128.h" diff --git a/thirdparty/misc/r128.h b/thirdparty/misc/r128.h new file mode 100644 index 0000000000..be7cd3024d --- /dev/null +++ b/thirdparty/misc/r128.h @@ -0,0 +1,2123 @@ +/* +r128.h: 128-bit (64.64) signed fixed-point arithmetic. Version 1.4.3 + +COMPILATION +----------- +Drop this header file somewhere in your project and include it wherever it is +needed. There is no separate .c file for this library. To get the code, in ONE +file in your project, put: + +#define R128_IMPLEMENTATION + +before you include this file. You may also provide a definition for R128_ASSERT +to force the library to use a custom assert macro. + +COMPILER/LIBRARY SUPPORT +------------------------ +This library requires a C89 compiler with support for 64-bit integers. If your +compiler does not support the long long data type, the R128_U64, etc. macros +must be set appropriately. On x86 and x64 targets, Intel intrinsics are used +for speed. If your compiler does not support these intrinsics, you can add +#define R128_STDC_ONLY +in your implementation file before including r128.h. + +The only C runtime library functionality used by this library is <assert.h>. +This can be avoided by defining an R128_ASSERT macro in your implementation +file. Since this library uses 64-bit arithmetic, this may implicitly add a +runtime library dependency on 32-bit platforms. + +C++ SUPPORT +----------- +Operator overloads are supplied for C++ files that include this file. Since all +C++ functions are declared inline (or static inline), the R128_IMPLEMENTATION +file can be either C++ or C. + +LICENSE +------- +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. +*/ + +#ifndef H_R128_H +#define H_R128_H + +#include <stddef.h> + +// 64-bit integer support +// If your compiler does not have stdint.h, add appropriate defines for these macros. +#if defined(_MSC_VER) && (_MSC_VER < 1600) +# define R128_S32 __int32 +# define R128_U32 unsigned __int32 +# define R128_S64 __int64 +# define R128_U64 unsigned __int64 +# define R128_LIT_S64(x) x##i64 +# define R128_LIT_U64(x) x##ui64 +#else +# include <stdint.h> +# define R128_S32 int32_t +# define R128_U32 uint32_t +# define R128_S64 int64_t +# define R128_U64 uint64_t +# define R128_LIT_S64(x) x##ll +# define R128_LIT_U64(x) x##ull +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct R128 { + R128_U64 lo; + R128_U64 hi; + +#ifdef __cplusplus + R128(); + R128(double); + R128(int); + R128(R128_S64); + R128(R128_U64 low, R128_U64 high); + + operator double() const; + operator R128_S64() const; + operator int() const; + operator bool() const; + + bool operator!() const; + R128 operator~() const; + R128 operator-() const; + R128 &operator|=(const R128 &rhs); + R128 &operator&=(const R128 &rhs); + R128 &operator^=(const R128 &rhs); + R128 &operator+=(const R128 &rhs); + R128 &operator-=(const R128 &rhs); + R128 &operator*=(const R128 &rhs); + R128 &operator/=(const R128 &rhs); + R128 &operator%=(const R128 &rhs); + R128 &operator<<=(int amount); + R128 &operator>>=(int amount); +#endif //__cplusplus +} R128; + +// Type conversion +extern void r128FromInt(R128 *dst, R128_S64 v); +extern void r128FromFloat(R128 *dst, double v); +extern R128_S64 r128ToInt(const R128 *v); +extern double r128ToFloat(const R128 *v); + +// Copy +extern void r128Copy(R128 *dst, const R128 *src); + +// Negate +extern void r128Neg(R128 *dst, const R128 *src); + +// Bitwise operations +extern void r128Not(R128 *dst, const R128 *src); // ~a +extern void r128Or(R128 *dst, const R128 *a, const R128 *b); // a | b +extern void r128And(R128 *dst, const R128 *a, const R128 *b); // a & b +extern void r128Xor(R128 *dst, const R128 *a, const R128 *b); // a ^ b +extern void r128Shl(R128 *dst, const R128 *src, int amount); // shift left by amount mod 128 +extern void r128Shr(R128 *dst, const R128 *src, int amount); // shift right logical by amount mod 128 +extern void r128Sar(R128 *dst, const R128 *src, int amount); // shift right arithmetic by amount mod 128 + +// Arithmetic +extern void r128Add(R128 *dst, const R128 *a, const R128 *b); // a + b +extern void r128Sub(R128 *dst, const R128 *a, const R128 *b); // a - b +extern void r128Mul(R128 *dst, const R128 *a, const R128 *b); // a * b +extern void r128Div(R128 *dst, const R128 *a, const R128 *b); // a / b +extern void r128Mod(R128 *dst, const R128 *a, const R128 *b); // a - toInt(a / b) * b + +extern void r128Sqrt(R128 *dst, const R128 *v); // sqrt(v) +extern void r128Rsqrt(R128 *dst, const R128 *v); // 1 / sqrt(v) + +// Comparison +extern int r128Cmp(const R128 *a, const R128 *b); // sign of a-b +extern void r128Min(R128 *dst, const R128 *a, const R128 *b); +extern void r128Max(R128 *dst, const R128 *a, const R128 *b); +extern void r128Floor(R128 *dst, const R128 *v); +extern void r128Ceil(R128 *dst, const R128 *v); +extern int r128IsNeg(const R128 *v); // quick check for < 0 + +// String conversion +// +typedef enum R128ToStringSign { + R128ToStringSign_Default, // no sign character for positive values + R128ToStringSign_Space, // leading space for positive values + R128ToStringSign_Plus, // leading '+' for positive values +} R128ToStringSign; + +// Formatting options for use with r128ToStringOpt. The "defaults" correspond +// to a format string of "%f". +// +typedef struct R128ToStringFormat { + // sign character for positive values. Default is R128ToStringSign_Default. + R128ToStringSign sign; + + // minimum number of characters to write. Default is 0. + int width; + + // place to the right of the decimal at which rounding is performed. If negative, + // a maximum of 20 decimal places will be written, with no trailing zeroes. + // (20 places is sufficient to ensure that r128FromString will convert back to the + // original value.) Default is -1. NOTE: This is not the same default that the C + // standard library uses for %f. + int precision; + + // If non-zero, pads the output string with leading zeroes if the final result is + // fewer than width characters. Otherwise, leading spaces are used. Default is 0. + int zeroPad; + + // Always print a decimal point, even if the value is an integer. Default is 0. + int decimal; + + // Left-align output if width specifier requires padding. + // Default is 0 (right align). + int leftAlign; +} R128ToStringFormat; + +// r128ToStringOpt: convert R128 to a decimal string, with formatting. +// +// dst and dstSize: specify the buffer to write into. At most dstSize bytes will be written +// (including null terminator). No additional rounding is performed if dstSize is not large +// enough to hold the entire string. +// +// opt: an R128ToStringFormat struct (q.v.) with formatting options. +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Number of bytes that will be written (i.e. how big does dst need to be?): +// If width is specified: width + 1 bytes. +// If precision is specified: at most precision + 22 bytes. +// If neither is specified: at most 42 bytes. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToStringOpt(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *opt); + +// r128ToStringf: convert R128 to a decimal string, with formatting. +// +// dst and dstSize: specify the buffer to write into. At most dstSize bytes will be written +// (including null terminator). +// +// format: a printf-style format specifier, as one would use with floating point types. +// e.g. "%+5.2f". (The leading % and trailing f are optional.) +// NOTE: This is NOT a full replacement for sprintf. Any characters in the format string +// that do not correspond to a format placeholder are ignored. +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Number of bytes that will be written (i.e. how big does dst need to be?): +// If the precision field is specified: at most max(width, precision + 21) + 1 bytes +// Otherwise: at most max(width, 41) + 1 bytes. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToStringf(char *dst, size_t dstSize, const char *format, const R128 *v); + +// r128ToString: convert R128 to a decimal string, with default formatting. +// Equivalent to r128ToStringf(dst, dstSize, "%f", v). +// +// Uses the R128_decimal global as the decimal point character. +// Always writes a null terminator, even if the destination buffer is not large enough. +// +// Will write at most 42 bytes (including NUL) to dst. +// +// Returns the number of bytes that would have been written if dst was sufficiently large, +// not including the final null terminator. +// +extern int r128ToString(char *dst, size_t dstSize, const R128 *v); + +// r128FromString: Convert string to R128. +// +// The string can be formatted either as a decimal number with optional sign +// or as hexadecimal with a prefix of 0x or 0X. +// +// endptr, if not NULL, is set to the character following the last character +// used in the conversion. +// +extern void r128FromString(R128 *dst, const char *s, char **endptr); + +// Constants +extern const R128 R128_min; // minimum (most negative) value +extern const R128 R128_max; // maximum (most positive) value +extern const R128 R128_smallest; // smallest positive value +extern const R128 R128_zero; // zero +extern const R128 R128_one; // 1.0 + +extern char R128_decimal; // decimal point character used by r128From/ToString. defaults to '.' + +#ifdef __cplusplus +} + +#include <limits> +namespace std { +template<> +struct numeric_limits<R128> +{ + static const bool is_specialized = true; + + static R128 min() throw() { return R128_min; } + static R128 max() throw() { return R128_max; } + + static const int digits = 127; + static const int digits10 = 38; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const int radix = 2; + static R128 epsilon() throw() { return R128_smallest; } + static R128 round_error() throw() { return R128_one; } + + static const int min_exponent = 0; + static const int min_exponent10 = 0; + static const int max_exponent = 0; + static const int max_exponent10 = 0; + + static const bool has_infinity = false; + static const bool has_quiet_NaN = false; + static const bool has_signaling_NaN = false; + static const float_denorm_style has_denorm = denorm_absent; + static const bool has_denorm_loss = false; + + static R128 infinity() throw() { return R128_zero; } + static R128 quiet_NaN() throw() { return R128_zero; } + static R128 signaling_NaN() throw() { return R128_zero; } + static R128 denorm_min() throw() { return R128_zero; } + + static const bool is_iec559 = false; + static const bool is_bounded = true; + static const bool is_modulo = true; + + static const bool traps = numeric_limits<R128_U64>::traps; + static const bool tinyness_before = false; + static const float_round_style round_style = round_toward_zero; +}; +} //namespace std + +inline R128::R128() {} + +inline R128::R128(double v) +{ + r128FromFloat(this, v); +} + +inline R128::R128(int v) +{ + r128FromInt(this, v); +} + +inline R128::R128(R128_S64 v) +{ + r128FromInt(this, v); +} + +inline R128::R128(R128_U64 low, R128_U64 high) +{ + lo = low; + hi = high; +} + +inline R128::operator double() const +{ + return r128ToFloat(this); +} + +inline R128::operator R128_S64() const +{ + return r128ToInt(this); +} + +inline R128::operator int() const +{ + return (int) r128ToInt(this); +} + +inline R128::operator bool() const +{ + return lo || hi; +} + +inline bool R128::operator!() const +{ + return !lo && !hi; +} + +inline R128 R128::operator~() const +{ + R128 r; + r128Not(&r, this); + return r; +} + +inline R128 R128::operator-() const +{ + R128 r; + r128Neg(&r, this); + return r; +} + +inline R128 &R128::operator|=(const R128 &rhs) +{ + r128Or(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator&=(const R128 &rhs) +{ + r128And(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator^=(const R128 &rhs) +{ + r128Xor(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator+=(const R128 &rhs) +{ + r128Add(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator-=(const R128 &rhs) +{ + r128Sub(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator*=(const R128 &rhs) +{ + r128Mul(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator/=(const R128 &rhs) +{ + r128Div(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator%=(const R128 &rhs) +{ + r128Mod(this, this, &rhs); + return *this; +} + +inline R128 &R128::operator<<=(int amount) +{ + r128Shl(this, this, amount); + return *this; +} + +inline R128 &R128::operator>>=(int amount) +{ + r128Sar(this, this, amount); + return *this; +} + +static inline R128 operator|(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r |= rhs; +} + +static inline R128 operator&(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r &= rhs; +} + +static inline R128 operator^(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r ^= rhs; +} + +static inline R128 operator+(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r += rhs; +} + +static inline R128 operator-(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r -= rhs; +} + +static inline R128 operator*(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r *= rhs; +} + +static inline R128 operator/(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r /= rhs; +} + +static inline R128 operator%(const R128 &lhs, const R128 &rhs) +{ + R128 r(lhs); + return r %= rhs; +} + +static inline R128 operator<<(const R128 &lhs, int amount) +{ + R128 r(lhs); + return r <<= amount; +} + +static inline R128 operator>>(const R128 &lhs, int amount) +{ + R128 r(lhs); + return r >>= amount; +} + +static inline bool operator<(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) < 0; +} + +static inline bool operator>(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) > 0; +} + +static inline bool operator<=(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) <= 0; +} + +static inline bool operator>=(const R128 &lhs, const R128 &rhs) +{ + return r128Cmp(&lhs, &rhs) >= 0; +} + +static inline bool operator==(const R128 &lhs, const R128 &rhs) +{ + return lhs.lo == rhs.lo && lhs.hi == rhs.hi; +} + +static inline bool operator!=(const R128 &lhs, const R128 &rhs) +{ + return lhs.lo != rhs.lo || lhs.hi != rhs.hi; +} + +#endif //__cplusplus +#endif //H_R128_H + +#ifdef R128_IMPLEMENTATION + +#ifdef R128_DEBUG_VIS +# define R128_DEBUG_SET(x) r128ToString(R128_last, sizeof(R128_last), x) +#else +# define R128_DEBUG_SET(x) +#endif + +#define R128_SET2(x, l, h) do { (x)->lo = (R128_U64)(l); (x)->hi = (R128_U64)(h); } while(0) +#define R128_R0(x) ((R128_U32)(x)->lo) +#define R128_R2(x) ((R128_U32)(x)->hi) +#if defined(_M_IX86) +// workaround: MSVC x86's handling of 64-bit values is not great +# define R128_SET4(x, r0, r1, r2, r3) do { \ + ((R128_U32*)&(x)->lo)[0] = (R128_U32)(r0); \ + ((R128_U32*)&(x)->lo)[1] = (R128_U32)(r1); \ + ((R128_U32*)&(x)->hi)[0] = (R128_U32)(r2); \ + ((R128_U32*)&(x)->hi)[1] = (R128_U32)(r3); \ + } while(0) +# define R128_R1(x) (((R128_U32*)&(x)->lo)[1]) +# define R128_R3(x) (((R128_U32*)&(x)->hi)[1]) +#else +# define R128_SET4(x, r0, r1, r2, r3) do { (x)->lo = (R128_U64)(r0) | ((R128_U64)(r1) << 32); \ + (x)->hi = (R128_U64)(r2) | ((R128_U64)(r3) << 32); } while(0) +# define R128_R1(x) ((R128_U32)((x)->lo >> 32)) +# define R128_R3(x) ((R128_U32)((x)->hi >> 32)) +#endif + +#if defined(_M_X64) +# define R128_INTEL 1 +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include <intrin.h> +# endif +#elif defined(__x86_64__) +# define R128_INTEL 1 +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include <x86intrin.h> +# endif +#elif defined(_M_IX86) +# define R128_INTEL 1 +# ifndef R128_STDC_ONLY +# include <intrin.h> +# endif +#elif defined(__i386__) +# define R128_INTEL 1 +# ifndef R128_STDC_ONLY +# include <x86intrin.h> +# endif +#elif defined(_M_ARM) +# ifndef R128_STDC_ONLY +# include <intrin.h> +# endif +#elif defined(_M_ARM64) +# define R128_64BIT 1 +# ifndef R128_STDC_ONLY +# include <intrin.h> +# endif +#elif defined(__aarch64__) +# define R128_64BIT 1 +#endif + +#ifndef R128_INTEL +# define R128_INTEL 0 +#endif + +#ifndef R128_64BIT +# define R128_64BIT 0 +#endif + +#ifndef R128_ASSERT +# include <assert.h> +# define R128_ASSERT(x) assert(x) +#endif + +#include <stdlib.h> // for NULL + +static const R128ToStringFormat R128__defaultFormat = { + R128ToStringSign_Default, + 0, + -1, + 0, + 0, + 0 +}; + +const R128 R128_min = { 0, R128_LIT_U64(0x8000000000000000) }; +const R128 R128_max = { R128_LIT_U64(0xffffffffffffffff), R128_LIT_U64(0x7fffffffffffffff) }; +const R128 R128_smallest = { 1, 0 }; +const R128 R128_zero = { 0, 0 }; +const R128 R128_one = { 0, 1 }; +char R128_decimal = '.'; +#ifdef R128_DEBUG_VIS +char R128_last[42]; +#endif + +static int r128__clz64(R128_U64 x) +{ +#if defined(R128_STDC_ONLY) + R128_U64 n = 64, y; + y = x >> 32; if (y) { n -= 32; x = y; } + y = x >> 16; if (y) { n -= 16; x = y; } + y = x >> 8; if (y) { n -= 8; x = y; } + y = x >> 4; if (y) { n -= 4; x = y; } + y = x >> 2; if (y) { n -= 2; x = y; } + y = x >> 1; if (y) { n -= 1; x = y; } + return (int)(n - x); +#elif defined(_M_X64) || defined(_M_ARM64) + unsigned long idx; + if (_BitScanReverse64(&idx, x)) { + return 63 - (int)idx; + } else { + return 64; + } +#elif defined(_MSC_VER) + unsigned long idx; + if (_BitScanReverse(&idx, (R128_U32)(x >> 32))) { + return 31 - (int)idx; + } else if (_BitScanReverse(&idx, (R128_U32)x)) { + return 63 - (int)idx; + } else { + return 64; + } +#else + return x ? __builtin_clzll(x) : 64; +#endif +} + +#if !R128_64BIT +// 32*32->64 +static R128_U64 r128__umul64(R128_U32 a, R128_U32 b) +{ +# if defined(_M_IX86) && !defined(R128_STDC_ONLY) + return __emulu(a, b); +# elif defined(_M_ARM) && !defined(R128_STDC_ONLY) + return _arm_umull(a, b); +# else + return a * (R128_U64)b; +# endif +} + +// 64/32->32 +static R128_U32 r128__udiv64(R128_U32 nlo, R128_U32 nhi, R128_U32 d, R128_U32 *rem) +{ +# if defined(_M_IX86) && (_MSC_VER >= 1920) && !defined(R128_STDC_ONLY) + unsigned __int64 n = ((unsigned __int64)nhi << 32) | nlo; + return _udiv64(n, d, rem); +# elif defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + mov eax, nlo + mov edx, nhi + div d + mov ecx, rem + mov dword ptr [ecx], edx + } +# elif defined(__i386__) && !defined(R128_STDC_ONLY) + R128_U32 q, r; + __asm("divl %4" + : "=a"(q), "=d"(r) + : "a"(nlo), "d"(nhi), "X"(d)); + *rem = r; + return q; +# else + R128_U64 n64 = ((R128_U64)nhi << 32) | nlo; + *rem = (R128_U32)(n64 % d); + return (R128_U32)(n64 / d); +# endif +} +#elif !defined(_M_X64) || defined(R128_STDC_ONLY) +#define r128__umul64(a, b) ((a) * (R128_U64)(b)) +static R128_U32 r128__udiv64(R128_U32 nlo, R128_U32 nhi, R128_U32 d, R128_U32 *rem) +{ + R128_U64 n64 = ((R128_U64)nhi << 32) | nlo; + *rem = (R128_U32)(n64 % d); + return (R128_U32)(n64 / d); +} +#endif //!R128_64BIT + +static void r128__neg(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) + { + unsigned char carry = 0; +# if R128_64BIT + carry = _addcarry_u64(carry, ~src->lo, 1, &dst->lo); + carry = _addcarry_u64(carry, ~src->hi, 0, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + carry = _addcarry_u32(carry, ~R128_R0(src), 1, &r0); + carry = _addcarry_u32(carry, ~R128_R1(src), 0, &r1); + carry = _addcarry_u32(carry, ~R128_R2(src), 0, &r2); + carry = _addcarry_u32(carry, ~R128_R3(src), 0, &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT + } +#else + if (src->lo) { + dst->lo = ~src->lo + 1; + dst->hi = ~src->hi; + } else { + dst->lo = 0; + dst->hi = ~src->hi + 1; + } +#endif //R128_INTEL +} + +// 64*64->128 +static void r128__umul128(R128 *dst, R128_U64 a, R128_U64 b) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + dst->lo = _umul128(a, b, &dst->hi); +#elif R128_64BIT && !defined(_MSC_VER) && !defined(R128_STDC_ONLY) + unsigned __int128 p0 = a * (unsigned __int128)b; + dst->hi = (R128_U64)(p0 >> 64); + dst->lo = (R128_U64)p0; +#else + R128_U32 alo = (R128_U32)a; + R128_U32 ahi = (R128_U32)(a >> 32); + R128_U32 blo = (R128_U32)b; + R128_U32 bhi = (R128_U32)(b >> 32); + R128_U64 p0, p1, p2, p3; + + p0 = r128__umul64(alo, blo); + p1 = r128__umul64(alo, bhi); + p2 = r128__umul64(ahi, blo); + p3 = r128__umul64(ahi, bhi); + + { +#if R128_INTEL && !defined(R128_STDC_ONLY) + R128_U32 r0, r1, r2, r3; + unsigned char carry; + + r0 = (R128_U32)(p0); + r1 = (R128_U32)(p0 >> 32); + r2 = (R128_U32)(p1 >> 32); + r3 = (R128_U32)(p3 >> 32); + + carry = _addcarry_u32(0, r1, (R128_U32)p1, &r1); + carry = _addcarry_u32(carry, r2, (R128_U32)(p2 >> 32), &r2); + _addcarry_u32(carry, r3, 0, &r3); + carry = _addcarry_u32(0, r1, (R128_U32)p2, &r1); + carry = _addcarry_u32(carry, r2, (R128_U32)p3, &r2); + _addcarry_u32(carry, r3, 0, &r3); + + R128_SET4(dst, r0, r1, r2, r3); +#else + R128_U64 carry, lo, hi; + carry = ((R128_U64)(R128_U32)p1 + (R128_U64)(R128_U32)p2 + (p0 >> 32)) >> 32; + + lo = p0 + ((p1 + p2) << 32); + hi = p3 + ((R128_U32)(p1 >> 32) + (R128_U32)(p2 >> 32)) + carry; + + R128_SET2(dst, lo, hi); +#endif + } +#endif +} + +// 128/64->64 +#if defined(_M_X64) && (_MSC_VER < 1920) && !defined(R128_STDC_ONLY) +// MSVC x64 provides neither inline assembly nor (pre-2019) a div intrinsic, so we do fake +// "inline assembly" to avoid long division or outline assembly. +#pragma code_seg(".text") +__declspec(allocate(".text")) static const unsigned char r128__udiv128Code[] = { + 0x48, 0x8B, 0xC1, //mov rax, rcx + 0x49, 0xF7, 0xF0, //div rax, r8 + 0x49, 0x89, 0x11, //mov qword ptr [r9], rdx + 0xC3 //ret +}; +typedef R128_U64 (*r128__udiv128Proc)(R128_U64 nlo, R128_U64 nhi, R128_U64 d, R128_U64 *rem); +static const r128__udiv128Proc r128__udiv128 = (r128__udiv128Proc)(void*)r128__udiv128Code; +#else +static R128_U64 r128__udiv128(R128_U64 nlo, R128_U64 nhi, R128_U64 d, R128_U64 *rem) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + return _udiv128(nhi, nlo, d, rem); +#elif defined(__x86_64__) && !defined(R128_STDC_ONLY) + R128_U64 q, r; + __asm("divq %4" + : "=a"(q), "=d"(r) + : "a"(nlo), "d"(nhi), "X"(d)); + *rem = r; + return q; +#else + R128_U64 tmp; + R128_U32 d0, d1; + R128_U32 n3, n2, n1, n0; + R128_U32 q0, q1; + R128_U32 r; + int shift; + + R128_ASSERT(d != 0); //division by zero + R128_ASSERT(nhi < d); //overflow + + // normalize + shift = r128__clz64(d); + + if (shift) { + R128 tmp128; + R128_SET2(&tmp128, nlo, nhi); + r128Shl(&tmp128, &tmp128, shift); + n3 = R128_R3(&tmp128); + n2 = R128_R2(&tmp128); + n1 = R128_R1(&tmp128); + n0 = R128_R0(&tmp128); + d <<= shift; + } else { + n3 = (R128_U32)(nhi >> 32); + n2 = (R128_U32)nhi; + n1 = (R128_U32)(nlo >> 32); + n0 = (R128_U32)nlo; + } + + d1 = (R128_U32)(d >> 32); + d0 = (R128_U32)d; + + // first digit + R128_ASSERT(n3 <= d1); + if (n3 < d1) { + q1 = r128__udiv64(n2, n3, d1, &r); + } else { + q1 = 0xffffffffu; + r = n2 + d1; + } +refine1: + if (r128__umul64(q1, d0) > ((R128_U64)r << 32) + n1) { + --q1; + if (r < ~d1 + 1) { + r += d1; + goto refine1; + } + } + + tmp = ((R128_U64)n2 << 32) + n1 - (r128__umul64(q1, d0) + (r128__umul64(q1, d1) << 32)); + n2 = (R128_U32)(tmp >> 32); + n1 = (R128_U32)tmp; + + // second digit + R128_ASSERT(n2 <= d1); + if (n2 < d1) { + q0 = r128__udiv64(n1, n2, d1, &r); + } else { + q0 = 0xffffffffu; + r = n1 + d1; + } +refine0: + if (r128__umul64(q0, d0) > ((R128_U64)r << 32) + n0) { + --q0; + if (r < ~d1 + 1) { + r += d1; + goto refine0; + } + } + + tmp = ((R128_U64)n1 << 32) + n0 - (r128__umul64(q0, d0) + (r128__umul64(q0, d1) << 32)); + n1 = (R128_U32)(tmp >> 32); + n0 = (R128_U32)tmp; + + *rem = (((R128_U64)n1 << 32) + n0) >> shift; + return ((R128_U64)q1 << 32) + q0; +#endif +} +#endif + +static int r128__ucmp(const R128 *a, const R128 *b) +{ + if (a->hi != b->hi) { + if (a->hi > b->hi) { + return 1; + } else { + return -1; + } + } else { + if (a->lo == b->lo) { + return 0; + } else if (a->lo > b->lo) { + return 1; + } else { + return -1; + } + } +} + +static void r128__umul(R128 *dst, const R128 *a, const R128 *b) +{ +#if defined(_M_X64) && !defined(R128_STDC_ONLY) + R128_U64 t0, t1; + R128_U64 lo, hi = 0; + unsigned char carry; + + t0 = _umul128(a->lo, b->lo, &t1); + carry = _addcarry_u64(0, t1, t0 >> 63, &lo); + _addcarry_u64(carry, hi, hi, &hi); + + t0 = _umul128(a->lo, b->hi, &t1); + carry = _addcarry_u64(0, lo, t0, &lo); + _addcarry_u64(carry, hi, t1, &hi); + + t0 = _umul128(a->hi, b->lo, &t1); + carry = _addcarry_u64(0, lo, t0, &lo); + _addcarry_u64(carry, hi, t1, &hi); + + t0 = _umul128(a->hi, b->hi, &t1); + hi += t0; + + R128_SET2(dst, lo, hi); +#elif defined(__x86_64__) && !defined(R128_STDC_ONLY) + unsigned __int128 p0, p1, p2, p3; + p0 = a->lo * (unsigned __int128)b->lo; + p1 = a->lo * (unsigned __int128)b->hi; + p2 = a->hi * (unsigned __int128)b->lo; + p3 = a->hi * (unsigned __int128)b->hi; + + p0 = (p3 << 64) + p2 + p1 + (p0 >> 64) + ((R128_U64)p0 >> 63); + dst->lo = (R128_U64)p0; + dst->hi = (R128_U64)(p0 >> 64); +#else + R128 p0, p1, p2, p3, round; + + r128__umul128(&p0, a->lo, b->lo); + round.hi = 0; round.lo = p0.lo >> 63; + p0.lo = p0.hi; p0.hi = 0; //r128Shr(&p0, &p0, 64); + r128Add(&p0, &p0, &round); + + r128__umul128(&p1, a->hi, b->lo); + r128Add(&p0, &p0, &p1); + + r128__umul128(&p2, a->lo, b->hi); + r128Add(&p0, &p0, &p2); + + r128__umul128(&p3, a->hi, b->hi); + p3.hi = p3.lo; p3.lo = 0; //r128Shl(&p3, &p3, 64); + r128Add(&p0, &p0, &p3); + + R128_SET2(dst, p0.lo, p0.hi); +#endif +} + +// Shift d left until the high bit is set, and shift n left by the same amount. +// returns non-zero on overflow. +static int r128__norm(R128 *n, R128 *d, R128_U64 *n2) +{ + R128_U64 d0, d1; + R128_U64 n0, n1; + int shift; + + d1 = d->hi; + d0 = d->lo; + n1 = n->hi; + n0 = n->lo; + + if (d1) { + shift = r128__clz64(d1); + if (shift) { + d1 = (d1 << shift) | (d0 >> (64 - shift)); + d0 = d0 << shift; + *n2 = n1 >> (64 - shift); + n1 = (n1 << shift) | (n0 >> (64 - shift)); + n0 = n0 << shift; + } else { + *n2 = 0; + } + } else { + shift = r128__clz64(d0); + if (r128__clz64(n1) <= shift) { + return 1; // overflow + } + + if (shift) { + d1 = d0 << shift; + d0 = 0; + *n2 = (n1 << shift) | (n0 >> (64 - shift)); + n1 = n0 << shift; + n0 = 0; + } else { + d1 = d0; + d0 = 0; + *n2 = n1; + n1 = n0; + n0 = 0; + } + } + + R128_SET2(n, n0, n1); + R128_SET2(d, d0, d1); + return 0; +} + +static void r128__udiv(R128 *quotient, const R128 *dividend, const R128 *divisor) +{ + R128 tmp; + R128_U64 d0, d1; + R128_U64 n1, n2, n3; + R128 q; + + R128_ASSERT(dividend != NULL); + R128_ASSERT(divisor != NULL); + R128_ASSERT(quotient != NULL); + R128_ASSERT(divisor->hi != 0 || divisor->lo != 0); // divide by zero + + // scale dividend and normalize + { + R128 n, d; + R128_SET2(&n, dividend->lo, dividend->hi); + R128_SET2(&d, divisor->lo, divisor->hi); + if (r128__norm(&n, &d, &n3)) { + R128_SET2(quotient, R128_max.lo, R128_max.hi); + return; + } + + d1 = d.hi; + d0 = d.lo; + n2 = n.hi; + n1 = n.lo; + } + + // first digit + R128_ASSERT(n3 <= d1); + { + R128 t0, t1; + t0.lo = n1; + if (n3 < d1) { + q.hi = r128__udiv128(n2, n3, d1, &t0.hi); + } else { + q.hi = R128_LIT_U64(0xffffffffffffffff); + t0.hi = n2 + d1; + } + +refine1: + r128__umul128(&t1, q.hi, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q.hi; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine1; + } + } + } + + { + R128 t0, t1, t2; + t0.hi = n2; + t0.lo = n1; + + r128__umul128(&t1, q.hi, d0); + r128__umul128(&t2, q.hi, d1); + + t2.hi = t2.lo; t2.lo = 0; //r128Shl(&t2, &t2, 64); + r128Add(&tmp, &t1, &t2); + r128Sub(&tmp, &t0, &tmp); + } + n2 = tmp.hi; + n1 = tmp.lo; + + // second digit + R128_ASSERT(n2 <= d1); + { + R128 t0, t1; + t0.lo = 0; + if (n2 < d1) { + q.lo = r128__udiv128(n1, n2, d1, &t0.hi); + } else { + q.lo = R128_LIT_U64(0xffffffffffffffff); + t0.hi = n1 + d1; + } + + refine0: + r128__umul128(&t1, q.lo, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q.lo; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine0; + } + } + } + + R128_SET2(quotient, q.lo, q.hi); +} + +static R128_U64 r128__umod(R128 *n, R128 *d) +{ + R128_U64 d0, d1; + R128_U64 n3, n2, n1; + R128_U64 q; + + R128_ASSERT(d != NULL); + R128_ASSERT(n != NULL); + R128_ASSERT(d->hi != 0 || d->lo != 0); // divide by zero + + if (r128__norm(n, d, &n3)) { + return R128_LIT_U64(0xffffffffffffffff); + } + + d1 = d->hi; + d0 = d->lo; + n2 = n->hi; + n1 = n->lo; + + R128_ASSERT(n3 < d1); + { + R128 t0, t1; + t0.lo = n1; + q = r128__udiv128(n2, n3, d1, &t0.hi); + + refine1: + r128__umul128(&t1, q, d0); + if (r128__ucmp(&t1, &t0) > 0) { + --q; + if (t0.hi < ~d1 + 1) { + t0.hi += d1; + goto refine1; + } + } + } + + return q; +} + +static int r128__format(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *format) +{ + char buf[128]; + R128 tmp; + R128_U64 whole; + char *cursor, *decimal, *dstp = dst; + int sign = 0; + int fullPrecision = 1; + int width, precision; + int padCnt, trail = 0; + + R128_ASSERT(dst != NULL && dstSize > 0); + R128_ASSERT(v != NULL); + R128_ASSERT(format != NULL); + + --dstSize; + + R128_SET2(&tmp, v->lo, v->hi); + if (r128IsNeg(&tmp)) { + r128__neg(&tmp, &tmp); + sign = 1; + } + + width = format->width; + if (width < 0) { + width = 0; + } + + precision = format->precision; + if (precision < 0) { + // print a maximum of 20 digits + fullPrecision = 0; + precision = 20; + } else if (precision > sizeof(buf) - 21) { + trail = precision - (sizeof(buf) - 21); + precision -= trail; + } + + whole = tmp.hi; + decimal = cursor = buf; + + // fractional part first in case a carry into the whole part is required + if (tmp.lo || format->decimal) { + while (tmp.lo || (fullPrecision && precision)) { + if ((int)(cursor - buf) == precision) { + if ((R128_S64)tmp.lo < 0) { + // round up, propagate carry backwards + char *c; + for (c = cursor - 1; c >= buf; --c) { + char d = ++*c; + if (d <= '9') { + goto endfrac; + } else { + *c = '0'; + } + } + + // carry out into the whole part + whole++; + } + + break; + } + + r128__umul128(&tmp, tmp.lo, 10); + *cursor++ = (char)tmp.hi + '0'; + } + + endfrac: + if (format->decimal || precision) { + decimal = cursor; + *cursor++ = R128_decimal; + } + } + + // whole part + do { + char digit = (char)(whole % 10); + whole /= 10; + *cursor++ = digit + '0'; + } while (whole); + +#define R128__WRITE(c) do { if (dstp < dst + dstSize) *dstp = c; ++dstp; } while(0) + + padCnt = width - (int)(cursor - buf) - 1; + + // left padding + if (!format->leftAlign) { + char padChar = format->zeroPad ? '0' : ' '; + if (format->zeroPad) { + if (sign) { + R128__WRITE('-'); + } else if (format->sign == R128ToStringSign_Plus) { + R128__WRITE('+'); + } else if (format->sign == R128ToStringSign_Space) { + R128__WRITE(' '); + } else { + ++padCnt; + } + } + + for (; padCnt > 0; --padCnt) { + R128__WRITE(padChar); + } + } + + if (format->leftAlign || !format->zeroPad) { + if (sign) { + R128__WRITE('-'); + } else if (format->sign == R128ToStringSign_Plus) { + R128__WRITE('+'); + } else if (format->sign == R128ToStringSign_Space) { + R128__WRITE(' '); + } else { + ++padCnt; + } + } + + { + char *i; + + // reverse the whole part + for (i = cursor - 1; i >= decimal; --i) { + R128__WRITE(*i); + } + + // copy the fractional part + for (i = buf; i < decimal; ++i) { + R128__WRITE(*i); + } + } + + // right padding + if (format->leftAlign) { + char padChar = format->zeroPad ? '0' : ' '; + for (; padCnt > 0; --padCnt) { + R128__WRITE(padChar); + } + } + + // trailing zeroes for very large precision + while (trail--) { + R128__WRITE('0'); + } + +#undef R128__WRITE + + if (dstp <= dst + dstSize) { + *dstp = '\0'; + } else { + dst[dstSize] = '\0'; + } + return (int)(dstp - dst); +} + +void r128FromInt(R128 *dst, R128_S64 v) +{ + R128_ASSERT(dst != NULL); + dst->lo = 0; + dst->hi = (R128_U64)v; + R128_DEBUG_SET(dst); +} + +void r128FromFloat(R128 *dst, double v) +{ + R128_ASSERT(dst != NULL); + + if (v < -9223372036854775808.0) { + r128Copy(dst, &R128_min); + } else if (v >= 9223372036854775808.0) { + r128Copy(dst, &R128_max); + } else { + R128 r; + int sign = 0; + + if (v < 0) { + v = -v; + sign = 1; + } + + r.hi = (R128_U64)(R128_S64)v; + v -= (R128_S64)v; + r.lo = (R128_U64)(v * 18446744073709551616.0); + + if (sign) { + r128__neg(&r, &r); + } + + r128Copy(dst, &r); + } +} + +void r128FromString(R128 *dst, const char *s, char **endptr) +{ + R128_U64 lo = 0, hi = 0; + R128_U64 base = 10; + + int sign = 0; + + R128_ASSERT(dst != NULL); + R128_ASSERT(s != NULL); + + R128_SET2(dst, 0, 0); + + // consume whitespace + for (;;) { + if (*s == ' ' || *s == '\t' || *s == '\r' || *s == '\n' || *s == '\v') { + ++s; + } else { + break; + } + } + + // sign + if (*s == '-') { + sign = 1; + ++s; + } else if (*s == '+') { + ++s; + } + + // parse base prefix + if (s[0] == '0' && (s[1] == 'x' || s[1] == 'X')) { + base = 16; + s += 2; + } + + // whole part + for (;; ++s) { + R128_U64 digit; + + if ('0' <= *s && *s <= '9') { + digit = *s - '0'; + } else if (base == 16 && 'a' <= *s && *s <= 'f') { + digit = *s - 'a' + 10; + } else if (base == 16 && 'A' <= *s && *s <= 'F') { + digit = *s - 'A' + 10; + } else { + break; + } + + hi = hi * base + digit; + } + + // fractional part + if (*s == R128_decimal) { + const char *exp = ++s; + + // find the last digit and work backwards + for (;; ++s) { + if ('0' <= *s && *s <= '9') { + } else if (base == 16 && ('a' <= *s && *s <= 'f')) { + } else if (base == 16 && ('A' <= *s && *s <= 'F')) { + } else { + break; + } + } + + for (--s; s >= exp; --s) { + R128_U64 digit, unused; + + if ('0' <= *s && *s <= '9') { + digit = *s - '0'; + } else if ('a' <= *s && *s <= 'f') { + digit = *s - 'a' + 10; + } else { + digit = *s - 'A' + 10; + } + + lo = r128__udiv128(lo, digit, base, &unused); + } + } + + R128_SET2(dst, lo, hi); + if (sign) { + r128__neg(dst, dst); + } + + if (endptr) { + *endptr = (char *) s; + } +} + +R128_S64 r128ToInt(const R128 *v) +{ + R128_ASSERT(v != NULL); + return (R128_S64)v->hi; +} + +double r128ToFloat(const R128 *v) +{ + R128 tmp; + int sign = 0; + double d; + + R128_ASSERT(v != NULL); + + R128_SET2(&tmp, v->lo, v->hi); + if (r128IsNeg(&tmp)) { + r128__neg(&tmp, &tmp); + sign = 1; + } + + d = tmp.hi + tmp.lo * (1 / 18446744073709551616.0); + if (sign) { + d = -d; + } + + return d; +} + +int r128ToStringOpt(char *dst, size_t dstSize, const R128 *v, const R128ToStringFormat *opt) +{ + return r128__format(dst, dstSize, v, opt); +} + +int r128ToStringf(char *dst, size_t dstSize, const char *format, const R128 *v) +{ + R128ToStringFormat opts; + + R128_ASSERT(dst != NULL && dstSize); + R128_ASSERT(format != NULL); + R128_ASSERT(v != NULL); + + opts.sign = R128__defaultFormat.sign; + opts.precision = R128__defaultFormat.precision; + opts.zeroPad = R128__defaultFormat.zeroPad; + opts.decimal = R128__defaultFormat.decimal; + opts.leftAlign = R128__defaultFormat.leftAlign; + + if (*format == '%') { + ++format; + } + + // flags field + for (;; ++format) { + if (*format == ' ' && opts.sign != R128ToStringSign_Plus) { + opts.sign = R128ToStringSign_Space; + } else if (*format == '+') { + opts.sign = R128ToStringSign_Plus; + } else if (*format == '0') { + opts.zeroPad = 1; + } else if (*format == '-') { + opts.leftAlign = 1; + } else if (*format == '#') { + opts.decimal = 1; + } else { + break; + } + } + + // width field + opts.width = 0; + for (;;) { + if ('0' <= *format && *format <= '9') { + opts.width = opts.width * 10 + *format++ - '0'; + } else { + break; + } + } + + // precision field + if (*format == '.') { + opts.precision = 0; + ++format; + for (;;) { + if ('0' <= *format && *format <= '9') { + opts.precision = opts.precision * 10 + *format++ - '0'; + } else { + break; + } + } + } + + return r128__format(dst, dstSize, v, &opts); +} + +int r128ToString(char *dst, size_t dstSize, const R128 *v) +{ + return r128__format(dst, dstSize, v, &R128__defaultFormat); +} + +void r128Copy(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + dst->lo = src->lo; + dst->hi = src->hi; + R128_DEBUG_SET(dst); +} + +void r128Neg(R128 *dst, const R128 *src) +{ + r128__neg(dst, src); + R128_DEBUG_SET(dst); +} + +void r128Not(R128 *dst, const R128 *src) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + + dst->lo = ~src->lo; + dst->hi = ~src->hi; + R128_DEBUG_SET(dst); +} + +void r128Or(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo | b->lo; + dst->hi = a->hi | b->hi; + R128_DEBUG_SET(dst); +} + +void r128And(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo & b->lo; + dst->hi = a->hi & b->hi; + R128_DEBUG_SET(dst); +} + +void r128Xor(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + dst->lo = a->lo ^ b->lo; + dst->hi = a->hi ^ b->hi; + R128_DEBUG_SET(dst); +} + +void r128Shl(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shld eax, ebx, cl + shld ebx, esi, cl + shld esi, edi, cl + shl edi, cl + + // clear out low 12 bytes of stack + xor edx, edx + mov dword ptr[r], edx + mov dword ptr[r + 4], edx + mov dword ptr[r + 8], edx + + // store shifted amount offset by count/32 bits + shr ecx, 5 + and ecx, 3 + mov dword ptr[r + ecx * 4 + 0], edi + mov dword ptr[r + ecx * 4 + 4], esi + mov dword ptr[r + ecx * 4 + 8], ebx + mov dword ptr[r + ecx * 4 + 12], eax + } +#else + + r[0] = src->lo; + r[1] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[1] = r[0] << (amount - 64); + r[0] = 0; + } else if (amount) { +# ifdef _M_X64 + r[1] = __shiftleft128(r[0], r[1], (char) amount); +# else + r[1] = (r[1] << amount) | (r[0] >> (64 - amount)); +# endif + r[0] = r[0] << amount; + } +#endif //_M_IX86 + + dst->lo = r[0]; + dst->hi = r[1]; + R128_DEBUG_SET(dst); +} + +void r128Shr(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shrd edi, esi, cl + shrd esi, ebx, cl + shrd ebx, eax, cl + shr eax, cl + + // clear out high 12 bytes of stack + xor edx, edx + mov dword ptr[r + 20], edx + mov dword ptr[r + 24], edx + mov dword ptr[r + 28], edx + + // store shifted amount offset by -count/32 bits + shr ecx, 5 + and ecx, 3 + neg ecx + mov dword ptr[r + ecx * 4 + 16], edi + mov dword ptr[r + ecx * 4 + 20], esi + mov dword ptr[r + ecx * 4 + 24], ebx + mov dword ptr[r + ecx * 4 + 28], eax + } +#else + r[2] = src->lo; + r[3] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[2] = r[3] >> (amount - 64); + r[3] = 0; + } else if (amount) { +#ifdef _M_X64 + r[2] = __shiftright128(r[2], r[3], (char) amount); +#else + r[2] = (r[2] >> amount) | (r[3] << (64 - amount)); +#endif + r[3] = r[3] >> amount; + } +#endif + + dst->lo = r[2]; + dst->hi = r[3]; + R128_DEBUG_SET(dst); +} + +void r128Sar(R128 *dst, const R128 *src, int amount) +{ + R128_U64 r[4]; + + R128_ASSERT(dst != NULL); + R128_ASSERT(src != NULL); + +#if defined(_M_IX86) && !defined(R128_STDC_ONLY) + __asm { + // load src + mov edx, dword ptr[src] + mov ecx, amount + + mov edi, dword ptr[edx] + mov esi, dword ptr[edx + 4] + mov ebx, dword ptr[edx + 8] + mov eax, dword ptr[edx + 12] + + // shift mod 32 + shrd edi, esi, cl + shrd esi, ebx, cl + shrd ebx, eax, cl + sar eax, cl + + // copy sign to high 12 bytes of stack + cdq + mov dword ptr[r + 20], edx + mov dword ptr[r + 24], edx + mov dword ptr[r + 28], edx + + // store shifted amount offset by -count/32 bits + shr ecx, 5 + and ecx, 3 + neg ecx + mov dword ptr[r + ecx * 4 + 16], edi + mov dword ptr[r + ecx * 4 + 20], esi + mov dword ptr[r + ecx * 4 + 24], ebx + mov dword ptr[r + ecx * 4 + 28], eax + } +#else + r[2] = src->lo; + r[3] = src->hi; + + amount &= 127; + if (amount >= 64) { + r[2] = (R128_U64)((R128_S64)r[3] >> (amount - 64)); + r[3] = (R128_U64)((R128_S64)r[3] >> 63); + } else if (amount) { + r[2] = (r[2] >> amount) | (R128_U64)((R128_S64)r[3] << (64 - amount)); + r[3] = (R128_U64)((R128_S64)r[3] >> amount); + } +#endif + + dst->lo = r[2]; + dst->hi = r[3]; + R128_DEBUG_SET(dst); +} + +void r128Add(R128 *dst, const R128 *a, const R128 *b) +{ + unsigned char carry = 0; + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) +# if R128_64BIT + carry = _addcarry_u64(carry, a->lo, b->lo, &dst->lo); + carry = _addcarry_u64(carry, a->hi, b->hi, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + carry = _addcarry_u32(carry, R128_R0(a), R128_R0(b), &r0); + carry = _addcarry_u32(carry, R128_R1(a), R128_R1(b), &r1); + carry = _addcarry_u32(carry, R128_R2(a), R128_R2(b), &r2); + carry = _addcarry_u32(carry, R128_R3(a), R128_R3(b), &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT +#else + { + R128_U64 r = a->lo + b->lo; + carry = r < a->lo; + dst->lo = r; + dst->hi = a->hi + b->hi + carry; + } +#endif //R128_INTEL + + R128_DEBUG_SET(dst); +} + +void r128Sub(R128 *dst, const R128 *a, const R128 *b) +{ + unsigned char borrow = 0; + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + +#if R128_INTEL && !defined(R128_STDC_ONLY) +# if R128_64BIT + borrow = _subborrow_u64(borrow, a->lo, b->lo, &dst->lo); + borrow = _subborrow_u64(borrow, a->hi, b->hi, &dst->hi); +# else + R128_U32 r0, r1, r2, r3; + borrow = _subborrow_u32(borrow, R128_R0(a), R128_R0(b), &r0); + borrow = _subborrow_u32(borrow, R128_R1(a), R128_R1(b), &r1); + borrow = _subborrow_u32(borrow, R128_R2(a), R128_R2(b), &r2); + borrow = _subborrow_u32(borrow, R128_R3(a), R128_R3(b), &r3); + R128_SET4(dst, r0, r1, r2, r3); +# endif //R128_64BIT +#else + { + R128_U64 r = a->lo - b->lo; + borrow = r > a->lo; + dst->lo = r; + dst->hi = a->hi - b->hi - borrow; + } +#endif //R128_INTEL + + R128_DEBUG_SET(dst); +} + +void r128Mul(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 ta, tb, tc; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&ta, a->lo, a->hi); + R128_SET2(&tb, b->lo, b->hi); + + if (r128IsNeg(&ta)) { + r128__neg(&ta, &ta); + sign = !sign; + } + if (r128IsNeg(&tb)) { + r128__neg(&tb, &tb); + sign = !sign; + } + + r128__umul(&tc, &ta, &tb); + if (sign) { + r128__neg(&tc, &tc); + } + + r128Copy(dst, &tc); +} + +void r128Div(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 tn, td, tq; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&tn, a->lo, a->hi); + R128_SET2(&td, b->lo, b->hi); + + if (r128IsNeg(&tn)) { + r128__neg(&tn, &tn); + sign = !sign; + } + + if (td.lo == 0 && td.hi == 0) { + // divide by zero + if (sign) { + r128Copy(dst, &R128_min); + } else { + r128Copy(dst, &R128_max); + } + return; + } else if (r128IsNeg(&td)) { + r128__neg(&td, &td); + sign = !sign; + } + + r128__udiv(&tq, &tn, &td); + + if (sign) { + r128__neg(&tq, &tq); + } + + r128Copy(dst, &tq); +} + +void r128Mod(R128 *dst, const R128 *a, const R128 *b) +{ + int sign = 0; + R128 tn, td, tq; + + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + R128_SET2(&tn, a->lo, a->hi); + R128_SET2(&td, b->lo, b->hi); + + if (r128IsNeg(&tn)) { + r128__neg(&tn, &tn); + sign = !sign; + } + + if (td.lo == 0 && td.hi == 0) { + // divide by zero + if (sign) { + r128Copy(dst, &R128_min); + } else { + r128Copy(dst, &R128_max); + } + return; + } else if (r128IsNeg(&td)) { + r128__neg(&td, &td); + sign = !sign; + } + + tq.hi = r128__umod(&tn, &td); + tq.lo = 0; + + if (sign) { + tq.hi = ~tq.hi + 1; + } + + r128Mul(&tq, &tq, b); + r128Sub(dst, a, &tq); +} + +void r128Rsqrt(R128 *dst, const R128 *v) +{ + static const R128 threeHalves = { R128_LIT_U64(0x8000000000000000), 1 }; + R128 x, est; + int i; + + if ((R128_S64)v->hi < 0) { + r128Copy(dst, &R128_min); + return; + } + + R128_SET2(&x, v->lo, v->hi); + + // get initial estimate + if (x.hi) { + int shift = (64 + r128__clz64(x.hi)) >> 1; + est.lo = R128_LIT_U64(1) << shift; + est.hi = 0; + } else if (x.lo) { + int shift = r128__clz64(x.lo) >> 1; + est.hi = R128_LIT_U64(1) << shift; + est.lo = 0; + } else { + R128_SET2(dst, 0, 0); + return; + } + + // x /= 2 + r128Shr(&x, &x, 1); + + // Newton-Raphson iterate + for (i = 0; i < 7; ++i) { + R128 newEst; + + // newEst = est * (threeHalves - (x / 2) * est * est); + r128__umul(&newEst, &est, &est); + r128__umul(&newEst, &newEst, &x); + r128Sub(&newEst, &threeHalves, &newEst); + r128__umul(&newEst, &est, &newEst); + + if (newEst.lo == est.lo && newEst.hi == est.hi) { + break; + } + R128_SET2(&est, newEst.lo, newEst.hi); + } + + r128Copy(dst, &est); +} + +void r128Sqrt(R128 *dst, const R128 *v) +{ + R128 x, est; + int i; + + if ((R128_S64)v->hi < 0) { + r128Copy(dst, &R128_min); + return; + } + + R128_SET2(&x, v->lo, v->hi); + + // get initial estimate + if (x.hi) { + int shift = (63 - r128__clz64(x.hi)) >> 1; + r128Shr(&est, &x, shift); + } else if (x.lo) { + int shift = (1 + r128__clz64(x.lo)) >> 1; + r128Shl(&est, &x, shift); + } else { + R128_SET2(dst, 0, 0); + return; + } + + // Newton-Raphson iterate + for (i = 0; i < 7; ++i) { + R128 newEst; + + // newEst = (est + x / est) / 2 + r128__udiv(&newEst, &x, &est); + r128Add(&newEst, &newEst, &est); + r128Shr(&newEst, &newEst, 1); + + if (newEst.lo == est.lo && newEst.hi == est.hi) { + break; + } + R128_SET2(&est, newEst.lo, newEst.hi); + } + + r128Copy(dst, &est); +} + +int r128Cmp(const R128 *a, const R128 *b) +{ + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (a->hi == b->hi) { + if (a->lo == b->lo) { + return 0; + } else if (a->lo > b->lo) { + return 1; + } else { + return -1; + } + } else if ((R128_S64)a->hi > (R128_S64)b->hi) { + return 1; + } else { + return -1; + } +} + +int r128IsNeg(const R128 *v) +{ + R128_ASSERT(v != NULL); + + return (R128_S64)v->hi < 0; +} + +void r128Min(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (r128Cmp(a, b) < 0) { + r128Copy(dst, a); + } else { + r128Copy(dst, b); + } +} + +void r128Max(R128 *dst, const R128 *a, const R128 *b) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(a != NULL); + R128_ASSERT(b != NULL); + + if (r128Cmp(a, b) > 0) { + r128Copy(dst, a); + } else { + r128Copy(dst, b); + } +} + +void r128Floor(R128 *dst, const R128 *v) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(v != NULL); + + if ((R128_S64)v->hi < 0) { + dst->hi = v->hi - (v->lo != 0); + } else { + dst->hi = v->hi; + } + dst->lo = 0; + R128_DEBUG_SET(dst); +} + +void r128Ceil(R128 *dst, const R128 *v) +{ + R128_ASSERT(dst != NULL); + R128_ASSERT(v != NULL); + + if ((R128_S64)v->hi > 0) { + dst->hi = v->hi + (v->lo != 0); + } else { + dst->hi = v->hi; + } + dst->lo = 0; + R128_DEBUG_SET(dst); +} + +#endif //R128_IMPLEMENTATION diff --git a/thirdparty/misc/stb_rect_pack.h b/thirdparty/misc/stb_rect_pack.h new file mode 100644 index 0000000000..5c848de0e7 --- /dev/null +++ b/thirdparty/misc/stb_rect_pack.h @@ -0,0 +1,628 @@ +// stb_rect_pack.h - v1.00 - public domain - rectangle packing +// Sean Barrett 2014 +// +// Useful for e.g. packing rectangular textures into an atlas. +// Does not do rotation. +// +// Not necessarily the awesomest packing method, but better than +// the totally naive one in stb_truetype (which is primarily what +// this is meant to replace). +// +// Has only had a few tests run, may have issues. +// +// More docs to come. +// +// No memory allocations; uses qsort() and assert() from stdlib. +// Can override those by defining STBRP_SORT and STBRP_ASSERT. +// +// This library currently uses the Skyline Bottom-Left algorithm. +// +// Please note: better rectangle packers are welcome! Please +// implement them to the same API, but with a different init +// function. +// +// Credits +// +// Library +// Sean Barrett +// Minor features +// Martins Mozeiko +// github:IntellectualKitty +// +// Bugfixes / warning fixes +// Jeremy Jaussaud +// Fabian Giesen +// +// Version history: +// +// 1.00 (2019-02-25) avoid small space waste; gracefully fail too-wide rectangles +// 0.99 (2019-02-07) warning fixes +// 0.11 (2017-03-03) return packing success/fail result +// 0.10 (2016-10-25) remove cast-away-const to avoid warnings +// 0.09 (2016-08-27) fix compiler warnings +// 0.08 (2015-09-13) really fix bug with empty rects (w=0 or h=0) +// 0.07 (2015-09-13) fix bug with empty rects (w=0 or h=0) +// 0.06 (2015-04-15) added STBRP_SORT to allow replacing qsort +// 0.05: added STBRP_ASSERT to allow replacing assert +// 0.04: fixed minor bug in STBRP_LARGE_RECTS support +// 0.01: initial release +// +// LICENSE +// +// See end of file for license information. + +////////////////////////////////////////////////////////////////////////////// +// +// INCLUDE SECTION +// + +#ifndef STB_INCLUDE_STB_RECT_PACK_H +#define STB_INCLUDE_STB_RECT_PACK_H + +#define STB_RECT_PACK_VERSION 1 + +#ifdef STBRP_STATIC +#define STBRP_DEF static +#else +#define STBRP_DEF extern +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct stbrp_context stbrp_context; +typedef struct stbrp_node stbrp_node; +typedef struct stbrp_rect stbrp_rect; + +#ifdef STBRP_LARGE_RECTS +typedef int stbrp_coord; +#else +typedef unsigned short stbrp_coord; +#endif + +STBRP_DEF int stbrp_pack_rects (stbrp_context *context, stbrp_rect *rects, int num_rects); +// Assign packed locations to rectangles. The rectangles are of type +// 'stbrp_rect' defined below, stored in the array 'rects', and there +// are 'num_rects' many of them. +// +// Rectangles which are successfully packed have the 'was_packed' flag +// set to a non-zero value and 'x' and 'y' store the minimum location +// on each axis (i.e. bottom-left in cartesian coordinates, top-left +// if you imagine y increasing downwards). Rectangles which do not fit +// have the 'was_packed' flag set to 0. +// +// You should not try to access the 'rects' array from another thread +// while this function is running, as the function temporarily reorders +// the array while it executes. +// +// To pack into another rectangle, you need to call stbrp_init_target +// again. To continue packing into the same rectangle, you can call +// this function again. Calling this multiple times with multiple rect +// arrays will probably produce worse packing results than calling it +// a single time with the full rectangle array, but the option is +// available. +// +// The function returns 1 if all of the rectangles were successfully +// packed and 0 otherwise. + +struct stbrp_rect +{ + // reserved for your use: + int id; + + // input: + stbrp_coord w, h; + + // output: + stbrp_coord x, y; + int was_packed; // non-zero if valid packing + +}; // 16 bytes, nominally + + +STBRP_DEF void stbrp_init_target (stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes); +// Initialize a rectangle packer to: +// pack a rectangle that is 'width' by 'height' in dimensions +// using temporary storage provided by the array 'nodes', which is 'num_nodes' long +// +// You must call this function every time you start packing into a new target. +// +// There is no "shutdown" function. The 'nodes' memory must stay valid for +// the following stbrp_pack_rects() call (or calls), but can be freed after +// the call (or calls) finish. +// +// Note: to guarantee best results, either: +// 1. make sure 'num_nodes' >= 'width' +// or 2. call stbrp_allow_out_of_mem() defined below with 'allow_out_of_mem = 1' +// +// If you don't do either of the above things, widths will be quantized to multiples +// of small integers to guarantee the algorithm doesn't run out of temporary storage. +// +// If you do #2, then the non-quantized algorithm will be used, but the algorithm +// may run out of temporary storage and be unable to pack some rectangles. + +STBRP_DEF void stbrp_setup_allow_out_of_mem (stbrp_context *context, int allow_out_of_mem); +// Optionally call this function after init but before doing any packing to +// change the handling of the out-of-temp-memory scenario, described above. +// If you call init again, this will be reset to the default (false). + + +STBRP_DEF void stbrp_setup_heuristic (stbrp_context *context, int heuristic); +// Optionally select which packing heuristic the library should use. Different +// heuristics will produce better/worse results for different data sets. +// If you call init again, this will be reset to the default. + +enum +{ + STBRP_HEURISTIC_Skyline_default=0, + STBRP_HEURISTIC_Skyline_BL_sortHeight = STBRP_HEURISTIC_Skyline_default, + STBRP_HEURISTIC_Skyline_BF_sortHeight +}; + + +////////////////////////////////////////////////////////////////////////////// +// +// the details of the following structures don't matter to you, but they must +// be visible so you can handle the memory allocations for them + +struct stbrp_node +{ + stbrp_coord x,y; + stbrp_node *next; +}; + +struct stbrp_context +{ + int width; + int height; + int align; + int init_mode; + int heuristic; + int num_nodes; + stbrp_node *active_head; + stbrp_node *free_head; + stbrp_node extra[2]; // we allocate two extra nodes so optimal user-node-count is 'width' not 'width+2' +}; + +#ifdef __cplusplus +} +#endif + +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION SECTION +// + +#ifdef STB_RECT_PACK_IMPLEMENTATION +#ifndef STBRP_SORT +#include <stdlib.h> +#define STBRP_SORT qsort +#endif + +#ifndef STBRP_ASSERT +#include <assert.h> +#define STBRP_ASSERT assert +#endif + +#ifdef _MSC_VER +#define STBRP__NOTUSED(v) (void)(v) +#else +#define STBRP__NOTUSED(v) (void)sizeof(v) +#endif + +enum +{ + STBRP__INIT_skyline = 1 +}; + +STBRP_DEF void stbrp_setup_heuristic(stbrp_context *context, int heuristic) +{ + switch (context->init_mode) { + case STBRP__INIT_skyline: + STBRP_ASSERT(heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight || heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight); + context->heuristic = heuristic; + break; + default: + STBRP_ASSERT(0); + } +} + +STBRP_DEF void stbrp_setup_allow_out_of_mem(stbrp_context *context, int allow_out_of_mem) +{ + if (allow_out_of_mem) + // if it's ok to run out of memory, then don't bother aligning them; + // this gives better packing, but may fail due to OOM (even though + // the rectangles easily fit). @TODO a smarter approach would be to only + // quantize once we've hit OOM, then we could get rid of this parameter. + context->align = 1; + else { + // if it's not ok to run out of memory, then quantize the widths + // so that num_nodes is always enough nodes. + // + // I.e. num_nodes * align >= width + // align >= width / num_nodes + // align = ceil(width/num_nodes) + + context->align = (context->width + context->num_nodes-1) / context->num_nodes; + } +} + +STBRP_DEF void stbrp_init_target(stbrp_context *context, int width, int height, stbrp_node *nodes, int num_nodes) +{ + int i; +#ifndef STBRP_LARGE_RECTS + STBRP_ASSERT(width <= 0xffff && height <= 0xffff); +#endif + + for (i=0; i < num_nodes-1; ++i) + nodes[i].next = &nodes[i+1]; + nodes[i].next = NULL; + context->init_mode = STBRP__INIT_skyline; + context->heuristic = STBRP_HEURISTIC_Skyline_default; + context->free_head = &nodes[0]; + context->active_head = &context->extra[0]; + context->width = width; + context->height = height; + context->num_nodes = num_nodes; + stbrp_setup_allow_out_of_mem(context, 0); + + // node 0 is the full width, node 1 is the sentinel (lets us not store width explicitly) + context->extra[0].x = 0; + context->extra[0].y = 0; + context->extra[0].next = &context->extra[1]; + context->extra[1].x = (stbrp_coord) width; +#ifdef STBRP_LARGE_RECTS + context->extra[1].y = (1<<30); +#else + context->extra[1].y = 65535; +#endif + context->extra[1].next = NULL; +} + +// find minimum y position if it starts at x1 +static int stbrp__skyline_find_min_y(stbrp_context *c, stbrp_node *first, int x0, int width, int *pwaste) +{ + stbrp_node *node = first; + int x1 = x0 + width; + int min_y, visited_width, waste_area; + + STBRP__NOTUSED(c); + + STBRP_ASSERT(first->x <= x0); + + #if 0 + // skip in case we're past the node + while (node->next->x <= x0) + ++node; + #else + STBRP_ASSERT(node->next->x > x0); // we ended up handling this in the caller for efficiency + #endif + + STBRP_ASSERT(node->x <= x0); + + min_y = 0; + waste_area = 0; + visited_width = 0; + while (node->x < x1) { + if (node->y > min_y) { + // raise min_y higher. + // we've accounted for all waste up to min_y, + // but we'll now add more waste for everything we've visted + waste_area += visited_width * (node->y - min_y); + min_y = node->y; + // the first time through, visited_width might be reduced + if (node->x < x0) + visited_width += node->next->x - x0; + else + visited_width += node->next->x - node->x; + } else { + // add waste area + int under_width = node->next->x - node->x; + if (under_width + visited_width > width) + under_width = width - visited_width; + waste_area += under_width * (min_y - node->y); + visited_width += under_width; + } + node = node->next; + } + + *pwaste = waste_area; + return min_y; +} + +typedef struct +{ + int x,y; + stbrp_node **prev_link; +} stbrp__findresult; + +static stbrp__findresult stbrp__skyline_find_best_pos(stbrp_context *c, int width, int height) +{ + int best_waste = (1<<30), best_x, best_y = (1 << 30); + stbrp__findresult fr; + stbrp_node **prev, *node, *tail, **best = NULL; + + // align to multiple of c->align + width = (width + c->align - 1); + width -= width % c->align; + STBRP_ASSERT(width % c->align == 0); + + // if it can't possibly fit, bail immediately + if (width > c->width || height > c->height) { + fr.prev_link = NULL; + fr.x = fr.y = 0; + return fr; + } + + node = c->active_head; + prev = &c->active_head; + while (node->x + width <= c->width) { + int y,waste; + y = stbrp__skyline_find_min_y(c, node, node->x, width, &waste); + if (c->heuristic == STBRP_HEURISTIC_Skyline_BL_sortHeight) { // actually just want to test BL + // bottom left + if (y < best_y) { + best_y = y; + best = prev; + } + } else { + // best-fit + if (y + height <= c->height) { + // can only use it if it first vertically + if (y < best_y || (y == best_y && waste < best_waste)) { + best_y = y; + best_waste = waste; + best = prev; + } + } + } + prev = &node->next; + node = node->next; + } + + best_x = (best == NULL) ? 0 : (*best)->x; + + // if doing best-fit (BF), we also have to try aligning right edge to each node position + // + // e.g, if fitting + // + // ____________________ + // |____________________| + // + // into + // + // | | + // | ____________| + // |____________| + // + // then right-aligned reduces waste, but bottom-left BL is always chooses left-aligned + // + // This makes BF take about 2x the time + + if (c->heuristic == STBRP_HEURISTIC_Skyline_BF_sortHeight) { + tail = c->active_head; + node = c->active_head; + prev = &c->active_head; + // find first node that's admissible + while (tail->x < width) + tail = tail->next; + while (tail) { + int xpos = tail->x - width; + int y,waste; + STBRP_ASSERT(xpos >= 0); + // find the left position that matches this + while (node->next->x <= xpos) { + prev = &node->next; + node = node->next; + } + STBRP_ASSERT(node->next->x > xpos && node->x <= xpos); + y = stbrp__skyline_find_min_y(c, node, xpos, width, &waste); + if (y + height <= c->height) { + if (y <= best_y) { + if (y < best_y || waste < best_waste || (waste==best_waste && xpos < best_x)) { + best_x = xpos; + STBRP_ASSERT(y <= best_y); + best_y = y; + best_waste = waste; + best = prev; + } + } + } + tail = tail->next; + } + } + + fr.prev_link = best; + fr.x = best_x; + fr.y = best_y; + return fr; +} + +static stbrp__findresult stbrp__skyline_pack_rectangle(stbrp_context *context, int width, int height) +{ + // find best position according to heuristic + stbrp__findresult res = stbrp__skyline_find_best_pos(context, width, height); + stbrp_node *node, *cur; + + // bail if: + // 1. it failed + // 2. the best node doesn't fit (we don't always check this) + // 3. we're out of memory + if (res.prev_link == NULL || res.y + height > context->height || context->free_head == NULL) { + res.prev_link = NULL; + return res; + } + + // on success, create new node + node = context->free_head; + node->x = (stbrp_coord) res.x; + node->y = (stbrp_coord) (res.y + height); + + context->free_head = node->next; + + // insert the new node into the right starting point, and + // let 'cur' point to the remaining nodes needing to be + // stiched back in + + cur = *res.prev_link; + if (cur->x < res.x) { + // preserve the existing one, so start testing with the next one + stbrp_node *next = cur->next; + cur->next = node; + cur = next; + } else { + *res.prev_link = node; + } + + // from here, traverse cur and free the nodes, until we get to one + // that shouldn't be freed + while (cur->next && cur->next->x <= res.x + width) { + stbrp_node *next = cur->next; + // move the current node to the free list + cur->next = context->free_head; + context->free_head = cur; + cur = next; + } + + // stitch the list back in + node->next = cur; + + if (cur->x < res.x + width) + cur->x = (stbrp_coord) (res.x + width); + +#ifdef _DEBUG + cur = context->active_head; + while (cur->x < context->width) { + STBRP_ASSERT(cur->x < cur->next->x); + cur = cur->next; + } + STBRP_ASSERT(cur->next == NULL); + + { + int count=0; + cur = context->active_head; + while (cur) { + cur = cur->next; + ++count; + } + cur = context->free_head; + while (cur) { + cur = cur->next; + ++count; + } + STBRP_ASSERT(count == context->num_nodes+2); + } +#endif + + return res; +} + +static int rect_height_compare(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + if (p->h > q->h) + return -1; + if (p->h < q->h) + return 1; + return (p->w > q->w) ? -1 : (p->w < q->w); +} + +static int rect_original_order(const void *a, const void *b) +{ + const stbrp_rect *p = (const stbrp_rect *) a; + const stbrp_rect *q = (const stbrp_rect *) b; + return (p->was_packed < q->was_packed) ? -1 : (p->was_packed > q->was_packed); +} + +#ifdef STBRP_LARGE_RECTS +#define STBRP__MAXVAL 0xffffffff +#else +#define STBRP__MAXVAL 0xffff +#endif + +STBRP_DEF int stbrp_pack_rects(stbrp_context *context, stbrp_rect *rects, int num_rects) +{ + int i, all_rects_packed = 1; + + // we use the 'was_packed' field internally to allow sorting/unsorting + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = i; + } + + // sort according to heuristic + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_height_compare); + + for (i=0; i < num_rects; ++i) { + if (rects[i].w == 0 || rects[i].h == 0) { + rects[i].x = rects[i].y = 0; // empty rect needs no space + } else { + stbrp__findresult fr = stbrp__skyline_pack_rectangle(context, rects[i].w, rects[i].h); + if (fr.prev_link) { + rects[i].x = (stbrp_coord) fr.x; + rects[i].y = (stbrp_coord) fr.y; + } else { + rects[i].x = rects[i].y = STBRP__MAXVAL; + } + } + } + + // unsort + STBRP_SORT(rects, num_rects, sizeof(rects[0]), rect_original_order); + + // set was_packed flags and all_rects_packed status + for (i=0; i < num_rects; ++i) { + rects[i].was_packed = !(rects[i].x == STBRP__MAXVAL && rects[i].y == STBRP__MAXVAL); + if (!rects[i].was_packed) + all_rects_packed = 0; + } + + // return the all_rects_packed status + return all_rects_packed; +} +#endif + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/thirdparty/oidn/LICENSE.txt b/thirdparty/oidn/LICENSE.txt new file mode 100644 index 0000000000..d645695673 --- /dev/null +++ b/thirdparty/oidn/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/thirdparty/oidn/common/barrier.h b/thirdparty/oidn/common/barrier.h new file mode 100644 index 0000000000..b20f670053 --- /dev/null +++ b/thirdparty/oidn/common/barrier.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include <mutex> +#include <condition_variable> + +namespace oidn { + + class Barrier + { + private: + std::mutex m; + std::condition_variable cv; + volatile int count; + + public: + Barrier(int count) : count(count) {} + + void wait() + { + std::unique_lock<std::mutex> lk(m); + count--; + + if (count == 0) + { + lk.unlock(); + cv.notify_all(); + } + else + { + cv.wait(lk, [&]{ return count == 0; }); + } + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h new file mode 100644 index 0000000000..18069c6a7d --- /dev/null +++ b/thirdparty/oidn/common/exception.h @@ -0,0 +1,45 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include <exception> +#include "platform.h" + +namespace oidn { + + class Exception : public std::exception + { + private: + Error error; + const char* message; + + public: + Exception(Error error, const char* message) + : error(error), message(message) {} + + Error code() const noexcept + { + return error; + } + + const char* what() const noexcept override + { + return message; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp new file mode 100644 index 0000000000..59a14ff47c --- /dev/null +++ b/thirdparty/oidn/common/platform.cpp @@ -0,0 +1,114 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "platform.h" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + void* alignedMalloc(size_t size, size_t alignment) + { + if (size == 0) + return nullptr; + + assert((alignment & (alignment-1)) == 0); + void* ptr = _mm_malloc(size, alignment); + + if (ptr == nullptr) + throw std::bad_alloc(); + + return ptr; + } + + void alignedFree(void* ptr) + { + if (ptr) + _mm_free(ptr); + } + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName() + { + std::string name; + + #if defined(__linux__) + name = "Linux"; + #elif defined(__FreeBSD__) + name = "FreeBSD"; + #elif defined(__CYGWIN__) + name = "Cygwin"; + #elif defined(_WIN32) + name = "Windows"; + #elif defined(__APPLE__) + name = "macOS"; + #elif defined(__unix__) + name = "Unix"; + #else + return "Unknown"; + #endif + + #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__) + name += " (64-bit)"; + #else + name += " (32-bit)"; + #endif + + return name; + } + + std::string getCompilerName() + { + #if defined(__INTEL_COMPILER) + int mayor = __INTEL_COMPILER / 100 % 100; + int minor = __INTEL_COMPILER % 100; + std::string version = "Intel Compiler "; + version += toString(mayor); + version += "." + toString(minor); + #if defined(__INTEL_COMPILER_UPDATE) + version += "." + toString(__INTEL_COMPILER_UPDATE); + #endif + return version; + #elif defined(__clang__) + return "Clang " __clang_version__; + #elif defined(__GNUC__) + return "GCC " __VERSION__; + #elif defined(_MSC_VER) + std::string version = toString(_MSC_FULL_VER); + version.insert(4, "."); + version.insert(9, "."); + version.insert(2, "."); + return "Visual C++ Compiler " + version; + #else + return "Unknown"; + #endif + } + + std::string getBuildName() + { + #if defined(NDEBUG) + return "Release"; + #else + return "Debug"; + #endif + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h new file mode 100644 index 0000000000..205ac8981d --- /dev/null +++ b/thirdparty/oidn/common/platform.h @@ -0,0 +1,131 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #define NOMINMAX + #include <Windows.h> +#elif defined(__APPLE__) + #include <sys/sysctl.h> +#endif + +#include <xmmintrin.h> +#include <cstdint> +#include <climits> +#include <limits> +#include <atomic> +#include <algorithm> +#include <memory> +#include <cmath> +#include <string> +#include <sstream> +#include <iostream> +#include <cassert> +#include "include/OpenImageDenoise/oidn.hpp" + +namespace oidn { + + // ---------------------------------------------------------------------------- + // Macros + // ---------------------------------------------------------------------------- + + #if defined(_WIN32) + // Windows + #if !defined(__noinline) + #define __noinline __declspec(noinline) + #endif + #else + // Unix + #if !defined(__forceinline) + #define __forceinline inline __attribute__((always_inline)) + #endif + #if !defined(__noinline) + #define __noinline __attribute__((noinline)) + #endif + #endif + + #ifndef UNUSED + #define UNUSED(x) ((void)x) + #endif + #ifndef MAYBE_UNUSED + #define MAYBE_UNUSED(x) UNUSED(x) + #endif + + // ---------------------------------------------------------------------------- + // Error handling and debugging + // ---------------------------------------------------------------------------- + + struct Verbose + { + int verbose; + + Verbose(int v = 0) : verbose(v) {} + __forceinline bool isVerbose(int v = 1) const { return v <= verbose; } + }; + + #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; } + #define OIDN_FATAL(message) throw std::runtime_error(message); + + // ---------------------------------------------------------------------------- + // Common functions + // ---------------------------------------------------------------------------- + + using std::min; + using std::max; + + template<typename T> + __forceinline T clamp(const T& value, const T& minValue, const T& maxValue) + { + return min(max(value, minValue), maxValue); + } + + void* alignedMalloc(size_t size, size_t alignment); + void alignedFree(void* ptr); + + template<typename T> + inline std::string toString(const T& a) + { + std::stringstream sm; + sm << a; + return sm.str(); + } + +#if defined(__APPLE__) + template<typename T> + bool getSysctl(const char* name, T& value) + { + int64_t result = 0; + size_t size = sizeof(result); + + if (sysctlbyname(name, &result, &size, nullptr, 0) != 0) + return false; + + value = T(result); + return true; + } +#endif + + // ---------------------------------------------------------------------------- + // System information + // ---------------------------------------------------------------------------- + + std::string getPlatformName(); + std::string getCompilerName(); + std::string getBuildName(); + +} // namespace oidn diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h new file mode 100644 index 0000000000..de44603af2 --- /dev/null +++ b/thirdparty/oidn/common/ref.h @@ -0,0 +1,163 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +namespace oidn { + + class RefCount + { + private: + std::atomic<size_t> count; + + public: + __forceinline RefCount(int count = 0) noexcept : count(count) {} + + __forceinline size_t incRef() noexcept + { + return count.fetch_add(1) + 1; + } + + __forceinline size_t decRef() + { + const size_t newCount = decRefKeep(); + if (newCount == 0) + destroy(); + return newCount; + } + + __forceinline size_t decRefKeep() noexcept + { + return count.fetch_add(-1) - 1; + } + + __forceinline void destroy() + { + delete this; + } + + protected: + // Disable copying + RefCount(const RefCount&) = delete; + RefCount& operator =(const RefCount&) = delete; + + virtual ~RefCount() noexcept = default; + }; + + template<typename T> + class Ref + { + private: + T* ptr; + + public: + __forceinline Ref() noexcept : ptr(nullptr) {} + __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {} + __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); } + __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; } + __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + template<typename Y> + __forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); } + + template<typename Y> + __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); } + + __forceinline ~Ref() { if (ptr) ptr->decRef(); } + + __forceinline Ref& operator =(const Ref& other) + { + if (other.ptr) + other.ptr->incRef(); + if (ptr) + ptr->decRef(); + ptr = other.ptr; + return *this; + } + + __forceinline Ref& operator =(Ref&& other) + { + if (ptr) + ptr->decRef(); + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + __forceinline Ref& operator =(T* other) + { + if (other) + other->incRef(); + if (ptr) + ptr->decRef(); + ptr = other; + return *this; + } + + __forceinline Ref& operator =(std::nullptr_t) + { + if (ptr) + ptr->decRef(); + ptr = nullptr; + return *this; + } + + __forceinline operator bool() const noexcept { return ptr != nullptr; } + + __forceinline T& operator *() const noexcept { return *ptr; } + __forceinline T* operator ->() const noexcept { return ptr; } + + __forceinline T* get() const noexcept { return ptr; } + + __forceinline T* detach() noexcept + { + T* res = ptr; + ptr = nullptr; + return res; + } + }; + + template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr < b.ptr; } + + template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr == nullptr; } + template<typename T> __forceinline bool operator ==(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr == b.ptr; } + template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr == b.ptr; } + + template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr != nullptr; } + template<typename T> __forceinline bool operator !=(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr != b.ptr; } + template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr != b.ptr; } + + template<typename T, typename... Args> + __forceinline Ref<T> makeRef(Args&&... args) + { + return Ref<T>(new T(std::forward<Args>(args)...)); + } + + template<typename T, typename Y> + __forceinline Ref<Y> staticRefCast(const Ref<T>& a) + { + return Ref<Y>(static_cast<Y*>(a.get())); + } + + template<typename T, typename Y> + __forceinline Ref<Y> dynamicRefCast(const Ref<T>& a) + { + return Ref<Y>(dynamic_cast<Y*>(a.get())); + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp new file mode 100644 index 0000000000..0249f2e141 --- /dev/null +++ b/thirdparty/oidn/common/tensor.cpp @@ -0,0 +1,83 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "exception.h" +#include "tensor.h" + +namespace oidn { + + std::map<std::string, Tensor> parseTensors(void* buffer) + { + char* input = (char*)buffer; + + // Parse the magic value + const int magic = *(unsigned short*)input; + if (magic != 0x41D7) + throw Exception(Error::InvalidOperation, "invalid tensor archive"); + input += sizeof(unsigned short); + + // Parse the version + const int majorVersion = *(unsigned char*)input++; + const int minorVersion = *(unsigned char*)input++; + UNUSED(minorVersion); + if (majorVersion > 1) + throw Exception(Error::InvalidOperation, "unsupported tensor archive version"); + + // Parse the number of tensors + const int numTensors = *(int*)input; + input += sizeof(int); + + // Parse the tensors + std::map<std::string, Tensor> tensorMap; + for (int i = 0; i < numTensors; ++i) + { + Tensor tensor; + + // Parse the name + const int nameLen = *(unsigned char*)input++; + std::string name(input, nameLen); + input += nameLen; + + // Parse the number of dimensions + const int ndims = *(unsigned char*)input++; + + // Parse the shape of the tensor + tensor.dims.resize(ndims); + for (int i = 0; i < ndims; ++i) + tensor.dims[i] = ((int*)input)[i]; + input += ndims * sizeof(int); + + // Parse the format of the tensor + tensor.format = std::string(input, input + ndims); + input += ndims; + + // Parse the data type of the tensor + const char type = *(unsigned char*)input++; + if (type != 'f') // only float32 is supported + throw Exception(Error::InvalidOperation, "unsupported tensor data type"); + + // Skip the data + tensor.data = (float*)input; + input += tensor.size() * sizeof(float); + + // Add the tensor to the map + tensorMap.emplace(name, std::move(tensor)); + } + + return tensorMap; + } + +} // namespace oidn diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h new file mode 100644 index 0000000000..48e7d1123d --- /dev/null +++ b/thirdparty/oidn/common/tensor.h @@ -0,0 +1,66 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include <vector> +#include <map> + +namespace oidn { + + template<typename T> + using shared_vector = std::shared_ptr<std::vector<T>>; + + // Generic tensor + struct Tensor + { + float* data; + std::vector<int64_t> dims; + std::string format; + shared_vector<char> buffer; // optional, only for reference counting + + __forceinline Tensor() : data(nullptr) {} + + __forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format) + : dims(dims), + format(format) + { + buffer = std::make_shared<std::vector<char>>(size() * sizeof(float)); + data = (float*)buffer->data(); + } + + __forceinline operator bool() const { return data != nullptr; } + + __forceinline int ndims() const { return (int)dims.size(); } + + // Returns the number of values + __forceinline size_t size() const + { + size_t size = 1; + for (int i = 0; i < ndims(); ++i) + size *= dims[i]; + return size; + } + + __forceinline float& operator [](size_t i) { return data[i]; } + __forceinline const float& operator [](size_t i) const { return data[i]; } + }; + + // Parses tensors from a buffer + std::map<std::string, Tensor> parseTensors(void* buffer); + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp new file mode 100644 index 0000000000..48c489c57b --- /dev/null +++ b/thirdparty/oidn/common/thread.cpp @@ -0,0 +1,297 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#if defined(_MSC_VER) + #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned +#endif + +#if defined(__APPLE__) + #include <mach/thread_act.h> + #include <mach/mach_init.h> +#endif + +#include "thread.h" +#include <fstream> + +namespace oidn { + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + HMODULE hLib = GetModuleHandle(TEXT("kernel32")); + pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx"); + pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity"); + + if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity) + { + // Get logical processor information + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr; + DWORD bufferSize = 0; + + // First call the function with an empty buffer to get the required buffer size + BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + return; + } + + // Allocate the buffer + buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize); + if (!buffer) + { + OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed"); + return; + } + + // Call again the function but now with the properly sized buffer + result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize); + if (!result) + { + OIDN_WARNING("GetLogicalProcessorInformationEx failed"); + free(buffer); + return; + } + + // Iterate over the logical processor information structures + // There should be one structure for each physical core + char* ptr = (char*)buffer; + while (ptr < (char*)buffer + bufferSize) + { + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr; + if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0) + { + // Iterate over the groups + int numThreads = 0; + for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group) + { + GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group]; + while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore)) + { + // Extract the next set bit/thread from the mask + GROUP_AFFINITY threadAffinity = coreAffinity; + threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask; + + // Push the affinity for this thread + affinities.push_back(threadAffinity); + oldAffinities.push_back(threadAffinity); + numThreads++; + + // Remove this bit/thread from the mask + coreAffinity.Mask ^= threadAffinity.Mask; + } + } + } + + // Next structure + ptr += item->Size; + } + + // Free the buffer + free(buffer); + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Save the current affinity and set the new one + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex])) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + // Restore the original affinity + const HANDLE thread = GetCurrentThread(); + if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr)) + OIDN_WARNING("SetThreadGroupAffinity failed"); + } + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + std::vector<int> threadIds; + + // Parse the thread/CPU topology + for (int cpuId = 0; ; cpuId++) + { + std::fstream fs; + std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list"); + fs.open(cpu.c_str(), std::fstream::in); + if (fs.fail()) break; + + int i; + int j = 0; + while ((j < numThreadsPerCore) && (fs >> i)) + { + if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; })) + threadIds.push_back(i); + + if (fs.peek() == ',') + fs.ignore(); + j++; + } + + fs.close(); + } + + #if 0 + for (size_t i = 0; i < thread_ids.size(); ++i) + std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl; + #endif + + // Create the affinity structures + affinities.resize(threadIds.size()); + oldAffinities.resize(threadIds.size()); + + for (size_t i = 0; i < threadIds.size(); ++i) + { + cpu_set_t affinity; + CPU_ZERO(&affinity); + CPU_SET(threadIds[i], &affinity); + + affinities[i] = affinity; + oldAffinities[i] = affinity; + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Save the current affinity + if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + { + OIDN_WARNING("pthread_getaffinity_np failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const pthread_t thread = pthread_self(); + + // Restore the original affinity + if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0) + OIDN_WARNING("pthread_setaffinity_np failed"); + } + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose) + : Verbose(verbose) + { + // Query the thread/CPU topology + int numPhysicalCpus; + int numLogicalCpus; + + if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus)) + { + OIDN_WARNING("sysctlbyname failed"); + return; + } + + if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1)) + return; // this shouldn't happen + const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus; + + // Create the affinity structures + // macOS doesn't support binding a thread to a specific core, but we can at least group threads which + // should be on the same core together + for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1! + { + thread_affinity_policy affinity; + affinity.affinity_tag = core; + + for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread) + { + affinities.push_back(affinity); + oldAffinities.push_back(affinity); + } + } + } + + void ThreadAffinity::set(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Save the current affinity + mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT; + boolean_t getDefault = FALSE; + if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS) + { + OIDN_WARNING("thread_policy_get failed"); + oldAffinities[threadIndex] = affinities[threadIndex]; + return; + } + + // Set the new affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + + void ThreadAffinity::restore(int threadIndex) + { + if (threadIndex >= (int)affinities.size()) + return; + + const auto thread = mach_thread_self(); + + // Restore the original affinity + if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS) + OIDN_WARNING("thread_policy_set failed"); + } + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h new file mode 100644 index 0000000000..2c731367da --- /dev/null +++ b/thirdparty/oidn/common/thread.h @@ -0,0 +1,202 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" + +#if !defined(_WIN32) + #include <pthread.h> + #include <sched.h> + #if defined(__APPLE__) + #include <mach/thread_policy.h> + #endif +#endif + +#include <vector> +#include <mutex> + +namespace oidn { + + // -------------------------------------------------------------------------- + // ThreadLocal + // -------------------------------------------------------------------------- + + // Wrapper which makes any variable thread-local + template<typename T> + class ThreadLocal : public Verbose + { + private: + #if defined(_WIN32) + DWORD key; + #else + pthread_key_t key; + #endif + + std::vector<T*> instances; + std::mutex mutex; + + public: + ThreadLocal(int verbose = 0) + : Verbose(verbose) + { + #if defined(_WIN32) + key = TlsAlloc(); + if (key == TLS_OUT_OF_INDEXES) + OIDN_FATAL("TlsAlloc failed"); + #else + if (pthread_key_create(&key, nullptr) != 0) + OIDN_FATAL("pthread_key_create failed"); + #endif + } + + ~ThreadLocal() + { + std::lock_guard<std::mutex> lock(mutex); + for (T* ptr : instances) + delete ptr; + + #if defined(_WIN32) + if (!TlsFree(key)) + OIDN_WARNING("TlsFree failed"); + #else + if (pthread_key_delete(key) != 0) + OIDN_WARNING("pthread_key_delete failed"); + #endif + } + + T& get() + { + #if defined(_WIN32) + T* ptr = (T*)TlsGetValue(key); + #else + T* ptr = (T*)pthread_getspecific(key); + #endif + + if (ptr) + return *ptr; + + ptr = new T; + std::lock_guard<std::mutex> lock(mutex); + instances.push_back(ptr); + + #if defined(_WIN32) + if (!TlsSetValue(key, ptr)) + OIDN_FATAL("TlsSetValue failed"); + #else + if (pthread_setspecific(key, ptr) != 0) + OIDN_FATAL("pthread_setspecific failed"); + #endif + + return *ptr; + } + }; + +#if defined(_WIN32) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Windows + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP, + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, + PDWORD); + + typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE, + CONST GROUP_AFFINITY*, + PGROUP_AFFINITY); + + GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr; + SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr; + + std::vector<GROUP_AFFINITY> affinities; // thread affinities + std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__linux__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - Linux + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector<cpu_set_t> affinities; // thread affinities + std::vector<cpu_set_t> oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#elif defined(__APPLE__) + + // -------------------------------------------------------------------------- + // ThreadAffinity - macOS + // -------------------------------------------------------------------------- + + class ThreadAffinity : public Verbose + { + private: + std::vector<thread_affinity_policy> affinities; // thread affinities + std::vector<thread_affinity_policy> oldAffinities; // original thread affinities + + public: + ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0); + + int getNumThreads() const + { + return (int)affinities.size(); + } + + // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity + void set(int threadIndex); + + // Restores the affinity of the thread + void restore(int threadIndex); + }; + +#endif + +} // namespace oidn diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h new file mode 100644 index 0000000000..62aaaa1c33 --- /dev/null +++ b/thirdparty/oidn/common/timer.h @@ -0,0 +1,49 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "platform.h" +#include <chrono> + +namespace oidn { + + class Timer + { + private: + using clock = std::chrono::high_resolution_clock; + + std::chrono::time_point<clock> start; + + public: + Timer() + { + reset(); + } + + void reset() + { + start = clock::now(); + } + + double query() const + { + auto end = clock::now(); + return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count(); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp new file mode 100644 index 0000000000..7353fe4e25 --- /dev/null +++ b/thirdparty/oidn/core/api.cpp @@ -0,0 +1,408 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#ifdef _WIN32 +# define OIDN_API extern "C" __declspec(dllexport) +#else +# define OIDN_API extern "C" __attribute__ ((visibility ("default"))) +#endif + +// Locks the device that owns the specified object +// Use *only* inside OIDN_TRY/CATCH! +#define OIDN_LOCK(obj) \ + std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex()); + +// Try/catch for converting exceptions to errors +#define OIDN_TRY \ + try { + +#define OIDN_CATCH(obj) \ + } catch (Exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \ + } catch (std::bad_alloc&) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + } catch (mkldnn::error& e) { \ + if (e.status == mkldnn_out_of_memory) \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \ + else \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \ + } catch (std::exception& e) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \ + } catch (...) { \ + Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \ + } + +#include "device.h" +#include "filter.h" +#include <mutex> + +namespace oidn { + + namespace + { + __forceinline void checkHandle(void* handle) + { + if (handle == nullptr) + throw Exception(Error::InvalidArgument, "invalid handle"); + } + + template<typename T> + __forceinline void retainObject(T* obj) + { + if (obj) + { + obj->incRef(); + } + else + { + OIDN_TRY + checkHandle(obj); + OIDN_CATCH(obj) + } + } + + template<typename T> + __forceinline void releaseObject(T* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + OIDN_LOCK(obj); + obj->destroy(); + OIDN_CATCH(obj) + } + } + + template<> + __forceinline void releaseObject(Device* obj) + { + if (obj == nullptr || obj->decRefKeep() == 0) + { + OIDN_TRY + checkHandle(obj); + // Do NOT lock the device because it owns the mutex + obj->destroy(); + OIDN_CATCH(obj) + } + } + } + + OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type) + { + Ref<Device> device = nullptr; + OIDN_TRY + if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT) + device = makeRef<Device>(); + else + throw Exception(Error::InvalidArgument, "invalid device type"); + OIDN_CATCH(device) + return (OIDNDevice)device.detach(); + } + + OIDN_API void oidnRetainDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + retainObject(device); + } + + OIDN_API void oidnReleaseDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + releaseObject(device); + } + + OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->set1i(name, value); + OIDN_CATCH(device) + } + + OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return false; + } + + OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + return device->get1i(name); + OIDN_CATCH(device) + return 0; + } + + OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->setErrorFunction((ErrorFunction)func, userPtr); + OIDN_CATCH(device) + } + + OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage) + { + Device* device = (Device*)hDevice; + OIDN_TRY + return (OIDNError)Device::getError(device, outMessage); + OIDN_CATCH(device) + if (outMessage) *outMessage = ""; + return OIDN_ERROR_UNKNOWN; + } + + OIDN_API void oidnCommitDevice(OIDNDevice hDevice) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + device->commit(); + OIDN_CATCH(device) + } + + OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref<Buffer> buffer = device->newBuffer(byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref<Buffer> buffer = device->newBuffer(ptr, byteSize); + return (OIDNBuffer)buffer.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + retainObject(buffer); + } + + OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer) + { + Buffer* buffer = (Buffer*)hBuffer; + releaseObject(buffer); + } + + OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->map(byteOffset, byteSize); + OIDN_CATCH(buffer) + return nullptr; + } + + OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr) + { + Buffer* buffer = (Buffer*)hBuffer; + OIDN_TRY + checkHandle(hBuffer); + OIDN_LOCK(buffer); + return buffer->unmap(mappedPtr); + OIDN_CATCH(buffer) + } + + OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type) + { + Device* device = (Device*)hDevice; + OIDN_TRY + checkHandle(hDevice); + OIDN_LOCK(device); + Ref<Filter> filter = device->newFilter(type); + return (OIDNFilter)filter.detach(); + OIDN_CATCH(device) + return nullptr; + } + + OIDN_API void oidnRetainFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + retainObject(filter); + } + + OIDN_API void oidnReleaseFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + releaseObject(filter); + } + + OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name, + OIDNBuffer hBuffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + checkHandle(hBuffer); + OIDN_LOCK(filter); + Ref<Buffer> buffer = (Buffer*)hBuffer; + if (buffer->getDevice() != filter->getDevice()) + throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices"); + Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride); + filter->setImage(name, data); + OIDN_CATCH(filter) + } + + OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, int(value)); + OIDN_CATCH(filter) + } + + OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return false; + } + + OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1i(name, value); + OIDN_CATCH(filter) + } + + OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1i(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->set1f(name, value); + OIDN_CATCH(filter) + } + + OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + return filter->get1f(name); + OIDN_CATCH(filter) + return 0; + } + + OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->setProgressMonitorFunction(func, userPtr); + OIDN_CATCH(filter) + } + + OIDN_API void oidnCommitFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->commit(); + OIDN_CATCH(filter) + } + + OIDN_API void oidnExecuteFilter(OIDNFilter hFilter) + { + Filter* filter = (Filter*)hFilter; + OIDN_TRY + checkHandle(hFilter); + OIDN_LOCK(filter); + filter->execute(); + OIDN_CATCH(filter) + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp new file mode 100644 index 0000000000..8ae2421fa6 --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.cpp @@ -0,0 +1,519 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "autoencoder.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter + // -------------------------------------------------------------------------- + + AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device) + : Filter(device) + { + } + + void AutoencoderFilter::setImage(const std::string& name, const Image& data) + { + if (name == "color") + color = data; + else if (name == "albedo") + albedo = data; + else if (name == "normal") + normal = data; + else if (name == "output") + output = data; + + dirty = true; + } + + void AutoencoderFilter::set1i(const std::string& name, int value) + { + if (name == "hdr") + hdr = value; + else if (name == "srgb") + srgb = value; + else if (name == "maxMemoryMB") + maxMemoryMB = value; + + dirty = true; + } + + int AutoencoderFilter::get1i(const std::string& name) + { + if (name == "hdr") + return hdr; + else if (name == "srgb") + return srgb; + else if (name == "maxMemoryMB") + return maxMemoryMB; + else if (name == "alignment") + return alignment; + else if (name == "overlap") + return overlap; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::set1f(const std::string& name, float value) + { + if (name == "hdrScale") + hdrScale = value; + + dirty = true; + } + + float AutoencoderFilter::get1f(const std::string& name) + { + if (name == "hdrScale") + return hdrScale; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void AutoencoderFilter::commit() + { + if (!dirty) + return; + + { + if (mayiuse(avx512_common)) + net = buildNet<16>(); + else + net = buildNet<8>(); + } + + dirty = false; + } + + void AutoencoderFilter::execute() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the filter are not committed"); + + if (!net) + return; + + { + Progress progress; + progress.func = progressFunc; + progress.userPtr = progressUserPtr; + progress.taskCount = tileCountH * tileCountW; + + // Iterate over the tiles + int tileIndex = 0; + + for (int i = 0; i < tileCountH; ++i) + { + const int h = i * (tileH - 2*overlap); // input tile position (including overlap) + const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top + const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom + const int tileH1 = min(H - h, tileH); // input tile size (including overlap) + const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size + const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer + + for (int j = 0; j < tileCountW; ++j) + { + const int w = j * (tileW - 2*overlap); // input tile position (including overlap) + const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left + const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right + const int tileW1 = min(W - w, tileW); // input tile size (including overlap) + const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size + const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer + + // Set the input tile + inputReorder->setTile(h, w, + alignOffsetH, alignOffsetW, + tileH1, tileW1); + + // Set the output tile + outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW, + h + overlapBeginH, w + overlapBeginW, + tileH2, tileW2); + + //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2); + + // Denoise the tile + net->execute(progress, tileIndex); + + // Next tile + tileIndex++; + } + } + } + } + + void AutoencoderFilter::computeTileSize() + { + const int minTileSize = 3*overlap; + const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8; + const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel; + + tileCountH = 1; + tileCountW = 1; + tileH = roundUp(H, alignment); + tileW = roundUp(W, alignment); + + // Divide the image into tiles until the tile size gets below the threshold + while (int64_t(tileH) * tileW > maxTilePixels) + { + if (tileH > minTileSize && tileH > tileW) + { + tileCountH++; + tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize); + } + else if (tileW > minTileSize) + { + tileCountW++; + tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize); + } + else + break; + } + + // Compute the final number of tiles + tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1; + tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1; + + if (device->isVerbose(2)) + { + std::cout << "Tile size : " << tileW << "x" << tileH << std::endl; + std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl; + } + } + + template<int K> + std::shared_ptr<Executable> AutoencoderFilter::buildNet() + { + H = color.height; + W = color.width; + + // Configure the network + int inputC; + void* weightPtr; + + if (srgb && hdr) + throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time"); + + if (color && !albedo && !normal && weightData.hdr) + { + inputC = 3; + weightPtr = hdr ? weightData.hdr : weightData.ldr; + } + else if (color && albedo && !normal && weightData.hdr_alb) + { + inputC = 6; + weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb; + } + else if (color && albedo && normal && weightData.hdr_alb_nrm) + { + inputC = 9; + weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm; + } + else + { + throw Exception(Error::InvalidOperation, "unsupported combination of input features"); + } + + if (!output) + throw Exception(Error::InvalidOperation, "output image not specified"); + + if ((color.format != Format::Float3) + || (albedo && albedo.format != Format::Float3) + || (normal && normal.format != Format::Float3) + || (output.format != Format::Float3)) + throw Exception(Error::InvalidOperation, "unsupported image format"); + + if ((albedo && (albedo.width != W || albedo.height != H)) + || (normal && (normal.width != W || normal.height != H)) + || (output.width != W || output.height != H)) + throw Exception(Error::InvalidOperation, "image size mismatch"); + + // Compute the tile size + computeTileSize(); + + // If the image size is zero, there is nothing else to do + if (H <= 0 || W <= 0) + return nullptr; + + // Parse the weights + const auto weightMap = parseTensors(weightPtr); + + // Create the network + std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap); + + // Compute the tensor sizes + const auto inputDims = memory::dims({1, inputC, tileH, tileW}); + const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0 + + const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0 + const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1 + const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1 + const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0 + const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2 + const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0 + const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3 + const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0 + const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4 + const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0 + const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1 + const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4 + const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims); + const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0 + const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1 + const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3 + const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims); + const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0 + const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1 + const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2 + const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims); + const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0 + const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1 + const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1 + const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims); + const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0 + const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1 + const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0 + const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims); + const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0 + const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1 + const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0 + + const auto outputDims = memory::dims({1, 3, tileH, tileW}); + + // Allocate two temporary ping-pong buffers to decrease memory usage + const auto temp0Dims = getMaxTensorDims({ + conv1Dims, + conv2Dims, + conv3Dims, + conv4Dims, + conv5Dims, + conv6Dims, + conv7Dims, + conv8Dims, + conv9Dims, + conv10Dims, + conv11Dims + }); + + const auto temp1Dims = getMaxTensorDims({ + conv1bDims, + pool5Dims, + conv6bDims, + conv7bDims, + conv8bDims, + conv9bDims, + conv10bDims, + }); + + auto temp0 = net->allocTensor(temp0Dims); + auto temp1 = net->allocTensor(temp1Dims); + + // Allocate enough memory to hold the concat outputs. Then use the first + // half to hold the previous conv output and the second half to hold the + // pool/orig image output. This works because everything is C dimension + // outermost, padded to K floats, and all the concats are on the C dimension. + auto concat0Dst = net->allocTensor(concat0Dims); + auto concat1Dst = net->allocTensor(concat1Dims); + auto concat2Dst = net->allocTensor(concat2Dims); + auto concat3Dst = net->allocTensor(concat3Dims); + auto concat4Dst = net->allocTensor(concat4Dims); + + // Transfer function + std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc(); + + // Autoexposure + if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc)) + { + if (isnan(hdrScale)) + net->addAutoexposure(color, tf); + else + tf->setExposure(hdrScale); + } + + // Input reorder + auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims); + inputReorder = net->addInputReorder(color, albedo, normal, + transferFunc, + alignment, inputReorderDst); + + // conv1 + auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0); + + // conv1b + auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1); + + // pool1 + // Adjust pointer for pool1 to eliminate concat1 + auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims); + auto pool1 = net->addPool(conv1b->getDst(), pool1Dst); + + // conv2 + auto conv2 = net->addConv("conv2", pool1->getDst(), temp0); + + // pool2 + // Adjust pointer for pool2 to eliminate concat2 + auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims); + auto pool2 = net->addPool(conv2->getDst(), pool2Dst); + + // conv3 + auto conv3 = net->addConv("conv3", pool2->getDst(), temp0); + + // pool3 + // Adjust pointer for pool3 to eliminate concat3 + auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims); + auto pool3 = net->addPool(conv3->getDst(), pool3Dst); + + // conv4 + auto conv4 = net->addConv("conv4", pool3->getDst(), temp0); + + // pool4 + // Adjust pointer for pool4 to eliminate concat4 + auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims); + auto pool4 = net->addPool(conv4->getDst(), pool4Dst); + + // conv5 + auto conv5 = net->addConv("conv5", pool4->getDst(), temp0); + + // pool5 + auto pool5 = net->addPool(conv5->getDst(), temp1); + + // upsample4 + auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst); + auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst); + + // conv6 + auto conv6 = net->addConv("conv6", concat4Dst, temp0); + + // conv6b + auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1); + + // upsample3 + auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst); + auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst); + + // conv7 + auto conv7 = net->addConv("conv7", concat3Dst, temp0); + + // conv7b + auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1); + + // upsample2 + auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst); + auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst); + + // conv8 + auto conv8 = net->addConv("conv8", concat2Dst, temp0); + + // conv8b + auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1); + + // upsample1 + auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst); + auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst); + + // conv9 + auto conv9 = net->addConv("conv9", concat1Dst, temp0); + + // conv9b + auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1); + + // upsample0 + auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst); + auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst); + + // conv10 + auto conv10 = net->addConv("conv10", concat0Dst, temp0); + + // conv10b + auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1); + + // conv11 + auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */); + + // Output reorder + outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output); + + net->finalize(); + return net; + } + + std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc() + { + if (hdr) + return std::make_shared<PQXTransferFunction>(); + else if (srgb) + return std::make_shared<LinearTransferFunction>(); + else + return std::make_shared<GammaTransferFunction>(); + } + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + // -------------------------------------------------------------------------- + // RTFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // LDR + extern unsigned char rt_ldr[]; // color + extern unsigned char rt_ldr_alb[]; // color, albedo + extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal + + // HDR + extern unsigned char rt_hdr[]; // color + extern unsigned char rt_hdr_alb[]; // color, albedo + extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal + } + + RTFilter::RTFilter(const Ref<Device>& device) + : AutoencoderFilter(device) + { + weightData.ldr = weights::rt_ldr; + weightData.ldr_alb = weights::rt_ldr_alb; + weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm; + weightData.hdr = weights::rt_hdr; + weightData.hdr_alb = weights::rt_hdr_alb; + weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm; + } +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter + // -------------------------------------------------------------------------- + + namespace weights + { + // HDR + extern unsigned char rtlightmap_hdr[]; // color + } + + RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device) + : AutoencoderFilter(device) + { + weightData.hdr = weights::rtlightmap_hdr; + + hdr = true; + } + + std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc() + { + return std::make_shared<LogTransferFunction>(); + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h new file mode 100644 index 0000000000..97432f2bbd --- /dev/null +++ b/thirdparty/oidn/core/autoencoder.h @@ -0,0 +1,116 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "filter.h" +#include "network.h" +#include "transfer_function.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // AutoencoderFilter - Direct-predicting autoencoder + // -------------------------------------------------------------------------- + + class AutoencoderFilter : public Filter + { + protected: + static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary) + static constexpr int receptiveField = 222; // receptive field in pixels + static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels + + static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage + static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8 + static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16 + + Image color; + Image albedo; + Image normal; + Image output; + bool hdr = false; + float hdrScale = std::numeric_limits<float>::quiet_NaN(); + bool srgb = false; + int maxMemoryMB = 6000; // approximate maximum memory usage in MBs + + int H = 0; // image height + int W = 0; // image width + int tileH = 0; // tile height + int tileW = 0; // tile width + int tileCountH = 1; // number of tiles in H dimension + int tileCountW = 1; // number of tiles in W dimension + + std::shared_ptr<Executable> net; + std::shared_ptr<Node> inputReorder; + std::shared_ptr<Node> outputReorder; + + struct + { + void* ldr = nullptr; + void* ldr_alb = nullptr; + void* ldr_alb_nrm = nullptr; + void* hdr = nullptr; + void* hdr_alb = nullptr; + void* hdr_alb_nrm = nullptr; + } weightData; + + explicit AutoencoderFilter(const Ref<Device>& device); + virtual std::shared_ptr<TransferFunction> makeTransferFunc(); + + public: + void setImage(const std::string& name, const Image& data) override; + void set1i(const std::string& name, int value) override; + int get1i(const std::string& name) override; + void set1f(const std::string& name, float value) override; + float get1f(const std::string& name) override; + + void commit() override; + void execute() override; + + private: + void computeTileSize(); + + template<int K> + std::shared_ptr<Executable> buildNet(); + + bool isCommitted() const { return bool(net); } + }; + + // -------------------------------------------------------------------------- + // RTFilter - Generic ray tracing denoiser + // -------------------------------------------------------------------------- + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + class RTFilter : public AutoencoderFilter + { + public: + explicit RTFilter(const Ref<Device>& device); + }; +#endif + + // -------------------------------------------------------------------------- + // RTLightmapFilter - Ray traced lightmap denoiser + // -------------------------------------------------------------------------- + + class RTLightmapFilter : public AutoencoderFilter + { + public: + explicit RTLightmapFilter(const Ref<Device>& device); + std::shared_ptr<TransferFunction> makeTransferFunc() override; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h new file mode 100644 index 0000000000..b95109152e --- /dev/null +++ b/thirdparty/oidn/core/buffer.h @@ -0,0 +1,75 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" + +namespace oidn { + + class Device; + + // Buffer which may or may not own its data + class Buffer : public RefCount + { + private: + char* ptr; + size_t byteSize; + bool shared; + Ref<Device> device; + + public: + __forceinline Buffer(const Ref<Device>& device, size_t size) + : ptr((char*)alignedMalloc(size, 64)), + byteSize(size), + shared(false), + device(device) {} + + __forceinline Buffer(const Ref<Device>& device, void* data, size_t size) + : ptr((char*)data), + byteSize(size), + shared(true), + device(device) + { + if (data == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + } + + __forceinline ~Buffer() + { + if (!shared) + alignedFree(ptr); + } + + __forceinline char* data() { return ptr; } + __forceinline const char* data() const { return ptr; } + __forceinline size_t size() const { return byteSize; } + + void* map(size_t offset, size_t size) + { + if (offset + size > byteSize) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + + return ptr + offset; + } + + void unmap(void* mappedPtr) {} + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h new file mode 100644 index 0000000000..6c87f377bc --- /dev/null +++ b/thirdparty/oidn/core/common.h @@ -0,0 +1,133 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +#include "mkl-dnn/include/mkldnn.hpp" +#include "mkl-dnn/include/mkldnn_debug.h" +#include "mkl-dnn/src/common/mkldnn_thread.hpp" +#include "mkl-dnn/src/common/type_helpers.hpp" +#include "mkl-dnn/src/cpu/jit_generator.hpp" + +#include "common/ref.h" +#include "common/exception.h" +#include "common/thread.h" +#include "math.h" + +namespace oidn { + + using namespace mkldnn; + using namespace mkldnn::impl::cpu; + using mkldnn::impl::parallel_nd; + using mkldnn::impl::memory_desc_matches_tag; + + + inline size_t getFormatBytes(Format format) + { + switch (format) + { + case Format::Undefined: return 1; + case Format::Float: return sizeof(float); + case Format::Float2: return sizeof(float)*2; + case Format::Float3: return sizeof(float)*3; + case Format::Float4: return sizeof(float)*4; + } + assert(0); + return 0; + } + + + inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]); + } + + inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem) + { + const mkldnn_memory_desc_t& desc = mem->get_desc().data; + return memory::data_type(desc.data_type); + } + + // Returns the number of values in a tensor + inline size_t getTensorSize(const memory::dims& dims) + { + size_t res = 1; + for (int i = 0; i < (int)dims.size(); ++i) + res *= dims[i]; + return res; + } + + inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims) + { + memory::dims result; + size_t maxSize = 0; + + for (const auto& d : dims) + { + const size_t size = getTensorSize(d); + if (size > maxSize) + { + result = d; + maxSize = size; + } + } + + return result; + } + + inline size_t getTensorSize(const std::shared_ptr<memory>& mem) + { + return getTensorSize(getTensorDims(mem)); + } + + + template<int K> + inline int getPadded(int dim) + { + return (dim + (K-1)) & ~(K-1); + } + + template<int K> + inline memory::dims getPadded_nchw(const memory::dims& dims) + { + assert(dims.size() == 4); + memory::dims padDims = dims; + padDims[1] = getPadded<K>(dims[1]); // pad C + return padDims; + } + + + template<int K> + struct BlockedFormat; + + template<> + struct BlockedFormat<8> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o; + }; + + template<> + struct BlockedFormat<16> + { + static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c; + static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o; + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp new file mode 100644 index 0000000000..0812624bb5 --- /dev/null +++ b/thirdparty/oidn/core/device.cpp @@ -0,0 +1,205 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "device.h" +#include "autoencoder.h" + +namespace oidn { + + thread_local Device::ErrorState Device::globalError; + + Device::Device() + { + if (!mayiuse(sse41)) + throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum"); + } + + Device::~Device() + { + } + + void Device::setError(Device* device, Error code, const std::string& message) + { + // Update the stored error only if the previous error was queried + if (device) + { + ErrorState& curError = device->error.get(); + + if (curError.code == Error::None) + { + curError.code = code; + curError.message = message; + } + + // Print the error message in verbose mode + if (device->isVerbose()) + std::cerr << "Error: " << message << std::endl; + + // Call the error callback function + ErrorFunction errorFunc; + void* errorUserPtr; + + { + std::lock_guard<std::mutex> lock(device->mutex); + errorFunc = device->errorFunc; + errorUserPtr = device->errorUserPtr; + } + + if (errorFunc) + errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str()); + } + else + { + if (globalError.code == Error::None) + { + globalError.code = code; + globalError.message = message; + } + } + } + + Error Device::getError(Device* device, const char** outMessage) + { + // Return and clear the stored error code, but keep the error message so pointers to it will + // remain valid until the next getError call + if (device) + { + ErrorState& curError = device->error.get(); + const Error code = curError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : curError.message.c_str(); + curError.code = Error::None; + return code; + } + else + { + const Error code = globalError.code; + if (outMessage) + *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str(); + globalError.code = Error::None; + return code; + } + } + + void Device::setErrorFunction(ErrorFunction func, void* userPtr) + { + errorFunc = func; + errorUserPtr = userPtr; + } + + int Device::get1i(const std::string& name) + { + if (name == "numThreads") + return numThreads; + else if (name == "setAffinity") + return setAffinity; + else if (name == "verbose") + return verbose; + else if (name == "version") + return OIDN_VERSION; + else if (name == "versionMajor") + return OIDN_VERSION_MAJOR; + else if (name == "versionMinor") + return OIDN_VERSION_MINOR; + else if (name == "versionPatch") + return OIDN_VERSION_PATCH; + else + throw Exception(Error::InvalidArgument, "invalid parameter"); + } + + void Device::set1i(const std::string& name, int value) + { + if (name == "numThreads") + numThreads = value; + else if (name == "setAffinity") + setAffinity = value; + else if (name == "verbose") + { + verbose = value; + error.verbose = value; + } + + dirty = true; + } + + void Device::commit() + { + if (isCommitted()) + throw Exception(Error::InvalidOperation, "device can be committed only once"); + + // Create the task arena + const int maxNumThreads = 1; //affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency(); + numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads; + + dirty = false; + + if (isVerbose()) + print(); + } + + void Device::checkCommitted() + { + if (dirty) + throw Exception(Error::InvalidOperation, "changes to the device are not committed"); + } + + Ref<Buffer> Device::newBuffer(size_t byteSize) + { + checkCommitted(); + return makeRef<Buffer>(Ref<Device>(this), byteSize); + } + + Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize) + { + checkCommitted(); + return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize); + } + + Ref<Filter> Device::newFilter(const std::string& type) + { + checkCommitted(); + + if (isVerbose()) + std::cout << "Filter: " << type << std::endl; + + Ref<Filter> filter; + +// Godot doesn't need Raytracing filters. Removing them saves space in the weights files. +#if 0 + if (type == "RT") + filter = makeRef<RTFilter>(Ref<Device>(this)); +#endif + if (type == "RTLightmap") + filter = makeRef<RTLightmapFilter>(Ref<Device>(this)); + else + throw Exception(Error::InvalidArgument, "unknown filter type"); + + return filter; + } + + void Device::print() + { + std::cout << std::endl; + + std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl; + std::cout << " Compiler: " << getCompilerName() << std::endl; + std::cout << " Build : " << getBuildName() << std::endl; + std::cout << " Platform: " << getPlatformName() << std::endl; + + std::cout << std::endl; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h new file mode 100644 index 0000000000..93a83eb731 --- /dev/null +++ b/thirdparty/oidn/core/device.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" + +namespace oidn { + + class Buffer; + class Filter; + + class Device : public RefCount, public Verbose + { + private: + // Thread-safety + std::mutex mutex; + + // Error handling + struct ErrorState + { + Error code = Error::None; + std::string message; + }; + + static thread_local ErrorState globalError; + ThreadLocal<ErrorState> error; + ErrorFunction errorFunc = nullptr; + void* errorUserPtr = nullptr; + + // Parameters + int numThreads = 0; // autodetect by default + bool setAffinity = true; + + bool dirty = true; + + public: + Device(); + ~Device(); + + static void setError(Device* device, Error code, const std::string& message); + static Error getError(Device* device, const char** outMessage); + + void setErrorFunction(ErrorFunction func, void* userPtr); + + int get1i(const std::string& name); + void set1i(const std::string& name, int value); + + void commit(); + + Ref<Buffer> newBuffer(size_t byteSize); + Ref<Buffer> newBuffer(void* ptr, size_t byteSize); + Ref<Filter> newFilter(const std::string& type); + + __forceinline Device* getDevice() { return this; } + __forceinline std::mutex& getMutex() { return mutex; } + + private: + bool isCommitted() const { return false; } + void checkCommitted(); + + void print(); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp new file mode 100644 index 0000000000..ec1f10af87 --- /dev/null +++ b/thirdparty/oidn/core/filter.cpp @@ -0,0 +1,27 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "filter.h" + +namespace oidn { + + void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr) + { + progressFunc = func; + progressUserPtr = userPtr; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h new file mode 100644 index 0000000000..935fa202f4 --- /dev/null +++ b/thirdparty/oidn/core/filter.h @@ -0,0 +1,52 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "device.h" +#include "image.h" + +namespace oidn { + + class Filter : public RefCount + { + protected: + Ref<Device> device; + + ProgressMonitorFunction progressFunc = nullptr; + void* progressUserPtr = nullptr; + + bool dirty = true; + + public: + explicit Filter(const Ref<Device>& device) : device(device) {} + + virtual void setImage(const std::string& name, const Image& data) = 0; + virtual void set1i(const std::string& name, int value) = 0; + virtual int get1i(const std::string& name) = 0; + virtual void set1f(const std::string& name, float value) = 0; + virtual float get1f(const std::string& name) = 0; + + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr); + + virtual void commit() = 0; + virtual void execute() = 0; + + Device* getDevice() { return device.get(); } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h new file mode 100644 index 0000000000..748f49c4e5 --- /dev/null +++ b/thirdparty/oidn/core/image.h @@ -0,0 +1,111 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include "buffer.h" + +namespace oidn { + + struct Image + { + static constexpr int maxSize = 65536; + + char* ptr; // pointer to the first pixel + int width; // width in number of pixels + int height; // height in number of pixels + size_t bytePixelStride; // pixel stride in number of *bytes* + size_t rowStride; // row stride in number of *pixel strides* + Format format; // pixel format + Ref<Buffer> buffer; // buffer containing the image data + + Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {} + + Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + if (ptr == nullptr) + throw Exception(Error::InvalidArgument, "buffer pointer null"); + + init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + } + + Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride) + { + init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride); + + if (byteOffset + height * rowStride * bytePixelStride > buffer->size()) + throw Exception(Error::InvalidArgument, "buffer region out of range"); + } + + void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride) + { + assert(width >= 0); + assert(height >= 0); + if (width > maxSize || height > maxSize) + throw Exception(Error::InvalidArgument, "image size too large"); + + this->ptr = ptr; + this->width = width; + this->height = height; + + const size_t pixelSize = getFormatBytes(format); + if (inBytePixelStride != 0) + { + if (inBytePixelStride < pixelSize) + throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size"); + + this->bytePixelStride = inBytePixelStride; + } + else + { + this->bytePixelStride = pixelSize; + } + + if (inByteRowStride != 0) + { + if (inByteRowStride < width * this->bytePixelStride) + throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride"); + if (inByteRowStride % this->bytePixelStride != 0) + throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride"); + + this->rowStride = inByteRowStride / this->bytePixelStride; + } + else + { + this->rowStride = width; + } + + this->format = format; + } + + __forceinline char* get(int y, int x) + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + __forceinline const char* get(int y, int x) const + { + return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride); + } + + operator bool() const + { + return ptr != nullptr; + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h new file mode 100644 index 0000000000..966856afe9 --- /dev/null +++ b/thirdparty/oidn/core/input_reorder.h @@ -0,0 +1,232 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Input reorder node + template<int K, class TransferFunction> + class InputReorderNode : public Node + { + private: + // Source + Image color; + Image albedo; + Image normal; + + // Destination + std::shared_ptr<memory> dst; + float* dstPtr; + int C2; + int H2; + int W2; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr<TransferFunction> transferFunc; + + public: + InputReorderNode(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr<memory>& dst, + const std::shared_ptr<TransferFunction>& transferFunc) + : color(color), albedo(albedo), normal(normal), + dst(dst), + h1Begin(0), w1Begin(0), + H(color.height), W(color.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); + assert(dstDesc.ndims == 4); + assert(dstDesc.data_type == memory::data_type::f32); + assert(dstDesc.dims[0] == 1); + //assert(dstDesc.dims[1] >= getPadded<K>(C1)); + + dstPtr = (float*)dst->get_data_handle(); + C2 = dstDesc.dims[1]; + H2 = dstDesc.dims[2]; + W2 = dstDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(H + h1Begin <= color.height); + assert(W + w1Begin <= color.width); + assert(H + h2Begin <= H2); + assert(W + w2Begin <= W2); + + parallel_nd(H2, [&](int h2) + { + const int h = h2 - h2Begin; + + if (h >= 0 && h < H) + { + const int h1 = h + h1Begin; + + // Zero pad + for (int w2 = 0; w2 < w2Begin; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Reorder + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + + int c = 0; + storeColor(h2, w2, c, (float*)color.get(h1, w1)); + if (albedo) + storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1)); + if (normal) + storeNormal(h2, w2, c, (float*)normal.get(h1, w1)); + while (c < C2) + store(h2, w2, c, 0.f); + } + + // Zero pad + for (int w2 = W + w2Begin; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + else + { + // Zero pad + for (int w2 = 0; w2 < W2; ++w2) + { + int c = 0; + while (c < C2) + store(h2, w2, c, 0.f); + } + } + }); + } + + std::shared_ptr<memory> getDst() const override { return dst; } + + private: + // Stores a single value + __forceinline void store(int h, int w, int& c, float value) + { + // Destination is in nChwKc format + float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K); + *dst_c = value; + c++; + } + + // Stores a color + __forceinline void storeColor(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = maxSafe(x, 0.f); + + // Apply the transfer function + x = transferFunc->forward(x); + + // Store the value + store(h, w, c, x); + } + } + + // Stores an albedo + __forceinline void storeAlbedo(int h, int w, int& c, const float* values) + { + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = values[i]; + + // Sanitize the value + x = clampSafe(x, 0.f, 1.f); + + // Store the value + store(h, w, c, x); + } + } + + // Stores a normal + __forceinline void storeNormal(int h, int w, int& c, const float* values) + { + // Load the normal + float x = values[0]; + float y = values[1]; + float z = values[2]; + + // Compute the length of the normal + const float lengthSqr = sqr(x) + sqr(y) + sqr(z); + + // Normalize the normal and transform it to [0..1] + if (isfinite(lengthSqr)) + { + const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f; + + const float scale = invLength * 0.5f; + const float offset = 0.5f; + + x = x * scale + offset; + y = y * scale + offset; + z = z * scale + offset; + } + else + { + x = 0.f; + y = 0.f; + z = 0.f; + } + + // Store the normal + store(h, w, c, x); + store(h, w, c, y); + store(h, w, c, z); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h new file mode 100644 index 0000000000..a844ef0d1d --- /dev/null +++ b/thirdparty/oidn/core/math.h @@ -0,0 +1,78 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common/platform.h" + +namespace oidn { + + constexpr float minVectorLength = 1e-10f; + constexpr float minVectorLengthSqr = minVectorLength * minVectorLength; + + using std::log; + using std::log2; + using std::exp; + using std::exp2; + using std::pow; + using std::isfinite; + using std::isnan; + + __forceinline float sqr(float x) + { + return x * x; + } + + __forceinline float rcp(float x) + { + __m128 r = _mm_rcp_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x)))); + } + + __forceinline float rsqrt(float x) + { + __m128 r = _mm_rsqrt_ss(_mm_set_ss(x)); + return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r), + _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r)))); + } + + __forceinline float maxSafe(float value, float minValue) + { + return isfinite(value) ? max(value, minValue) : minValue; + } + + __forceinline float clampSafe(float value, float minValue, float maxValue) + { + return isfinite(value) ? clamp(value, minValue, maxValue) : minValue; + } + + // Returns ceil(a / b) for non-negative integers + template<class Int> + __forceinline constexpr Int ceilDiv(Int a, Int b) + { + //assert(a >= 0); + //assert(b > 0); + return (a + b - 1) / b; + } + + // Returns a rounded up to multiple of b + template<class Int> + __forceinline constexpr Int roundUp(Int a, Int b) + { + return ceilDiv(a, b) * b; + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp new file mode 100644 index 0000000000..4da32073cd --- /dev/null +++ b/thirdparty/oidn/core/network.cpp @@ -0,0 +1,434 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "network.h" +#include "upsample.h" +#include "weights_reorder.h" +#include <cstring> + +namespace oidn { + + template<int K> + Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap) + : device(device), + eng(engine::cpu, 0), + sm(eng), + weightMap(weightMap) + { + } + + template<int K> + void Network<K>::execute(const Progress& progress, int taskIndex) + { + if (progress.func) + { + const double value = double(taskIndex) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + + for (size_t i = 0; i < nodes.size(); ++i) + { + nodes[i]->execute(sm); + + if (progress.func) + { + const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount); + if (!progress.func(progress.userPtr, value)) + throw Exception(Error::Cancelled, "execution was cancelled"); + } + } + } + + template<int K> + std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims, + memory::format_tag format, + void* data) + { + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat<K>::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + if (data == nullptr) + { + const size_t bytes = getTensorSize(dims) * sizeof(float); + if (format == BlockedFormat<K>::nChwKc) + activationAllocBytes += bytes; + totalAllocBytes += bytes; + + return std::make_shared<memory>(desc, eng); + } + else + { + return std::make_shared<memory>(desc, eng, data); + } + } + + template<int K> + std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims, + const std::shared_ptr<memory>& src, + size_t srcOffset, + memory::format_tag format) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(srcDesc.data_type == memory::data_type::f32); + assert(getTensorSize(src) >= srcOffset + getTensorSize(dims)); + + if (format == memory::format_tag::any) + { + if (dims.size() == 4) + format = BlockedFormat<K>::nChwKc; + else if (dims.size() == 1) + format = memory::format_tag::x; + else + assert(0); + } + memory::desc desc(dims, memory::data_type::f32, format); + float* srcPtr = (float*)src->get_data_handle() + srcOffset; + return std::make_shared<memory>(desc, eng, srcPtr); + } + + template<int K> + std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims, + const std::shared_ptr<memory>& src, + const memory::dims& srcOffset) + { + return castTensor(dims, src, getTensorSize(srcOffset)); + } + + template<int K> + void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst) + { + assert(getTensorType(dst) == memory::data_type::f32); + memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float)); + } + + template<int K> + memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment) + { + memory::dims dstDims = srcDims; + dstDims[1] = getPadded<K>(srcDims[1]); // round up C + dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H + dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W + return dstDims; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr<TransferFunction>& transferFunc, + int alignment, + const std::shared_ptr<memory>& userDst) + { + assert(color); + int inputC = 3; + if (albedo) inputC += 3; + if (normal) inputC += 3; + + memory::dims srcDims = {1, inputC, color.height, color.width}; + memory::dims dstDims = getInputReorderDims(srcDims, alignment); + + // Allocate padded memory + auto dst = userDst; + if (!dst) + dst = allocTensor(dstDims); + + // Push node + std::shared_ptr<Node> node; + + if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc)) + node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc)) + node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc)) + node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf); + else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc)) + node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src, + const std::shared_ptr<TransferFunction>& transferFunc, + const Image& output) + { + memory::dims srcDims = getTensorDims(src); + assert(srcDims[1] == K); + + // Push node + std::shared_ptr<Node> node; + + if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc)) + node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc)) + node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc)) + node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf); + else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc)) + node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf); + else + assert(0); + + nodes.push_back(node); + return node; + } + + template<int K> + memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims) + { + auto b = weightMap[name + "/b"]; + memory::dims dstDims = srcDims; + dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC) + return dstDims; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addConv(const std::string& name, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst, + bool relu) + { + const memory::dims strides = {1, 1}; + const memory::dims padding = {1, 1}; + + memory::dims srcDims = getTensorDims(src); + + // Get the weights + const auto& W = weightMap[name + "/W"]; + if (W.ndims() != 4 || W.format != "oihw") + throw Exception(Error::InvalidOperation, "invalid convolution weights"); + memory::dims weightsDims = W.dims; + auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data); + + // Pad the weights + memory::dims weightsPadDims = weightsDims; + weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC + weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC + assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC] + auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw); + WeightsReorderNode<K>(userWeights, weightsPad).execute(sm); + + // Get the biases + const auto& b = weightMap[name + "/b"]; + if (b.ndims() != 1) + throw Exception(Error::InvalidOperation, "invalid convolution biases"); + memory::dims biasDims = b.dims; + + // Copy/pad the biases + memory::dims biasPadDims = {getPadded<K>(biasDims[0])}; + auto bias = allocTensor(biasPadDims); + if (biasDims[0] != biasPadDims[0]) + memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float)); + memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float)); + + // Allocate memory for destination + memory::dims dstDims = srcDims; + dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC] + + std::shared_ptr<memory> dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create a convolution + // Let the convolution primitive choose the weights format + auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any); + + auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct; + auto convDesc = convolution_forward::desc( + prop_kind::forward_inference, convAlgo, + src->get_desc(), + weightsDesc, + bias->get_desc(), + dst->get_desc(), + strides, padding, padding, padding_kind::zero); + + // Incorporate relu + mkldnn::primitive_attr convAttr; + if (relu) + { + mkldnn::post_ops ops; + ops.append_eltwise( + 1.f, // scale factor, not used + algorithm::eltwise_relu, + 0.f, // max with + 0.f // unused + ); + convAttr.set_post_ops(ops); + } + convAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng); + + // Reorder the weights to the final format, if necessary + auto weights = weightsPad; + if (convPrimDesc.weights_desc() != weightsPad->get_desc()) + { + weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng); + ReorderNode(weightsPad, weights).execute(sm); + } + + // Create convolution node and add it to the net + auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst); + nodes.push_back(node); + return node; + } + + template<int K> + memory::dims Network<K>::getPoolDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] /= 2; // H/2 + dstDims[3] /= 2; // W/2 + return dstDims; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst) + { + const memory::dims kernel = {2, 2}; + const memory::dims strides = {2, 2}; + const memory::dims padding = {0, 0}; + + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getPoolDims(srcDims); + + std::shared_ptr<memory> dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + auto poolDesc = pooling_forward::desc( + prop_kind::forward_inference, pooling_max, + src->get_desc(), + dst->get_desc(), + strides, kernel, padding, padding, padding_kind::zero); + + mkldnn::primitive_attr poolAttr; + poolAttr.set_scratchpad_mode(scratchpad_mode_user); + + auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng); + + auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst); + nodes.push_back(node); + return node; + } + + template<int K> + memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims) + { + memory::dims dstDims = srcDims; + dstDims[2] *= 2; // H*2 + dstDims[3] *= 2; // W*2 + return dstDims; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst) + { + memory::dims srcDims = getTensorDims(src); + memory::dims dstDims = getUpsampleDims(srcDims); + + std::shared_ptr<memory> dst; + if (!userDst) + dst = allocTensor(dstDims); + else if (getTensorDims(userDst) == dstDims) + dst = userDst; + else + dst = castTensor(dstDims, userDst); + + // Create upsampling node and add it to net + auto node = std::make_shared<UpsampleNode<K>>(src, dst); + nodes.push_back(node); + return node; + } + + template<int K> + memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims) + { + assert(src1Dims[0] == src2Dims[0]); // N + assert(src1Dims[2] == src2Dims[2]); // H + assert(src1Dims[3] == src2Dims[3]); // W + + memory::dims dstDims = src1Dims; + dstDims[1] += src2Dims[1]; // C + return dstDims; + } + + template<int K> + std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color, + const std::shared_ptr<HDRTransferFunction>& transferFunc) + { + auto node = std::make_shared<AutoexposureNode>(color, transferFunc); + nodes.push_back(node); + return node; + } + + template <int K> + void Network<K>::finalize() + { + // Compute the size of the scratchpad + size_t scratchpadSize = 0; + for (const auto& node : nodes) + scratchpadSize = max(scratchpadSize, node->getScratchpadSize()); + + // Allocate the scratchpad + memory::dims scratchpadDims = { memory::dim(scratchpadSize) }; + memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x); + auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng); + activationAllocBytes += scratchpadSize; + totalAllocBytes += scratchpadSize; + + // Set the scratchpad for the nodes + for (auto& node : nodes) + node->setScratchpad(scratchpad); + + // Free the weights + weightMap.clear(); + + // Print statistics + if (device->isVerbose(2)) + { + std::cout << "Activation bytes: " << activationAllocBytes << std::endl; + std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl; + std::cout << "Total bytes : " << totalAllocBytes << std::endl; + } + } + + template class Network<8>; + template class Network<16>; + +} // namespace oidn diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h new file mode 100644 index 0000000000..7a696fd355 --- /dev/null +++ b/thirdparty/oidn/core/network.h @@ -0,0 +1,112 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "common/tensor.h" +#include "image.h" +#include "node.h" +#include "input_reorder.h" +#include "output_reorder.h" +#include "transfer_function.h" + +#pragma once + +namespace oidn { + + // Progress state + struct Progress + { + ProgressMonitorFunction func; + void* userPtr; + int taskCount; + }; + + class Executable + { + public: + virtual ~Executable() {} + virtual void execute(const Progress& progress, int taskIndex) = 0; + }; + + template<int K> + class Network : public Executable + { + public: + Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap); + + void execute(const Progress& progress, int taskIndex) override; + + std::shared_ptr<memory> allocTensor(const memory::dims& dims, + memory::format_tag format = memory::format_tag::any, + void* data = nullptr); + + std::shared_ptr<memory> castTensor(const memory::dims& dims, + const std::shared_ptr<memory>& src, + size_t srcOffset = 0, + memory::format_tag format = memory::format_tag::any); + + std::shared_ptr<memory> castTensor(const memory::dims& dims, + const std::shared_ptr<memory>& src, + const memory::dims& srcOffset); + + void zeroTensor(const std::shared_ptr<memory>& dst); + + memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment); + + std::shared_ptr<Node> addInputReorder(const Image& color, + const Image& albedo, + const Image& normal, + const std::shared_ptr<TransferFunction>& transferFunc, + int alignment, + const std::shared_ptr<memory>& userDst = nullptr); + + std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src, + const std::shared_ptr<TransferFunction>& transferFunc, + const Image& output); + + memory::dims getConvDims(const std::string& name, const memory::dims& srcDims); + std::shared_ptr<Node> addConv(const std::string& name, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst = nullptr, + bool relu = true); + + memory::dims getPoolDims(const memory::dims& srcDims); + std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst = nullptr); + + memory::dims getUpsampleDims(const memory::dims& srcDims); + std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& userDst = nullptr); + + memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims); + + std::shared_ptr<Node> addAutoexposure(const Image& color, + const std::shared_ptr<HDRTransferFunction>& transferFunc); + + void finalize(); + + private: + Ref<Device> device; + engine eng; + stream sm; + std::vector<std::shared_ptr<Node>> nodes; + std::map<std::string, Tensor> weightMap; + + // Memory allocation statistics + size_t activationAllocBytes = 0; // number of allocated activation bytes + size_t totalAllocBytes = 0; // total number of allocated bytes + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h new file mode 100644 index 0000000000..b9ffe906df --- /dev/null +++ b/thirdparty/oidn/core/node.h @@ -0,0 +1,142 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "common.h" +#include <vector> + +namespace oidn { + + class Node + { + public: + virtual ~Node() = default; + + virtual void execute(stream& sm) = 0; + + virtual std::shared_ptr<memory> getDst() const { return nullptr; } + + virtual size_t getScratchpadSize() const { return 0; } + virtual void setScratchpad(const std::shared_ptr<memory>& mem) {} + + virtual void setTile(int h1, int w1, int h2, int w2, int H, int W) + { + assert(0); // not supported + } + }; + + // Node wrapping an MKL-DNN primitive + class MklNode : public Node + { + private: + primitive prim; + std::unordered_map<int, memory> args; + std::shared_ptr<memory> scratchpad; + + public: + MklNode(const primitive& prim, const std::unordered_map<int, memory>& args) + : prim(prim), + args(args) + {} + + size_t getScratchpadSize() const override + { + const auto primDesc = prim.get_primitive_desc(); + const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0); + if (scratchpadDesc == nullptr) + return 0; + return mkldnn_memory_desc_get_size(scratchpadDesc); + } + + void setScratchpad(const std::shared_ptr<memory>& mem) override + { + scratchpad = mem; + args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad)); + } + + void execute(stream& sm) override + { + prim.execute(sm, args); + } + }; + + // Convolution node + class ConvNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> weights; + std::shared_ptr<memory> bias; + std::shared_ptr<memory> dst; + + public: + ConvNode(const convolution_forward::primitive_desc& desc, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& weights, + const std::shared_ptr<memory>& bias, + const std::shared_ptr<memory>& dst) + : MklNode(convolution_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_WEIGHTS, *weights }, + { MKLDNN_ARG_BIAS, *bias }, + { MKLDNN_ARG_DST, *dst } }), + src(src), weights(weights), bias(bias), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + + // Pooling node + class PoolNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + PoolNode(const pooling_forward::primitive_desc& desc, + const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : MklNode(pooling_forward(desc), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + + // Reorder node + class ReorderNode : public MklNode + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + ReorderNode(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : MklNode(reorder(reorder::primitive_desc(*src, *dst)), + { { MKLDNN_ARG_SRC, *src }, + { MKLDNN_ARG_DST, *dst } }), + src(src), dst(dst) + {} + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h new file mode 100644 index 0000000000..7918d48e15 --- /dev/null +++ b/thirdparty/oidn/core/output_reorder.h @@ -0,0 +1,126 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" +#include "image.h" + +namespace oidn { + + // Output reorder node + template<int K, class TransferFunction> + class OutputReorderNode : public Node + { + private: + // Source + std::shared_ptr<memory> src; + const float* srcPtr; + int H1; + int W1; + + // Destination + Image output; + + // Tile + int h1Begin; + int w1Begin; + int h2Begin; + int w2Begin; + int H; + int W; + + std::shared_ptr<TransferFunction> transferFunc; + + public: + OutputReorderNode(const std::shared_ptr<memory>& src, + const Image& output, + const std::shared_ptr<TransferFunction>& transferFunc) + : src(src), + output(output), + h1Begin(0), w1Begin(0), + h2Begin(0), w2Begin(0), + H(output.height), W(output.width), + transferFunc(transferFunc) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + MAYBE_UNUSED(srcDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); + assert(srcDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + // We assume output data is <= K OC + assert(srcDesc.dims[1] == K); + + srcPtr = (float*)src->get_data_handle(); + H1 = srcDesc.dims[2]; + W1 = srcDesc.dims[3]; + } + + void setTile(int h1, int w1, int h2, int w2, int H, int W) override + { + h1Begin = h1; + w1Begin = w1; + h2Begin = h2; + w2Begin = w2; + this->H = H; + this->W = W; + } + + void execute(stream& sm) override + { + assert(h1Begin + H <= H1); + assert(w1Begin + W <= W1); + assert(h2Begin + H <= output.height); + assert(w2Begin + W <= output.width); + + const int C1 = K; + + parallel_nd(H, [&](int h) + { + const int h1 = h + h1Begin; + const int h2 = h + h2Begin; + + for (int w = 0; w < W; ++w) + { + const int w1 = w + w1Begin; + const int w2 = w + w2Begin; + float* dstPtr_C = (float*)output.get(h2, w2); + + // Source is in nChwKc format. In this case C is 1 so this is really nhwc + const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1; + + #pragma unroll + for (int i = 0; i < 3; ++i) + { + // Load the value + float x = srcPtr_C[i]; + + // The CNN output may contain negative values or even NaNs, so it must be sanitized + x = maxSafe(x, 0.f); + + // Apply the inverse transfer function + x = transferFunc->inverse(x); + + // Sanitize and store the final value + dstPtr_C[i] = max(x, 0.f); + } + } + }); + } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp new file mode 100644 index 0000000000..a33e3c84bc --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.cpp @@ -0,0 +1,95 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#include "transfer_function.h" + +namespace oidn { + + const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f); + const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale); + + float AutoexposureNode::autoexposure(const Image& color) + { + assert(color.format == Format::Float3); + return 1.0f; + + /*constexpr float key = 0.18f; + constexpr float eps = 1e-8f; + constexpr int K = 16; // downsampling amount + + // Downsample the image to minimize sensitivity to noise + const int H = color.height; // original height + const int W = color.width; // original width + const int HK = (H + K/2) / K; // downsampled height + const int WK = (W + K/2) / K; // downsampled width + + // Compute the average log luminance of the downsampled image + using Sum = std::pair<float, int>; + + Sum sum = + tbb::parallel_reduce( + tbb::blocked_range2d<int>(0, HK, 0, WK), + Sum(0.f, 0), + [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum + { + // Iterate over blocks + for (int i = r.rows().begin(); i != r.rows().end(); ++i) + { + for (int j = r.cols().begin(); j != r.cols().end(); ++j) + { + // Compute the average luminance in the current block + const int beginH = int(ptrdiff_t(i) * H / HK); + const int beginW = int(ptrdiff_t(j) * W / WK); + const int endH = int(ptrdiff_t(i+1) * H / HK); + const int endW = int(ptrdiff_t(j+1) * W / WK); + + float L = 0.f; + + for (int h = beginH; h < endH; ++h) + { + for (int w = beginW; w < endW; ++w) + { + const float* rgb = (const float*)color.get(h, w); + + const float r = maxSafe(rgb[0], 0.f); + const float g = maxSafe(rgb[1], 0.f); + const float b = maxSafe(rgb[2], 0.f); + + L += luminance(r, g, b); + } + } + + L /= (endH - beginH) * (endW - beginW); + + // Accumulate the log luminance + if (L > eps) + { + sum.first += log2(L); + sum.second++; + } + } + } + + return sum; + }, + [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); }, + tbb::static_partitioner() + ); + + return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;*/ + } + +} // namespace oidn diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h new file mode 100644 index 0000000000..35f2833092 --- /dev/null +++ b/thirdparty/oidn/core/transfer_function.h @@ -0,0 +1,201 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "image.h" +#include "node.h" + +namespace oidn { + + __forceinline float luminance(float r, float g, float b) + { + return 0.212671f * r + 0.715160f * g + 0.072169f * b; + } + + // Color transfer function base class + class TransferFunction + { + public: + virtual ~TransferFunction() = default; + + virtual float forward(float y) const = 0; + virtual float inverse(float x) const = 0; + }; + + // HDR transfer function base class + class HDRTransferFunction : public TransferFunction + { + protected: + static constexpr float yMax = 65504.f; + + float exposure; + float rcpExposure; + + public: + HDRTransferFunction(float exposure = 1.f) + { + setExposure(exposure); + } + + void setExposure(float exposure) + { + this->exposure = exposure; + this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f; + } + }; + + // Linear transfer function (LDR) + class LinearTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(y, 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(x, 1.f); + } + }; + + // 2.2 gamma transfer function (LDR) + class GammaTransferFunction : public TransferFunction + { + public: + __forceinline float forward(float y) const override + { + return min(pow(y, 1.f/2.2f), 1.f); + } + + __forceinline float inverse(float x) const override + { + return min(pow(x, 2.2f), 1.f); + } + }; + + // Logarithmic transfer function (HDR) + // Compresses [0..65504] to [0..1] + class LogTransferFunction : public HDRTransferFunction + { + private: + static const float xScale; + + public: + LogTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return log(y * exposure + 1.f) * xScale; + } + + __forceinline float inverse(float x) const override + { + return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure; + } + }; + + // PQX transfer function (HDR) + // Compresses [0..65504] to [0..1] + class PQXTransferFunction : public HDRTransferFunction + { + private: + static constexpr float m1 = 2610.f / 4096.f / 4.f; + static constexpr float m2 = 2523.f / 4096.f * 128.f; + static constexpr float c1 = 3424.f / 4096.f; + static constexpr float c2 = 2413.f / 4096.f * 32.f; + static constexpr float c3 = 2392.f / 4096.f * 32.f; + static constexpr float a = 3711.f / 4096.f / 8.f; + + static constexpr float yScale = 100.f / 10000.f; + static const float xScale; + + public: + PQXTransferFunction(float exposure = 1.f) + : HDRTransferFunction(exposure) + { + } + + __forceinline float forward(float y) const override + { + return pqxForward(y * exposure * yScale) * xScale; + } + + __forceinline float inverse(float x) const override + { + return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure; + } + + private: + static __forceinline float pqForward(float y) + { + const float yp = pow(y, m1); + return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2); + } + + static __forceinline float pqxForward(float y) + { + if (y <= 1.f) + return pqForward(y); + else + return a * log(y) + 1.f; + } + + static __forceinline float pqInverse(float x) + { + const float xp = pow(x, 1.f/m2); + return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1); + } + + static __forceinline float pqxInverse(float x) + { + if (x <= 1.f) + return pqInverse(x); + else + return exp((x - 1.f) * (1.f/a)); + } + }; + + // Autoexposure node + class AutoexposureNode : public Node + { + private: + Image color; + std::shared_ptr<HDRTransferFunction> transferFunc; + + public: + AutoexposureNode(const Image& color, + const std::shared_ptr<HDRTransferFunction>& transferFunc) + : color(color), + transferFunc(transferFunc) + {} + + void execute(stream& sm) override + { + const float exposure = autoexposure(color); + //printf("exposure = %f\n", exposure); + transferFunc->setExposure(exposure); + } + + private: + static float autoexposure(const Image& color); + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h new file mode 100644 index 0000000000..f6cace44cd --- /dev/null +++ b/thirdparty/oidn/core/upsample.h @@ -0,0 +1,92 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // 2x2 nearest-neighbor upsampling node + template<int K> + class UpsampleNode : public Node + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + UpsampleNode(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(srcDesc.dims[0] == 1); + assert(dstDesc.dims[0] == 1); + // 2x2 upsampling + assert(dstDesc.dims[2] == srcDesc.dims[2] * 2); + assert(dstDesc.dims[3] == srcDesc.dims[3] * 2); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int C = srcDesc.dims[1]; + const int H = srcDesc.dims[2]; + const int W = srcDesc.dims[3]; + const int CK = C / K; + + parallel_nd(CK, H, [&](int ck, int h) + { + const size_t offset = ck*H*W*K + h*W*K; + const float* srcPtr_line = srcPtr + offset; + float* dstPtr_line0 = dstPtr + offset * 4; + float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line + + for (int w = 0; w < W; ++w) + { + #pragma unroll + for (int k = 0; k < K; k += 4) + { + const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]); + + _mm_stream_ps(&dstPtr_line0[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K + k], m); + _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m); + } + } + }); + } + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h new file mode 100644 index 0000000000..6c5dacb8aa --- /dev/null +++ b/thirdparty/oidn/core/weights_reorder.h @@ -0,0 +1,99 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include "node.h" + +namespace oidn { + + // Reorders weights from oihw to padded oihw format + template<int K> + class WeightsReorderNode : public Node + { + private: + std::shared_ptr<memory> src; + std::shared_ptr<memory> dst; + + public: + WeightsReorderNode(const std::shared_ptr<memory>& src, + const std::shared_ptr<memory>& dst) + : src(src), + dst(dst) + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + MAYBE_UNUSED(srcDesc); + MAYBE_UNUSED(dstDesc); + assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw))); + assert(srcDesc.ndims == 4); + assert(dstDesc.ndims == 4); + assert(srcDesc.data_type == memory::data_type::f32); + assert(dstDesc.data_type == memory::data_type::f32); + assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC + assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC + assert(srcDesc.dims[2] == dstDesc.dims[2]); + assert(srcDesc.dims[3] == dstDesc.dims[3]); + } + + void execute(stream& sm) override + { + const mkldnn_memory_desc_t& srcDesc = src->get_desc().data; + const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data; + + const float* srcPtr = (float*)src->get_data_handle(); + float* dstPtr = (float*)dst->get_data_handle(); + + const int OC1 = srcDesc.dims[0]; + const int OC2 = dstDesc.dims[0]; + const int IC1 = srcDesc.dims[1]; + const int IC2 = dstDesc.dims[1]; + const int H = dstDesc.dims[2]; + const int W = dstDesc.dims[3]; + + for (int oc = 0; oc < OC2; ++oc) + { + for (int ic = 0; ic < IC2; ++ic) + { + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + // Output is in oihw format + float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w; + + if (oc < OC1 && ic < IC1) + { + // Input is in oihw format + const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w; + *dstPtr_c = *srcPtr_c; + } + else + { + // padding + *dstPtr_c = 0; + } + } + } + } + } + } + + std::shared_ptr<memory> getDst() const override { return dst; } + }; + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h new file mode 100644 index 0000000000..57ba6baa21 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.h @@ -0,0 +1,214 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include <stddef.h> +#include <stdbool.h> +#include <stdint.h> + +#include "version.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#ifndef OIDN_API +#if defined(_WIN32) && !defined(OIDN_STATIC_LIB) +# define OIDN_API __declspec(dllimport) +#else +# define OIDN_API +#endif +#endif + +// ---------------------------------------------------------------------------- +// Device +// ---------------------------------------------------------------------------- + +// Device types +typedef enum +{ + OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically + + OIDN_DEVICE_TYPE_CPU = 1, // CPU device +} OIDNDeviceType; + +// Error codes +typedef enum +{ + OIDN_ERROR_NONE = 0, // no error occurred + OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred + OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified + OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed + OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation + OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported + OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user +} OIDNError; + +// Error callback function +typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message); + +// Device handle +typedef struct OIDNDeviceImpl* OIDNDevice; + +// Creates a new device. +OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type); + +// Retains the device (increments the reference count). +OIDN_API void oidnRetainDevice(OIDNDevice device); + +// Releases the device (decrements the reference count). +OIDN_API void oidnReleaseDevice(OIDNDevice device); + +// Sets a boolean parameter of the device. +OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value); + +// Sets an integer parameter of the device. +OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value); + +// Gets a boolean parameter of the device. +OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name); + +// Gets an integer parameter of the device (e.g. "version"). +OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name); + +// Sets the error callback function of the device. +OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr); + +// Returns the first unqueried error code stored in the device for the current +// thread, optionally also returning a string message (if not NULL), and clears +// the stored error. Can be called with a NULL device as well to check why a +// device creation failed. +OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage); + +// Commits all previous changes to the device. +// Must be called before first using the device (e.g. creating filters). +OIDN_API void oidnCommitDevice(OIDNDevice device); + +// ---------------------------------------------------------------------------- +// Buffer +// ---------------------------------------------------------------------------- + +// Formats for images and other data stored in buffers +typedef enum +{ + OIDN_FORMAT_UNDEFINED = 0, + + // 32-bit single-precision floating point scalar and vector formats + OIDN_FORMAT_FLOAT = 1, + OIDN_FORMAT_FLOAT2 = 2, + OIDN_FORMAT_FLOAT3 = 3, + OIDN_FORMAT_FLOAT4 = 4, +} OIDNFormat; + +// Access modes for mapping buffers +typedef enum +{ + OIDN_ACCESS_READ = 0, // read-only access + OIDN_ACCESS_WRITE = 1, // write-only access + OIDN_ACCESS_READ_WRITE = 2, // read and write access + OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded +} OIDNAccess; + +// Buffer handle +typedef struct OIDNBufferImpl* OIDNBuffer; + +// Creates a new buffer (data allocated and owned by the device). +OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize); + +// Creates a new shared buffer (data allocated and owned by the user). +OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize); + +// Maps a region of the buffer to host memory. +// If byteSize is 0, the maximum available amount of memory will be mapped. +OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize); + +// Unmaps a region of the buffer. +// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer. +OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr); + +// Retains the buffer (increments the reference count). +OIDN_API void oidnRetainBuffer(OIDNBuffer buffer); + +// Releases the buffer (decrements the reference count). +OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer); + +// ---------------------------------------------------------------------------- +// Filter +// ---------------------------------------------------------------------------- + +// Progress monitor callback function +typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n); + +// Filter handle +typedef struct OIDNFilterImpl* OIDNFilter; + +// Creates a new filter of the specified type (e.g. "RT"). +OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type); + +// Retains the filter (increments the reference count). +OIDN_API void oidnRetainFilter(OIDNFilter filter); + +// Releases the filter (decrements the reference count). +OIDN_API void oidnReleaseFilter(OIDNFilter filter); + +// Sets an image parameter of the filter (stored in a buffer). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name, + OIDNBuffer buffer, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets an image parameter of the filter (owned by the user). +// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically. +OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name, + void* ptr, OIDNFormat format, + size_t width, size_t height, + size_t byteOffset, + size_t bytePixelStride, size_t byteRowStride); + +// Sets a boolean parameter of the filter. +OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value); + +// Gets a boolean parameter of the filter. +OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name); + +// Sets an integer parameter of the filter. +OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value); + +// Gets an integer parameter of the filter. +OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name); + +// Sets a float parameter of the filter. +OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value); + +// Gets a float parameter of the filter. +OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name); + +// Sets the progress monitor callback function of the filter. +OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr); + +// Commits all previous changes to the filter. +// Must be called before first executing the filter. +OIDN_API void oidnCommitFilter(OIDNFilter filter); + +// Executes the filter. +OIDN_API void oidnExecuteFilter(OIDNFilter filter); + +#if defined(__cplusplus) +} +#endif diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp new file mode 100644 index 0000000000..9f95a56fe1 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp @@ -0,0 +1,468 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#include <algorithm> +#include "oidn.h" + +namespace oidn { + + // -------------------------------------------------------------------------- + // Buffer + // -------------------------------------------------------------------------- + + // Formats for images and other data stored in buffers + enum class Format + { + Undefined = OIDN_FORMAT_UNDEFINED, + + // 32-bit single-precision floating point scalar and vector formats + Float = OIDN_FORMAT_FLOAT, + Float2 = OIDN_FORMAT_FLOAT2, + Float3 = OIDN_FORMAT_FLOAT3, + Float4 = OIDN_FORMAT_FLOAT4, + }; + + // Access modes for mapping buffers + enum class Access + { + Read = OIDN_ACCESS_READ, // read-only access + Write = OIDN_ACCESS_WRITE, // write-only access + ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access + WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded + }; + + // Buffer object with automatic reference counting + class BufferRef + { + private: + OIDNBuffer handle; + + public: + BufferRef() : handle(nullptr) {} + BufferRef(OIDNBuffer handle) : handle(handle) {} + + BufferRef(const BufferRef& other) : handle(other.handle) + { + if (handle) + oidnRetainBuffer(handle); + } + + BufferRef(BufferRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + BufferRef& operator =(const BufferRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainBuffer(other.handle); + if (handle) + oidnReleaseBuffer(handle); + handle = other.handle; + } + return *this; + } + + BufferRef& operator =(BufferRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + BufferRef& operator =(OIDNBuffer other) + { + if (other) + oidnRetainBuffer(other); + if (handle) + oidnReleaseBuffer(handle); + handle = other; + return *this; + } + + ~BufferRef() + { + if (handle) + oidnReleaseBuffer(handle); + } + + OIDNBuffer getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Maps a region of the buffer to host memory. + // If byteSize is 0, the maximum available amount of memory will be mapped. + void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0) + { + return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize); + } + + // Unmaps a region of the buffer. + // mappedPtr must be a pointer returned by a previous call to map. + void unmap(void* mappedPtr) + { + oidnUnmapBuffer(handle, mappedPtr); + } + }; + + // -------------------------------------------------------------------------- + // Filter + // -------------------------------------------------------------------------- + + // Progress monitor callback function + typedef bool (*ProgressMonitorFunction)(void* userPtr, double n); + + // Filter object with automatic reference counting + class FilterRef + { + private: + OIDNFilter handle; + + public: + FilterRef() : handle(nullptr) {} + FilterRef(OIDNFilter handle) : handle(handle) {} + + FilterRef(const FilterRef& other) : handle(other.handle) + { + if (handle) + oidnRetainFilter(handle); + } + + FilterRef(FilterRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + FilterRef& operator =(const FilterRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainFilter(other.handle); + if (handle) + oidnReleaseFilter(handle); + handle = other.handle; + } + return *this; + } + + FilterRef& operator =(FilterRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + FilterRef& operator =(OIDNFilter other) + { + if (other) + oidnRetainFilter(other); + if (handle) + oidnReleaseFilter(handle); + handle = other; + return *this; + } + + ~FilterRef() + { + if (handle) + oidnReleaseFilter(handle); + } + + OIDNFilter getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets an image parameter of the filter (stored in a buffer). + void setImage(const char* name, + const BufferRef& buffer, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetFilterImage(handle, name, + buffer.getHandle(), (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets an image parameter of the filter (owned by the user). + void setImage(const char* name, + void* ptr, Format format, + size_t width, size_t height, + size_t byteOffset = 0, + size_t bytePixelStride = 0, size_t byteRowStride = 0) + { + oidnSetSharedFilterImage(handle, name, + ptr, (OIDNFormat)format, + width, height, + byteOffset, + bytePixelStride, byteRowStride); + } + + // Sets a boolean parameter of the filter. + void set(const char* name, bool value) + { + oidnSetFilter1b(handle, name, value); + } + + // Sets an integer parameter of the filter. + void set(const char* name, int value) + { + oidnSetFilter1i(handle, name, value); + } + + // Sets a float parameter of the filter. + void set(const char* name, float value) + { + oidnSetFilter1f(handle, name, value); + } + + // Gets a parameter of the filter. + template<typename T> + T get(const char* name); + + // Sets the progress monitor callback function of the filter. + void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr) + { + oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr); + } + + // Commits all previous changes to the filter. + void commit() + { + oidnCommitFilter(handle); + } + + // Executes the filter. + void execute() + { + oidnExecuteFilter(handle); + } + }; + + // Gets a boolean parameter of the filter. + template<> + inline bool FilterRef::get(const char* name) + { + return oidnGetFilter1b(handle, name); + } + + // Gets an integer parameter of the filter. + template<> + inline int FilterRef::get(const char* name) + { + return oidnGetFilter1i(handle, name); + } + + // Gets a float parameter of the filter. + template<> + inline float FilterRef::get(const char* name) + { + return oidnGetFilter1f(handle, name); + } + + // -------------------------------------------------------------------------- + // Device + // -------------------------------------------------------------------------- + + // Device types + enum class DeviceType + { + Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically + + CPU = OIDN_DEVICE_TYPE_CPU, // CPU device + }; + + // Error codes + enum class Error + { + None = OIDN_ERROR_NONE, // no error occurred + Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred + InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified + InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed + OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation + UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported + Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user + }; + + // Error callback function + typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message); + + // Device object with automatic reference counting + class DeviceRef + { + private: + OIDNDevice handle; + + public: + DeviceRef() : handle(nullptr) {} + DeviceRef(OIDNDevice handle) : handle(handle) {} + + DeviceRef(const DeviceRef& other) : handle(other.handle) + { + if (handle) + oidnRetainDevice(handle); + } + + DeviceRef(DeviceRef&& other) : handle(other.handle) + { + other.handle = nullptr; + } + + DeviceRef& operator =(const DeviceRef& other) + { + if (&other != this) + { + if (other.handle) + oidnRetainDevice(other.handle); + if (handle) + oidnReleaseDevice(handle); + handle = other.handle; + } + return *this; + } + + DeviceRef& operator =(DeviceRef&& other) + { + std::swap(handle, other.handle); + return *this; + } + + DeviceRef& operator =(OIDNDevice other) + { + if (other) + oidnRetainDevice(other); + if (handle) + oidnReleaseDevice(handle); + handle = other; + return *this; + } + + ~DeviceRef() + { + if (handle) + oidnReleaseDevice(handle); + } + + OIDNDevice getHandle() const + { + return handle; + } + + operator bool() const + { + return handle != nullptr; + } + + // Sets a boolean parameter of the device. + void set(const char* name, bool value) + { + oidnSetDevice1b(handle, name, value); + } + + // Sets an integer parameter of the device. + void set(const char* name, int value) + { + oidnSetDevice1i(handle, name, value); + } + + // Gets a parameter of the device. + template<typename T> + T get(const char* name); + + // Sets the error callback function of the device. + void setErrorFunction(ErrorFunction func, void* userPtr = nullptr) + { + oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr); + } + + // Returns the first unqueried error code and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError() + { + return (Error)oidnGetDeviceError(handle, nullptr); + } + + // Returns the first unqueried error code and string message, and clears the stored error. + // Can be called for a null device as well to check why a device creation failed. + Error getError(const char*& outMessage) + { + return (Error)oidnGetDeviceError(handle, &outMessage); + } + + // Commits all previous changes to the device. + // Must be called before first using the device (e.g. creating filters). + void commit() + { + oidnCommitDevice(handle); + } + + // Creates a new buffer (data allocated and owned by the device). + BufferRef newBuffer(size_t byteSize) + { + return oidnNewBuffer(handle, byteSize); + } + + // Creates a new shared buffer (data allocated and owned by the user). + BufferRef newBuffer(void* ptr, size_t byteSize) + { + return oidnNewSharedBuffer(handle, ptr, byteSize); + } + + // Creates a new filter of the specified type (e.g. "RT"). + FilterRef newFilter(const char* type) + { + return oidnNewFilter(handle, type); + } + }; + + // Gets a boolean parameter of the device. + template<> + inline bool DeviceRef::get(const char* name) + { + return oidnGetDevice1b(handle, name); + } + + // Gets an integer parameter of the device (e.g. "version"). + template<> + inline int DeviceRef::get(const char* name) + { + return oidnGetDevice1i(handle, name); + } + + // Creates a new device. + inline DeviceRef newDevice(DeviceType type = DeviceType::Default) + { + return DeviceRef(oidnNewDevice((OIDNDeviceType)type)); + } + +} // namespace oidn diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h new file mode 100644 index 0000000000..66b347c992 --- /dev/null +++ b/thirdparty/oidn/include/OpenImageDenoise/version.h @@ -0,0 +1,23 @@ +// ======================================================================== // +// Copyright 2009-2019 Intel Corporation // +// // +// Licensed under the Apache License, Version 2.0 (the "License"); // +// you may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +// ======================================================================== // + +#pragma once + +#define OIDN_VERSION_MAJOR 1 +#define OIDN_VERSION_MINOR 1 +#define OIDN_VERSION_PATCH 0 +#define OIDN_VERSION 10100 +#define OIDN_VERSION_STRING "1.1.0" diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE new file mode 100644 index 0000000000..d13f7b7ca0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/LICENSE @@ -0,0 +1,214 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ============================================================================ + + Intel MKL-DNN includes components with separate copyright + notices and license terms. + + XByak, 3-clause BSD license + Copyright (c) 2007 MITSUNARI Shigeo + See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT + + gtest, 3-clause BSD license + Copyright 2008, Google Inc. + See full copyright notice and license text in tests/gtests/gtest/LICENSE diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h new file mode 100644 index 0000000000..9b64994922 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.h @@ -0,0 +1,1771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_H +#define MKLDNN_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#include "mkldnn_version.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +/** @addtogroup c_api C API + * @{ */ + +/** @addtogroup c_api_primitive Primitive operations + * @{ */ + +/** @addtogroup c_api_primitive_common Common primitive operations + * @{ */ + +/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr, + * @p engine, and optionally a hint primitive descriptor from forward + * propagation (required for backward propagation). Pass @c NULL for forward + * propagation. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create( + mkldnn_primitive_desc_iterator_t *iterator, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no + * more primitive descriptors are available. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next( + mkldnn_primitive_desc_iterator_t iterator); + +/** Fetches the current primitive descriptor. + * + * @note + * The user should delete the fetched primitive descriptor using + * mkldnn_primitive_desc_destroy() once it is no longer needed. */ +mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch( + const_mkldnn_primitive_desc_iterator_t iterator); + +/** Deletes a primitive descriptor @p iterator */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy( + mkldnn_primitive_desc_iterator_t iterator); + +/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and + * optionally a hint primitive descriptor from forward propagation. The call is + * equivalent to creating a primitive descriptor iterator, immediately fetching + * a primitive descriptor, and then destroying the iterator. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine, + const_mkldnn_primitive_desc_t hint_forward_primitive_desc); + +/** Makes a copy of a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone( + mkldnn_primitive_desc_t *primitive_desc, + const_mkldnn_primitive_desc_t existing_primitive_desc); + +/** Returns a constant reference to the attribute of a @p primitive_desc. + * + * @warning + * The user should not destroy the obtained @p attr. + * + * @warning + * The lifetime of an @p attr is the same as that of a @p primitive_desc, + * so it is illegal to use the @p attr once @p primitive_desc has been + * destroyed. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr( + const_mkldnn_primitive_desc_t primitive_desc, + const_mkldnn_primitive_attr_t *attr); + +/** Deletes a @p primitive_desc. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy( + mkldnn_primitive_desc_t primitive_desc); + +/** Queries primitive descriptor + * + * One of the most typical use cases is to query a convolution primitive + * descriptor created with source, weights, and destination formats equal + * to #mkldnn_format_tag_any about the corresponding memory descriptors + * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and + * #mkldnn_query_dst_md respectively) to be able to prepare memory and + * create reorders if required. + * + * Another quite typical use case is to query an operation primitive + * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md). + * The returned status #mkldnn_not_required indicates that a workspace is + * not required. + * + * A few other possibilities: + * - query an operation primitive descriptor for the underlying operation + * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d, + * #mkldnn_query_rnn_d, etc.) + * - query an operation primitive descriptor for the implementation + * information string (#mkldnn_query_impl_info_str) + * - query an operation primitive descriptor for the number of inputs and + * outputs (#mkldnn_query_num_of_inputs_s32 and + * #mkldnn_query_num_of_outputs_s32 respectively) + * + * @sa mkldnn_query_t for more options + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index, void *result); + +/** Queries primitive descriptor for memory descriptor + * + * @returns NULL in case of any error. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Queries primitive descriptor for signed 32bit int + * + * @returns 0 in case of any error (in particular if the queried entity is + * not of type int32_t). Note that 0 might also be the actual returned + * value. + * + * This is just a specialized version of mkldnn_primitive_desc_query + * used for convenience. + */ +int MKLDNN_API mkldnn_primitive_desc_query_s32( + const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, + int index); + +/** Creates a @p primitive using a @p primitive_desc descriptor. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_create( + mkldnn_primitive_t *primitive, + const_mkldnn_primitive_desc_t primitive_desc); + +/** Executes a @p primitive using a @p stream, and @p nargs arguments + * @p args. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_execute( + const_mkldnn_primitive_t primitive, mkldnn_stream_t stream, + int nargs, const mkldnn_exec_arg_t *args); + +/** Retrieves a reference to the @p primitive_desc descriptor of given @p + * primitive. + * + * @warning + * The returned object must not be destroyed by the user. The @c const + * qualifier of the returned object prevents such attempts. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc( + const_mkldnn_primitive_t primitive, + const_mkldnn_primitive_desc_t *primitive_desc); + +/** Deletes a @p primitive. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy( + mkldnn_primitive_t primitive); + +/** @} */ + +/** @addtogroup c_api_attributes Attributes + * An extension for controlling primitive behavior. + * @{ */ + +/** Creates an empty (default) @p attr attribute. All the parameters are set to + * default values. + * + * An empty attribute is used in primitive descriptor creation whenever it + * is not passed explicitly, e.g. in mkldnn_primitive_desc_create. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create( + mkldnn_primitive_attr_t *attr); + +/** Makes a copy of an @p existing_attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone( + mkldnn_primitive_attr_t *attr, + const_mkldnn_primitive_attr_t existing_attr); + +/** Deletes an @p attr. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy( + mkldnn_primitive_attr_t attr); + +/** Returns the scratchpad @p mode set in the attribute @p attr */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode( + const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode); + +/** Sets scratchpad @p mode. + * + * The possible values are: #mkldnn_scratchpad_mode_library (default) and + * #mkldnn_scratchpad_mode_user. */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode( + mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode); + +/** Returns @p count, correspondence scale @p mask, and a pointer to a constant + * floating point array of output @p scales for given @p attr, previously set + * by mkldnn_primitive_attr_set_output_scales. + * + * @warning + * The @p scales array points to the internal @p attr field, so the user + * should not modify or destroy @p scales. + * + * @warning + * The lifetime of @p scales is the same as that of the @p attr to which it + * belongs, so it is illegal to use @p scales after @p attr is destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales( + const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask, + const float **scales); + +/** Sets output @p scales for primitive operations. The number of elements @p + * count and correspondence scale @p mask are stored for future use. + * + * The @p mask argument defines the correspondence between the output tensor + * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a + * dedicated scaling factor for each slice of the output tensor over the i-th + * dimension. Set @p mask to 0 to use a common scaling factor for the whole + * output tensor. + * + * @note + * The dimension order is always native and does not depend on the actual + * layout used. Examples: + * - 2D dimensional data the order of dimensions is always: (n, c) + * - 4D dimensional data the order is always: (n, c, h, w) + * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw) + * + * Example usage: + * @code + * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params + * float scales[oc] = { ... }; // unique output scales per output channel + * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ... + * + * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc + * + * mkldnn_primitive_attr_t attr; + * mkldnn_primitive_attr_create(&attr); // create default attributes + * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales); + * + * mkldnn_primitive_desc_t cpd; + * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL); + * @endcode + * + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is the user's + * responsibility to set proper values. The following formula must hold: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *scales); + +/** Returns @p post_ops for given @p attr. + * + * @warning + * @p post_ops points to the internal @p attr field, so the user should not + * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the + * same as that of the @p attr it belongs to, so it is illegal to use @p + * post_ops after @p attr has been destroyed. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops( + const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops); + +/** Sets configured @p post_ops to an attribute @p attr for future use (when + * primitive descriptor is being created). + * + * @note + * At this point in time, there is no way to check whether the primitive + * descriptor does or does not support a given sequence of post operations. + * Therefore the user should handle an error that might occur at the + * mkldnn_primitive_desc_create call. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops( + mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops); + +/** @addtogroup c_api_attributes_post_ops Sequence of post operations + * An extension for performing extra operations after a base operation. + * @{ */ + +/** Creates an empty sequence of post operations @p post_ops. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops); + +/** Deletes a @p post_ops sequence. */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops); + +/** Returns the @p length of post operations for given @p post_ops. */ +int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops); + +/** Returns the type of post operation with index @p index in given + * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */ +mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind( + const_mkldnn_post_ops_t post_ops, int index); + +/** Appends accumulation (sum) post operation to the @p post_ops. Prior to + * accumulating the result, the previous value would be multiplied by @p scale. + * + * The kind of this post operation is #mkldnn_sum. + * + * This feature might improve performance for cases like residual learning + * blocks, where the result of convolution is accumulated to the previously + * computed activations. The parameter @p scale might be extreme for the + * integer-based computations when the result and previous activations have + * different logical scaling factors. + * + * In the simplest case when the accumulation is the only post operation, the + * computations would be: + * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...) + * + * @note + * This post operation (as well as all the others) disregards the original + * layout of the destination; that is, the layout of the original + * destination is expected to be the same as the layout of the stored + * destination. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum( + mkldnn_post_ops_t post_ops, float scale); + +/** Gets the parameters of the accumulation (sum) post operation with index + * @p index in the sequence of @p post_ops. + * + * @note + * If index @p index would not correspond to the accumulation post + * operation, the function returns #mkldnn_invalid_arguments. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum( + const_mkldnn_post_ops_t post_ops, int index, float *scale); + +/** Appends eltwise post operation to the @p post_ops with given parameters + * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and + * mkldnn_eltwise_desc_t). + * + * The kind of this post operation is #mkldnn_eltwise. + * + * In the simplest case when the eltwise is the only post operation, the + * computations would be: + * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...) + * where eltwise_op is configured with the given parameters. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise( + mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, + float alpha, float beta); + +/** Gets the eltwise parameters of the post operation with index @p index in + * the sequence of @p post_ops. + */ +mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise( + const_mkldnn_post_ops_t post_ops, int index, float *scale, + mkldnn_alg_kind_t *alg, float *alpha, float *beta); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_memory Memory + * A primitive to describe and store data. + * + * The library supports various data types and formats. Memory hierarchy + * consists of three levels of abstraction: + * 1. **Memory descriptor** -- engine agnostic logical description of data + * (number of dimensions, dimensions themselves, and data type), and + * optionally the format/layout that describes the physical representation + * of data in memory. If the format is not known yet, one can pass + * #mkldnn_format_tag_any. This approach is used to allow compute-intensive + * primitives to specify the most appropriate format on their own with + * users required to reorder the data if the incoming format doesn't match + * the primitive's selection. Memory descriptor can be initialized with + * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides() + * functions, or by directly filling the mkldnn_memory_desc_t structure. + * The latter requires deep knowledge of how the physical data + * representation is mapped to the structure. + * The @ref understanding_memory_formats topic should shed some light on + * that. + * For the fully defined memory descriptors (i.e. where the format kind is + * not equal to #mkldnn_format_kind_any) a user can the size, using the + * mkldnn_memory_desc_get_size() function. As described in + * @ref understanding_memory_formats, the size of data sometimes cannot + * be computed as the product of dimensions times the size of the data + * type. So users are encouraged to use this function for better code + * portability. + * Two memory descriptors can be compared with mkldnn_memory_desc_equal(). + * The comparison is especially useful when checking whether a primitive + * requires reorder from the user's data format to the primitive's format. + * 2. **Memory** -- an engine-specific object that handles the data and its + * description (a memory descriptor). For CPU enigne, the data handle is + * simply a pointer to @c void. The data handle can be queried using + * mkldnn_memory_get_data_handle() and set using + * mkldnn_memory_set_data_handle(). The latter function always sets the + * memory in the padding region to zero, which is the invariant maintained + * by all the primitives in Intel MKL-DNN. + * See @ref understanding_memory_formats for more details. + * A memory can be created using mkldnn_memory_create() function. + * A memory can also be queried for the underlying memory descriptor and + * engine using mkldnn_memory_get_memory_desc() and + * mkldnn_memory_get_engine() functions. + * + * Along with ordinary memory with all dimensions being positive, Intel + * MKL-DNN supports *zero-volume* memory with one or more dimensions set to + * zero. This is to support the NumPy\* convention. + * If a *zero-volume* memory is passed to a primitive, the primitive does + * not perform any computations on this memory. For example: + * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)` + * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)` + * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)` + * destination (assuming strides are `1` and paddings are zero) and perform + * zero multiply-add operations. + * - Concatenation of three memories of shapes `(3, 4, 13, 13)`, + * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce + * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second + * input (however, if the user created a concatenation primitive descriptor + * with three inputs they should also provide all three memories to the + * concatenation primitive, including the one with zero second dimension). + * - However, Intel MKL-DNN would return an error when attempting to create a + * convolution with *zero-volume* memory passed for weights because such a + * convolution is not well-defined: + * ~~~ + * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3) + * ~~~ + * Should the values in the destination be zeroes or just not accessed at + * all? Moreover, backward pass w.r.t. weights in such cases is also not + * well-defined. + * + * Data handle of *zero-volume* memory is never accessed and hence can be + * unset (NULL in case of CPU engine). + * + * @sa @ref understanding_memory_formats + * @{ */ + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and @p strides. + * + * The @p strides might be NULL, which means the order of physical dimensions + * is the same as the order of logical ones. + * + * @note The logical order of dimensions is defined by a primitive that + * consumes the memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, const mkldnn_dims_t strides); + +/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p + * data_type, and format @p tag. + * + * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define + * the appropriate memory format. In this case, the @p format_kind would be set + * to #mkldnn_format_kind_any */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag( + mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, + mkldnn_data_type_t data_type, mkldnn_format_tag_t tag); + +/** Initializes a @p memory_desc for a given @p parent_memory_desc, with + * @p dims sizes and @p offsets. May fail if layout used does not allow + * obtain desired submemory. In this case consider using `extract` or `insert` + * primitive */ +mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory( + mkldnn_memory_desc_t *memory_desc, + const mkldnn_memory_desc_t *parent_memory_desc, + const mkldnn_dims_t dims, const mkldnn_dims_t offsets); + +/** Compares two memory descriptors. + * @return 1 if the descriptors are the same. + * @return 0 if the descriptors are different. + * + * Use this function to identify whether a reorder is required between the + * two memories */ +int MKLDNN_API mkldnn_memory_desc_equal( + const mkldnn_memory_desc_t *lhs, + const mkldnn_memory_desc_t *rhs); + +/** Returns the size (in bytes) that is required for given @p memory_desc */ +size_t MKLDNN_API mkldnn_memory_desc_get_size( + const mkldnn_memory_desc_t *memory_desc); + +/** Creates a memory for given @p memory_desc and @p engine. Also sets handle + * to @p native_handle. + * The @p native_handle can: + * - point to the user allocated memory, i.e. valid handle. In this case the + * library doesn't own allocated memory. + * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and + * attach memory. In this case the library owns allocated memory. + * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory. + */ +mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory, + const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine, + void *native_handle); + +/** Returns a @p memory_desc associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc( + const_mkldnn_memory_t memory, + const mkldnn_memory_desc_t **memory_desc); + +/** Returns an @p engine associated with @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine( + const_mkldnn_memory_t memory, mkldnn_engine_t *engine); + +/** For a @p memory, returns the data @p handle. + * + * For the CPU engine, the data handle is a pointer to the actual data. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle( + const_mkldnn_memory_t memory, void **handle); + +/** For a @p memory, sets the data @p handle. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle( + mkldnn_memory_t memory, void *handle); + +/** Deletes a @p memory. */ +mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory); + +/** @} */ + +/** @addtogroup c_api_reorder Reorder + * A primitive to copy data between memory formats. + * @{ */ + +/** Initializes a @p reorder_primitive_desc using the description of the source + * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md) + * memory, and an @p attr attribute. + * + * Inputs: + * - input (#mkldnn_query_src_md, 0) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create( + mkldnn_primitive_desc_t *reorder_primitive_desc, + mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md, + mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md, + const_mkldnn_primitive_attr_t attr); + +/** @} */ + +/** @addtogroup c_api_concat Concat + * A primitive to concatenate data by arbitrary dimension. + * @{ */ + +/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n + * inputs by @p concat_dimension with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - input 0 (#mkldnn_query_src_md, 0) + * - input 1 (#mkldnn_query_src_md, 1) + * - ... + * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create( + mkldnn_primitive_desc_t *concat_primitive_desc, + const mkldnn_memory_desc_t *dst_md, + int n, int concat_dimension, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_sum Sum + * A primitive to sum data. + * @{ */ + +/** Creates out-of-place @p sum_primitive_desc for sum of @p n + * inputs multiplied by scale with resulting @p output_desc memory + * descriptor. @p output_desc can be NULL or specified with the + * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory + * format would be chosen automatically. + * + * Inputs: + * - src 0 (#mkldnn_query_src_md, 0) + * - src 1 (#mkldnn_query_src_md, 1) + * - ... + * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1) + * + * Outputs: + * - output (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create( + mkldnn_primitive_desc_t *sum_primitive_desc, + const mkldnn_memory_desc_t *dst_mds, + int n, const float *scales, + const mkldnn_memory_desc_t *src_mds, + const_mkldnn_primitive_attr_t attr, + mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_convolution Convolution + * A primitive to compute convolution using different algorithms. + * + * \f[dst[n][oc][oh][ow] = + * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw] + * \cdot weights[g][oc][ic][kh][kw] + * + bias[g][oc],\f] + * + * where size of output spatial domain is given by + * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}} + * \right\rfloor + 1\f$, + * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}} + * \right\rfloor + 1\f$, + * + * and summation is carried over input channels \f$ic\f$ in + * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and + * \f$p_l, p_r\f$ are @p padding_l and @p padding_r. + * @{ */ + +/** Initializes a convolution descriptor @p conv_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. In order to create a + * convolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equal to #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * In order to create a dilated convolution without bias, @p bias_desc + * should either be @c NULL or point to a descriptor with memory format kind + * equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated convolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a convolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p dilates @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API +mkldnn_dilated_convolution_backward_weights_desc_init( + mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_deconvolution Deconvolution + * A primitive to compute deconvolution using different algorithms. + * + * @{ */ + + +/** Initializes a deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. In order to create a + * deconvolution without bias, @p bias_desc should either be @c NULL or point to + * a descriptor with memory format kind equals #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward + * propagation using @p prop_kind (possible values are #mkldnn_forward_training + * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, + * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to + * create a dilated deconvolution without bias, @p bias_desc should either be + * @c NULL or point to a descriptor with memory format kind equal + * #mkldnn_format_kind_undef. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to data using @p alg_kind, memory descriptors, @p strides, @p + * padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to data using @p alg_kind, memory descriptors, @p + * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a deconvolution descriptor @p conv_desc for backward propagation + * with respect to weights using @p alg_kind, memory descriptors, @p strides, + * @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, + mkldnn_padding_kind_t padding_kind); + +/** Initializes a dilated deconvolution descriptor @p conv_desc for backward + * propagation with respect to weights using @p alg_kind, memory descriptors, + * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init( + mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_shuffle Shuffle + * A primitive to shuffle data along the axis. + * @{ */ + +/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind, + * memory descriptor @p data_desc, @p axis, and @p group_size. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int axis, + mkldnn_dim_t group_size); + +/** Initializes a @p shuffle_desc for backward propagation using memory + * descriptor @p diff_data_desc, @p axis, and @p group_size. + * + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * + */ +mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init( + mkldnn_shuffle_desc_t *shuffle_desc, + const mkldnn_memory_desc_t *diff_data_desc, int axis, + mkldnn_dim_t group_size); + +/** @} */ + +/** @addtogroup c_api_eltwise Eltwise + * A primitive to compute element-wise operations like parametric rectifier + * linear unit (ReLU). + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * @warning Because the original src is required for backward pass, in-place + * forward pass in general cannot be applied during training. However, for some + * kinds of element-wise operations (namely ReLU with alpha parameter equals 0), + * dst and src can be interchangeable for the backward pass, which enables + * performing in-place forward even for training. + * + * @{ */ + +/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and + * @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + float alpha, float beta); + +/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind + * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the + * @p alpha and @p beta parameters. + * + * @sa mkldnn_eltwise_desc_t for details. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init( + mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, float alpha, float beta); + +/** @} */ + +/** @addtogroup c_api_softmax Softmax + * A primitive to perform softmax. + * + * \f[dst[u][c][in] = + * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])} + * {\sum\limits_{c}\{\exp(src[ou][c][in]) + * - \max\limits_{c}(src[ou][c][in])\}},\f] + * + * where \f$ou, iu\f$ are outer and inner sizes repectively, defined + * by @p data_desc.dims and @p softmax_axis. + * @{ */ + +/** Initializes a @p softmax_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference) + * and memory descriptor @p data_desc. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** Initializes a @p softmax_desc for backward propagation using memory + * descriptors @p diff_desc and @p data_desc. + * + * Inputs: + * - dst (#mkldnn_query_dst_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init( + mkldnn_softmax_desc_t *softmax_desc, + const mkldnn_memory_desc_t *diff_desc, + const mkldnn_memory_desc_t *data_desc, int softmax_axis); + +/** @} */ + +/** @addtogroup c_api_pooling Pooling + * A primitive to perform max or average pooling. + * + * Max pooling: + * \f[dst[n][oc][oh][ow] = + * \max\limits_{kw,kh} + * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f] + * + * Average pooling: + * \f[dst[n][oc][oh][ow] = + * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh} + * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f] + * + * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and + * output spatial dimensions are calculated similarly to how they are done in + * convolution. + * + * During training, max pooling requires a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to + * save indices where maximum was found. The workspace layout is opaque, and + * the indices cannot be restored from it. However, one can use backward + * pooling to perform up-sampling (used in some detection topologies). + * + * @{ */ + +/** Initializes a pooling descriptor @p pool_desc for forward propagation using + * @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling + * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l, + * @p padding_r, and @p padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max and + * @p prop_kind = #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** Initializes a pooling descriptor @p pool_desc for backward propagation + * using @p alg_kind, memory descriptors, and pooling parameters in the spatial + * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p + * padding_kind. + * + * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p alg_kind = #mkldnn_pooling_max + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init( + mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, + const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, + const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind); + +/** @} */ + +/** @addtogroup c_api_lrn LRN + * A primitive to perform local response normalization (LRN) across or within + * channels. + * + * LRN accross channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c+i][h][w])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * LRN within channels: + * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}} + * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2} + * (src[n][c][h+i][w+i])^2\right\}^{-\beta} + * src[n][c][h][w],\f] + * + * where \f$n_{l}\f$ is the @p local_size. + * + * During training, LRN might or might not require a workspace on forward + * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The + * behavior is implementation specific. Optimized implementations typically + * require a workspace and use it to save some intermediate results from the + * forward pass that accelerate computations on the backward pass. + * + * To check whether a workspace is required, query the LRN primitive descriptor + * for the workspace (#mkldnn_query_workspace_md). Success indicates that the + * workspace is required and its description will be returned. + * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd + * + * @{ */ + +/** Initializes an @p lrn_desc for forward propagation using @p prop_kind + * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference), + * @p alg_kind, memory descriptor @p data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, + mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, + mkldnn_dim_t local_size, float alpha, float beta, float k); + +/** Initializes an @p lrn_desc for backward propagation using @p alg_kind, + * memory descriptors @p data_desc and @p diff_data_desc, and regularization + * parameters @p local_size, @p alpha, @p beta, and @p k. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - workspace (#mkldnn_query_workspace_md, 0), + * if the underlying implementation requires + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init( + mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size, + float alpha, float beta, float k); + +/** @} */ + +/** @addtogroup c_api_batch_normalization Batch Normalization + * A primitive to perform batch normalization. + * + * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]} + * {\sqrt{\sigma[c] + eps}} + \beta[c],\f] + * + * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and, + * + * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$, + * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn} + * (src[n][c][h][w] - \mu[c])^2\f$, + * + * and @c eps is a constant to improve numerical stability. + * + * Both forward and backward passes support in-place operation; that is, src + * and dst point to the same memory for forward pass, and diff_dst and diff_src + * point to the same memory for backward pass. + * + * Batch normalization supports different flavors controlled by + * mkldnn_batch_normalization_desc_t. For example, batch normalization can + * compute the mean and variance on its own or take them as inputs. It can + * either perform scaling and shifting using gamma and beta parameters or not. + * Optionally it can also perform a fused ReLU, which in case of training would + * also require a workspace. + * + * @sa mkldnn_batch_normalization_desc_t + * @{ */ + +/** Initializes a batch normalization descriptor @p bnrm_desc for forward + * propagation using @p prop_kind (possible values are + * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor + * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit + * flags of type mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - variance (#mkldnn_query_src_md, 2), + * if #mkldnn_use_global_stats bit-flags is set in @p flags + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + * - mean (#mkldnn_query_dst_md, 1), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * @p prop_kind = #mkldnn_forward_training + * - variance (#mkldnn_query_dst_md, 2), + * if #mkldnn_use_global_stats bit-flags is not set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_forward_training + * + * @note In-place operation is supported; that is, dst points to the same memory + * as src. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** Initializes a batch normalization descriptor @p bnrm_desc for backward + * propagation with respect to data and scale-shift parameters using memory + * descriptors @p data_desc and @p diff_data_desc, normalization parameter + * @p epsilon, and @p flags set using bit flags of type + * mkldnn_batch_normalization_desc_t. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - mean (#mkldnn_query_src_md, 1) + * - variance (#mkldnn_query_src_md, 2) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - scale_and_shift (#mkldnn_query_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * - workspace (#mkldnn_query_workspace_md, 0), + * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0), + * if #mkldnn_use_scaleshift bit-flags is set in @p flags + * and @p prop_kind = #mkldnn_backward + * + * @note in-place operation is supported, + * i.e. diff_src points to the same memory as diff_dst. + * + * @sa mkldnn_batch_normalization_desc_t + */ +mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init( + mkldnn_batch_normalization_desc_t *bnrm_desc, + mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *diff_data_desc, + const mkldnn_memory_desc_t *data_desc, + float epsilon, unsigned flags); + +/** @} */ + +/** @addtogroup c_api_inner_product Inner product + * A primitive to compute an inner product. + * + * Inner product layer is also known as fully connected layer. + * With spatial dimension: + * + * \f[dst[n][oc] = \sum\limits_{ic, kh, kw} + * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw] + * + bias[oc]\f] + * @{ */ + +/** Initializes an inner product descriptor @p ip_desc for forward propagation + * using @p prop_kind (possible values are #mkldnn_forward_training and + * #mkldnn_forward_inference) and memory descriptors. In order to create an + * inner product without bias, @p bias_desc should be either @c NULL or a + * pointer to a descriptor with memory format kind equals + * #mkldnn_format_kind_undef. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * - bias (#mkldnn_query_weights_md, 1), if created with bias + * + * Outputs: + * - dst (#mkldnn_query_dst_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init( + mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to data using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * - weights (#mkldnn_query_weights_md, 0) + * + * Outputs: + * - diff_src (#mkldnn_query_diff_src_md, 0) + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *diff_src_desc, + const mkldnn_memory_desc_t *weights_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** Initializes an inner product descriptor @p ip_desc for backward propagation + * with respect to weights using memory descriptors. + * + * @note Memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src (#mkldnn_query_src_md, 0) + * - diff_dst (#mkldnn_query_diff_dst_md, 0) + * + * Outputs: + * - diff_weights (#mkldnn_query_diff_weights_md, 0) + * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias + */ +mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init( + mkldnn_inner_product_desc_t *ip_desc, + const mkldnn_memory_desc_t *src_desc, + const mkldnn_memory_desc_t *diff_weights_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_desc); + +/** @} */ + +/** @addtogroup c_api_rnn RNN + * A primitive to compute the common recurrent layer. + * @todo add additional description for the group + * @{ */ + +/** + * Initializes a recurrent cell descriptor @p rnn_cell_desc + * using @p rnn_cell_desc, @p kind (possible values are + * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and + * #mkldnn_gru_linear_before_reset), + * @p f (possible values are #mkldnn_eltwise_relu and + * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping. + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init( + mkldnn_rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, + unsigned int flags, float alpha, float clipping); + +/** Returns the number of gates of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_gates_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Returns the number of states of a particular @p rnn_cell_desc. */ +int MKLDNN_API mkldnn_rnn_cell_get_states_count( + const mkldnn_rnn_cell_desc_t *rnn_cell_desc); + +/** Sets quantization @p scale and @p shift for RNN data tensors. + * For performance reasons, low precision configuration of RNN primitive + * expects input activations to have unsigned int8 data type. Scale and shift + * used to quantize floating point data to unsigned integer must be passed to + * RNN primitive using attributes. + * Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // activations quantization parameters + * float scale = ..., shift = ..; + * + * mkldnn_primitive_attr_t rnn_attr; + * // create default attributes + * mkldnn_primitive_attr_create(&rnn_attr); + * + * // set scale and shift for int8 quantization of activation + * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * Quantization scale and shift are common for src_layer, src_iter, + * dst_iter and dst_layer. + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams( + mkldnn_primitive_attr_t attr, const float scale, const float shift); + +/** Sets quantization scales @p weights_scales for RNN weights tensors. + * Low precision configuration of RNN primitive expects input weights to have + * signed int8 data type. Scales used to quantize floating point data + * to signed integer must be passed to RNN primitive using attributes. + * The @p mask argument defines correspondence between output tensor dimensions + * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use + * dedicated scaling factor for each slice of the output tensor over i-th + * dimension. Set @p mask to 0 to use common scaling factor for the whole output + * tensor. Example usage: + * @code + * // rnn parameters + * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32; + * // unique output scales per output channel + * float weights_scales[dic * n_gates] = { ... }; + * // mask that specifies last two dimensions of ldigo format + * int mask = 0x3; + * + * mkldnn_primitive_attr_t attr; + * // create default attributes + * mkldnn_primitive_attr_create(&attr); + * + * // set output channel-wise weights scales + * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask, + * weights_scales); + * + * // create & configure rnn op_desc + * mkldnn_rnn_desc_t rnn_d; + * mkldnn_primitive_desc_t rnn_pd; + * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL); + * @endcode + * @note + * The dimension order is always native and does not depend on the actual + * layout used. For example, 5 dimensional weights always have + * (l, d, i, g, o) logical dimension ordering. + * @note + * Quantization sales are common for weights_layer and weights_iteration + * @note + * There is no way to check that @p count corresponds to @p mask until an + * actual primitive descriptor is created, so it is user's responsibility + * to set proper values. The following formula must be held: + * + * \f[count = \prod\limits_{d \in mask} output.dims[d]\f] + */ +mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams ( + mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask, + const float *weights_scales); + +/** Initializes a rnn descriptor @p rnn_desc for forward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * @note If @p prop_kind equals #mkldnn_forward_training, you must query a + * workspace memory descriptor before creating the primitive. + * + * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be + * @c NULL or point to a zero memory descriptor, which would indicate that the + * RNN primitive should not use them. + * + * @note All memory descriptors except @p src_iter_desc are allowed to be + * initialized with #mkldnn_format_kind_any value of @p format_kind. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * + * Outputs: + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0), + * if @p prop_kind equals #mkldnn_forward_training + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc); + +/** Initializes a rnn descriptor @p rnn_desc for backward propagation + * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors. + * + * @note All memory descriptors are allowed to be initialized with + * #mkldnn_format_kind_any value of @p format_kind. + * + * @p src_iter_desc (simultaneously with @p diff_src_iter_desc), + * @p bias_desc (simultaneously with @p diff_bias_desc), and + * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to + * either be @c NULL or point to a zero memory descriptor, which would indicate + * that the RNN primitive should not use them. + * + * Inputs: + * - src_layer (#mkldnn_query_src_md, 0) + * - src_iter (#mkldnn_query_src_md, 1), if used + * - weights_layer (#mkldnn_query_weights_md, 0) + * - weights_iter (#mkldnn_query_weights_md, 1) + * - bias (#mkldnn_query_weights_md, 2), if used + * - dst_layer (#mkldnn_query_dst_md, 0) + * - dst_iter (#mkldnn_query_dst_md, 1), if used + * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0) + * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used + * - workspace (#mkldnn_query_workspace_md, 0) + * + * Outputs: + * - diff_src_layer (#mkldnn_query_diff_src_md, 0) + * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used + * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0) + * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1) + * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used + */ +mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init( + mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, + const mkldnn_rnn_cell_desc_t *rnn_cell_desc, + const mkldnn_rnn_direction_t direction, + const mkldnn_memory_desc_t *src_layer_desc, + const mkldnn_memory_desc_t *src_iter_desc, + const mkldnn_memory_desc_t *weights_layer_desc, + const mkldnn_memory_desc_t *weights_iter_desc, + const mkldnn_memory_desc_t *bias_desc, + const mkldnn_memory_desc_t *dst_layer_desc, + const mkldnn_memory_desc_t *dst_iter_desc, + const mkldnn_memory_desc_t *diff_src_layer_desc, + const mkldnn_memory_desc_t *diff_src_iter_desc, + const mkldnn_memory_desc_t *diff_weights_layer_desc, + const mkldnn_memory_desc_t *diff_weights_iter_desc, + const mkldnn_memory_desc_t *diff_bias_desc, + const mkldnn_memory_desc_t *diff_dst_layer, + const mkldnn_memory_desc_t *diff_dst_iter_desc); + +/** @} */ + +/** @} */ + +/** @addtogroup c_api_engine Engine operations + * @{ */ + +/** Returns the number of engines of a particular @p kind. */ +size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind); + +/** Creates an @p engine of particular @p kind and @p index. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, + mkldnn_engine_kind_t kind, size_t index); + +/** Returns the kind of an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine, + mkldnn_engine_kind_t *kind); + +/** Destroys an @p engine. */ +mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine); + +/** @} */ + +/** @addtogroup c_api_stream Execution stream operations + * @{ */ + +/** Creates an execution @p stream for @p engine and with @p flags. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, + mkldnn_engine_t engine, unsigned flags); + +/** Destroys an execution @p stream. */ +mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream); + +/** @} */ + +/** @addtogroup c_api_service Service functions + * @{ */ + +/** Sets verbosity level (print information to stdout). + * Possible levels are: + * - 0 -- no verbose output (default) + * - 1 -- primitive information at execution + * - 2 -- primitive information at creation and execution + * + * @note + * Dumping information might affect performance. + * This setting overrides the MKLDNN_VERBOSE environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level); + +/** Enables or disables dumping of JIT-generated code. + * The enable parameter can be: + * - 0 -- disable + * - any other value -- enable + * + * @note + * This setting overrides the MKLDNN_JIT_DUMP environment variable. */ +mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable); + +/** Gets library version information. + * Version information includes: + * - major -- major version number + * - minor -- minor version number + * - patch -- patch release number + * - hash -- git commit hash */ +const mkldnn_version_t MKLDNN_API *mkldnn_version(); + +/** @} */ + +/** @addtogroup c_api_blas BLAS functions + * A subset of Basic Linear ALgebra (BLAS) functions to perform + * matrix-matrix multiplication. + * @{ */ + +/** SGEMM performs a matrix-matrix multiplication operation defined as + * + * C := alpha*op( A )*op( B ) + beta*C + * + * where + * - op( X ) is one of op( X ) = X or op( X ) = X**T, + * - alpha and beta are scalars, + * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix + * and C an m by n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different from the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_sgemm( + const char *transa, const char *transb, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, const float *A, const mkldnn_dim_t *lda, + const float *B, const mkldnn_dim_t *ldb, + const float *beta, float *C, const mkldnn_dim_t *ldc); + +/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication + * operation and add the result to a scalar-matrix product. For the final + * result, a vector is added to each row or column of the output matrix. + * The operation is defined as: + * + * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset + * + * where + * - op( X ) = X or op( X ) = X**T, + * - A_offset is an m-by-k matrix with every element equal to the value oa, + * - B_offset is an k-by-n matrix with every element equal to the value ob, + * - C_offset is an m-by-n matrix defined by the oc array, size len: + * - if offsetc = F: len must be at least 1 + * - if offsetc = C: len must be at least max(1, m) + * - if offsetc = R: len must be at least max(1, n) + * - alpha and beta are scalars, and A, B and C are matrices, with op( A ) + * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix. + * + * The matrices are assumed to be stored in column-major order (the elements + * in a matrix columns are contiguous in memory). + * + * @note + * The API is different compared with the standard BLAS routine + * because it returns mkldnn_status_t for error handling. + * XERBLA is not supported: no error message will be printed + * in case of incorrect parameters. */ +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); + +mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32( + const char *transa, const char *transb, const char *offsetc, + const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K, + const float *alpha, + const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao, + const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo, + const float *beta, + int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co); +/** @} */ + +/** @} */ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp new file mode 100644 index 0000000000..581400a013 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp @@ -0,0 +1,2615 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_HPP +#define MKLDNN_HPP + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include <stdlib.h> +#include <memory> +#include <vector> +#include <unordered_map> +#include <algorithm> +#include <iterator> + +#include "mkldnn.h" +#endif + +namespace mkldnn { + +/// @addtogroup cpp_api C++ API +/// @{ + +/// @addtogroup cpp_api_utils Utils +/// @{ + +/// A class that provides the destructor for an Intel(R) MKL-DNN C handle +template <typename T> class handle_traits {}; + +/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base +/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and +/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class +/// can be passed by value. This class enables wrapping: +/// - Newly constructed handles. +/// @n In this case, the constructed handle uses reference counting provided +/// by @p std::shared_ptr with a proper deleter function specified through +/// the @p handle_traits class. +/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for +/// example, through mkldnn_primitive_get_primitive_desc()). +/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a +/// deleter because it is assumed that the handle wrapper for the original +/// object deletes the handle (this model is similar to @p std::weak_ptr). +template <typename T, typename traits=handle_traits<T>> class handle { +private: + std::shared_ptr<typename std::remove_pointer<T>::type> _data; + handle(const handle &&) = delete; + handle &operator=(const handle &&other) = delete; +protected: + bool operator==(const T other) const { return other == _data.get(); } + bool operator!=(const T other) const { return !(*this == other); } +public: + /// Constructs a C handle wrapper. + /// @param t The C handle to wrap. + /// @param weak A flag to specify whether to construct a weak wrapper. + handle(T t = 0, bool weak = false): _data(0) { + reset(t, weak); + } + + handle(const handle &other): _data(other._data) {} + handle &operator=(const handle &other) { + _data = other._data; + return *this; + } + /// Resets the value of a C handle. + /// @param t The new value of the C handle. + /// @param weak A flag to specify whether the wrapper should be weak. + void reset(T t, bool weak = false) { + auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); }; + _data.reset(t, weak ? dummy_destructor : traits::destructor); + } + + /// Returns the value of the underlying C handle. + T get() const { return _data.get(); } + + bool operator==(const handle &other) const { return other._data.get() == _data.get(); } + bool operator!=(const handle &other) const { return !(*this == other); } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_memory_t> { + static constexpr auto destructor = &mkldnn_memory_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_desc_t> { + static constexpr auto destructor = &mkldnn_primitive_desc_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_t> { + static constexpr auto destructor = &mkldnn_primitive_destroy; +}; + +template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> { + static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy; +}; +#endif + +struct memory; +struct primitive_desc; + +/// Base class for all computational primitives. +class primitive: public handle<mkldnn_primitive_t> { + friend struct error; + friend struct stream; + using handle::handle; +public: + /// A proxy to C primitive kind enum + enum class kind { + undefined_primitive = mkldnn_undefined_primitive, + reorder = mkldnn_reorder, + concat = mkldnn_concat, + sum = mkldnn_sum, + convolution = mkldnn_convolution, + deconvolution = mkldnn_deconvolution, + shuffle = mkldnn_shuffle, + eltwise = mkldnn_eltwise, + softmax = mkldnn_softmax, + pooling = mkldnn_pooling, + lrn = mkldnn_lrn, + batch_normalization = mkldnn_batch_normalization, + inner_product = mkldnn_inner_product, + rnn = mkldnn_rnn, + }; + + primitive(const_mkldnn_primitive_desc_t c_pd); + primitive(const primitive_desc &pd); + + /// Returns the descriptor of the underlying C API primitive. + inline const_mkldnn_primitive_desc_t get_primitive_desc() const; + // TODO: use the C++ API wrapper structure. + + void execute(struct stream &astream, + const std::unordered_map<int, memory> &args) const; +}; + +inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) { + return static_cast<mkldnn_primitive_kind_t>(akind); +} +/// Intel(R) MKL-DNN exception class. +/// +/// This class captures the status returned by the failed C API function, error +/// message, and, optionally, handle of the primitive that caused the error. +struct error: public std::exception { + mkldnn_status_t status; + const char *message; + + /// Constructs an error instance. + /// + /// @param astatus The error status returned by the C API. + /// @param amessage The error message. + error(mkldnn_status_t astatus, const char *amessage) + : status(astatus), message(amessage) {} + + /// A convenience function for wrapping calls to the C API. Checks the + /// return status and throws an #error in case of failure. + /// + /// @param status The error status returned by the C API. + /// @param message The error message. + static void wrap_c_api(mkldnn_status_t status, const char *message) { + if (status != mkldnn_success) + throw error(status, message); + } +}; + +const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const { + const_mkldnn_primitive_desc_t pd; + error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd), + "could not get primitive descriptor by primitive"); + return pd; +} +/// @} + +/// @addtogroup cpp_api_enums Common data types and enumerations +/// A proxy to @ref c_api_types in @ref c_api. +/// +/// @{ + +enum scratchpad_mode { + scratchpad_mode_library = mkldnn_scratchpad_mode_library, + scratchpad_mode_user = mkldnn_scratchpad_mode_user, +}; + +inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) { + return static_cast<mkldnn_scratchpad_mode_t>(mode); +} + +enum padding_kind { + zero = mkldnn_padding_zero +}; + +inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) { + return static_cast<mkldnn_padding_kind_t>(kind); +} + +enum prop_kind { + forward_training = mkldnn_forward_training, + forward_scoring = mkldnn_forward_scoring, + forward_inference = mkldnn_forward_inference, + forward = mkldnn_forward, + backward = mkldnn_backward, + backward_data = mkldnn_backward_data, + backward_weights = mkldnn_backward_weights, + backward_bias = mkldnn_backward_bias +}; + +inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) { + return static_cast<mkldnn_prop_kind_t>(kind); +} + +enum algorithm { + algorithm_undef = mkldnn_alg_kind_undef, + convolution_auto = mkldnn_convolution_auto, + convolution_direct = mkldnn_convolution_direct, + convolution_winograd = mkldnn_convolution_winograd, + deconvolution_direct = mkldnn_deconvolution_direct, + deconvolution_winograd = mkldnn_deconvolution_winograd, + eltwise_relu = mkldnn_eltwise_relu, + eltwise_tanh = mkldnn_eltwise_tanh, + eltwise_elu = mkldnn_eltwise_elu, + eltwise_square = mkldnn_eltwise_square, + eltwise_abs = mkldnn_eltwise_abs, + eltwise_sqrt = mkldnn_eltwise_sqrt, + eltwise_linear = mkldnn_eltwise_linear, + eltwise_bounded_relu = mkldnn_eltwise_bounded_relu, + eltwise_soft_relu = mkldnn_eltwise_soft_relu, + eltwise_logistic = mkldnn_eltwise_logistic, + lrn_across_channels = mkldnn_lrn_across_channels, + lrn_within_channel = mkldnn_lrn_within_channel, + pooling_max = mkldnn_pooling_max, + pooling_avg = mkldnn_pooling_avg, + pooling_avg_include_padding = mkldnn_pooling_avg_include_padding, + pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding, + vanilla_rnn = mkldnn_vanilla_rnn, + vanilla_lstm = mkldnn_vanilla_lstm, + vanilla_gru = mkldnn_vanilla_gru, + gru_linear_before_reset = mkldnn_gru_linear_before_reset +}; + +inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) { + return static_cast<mkldnn_alg_kind_t>(aalgorithm); +} + +enum batch_normalization_flag { + use_global_stats = mkldnn_use_global_stats, + use_scale_shift = mkldnn_use_scaleshift, + fuse_bn_relu = mkldnn_fuse_bn_relu +}; + +inline mkldnn_batch_normalization_flag_t convert_to_c( + batch_normalization_flag aflag) { + return static_cast<mkldnn_batch_normalization_flag_t>(aflag); +} + +enum rnn_direction { + unidirectional_left2right = mkldnn_unidirectional_left2right, + unidirectional_right2left = mkldnn_unidirectional_right2left, + unidirectional = mkldnn_unidirectional, + bidirectional_concat = mkldnn_bidirectional_concat, + bidirectional_sum = mkldnn_bidirectional_sum, +}; + +inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) { + return static_cast<mkldnn_rnn_direction_t>(adir); +} + +enum query { + undef = mkldnn_query_undef, + + query_engine = mkldnn_query_engine, + primitive_kind = mkldnn_query_primitive_kind, + + num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32, + num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32, + + time_estimate_f64 = mkldnn_query_time_estimate_f64, + memory_consumption_s64 = mkldnn_query_memory_consumption_s64, + + query_scratchpad_engine = mkldnn_query_scratchpad_engine, + + impl_info_str = mkldnn_query_impl_info_str, + + op_d = mkldnn_query_op_d, + convolution_d = mkldnn_query_convolution_d, + deconvolution_d = mkldnn_query_deconvolution_d, + shuffle_d = mkldnn_query_shuffle_d, + eltwise_d = mkldnn_query_eltwise_d, + softmax_d = mkldnn_query_softmax_d, + pooling_d = mkldnn_query_pooling_d, + lrn_d = mkldnn_query_lrn_d, + batch_normalization_d = mkldnn_query_batch_normalization_d, + inner_product_d = mkldnn_query_inner_product_d, + rnn_d = mkldnn_query_rnn_d, + + src_md = mkldnn_query_src_md, + diff_src_md = mkldnn_query_diff_src_md, + weights_md = mkldnn_query_weights_md, + diff_weights_md = mkldnn_query_diff_weights_md, + dst_md = mkldnn_query_dst_md, + diff_dst_md = mkldnn_query_diff_dst_md, + workspace_md = mkldnn_query_workspace_md, + scratchpad_md = mkldnn_query_scratchpad_md, +}; + +inline mkldnn_query_t convert_to_c(query aquery) { + return static_cast<mkldnn_query_t>(aquery); +} + +/// @} + +/// @addtogroup cpp_api_attr Attributes +/// An extension for controlling primitive behavior. +/// +/// @sa @ref c_api_attributes in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_post_ops_t> { + static constexpr auto destructor = &mkldnn_post_ops_destroy; +}; +#endif + +struct post_ops: public handle<mkldnn_post_ops_t> { + post_ops() { + mkldnn_post_ops_t result; + error::wrap_c_api(mkldnn_post_ops_create(&result), + "could not create post operation sequence"); + reset(result); + } + + int len() const { return mkldnn_post_ops_len(get()); } + + primitive::kind kind(int index) const { + error::wrap_c_api( + index < len() ? mkldnn_success : mkldnn_invalid_arguments, + "post_ops index is out of range"); + return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(), + index)); + } + + void append_sum(float scale = 1.) { + error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale), + "could not append sum"); + } + + void get_params_sum(int index, float &scale) const { + error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale), + "could not get sum params"); + } + + void append_eltwise(float scale, algorithm alg, float alpha, + float beta) { + error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale, + convert_to_c(alg), alpha, beta), + "could not append eltwise"); + } + + void get_params_eltwise(int index, float &scale, algorithm &alg, + float &alpha, float &beta) const { + mkldnn_alg_kind_t c_alg; + error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index, + &scale, &c_alg, &alpha, &beta), + "could not get eltwise params"); + alg = static_cast<algorithm>(c_alg); + } +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_primitive_attr_t> { + static constexpr auto destructor = &mkldnn_primitive_attr_destroy; +}; +#endif + +struct primitive_attr: public handle<mkldnn_primitive_attr_t> { + primitive_attr() { + mkldnn_primitive_attr_t result; + error::wrap_c_api(mkldnn_primitive_attr_create(&result), + "could not create a primitive attr"); + reset(result); + } + + scratchpad_mode get_scratchpad_mode() const { + mkldnn_scratchpad_mode_t result; + error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode( + get(), &result), "could not get scratchpad mode"); + return scratchpad_mode(result); + } + + void set_scratchpad_mode(scratchpad_mode mode) { + error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode( + get(), mkldnn::convert_to_c(mode)), + "could not set scratchpad mode"); + } + + void get_output_scales(int &mask, std::vector<float> &scales) const + { + mkldnn_dim_t count; + int c_mask; + const float *c_scales; + error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(), + &count, &c_mask, &c_scales), + "could not get int output scales"); + scales.resize(count); + + mask = c_mask; + for (mkldnn_dim_t c = 0; c < count; ++c) + scales[c] = c_scales[c]; + } + + void set_output_scales(int mask, const std::vector<float> &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(), + (mkldnn_dim_t)scales.size(), mask, &scales[0]), + "could not set int output scales"); + } + + const post_ops get_post_ops() const { + post_ops result; + const_mkldnn_post_ops_t c_result; + error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result), + "could not get post operation sequence"); + result.reset(const_cast<mkldnn_post_ops_t>(c_result), true); + return result; + } + + void set_post_ops(post_ops ops) { + error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()), + "could not set post operation sequence"); + } + + void set_rnn_data_qparams(const float scale, const float shift) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(), + scale, shift), "could not set rnn data int scale/shift"); + } + + void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) + { + error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(), + (int)scales.size(), mask, &scales[0]), + "could not set rnn weights int scales"); + } +}; + +/// @} + +/// @addtogroup cpp_api_engine Engine +/// Engine operations. +/// +/// @sa @ref c_api_engine in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_engine_t> { + static constexpr auto destructor = &mkldnn_engine_destroy; +}; +#endif + +/// An execution engine. +struct engine: public handle<mkldnn_engine_t> { + friend class primitive; + // gcc bug??? using handle::handle; + + /// Kinds of engines. + enum kind { + /// An unspecified engine + any = mkldnn_any_engine, + /// CPU engine + cpu = mkldnn_cpu, + }; + + /// Returns the number of engines of a certain kind. + /// + /// @param akind The kind of engines to count. + + static size_t get_count(kind akind) { + return mkldnn_engine_get_count(convert_to_c(akind)); + } + + /// Constructs an engine. + /// + /// @param akind The kind of engine to construct. + /// @param index The index of the engine. Must be less than the value + /// returned by #get_count() for this particular kind of engine. + + engine(kind akind, size_t index) { + mkldnn_engine_t aengine; + error::wrap_c_api( + mkldnn_engine_create(&aengine, + convert_to_c(akind), index), + "could not create an engine"); + reset(aengine); + } + + explicit engine(const mkldnn_engine_t& aengine) + : handle(aengine, true) {} + + engine(const handle<mkldnn_primitive_desc_t> &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + reset(engine_q, true); + } + + template <class primitive_desc> + static engine query(const primitive_desc &pd) { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(pd.get(), + mkldnn::convert_to_c(query_engine), 0, &engine_q), + "could not get engine from primitive_desc"); + + return engine(engine_q); + } + +private: + static mkldnn_engine_kind_t convert_to_c(kind akind) { + return static_cast<mkldnn_engine_kind_t>(akind); + } +}; + +/// @} + +/// @addtogroup cpp_api_stream Stream +/// Execution stream operations +/// +/// @sa @ref c_api_stream in @ref c_api +/// @{ + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template <> struct handle_traits<mkldnn_stream_t> { + static constexpr auto destructor = &mkldnn_stream_destroy; +}; +#endif + +struct stream: public handle<mkldnn_stream_t> { + using handle::handle; + + enum: unsigned { + default_flags = mkldnn_stream_default_flags, + }; + + /// Constructs a stream. + stream(const engine &aengine, + unsigned flags = static_cast<unsigned>(default_flags)) { + mkldnn_stream_t astream; + error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags), + "could not create a stream"); + reset(astream); + } +}; + +/// @} + +/// @addtogroup cpp_api_memory_related Memory and memory related operations +/// @{ + +/// @addtogroup cpp_api_memory Memory +/// A primitive to describe and store data. +/// +/// For more information, refer to @ref c_api_memory in @ref c_api. +/// @{ + +/// Memory that describes the data. +struct memory: public handle<mkldnn_memory_t> { + public: + typedef mkldnn_dim_t dim; + typedef std::vector<dim> dims; + + template <typename T> static void validate_dims(const std::vector<T> &v) { + if (v.size() > MKLDNN_MAX_NDIMS) + throw error(mkldnn_invalid_arguments, "invalid dimensions"); + } + + /// Data type specification. See #mkldnn_data_type_t for a detailed + /// description. + enum data_type { + data_undef = mkldnn_data_type_undef, + f32 = mkldnn_f32, + s32 = mkldnn_s32, + s8 = mkldnn_s8, + u8 = mkldnn_u8, + }; + + /// Memory format tag specification. See #mkldnn_format_tag_t + /// for a detailed description. + enum format_tag { + format_tag_undef = mkldnn_format_tag_undef, + any = mkldnn_format_tag_any, + a = mkldnn_a, + ab = mkldnn_ab, + abc = mkldnn_abc, + abcd = mkldnn_abcd, + abcde = mkldnn_abcde, + abcdef = mkldnn_abcdef, + abdec = mkldnn_abdec, + acb = mkldnn_acb, + acbde = mkldnn_acbde, + acdb = mkldnn_acdb, + acdeb = mkldnn_acdeb, + ba = mkldnn_ba, + bac = mkldnn_bac, + bacd = mkldnn_bacd, + bcda = mkldnn_bcda, + cba = mkldnn_cba, + cdba = mkldnn_cdba, + cdeba = mkldnn_cdeba, + decab = mkldnn_decab, + Abc16a = mkldnn_Abc16a, + ABc16a16b = mkldnn_ABc16a16b, + aBc16b = mkldnn_aBc16b, + ABc16b16a = mkldnn_ABc16b16a, + Abc4a = mkldnn_Abc4a, + aBc4b = mkldnn_aBc4b, + ABc4b16a4b = mkldnn_ABc4b16a4b, + ABc4b4a = mkldnn_ABc4b4a, + ABc8a16b2a = mkldnn_ABc8a16b2a, + ABc8a8b = mkldnn_ABc8a8b, + aBc8b = mkldnn_aBc8b, + ABc8b16a2b = mkldnn_ABc8b16a2b, + ABc8b8a = mkldnn_ABc8b8a, + Abcd16a = mkldnn_Abcd16a, + ABcd16a16b = mkldnn_ABcd16a16b, + aBcd16b = mkldnn_aBcd16b, + ABcd16b16a = mkldnn_ABcd16b16a, + aBCd16b16c = mkldnn_aBCd16b16c, + aBCd16c16b = mkldnn_aBCd16c16b, + Abcd4a = mkldnn_Abcd4a, + aBcd4b = mkldnn_aBcd4b, + ABcd4b16a4b = mkldnn_ABcd4b16a4b, + ABcd4b4a = mkldnn_ABcd4b4a, + aBCd4c16b4c = mkldnn_aBCd4c16b4c, + aBCd4c4b = mkldnn_aBCd4c4b, + ABcd8a16b2a = mkldnn_ABcd8a16b2a, + ABcd8a8b = mkldnn_ABcd8a8b, + aBcd8b = mkldnn_aBcd8b, + ABcd8b16a2b = mkldnn_ABcd8b16a2b, + aBCd8b16c2b = mkldnn_aBCd8b16c2b, + ABcd8b8a = mkldnn_ABcd8b8a, + aBCd8b8c = mkldnn_aBCd8b8c, + aBCd8c16b2c = mkldnn_aBCd8c16b2c, + aBCd8c8b = mkldnn_aBCd8c8b, + Abcde16a = mkldnn_Abcde16a, + ABcde16a16b = mkldnn_ABcde16a16b, + aBcde16b = mkldnn_aBcde16b, + ABcde16b16a = mkldnn_ABcde16b16a, + aBCde16b16c = mkldnn_aBCde16b16c, + aBCde16c16b = mkldnn_aBCde16c16b, + aBCde2c8b4c = mkldnn_aBCde2c8b4c, + Abcde4a = mkldnn_Abcde4a, + aBcde4b = mkldnn_aBcde4b, + ABcde4b4a = mkldnn_ABcde4b4a, + aBCde4b4c = mkldnn_aBCde4b4c, + aBCde4c16b4c = mkldnn_aBCde4c16b4c, + aBCde4c4b = mkldnn_aBCde4c4b, + Abcde8a = mkldnn_Abcde8a, + ABcde8a8b = mkldnn_ABcde8a8b, + aBcde8b = mkldnn_aBcde8b, + ABcde8b16a2b = mkldnn_ABcde8b16a2b, + aBCde8b16c2b = mkldnn_aBCde8b16c2b, + ABcde8b8a = mkldnn_ABcde8b8a, + aBCde8b8c = mkldnn_aBCde8b8c, + aBCde8c16b2c = mkldnn_aBCde8c16b2c, + aBCde8c8b = mkldnn_aBCde8c8b, + aBcdef16b = mkldnn_aBcdef16b, + aBCdef16b16c = mkldnn_aBCdef16b16c, + aBCdef16c16b = mkldnn_aBCdef16c16b, + aBcdef4b = mkldnn_aBcdef4b, + aBCdef4c4b = mkldnn_aBCdef4c4b, + aBCdef8b8c = mkldnn_aBCdef8b8c, + aBCdef8c16b2c = mkldnn_aBCdef8c16b2c, + aBCdef8c8b = mkldnn_aBCdef8c8b, + aBdc16b = mkldnn_aBdc16b, + aBdc4b = mkldnn_aBdc4b, + aBdc8b = mkldnn_aBdc8b, + aBdec16b = mkldnn_aBdec16b, + aBdec4b = mkldnn_aBdec4b, + aBdec8b = mkldnn_aBdec8b, + aBdefc16b = mkldnn_aBdefc16b, + aBdefc4b = mkldnn_aBdefc4b, + aBdefc8b = mkldnn_aBdefc8b, + Acb16a = mkldnn_Acb16a, + Acb4a = mkldnn_Acb4a, + Acb8a = mkldnn_Acb8a, + aCBd16b16c = mkldnn_aCBd16b16c, + aCBde16b16c = mkldnn_aCBde16b16c, + Acdb16a = mkldnn_Acdb16a, + Acdb4a = mkldnn_Acdb4a, + Acdb8a = mkldnn_Acdb8a, + Acdeb16a = mkldnn_Acdeb16a, + Acdeb4a = mkldnn_Acdeb4a, + Acdeb8a = mkldnn_Acdeb8a, + BAc16a16b = mkldnn_BAc16a16b, + BAcd16a16b = mkldnn_BAcd16a16b, + format_tag_last = mkldnn_format_tag_last, + + x = mkldnn_x, + nc = mkldnn_nc, + cn = mkldnn_cn, + ncw = mkldnn_ncw, + nwc = mkldnn_nwc, + nchw = mkldnn_nchw, + nhwc = mkldnn_nhwc, + chwn = mkldnn_chwn, + ncdhw = mkldnn_ncdhw, + ndhwc = mkldnn_ndhwc, + oi = mkldnn_oi, + io = mkldnn_io, + oiw = mkldnn_oiw, + wio = mkldnn_wio, + oihw = mkldnn_oihw, + hwio = mkldnn_hwio, + ihwo = mkldnn_ihwo, + iohw = mkldnn_iohw, + oidhw = mkldnn_oidhw, + dhwio = mkldnn_dhwio, + goiw = mkldnn_goiw, + goihw = mkldnn_goihw, + hwigo = mkldnn_hwigo, + giohw = mkldnn_giohw, + goidhw = mkldnn_goidhw, + tnc = mkldnn_tnc, + ntc = mkldnn_ntc, + ldsnc = mkldnn_ldsnc, + ldigo = mkldnn_ldigo, + ldgoi = mkldnn_ldgoi, + ldgo = mkldnn_ldgo, + nCdhw16c = mkldnn_nCdhw16c, + nCdhw4c = mkldnn_nCdhw4c, + nCdhw8c = mkldnn_nCdhw8c, + nChw16c = mkldnn_nChw16c, + nChw4c = mkldnn_nChw4c, + nChw8c = mkldnn_nChw8c, + nCw16c = mkldnn_nCw16c, + nCw4c = mkldnn_nCw4c, + nCw8c = mkldnn_nCw8c, + IOw16o16i = mkldnn_IOw16o16i, + OIw16i16o = mkldnn_OIw16i16o, + OIw16o16i = mkldnn_OIw16o16i, + Oiw16o = mkldnn_Oiw16o, + OIw4i16o4i = mkldnn_OIw4i16o4i, + OIw4i4o = mkldnn_OIw4i4o, + Oiw4o = mkldnn_Oiw4o, + OIw8i16o2i = mkldnn_OIw8i16o2i, + OIw8i8o = mkldnn_OIw8i8o, + OIw8o16i2o = mkldnn_OIw8o16i2o, + OIw8o8i = mkldnn_OIw8o8i, + Owi16o = mkldnn_Owi16o, + Owi4o = mkldnn_Owi4o, + Owi8o = mkldnn_Owi8o, + IOhw16o16i = mkldnn_IOhw16o16i, + Ohwi16o = mkldnn_Ohwi16o, + Ohwi4o = mkldnn_Ohwi4o, + Ohwi8o = mkldnn_Ohwi8o, + OIhw16i16o = mkldnn_OIhw16i16o, + OIhw16o16i = mkldnn_OIhw16o16i, + Oihw16o = mkldnn_Oihw16o, + OIhw4i16o4i = mkldnn_OIhw4i16o4i, + OIhw4i4o = mkldnn_OIhw4i4o, + Oihw4o = mkldnn_Oihw4o, + OIhw8i16o2i = mkldnn_OIhw8i16o2i, + OIhw8i8o = mkldnn_OIhw8i8o, + OIhw8o16i2o = mkldnn_OIhw8o16i2o, + OIhw8o8i = mkldnn_OIhw8o8i, + Odhwi16o = mkldnn_Odhwi16o, + Odhwi4o = mkldnn_Odhwi4o, + Odhwi8o = mkldnn_Odhwi8o, + OIdhw16i16o = mkldnn_OIdhw16i16o, + OIdhw16o16i = mkldnn_OIdhw16o16i, + Oidhw16o = mkldnn_Oidhw16o, + OIdhw4i4o = mkldnn_OIdhw4i4o, + Oidhw4o = mkldnn_Oidhw4o, + OIdhw8i16o2i = mkldnn_OIdhw8i16o2i, + OIdhw8i8o = mkldnn_OIdhw8i8o, + OIdhw8o8i = mkldnn_OIdhw8o8i, + gIOw16o16i = mkldnn_gIOw16o16i, + gOIw16i16o = mkldnn_gOIw16i16o, + gOIw16o16i = mkldnn_gOIw16o16i, + gOiw16o = mkldnn_gOiw16o, + gOIw4i16o4i = mkldnn_gOIw4i16o4i, + gOIw4i4o = mkldnn_gOIw4i4o, + gOiw4o = mkldnn_gOiw4o, + gOIw8i16o2i = mkldnn_gOIw8i16o2i, + gOIw8i8o = mkldnn_gOIw8i8o, + gOIw8o16i2o = mkldnn_gOIw8o16i2o, + gOIw8o8i = mkldnn_gOIw8o8i, + gOwi16o = mkldnn_gOwi16o, + gOwi4o = mkldnn_gOwi4o, + gOwi8o = mkldnn_gOwi8o, + gIOhw16o16i = mkldnn_gIOhw16o16i, + gOhwi16o = mkldnn_gOhwi16o, + gOhwi4o = mkldnn_gOhwi4o, + gOhwi8o = mkldnn_gOhwi8o, + Goihw16g = mkldnn_Goihw16g, + gOIhw16i16o = mkldnn_gOIhw16i16o, + gOIhw16o16i = mkldnn_gOIhw16o16i, + gOihw16o = mkldnn_gOihw16o, + gOIhw2i8o4i = mkldnn_gOIhw2i8o4i, + gOIhw4i16o4i = mkldnn_gOIhw4i16o4i, + gOIhw4i4o = mkldnn_gOIhw4i4o, + gOIhw4o4i = mkldnn_gOIhw4o4i, + gOihw4o = mkldnn_gOihw4o, + Goihw8g = mkldnn_Goihw8g, + gOIhw8i16o2i = mkldnn_gOIhw8i16o2i, + gOIhw8i8o = mkldnn_gOIhw8i8o, + gOIhw8o16i2o = mkldnn_gOIhw8o16i2o, + gOIhw8o8i = mkldnn_gOIhw8o8i, + gOdhwi16o = mkldnn_gOdhwi16o, + gOdhwi4o = mkldnn_gOdhwi4o, + gOdhwi8o = mkldnn_gOdhwi8o, + gOIdhw16i16o = mkldnn_gOIdhw16i16o, + gOIdhw16o16i = mkldnn_gOIdhw16o16i, + gOidhw16o = mkldnn_gOidhw16o, + gOIdhw4i4o = mkldnn_gOIdhw4i4o, + gOidhw4o = mkldnn_gOidhw4o, + gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i, + gOIdhw8i8o = mkldnn_gOIdhw8i8o, + gOIdhw8o8i = mkldnn_gOIdhw8o8i, + }; + + /// A memory descriptor. + struct desc { + friend struct memory; + /// The underlying C API data structure. + mkldnn_memory_desc_t data; + + /// Constructs a zero memory descriptor + desc(): data() {} + + /// Constructs a memory descriptor. + /// + /// @param adims Data dimensions + /// @param adata_type Data precision/type. + /// @param aformat Data layout format tag. + desc(const dims &adims, data_type adata_type, + format_tag aformat) { + validate_dims(adims); + error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(), + adims.size() == 0 ? nullptr : &adims[0], + convert_to_c(adata_type), convert_to_c(aformat)), + "could not initialize a memory descriptor"); + } + + /// Constructs a memory descriptor from a C API data structure. + /// + /// @param adata A C API #mkldnn_memory_desc_t structure. + desc(const mkldnn_memory_desc_t &adata): data(adata) {} + + /// Constructs a sub-memory descriptor + // + /// @param adims Sizes of a sub-memory + /// @param offsets Offsets of a sub-memory + desc submemory_desc(const dims &adims, const dims &offsets) { + mkldnn_memory_desc_t sub_md; + error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md, + &data, &adims[0], &offsets[0]), + "could not initialize a sub-memory"); + return desc(sub_md); + } + + /// Returns the number of bytes required to allocate the memory described + /// including the padding area. + size_t get_size() const { return mkldnn_memory_desc_get_size(&data); } + + bool operator==(const desc &other) const { + return mkldnn_memory_desc_equal(&data, &other.data) != 0; + } + + bool operator!=(const desc &other) const { return !operator==(other); } + }; + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + /// @param ahandle Native handle. + memory(const desc &md, const engine &aengine, void *ahandle) { + mkldnn_memory_t result; + error::wrap_c_api(mkldnn_memory_create(&result, &md.data, + aengine.get(), ahandle), "could not create a memory"); + reset(result); + } + + /// Constructs a memory. + /// + /// @param md Memory descriptor. + /// @param aengine Engine. + memory(const desc &md, const engine &aengine) + : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {} + + /// Returns the descriptor of the memory. + desc get_desc() const { + const mkldnn_memory_desc_t *cdesc; + error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc), + "could not get memory descriptor from a memory"); + return desc(*cdesc); + } + + /// Returns the engine of the memory. + engine get_engine() const { + mkldnn_engine_t engine_q; + error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q), + "could not get engine from a memory"); + return engine(engine_q); + } + + /// Returns a handle of the data contained in the memory. + /// + /// On the CPU engine, this is a pointer to the allocated memory. + void *get_data_handle() const { + void *handle; + error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle), + "could not get native handle"); + return handle; + } + + void set_data_handle(void *handle) const { + error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle), + "could not set native handle"); + } + + // Must go away or be private: + static mkldnn_data_type_t convert_to_c(data_type adata_type) { + return static_cast<mkldnn_data_type_t>(adata_type); + } + static mkldnn_format_tag_t convert_to_c(format_tag aformat) { + return static_cast<mkldnn_format_tag_t>(aformat); + } +}; + +inline bool operator==(mkldnn_data_type_t a, memory::data_type b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) { + return !(a == b); +} +inline bool operator==(memory::data_type a, mkldnn_data_type_t b) { + return b == a; +} +inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) { + return !(a == b); +} + +inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) { + return a == memory::convert_to_c(b); +} +inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) { + return !(a == b); +} +inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) { + return b == a; +} +inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) { + return !(a == b); +} + +/// @} + +/// @addtogroup cpp_api_reorder Reorder +/// A primitive to copy data between memory formats. +/// +/// @sa @ref c_api_reorder in @ref c_api +/// @{ + +struct reorder : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const engine &src_engine, const memory::desc &src_md, + const engine &dst_engine, const memory::desc &dst_md) { + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src_engine.get(), &src_md.data, + dst_engine.get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst, + const primitive_attr &aattr) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, aattr.get()), + "could not create a reorder primitive descriptor"); + reset(result); + } + + primitive_desc(const memory &src, const memory &dst) { + mkldnn_primitive_desc_t result; + auto src_md = src.get_desc(); + auto dst_md = dst.get_desc(); + error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result, + src.get_engine().get(), &src_md.data, + dst.get_engine().get(), &dst_md.data, nullptr), + "could not create a reorder primitive descriptor"); + reset(result); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine scratchpad_engine() { + mkldnn_engine_t engine_q; + error::wrap_c_api( + mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q), + "could not get scratchpad engine from reorder primitive_desc"); + + return engine(engine_q); + } + + engine get_engine() { return engine::query(*this); } + }; + + reorder(const primitive_desc &pd): primitive(pd.get()) {} + + reorder(const memory &src, const memory &dst): + primitive(primitive_desc(src, dst).get()) {} + + void execute(stream astream, memory &src, memory &dst) { + primitive::execute(astream, + {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}}); + } +}; + +/// @} + +/// @addtogroup cpp_api_concat Concat +/// A primitive to concatenate data by arbitrary dimension. +/// +/// @sa @ref c_api_concat in @ref c_api +/// @{ + +struct concat : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + std::vector<mkldnn_memory_desc_t> cpp_to_c( + const std::vector<memory::desc> &srcs) { + std::vector<mkldnn_memory_desc_t> c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, int concat_dimension, + const std::vector<memory::desc> &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + primitive_desc(int concat_dimension, + const std::vector<memory::desc> &srcs, const engine &aengine) { + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_concat_primitive_desc_create( + &result, nullptr, (int)c_api_srcs.size(), + concat_dimension, &c_api_srcs[0], nullptr, aengine.get()), + "could not create a concat primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + concat(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @addtogroup cpp_api_sum Sum +/// A primitive to sum data. +/// +/// @sa @ref c_api_sum in @ref c_api +/// @{ + +struct sum : public primitive { + struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + std::vector<mkldnn_memory_desc_t> cpp_to_c( + const std::vector<memory::desc> &srcs) { + std::vector<mkldnn_memory_desc_t> c_api_srcs; + c_api_srcs.reserve(srcs.size()); + for (const auto &s : srcs) c_api_srcs.push_back(s.data); + return c_api_srcs; + } + + primitive_desc(const memory::desc &dst, + const std::vector<float> &scales, + const std::vector<memory::desc> &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create( + &result, &dst.data, (int)c_api_srcs.size(), + &scales[0], &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + primitive_desc(const std::vector<float> &scales, + const std::vector<memory::desc> &srcs, const engine &aengine) { + error::wrap_c_api(scales.size() == srcs.size() + ? mkldnn_success : mkldnn_invalid_arguments, + "number of scales not equal to number of srcs"); + + auto c_api_srcs = cpp_to_c(srcs); + mkldnn_primitive_desc_t result; + error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result, + nullptr, (int)c_api_srcs.size(), &scales[0], + &c_api_srcs[0], nullptr, aengine.get()), + "could not create a sum primitive descriptor"); + reset(result); + } + + memory::desc dst_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(dst_md), 0); + error::wrap_c_api( + cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success, + "could not get a dst memory descriptor"); + return memory::desc(*cdesc); + } + + memory::desc scratchpad_desc() const { + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(scratchpad_md), 0); + if (cdesc == nullptr) + return memory::desc(); + return memory::desc(*cdesc); + } + + engine get_engine() { return engine::query(*this); } + }; + + sum(const primitive_desc &pd): primitive(pd.get()) {} +}; + +/// @} + +/// @} + +/// @addtogroup cpp_api_primitives Primitives +/// @{ + +/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors +/// @{ + +/// A base class for all primitive descriptors. +struct primitive_desc : public handle<mkldnn_primitive_desc_t> { + primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, + const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) { + mkldnn_primitive_desc_iterator_t iterator = nullptr; + mkldnn_status_t status = mkldnn_primitive_desc_iterator_create( + &iterator, desc, attr ? attr->get() : nullptr, e.get(), + hint_fwd_pd); + error::wrap_c_api(status, + "could not create a primitive descriptor iterator"); + pd_iterator.reset(iterator); + fetch_impl(); + } + + engine get_engine() { return engine::query(*this); } + + primitive_attr get_primitive_attr() const { + const_mkldnn_primitive_attr_t const_cattr; + error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr), + "could not get attributes"); + mkldnn_primitive_attr_t cattr; + error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr), + "could not clone attributes"); + + primitive_attr attr; + attr.reset(cattr); + return attr; + } + + /// Returns implementation name + const char *impl_info_str() const { + const char *res; + error::wrap_c_api(mkldnn_primitive_desc_query(get(), + mkldnn_query_impl_info_str, 0, &res), + "could not query implementation info string"); + return res; + } + + /// Queries the memory::dim value (same as int64_t) + memory::dim query_s64(query q) const { + memory::dim res; + mkldnn_status_t status = mkldnn_primitive_desc_query(get(), + mkldnn::convert_to_c(q), 0, &res); + return status == mkldnn_success ? res : 0; + } + + /// Advances the next implementation for the given op descriptor. + /// + /// Returns: + /// - @c true on success + /// - @c false if the last implementation reached, and + /// the primitive descriptor itself is kept unchanged + bool next_impl() { + mkldnn_status_t status = mkldnn_primitive_desc_iterator_next( + pd_iterator.get()); + if (status == mkldnn_iterator_ends) return false; + error::wrap_c_api(status, "primitive descriptor iterator next failed"); + + fetch_impl(); + return true; + } + + /// Queries and returns requested memory descriptor. + memory::desc query_md(query what, int idx = 0) const { + std::vector<query> valid_q{src_md, diff_src_md, weights_md, + diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md}; + if (!std::any_of(valid_q.cbegin(), valid_q.cend(), + [=](query q) { return what == q; })) + throw error(mkldnn_invalid_arguments, "invalid memory query"); + + const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md( + get(), mkldnn::convert_to_c(what), idx); + if (cdesc == nullptr) return memory::desc(); + + return memory::desc(*cdesc); + } + + // register specialized queries, e.g. src_desc() +# define REG_QUERY_MD(name, what, idx) \ + memory::desc name ## _desc() const { return query_md(what ## _md, idx); } + + private: + handle<mkldnn_primitive_desc_iterator_t> pd_iterator; + void fetch_impl() { + mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch( + pd_iterator.get()); + error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error, + "could not fetch a primitive descriptor from the iterator"); + reset(pd); + } +}; + +/// @} + +/// @addtogroup cpp_api_convolution Convolution +/// A primitive to compute convolution using different algorithms. +/// +/// @sa @ref c_api_convolution in @ref c_api +/// @{ + +struct convolution_forward: public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated convolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_data : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api( + mkldnn_dilated_convolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct convolution_backward_weights : public primitive { + struct desc { + mkldnn_convolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a convolution backward weights descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const convolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + convolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} +// +/// @addtogroup cpp_api_deconvolution Deconvolution +/// A primitive to compute deconvolution using different algorithms. +/// +/// @sa @ref c_api_deconvolution in @ref c_api +/// @{ + +struct deconvolution_forward: public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, &bias_desc.data, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, &weights_desc.data, nullptr, + &dst_desc.data, &strides[0], &dilates[0], &padding_l[0], + &padding_r[0], mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_data : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward data descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init( + &data, convert_to_c(aalgorithm), &diff_src_desc.data, + &weights_desc.data, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct deconvolution_backward_weights : public primitive { + struct desc { + mkldnn_deconvolution_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, &diff_bias_desc.data, + &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + desc(algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc, + const memory::dims strides, + const memory::dims dilates, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(dilates); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init( + &data, convert_to_c(aalgorithm), &src_desc.data, + &diff_weights_desc.data, nullptr, &diff_dst_desc.data, + &strides[0], &dilates[0], &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not create a dilated deconvolution backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const deconvolution_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_lrn LRN +/// A primitive to perform local response normalization (LRN) across or within +/// channels. +/// +/// @sa @ref c_api_lrn in @ref c_api +/// @{ + +struct lrn_forward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm), + &src_desc.data, local_size, alpha, beta, k), + "could not create a lrn forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct lrn_backward : public primitive { + struct desc { + mkldnn_lrn_desc_t data; + + desc(algorithm aalgorithm, const memory::desc &data_desc, + const memory::desc &diff_data_desc, memory::dim local_size, + float alpha, float beta, float k = 1.f) { + error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data, + convert_to_c(aalgorithm), &diff_data_desc.data, + &data_desc.data, local_size, alpha, beta, k), + "could not create a lrn backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const lrn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + lrn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_pooling Pooling +/// A primitive to perform max or average pooling. +/// +/// @sa @ref c_api_pooling in @ref c_api +/// @{ + +struct pooling_forward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(prop_kind aprop_kind, algorithm aalgorithm, + const memory::desc &src_desc, + const memory::desc &dst_desc, + const memory::dims strides, + const memory::dims kernel, + const memory::dims padding_l, + const memory::dims padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + convert_to_c(aalgorithm), + &src_desc.data, &dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a forward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct pooling_backward : public primitive { + struct desc { + mkldnn_pooling_desc_t data; + desc(algorithm aalgorithm, + const memory::desc &diff_src_desc, + const memory::desc &diff_dst_desc, + const memory::dims &strides, + const memory::dims &kernel, + const memory::dims &padding_l, + const memory::dims &padding_r, + const padding_kind apadding_kind) { + memory::validate_dims(strides); + memory::validate_dims(kernel); + memory::validate_dims(padding_l); + memory::validate_dims(padding_r); + error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data, + convert_to_c(aalgorithm), + &diff_src_desc.data, &diff_dst_desc.data, + &strides[0], &kernel[0], + &padding_l[0], &padding_r[0], + mkldnn::convert_to_c(apadding_kind)), + "could not init a backward pooling descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const pooling_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + pooling_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_eltwise Eltwise +/// A primitive to compute element-wise operations like parametric rectifier +/// linear unit (ReLU). +/// +/// @sa @ref c_api_eltwise in @ref c_api +/// @{ + +struct eltwise_forward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, algorithm alg_kind, + const memory::desc &src_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + mkldnn::convert_to_c(alg_kind), &src_desc.data, + static_cast<float>(alpha), static_cast<float>(beta)), + "could not create a eltwise forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct eltwise_backward : public primitive { + struct desc { + mkldnn_eltwise_desc_t data; + + template <typename T> + desc(algorithm alg_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T alpha = 0, T beta = 0) { + error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data, + mkldnn::convert_to_c(alg_kind), &diff_data_desc.data, + &data_desc.data, static_cast<float>(alpha), + static_cast<float>(beta)), + "could not create a eltwise backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const eltwise_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + eltwise_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_softmax Softmax +/// A primitive to perform softmax. +/// +/// @sa @ref c_api_softmax in @ref c_api +/// @{ + +struct softmax_forward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + softmax_axis), + "could not create a softmax forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct softmax_backward : public primitive { + struct desc { + mkldnn_softmax_desc_t data; + desc(const memory::desc &diff_desc, const memory::desc &data_desc, + int softmax_axis) { + error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data, + &diff_desc.data, &data_desc.data, softmax_axis), + "could not init a backward softmax descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const softmax_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + softmax_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_batch_norm Batch normalization +/// A primitive to perform batch normalization. +/// +/// @sa @ref c_api_batch_normalization in @ref c_api +/// @{ + +struct batch_normalization_forward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, + unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + static_cast<float>(epsilon), flags), + "could not create a batch normalization forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + + memory::desc mean_desc() const { return stat_desc(mean); } + memory::desc variance_desc() const { return stat_desc(var); } + + private: + enum { mean = 1, var = 2, }; + memory::desc stat_desc(int kind) const { + mkldnn_batch_normalization_desc_t *p; + error::wrap_c_api(mkldnn_primitive_desc_query( + get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p), + "could not get a batch-normalization descriptor"); + return query_md(p->flags & use_global_stats ? src_md : dst_md, kind); + } + }; + + batch_normalization_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct batch_normalization_backward : public primitive { + struct desc { + mkldnn_batch_normalization_desc_t data; + template <typename T> + desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, + const memory::desc &data_desc, T epsilon, unsigned flags) { + error::wrap_c_api( + mkldnn_batch_normalization_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), + &diff_data_desc.data, &data_desc.data, + static_cast<float>(epsilon), flags), + "could not create a batch normalization backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const batch_normalization_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(mean, src, 1); + REG_QUERY_MD(variance, src, 2); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + batch_normalization_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_inner_product Inner Product +/// A primitive to compute an inner product. +/// +/// @sa @ref c_api_inner_product in @ref c_api +/// @{ + +struct inner_product_forward: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &bias_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, &bias_desc.data, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + + desc(prop_kind aprop_kind, const memory::desc &src_desc, + const memory::desc &weights_desc, + const memory::desc &dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &src_desc.data, + &weights_desc.data, nullptr, &dst_desc.data), + "could not create a inner product forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(bias, weights, 1); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_data: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &diff_src_desc, + const memory::desc &weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_data_desc_init(&data, + &diff_src_desc.data, &weights_desc.data, + &diff_dst_desc.data), + "could not create a inner product backward data descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(weights, weights, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_data(const primitive_desc &pd): primitive(pd) {} +}; + +struct inner_product_backward_weights: public primitive { + struct desc { + mkldnn_inner_product_desc_t data; + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + &diff_bias_desc.data, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + desc(const memory::desc &src_desc, + const memory::desc &diff_weights_desc, + const memory::desc &diff_dst_desc) { + error::wrap_c_api( + mkldnn_inner_product_backward_weights_desc_init( + &data, &src_desc.data, &diff_weights_desc.data, + nullptr, &diff_dst_desc.data), + "could not create a inner product backward weights descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const inner_product_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(diff_weights, diff_weights, 0); + REG_QUERY_MD(diff_bias, diff_weights, 1); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_rnn RNN +/// A primitive to compute common recurrent layer. +/// +/// @sa @ref c_api_rnn in @ref c_api +/// @{ + +struct rnn_cell { + struct desc { + mkldnn_rnn_cell_desc_t c_rnn_cell_; + + desc(algorithm kind, algorithm activation_f) { + error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_, + mkldnn::convert_to_c(kind), + mkldnn::convert_to_c(activation_f), 0U, 0, 0), + "could not init an rnn cell descriptor"); + } + desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {} + + operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; } + + algorithm get_cell_kind() const + { return algorithm(c_rnn_cell_.cell_kind); } + algorithm get_activation() const + { return algorithm(c_rnn_cell_.activation_kind); } + + float get_alpha() const { return c_rnn_cell_.alpha; } + void set_alpha(float alpha) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu; + c_rnn_cell_.alpha = alpha; + } + + float get_clipping() const { return c_rnn_cell_.clipping; } + void set_clipping(float clipping) { + c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping; + c_rnn_cell_.clipping = clipping; + } + + int get_gates_count() const { + return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_); + } + int get_state_count() const { + return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_); + } + }; +}; + +struct rnn_forward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc + ) { + error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data), + "could not create an RNN forward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e) + : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + rnn_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct rnn_backward : public primitive { + struct desc { + mkldnn_rnn_desc_t data; + desc(prop_kind aprop_kind, rnn_cell::desc cell, + const rnn_direction direction, + const memory::desc &src_layer_desc, + const memory::desc &src_iter_desc, + const memory::desc &weights_layer_desc, + const memory::desc &weights_iter_desc, + const memory::desc &bias_desc, + const memory::desc &dst_layer_desc, + const memory::desc &dst_iter_desc, + const memory::desc &diff_src_layer_desc, + const memory::desc &diff_src_iter_desc, + const memory::desc &diff_weights_layer_desc, + const memory::desc &diff_weights_iter_desc, + const memory::desc &diff_bias_desc, + const memory::desc &diff_dst_layer_desc, + const memory::desc &diff_dst_iter_desc) { + error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), cell, + mkldnn::convert_to_c(direction), + &src_layer_desc.data, &src_iter_desc.data, + &weights_layer_desc.data, &weights_iter_desc.data, + &bias_desc.data, + &dst_layer_desc.data, &dst_iter_desc.data, + &diff_src_layer_desc.data, &diff_src_iter_desc.data, + &diff_weights_layer_desc.data, + &diff_weights_iter_desc.data, &diff_bias_desc.data, + &diff_dst_layer_desc.data, &diff_dst_iter_desc.data), + "could not create an RNN backward descriptor"); + } + + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, + const rnn_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(src_layer, src, 0); + REG_QUERY_MD(src_iter, src, 1); + REG_QUERY_MD(weights_layer, weights, 0); + REG_QUERY_MD(weights_iter, weights, 1); + REG_QUERY_MD(bias, weights, 2); + REG_QUERY_MD(dst_layer, dst, 0); + REG_QUERY_MD(dst_iter, dst, 1); + REG_QUERY_MD(workspace, workspace, 0); + + REG_QUERY_MD(diff_src_layer, diff_src, 0); + REG_QUERY_MD(diff_src_iter, diff_src, 1); + REG_QUERY_MD(diff_weights_layer, diff_weights, 0); + REG_QUERY_MD(diff_weights_iter, diff_weights, 1); + REG_QUERY_MD(diff_bias, diff_weights, 2); + REG_QUERY_MD(diff_dst_layer, diff_dst, 0); + REG_QUERY_MD(diff_dst_iter, diff_dst, 1); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + // With last iteration (with and without input src_iter) + rnn_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @addtogroup cpp_api_shuffle Shuffle +/// A primitive to shuffle data along the axis. +/// +/// @sa @ref c_api_shuffle in @ref c_api +/// @{ + +struct shuffle_forward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(prop_kind aprop_kind, const memory::desc &data_desc, + int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data, + mkldnn::convert_to_c(aprop_kind), &data_desc.data, + axis, group_size), + "could not create a shuffle forward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e) + : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {} + + REG_QUERY_MD(src, src, 0); + REG_QUERY_MD(dst, dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_forward(const primitive_desc &pd): primitive(pd) {} +}; + +struct shuffle_backward : public primitive { + struct desc { + mkldnn_shuffle_desc_t data; + desc(const memory::desc &diff_data_desc, int axis, int group_size) { + error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data, + &diff_data_desc.data, axis, group_size), + "could not create a shuffle backward descriptor"); + } + }; + + struct primitive_desc : public mkldnn::primitive_desc { + primitive_desc(const desc &desc, const engine &e, + const shuffle_forward::primitive_desc &hint_fwd_pd) + : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {} + + REG_QUERY_MD(diff_src, diff_src, 0); + REG_QUERY_MD(diff_dst, diff_dst, 0); + REG_QUERY_MD(scratchpad, scratchpad, 0); + }; + + shuffle_backward(const primitive_desc &pd): primitive(pd) {} +}; + +/// @} + +/// @} Primitives + +/// @} C++ API + +#undef REG_QUERY_MD + +// implementation section +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) { + mkldnn_primitive_t result; + error::wrap_c_api(mkldnn_primitive_create(&result, c_pd), + "could not create a primitive"); + reset(result); +} + +inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {} + +inline void primitive::execute(stream &astream, + const std::unordered_map<int, memory> &args) const { + std::vector<mkldnn_exec_arg_t> c_args; + c_args.reserve(args.size()); + for (const auto &a: args) + c_args.push_back({a.first, a.second.get()}); + + error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(), + (int)c_args.size(), c_args.data()), + "primitive execution fail"); +} +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h new file mode 100644 index 0000000000..f4dc2fdfa6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#ifndef MKLDNN_DEBUG_H +#define MKLDNN_DEBUG_H + +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +/* All symbols shall be internal unless marked as MKLDNN_API */ +#if defined _WIN32 || defined __CYGWIN__ +# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport) +# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport) +#else +# if __GNUC__ >= 4 +# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default"))) +# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default"))) +# else +# define MKLDNN_HELPER_DLL_IMPORT +# define MKLDNN_HELPER_DLL_EXPORT +# endif +#endif + +#ifdef MKLDNN_DLL +# ifdef MKLDNN_DLL_EXPORTS +# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT +# else +# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT +# endif +#else +# define MKLDNN_API +#endif + +#if defined (__GNUC__) +# define MKLDNN_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +# define MKLDNN_DEPRECATED __declspec(deprecated) +#else +# define MKLDNN_DEPRECATED +#endif + +#include "mkldnn_types.h" +#endif /* DOXYGEN_SHOULD_SKIP_THIS */ + +#ifdef __cplusplus +extern "C" { +#endif + +const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v); +const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v); +const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v); +const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v); +const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v); +const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v); +const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v); +const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v); + +/** Forms a format string for a given memory descriptor. + * + * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'. + * Here: + * - dt -- data type + * - p -- indicates there is non-trivial padding + * - o -- indicates there is non-trivial padding offset + * - 0 -- indicates there is non-trivial offset0 + * - fmt_kind -- format kind (blocked, wino, etc...) + * - fmt -- extended format string (format_kind specific) + * - extra -- shows extra fields (underspecified) + */ +int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len, + const mkldnn_memory_desc_t *md); + +/** Forms a dimension string for a given memory descriptor. + * + * The format is defined as: 'dim0xdim1x...xdimN + */ +int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len, + const mkldnn_memory_desc_t *md); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h new file mode 100644 index 0000000000..1b6c356982 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h @@ -0,0 +1,1415 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TYPES_H +#define MKLDNN_TYPES_H + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +#include <stddef.h> +#include <stdint.h> +#endif + +/** @addtogroup c_api C API + * @{ + * + * @addtogroup c_api_types Types + * @{ + * + * @addtogroup c_api_types_generic Generic + * @{ */ + +/** Intel(R) MKL-DNN Version type */ +typedef struct { + int major; + int minor; + int patch; + const char *hash; +} mkldnn_version_t; + +/** Status values returned by Intel(R) MKL-DNN functions. */ +typedef enum { + /** The operation was successful */ + mkldnn_success = 0, + /** The operation failed due to an out-of-memory condition */ + mkldnn_out_of_memory = 1, + /** The operation failed and should be retried */ + mkldnn_try_again = 2, + /** The operation failed because of incorrect function arguments */ + mkldnn_invalid_arguments = 3, + /** The operation failed because a primitive was not ready for execution */ + mkldnn_not_ready = 4, + /** The operation failed because requested functionality is not implemented + */ + mkldnn_unimplemented = 5, + /** Primitive iterator passed over last primitive descriptor */ + mkldnn_iterator_ends = 6, + /** Primitive or engine failed on execution */ + mkldnn_runtime_error = 7, + /** Queried element is not required for given primitive */ + mkldnn_not_required = 8, +} mkldnn_status_t; + +/** Data type specification */ +typedef enum { + /** Undefined data type, used for empty memory descriptors. */ + mkldnn_data_type_undef = 0, + /** 32-bit/single-precision floating point. */ + mkldnn_f32 = 1, + /** 32-bit signed integer. */ + mkldnn_s32 = 2, + /** 8-bit signed integer. */ + mkldnn_s8 = 3, + /** 8-bit unsigned integer. */ + mkldnn_u8 = 4, +} mkldnn_data_type_t; + +/** Memory format kind */ +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_format_kind_undef = 0, + /** Unspecified format. The primitive selects a format automatically. */ + mkldnn_format_kind_any, + /** A tensor in a generic format described by the stride and blocking + * values in each dimension. See #mkldnn_blocking_desc_t for more + * information. */ + mkldnn_blocked, + /** Weights format used in 8bit Winograd convolution */ + mkldnn_format_kind_wino, + /** Packed weights format used in RNN */ + mkldnn_format_kind_rnn_packed, +} mkldnn_format_kind_t; + +/** Memory format tag specification. + * + * Intel MKL-DNN formats describe physical data layout. The physical layout + * is described as a sequence of the dimensions as they are laid out in the + * memory (from the outer-most to the inner-most). Note that this order + * doesn't affect the logical order of the dimensions that is kept in the + * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the + * dimensions is specified by the type of tensor. + * + * For example, CNN 5D tensor always has its logical dimensions in the order + * `(batch, channels, depth, height, width)`, while the physical layout might be + * #mkldnn_ncdhw or #mkldnn_ndhwc: + * + * ~~~cpp + * int batch = 2, channels = 16, depth = 13, height = 13, width = 13; + * + * int ndims = 5; // 5D tensor + * mkldnn_dims_t dims = {batch, channels, depth, height, width}; + * mkldnn_memory_desc_t data_in_ncdhw; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw); + * + * // note that in both cases dims passed are the same + * mkldnn_memory_desc_t data_in_ndhwc; + * mkldnn_memory_desc_init_by_tag( + * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc); + * ~~~ + * + * The following notation applies to memory format names: + * - @c 'n' denotes the mini-batch dimension + * - @c 'c' denotes a channels dimension + * - When there are multiple channel dimensions (for example, in convolution + * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output + * channels + * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width + * respectively + * - Upper-case letters indicate that the data is laid out in blocks + * for a particular dimension. In such cases, the format name contains both + * upper- and lower-case letters for that dimension with a lower-case letter + * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a + * format where the outermost dimension is mini-batch, followed by the + * channel block number, followed by the spatial height and width, and + * finally followed by 8-element channel blocks. + * + * @note + * Channel designations can be different. For example, both the @c + * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D + * tensor. + * + * @sa @ref understanding_memory_formats + */ +typedef enum { + /** Undefined memory format tag */ + mkldnn_format_tag_undef = 0, + /** Undefined memory format tag. + * The primitive selects a format automatically. */ + mkldnn_format_tag_any, + + /* Semantic agnostic section */ + /* The physical order of dimensions is defined by the permutation of the + * characters, assuming that ab..z defines the natural order. + */ + + /* Plain formats */ + + mkldnn_a, + mkldnn_ab, + mkldnn_abc, + mkldnn_abcd, + mkldnn_abcde, + mkldnn_abcdef, + mkldnn_abdec, + mkldnn_acb, + mkldnn_acbde, + mkldnn_acdb, + mkldnn_acdeb, + mkldnn_ba, + mkldnn_bac, + mkldnn_bacd, + mkldnn_bcda, + mkldnn_cba, + mkldnn_cdba, + mkldnn_cdeba, + mkldnn_decab, + + /* Opaque blocked formats */ + + mkldnn_Abc16a, + mkldnn_ABc16a16b, + mkldnn_aBc16b, + mkldnn_ABc16b16a, + mkldnn_Abc4a, + mkldnn_aBc4b, + mkldnn_ABc4b16a4b, + mkldnn_ABc4b4a, + mkldnn_ABc8a16b2a, + mkldnn_ABc8a8b, + mkldnn_aBc8b, + mkldnn_ABc8b16a2b, + mkldnn_ABc8b8a, + mkldnn_Abcd16a, + mkldnn_ABcd16a16b, + mkldnn_aBcd16b, + mkldnn_ABcd16b16a, + mkldnn_aBCd16b16c, + mkldnn_aBCd16c16b, + mkldnn_Abcd4a, + mkldnn_aBcd4b, + mkldnn_ABcd4b16a4b, + mkldnn_ABcd4b4a, + mkldnn_aBCd4c16b4c, + mkldnn_aBCd4c4b, + mkldnn_ABcd8a16b2a, + mkldnn_ABcd8a8b, + mkldnn_aBcd8b, + mkldnn_ABcd8b16a2b, + mkldnn_aBCd8b16c2b, + mkldnn_ABcd8b8a, + mkldnn_aBCd8b8c, + mkldnn_aBCd8c16b2c, + mkldnn_aBCd8c8b, + mkldnn_Abcde16a, + mkldnn_ABcde16a16b, + mkldnn_aBcde16b, + mkldnn_ABcde16b16a, + mkldnn_aBCde16b16c, + mkldnn_aBCde16c16b, + mkldnn_aBCde2c8b4c, + mkldnn_Abcde4a, + mkldnn_aBcde4b, + mkldnn_ABcde4b4a, + mkldnn_aBCde4b4c, + mkldnn_aBCde4c16b4c, + mkldnn_aBCde4c4b, + mkldnn_Abcde8a, + mkldnn_ABcde8a8b, + mkldnn_aBcde8b, + mkldnn_ABcde8b16a2b, + mkldnn_aBCde8b16c2b, + mkldnn_ABcde8b8a, + mkldnn_aBCde8b8c, + mkldnn_aBCde8c16b2c, + mkldnn_aBCde8c8b, + mkldnn_aBcdef16b, + mkldnn_aBCdef16b16c, + mkldnn_aBCdef16c16b, + mkldnn_aBcdef4b, + mkldnn_aBCdef4c4b, + mkldnn_aBCdef8b8c, + mkldnn_aBCdef8c16b2c, + mkldnn_aBCdef8c8b, + mkldnn_aBdc16b, + mkldnn_aBdc4b, + mkldnn_aBdc8b, + mkldnn_aBdec16b, + mkldnn_aBdec4b, + mkldnn_aBdec8b, + mkldnn_aBdefc16b, + mkldnn_aBdefc4b, + mkldnn_aBdefc8b, + mkldnn_Acb16a, + mkldnn_Acb4a, + mkldnn_Acb8a, + mkldnn_aCBd16b16c, + mkldnn_aCBde16b16c, + mkldnn_Acdb16a, + mkldnn_Acdb4a, + mkldnn_Acdb8a, + mkldnn_Acdeb16a, + mkldnn_Acdeb4a, + mkldnn_Acdeb8a, + mkldnn_BAc16a16b, + mkldnn_BAcd16a16b, + + /** Just a sentinel, not real memory format tag. Must be changed after new + * format tag is added. */ + mkldnn_format_tag_last, + + /* Aliases */ + + mkldnn_x = mkldnn_a, + mkldnn_nc = mkldnn_ab, + mkldnn_cn = mkldnn_ba, + mkldnn_ncw = mkldnn_abc, + mkldnn_nwc = mkldnn_acb, + mkldnn_nchw = mkldnn_abcd, + mkldnn_nhwc = mkldnn_acdb, + mkldnn_chwn = mkldnn_bcda, + mkldnn_ncdhw = mkldnn_abcde, + mkldnn_ndhwc = mkldnn_acdeb, + + mkldnn_oi = mkldnn_ab, + mkldnn_io = mkldnn_ba, + mkldnn_oiw = mkldnn_abc, + mkldnn_wio = mkldnn_cba, + mkldnn_oihw = mkldnn_abcd, + mkldnn_hwio = mkldnn_cdba, + mkldnn_ihwo = mkldnn_bcda, + mkldnn_iohw = mkldnn_bacd, + mkldnn_oidhw = mkldnn_abcde, + mkldnn_dhwio = mkldnn_cdeba, + mkldnn_goiw = mkldnn_abcd, + mkldnn_goihw = mkldnn_abcde, + mkldnn_hwigo = mkldnn_decab, + mkldnn_giohw = mkldnn_acbde, + mkldnn_goidhw = mkldnn_abcdef, + + /** 3D RNN data tensor in the format (seq_length, batch, input channels). */ + mkldnn_tnc = mkldnn_abc, + /** 3D RNN data tensor in the format (batch, seq_length, input channels). */ + mkldnn_ntc = mkldnn_bac, + /** 5D RNN states tensor in the format (num_layers, num_directions, + * num_states, batch, state channels). */ + mkldnn_ldsnc = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * input_channels, num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldigo = mkldnn_abcde, + /** 5D RNN weights tensor in the format (num_layers, num_directions, + * num_gates, output_channels, input_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgoi = mkldnn_abdec, + /** 4D RNN bias tensor in the format (num_layers, num_directions, + * num_gates, output_channels). + * + * - For LSTM cells, the gates order is input, forget, candidate + * and output gate. + * - For GRU cells, the gates order is update, reset and output gate. */ + mkldnn_ldgo = mkldnn_abcd, + + /* Opaque data types, are not to be used explicitly */ + + /* data */ + mkldnn_nCdhw16c = mkldnn_aBcde16b, + mkldnn_nCdhw4c = mkldnn_aBcde4b, + mkldnn_nCdhw8c = mkldnn_aBcde8b, + mkldnn_nChw16c = mkldnn_aBcd16b, + mkldnn_nChw4c = mkldnn_aBcd4b, + mkldnn_nChw8c = mkldnn_aBcd8b, + mkldnn_nCw16c = mkldnn_aBc16b, + mkldnn_nCw4c = mkldnn_aBc4b, + mkldnn_nCw8c = mkldnn_aBc8b, + + /* weights, 3D */ + mkldnn_IOw16o16i = mkldnn_BAc16a16b, + mkldnn_OIw16i16o = mkldnn_ABc16b16a, + mkldnn_OIw16o16i = mkldnn_ABc16a16b, + mkldnn_Oiw16o = mkldnn_Abc16a, + mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b, + mkldnn_OIw4i4o = mkldnn_ABc4b4a, + mkldnn_Oiw4o = mkldnn_Abc4a, + mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b, + mkldnn_OIw8i8o = mkldnn_ABc8b8a, + mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a, + mkldnn_OIw8o8i = mkldnn_ABc8a8b, + mkldnn_Owi16o = mkldnn_Acb16a, + mkldnn_Owi4o = mkldnn_Acb4a, + mkldnn_Owi8o = mkldnn_Acb8a, + + /* weights, 4D */ + mkldnn_IOhw16o16i = mkldnn_BAcd16a16b, + mkldnn_Ohwi16o = mkldnn_Acdb16a, + mkldnn_Ohwi4o = mkldnn_Acdb4a, + mkldnn_Ohwi8o = mkldnn_Acdb8a, + mkldnn_OIhw16i16o = mkldnn_ABcd16b16a, + mkldnn_OIhw16o16i = mkldnn_ABcd16a16b, + mkldnn_Oihw16o = mkldnn_Abcd16a, + mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b, + mkldnn_OIhw4i4o = mkldnn_ABcd4b4a, + mkldnn_Oihw4o = mkldnn_Abcd4a, + mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b, + mkldnn_OIhw8i8o = mkldnn_ABcd8b8a, + mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a, + mkldnn_OIhw8o8i = mkldnn_ABcd8a8b, + + /* weights, 5D */ + mkldnn_Odhwi16o = mkldnn_Acdeb16a, + mkldnn_Odhwi4o = mkldnn_Acdeb4a, + mkldnn_Odhwi8o = mkldnn_Acdeb8a, + mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a, + mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b, + mkldnn_Oidhw16o = mkldnn_Abcde16a, + mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a, + mkldnn_Oidhw4o = mkldnn_Abcde4a, + mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b, + mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a, + mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b, + + /* weights w/ groups, 3D */ + mkldnn_Goiw16g = mkldnn_Abcd16a, + mkldnn_gIOw16o16i = mkldnn_aCBd16b16c, + mkldnn_gOIw16i16o = mkldnn_aBCd16c16b, + mkldnn_gOIw16o16i = mkldnn_aBCd16b16c, + mkldnn_gOiw16o = mkldnn_aBcd16b, + mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c, + mkldnn_gOIw4i4o = mkldnn_aBCd4c4b, + mkldnn_gOiw4o = mkldnn_aBcd4b, + mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c, + mkldnn_gOIw8i8o = mkldnn_aBCd8c8b, + mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b, + mkldnn_gOIw8o8i = mkldnn_aBCd8b8c, + mkldnn_gOwi16o = mkldnn_aBdc16b, + mkldnn_gOwi4o = mkldnn_aBdc4b, + mkldnn_gOwi8o = mkldnn_aBdc8b, + + /* weights w/ groups, 4D */ + mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c, + mkldnn_gOhwi16o = mkldnn_aBdec16b, + mkldnn_gOhwi4o = mkldnn_aBdec4b, + mkldnn_gOhwi8o = mkldnn_aBdec8b, + mkldnn_Goihw16g = mkldnn_Abcde16a, + mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b, + mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c, + mkldnn_gOihw16o = mkldnn_aBcde16b, + mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c, + mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c, + mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b, + mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c, + mkldnn_gOihw4o = mkldnn_aBcde4b, + mkldnn_Goihw8g = mkldnn_Abcde8a, + mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c, + mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b, + mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b, + mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c, + + /* weights w/ groups, 6D */ + mkldnn_gOdhwi16o = mkldnn_aBdefc16b, + mkldnn_gOdhwi4o = mkldnn_aBdefc4b, + mkldnn_gOdhwi8o = mkldnn_aBdefc8b, + mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b, + mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c, + mkldnn_gOidhw16o = mkldnn_aBcdef16b, + mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b, + mkldnn_gOidhw4o = mkldnn_aBcdef4b, + mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c, + mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b, + mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c, +} mkldnn_format_tag_t; + +/** Kinds of padding. Define how to interpret the data in padding regions. */ +typedef enum { + /** The data in padding regions is zero. */ + mkldnn_padding_zero, +} mkldnn_padding_kind_t; + +/** Kinds of propagation. */ +typedef enum { + /* TODO: suggest renames */ + /** Undefined propagation type. */ + mkldnn_prop_kind_undef = 0, + /** Forward data propagation (training mode). In this mode primitives + * perform computations necessary for subsequent backward propagation. */ + mkldnn_forward_training = 64, + /** Forward data propagation (inference mode). In this mode primitives + * perform only computations that are necessary for inference and omit + * computations that are necessary only for backward propagation. */ + mkldnn_forward_inference = 96, + /** Forward data propagation (alias for @c mkldnn_forward_inference) */ + mkldnn_forward_scoring = mkldnn_forward_inference, + /** Forward data propagation (alias for @c mkldnn_forward_training) */ + mkldnn_forward = mkldnn_forward_training, + /** Backward propagation (with respect to all parameters */ + mkldnn_backward = 128, + /** Backward data propagation */ + mkldnn_backward_data = 160, + /** Backward weights propagation */ + mkldnn_backward_weights = 192, + /** Backward bias propagation */ + mkldnn_backward_bias = 193, +} mkldnn_prop_kind_t; + +/** Kinds of primitives. Used to implement a way to extend the library with new + * primitives without changing the ABI. */ +typedef enum { + /** Undefined primitive (XXX: why do we have it?). */ + mkldnn_undefined_primitive, + /** A reorder primitive.*/ + mkldnn_reorder, + /** A shuffle primitive.*/ + mkldnn_shuffle, + /** A (out-of-place) concat primitive. */ + mkldnn_concat, + /** A sum primitive. */ + mkldnn_sum, + /** A convolution primitive. */ + mkldnn_convolution, + /** A deconvolution primitive. */ + mkldnn_deconvolution, + /** An element-wise primitive. */ + mkldnn_eltwise, + /** A Softmax primitive. */ + mkldnn_softmax, + /** A pooling primitive. */ + mkldnn_pooling, + /** An LRN primitive. */ + mkldnn_lrn, + /** An batch normalization primitive. */ + mkldnn_batch_normalization, + /** An inner product primitive. */ + mkldnn_inner_product, + /** A rnn primitive. */ + mkldnn_rnn, +} mkldnn_primitive_kind_t; + +/** Kinds of algorithms. */ +typedef enum { + mkldnn_alg_kind_undef, + /** Direct convolution */ + mkldnn_convolution_direct = 0x1, + /** Winograd convolution */ + mkldnn_convolution_winograd = 0x2, + /** Convolution algorithm(either direct or Winograd) is chosen just in time **/ + mkldnn_convolution_auto = 0x3, + /** Direct deconvolution */ + mkldnn_deconvolution_direct = 0xa, + /** Winograd deconvolution */ + mkldnn_deconvolution_winograd = 0xb, + /** Eltwise: ReLU */ + mkldnn_eltwise_relu = 0x1f, + /** Eltwise: hyperbolic tangent non-linearity (tanh) */ + mkldnn_eltwise_tanh = 0x2f, + /** Eltwise: parametric exponential linear unit (elu) */ + mkldnn_eltwise_elu = 0x3f, + /** Eltwise: square */ + mkldnn_eltwise_square = 0x4f, + /** Eltwise: abs */ + mkldnn_eltwise_abs = 0x5f, + /** Eltwise: square root */ + mkldnn_eltwise_sqrt = 0x6f, + /** Eltwise: linear */ + mkldnn_eltwise_linear = 0x7f, + /** Eltwise: bounded_relu */ + mkldnn_eltwise_bounded_relu = 0x8f, + /** Eltwise: soft_relu */ + mkldnn_eltwise_soft_relu = 0x9f, + /** Eltwise: logistic */ + mkldnn_eltwise_logistic = 0xaf, + /** Max pooling */ + mkldnn_pooling_max = 0x1ff, + /** Average pooling include padding */ + mkldnn_pooling_avg_include_padding = 0x2ff, + /** Average pooling exclude padding */ + mkldnn_pooling_avg_exclude_padding = 0x3ff, + mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding, + /** Local response normalization (LRN) across multiple channels */ + mkldnn_lrn_across_channels = 0xaff, + /** LRN within a single channel */ + mkldnn_lrn_within_channel = 0xbff, + /** RNN cell */ + mkldnn_vanilla_rnn = 0x1fff, + /** LSTM cell */ + mkldnn_vanilla_lstm = 0x2fff, + /** GRU cell */ + mkldnn_vanilla_gru = 0x3fff, + /** GRU cell with linear before reset + * + * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru + * in how the new memory gate is calculated: + * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f] + * Primitive expects 4 biases on input: + * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$ + * */ + mkldnn_gru_linear_before_reset = 0x4fff, +} mkldnn_alg_kind_t; + +/** Flags for batch-normalization primititve. */ +typedef enum { + /** Use global statistics + * + * If specified + * - on forward propagation use mean and variance provided by user (input) + * - on backward propagation reduces the amount of computations, since + * mean and variance are considered as constants + * + * If not specified: + * - on forward propagation mean and variance are computed and stored in + * output + * - on backward propagation compute full derivative wrt to data + */ + mkldnn_use_global_stats = 0x1U, + /** Use scale and shift parameters + * + * If specified: + * - on forward propagation use scale and shift (aka scale and bias) for + * the batch normalization results + * - on backward propagation (for prop_kind == #mkldnn_backward) compute + * diff wrt to scale and shift (hence one extra output used) + * + * If no specified: + * - on backward propagation prop_kind == #mkldnn_backward_data has the + * same behavior as prop_kind == #mkldnn_backward + */ + mkldnn_use_scaleshift = 0x2U, + /** Fuse with ReLU + * + * If specified: + * - on inference this option behaves the same as if the primitive were + * fused with ReLU via post ops API + * - on training primitive requires workspace (required to be able to + * perform backward pass) + */ + mkldnn_fuse_bn_relu = 0x4U, +} mkldnn_batch_normalization_flag_t; + +/** @} */ + +/** @addtogroup c_api_types_memory Memory + * @{ */ + +/** Maximum number of dimensions a tensor can have. Only restricts the amount + * of space used for the tensor description. Individual computational + * primitives may support only tensors of certain dimensions. */ +#define MKLDNN_MAX_NDIMS 12 + +/** A type to describe tensor dimension. */ +typedef int64_t mkldnn_dim_t; + +/** A type to describe tensor dimensions. */ +typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS]; + +/** A type to describe strides within a tensor. */ +typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS]; + +/** Generic description of blocked data layout for most memory formats. + * + * @sa @ref understanding_memory_formats */ +typedef struct { + /** The strides between the outermost blocks. + * In case of plain (non-blocked) formats the strides between dimensions. */ + mkldnn_dims_t strides; + /* Innermost section + * ASSUMPTION: the innermost blocks are always dense */ + /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */ + int inner_nblks; + /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */ + mkldnn_dims_t inner_blks; + /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of + * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */ + mkldnn_dims_t inner_idxs; +} mkldnn_blocking_desc_t; + +typedef enum { + /** Undefined memory format, used for empty memory descriptors. */ + mkldnn_wino_undef = 0, + /** Tensors of weights for 2x3 winograd convolutions. */ + mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, + /** Tensor of weights for 4x3 convolution. */ + mkldnn_wino_wei_OBaaIBOIio +} mkldnn_wino_memory_format_t; + +/** Description of tensor of weights for winograd 2x3 convolution. */ +typedef struct { + mkldnn_wino_memory_format_t wino_format; + int r; + int alpha; + int ic; + int oc; + int ic_block; + int oc_block; + int ic2_block; + int oc2_block; + float adj_scale; + size_t size; +} mkldnn_wino_desc_t; + +typedef enum { + mkldnn_packed_format_undef = 0, + mkldnn_ldigo_p, + mkldnn_ldgoi_p +} mkldnn_rnn_packed_memory_format_t; + +/* Maximum number of parts of RNN weights tensor that require separate + * computation. */ +#define MKLDNN_RNN_MAX_N_PARTS 4 + +/** Description of tensor of packed weights for rnn. */ +typedef struct { + mkldnn_rnn_packed_memory_format_t format; + int n_parts; + int n; + int parts[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + size_t offset_compensation; + size_t size; +} mkldnn_rnn_packed_desc_t; + +typedef enum { + mkldnn_memory_extra_flag_none = 0x0U, + /** Indicates the weights have an additional buffer, that depends on the + * @p compensation_mask. + * + * For instance, in 4D case with the compensation mask equals (1 << 0) + * the additional buffer would consist of OC values: + * O[oc : 0,OC] = + * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) } + */ + mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U, + mkldnn_memory_extra_flag_scale_adjust = 0x2U, +} mkldnn_memory_extra_flags_t; + +/** Description of extra information stored in memory */ +typedef struct { + /** The flags contain arbitrary extra information, such as compensation. + * @sa mkldnn_memory_extra_flags_t */ + uint64_t flags; + /** Compensation mask */ + int compensation_mask; + /** Scale applied to the data */ + float scale_adjust; + /** For future backwards compatibility */ + char reserved[64]; +} mkldnn_memory_extra_desc_t; + +/** Memory descriptor. The description is based on a number of dimensions, + * dimensions themselves, plus information about elements type and memory + * format. Additionally, contains format-specific descriptions of the data + * layout. */ +typedef struct { + /** Number of dimensions */ + int ndims; + /** Dimensions in the following order: + * - CNN data tensors: mini-batch, channel, spatial + * (<code>{N, C, [[D,] H,] W}</code>) + * - CNN weight tensors: group (optional), output channel, input channel, + * spatial (<code>{[G,] O, I, [[D,] H,] W}</code>) + * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>) + * or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>) + * - RNN weight tensor: layers, directions, input channel, gates, output channels + * (<code>{L, D, I, G, O}</code>). + * + * @note + * The order of dimensions does not depend on the memory format, so + * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc + * the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>. + */ + mkldnn_dims_t dims; + /** Data type of the tensor elements. */ + mkldnn_data_type_t data_type; + + /** Size of the data including padding in each dimension. */ + mkldnn_dims_t padded_dims; + /** Per-dimension offset from the padding to actual data, the top-level + * tensor with offsets applied must lie within the padding area. */ + mkldnn_dims_t padded_offsets; + + /** Offset from memory origin to the current block, non-zero only in + * a description of a memory sub-block. */ + mkldnn_dim_t offset0; + + /** Memory format kind. */ + mkldnn_format_kind_t format_kind; + union { + /** Description of the data layout for memory formats that use + * blocking. */ + mkldnn_blocking_desc_t blocking; + /** Tensor of weights for integer 8bit winograd convolution. */ + mkldnn_wino_desc_t wino_desc; + /** Tensor of packed weights for RNN. */ + mkldnn_rnn_packed_desc_t rnn_packed_desc; + /* ... other descriptions possible */ + } format_desc; + + mkldnn_memory_extra_desc_t extra; +} mkldnn_memory_desc_t; + +/** @struct mkldnn_memory + * An opaque structure to describe a memory. */ +struct mkldnn_memory; + +/** A memory handle. */ +typedef struct mkldnn_memory *mkldnn_memory_t; + +/** A constant memory handle. */ +typedef const struct mkldnn_memory *const_mkldnn_memory_t; + +#define MKLDNN_NATIVE_HANDLE_NONE (NULL) +#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1) + +/** @} */ + +/** @addtogroup c_api_types_op_descs Operation descriptors + * @{*/ + +/** A pointer to any of the operation descriptors. */ +typedef void *mkldnn_op_desc_t; +/** A pointer to any of the operation descriptors (constant variant). */ +typedef const void *const_mkldnn_op_desc_t; + +/** A descriptor of a convolution operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** The kind of the convolution algorithm. Possible values: + * #mkldnn_convolution_direct. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Convolution strides in each spatial dimension. */ + mkldnn_dims_t strides; + /** Convolution dilates in each spatial dimension. */ + mkldnn_dims_t dilates; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_convolution_desc_t; + +/** A descriptor of a deconvolution operation. */ +typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t; + +/** A descriptor of a shuffle operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_convolution. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward_data. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor, + * and source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** axis for shuffling. */ + int axis; + /** number of groups in group convolution */ + mkldnn_dim_t group_size; +} mkldnn_shuffle_desc_t; + +/** A descriptor of a element-wise operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_eltwise. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu, + * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square, + * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear, + * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and + * #mkldnn_eltwise_logistic. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Algorithm specific parameter. + * Accordance table: + * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored + * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored + * - #mkldnn_eltwise_square: @p alpha and @p beta ignored + * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored + * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored + * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift + * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored + * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored + * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored + */ + float alpha, beta; +} mkldnn_eltwise_desc_t; + +/** A descriptor of a Softmax operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_softmax. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training and + * #mkldnn_forward_inference. */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and Destination of gradient memory descriptor. */ + mkldnn_memory_desc_t diff_desc; + /** The axis along which to perform the softmax. */ + int softmax_axis; +} mkldnn_softmax_desc_t; + +/** A descriptor of a pooling operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_pooling. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and + * #mkldnn_pooling_avg. */ + mkldnn_alg_kind_t alg_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** Pooling kernel strides for spatial dimensions. */ + mkldnn_dims_t strides; + /** Pooling kernel spatial dimensions. */ + mkldnn_dims_t kernel; + /** Padding in each spatial dimension. padding[0] is a padding in the + * beginning (@p padding_l), padding[1] is a padding in the end (@p + * padding_r). */ + mkldnn_dims_t padding[2]; + /** The kind of padding to use. */ + mkldnn_padding_kind_t padding_kind; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_pooling_desc_t; + +/** A descriptor of a Local Response Normalization (LRN) operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_lrn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and + * #mkldnn_lrn_across_channels. */ + mkldnn_alg_kind_t alg_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** The number of channels to sum over (for cross-channel LRN) or the side + * length of the square region to sum over (for within-channel LRN). */ + mkldnn_dim_t local_size; + /** LRN alpha parameter. */ + float lrn_alpha; + /** LRN beta parameter. */ + float lrn_beta; + /** LRN k parameter. */ + float lrn_k; +} mkldnn_lrn_desc_t; + +/** A descriptor of a Batch Normalization operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_batch_normalization. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data. + */ + mkldnn_prop_kind_t prop_kind; + /** Source and destination memory descriptor. */ + mkldnn_memory_desc_t data_desc; + /** Source and destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_data_desc; + /** Scale and shift data and gradient memory descriptors. + * + * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st + * dimension contains gamma parameter, 2-nd dimension contains beta + * parameter. */ + mkldnn_memory_desc_t data_scaleshift_desc; + mkldnn_memory_desc_t diff_data_scaleshift_desc; + /** Mean and variance data memory descriptors. + * + * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels]. + */ + mkldnn_memory_desc_t mean_desc; + mkldnn_memory_desc_t variance_desc; + /** Batch normalization epsilon parameter. */ + float batch_norm_epsilon; + unsigned flags; +} mkldnn_batch_normalization_desc_t; + +/** A descriptor of an inner product operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_inner_product. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, #mkldnn_backward_data, + * #mkldnn_backward_weights, and #mkldnn_backward_bias. */ + mkldnn_prop_kind_t prop_kind; + /** Source memory descriptor. */ + mkldnn_memory_desc_t src_desc; + /** Source gradient memory descriptor. */ + mkldnn_memory_desc_t diff_src_desc; + /** Weights memory descriptor. */ + mkldnn_memory_desc_t weights_desc; + /** Weights gradient memory descriptor. */ + mkldnn_memory_desc_t diff_weights_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination memory descriptor. */ + mkldnn_memory_desc_t dst_desc; + /** Destination gradient memory descriptor. */ + mkldnn_memory_desc_t diff_dst_desc; + /** The accumulator data type. Initialized automatically. */ + mkldnn_data_type_t accum_data_type; +} mkldnn_inner_product_desc_t; + +/** Flags for RNN cell. */ +typedef enum { + mkldnn_rnn_cell_with_relu = 0x1U, + mkldnn_rnn_cell_with_clipping = 0x2U, +} mkldnn_rnn_cell_flags_t; + +typedef struct { + /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn, + * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, + * or #mkldnn_gru_linear_before_reset. */ + mkldnn_alg_kind_t cell_kind; + /** Activation function used. Must be either #mkldnn_eltwise_relu or + * #mkldnn_eltwise_tanh. */ + mkldnn_alg_kind_t activation_kind; + /** RNN cell flags */ + unsigned int flags; + /** @c alpha is a negative slope parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */ + float alpha; + /** clipping parameter (used only if + * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */ + float clipping; +} mkldnn_rnn_cell_desc_t; + +/** A direction of RNN primitive execution. */ +typedef enum { + /* Unidirectional execution of RNN primitive from left to right. */ + mkldnn_unidirectional_left2right, + /* Unidirectional execution of RNN primitive from right to left. */ + mkldnn_unidirectional_right2left, + /* Bidirectional execution of RNN primitive with concatenation of the + * results. */ + mkldnn_bidirectional_concat, + /* Bidirectional execution of RNN primitive with summation of the + * results. */ + mkldnn_bidirectional_sum, + mkldnn_unidirectional = mkldnn_unidirectional_left2right, +} mkldnn_rnn_direction_t; + +/** A descriptor for an RNN operation. */ +typedef struct { + /** The kind of primitive. Used for self-identifying the primitive + * descriptor. Must be #mkldnn_rnn. */ + mkldnn_primitive_kind_t primitive_kind; + /** The kind of propagation. Possible values: #mkldnn_forward_training, + * #mkldnn_forward_inference, and #mkldnn_backward. */ + mkldnn_prop_kind_t prop_kind; + /** The RNN cell desc. */ + mkldnn_rnn_cell_desc_t cell_desc; + /** The direction of RNN primitive execution. */ + mkldnn_rnn_direction_t direction; + /** Source layer memory descriptor. */ + mkldnn_memory_desc_t src_layer_desc; + /** Source iteration memory descriptor. */ + mkldnn_memory_desc_t src_iter_desc; + /** Weights layer memory descriptor. */ + mkldnn_memory_desc_t weights_layer_desc; + /** Weights iteration memory descriptor. */ + mkldnn_memory_desc_t weights_iter_desc; + /** Bias memory descriptor. */ + mkldnn_memory_desc_t bias_desc; + /** Destination layer memory descriptor. */ + mkldnn_memory_desc_t dst_layer_desc; + /** Destination iter memory descriptor. */ + mkldnn_memory_desc_t dst_iter_desc; + /** Source gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_src_layer_desc; + /** Source gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_src_iter_desc; + /** Weights gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_weights_layer_desc; + /** Weights gradient iter memory descriptor. */ + mkldnn_memory_desc_t diff_weights_iter_desc; + /** Bias gradient memory descriptor. */ + mkldnn_memory_desc_t diff_bias_desc; + /** Destination gradient layer memory descriptor. */ + mkldnn_memory_desc_t diff_dst_layer_desc; + /** Destination gradient iteration memory descriptor. */ + mkldnn_memory_desc_t diff_dst_iter_desc; +} mkldnn_rnn_desc_t; + +/** @} */ + +/** @addtogroup c_api_engine_types Engine + * @{ */ + +/** @brief Kinds of engines. */ +typedef enum { + /** An unspecified engine. */ + mkldnn_any_engine, + /** CPU engine. */ + mkldnn_cpu, +} mkldnn_engine_kind_t; + +/** @struct mkldnn_engine + * @brief An opaque structure to describe an engine. */ +struct mkldnn_engine; +/** @brief An engine handle. */ +typedef struct mkldnn_engine *mkldnn_engine_t; +#if 0 +/* FIXME: looks like this never happens */ +/** @brief A constant engine handle. */ +typedef const struct mkldnn_engine *const_mkldnn_engine_t; +#endif + +/** @} */ + +/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators + * @{ */ + +/** @struct mkldnn_primitive_desc_iterator + * @brief An opaque structure to describe a primitive descriptor iterator. */ +struct mkldnn_primitive_desc_iterator; + +/** @brief A primitive descriptor iterator handle. */ +typedef struct mkldnn_primitive_desc_iterator + *mkldnn_primitive_desc_iterator_t; + +/** @brief A constant primitive descriptor iterator handle. */ +typedef const struct mkldnn_primitive_desc_iterator + *const_mkldnn_primitive_desc_iterator_t; + +/** @} */ + +/** @addtogroup c_api_primitive_descs Primitive descriptors + * @{ */ + +/** @struct mkldnn_primitive_desc + * @brief An opaque structure to describe a primitive descriptor. */ +struct mkldnn_primitive_desc; + +/** @brief A primitive descriptor handle. */ +typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t; + +/** @brief A constant primitive descriptor handle. */ +typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t; + +/** @} */ + +/** @addtogroup c_api_primitive_attr Primitive descriptor attributes + * @{ */ + +/** Scratchpad mode */ +typedef enum { + /** The library manages scratchpad (default) */ + mkldnn_scratchpad_mode_library, + /** A user shall query and provide the scratchpad memory to primitives */ + mkldnn_scratchpad_mode_user, +} mkldnn_scratchpad_mode_t; + +/** @struct mkldnn_primitive_attr + * @brief An opaque structure for primitive descriptor attributes. + * + * Attributes may contain: + * - output scales (to scale the result prior to storing it to the memory) + */ +struct mkldnn_primitive_attr; + +/** @brief A primitive descriptor attributes handle that controls primitive + * behavior. */ +typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t; + +/** @brief A constant primitive descriptor attributes handle. */ +typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t; + +/** @struct mkldnn_post_ops + * @brief An opaque structure for a chain of post operations. + * + * mkldnn_post_ops can be used to perform some (trivial) operations like + * accumulation or eltwise after certain primitives like convolution. + * + * Post operations might be combined together, making a chain of post + * operations. For instance one can configure convolution followed by + * accumulation followed by eltwise. This might be especially beneficial + * for residual learning blocks. + * + * @warning + * Of course not all combinations are supported, so the user should handle + * errors accordingly. + * + * Supported post operations: + * - accumulation (base primitive: convolution) + * - eltwise (base primitive: convolution) + */ +struct mkldnn_post_ops; + +/** @brief A post operation chain handle. */ +typedef struct mkldnn_post_ops *mkldnn_post_ops_t; + +/** @brief A constant post operation chain handle. */ +typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t; + +/** @} */ + +/** @addtogroup c_api_types_primitive Primitive + * @{ */ + +/** @struct mkldnn_primitive + * An opaque structure to describe a primitive. */ +struct mkldnn_primitive; +/** A primitive handle. */ +typedef struct mkldnn_primitive *mkldnn_primitive_t; +/** A constant primitive handle. */ +typedef const struct mkldnn_primitive *const_mkldnn_primitive_t; + +/** @addtogroup c_api_types_arguments Argument indices + * @{ */ + +#define MKLDNN_ARG_SRC_0 1 +#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0 +#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0 + +#define MKLDNN_ARG_SRC_1 2 +#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1 + +#define MKLDNN_ARG_DST_0 17 +#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0 +#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0 + +#define MKLDNN_ARG_DST_1 18 +#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1 + +#define MKLDNN_ARG_WEIGHTS_0 33 +#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0 +#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0 + +#define MKLDNN_ARG_WEIGHTS_1 34 +#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1 + +#define MKLDNN_ARG_BIAS 41 + +#define MKLDNN_ARG_MEAN 49 +#define MKLDNN_ARG_VARIANCE 50 + +#define MKLDNN_ARG_WORKSPACE 64 +#define MKLDNN_ARG_SCRATCHPAD 80 + +#define MKLDNN_ARG_DIFF_SRC_0 129 +#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0 +#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0 + +#define MKLDNN_ARG_DIFF_SRC_1 130 +#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1 + +#define MKLDNN_ARG_DIFF_DST_0 145 +#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0 +#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0 + +#define MKLDNN_ARG_DIFF_DST_1 146 +#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1 + +#define MKLDNN_ARG_DIFF_WEIGHTS_0 161 +#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0 +#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0 + +#define MKLDNN_ARG_DIFF_WEIGHTS_1 162 +#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1 + +#define MKLDNN_ARG_DIFF_BIAS 169 + +#define MKLDNN_ARG_MULTIPLE_SRC 1024 +#define MKLDNN_ARG_MULTIPLE_DST 2048 + +/** @} */ + +/** An auxiliary structure to specify primitive's inputs/outputs at execution + * + * @warning + * With this API it's impossible to preserve constness of memory, so all + * memories are passed w/o const qualifier. However only memories with + * output semantics might be changed during the execution */ +typedef struct { + int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */ + mkldnn_memory_t memory; /**< Input/output memory */ +} mkldnn_exec_arg_t; + +/** @} */ + +/** @addtogroup c_api_types_query Queries + * @{ */ + +/** Primitive descriptor query specification + * + * For generic function mkldnn_primitive_desc_query(), the type of result must + * agree with the queried argument. The correspondence table: + * Query | type of result + * -------------------------------------------------------------- + * #mkldnn_query_engine | mkldnn_engine_t * + * #mkldnn_query_scratchpad_engine | mkldnn_engine_t * + * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t * + * *_s32 | int * + * *_s64 | mkldnn_dim_t * (same as int64_t *) + * *_f64 | double * + * *_str | const char ** + * #mkldnn_query_op_d | const_mkldnn_op_desc_t * + * *_md | const mkldnn_memory_desc_t ** + * *_${op}_d | const mkldnn_${op}_desc_t ** + * *_pd | const_mkldnn_primitive_desc_t * + * + * @note + * Rule of thumb: all opaque types and structures are returned by + * reference. All numbers are returned by value. + * + * @warning + * All returned references point to constant objects and are valid only + * during the lifetime of the queried primitive descriptor. Returned objects + * must not be destroyed by the user. If you need to keep the object longer + * than the lifetime of the queried primitive descriptor, use + * mkldnn_primitive_desc_clone() to make a copy. */ +typedef enum { + mkldnn_query_undef = 0, /**< no query */ + + mkldnn_query_engine, /**< execution engine */ + mkldnn_query_primitive_kind, /**< primitive kind */ + + mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */ + mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */ + + mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */ + mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra + (scratch) memory, additional to all + inputs and outputs memory (bytes) */ + + mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used + for creating scratchpad memory */ + + mkldnn_query_impl_info_str, /**< implementation name */ + + /* memory and op descriptor section */ + mkldnn_query_some_d = 64, /**< stub */ + mkldnn_query_op_d, /**< op descriptor */ + mkldnn_query_convolution_d, /**< convolution descriptor */ + mkldnn_query_deconvolution_d, /**< deconvolution descriptor */ + mkldnn_query_shuffle_d, /**< shuffle descriptor */ + mkldnn_query_eltwise_d, /**< eltwise descriptor */ + mkldnn_query_softmax_d, /**< softmax descriptor */ + mkldnn_query_pooling_d, /**< pooling descriptor */ + mkldnn_query_lrn_d, /**< lrn descriptor */ + mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */ + mkldnn_query_inner_product_d, /**< inner product descriptor */ + mkldnn_query_rnn_d, /**< rnn descriptor */ + + /* memory descriptor section */ + mkldnn_query_some_md = 128, /**< stub */ + mkldnn_query_src_md, /**< source memory desc */ + mkldnn_query_diff_src_md, /**< source gradient memory desc */ + mkldnn_query_weights_md, /**< weights memory descriptor desc */ + mkldnn_query_diff_weights_md, /**< weights grad. memory desc */ + mkldnn_query_dst_md, /**< destination memory desc */ + mkldnn_query_diff_dst_md, /**< destination grad. memory desc */ + mkldnn_query_workspace_md, /**< workspace memory desc */ + mkldnn_query_scratchpad_md, /**< scratchpad memory desc */ +} mkldnn_query_t; + +/** @} */ + +/** @addtogroup c_api_types_stream Execution stream + * @{ */ + +/** @brief Stream flags. */ +typedef enum { + /** A default stream configuration. */ + mkldnn_stream_default_flags = 0x0U, +} mkldnn_stream_flags_t; + +/** @struct mkldnn_stream + * An opaque structure to describe an execution stream. */ +struct mkldnn_stream; +/** An execution stream handle. */ +typedef struct mkldnn_stream *mkldnn_stream_t; +/** A constant execution stream handle. */ +typedef const struct mkldnn_stream *const_mkldnn_stream_t; + +/** @} */ +/** @} */ +/** @} */ + +#ifdef __cplusplus +} +#endif + + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h new file mode 100644 index 0000000000..a2713deccb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR 0 + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR 90 + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH 0 + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in new file mode 100644 index 0000000000..5ee0126188 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_VERSION_H +#define MKLDNN_VERSION_H + +/* Major version of MKL-DNN */ +#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@ + +/* Minor version of MKL-DNN */ +#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@ + +/* Patch version of MKL-DNN */ +#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@ + +/* Git Commit Hash of MKL-DNN */ +#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@" + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp new file mode 100644 index 0000000000..1a51d8562b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) { + bool args_ok = true + && !any_null(bnrm_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data, backward) + && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto bd = batch_normalization_desc_t(); + bd.primitive_kind = primitive_kind::batch_normalization; + bd.prop_kind = prop_kind; + + bd.data_desc = *data_desc; + bd.diff_data_desc = zero_md(); + if ( one_of(bd.prop_kind,backward_data, backward) ) + bd.diff_data_desc = *diff_data_desc; + + dims_t scaleshift_dims = { 2, data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2, + scaleshift_dims, data_type::f32, mkldnn_nc); + bd.diff_data_scaleshift_desc = zero_md(); + if (bd.prop_kind == backward) { + bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc; + } + + dims_t stats_dims = { data_desc->dims[1] }; + mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims, + data_type::f32, mkldnn_x); + bd.variance_desc = bd.mean_desc; + bd.batch_norm_epsilon = epsilon; + + unsigned bnorm_flags = + mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu; + if ((~bnorm_flags & flags) != 0) return invalid_arguments; + + bd.flags = flags; + + bool consistency = true + && utils::one_of(bd.data_desc.ndims, 2, 4, 5); + if (bd.prop_kind == backward_data) + consistency = consistency + && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5) + && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims, + bd.diff_data_desc.ndims); + if (!consistency) return invalid_arguments; + + *bnrm_desc = bd; + return success; +} +} + +status_t mkldnn_batch_normalization_forward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, float epsilon, unsigned flags) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr, + epsilon, flags); +} + +status_t mkldnn_batch_normalization_backward_desc_init( + batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind, + const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc, + float epsilon, unsigned flags) { + if (!one_of(prop_kind, backward, backward_data)) + return invalid_arguments; + return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc, + epsilon, flags); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp new file mode 100644 index 0000000000..f61410b33c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef BATCH_NORMALIZATION_PD_HPP +#define BATCH_NORMALIZATION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct batch_normalization_fwd_pd_t; + +struct batch_normalization_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::batch_normalization; + + batch_normalization_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , stat_md_(desc_.mean_desc) + , scaleshift_md_(desc_.data_scaleshift_desc) + , ws_md_() + {} + + const batch_normalization_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::batch_normalization_d: + *(const batch_normalization_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common batch_normalization aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return desc_.data_desc.ndims; } + + bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; } + bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; } + bool use_global_stats() const + { return desc_.flags & mkldnn_use_global_stats; } + bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; } + bool with_relu_post_op() const { + const auto &p = this->attr()->post_ops_; + return p.len_ == 1 && p.entry_[0].is_relu(true, true); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + bool is_bwd() const { return !this->is_fwd(); } + bool is_training() const + { return desc_.prop_kind == prop_kind::forward_training; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + batch_normalization_desc_t desc_; + const batch_normalization_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t stat_md_; + memory_desc_t scaleshift_md_; + + memory_desc_t ws_md_; + + void init_default_ws(size_t bits_per_element) { + const auto data_mdw = memory_desc_wrapper(data_md_); + + const dim_t data_nelems = data_mdw.nelems(true); + const dim_t bits_per_byte = 8; + const dims_t ws_sz = { (dim_t)utils::div_up( + data_nelems * bits_per_element, bits_per_byte) }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8, + format_tag::x); + } + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_fwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_fwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input; + if (arg == MKLDNN_ARG_DST) return arg_usage_t::output; + + if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) { + if (stats_is_src()) return arg_usage_t::input; + if (!stats_is_src() && is_training()) return arg_usage_t::output; + return arg_usage_t::unused; + } + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &data_md_; + if (!stats_is_src() && is_training() && (index == 1 || index == 2)) + return &stat_md_; + return nullptr; + } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const + { return stats_is_src() ? src_md(1) : dst_md(1); } + + virtual int n_inputs() const override + { return 1 + 2 * stats_is_src() + use_scaleshift(); } + virtual int n_outputs() const override + { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); } +}; + +struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t { + typedef batch_normalization_bwd_pd_t base_class; + typedef batch_normalization_fwd_pd_t hint_class; + + batch_normalization_bwd_pd_t(engine_t *engine, + const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN, + MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &scaleshift_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override + { return index == 0 ? &diff_scaleshift_md_ : nullptr; } + + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; } + + const memory_desc_t *stat_md() const { return src_md(1); } + + virtual int n_inputs() const override + { return 4 + use_scaleshift() + fuse_bn_relu(); } + virtual int n_outputs() const override + { return 1 + (desc_.prop_kind == prop_kind::backward); } + +protected: + memory_desc_t diff_data_md_; + memory_desc_t diff_scaleshift_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp new file mode 100644 index 0000000000..3d43a0fbee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp @@ -0,0 +1,550 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_MAPPING_HPP +#define TYPE_MAPPING_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { + +// TODO: autogenerate this + +using dim_t = mkldnn_dim_t; +using dims_t = mkldnn_dims_t; +using stride_t = mkldnn_dim_t; +using strides_t = mkldnn_strides_t; + +using status_t = mkldnn_status_t; +namespace status { + const status_t success = mkldnn_success; + const status_t out_of_memory = mkldnn_out_of_memory; + const status_t try_again = mkldnn_try_again; + const status_t invalid_arguments = mkldnn_invalid_arguments; + const status_t not_ready = mkldnn_not_ready; + const status_t unimplemented = mkldnn_unimplemented; + const status_t iterator_ends = mkldnn_iterator_ends; + const status_t runtime_error = mkldnn_runtime_error; + const status_t not_required = mkldnn_not_required; +} + +using prop_kind_t = mkldnn_prop_kind_t; +namespace prop_kind { + const prop_kind_t undef = mkldnn_prop_kind_undef; + const prop_kind_t forward_training = mkldnn_forward_training; + const prop_kind_t forward_inference = mkldnn_forward_inference; + const prop_kind_t forward_scoring = mkldnn_forward_scoring; + const prop_kind_t forward = mkldnn_forward; + const prop_kind_t backward = mkldnn_backward; + const prop_kind_t backward_data = mkldnn_backward_data; + const prop_kind_t backward_weights = mkldnn_backward_weights; + const prop_kind_t backward_bias = mkldnn_backward_bias; +} + +using alg_kind_t = mkldnn_alg_kind_t; +namespace alg_kind { + const alg_kind_t undef = mkldnn_alg_kind_undef; + const alg_kind_t convolution_auto = mkldnn_convolution_auto; + const alg_kind_t convolution_direct = mkldnn_convolution_direct; + const alg_kind_t convolution_winograd = mkldnn_convolution_winograd; + const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct; + const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd; + const alg_kind_t eltwise_relu = mkldnn_eltwise_relu; + const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh; + const alg_kind_t eltwise_elu = mkldnn_eltwise_elu; + const alg_kind_t eltwise_square = mkldnn_eltwise_square; + const alg_kind_t eltwise_abs = mkldnn_eltwise_abs; + const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt; + const alg_kind_t eltwise_linear = mkldnn_eltwise_linear; + const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu; + const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu; + const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic; + const alg_kind_t pooling_max = mkldnn_pooling_max; + const alg_kind_t pooling_avg = mkldnn_pooling_avg; + const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding; + const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding; + const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels; + const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel; + const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn; + const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm; + const alg_kind_t vanilla_gru = mkldnn_vanilla_gru; + const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset; +} + +using data_type_t = mkldnn_data_type_t; +namespace data_type { + const data_type_t undef = mkldnn_data_type_undef; + const data_type_t f32 = mkldnn_f32; + const data_type_t s32 = mkldnn_s32; + const data_type_t s8 = mkldnn_s8; + const data_type_t u8 = mkldnn_u8; +} + +using scratchpad_mode_t = mkldnn_scratchpad_mode_t; +namespace scratchpad_mode { + const scratchpad_mode_t library = mkldnn_scratchpad_mode_library; + const scratchpad_mode_t user = mkldnn_scratchpad_mode_user; +} + +using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t; +namespace rnn_packed_format { + const rnn_packed_format_t undef = mkldnn_packed_format_undef; + const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p; + const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p; +} + +using format_kind_t = mkldnn_format_kind_t; +namespace format_kind { + const format_kind_t undef = mkldnn_format_kind_undef; + const format_kind_t any = mkldnn_format_kind_any; + const format_kind_t blocked = mkldnn_blocked; + const format_kind_t wino = mkldnn_format_kind_wino; + const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed; +} + +using format_tag_t = mkldnn_format_tag_t; +namespace format_tag { + const format_tag_t undef = mkldnn_format_tag_undef; + const format_tag_t any = mkldnn_format_tag_any; + const format_tag_t a = mkldnn_a; + const format_tag_t ab = mkldnn_ab; + const format_tag_t abc = mkldnn_abc; + const format_tag_t abcd = mkldnn_abcd; + const format_tag_t abcde = mkldnn_abcde; + const format_tag_t abcdef = mkldnn_abcdef; + const format_tag_t abdec = mkldnn_abdec; + const format_tag_t acb = mkldnn_acb; + const format_tag_t acbde = mkldnn_acbde; + const format_tag_t acdb = mkldnn_acdb; + const format_tag_t acdeb = mkldnn_acdeb; + const format_tag_t ba = mkldnn_ba; + const format_tag_t bac = mkldnn_bac; + const format_tag_t bacd = mkldnn_bacd; + const format_tag_t bcda = mkldnn_bcda; + const format_tag_t cba = mkldnn_cba; + const format_tag_t cdba = mkldnn_cdba; + const format_tag_t cdeba = mkldnn_cdeba; + const format_tag_t decab = mkldnn_decab; + const format_tag_t Abc16a = mkldnn_Abc16a; + const format_tag_t ABc16a16b = mkldnn_ABc16a16b; + const format_tag_t aBc16b = mkldnn_aBc16b; + const format_tag_t ABc16b16a = mkldnn_ABc16b16a; + const format_tag_t Abc4a = mkldnn_Abc4a; + const format_tag_t aBc4b = mkldnn_aBc4b; + const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b; + const format_tag_t ABc4b4a = mkldnn_ABc4b4a; + const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a; + const format_tag_t ABc8a8b = mkldnn_ABc8a8b; + const format_tag_t aBc8b = mkldnn_aBc8b; + const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b; + const format_tag_t ABc8b8a = mkldnn_ABc8b8a; + const format_tag_t Abcd16a = mkldnn_Abcd16a; + const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b; + const format_tag_t aBcd16b = mkldnn_aBcd16b; + const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a; + const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c; + const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b; + const format_tag_t Abcd4a = mkldnn_Abcd4a; + const format_tag_t aBcd4b = mkldnn_aBcd4b; + const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b; + const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a; + const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c; + const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b; + const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a; + const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b; + const format_tag_t aBcd8b = mkldnn_aBcd8b; + const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b; + const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b; + const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a; + const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c; + const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c; + const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b; + const format_tag_t Abcde16a = mkldnn_Abcde16a; + const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b; + const format_tag_t aBcde16b = mkldnn_aBcde16b; + const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a; + const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c; + const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b; + const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c; + const format_tag_t Abcde4a = mkldnn_Abcde4a; + const format_tag_t aBcde4b = mkldnn_aBcde4b; + const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a; + const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c; + const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c; + const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b; + const format_tag_t Abcde8a = mkldnn_Abcde8a; + const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b; + const format_tag_t aBcde8b = mkldnn_aBcde8b; + const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b; + const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b; + const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a; + const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c; + const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c; + const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b; + const format_tag_t aBcdef16b = mkldnn_aBcdef16b; + const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c; + const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b; + const format_tag_t aBcdef4b = mkldnn_aBcdef4b; + const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b; + const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c; + const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c; + const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b; + const format_tag_t aBdc16b = mkldnn_aBdc16b; + const format_tag_t aBdc4b = mkldnn_aBdc4b; + const format_tag_t aBdc8b = mkldnn_aBdc8b; + const format_tag_t aBdec16b = mkldnn_aBdec16b; + const format_tag_t aBdec4b = mkldnn_aBdec4b; + const format_tag_t aBdec8b = mkldnn_aBdec8b; + const format_tag_t aBdefc16b = mkldnn_aBdefc16b; + const format_tag_t aBdefc4b = mkldnn_aBdefc4b; + const format_tag_t aBdefc8b = mkldnn_aBdefc8b; + const format_tag_t Acb16a = mkldnn_Acb16a; + const format_tag_t Acb4a = mkldnn_Acb4a; + const format_tag_t Acb8a = mkldnn_Acb8a; + const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c; + const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c; + const format_tag_t Acdb16a = mkldnn_Acdb16a; + const format_tag_t Acdb4a = mkldnn_Acdb4a; + const format_tag_t Acdb8a = mkldnn_Acdb8a; + const format_tag_t Acdeb16a = mkldnn_Acdeb16a; + const format_tag_t Acdeb4a = mkldnn_Acdeb4a; + const format_tag_t Acdeb8a = mkldnn_Acdeb8a; + const format_tag_t BAc16a16b = mkldnn_BAc16a16b; + const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b; + const format_tag_t last = mkldnn_format_tag_last; + + const format_tag_t x = mkldnn_x; + const format_tag_t nc = mkldnn_nc; + const format_tag_t cn = mkldnn_cn; + const format_tag_t ncw = mkldnn_ncw; + const format_tag_t nwc = mkldnn_nwc; + const format_tag_t nchw = mkldnn_nchw; + const format_tag_t nhwc = mkldnn_nhwc; + const format_tag_t chwn = mkldnn_chwn; + const format_tag_t ncdhw = mkldnn_ncdhw; + const format_tag_t ndhwc = mkldnn_ndhwc; + const format_tag_t oi = mkldnn_oi; + const format_tag_t io = mkldnn_io; + const format_tag_t oiw = mkldnn_oiw; + const format_tag_t wio = mkldnn_wio; + const format_tag_t oihw = mkldnn_oihw; + const format_tag_t hwio = mkldnn_hwio; + const format_tag_t ihwo = mkldnn_ihwo; + const format_tag_t iohw = mkldnn_iohw; + const format_tag_t oidhw = mkldnn_oidhw; + const format_tag_t dhwio = mkldnn_dhwio; + const format_tag_t goiw = mkldnn_goiw; + const format_tag_t goihw = mkldnn_goihw; + const format_tag_t hwigo = mkldnn_hwigo; + const format_tag_t giohw = mkldnn_giohw; + const format_tag_t goidhw = mkldnn_goidhw; + const format_tag_t tnc = mkldnn_tnc; + const format_tag_t ntc = mkldnn_ntc; + const format_tag_t ldsnc = mkldnn_ldsnc; + const format_tag_t ldigo = mkldnn_ldigo; + const format_tag_t ldgoi = mkldnn_ldgoi; + const format_tag_t ldgo = mkldnn_ldgo; + const format_tag_t nCdhw16c = mkldnn_nCdhw16c; + const format_tag_t nCdhw4c = mkldnn_nCdhw4c; + const format_tag_t nCdhw8c = mkldnn_nCdhw8c; + const format_tag_t nChw16c = mkldnn_nChw16c; + const format_tag_t nChw4c = mkldnn_nChw4c; + const format_tag_t nChw8c = mkldnn_nChw8c; + const format_tag_t nCw16c = mkldnn_nCw16c; + const format_tag_t nCw4c = mkldnn_nCw4c; + const format_tag_t nCw8c = mkldnn_nCw8c; + const format_tag_t IOw16o16i = mkldnn_IOw16o16i; + const format_tag_t OIw16i16o = mkldnn_OIw16i16o; + const format_tag_t OIw16o16i = mkldnn_OIw16o16i; + const format_tag_t Oiw16o = mkldnn_Oiw16o; + const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i; + const format_tag_t OIw4i4o = mkldnn_OIw4i4o; + const format_tag_t Oiw4o = mkldnn_Oiw4o; + const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i; + const format_tag_t OIw8i8o = mkldnn_OIw8i8o; + const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o; + const format_tag_t OIw8o8i = mkldnn_OIw8o8i; + const format_tag_t Owi16o = mkldnn_Owi16o; + const format_tag_t Owi4o = mkldnn_Owi4o; + const format_tag_t Owi8o = mkldnn_Owi8o; + const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i; + const format_tag_t Ohwi16o = mkldnn_Ohwi16o; + const format_tag_t Ohwi4o = mkldnn_Ohwi4o; + const format_tag_t Ohwi8o = mkldnn_Ohwi8o; + const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o; + const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i; + const format_tag_t Oihw16o = mkldnn_Oihw16o; + const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i; + const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o; + const format_tag_t Oihw4o = mkldnn_Oihw4o; + const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i; + const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o; + const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o; + const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i; + const format_tag_t Odhwi16o = mkldnn_Odhwi16o; + const format_tag_t Odhwi4o = mkldnn_Odhwi4o; + const format_tag_t Odhwi8o = mkldnn_Odhwi8o; + const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o; + const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i; + const format_tag_t Oidhw16o = mkldnn_Oidhw16o; + const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o; + const format_tag_t Oidhw4o = mkldnn_Oidhw4o; + const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i; + const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o; + const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i; + const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i; + const format_tag_t Goiw16g = mkldnn_Goiw16g; + const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o; + const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i; + const format_tag_t gOiw16o = mkldnn_gOiw16o; + const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i; + const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o; + const format_tag_t gOiw4o = mkldnn_gOiw4o; + const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i; + const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o; + const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; + const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i; + const format_tag_t gOwi16o = mkldnn_gOwi16o; + const format_tag_t gOwi4o = mkldnn_gOwi4o; + const format_tag_t gOwi8o = mkldnn_gOwi8o; + const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i; + const format_tag_t gOhwi16o = mkldnn_gOhwi16o; + const format_tag_t gOhwi4o = mkldnn_gOhwi4o; + const format_tag_t gOhwi8o = mkldnn_gOhwi8o; + const format_tag_t Goihw16g = mkldnn_Goihw16g; + const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o; + const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i; + const format_tag_t gOihw16o = mkldnn_gOihw16o; + const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i; + const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i; + const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o; + const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i; + const format_tag_t gOihw4o = mkldnn_gOihw4o; + const format_tag_t Goihw8g = mkldnn_Goihw8g; + const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i; + const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o; + const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o; + const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i; + const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o; + const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o; + const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o; + const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o; + const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i; + const format_tag_t gOidhw16o = mkldnn_gOidhw16o; + const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o; + const format_tag_t gOidhw4o = mkldnn_gOidhw4o; + const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i; + const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o; + const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i; +} + +using memory_extra_flags_t = mkldnn_memory_extra_flags_t; +namespace memory_extra_flags { + const memory_extra_flags_t none = mkldnn_memory_extra_flag_none; + const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8; + const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust; +} + +using padding_kind_t = mkldnn_padding_kind_t; +namespace padding_kind { + const padding_kind_t padding_zero = mkldnn_padding_zero; +} + +using engine_kind_t = mkldnn_engine_kind_t; +namespace engine_kind { + const engine_kind_t any_engine = mkldnn_any_engine; + const engine_kind_t cpu = mkldnn_cpu; +} + +using primitive_kind_t = mkldnn_primitive_kind_t; +namespace primitive_kind { + const primitive_kind_t undefined = mkldnn_undefined_primitive; + const primitive_kind_t reorder = mkldnn_reorder; + const primitive_kind_t concat = mkldnn_concat; + const primitive_kind_t sum = mkldnn_sum; + const primitive_kind_t convolution = mkldnn_convolution; + const primitive_kind_t deconvolution = mkldnn_deconvolution; + const primitive_kind_t shuffle = mkldnn_shuffle; + const primitive_kind_t eltwise = mkldnn_eltwise; + const primitive_kind_t softmax = mkldnn_softmax; + const primitive_kind_t pooling = mkldnn_pooling; + const primitive_kind_t lrn = mkldnn_lrn; + const primitive_kind_t batch_normalization = mkldnn_batch_normalization; + const primitive_kind_t inner_product = mkldnn_inner_product; + const primitive_kind_t rnn = mkldnn_rnn; +} + +using query_t = mkldnn_query_t; +namespace query { + const query_t undef = mkldnn_query_undef; + + const query_t engine = mkldnn_query_engine; + const query_t primitive_kind = mkldnn_query_primitive_kind; + + const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32; + const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32; + + const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64; + const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64; + + const query_t scratchpad_engine = mkldnn_query_scratchpad_engine; + + const query_t impl_info_str = mkldnn_query_impl_info_str; + + const query_t some_d = mkldnn_query_some_d; + const query_t op_d = mkldnn_query_op_d; + const query_t convolution_d = mkldnn_query_convolution_d; + const query_t deconvolution_d = mkldnn_query_deconvolution_d; + const query_t shuffle_d = mkldnn_query_shuffle_d; + const query_t eltwise_d = mkldnn_query_eltwise_d; + const query_t softmax_d = mkldnn_query_softmax_d; + const query_t pooling_d = mkldnn_query_pooling_d; + const query_t lrn_d = mkldnn_query_lrn_d; + const query_t batch_normalization_d = mkldnn_query_batch_normalization_d; + const query_t inner_product_d = mkldnn_query_inner_product_d; + const query_t rnn_d = mkldnn_query_rnn_d; + + const query_t some_md = mkldnn_query_some_md; + const query_t src_md = mkldnn_query_src_md; + const query_t diff_src_md = mkldnn_query_diff_src_md; + const query_t weights_md = mkldnn_query_weights_md; + const query_t diff_weights_md = mkldnn_query_diff_weights_md; + const query_t dst_md = mkldnn_query_dst_md; + const query_t diff_dst_md = mkldnn_query_diff_dst_md; + + const query_t workspace_md = mkldnn_query_workspace_md; + const query_t scratchpad_md = mkldnn_query_scratchpad_md; +} + +using blocking_desc_t = mkldnn_blocking_desc_t; +using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t; +using wino_desc_t = mkldnn_wino_desc_t; +using memory_extra_desc_t = mkldnn_memory_extra_desc_t; +using memory_desc_t = mkldnn_memory_desc_t; +using convolution_desc_t = mkldnn_convolution_desc_t; +using deconvolution_desc_t = mkldnn_deconvolution_desc_t; +using shuffle_desc_t = mkldnn_shuffle_desc_t; +using pooling_desc_t = mkldnn_pooling_desc_t; +using eltwise_desc_t = mkldnn_eltwise_desc_t; +using softmax_desc_t = mkldnn_softmax_desc_t; +using lrn_desc_t = mkldnn_lrn_desc_t; +using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t; +using inner_product_desc_t = mkldnn_inner_product_desc_t; + +using rnn_direction_t = mkldnn_rnn_direction_t; +using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t; +using rnn_desc_t = mkldnn_rnn_desc_t; + +/* C op_desc_t, which eventually are just (void*) */ +using c_op_desc_t = mkldnn_op_desc_t; +using const_c_op_desc_t = const_mkldnn_op_desc_t; + +struct op_desc_t { + union { + primitive_kind_t kind; + convolution_desc_t convolution; + deconvolution_desc_t deconvolution; + shuffle_desc_t shuffle; + pooling_desc_t pooling; + eltwise_desc_t eltwise; + softmax_desc_t softmax; + lrn_desc_t lrn; + batch_normalization_desc_t batch_normalization; + inner_product_desc_t inner_product; + rnn_desc_t rnn; + }; + + op_desc_t(const primitive_kind_t &_): kind(_) {} + +# define DECL_CTOR_AND_CONVERTERS(c_type, name) \ + op_desc_t(const c_type &_): name(_) {} \ + static op_desc_t *convert_from_c(c_type *_) \ + { return reinterpret_cast<op_desc_t*>(_); } \ + static const op_desc_t *convert_from_c(const c_type *_) \ + { return reinterpret_cast<const op_desc_t*>(_); } + + DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution); + DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle); + DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling); + DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise); + DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax); + DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn); + DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization); + DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product); + DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn); + +# undef DECL_CTOR_AND_CONVERTERS +}; + +using engine_t = mkldnn_engine; +using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator; +using primitive_desc_t = mkldnn_primitive_desc; +using primitive_attr_t = mkldnn_primitive_attr; +using post_ops_t = mkldnn_post_ops; +using memory_t = mkldnn_memory; +using primitive_t = mkldnn_primitive; + +using primitive_arg_index_t = int; + +using stream_flags_t = mkldnn_stream_flags_t; +namespace stream_flags { + const stream_flags_t default_flags = mkldnn_stream_default_flags; +} +using stream_t = mkldnn_stream; + +/* forward declaration of the internal primitive_desc types */ +struct batch_normalization_bwd_pd_t; +struct batch_normalization_fwd_pd_t; +struct batch_normalization_pd_t; +struct concat_pd_t; +struct convolution_bwd_data_pd_t; +struct convolution_bwd_weights_pd_t; +struct convolution_fwd_pd_t; +struct convolution_pd_t; +struct deconvolution_bwd_data_pd_t; +struct deconvolution_bwd_weights_pd_t; +struct deconvolution_fwd_pd_t; +struct deconvolution_pd_t; +struct eltwise_bwd_pd_t; +struct eltwise_fwd_pd_t; +struct eltwise_pd_t; +struct inner_product_bwd_data_pd_t; +struct inner_product_bwd_weights_pd_t; +struct inner_product_fwd_pd_t; +struct inner_product_pd_t; +struct lrn_bwd_pd_t; +struct lrn_fwd_pd_t; +struct lrn_pd_t; +struct pooling_bwd_pd_t; +struct pooling_fwd_pd_t; +struct pooling_pd_t; +struct reorder_pd_t; +struct rnn_bwd_pd_t; +struct rnn_fwd_pd_t; +struct rnn_pd_t; +struct shuffle_pd_t; +struct softmax_bwd_pd_t; +struct softmax_fwd_pd_t; +struct softmax_pd_t; +struct sum_pd_t; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp new file mode 100644 index 0000000000..ed4c35c6e9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "concat_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds, + const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(concat_pd, src_mds) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + int concat_dim_sz = dims[concat_dim]; + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (d == concat_dim) continue; + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + concat_dim_sz += src_mds[i].dims[concat_dim]; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != + (d == concat_dim ? concat_dim_sz : dims[d])) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.dims[concat_dim] = concat_dim_sz; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd); + + for (auto c = engine->get_concat_implementation_list(); *c; ++c) { + if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds) + == success) { + (*c_pd)->init_info(); + (*c_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp new file mode 100644 index 0000000000..29311927e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp @@ -0,0 +1,211 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONCAT_PD_HPP +#define CONCAT_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct concat_pd_t: public primitive_desc_t { + concat_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, int concat_dim, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::concat) + , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md) + { + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + concat_pd_t(const concat_pd_t &rhs) = default; + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + int concat_dim() const { return concat_dim_; } + + const memory_desc_t *src_image_md(int index = 0) const + { return index < n_inputs() ? &src_image_mds_[index] : nullptr; } + +protected: + int n_, concat_dim_; + memory_desc_t dst_md_; + nstl::vector<memory_desc_t> src_mds_; + + /* contains images of srcs in the dst memory (if possible) + * Lives here to simplify some implementations. An implementation might + * use this auxiliary array iff init() returned success */ + nstl::vector<memory_desc_t> src_image_mds_; + +protected: + /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */ + status_t init() { + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + if (!i_d.is_blocking_desc() || i_d.is_additional_buffer()) + return status::unimplemented; + } + + const int ndims = dst_md_.ndims; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) return status; + src_image_mds_.push_back(src_img_d); + current_concat_dim_offset += dim; + } + + return status::success; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + const int ndims = dst_md_.ndims; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain or it is not possible to create a + * blocked format for the output, pick the format of the plain input + * - If this fails as well, use plain layout (abcd...) + */ + status_t status = status::unimplemented; + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && !src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + if (status == status::success) break; + } + } + + if (status == status::success) { + /* check if we can create a sub-memory for the dst */ + bool desired_format_ok = true; + int current_concat_dim_offset = 0; + for (int i = 0; i < n_; ++i) { + const int dim = src_mds_[i].dims[concat_dim_]; + dims_t dims, offsets = {}; + utils::array_copy(dims, dst_md_.dims, ndims); + dims[concat_dim_] = dim; + offsets[concat_dim_] = current_concat_dim_offset; + + memory_desc_t src_img_d; + status_t status = mkldnn_memory_desc_init_submemory(&src_img_d, + &dst_md_, dims, offsets); + if (status != status::success) { + desired_format_ok = false; + break; + } + current_concat_dim_offset += dim; + } + + if (!desired_format_ok) + status = status::unimplemented; + } + + /* if no success so far, try using the format of the first plain input */ + if (status != status::success) { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (src_d.is_blocking_desc() && src_d.is_plain()) { + status = memory_desc_init_by_blocking_desc(dst_md_, + memory_desc_wrapper(src_mds_[0]).blocking_desc()); + if (status == status::success) return status; + } + } + } + + /* the last line of defense: use plain abcd... format */ + if (status != status::success) + status = memory_desc_init_by_strides(dst_md_, nullptr); + + return status; + } +}; + +#define DECLARE_CONCAT_PD_t(impl_name, ...) \ + static status_t create(concat_pd_t **concat_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, int concat_dim, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_CONCAT_PD_T(impl_name, ...) \ + DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp new file mode 100644 index 0000000000..0c5c02bcd1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto cd = convolution_desc_t(); + cd.primitive_kind = primitive_kind::convolution; + cd.prop_kind = prop_kind; + cd.alg_kind = alg_kind; + + cd.diff_src_desc = cd.src_desc = zero_md(); + cd.diff_dst_desc = cd.dst_desc = zero_md(); + cd.diff_weights_desc = cd.weights_desc = zero_md(); + cd.diff_bias_desc = cd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc; + (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) = + *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(cd.strides, strides, sp_dims); + utils::array_copy(cd.padding[0], padding_l, sp_dims); + utils::array_copy(cd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(cd.dilates, dilates, sp_dims); + else + utils::array_set(cd.dilates, 0, sp_dims); + + cd.padding_kind = padding_kind; + cd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + const int bias_dim = prop_kind == backward_data + ? src_desc->dims[1] + : dst_desc->dims[1]; + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == bias_dim : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) + { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = cd.dilates[i - 2]; + int pad_l = padding_l[i - 2]; + int pad_r = padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + if (str < 1) return invalid_arguments; + consistency = consistency + && dil >= 0 + && pad_l >= 0 + && pad_r + str > 0 + && (src - ker_range + pad_l + pad_r) / str + 1 == dst; + } + if (!consistency) return invalid_arguments; + + *conv_desc = cd; + return success; +} +} +} + +status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_forward_desc_init( + convolution_desc_t *conv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_data_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + nullptr, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_convolution_backward_weights_desc_init( + convolution_desc_t *conv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, + dilates, padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp new file mode 100644 index 0000000000..9604e0acf5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "convolution_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); } +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc) +{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp new file mode 100644 index 0000000000..b10c36db49 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CONVOLUTION_PD_HPP +#define CONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t conv_desc_init(convolution_desc_t *conv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind); + +memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc); +memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc); +const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc); + +struct convolution_fwd_pd_t; + +struct convolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::convolution; + + convolution_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const convolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits<base_pkind>::query_d: + *(const convolution_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common conv aux functions */ + + dim_t MB() const { return _src_md()->dims[0]; } + + dim_t IC() const { return _src_md()->dims[1]; } + dim_t OC() const { return _dst_md()->dims[1]; } + dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; } + + dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; } + dim_t IW() const { return _src_md()->dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; } + dim_t OW() const { return _dst_md()->dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; } + dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; } + dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return _src_md()->ndims; } + + bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); } + bool with_groups() const { return _wei_md()->ndims == ndims() + 1; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*_src_md()); + const auto d_d = memory_desc_wrapper(*_dst_md()); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + convolution_desc_t desc_; + const convolution_fwd_pd_t *hint_fwd_pd_; + + bool set_default_formats_common_template( + memory_desc_t &src_md, format_tag_t src_tag, + memory_desc_t &wei_md, format_tag_t wei_tag, + memory_desc_t &dst_md, format_tag_t dst_tag, + memory_desc_t &bia_md) { + using namespace format_tag; + +# define IS_OK(f) \ + do { if ((f) != status::success) return false; } while(0) + if (src_md.format_kind == format_kind::any + && !utils::one_of(src_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(src_md, src_tag)); + if (dst_md.format_kind == format_kind::any + && !utils::one_of(dst_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(dst_md, dst_tag)); + if (wei_md.format_kind == format_kind::any + && !utils::one_of(wei_tag, any, undef)) + IS_OK(memory_desc_init_by_tag(wei_md, wei_tag)); + if (with_bias() && bia_md.format_kind == format_kind::any) + IS_OK(memory_desc_init_by_tag(bia_md, x)); +# undef IS_OK + + return true; + } + + bool set_default_alg_kind(alg_kind_t alg_kind) { + assert(utils::one_of(alg_kind, alg_kind::convolution_direct, + alg_kind::convolution_winograd)); + if (desc_.alg_kind == alg_kind::convolution_auto) + desc_.alg_kind = alg_kind; + return desc_.alg_kind == alg_kind; + } + + bool expect_data_types(data_type_t src_dt, data_type_t wei_dt, + data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const { + bool ok = true + && (src_dt == data_type::undef || _src_md()->data_type == src_dt) + && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt) + && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt) + && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt); + if (with_bias() && bia_dt != data_type::undef) + ok = ok && _bia_md()->data_type == bia_dt; + return ok; + } + +private: + const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); } + const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); } + const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); } + const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); } +}; + +struct convolution_fwd_pd_t: public convolution_pd_t { + typedef convolution_fwd_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_fwd_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t wei_tag, format_tag_t dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + weights_md_, wei_tag, dst_md_, dst_tag, bias_md_); + } +}; + +struct convolution_bwd_data_pd_t: public convolution_pd_t { + typedef convolution_bwd_data_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_data_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + + virtual bool support_bias() const { return false; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t diff_src_tag, + format_tag_t wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(diff_src_md_, diff_src_tag, + weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_); + } +}; + +struct convolution_bwd_weights_pd_t: public convolution_pd_t { + typedef convolution_bwd_weights_pd_t base_class; + typedef convolution_fwd_pd_t hint_class; + + convolution_bwd_weights_pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : convolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + bool set_default_formats_common(format_tag_t src_tag, + format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) { + return set_default_formats_common_template(src_md_, src_tag, + diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag, + diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp new file mode 100644 index 0000000000..98063c1c37 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp @@ -0,0 +1,188 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t deconv_desc_init(deconvolution_desc_t *deconv_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides, + padding_l) + && one_of(alg_kind, deconvolution_direct, deconvolution_winograd) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) + return invalid_arguments; + + if (padding_r == nullptr) + padding_r = padding_l; + + auto dd = deconvolution_desc_t(); + dd.primitive_kind = primitive_kind::deconvolution; + dd.prop_kind = prop_kind; + dd.alg_kind = alg_kind; + + dd.diff_src_desc = dd.src_desc = zero_md(); + dd.diff_dst_desc = dd.dst_desc = zero_md(); + dd.diff_weights_desc = dd.weights_desc = zero_md(); + dd.diff_bias_desc = dd.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias + = bias_desc && bias_desc->format_kind != format_kind::undef; + const bool with_groups = weights_desc->ndims == src_desc->ndims + 1; + + (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc; + (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc) + = *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc) + = *bias_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(dd.strides, strides, sp_dims); + utils::array_copy(dd.padding[0], padding_l, sp_dims); + utils::array_copy(dd.padding[1], padding_r, sp_dims); + if (dilates) + utils::array_copy(dd.dilates, dilates, sp_dims); + else + utils::array_set(dd.dilates, 0, sp_dims); + + dd.padding_kind = padding_kind; + dd.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + const int g = with_groups ? weights_desc->dims[0] : 1; + bool consistency = true + && src_desc->ndims == dst_desc->ndims + && utils::one_of(src_desc->ndims, 3, 4, 5) + && utils::one_of(weights_desc->ndims, src_desc->ndims, + src_desc->ndims + 1) + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1] + && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0]; + for (int i = 2; i < src_desc->ndims; ++i) { + int src = src_desc->dims[i]; + int ker = weights_desc->dims[with_groups + i]; + int dil = dd.dilates[i - 2]; + int pad = padding_l[i - 2] + padding_r[i - 2]; + int str = strides[i - 2]; + int dst = dst_desc->dims[i]; + int ker_range = 1 + (ker - 1) * (dil + 1); + + consistency + = consistency && (dst - ker_range + pad) / str + 1 == src; + } + if (!consistency) + return invalid_arguments; + + *deconv_desc = dd; + return success; +} +} + +status_t mkldnn_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_forward_desc_init( + deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc, + weights_desc, bias_desc, dst_desc, strides, dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_data_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t dilates, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc, + weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l, + padding_r, padding_kind); +} + +status_t mkldnn_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr, + padding_l, padding_r, padding_kind); +} + +status_t mkldnn_dilated_deconvolution_backward_weights_desc_init( + deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc, + const dims_t strides, const dims_t dilates, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc, + diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates, + padding_l, padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp new file mode 100644 index 0000000000..539e44bd9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp @@ -0,0 +1,293 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DECONVOLUTION_PD_HPP +#define DECONVOLUTION_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct deconvolution_fwd_pd_t; + +struct deconvolution_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::deconvolution; + + deconvolution_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const deconvolution_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case pkind_traits<base_pkind>::query_d: + *(const deconvolution_desc_t **)result = desc(); + break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */ + + dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; } + + dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; } + dim_t G() const + { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; } + + dim_t ID() const { + return ndims() >= 5 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1]; + } + + dim_t OD() const { + return ndims() >= 5 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1]; + } + + dim_t KD() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 5 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1; + } + dim_t KH() const { + const int w_ndims = ndims() + with_groups(); + return ndims() >= 4 + ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1; + } + dim_t KW() const { + const int w_ndims = ndims() + with_groups(); + return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1]; + } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; } + dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; } + dim_t KDW() const { return desc_.dilates[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + bool with_bias() const { + return + !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero(); + } + + bool with_groups() const + { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; } + + int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + +protected: + deconvolution_desc_t desc_; + const deconvolution_fwd_pd_t *hint_fwd_pd_; +}; + +struct deconvolution_fwd_pd_t: public deconvolution_pd_t { + typedef deconvolution_fwd_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_fwd_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; +}; + +struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_data_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_data_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; +}; + +struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t { + typedef deconvolution_bwd_weights_pd_t base_class; + typedef deconvolution_fwd_pd_t hint_class; + + deconvolution_bwd_weights_pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp new file mode 100644 index 0000000000..f1708fca52 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, float alpha, float beta) { + bool args_ok = true + && !any_null(eltwise_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward_data) + && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ed = eltwise_desc_t(); + ed.primitive_kind = primitive_kind::eltwise; + ed.prop_kind = prop_kind; + ed.alg_kind = alg_kind; + + ed.data_desc = *data_desc; + ed.diff_data_desc = + (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md(); + + ed.alpha = alpha; + ed.beta = beta; + + bool consistency = true + && IMPLICATION(ed.prop_kind == backward_data, + array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims, + ed.diff_data_desc.ndims)); + if (!consistency) return invalid_arguments; + + *eltwise_desc = ed; + return success; +} +} + +status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, float alpha, float beta) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc, + nullptr, alpha, beta); +} + +status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_data_desc, + const memory_desc_t *data_desc, float alpha, float beta) { + return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc, + diff_data_desc, alpha, beta); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp new file mode 100644 index 0000000000..9fd260fcee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ELTWISE_PD_HPP +#define ELTWISE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct eltwise_fwd_pd_t; + +struct eltwise_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::eltwise; + + eltwise_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const eltwise_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::eltwise_d: + *(const eltwise_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common eltwise aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + +protected: + eltwise_desc_t desc_; + const eltwise_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct eltwise_fwd_pd_t: public eltwise_pd_t { + typedef eltwise_fwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const + { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); } +}; + +struct eltwise_bwd_pd_t: public eltwise_pd_t { + typedef eltwise_bwd_pd_t base_class; + typedef eltwise_fwd_pd_t hint_class; + + eltwise_bwd_pd_t(engine_t *engine, + const eltwise_desc_t *adesc, + const primitive_attr_t *attr, + const eltwise_fwd_pd_t *hint_fwd_pd) + : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + + bool is_zero_preserved() const { return true; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp new file mode 100644 index 0000000000..3b3e25456d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" +#include "engine.hpp" +#include "nstl.hpp" + +#include "c_types_map.hpp" +#include "../cpu/cpu_engine.hpp" + +namespace mkldnn { +namespace impl { + +engine_factory_t *engine_factories[] = { + &cpu::engine_factory, + nullptr, +}; + +static inline engine_factory_t *get_engine_factory(engine_kind_t kind) { + for (engine_factory_t **ef = engine_factories; *ef; ef++) + if ((*ef)->kind() == kind) + return *ef; + return nullptr; +} + +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +size_t mkldnn_engine_get_count(engine_kind_t kind) { + engine_factory_t *ef = get_engine_factory(kind); + return ef != nullptr ? ef->count() : 0; +} + +status_t mkldnn_engine_create(engine_t **engine, + engine_kind_t kind, size_t index) { + if (engine == nullptr) + return invalid_arguments; + + engine_factory_t *ef = get_engine_factory(kind); + if (ef == nullptr || index >= ef->count()) + return invalid_arguments; + + return ef->engine_create(engine, index); +} + +status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) { + if (engine == nullptr) + return invalid_arguments; + *kind = engine->kind(); + return success; +} + +status_t mkldnn_engine_destroy(engine_t *engine) { + /* TODO: engine->dec_ref_count(); */ + delete engine; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp new file mode 100644 index 0000000000..8ac8a29de5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef ENGINE_HPP +#define ENGINE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive.hpp" +#include "utils.hpp" + +/** \brief An abstraction of an execution unit with shared resources + * + * Responsibilities: + * - Provide engine specific memory allocation + * - Provide engine specific primitive_desc_t creators + */ +struct mkldnn_engine: public mkldnn::impl::c_compatible { + mkldnn_engine(mkldnn::impl::engine_kind_t kind) + : kind_(kind) + {} + virtual ~mkldnn_engine() {} + + /** get kind of the current engine */ + virtual mkldnn::impl::engine_kind_t kind() const { return kind_; } + + /** allocate memory */ + virtual mkldnn::impl::status_t memory_create( + mkldnn::impl::memory_t **memory, + const mkldnn::impl::memory_desc_t *md, + void *handle) = 0; + + /** implementation section (typedefs) */ + + // TODO: remove engine? + typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)( + mkldnn::impl::reorder_pd_t **reorder_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *src_engine, + const mkldnn::impl::memory_desc_t *src_md, + mkldnn::impl::engine_t *dst_engine, + const mkldnn::impl::memory_desc_t *dst_md); + + typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)( + mkldnn::impl::concat_pd_t **concat_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, int concat_dim, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)( + mkldnn::impl::sum_pd_t **sum_pd, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + const mkldnn::impl::memory_desc_t *dst_md, + int n, const float *scales, + const mkldnn::impl::memory_desc_t *src_mds); + + typedef mkldnn::impl::status_t (*primitive_desc_create_f)( + mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *); + + /* implementation section */ + + /** return the list of reorder implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const = 0; + + /** return the list of concat implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const = 0; + + /** return the list of sum implementations. engine guarantees to return + * a NULL-terminated list */ + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const = 0; + + /** return the list of implementations. engine guarantees to return a + * NULL-terminated list */ + virtual const primitive_desc_create_f* get_implementation_list() const = 0; + +protected: + mkldnn::impl::engine_kind_t kind_; +}; + +namespace mkldnn { +namespace impl { + +struct engine_factory_t: public c_compatible { + virtual size_t count() const = 0; + virtual engine_kind_t kind() const = 0; + virtual status_t engine_create(engine_t **engine, size_t index) const = 0; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp new file mode 100644 index 0000000000..5a9f58cb1e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind, + const memory_desc_t *src_desc, const memory_desc_t *weights_desc, + const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) { + bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc); + if (!args_ok) return invalid_arguments; + + auto id = inner_product_desc_t(); + id.primitive_kind = primitive_kind::inner_product; + id.prop_kind = prop_kind; + + id.diff_src_desc = id.src_desc = zero_md(); + id.diff_dst_desc = id.dst_desc = zero_md(); + id.diff_weights_desc = id.weights_desc = zero_md(); + id.diff_bias_desc = id.bias_desc = zero_md(); + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + const bool with_bias = + bias_desc && bias_desc->format_kind != format_kind::undef; + + (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc; + (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc; + (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) = + *weights_desc; + if (with_bias) + (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) = + *bias_desc; + + id.accum_data_type = types::default_accum_data_type(src_desc->data_type, + weights_desc->data_type, dst_desc->data_type, prop_kind); + + bool consistency = true + && memory_desc_wrapper(weights_desc).nelems() + && one_of(src_desc->ndims, 2, 3, 4, 5) + && dst_desc->ndims == 2 + && weights_desc->ndims == src_desc->ndims + && (with_bias ? bias_desc->ndims == 1 : true) + && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true) + && src_desc->dims[0] == dst_desc->dims[0] + && array_cmp(&src_desc->dims[1], &weights_desc->dims[1], + src_desc->ndims - 1) + && dst_desc->dims[1] == weights_desc->dims[0]; + if (!consistency) return invalid_arguments; + + *ip_desc = id; + return success; +} +} + +status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc, + prop_kind_t prop_kind, const memory_desc_t *src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_desc) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc, + dst_desc); +} + +status_t mkldnn_inner_product_backward_data_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc, + const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc) +{ + return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc, + nullptr, diff_dst_desc); +} + +status_t mkldnn_inner_product_backward_weights_desc_init( + inner_product_desc_t *ip_desc, const memory_desc_t *src_desc, + const memory_desc_t *diff_weights_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_desc) { + return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc, + diff_bias_desc, diff_dst_desc); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp new file mode 100644 index 0000000000..091cf0f5d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "utils.hpp" + +#include "inner_product_pd.hpp" + +namespace mkldnn { +namespace impl { + +using namespace prop_kind; + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_data + ? &desc->diff_src_desc : &desc->src_desc; +} + +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_weights_desc : &desc->weights_desc; +} + +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) { + return desc->prop_kind == backward_weights + ? &desc->diff_bias_desc : &desc->bias_desc; +} + +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) { + return utils::one_of(desc->prop_kind, forward_inference, forward_training) + ? &desc->dst_desc : &desc->diff_dst_desc; +} + +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); } +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc) +{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); } + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp new file mode 100644 index 0000000000..c426de632c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef INNER_PRODUCT_PD_HPP +#define INNER_PRODUCT_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc); +memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc); +const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc); + +struct inner_product_fwd_pd_t; + +struct inner_product_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::inner_product; + + inner_product_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + {} + + const inner_product_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::inner_product_d: + *(const inner_product_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common inner_product aux functions */ + + dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; } + dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; } + dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; } + + dim_t ID() const { + return ndims() >= 5 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t IH() const { + return ndims() >= 4 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t IW() const { + return ndims() >= 3 + ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t OD() const { + return ndims() >= 5 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t OH() const { + return ndims() >= 4 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t OW() const { + return ndims() >= 3 + ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t KD() const { + return ndims() >= 5 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1; + } + dim_t KH() const { + return ndims() >= 4 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1; + } + dim_t KW() const { + return ndims() >= 3 + ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1; + } + + dim_t IC_total() const { + return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1], + ndims() - 1); + } + + dim_t IC_total_padded() const { + auto src_d = desc()->prop_kind == prop_kind::backward_data + ? memory_desc_wrapper(diff_src_md()) + : memory_desc_wrapper(src_md()); + assert(src_d.is_blocking_desc()); + if (!src_d.is_blocking_desc()) return -1; + return utils::array_product(src_d.padded_dims() + 1, ndims() - 1); + } + + int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; } + + bool with_bias() const + { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); } + + bool has_zero_dim_memory() const { + const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_)); + const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_)); + return s_d.has_zero_dim() || d_d.has_zero_dim(); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + inner_product_desc_t desc_; + const inner_product_fwd_pd_t *hint_fwd_pd_; + + status_t template_set_default_params(memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t *bias_md) { + using namespace format_tag; + if (src_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, + utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw))); + } + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, nc)); + if (weights_md.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, + utils::pick(ndims() - 2, oi, oiw, oihw, oidhw))); + } + if (bias_md && bias_md->format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(*bias_md, x)); + return status::success; + } +}; + +struct inner_product_fwd_pd_t: public inner_product_pd_t { + typedef inner_product_fwd_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_fwd_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , weights_md_(desc_.weights_desc) + , bias_md_(desc_.bias_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_md_; + if (index == 1 && with_bias()) return &bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2 + with_bias(); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t src_md_; + memory_desc_t weights_md_; + memory_desc_t bias_md_; + memory_desc_t dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, weights_md_, dst_md_, + &bias_md_); + } +}; + +struct inner_product_bwd_data_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_data_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_data_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , weights_md_(desc_.weights_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *weights_md(int index = 0) const override + { return index == 0 ? &weights_md_ : nullptr; } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t weights_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(diff_src_md_, weights_md_, + diff_dst_md_, nullptr); + } +}; + +struct inner_product_bwd_weights_pd_t: public inner_product_pd_t { + typedef inner_product_bwd_weights_pd_t base_class; + typedef inner_product_fwd_pd_t hint_class; + + inner_product_bwd_weights_pd_t(engine_t *engine, + const inner_product_desc_t *adesc, + const primitive_attr_t *attr, + const inner_product_fwd_pd_t *hint_fwd_pd) + : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , diff_weights_md_(desc_.diff_weights_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_WEIGHTS) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *diff_weights_md(int index = 0) const override { + if (index == 0) return &diff_weights_md_; + if (index == 1 && with_bias()) return &diff_bias_md_; + return nullptr; + } + + virtual int n_inputs() const override { return 2; } + virtual int n_outputs() const override { return 1 + with_bias(); } + +protected: + memory_desc_t src_md_; + memory_desc_t diff_weights_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_md_; + + status_t set_default_params() { + return template_set_default_params(src_md_, diff_weights_md_, + diff_dst_md_, &diff_bias_md_); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp new file mode 100644 index 0000000000..fcf18b556f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t lrn_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc, + dim_t local_size, float alpha, float beta, float k) { + bool args_ok = true + && !any_null(lrn_desc, data_desc) + && one_of(alg_kind, lrn_within_channel, lrn_across_channels) + && one_of(prop_kind, forward_training, forward_inference, backward_data) + && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr); + if (!args_ok) return invalid_arguments; + + auto ld = lrn_desc_t(); + ld.primitive_kind = primitive_kind::lrn; + ld.prop_kind = prop_kind; + ld.alg_kind = alg_kind; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + ld.data_desc = *data_desc; + if (!is_fwd) + ld.diff_data_desc = *diff_data_desc; + else + ld.diff_data_desc = zero_md(); + ld.local_size = local_size; + ld.lrn_alpha = alpha; + ld.lrn_beta = beta; + ld.lrn_k = k; + + bool consistency = true + && ld.data_desc.ndims == 4; + if (ld.prop_kind == backward_data) + consistency = consistency + && ld.diff_data_desc.ndims == 4 + && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4); + if (!consistency) return invalid_arguments; + + *lrn_desc = ld; + return success; +} +} + +status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *data_desc, dim_t local_size, float alpha, + float beta, float k) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr, + local_size, alpha, beta, k); +} + +status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc, + alg_kind_t alg_kind, const memory_desc_t *data_desc, + const memory_desc_t *diff_data_desc, dim_t local_size, float alpha, + float beta, float k) { + return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc, + diff_data_desc, local_size, alpha, beta, k); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp new file mode 100644 index 0000000000..90886e9656 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef LRN_PD_HPP +#define LRN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct lrn_fwd_pd_t; + +struct lrn_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::lrn; + + lrn_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + , ws_md_() + {} + + const lrn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::lrn_d: + *(const lrn_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common lrn aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + lrn_desc_t desc_; + const lrn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + memory_desc_t ws_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct lrn_fwd_pd_t: public lrn_pd_t { + typedef lrn_fwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_fwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct lrn_bwd_pd_t: public lrn_pd_t { + typedef lrn_bwd_pd_t base_class; + typedef lrn_fwd_pd_t hint_class; + + lrn_bwd_pd_t(engine_t *engine, + const lrn_desc_t *adesc, + const primitive_attr_t *attr, + const lrn_fwd_pd_t *hint_fwd_pd) + : lrn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_data_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp new file mode 100644 index 0000000000..3fddc0bd45 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MATH_UTILS_HPP +#define MATH_UTILS_HPP + +#include <stdint.h> +#include <math.h> + +#include "utils.hpp" +#include "nstl.hpp" +#include "mkldnn_traits.hpp" + +#if defined(MKLDNN_X86_64) +#include "immintrin.h" +#endif + +namespace mkldnn { +namespace impl { +namespace math { + +/** rounds @p f to an integer according to the mxcsr register */ +inline int mxcsr_round(float f) { +#if defined(MKLDNN_X86_64) + return _mm_cvtss_si32(_mm_load_ss(&f)); +#else + return (int)nearbyintf(f); // optimism +#endif +} + +template <typename data_t, typename acc_t> +inline typename utils::enable_if<!nstl::is_integral<data_t>::value, + typename utils::remove_reference<data_t>::type>::type +saturate(const acc_t &x) { + return (typename utils::remove_reference<data_t>::type)x; +} + +template <typename data_t, typename acc_t> +inline typename utils::enable_if<nstl::is_integral<data_t>::value, + typename utils::remove_reference<data_t>::type>::type +saturate(const acc_t &x) { + acc_t v = x; + if (v < (acc_t)nstl::numeric_limits<data_t>::lowest()) + v = (acc_t)nstl::numeric_limits<data_t>::lowest(); + if (v > (acc_t)nstl::numeric_limits<data_t>::max()) + v = (acc_t)nstl::numeric_limits<data_t>::max(); + return (typename utils::remove_reference<data_t>::type)v; +} + +template <typename data_t> +double saturate(const double &x) { + double v = x; + if (v < (double)nstl::numeric_limits<data_t>::lowest()) + v = (double)nstl::numeric_limits<data_t>::lowest(); + if (v > (double)nstl::numeric_limits<data_t>::max()) + v = (double)nstl::numeric_limits<data_t>::max(); + return v; +} + +template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) { + return x <= 127u ? x : 127; +} + +template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) { + return x >= 0 ? x : 0; +} + +template <typename out_t> +typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type +out_round(float v) { return (out_t)mxcsr_round(v); } + +template <typename out_t> +typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type +out_round(double v) { return (out_t)mxcsr_round((float)v); } + +template <typename out_t> +typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type +out_round(float v) { return v; } + +inline int gcd(int a, int b) { + a = impl::nstl::abs(a); + b = impl::nstl::abs(b); + if (a < b) { int x = a; a = b; b = x; } + + if (b == 0) return a; + + int r; + while ((r = a % b) != 0) { a = b; b = r; } + + return b; +} + +template <typename T> +inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; } + +/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */ +inline int ilog2q(size_t v) { + if (v == 0) + return -1; + + int p = 0; +# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0) + CP(32); CP(16); CP(8); CP(4); CP(2); CP(1); +# undef CP + return p; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U one_m_square(T x) { + return (U)(1 - x) * (1 + x); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U x_m_square(T x) { + return (U)(1 - x) * x; +} + +/* activation */ +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U relu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(s * alpha); +} +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U relu_bwd(T dd, T s, A alpha) { + return s > 0 ? dd : (U)(dd * alpha); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U tanh_fwd(T s) { + const float e = tanhf((float) s); + return (U)e; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U tanh_bwd(T dd, T s) { + const float e = tanh_fwd<float>((float) s); + return (U)(dd * (1 - e) * (1 + e)); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U elu_fwd(T s, A alpha) { + return s > 0 ? s : (U)(alpha * (::expm1f((float)s))); +} +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> + inline U elu_bwd(T dd, T s, A alpha) { + return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s))); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U square_fwd(T s) { + return s * s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U square_bwd(T dd, T s) { + return dd * 2 * s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U abs_fwd(T s) { + return s > 0 ? s : -s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U abs_bwd(T dd, T s) { + return s > 0 ? dd : s < 0 ? -dd : 0; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U sqrt_fwd(T s) { + return s > 0 ? (U)(::sqrtf((float)(s))) : 0; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U sqrt_bwd(T dd, T s) { + return s > 0 + ? (U)(dd / (2 * ::sqrtf((float)(s)))) + : 0; +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U linear_fwd(T s, A alpha, A beta) { + return (U)(alpha * s + beta); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U linear_bwd(T dd, T s, A alpha, A beta) { + (void) s; + (void) beta; + return (U)(dd * alpha); +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U bounded_relu_fwd(T s, A alpha) { + s = s > 0 ? s : 0; + return s > alpha ? (U)(alpha) : s; +} + +template <typename T, typename A, + typename U = typename utils::remove_reference<T>::type> +inline U bounded_relu_bwd(T dd, T s, A alpha) { + return dd * (0 < s && s < alpha ? 1 : 0); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U soft_relu_fwd(T s) { + float max_logf = 8.872284e+01; //::logf(FLT_MAX) + return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s; +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U soft_relu_bwd(T dd, T s) { + return (U)(dd / (1 + ::expf((float)(-s)))); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U logistic_fwd(T s) { + U v = (U)(::expf((float) -s)); + return 1 / (1 + v); +} + +template <typename T, typename U = typename utils::remove_reference<T>::type> +inline U logistic_bwd(T dd, T s) { + U v = logistic_fwd<T, U>(s); + return dd * v * (1 - v); +} + +inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) { + using namespace alg_kind; + using namespace utils; + const bool preserves_zero = true + && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic) + && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh)); + return preserves_zero; +} + +inline float get_bias(const char *bias, size_t offset, data_type_t data_type) +{ + if (!bias) + return 0.0f; + +#define CASE(dt) \ + case dt: return (float)((const prec_traits<dt>::type *)bias)[offset] + + switch (data_type) { + CASE(data_type::s8); + CASE(data_type::u8); + CASE(data_type::s32); + CASE(data_type::f32); + default: assert(!"unimplemented"); + } + return 0; // never happens (should probably be a NaN) +#undef CASE +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp new file mode 100644 index 0000000000..cea849c96e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::data_type; + +namespace { +bool memory_desc_sanity_check(int ndims,const dims_t dims, + data_type_t data_type, format_kind_t format_kind) { + if (ndims == 0) return true; + + bool ok = true + && dims != nullptr + && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS + && one_of(data_type, f32, s32, s8, u8) + && format_kind != format_kind::undef; + if (!ok) return false; + for (int d = 0; d < ndims; ++d) + if (dims[d] < 0) return false; + + return true; +} + +bool memory_desc_sanity_check(const memory_desc_t *md) { + if (md == nullptr) return false; + return memory_desc_sanity_check(md->ndims, md->dims, md->data_type, + format_kind::any); +} +} + +status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims, + const dims_t dims, data_type_t data_type, format_tag_t tag) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0 || tag == format_tag::undef) { + *memory_desc = types::zero_md(); + return success; + } + + format_kind_t format_kind = types::format_tag_to_kind(tag); + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind; + + status_t status = success; + if (tag == format_tag::undef) { + status = invalid_arguments; + } else if (tag == format_tag::any) { + // nop + } else if (format_kind == format_kind::blocked) { + status = memory_desc_wrapper::compute_blocking(md, tag); + } else { + assert(!"unreachable"); + status = invalid_arguments; + } + + if (status == success) + *memory_desc = md; + + return status; +} + +status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc, + int ndims, const dims_t dims, data_type_t data_type, + const dims_t strides) { + if (any_null(memory_desc)) return invalid_arguments; + if (ndims == 0) { + *memory_desc = types::zero_md(); + return success; + } + + /* memory_desc != 0 */ + bool args_ok = !any_null(memory_desc) + && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any); + if (!args_ok) return invalid_arguments; + + auto md = memory_desc_t(); + md.ndims = ndims; + array_copy(md.dims, dims, ndims); + md.data_type = data_type; + array_copy(md.padded_dims, dims, ndims); + md.format_kind = format_kind::blocked; + + dims_t default_strides = {0}; + if (strides == nullptr) { + default_strides[md.ndims - 1] = 1; + for (int d = md.ndims - 2; d >= 0; --d) + default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1]; + strides = default_strides; + } else { + /* TODO: add sanity check for the provided strides */ + } + + array_copy(md.format_desc.blocking.strides, strides, md.ndims); + + *memory_desc = md; + + return status::success; +} + +status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md, + const memory_desc_t *parent_md, const dims_t dims, + const dims_t offsets) { + if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md)) + return invalid_arguments; + + const memory_desc_wrapper src_d(parent_md); + + for (int d = 0; d < src_d.ndims(); ++d) { + if (dims[d] < 0 || offsets[d] < 0 + || (offsets[d] + dims[d] > src_d.dims()[d])) + return invalid_arguments; + } + + if (src_d.format_kind() != format_kind::blocked) + return unimplemented; + + dims_t blocks; + src_d.compute_blocks(blocks); + + memory_desc_t dst_d = *parent_md; + auto &dst_d_blk = dst_d.format_desc.blocking; + + /* TODO: put this into memory_desc_wrapper */ + for (int d = 0; d < src_d.ndims(); ++d) { + /* very limited functionality for now */ + const bool ok = true + && offsets[d] % blocks[d] == 0 /* [r1] */ + && src_d.padded_offsets()[d] == 0 + && (false + || dims[d] % blocks[d] == 0 + || dims[d] < blocks[d]); + if (!ok) + return unimplemented; + + const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; + + dst_d.dims[d] = dims[d]; + dst_d.padded_dims[d] = is_right_border + ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d]; + dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; + dst_d.offset0 += /* [r1] */ + offsets[d] / blocks[d] * dst_d_blk.strides[d]; + } + + *md = dst_d; + + return success; +} + +int mkldnn_memory_desc_equal(const memory_desc_t *lhs, + const memory_desc_t *rhs) { + if (lhs == rhs) return 1; + if (any_null(lhs, rhs)) return 0; + return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); +} + +size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) { + if (md == nullptr) return 0; + return memory_desc_wrapper(*md).size(); +} + +status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md, + engine_t *engine, void *handle) { + if (any_null(memory, engine)) return invalid_arguments; + memory_desc_t z_md = types::zero_md(); + return engine->memory_create(memory, md ? md : &z_md, handle); +} + +status_t mkldnn_memory_get_memory_desc(const memory_t *memory, + const memory_desc_t **md) { + if (any_null(memory, md)) return invalid_arguments; + *md = memory->md(); + return success; +} + +status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) { + if (any_null(memory, engine)) return invalid_arguments; + *engine = memory->engine(); + return success; +} + +status_t mkldnn_memory_get_data_handle(const memory_t *memory, + void **handle) { + if (any_null(handle)) + return invalid_arguments; + if (memory == nullptr) { + *handle = nullptr; + return success; + } + return memory->get_data_handle(handle); +} + +status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) { + if (any_null(memory)) return invalid_arguments; + return memory->set_data_handle(handle); +} + +status_t mkldnn_memory_destroy(memory_t *memory) { + delete memory; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp new file mode 100644 index 0000000000..03dfee01ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_HPP +#define MEMORY_HPP + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" + +struct mkldnn_memory: public mkldnn::impl::c_compatible { + mkldnn_memory(mkldnn::impl::engine_t *engine, + const mkldnn::impl::memory_desc_t *md) + : engine_(engine), md_(*md) {} + virtual ~mkldnn_memory() {} + + /** allocates/initializes memory */ + virtual mkldnn::impl::status_t init() = 0; + + /** returns memory's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + /** returns memory's description */ + const mkldnn::impl::memory_desc_t *md() const { return &md_; } + + /** returns data handle */ + virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0; + + /** sets data handle */ + virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0; + + /** zeros padding */ + virtual mkldnn::impl::status_t zero_pad() const + { return mkldnn::impl::status::success; } + +protected: + mkldnn::impl::engine_t *engine_; + const mkldnn::impl::memory_desc_t md_; + +private: + mkldnn_memory() = delete; + mkldnn_memory(const mkldnn_memory &) = delete; + mkldnn_memory(mkldnn_memory &&) = delete; + mkldnn_memory &operator=(const mkldnn_memory &) = delete; + mkldnn_memory &operator=(mkldnn_memory &&) = delete; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp new file mode 100644 index 0000000000..8a99be33f3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp @@ -0,0 +1,212 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include <initializer_list> + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +status_t fill_blocked(memory_desc_t &md, + std::initializer_list<int> perm, + std::initializer_list<int> inner_blks, + std::initializer_list<int> inner_idxs) { + const bool ok = true + && perm.size() == (size_t)md.ndims + && inner_blks.size() == inner_idxs.size(); + if (!ok) return status::invalid_arguments; + + md.offset0 = 0; + + blocking_desc_t &blk = md.format_desc.blocking; + + dim_t block_size = 1; + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + + blk.inner_nblks = (int)inner_blks.size(); + + int iblk = 0; + for (const auto &b: inner_idxs) + blk.inner_idxs[iblk++] = b; + + iblk = 0; + for (const auto &b: inner_blks) { + int dim = blk.inner_idxs[iblk]; + block_size *= b; + blocks[dim] *= b; + blk.inner_blks[iblk++] = b; + } + + utils::array_set(md.padded_offsets, 0, md.ndims); + for (int d = 0; d < md.ndims; ++d) + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + + dim_t stride = block_size; + // if only we use C++14, the initializer_list would have rbegin()/rend()... + for (int d = 0; d < md.ndims; ++d) + stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d]; + + for (const auto &d: perm) { + if (md.padded_dims[d] == 0) { + blk.strides[d] = 1; + continue; + } + stride /= md.padded_dims[d] / blocks[d]; + blk.strides[d] = stride; + } + + assert(stride == block_size); + + return status::success; +} + +status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag) +{ + using namespace format_tag; + + if (memory_desc.ndims == 0) return status::invalid_arguments; + +# define C(tag, ... /* perm, inner_blks, inner_idxs */) \ + case tag: return fill_blocked(memory_desc, __VA_ARGS__) + + switch (tag) { + C(a, {0}, {}, {}); + C(ab, {0, 1}, {}, {}); + C(abc, {0, 1, 2}, {}, {}); + C(abcd, {0, 1, 2, 3}, {}, {}); + C(abcde, {0, 1, 2, 3, 4}, {}, {}); + C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {}); + C(abdec, {0, 1, 3, 4, 2}, {}, {}); + C(acb, {0, 2, 1}, {}, {}); + C(acbde, {0, 2, 1, 3, 4}, {}, {}); + C(acdb, {0, 2, 3, 1}, {}, {}); + C(acdeb, {0, 2, 3, 4, 1}, {}, {}); + C(ba, {1, 0}, {}, {}); + C(bac, {1, 0, 2}, {}, {}); + C(bacd, {1, 0, 2, 3}, {}, {}); + C(bcda, {1, 2, 3, 0}, {}, {}); + C(cba, {2, 1, 0}, {}, {}); + C(cdba, {2, 3, 1, 0}, {}, {}); + C(cdeba, {2, 3, 4, 1, 0}, {}, {}); + C(decab, {3, 4, 2, 0, 1}, {}, {}); + + C(Abc4a, {0, 1, 2}, {4}, {0}); + C(aBc4b, {0, 1, 2}, {4}, {1}); + C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1}); + C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0}); + C(Abcd4a, {0, 1, 2, 3}, {4}, {0}); + C(aBcd4b, {0, 1, 2, 3}, {4}, {1}); + C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0}); + C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2}); + C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0}); + C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1}); + C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0}); + C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1}); + C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1}); + C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1}); + C(aBdc4b, {0, 1, 3, 2}, {4}, {1}); + C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1}); + C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1}); + C(Acb4a, {0, 2, 1}, {4}, {0}); + C(Acdb4a, {0, 2, 3, 1}, {4}, {0}); + C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0}); + + C(Abc16a, {0, 1, 2}, {16}, {0}); + C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1}); + C(aBc16b, {0, 1, 2}, {16}, {1}); + C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0}); + C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0}); + C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1}); + C(aBc8b, {0, 1, 2}, {8}, {1}); + C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1}); + C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0}); + C(Abcd16a, {0, 1, 2, 3}, {16}, {0}); + C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1}); + C(aBcd16b, {0, 1, 2, 3}, {16}, {1}); + C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0}); + C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2}); + C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1}); + C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1}); + C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0}); + C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1}); + C(aBcd8b, {0, 1, 2, 3}, {8}, {1}); + C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1}); + C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1}); + C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0}); + C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2}); + C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2}); + C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1}); + C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0}); + C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1}); + C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1}); + C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0}); + C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2}); + C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1}); + C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2}); + C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2}); + C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2}); + C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0}); + C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1}); + C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1}); + C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1}); + C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1}); + C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0}); + C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2}); + C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2}); + C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1}); + C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1}); + C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2}); + C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1}); + C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2}); + C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2}); + C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1}); + C(aBdc16b, {0, 1, 3, 2}, {16}, {1}); + C(aBdc8b, {0, 1, 3, 2}, {8}, {1}); + C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1}); + C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1}); + C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1}); + C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1}); + C(Acb16a, {0, 2, 1}, {16}, {0}); + C(Acb8a, {0, 2, 1}, {8}, {0}); + C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2}); + C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2}); + C(Acdb16a, {0, 2, 3, 1}, {16}, {0}); + C(Acdb8a, {0, 2, 3, 1}, {8}, {0}); + C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0}); + C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0}); + C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1}); + C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1}); + default: break; + } + +#undef C + + return status::invalid_arguments; +} + +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp new file mode 100644 index 0000000000..1758f9078a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_DESC_WRAPPER_HPP +#define MEMORY_DESC_WRAPPER_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +/** thin wrapper class over \struct memory_desc_t which allows easy + * manipulations with underlying C structure, which is taken by reference */ +struct memory_desc_wrapper: public c_compatible { + const memory_desc_t *md_; + + /** constructor which takes a reference to a constant underlying C memory + * descriptor \param md */ + memory_desc_wrapper(const memory_desc_t *md): md_(md) {} + memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {} + + /* implementing attributes */ + int ndims() const { return md_->ndims; } + const dims_t &dims() const { return md_->dims; } + data_type_t data_type() const { return md_->data_type; } + + const dims_t &padded_dims() const { return md_->padded_dims; } + const dims_t &padded_offsets() const { return md_->padded_offsets; } + dim_t offset0() const { return md_->offset0; } + + format_kind_t format_kind() const { return md_->format_kind; } + + bool is_blocking_desc() const + { return format_kind() == format_kind::blocked; } + bool is_wino_desc() const + { return format_kind() == format_kind::wino; } + bool is_rnn_packed_desc() const + { return format_kind() == format_kind::rnn_packed; } + + const blocking_desc_t &blocking_desc() const { + assert(is_blocking_desc()); + return md_->format_desc.blocking; + } + const wino_desc_t &wino_desc() const { + assert(is_wino_desc()); + return md_->format_desc.wino_desc; + } + const rnn_packed_desc_t &rnn_packed_desc() const { + assert(is_rnn_packed_desc()); + return md_->format_desc.rnn_packed_desc; + } + + const memory_extra_desc_t &extra() const { return md_->extra; } + + /* some useful function */ + + /** returns the number of elements including padding if \param with_padding + * is true, and the number of data elements otherwise */ + dim_t nelems(bool with_padding = false) const { + if (is_zero()) return 0; + return utils::array_product( + with_padding ? padded_dims() : dims(), ndims()); + } + + /** returns true if memory descriptor is zero */ + bool is_zero() const { return ndims() == 0; } + + /** returns true if memory descriptor contains zero as one of its dim */ + bool has_zero_dim() const { return nelems() == 0; } + + /** return the size of data type (a shortcut) */ + size_t data_type_size() const + { return types::data_type_size(data_type()); } + + /** return the size of data type of additional buffer */ + size_t additional_buffer_data_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) + return sizeof(int32_t); + return 0; + } + + /** return true if memory format has additional buffer */ + bool is_additional_buffer() const { + return (extra().flags & memory_extra_flags::compensation_conv_s8s8); + } + + /** returns the size of additional buffer */ + size_t additional_buffer_size() const { + if (extra().flags & memory_extra_flags::compensation_conv_s8s8) { + int cmask = extra().compensation_mask; + assert(cmask == 1 || cmask == 3); + dim_t prod = 1; + for (int d = 0; d < ndims(); ++d) + if (cmask & (1<<d)) prod *= padded_dims()[d]; + return prod * additional_buffer_data_size(); + } + + return 0; + } + + /** returns the size required to store described memory + * note: if offset0 != 0 returns 0 (need to specify the behavior) */ + size_t size() const { + if (is_zero() || has_zero_dim() || format_kind() == format_kind::any) + return 0; + + if (format_kind() == format_kind::wino) { + return wino_desc().size; + } else if (format_kind() == format_kind::rnn_packed) { + return rnn_packed_desc().size; + } else { + if (offset0() != 0) return 0; + + dims_t blocks = {0}; + compute_blocks(blocks); + + const auto &bd = blocking_desc(); + + size_t max_size = 0; + for (int d = 0; d < ndims(); ++d) + max_size = nstl::max<size_t>(max_size, + padded_dims()[d] / blocks[d] * bd.strides[d]); + + if (max_size == 1 && bd.inner_nblks != 0) { + max_size = utils::array_product(bd.inner_blks, bd.inner_nblks); + } + + return max_size * data_type_size() + additional_buffer_size(); + } + } + + /** returns true if data is dense in memory */ + bool is_dense(bool with_padding = false) const { + if (utils::one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + return nelems(with_padding) * data_type_size() == size(); + } + + /** returns true if memory desc is fully defined */ + bool is_defined() const { return format_kind() != format_kind::any; } + + /** returns true if the only (potentially) padded dim is \param dim */ + bool only_padded_dim(int dim) const { + for (int d = 0; d < ndims(); ++d) + if (d != dim && dims()[d] != padded_dims()[d]) + return false; + return true; + } + + /** returns true if memory desc has blocked layout and block dims are 1s */ + bool is_plain() const { + if (!is_blocking_desc()) return false; + return blocking_desc().inner_nblks == 0; + } + + /** returns overall block sizes */ + void compute_blocks(dims_t blocks) const { + if (!is_blocking_desc()) { + utils::array_set(blocks, 0, ndims()); + return; + } + + utils::array_set(blocks, 1, ndims()); + + const auto &bd = blocking_desc(); + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; + } + + /* comparison section */ + + bool operator==(const memory_desc_wrapper &rhs) const + { return *this->md_ == *rhs.md_; } + bool operator!=(const memory_desc_wrapper &rhs) const + { return !operator==(rhs); } + bool operator==(const memory_desc_t &rhs) const + { return operator==(memory_desc_wrapper(rhs)); } + bool operator!=(const memory_desc_t &rhs) const + { return !operator==(rhs); } + + /** returns true if data (w/o padding if with_padding == false and w/ + * padding otherwise) have the same physical structure, i.e. dimensions, + * strides, and blocked structure. Depending on with_data_type flag + * data_type is taken or not taken into account. dim_start allows to check + * similarity for the logical part of data [dim_start .. ndims()]. + * CAUTION: format kind any and undef are not similar to whatever, hence the + * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */ + /* TODO: revise */ + bool similar_to(const memory_desc_wrapper &rhs, + bool with_padding = true, bool with_data_type = true, + int dim_start = 0) const; + + /** returns true if one memory can be reordered to another */ + bool consistent_with(const memory_desc_wrapper &rhs) const; + + /** returns true if the memory desc corresponds to the given format tag and + * strides. + * @sa memory_desc_matches_tag */ + bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const { + return memory_desc_matches_tag(*md_, tag, strides); + } + + /** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ + template <typename... Tags> + format_tag_t matches_one_of_tag(Tags ...tags) const { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(*md_, tag)) + return tag; + } + return format_tag::undef; + } + + /* offset section */ + + /** returns physical offset by logical one. logical offset is represented by + * an array \param pos. if \param is_pos_padded is true \param pos + * represents the position in already padded area */ + dim_t off_v(const dims_t pos, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + const blocking_desc_t &blk = blocking_desc(); + + dims_t pos_copy = {0}; + for (int d = 0; d < ndims(); ++d) + pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]); + + dim_t phys_offset = offset0(); + + if (blk.inner_nblks > 0) { + dim_t blk_stride = 1; + for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) { + const int d = blk.inner_idxs[iblk]; + const dim_t p = pos_copy[d] % blk.inner_blks[iblk]; + + phys_offset += p * blk_stride; + + pos_copy[d] /= blk.inner_blks[iblk]; + + blk_stride *= blk.inner_blks[iblk]; + } + } + + for (int d = 0; d < ndims(); ++d) { + const dim_t p = pos_copy[d]; + phys_offset += p * blk.strides[d]; + } + + return phys_offset; + } + + /** returns physical offset by logical one. logical offset is represented by + * a scalar \param l_offset. if \param is_pos_padded is true, \param + * l_offset represents logical offset in already padded area */ + dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const { + assert(is_blocking_desc()); + dims_t pos; + for (int rd = 0; rd < ndims(); ++rd) { + const int d = ndims() - 1 - rd; + const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d]; + pos[d] = l_offset % cur_dim; + l_offset /= cur_dim; + } + return off_v(pos, is_pos_padded); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) */ + template<typename... Args> + dim_t off(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, false); + } + + /** returns physical offset by logical one. logical offset is represented by + * a tuple of indices (\param xn, ..., \param x1, \param x0) in already + * padded area */ + template<typename... Args> + dim_t off_padding(Args... args) const { + assert(sizeof...(args) == ndims()); + dims_t pos = { args... }; + return off_v(pos, true); + } + + /** returns physical offset by logical one. Logical offset is represented by + * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a + * user responsibility to adjust the result to get offset within blocks */ + template<typename ...Args> + dim_t blk_off(Args... args) const { + return _blk_off<sizeof...(args), Args...>(args...); + } + + template<bool skip_first, typename T, typename ...Args> + dim_t blk_off(T xn, Args... args) const { + return skip_first + ? blk_off<Args...>(args...) + : blk_off<T, Args...>(xn, args...); + } + + /* static functions section */ + /* TODO: replace with non-static, once md_ becomes non-const ref */ + + static status_t compute_blocking(memory_desc_t &memory_desc, + format_tag_t tag); + +private: + /* TODO: put logical_offset in utils */ + template<typename T> + dim_t logical_offset(T x0) const { return x0; } + + template<typename T, typename... Args> + dim_t logical_offset(T xn, Args... args) const { + const size_t n_args = sizeof...(args); + return xn * utils::array_product<n_args>( + &dims()[ndims() - n_args]) + logical_offset(args...); + } + + template<int ORIG_LEN, typename ...Void> + dim_t _blk_off() const { return offset0(); } + + template<int ORIG_LEN, typename T, typename ...Args> + dim_t _blk_off(T xc, Args ...args) const { + assert(is_blocking_desc()); + constexpr int dc = ORIG_LEN - sizeof...(args) - 1; + return xc * blocking_desc().strides[dc] + + _blk_off<ORIG_LEN, Args...>(args...); + } +}; + +inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs, + bool with_padding, bool with_data_type, int dim_start) const { + using namespace utils; + + if (one_of(format_kind(), format_kind::undef, format_kind::any)) + return false; + if (is_wino_desc() || is_rnn_packed_desc()) + return false; + + const int ds = dim_start; + const auto &blk = blocking_desc(); + const auto &r_blk = rhs.blocking_desc(); + + return ndims() == rhs.ndims() + && dim_start <= ndims() /* guard */ + && format_kind() == rhs.format_kind() + && IMPLICATION(with_data_type, data_type() == rhs.data_type()) + && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds) + && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds) + && blk.inner_nblks == r_blk.inner_nblks + && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks) + && IMPLICATION(with_padding, true + && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds, + ndims() - ds) + && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds, + ndims() - ds)); +} + +inline bool memory_desc_wrapper::consistent_with( + const memory_desc_wrapper &rhs) const { + if (ndims() == rhs.ndims()) { + for (int d = 0; d < ndims(); ++d) { + if (dims()[d] != rhs.dims()[d]) return false; + } + return true; + } else { + /* TODO: revise. + * is the following possible? + * [1, a, b] <--reorder--> [a, b] + * [a, 1, b] <--reorder--> [a, b] + * not, at least for now */ + return false; + } +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp new file mode 100644 index 0000000000..ec077b308c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp @@ -0,0 +1,295 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MEMORY_TRACKING_HPP +#define MEMORY_TRACKING_HPP + +#include <assert.h> +#include <unordered_map> + +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace memory_tracking { + +/* Memory tracking capabilities + * + * The main purpose of this header file is to provide uniform way to register + * required memory for a scratchpad at a primitive descriptor creation time + * and then easily access it having only the base address of the scratchpad. + * + * Primitives might contain multiple disjoint parts that require temporary + * buffers (known as scratchpad) during their execution. A primitive descriptor + * should summarize all the needs into one single number -- the buffer size + * that would be requested from a user. At execution time, the corresponding + * primitive will receive a base pointer to a scratchpad. It then needs to + * provide each part of algorithm the corresponding piece of memory. Three main + * challenges here are: + * 1. Track correct offset (from the base scratchpad address) for each piece + * 2. Algorithm might require that different memory pieces to be aligned, so + * the scratchpad size is no more just a sum of size of the corresponding + * subparts. + * 3. While a primitive is responsible for its scratchpad, the implementation + * might use some other basic blocks (e.g. cpu_reducer) that also require + * scratchpad memory. So there should be a simple way of passing the + * information back and force between the main algorithm (a primitive) and + * auxiliary stuff that lives completely separately from it (e.g. reducer). + * + * To address these challenges this header file provides 3 structures: + * 1. registry_t -- the class the stores the information about requested + * memory. The information includes required size and desired + * alignment for each piece. This class is also responsible + * for computing the right offset to a given piece using the + * base pointer. + * This class is basically a ledger with all entries. + * Lives in primitive descriptors. + * + * 2. registrar_t -- the interface to a registry_t to book memory. Used at + * primitive descriptor creation time only. Contains a + * reference to the corresponding *mutable* registry. + * Always modifiable. + * Allows chaining (using prefixes). + * + * 3. grantor_t -- the interface to a registry_t to access memory. Used at + * primitive execution time only. Contains a reference to + * the corresponding *constant* registry and base pointer. + * Always constant. + * Allows chaining (using prefixes). + * + * Both registrar_t and grantor_t allow chaining with extra prefix provided. + * The feature is useful when a primitive offload a part of computations to + * some other primitives which require their own scratchpad space + * (e.g. reducer). Prefixes are used to avoid key collision in cases when + * multiple sub-primitive (e.g. multiple reducers) are used. + * + * A short example below demonstrates how to use aforementioned classes. In it + * the main primitive is convolution that uses scratchpad for keeping padded + * bias. It also needs a reducer, that needs its own space as well. + * + * ``` c++ + * struct reducer_t { + * static void init(registrar_t &scratchpad) { + * // preserve space for the reduction (one page aligned) + * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096); + * } + * + * void exec(const grantor_t &scratchpad) { + * // get the pointer to preserved space. scratchpad came from + * // upper primitive (convolution in this example) + * auto space = scratchpad.get<float>(key_reducer_space); + * + * space[:] += ...; + * } + * }; + * + * struct conv_t { + * struct pd_t { + * void init() { + * registrar_t scratchpad(scratchpad_registry_); + * + * // preserve a space for padded bias (using default alignment) + * scratchpad.book(key_conv_padded_bias, 128); + * + * // create a proxy registrar for the reducer All entries made + * // by reducer would live in convolution's registry, but would + * // have their own `prefix`, so no interference with conv's + * // buffers. + * registrar_t reducer_scratchpad(scratchpad, prefix_reducer); + * + * reducer_t::init(reducer_scratchpad); + * } + * + * registry_t scratchpad_registry_; + * } + * + * void exec() { + * // get the base pointer to a scratchpad memory from a user + * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD); + * + * // create a grantor to the scratchpad (and provide the base + * // pointer). + * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr); + * + * // access the padded_bias (need only key name and the grantor) + * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); + * + * // to give the `right` grantor to reducer we need to add the + * // corresponding prefix, so that reducer would be able to access + * // its keys. The call is very similar to the one in pd_t::init + * // with only difference in types: grantor_t vs registrar_t. + * grantor_t reducer_scratchpad(scratchpad, prefix_reducer); + * reducer->exec(reducer_scratchpad); + * } + * }; + * ``` + */ + + +/* namespace with common keys and prefixes */ +namespace names { +enum { + key_none = 0, + key_bnorm_tmp_mean, + key_bnorm_tmp_var, + key_bnorm_tmp_diff_ss, + key_bnorm_tmp_stats, + key_bnorm_reduction, + key_concat_iptrs, + key_concat_istrides, + key_concat_nelems, + key_concat_optrs, + key_conv_adjusted_scales, + key_conv_bia_reduction, + key_conv_gemm_col, + key_conv_gemm_imtr, + key_conv_int_dat_in_acc_dt, + key_conv_padded_bias, + key_conv_rtus_space, + key_conv_tr_diff_dst, + key_conv_tr_diff_dst_bctx, + key_conv_tr_src, + key_conv_tr_src_bctx, + key_conv_wei_reduction, + key_conv_wei_bia_reduction, + key_conv_wei_bia_reduction_bctx, + key_iprod_int_dat_in_acc_dt, + key_reducer_space, + key_reducer_space_bctx, + key_reorder_wino_plain, + key_reorder_wino_transform_space, + key_reorder_rnn_weights_quantization, + key_reorder_rnn_weights_reduction, + key_rnn_space, + key_rnn_ptrs_bia, + key_rnn_ptrs_wei_layer, + key_rnn_ptrs_wei_iter, + key_softmax_reduction, + key_wino_U, + key_wino_V, + key_wino_M, + key_barrier, +}; + +enum { + prefix_none = 0, + prefix_reducer_bia, + prefix_reducer_wei, +}; +} + +// level 0: 00 00 00 xxx +// level 1: 00 00 aa xxx +// level 2: 00 aa bb xxx +// level 3: aa bb cc xxx +// max # of levels: 3 + 1 (base_level) +// here: +// xxx : [1 .. MAX_KEY) : key +// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3 + +using key_t = uint32_t; +enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), }; + +/// generates global key based on a prefix and a local key +inline key_t make_key(key_t prefix, key_t key) { return prefix + key; } + +/// generates global prefix based on the global parent and the local ones +inline key_t make_prefix(key_t parent_prefix, key_t prefix) +{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; } + +struct registrar_t; +struct grantor_t; + +struct registry_t { + void book(const key_t &key, size_t size, size_t alignment) { + if (size == 0) return; + assert(offset_map_.count(key) == 0); + + size = utils::rnd_up(size, minimal_alignment); + alignment = nstl::max<size_t>(alignment, minimal_alignment); + offset_map_[key] = entry_t{size_, size, alignment}; + + size_ += size + alignment - minimal_alignment; + } + + void *get(const key_t &key, void *base_ptr) const { + if (base_ptr == nullptr) { assert(size() == 0); return nullptr; } + if (offset_map_.count(key) != 1) return nullptr; + + const auto &e = offset_map_.at(key); + base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment); + char *ptr = (char *)base_ptr + e.offset; + return utils::align_ptr<void>(ptr, e.alignment); + } + + size_t size() const + { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; } + + registrar_t registrar(); + grantor_t grantor(void *base_ptr) const; + +protected: + enum { minimal_alignment = 64 }; + struct entry_t { size_t offset, size, alignment; }; + + std::unordered_map<key_t, entry_t> offset_map_; + size_t size_ = 0; +}; + +struct registrar_t { + enum { default_alignment = 64 }; + + registrar_t(registry_t ®istry): registry_(registry), prefix_(0) {} + registrar_t(registrar_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) {} + + void book(const key_t &key, size_t size, + size_t alignment = default_alignment) + { registry_.book(make_key(prefix_, key), size, alignment); } + +protected: + registry_t ®istry_; + const key_t prefix_; +}; + +struct grantor_t { + grantor_t(const registry_t ®istry, void *base_ptr) + : registry_(registry), prefix_(0), base_ptr_(base_ptr) {} + grantor_t(const grantor_t &parent, const key_t &prefix) + : registry_(parent.registry_) + , prefix_(make_prefix(parent.prefix_, prefix)) + , base_ptr_(parent.base_ptr_) {} + + template <typename T = void> T *get(const key_t &key) const + { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); } + +protected: + const registry_t ®istry_; + const key_t prefix_; + void *base_ptr_; +}; + +inline registrar_t registry_t::registrar() { return registrar_t(*this); } +inline grantor_t registry_t::grantor(void *base_ptr) const +{ return grantor_t(*this, base_ptr); } + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp new file mode 100644 index 0000000000..2ef4a8fddc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <stdio.h> +#include <cinttypes> + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#define DPRINT(...) do { \ + int l = snprintf(str + written_len, str_len, __VA_ARGS__); \ + if (l < 0) return l; \ + if ((size_t)l >= str_len) return -1; \ + written_len += l; str_len -= l; \ +} while(0) + +int mkldnn_md2fmt_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1u) + return -1; + + int written_len = 0; + + if (mdesc == nullptr) { + DPRINT("%s::%s::", + mkldnn_dt2str(data_type::undef), + mkldnn_fmt_kind2str(format_kind::undef)); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + DPRINT("%s:", mkldnn_dt2str(md.data_type())); + + bool padded_dims = false, padded_offsets = false; + for (int d = 0; d < md.ndims(); ++d) { + if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true; + if (md.padded_offsets()[d] != 0) padded_offsets = true; + } + bool offset0 = md.offset0(); + DPRINT("%s%s%s:", + padded_dims ? "p" : "", + padded_offsets ? "o" : "", + offset0 ? "0" : ""); + + DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind())); + + if (!md.is_blocking_desc()) { + /* TODO: extend */ + DPRINT("%s:", ""); + } else { + const auto &blk = md.blocking_desc(); + + dims_t blocks; + md.compute_blocks(blocks); + + char dim_chars[MKLDNN_MAX_NDIMS + 1]; + + bool plain = true; + for (int d = 0; d < md.ndims(); ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d; + if (blocks[d] != 1) plain = false; + } + + dims_t strides; + utils::array_copy(strides, blk.strides, md.ndims()); + utils::simultaneous_sort(strides, dim_chars, md.ndims(), + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md.ndims()] = '\0'; + DPRINT("%s", dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + DPRINT("%d%c", (int)blk.inner_blks[iblk], + 'a' + (char)blk.inner_idxs[iblk]); + } + } + + DPRINT("%s", ":"); + } + + DPRINT("f%lx", (long)md.extra().flags); + + return written_len; +} + +int mkldnn_md2dim_str(char *str, size_t str_len, + const mkldnn_memory_desc_t *mdesc) { + using namespace mkldnn::impl; + + if (str == nullptr || str_len <= 1) + return -1; + + int written_len = 0; + + if (mdesc == nullptr || mdesc->ndims == 0) { + DPRINT("%s", ""); + return written_len; + } + + memory_desc_wrapper md(mdesc); + + for (int d = 0; d < md.ndims() - 1; ++d) + DPRINT("%" PRId64 "x", md.dims()[d]); + DPRINT("%" PRId64, md.dims()[md.ndims() - 1]); + + return written_len; +} + +#undef DPRINT diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp new file mode 100644 index 0000000000..16a8f7ea5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp @@ -0,0 +1,365 @@ +/******************************************************************************* +* Copyright 2018-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* DO NOT EDIT, AUTO-GENERATED */ + +#include <assert.h> + +#include "mkldnn_debug.h" +#include "mkldnn_types.h" + +const char *mkldnn_status2str(mkldnn_status_t v) { + if (v == mkldnn_success) return "success"; + if (v == mkldnn_out_of_memory) return "out_of_memory"; + if (v == mkldnn_try_again) return "try_again"; + if (v == mkldnn_invalid_arguments) return "invalid_arguments"; + if (v == mkldnn_not_ready) return "not_ready"; + if (v == mkldnn_unimplemented) return "unimplemented"; + if (v == mkldnn_iterator_ends) return "iterator_ends"; + if (v == mkldnn_runtime_error) return "runtime_error"; + if (v == mkldnn_not_required) return "not_required"; + assert(!"unknown status"); + return "unknown status"; +} + +const char *mkldnn_dt2str(mkldnn_data_type_t v) { + if (v == mkldnn_data_type_undef) return "undef"; + if (v == mkldnn_f32) return "f32"; + if (v == mkldnn_s32) return "s32"; + if (v == mkldnn_s8) return "s8"; + if (v == mkldnn_u8) return "u8"; + assert(!"unknown dt"); + return "unknown dt"; +} + +const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) { + if (v == mkldnn_format_kind_undef) return "undef"; + if (v == mkldnn_format_kind_any) return "any"; + if (v == mkldnn_blocked) return "blocked"; + if (v == mkldnn_format_kind_wino) return "wino"; + if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed"; + assert(!"unknown fmt_kind"); + return "unknown fmt_kind"; +} + +const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) { + if (v == mkldnn_format_tag_undef) return "undef"; + if (v == mkldnn_format_tag_any) return "format_tag_any"; + if (v == mkldnn_a) return "a"; + if (v == mkldnn_ab) return "ab"; + if (v == mkldnn_abc) return "abc"; + if (v == mkldnn_abcd) return "abcd"; + if (v == mkldnn_abcde) return "abcde"; + if (v == mkldnn_abcdef) return "abcdef"; + if (v == mkldnn_abdec) return "abdec"; + if (v == mkldnn_acb) return "acb"; + if (v == mkldnn_acbde) return "acbde"; + if (v == mkldnn_acdb) return "acdb"; + if (v == mkldnn_acdeb) return "acdeb"; + if (v == mkldnn_ba) return "ba"; + if (v == mkldnn_bac) return "bac"; + if (v == mkldnn_bacd) return "bacd"; + if (v == mkldnn_bcda) return "bcda"; + if (v == mkldnn_cba) return "cba"; + if (v == mkldnn_cdba) return "cdba"; + if (v == mkldnn_cdeba) return "cdeba"; + if (v == mkldnn_decab) return "decab"; + if (v == mkldnn_Abc16a) return "Abc16a"; + if (v == mkldnn_ABc16a16b) return "ABc16a16b"; + if (v == mkldnn_aBc16b) return "aBc16b"; + if (v == mkldnn_ABc16b16a) return "ABc16b16a"; + if (v == mkldnn_Abc4a) return "Abc4a"; + if (v == mkldnn_aBc4b) return "aBc4b"; + if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b"; + if (v == mkldnn_ABc4b4a) return "ABc4b4a"; + if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a"; + if (v == mkldnn_ABc8a8b) return "ABc8a8b"; + if (v == mkldnn_aBc8b) return "aBc8b"; + if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b"; + if (v == mkldnn_ABc8b8a) return "ABc8b8a"; + if (v == mkldnn_Abcd16a) return "Abcd16a"; + if (v == mkldnn_ABcd16a16b) return "ABcd16a16b"; + if (v == mkldnn_aBcd16b) return "aBcd16b"; + if (v == mkldnn_ABcd16b16a) return "ABcd16b16a"; + if (v == mkldnn_aBCd16b16c) return "aBCd16b16c"; + if (v == mkldnn_aBCd16c16b) return "aBCd16c16b"; + if (v == mkldnn_Abcd4a) return "Abcd4a"; + if (v == mkldnn_aBcd4b) return "aBcd4b"; + if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b"; + if (v == mkldnn_ABcd4b4a) return "ABcd4b4a"; + if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c"; + if (v == mkldnn_aBCd4c4b) return "aBCd4c4b"; + if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a"; + if (v == mkldnn_ABcd8a8b) return "ABcd8a8b"; + if (v == mkldnn_aBcd8b) return "aBcd8b"; + if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b"; + if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b"; + if (v == mkldnn_ABcd8b8a) return "ABcd8b8a"; + if (v == mkldnn_aBCd8b8c) return "aBCd8b8c"; + if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c"; + if (v == mkldnn_aBCd8c8b) return "aBCd8c8b"; + if (v == mkldnn_Abcde16a) return "Abcde16a"; + if (v == mkldnn_ABcde16a16b) return "ABcde16a16b"; + if (v == mkldnn_aBcde16b) return "aBcde16b"; + if (v == mkldnn_ABcde16b16a) return "ABcde16b16a"; + if (v == mkldnn_aBCde16b16c) return "aBCde16b16c"; + if (v == mkldnn_aBCde16c16b) return "aBCde16c16b"; + if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c"; + if (v == mkldnn_Abcde4a) return "Abcde4a"; + if (v == mkldnn_aBcde4b) return "aBcde4b"; + if (v == mkldnn_ABcde4b4a) return "ABcde4b4a"; + if (v == mkldnn_aBCde4b4c) return "aBCde4b4c"; + if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c"; + if (v == mkldnn_aBCde4c4b) return "aBCde4c4b"; + if (v == mkldnn_Abcde8a) return "Abcde8a"; + if (v == mkldnn_ABcde8a8b) return "ABcde8a8b"; + if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b"; + if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b"; + if (v == mkldnn_ABcde8b8a) return "ABcde8b8a"; + if (v == mkldnn_aBCde8b8c) return "aBCde8b8c"; + if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c"; + if (v == mkldnn_aBCde8c8b) return "aBCde8c8b"; + if (v == mkldnn_aBcdef16b) return "aBcdef16b"; + if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c"; + if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b"; + if (v == mkldnn_aBcdef4b) return "aBcdef4b"; + if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b"; + if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c"; + if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c"; + if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b"; + if (v == mkldnn_aBdc16b) return "aBdc16b"; + if (v == mkldnn_aBdc4b) return "aBdc4b"; + if (v == mkldnn_aBdc8b) return "aBdc8b"; + if (v == mkldnn_aBdec16b) return "aBdec16b"; + if (v == mkldnn_aBdec4b) return "aBdec4b"; + if (v == mkldnn_aBdec8b) return "aBdec8b"; + if (v == mkldnn_aBdefc16b) return "aBdefc16b"; + if (v == mkldnn_aBdefc4b) return "aBdefc4b"; + if (v == mkldnn_aBdefc8b) return "aBdefc8b"; + if (v == mkldnn_Acb16a) return "Acb16a"; + if (v == mkldnn_Acb4a) return "Acb4a"; + if (v == mkldnn_Acb8a) return "Acb8a"; + if (v == mkldnn_aCBd16b16c) return "aCBd16b16c"; + if (v == mkldnn_aCBde16b16c) return "aCBde16b16c"; + if (v == mkldnn_Acdb16a) return "Acdb16a"; + if (v == mkldnn_Acdb4a) return "Acdb4a"; + if (v == mkldnn_Acdb8a) return "Acdb8a"; + if (v == mkldnn_Acdeb16a) return "Acdeb16a"; + if (v == mkldnn_Acdeb4a) return "Acdeb4a"; + if (v == mkldnn_Acdeb8a) return "Acdeb8a"; + if (v == mkldnn_BAc16a16b) return "BAc16a16b"; + if (v == mkldnn_BAcd16a16b) return "BAcd16a16b"; + if (v == mkldnn_format_tag_last) return "format_tag_last"; + if (v == mkldnn_x) return "x"; + if (v == mkldnn_nc) return "nc"; + if (v == mkldnn_cn) return "cn"; + if (v == mkldnn_ncw) return "ncw"; + if (v == mkldnn_nwc) return "nwc"; + if (v == mkldnn_nchw) return "nchw"; + if (v == mkldnn_nhwc) return "nhwc"; + if (v == mkldnn_chwn) return "chwn"; + if (v == mkldnn_ncdhw) return "ncdhw"; + if (v == mkldnn_ndhwc) return "ndhwc"; + if (v == mkldnn_oi) return "oi"; + if (v == mkldnn_io) return "io"; + if (v == mkldnn_oiw) return "oiw"; + if (v == mkldnn_wio) return "wio"; + if (v == mkldnn_oihw) return "oihw"; + if (v == mkldnn_hwio) return "hwio"; + if (v == mkldnn_ihwo) return "ihwo"; + if (v == mkldnn_iohw) return "iohw"; + if (v == mkldnn_oidhw) return "oidhw"; + if (v == mkldnn_dhwio) return "dhwio"; + if (v == mkldnn_goiw) return "goiw"; + if (v == mkldnn_goihw) return "goihw"; + if (v == mkldnn_hwigo) return "hwigo"; + if (v == mkldnn_giohw) return "giohw"; + if (v == mkldnn_goidhw) return "goidhw"; + if (v == mkldnn_tnc) return "tnc"; + if (v == mkldnn_ntc) return "ntc"; + if (v == mkldnn_ldsnc) return "ldsnc"; + if (v == mkldnn_ldigo) return "ldigo"; + if (v == mkldnn_ldgoi) return "ldgoi"; + if (v == mkldnn_ldgo) return "ldgo"; + if (v == mkldnn_nCdhw16c) return "nCdhw16c"; + if (v == mkldnn_nCdhw4c) return "nCdhw4c"; + if (v == mkldnn_nCdhw8c) return "nCdhw8c"; + if (v == mkldnn_nChw16c) return "nChw16c"; + if (v == mkldnn_nChw4c) return "nChw4c"; + if (v == mkldnn_nChw8c) return "nChw8c"; + if (v == mkldnn_nCw16c) return "nCw16c"; + if (v == mkldnn_nCw4c) return "nCw4c"; + if (v == mkldnn_nCw8c) return "nCw8c"; + if (v == mkldnn_IOw16o16i) return "IOw16o16i"; + if (v == mkldnn_OIw16i16o) return "OIw16i16o"; + if (v == mkldnn_OIw16o16i) return "OIw16o16i"; + if (v == mkldnn_Oiw16o) return "Oiw16o"; + if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i"; + if (v == mkldnn_OIw4i4o) return "OIw4i4o"; + if (v == mkldnn_Oiw4o) return "Oiw4o"; + if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i"; + if (v == mkldnn_OIw8i8o) return "OIw8i8o"; + if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o"; + if (v == mkldnn_OIw8o8i) return "OIw8o8i"; + if (v == mkldnn_Owi16o) return "Owi16o"; + if (v == mkldnn_Owi4o) return "Owi4o"; + if (v == mkldnn_Owi8o) return "Owi8o"; + if (v == mkldnn_IOhw16o16i) return "IOhw16o16i"; + if (v == mkldnn_Ohwi16o) return "Ohwi16o"; + if (v == mkldnn_Ohwi4o) return "Ohwi4o"; + if (v == mkldnn_Ohwi8o) return "Ohwi8o"; + if (v == mkldnn_OIhw16i16o) return "OIhw16i16o"; + if (v == mkldnn_OIhw16o16i) return "OIhw16o16i"; + if (v == mkldnn_Oihw16o) return "Oihw16o"; + if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i"; + if (v == mkldnn_OIhw4i4o) return "OIhw4i4o"; + if (v == mkldnn_Oihw4o) return "Oihw4o"; + if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i"; + if (v == mkldnn_OIhw8i8o) return "OIhw8i8o"; + if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o"; + if (v == mkldnn_OIhw8o8i) return "OIhw8o8i"; + if (v == mkldnn_Odhwi16o) return "Odhwi16o"; + if (v == mkldnn_Odhwi4o) return "Odhwi4o"; + if (v == mkldnn_Odhwi8o) return "Odhwi8o"; + if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o"; + if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i"; + if (v == mkldnn_Oidhw16o) return "Oidhw16o"; + if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o"; + if (v == mkldnn_Oidhw4o) return "Oidhw4o"; + if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i"; + if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o"; + if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i"; + if (v == mkldnn_Goiw16g) return "Goiw16g"; + if (v == mkldnn_gIOw16o16i) return "gIOw16o16i"; + if (v == mkldnn_gOIw16i16o) return "gOIw16i16o"; + if (v == mkldnn_gOIw16o16i) return "gOIw16o16i"; + if (v == mkldnn_gOiw16o) return "gOiw16o"; + if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i"; + if (v == mkldnn_gOIw4i4o) return "gOIw4i4o"; + if (v == mkldnn_gOiw4o) return "gOiw4o"; + if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i"; + if (v == mkldnn_gOIw8i8o) return "gOIw8i8o"; + if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o"; + if (v == mkldnn_gOIw8o8i) return "gOIw8o8i"; + if (v == mkldnn_gOwi16o) return "gOwi16o"; + if (v == mkldnn_gOwi4o) return "gOwi4o"; + if (v == mkldnn_gOwi8o) return "gOwi8o"; + if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i"; + if (v == mkldnn_gOhwi16o) return "gOhwi16o"; + if (v == mkldnn_gOhwi4o) return "gOhwi4o"; + if (v == mkldnn_gOhwi8o) return "gOhwi8o"; + if (v == mkldnn_Goihw16g) return "Goihw16g"; + if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o"; + if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i"; + if (v == mkldnn_gOihw16o) return "gOihw16o"; + if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i"; + if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i"; + if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o"; + if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i"; + if (v == mkldnn_gOihw4o) return "gOihw4o"; + if (v == mkldnn_Goihw8g) return "Goihw8g"; + if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i"; + if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o"; + if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o"; + if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i"; + if (v == mkldnn_gOdhwi16o) return "gOdhwi16o"; + if (v == mkldnn_gOdhwi4o) return "gOdhwi4o"; + if (v == mkldnn_gOdhwi8o) return "gOdhwi8o"; + if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o"; + if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i"; + if (v == mkldnn_gOidhw16o) return "gOidhw16o"; + if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o"; + if (v == mkldnn_gOidhw4o) return "gOidhw4o"; + if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i"; + if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o"; + if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i"; + assert(!"unknown fmt_tag"); + return "unknown fmt_tag"; +} + +const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) { + if (v == mkldnn_prop_kind_undef) return "undef"; + if (v == mkldnn_forward_training) return "forward_training"; + if (v == mkldnn_forward_inference) return "forward_inference"; + if (v == mkldnn_forward_scoring) return "forward_scoring"; + if (v == mkldnn_forward) return "forward"; + if (v == mkldnn_backward) return "backward"; + if (v == mkldnn_backward_data) return "backward_data"; + if (v == mkldnn_backward_weights) return "backward_weights"; + if (v == mkldnn_backward_bias) return "backward_bias"; + assert(!"unknown prop_kind"); + return "unknown prop_kind"; +} + +const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) { + if (v == mkldnn_undefined_primitive) return "undef"; + if (v == mkldnn_reorder) return "reorder"; + if (v == mkldnn_shuffle) return "shuffle"; + if (v == mkldnn_concat) return "concat"; + if (v == mkldnn_sum) return "sum"; + if (v == mkldnn_convolution) return "convolution"; + if (v == mkldnn_deconvolution) return "deconvolution"; + if (v == mkldnn_eltwise) return "eltwise"; + if (v == mkldnn_softmax) return "softmax"; + if (v == mkldnn_pooling) return "pooling"; + if (v == mkldnn_lrn) return "lrn"; + if (v == mkldnn_batch_normalization) return "batch_normalization"; + if (v == mkldnn_inner_product) return "inner_product"; + if (v == mkldnn_rnn) return "rnn"; + assert(!"unknown prim_kind"); + return "unknown prim_kind"; +} + +const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) { + if (v == mkldnn_alg_kind_undef) return "undef"; + if (v == mkldnn_convolution_direct) return "convolution_direct"; + if (v == mkldnn_convolution_winograd) return "convolution_winograd"; + if (v == mkldnn_convolution_auto) return "convolution_auto"; + if (v == mkldnn_deconvolution_direct) return "deconvolution_direct"; + if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd"; + if (v == mkldnn_eltwise_relu) return "eltwise_relu"; + if (v == mkldnn_eltwise_tanh) return "eltwise_tanh"; + if (v == mkldnn_eltwise_elu) return "eltwise_elu"; + if (v == mkldnn_eltwise_square) return "eltwise_square"; + if (v == mkldnn_eltwise_abs) return "eltwise_abs"; + if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt"; + if (v == mkldnn_eltwise_linear) return "eltwise_linear"; + if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu"; + if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu"; + if (v == mkldnn_eltwise_logistic) return "eltwise_logistic"; + if (v == mkldnn_pooling_max) return "pooling_max"; + if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding"; + if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding"; + if (v == mkldnn_pooling_avg) return "pooling_avg"; + if (v == mkldnn_lrn_across_channels) return "lrn_across_channels"; + if (v == mkldnn_lrn_within_channel) return "lrn_within_channel"; + if (v == mkldnn_vanilla_rnn) return "vanilla_rnn"; + if (v == mkldnn_vanilla_lstm) return "vanilla_lstm"; + if (v == mkldnn_vanilla_gru) return "vanilla_gru"; + if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset"; + assert(!"unknown alg_kind"); + return "unknown alg_kind"; +} + +const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) { + if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right"; + if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left"; + if (v == mkldnn_bidirectional_concat) return "bidirectional_concat"; + if (v == mkldnn_bidirectional_sum) return "bidirectional_sum"; + if (v == mkldnn_unidirectional) return "unidirectional"; + assert(!"unknown rnn_direction"); + return "unknown rnn_direction"; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp new file mode 100644 index 0000000000..7e5789e2c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_HPP +#define MKLDNN_THREAD_HPP + +#include "utils.hpp" +#include "z_magic.hpp" + +#define MKLDNN_THR_SEQ 0 +#define MKLDNN_THR_OMP 1 +#define MKLDNN_THR_TBB 2 + +/* Ideally this condition below should never happen (if the library is built + * using regular cmake). For the 3rd-party projects that build the library + * from the sources on their own try to guess the right threading... */ +#if !defined(MKLDNN_THR) +# define MKLDNN_THR MKLDNN_THR_TBB +#endif + +#if MKLDNN_THR == MKLDNN_THR_SEQ +#define MKLDNN_THR_SYNC 1 +inline int mkldnn_get_max_threads() { return 1; } +inline int mkldnn_get_num_threads() { return 1; } +inline int mkldnn_get_thread_num() { return 0; } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() {} + +#define PRAGMA_OMP(...) + +#elif MKLDNN_THR == MKLDNN_THR_OMP +#include <omp.h> +#define MKLDNN_THR_SYNC 1 + +inline int mkldnn_get_max_threads() { return omp_get_max_threads(); } +inline int mkldnn_get_num_threads() { return omp_get_num_threads(); } +inline int mkldnn_get_thread_num() { return omp_get_thread_num(); } +inline int mkldnn_in_parallel() { return omp_in_parallel(); } +inline void mkldnn_thr_barrier() { +# pragma omp barrier +} + +#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) + +#elif MKLDNN_THR == MKLDNN_THR_TBB +#include "tbb/task_arena.h" +#include "tbb/parallel_for.h" +#define MKLDNN_THR_SYNC 0 + +inline int mkldnn_get_max_threads() +{ return tbb::this_task_arena::max_concurrency(); } +inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); } +inline int mkldnn_get_thread_num() +{ return tbb::this_task_arena::current_thread_index(); } +inline int mkldnn_in_parallel() { return 0; } +inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); } + +#define PRAGMA_OMP(...) + +#endif + +/* MSVC still supports omp 2.0 only */ +#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define collapse(x) +# define PRAGMA_OMP_SIMD(...) +#else +# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) +#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER) + +namespace mkldnn { +namespace impl { + +inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; } + +template <typename T, typename U> +inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { + T n_min = 1; + T &n_my = n_end; + if (team <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else if (n_min == 1) { + // team = T1 + T2 + // n = T1*n1 + T2*n2 (n1 - n2 = 1) + T n1 = utils::div_up(n, (T)team); + T n2 = n1 - 1; + T T1 = n - n2 * (T)team; + n_my = (T)tid < T1 ? n1 : n2; + n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; + } + + n_end += n_start; +} + +} // namespace impl +} // namespace mkldnn + +#include "mkldnn_thread_parallel_nd.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp new file mode 100644 index 0000000000..50f9b29622 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP +#define MKLDNN_THREAD_PARALLEL_ND_HPP + +/* This header must be included by mkldnn_thread.hpp only */ + +/* Functions: + * - parallel(nthr, f) - executes f in parallel using at most + * nthr threads. If nthr equals 0 + * mkldnn_get_max_threads() threads is + * used + * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already + * created threads + * - parallel_nd(dims..., f) - creates a parallel section and then + * calls for_nd + * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then + * calls for_nd (mostly for convenience) + */ + +namespace mkldnn { +namespace impl { + +/* general parallelization */ +template <typename F> +void parallel(int nthr, F f) { + if (nthr == 0) nthr = mkldnn_get_max_threads(); +#if MKLDNN_THR == MKLDNN_THR_SEQ + assert(nthr == 1); + f(0, 1); +#elif MKLDNN_THR == MKLDNN_THR_OMP + if (nthr == 1) { f(0, 1); return; } +# pragma omp parallel num_threads(nthr) + f(mkldnn_get_thread_num(), mkldnn_get_num_threads()); +#elif MKLDNN_THR == MKLDNN_THR_TBB + if (nthr == 1) { f(0, 1); return; } + tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner()); +#endif +} + +/* for_nd section */ + +template <typename T0, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, F f) { + T0 start{0}, end{0}; + balance211(D0, nthr, ithr, start, end); + for (T0 d0 = start; d0 < end; ++d0) f(d0); +} + +template <typename T0, typename T1, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } +} + +template <typename T0, typename T1, typename T2, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename T5, typename F> +void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, + const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = start; iwork < end; ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } +} + +// Skip a lambda function in the parameter pack. +template <typename T> +constexpr size_t get_work_amount(const T &v) { return 1; } +template <typename T, typename ...Args> +constexpr size_t get_work_amount(const T &v, Args &&...args) +{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); } + +/* parallel_nd and parallel_nd_in_omp section */ + +#if MKLDNN_THR != MKLDNN_THR_TBB +template <typename ...Args> +void parallel_nd(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1; +# pragma omp parallel if (do_parallel) + { + const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads(); + const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num(); + for_nd(ithr, nthr, utils::forward<Args>(args)...); + } +#endif +} +#else // MKLDNN_THR != MKLDNN_THR_TBB + +// gcc 4.8 has a bug with passing parameter pack to lambdas. +// So have to explicitly instantiate all the cases. + +template <typename T0, typename F> +void parallel_nd(const T0 &D0, F f) { + const size_t work_amount = (size_t)D0; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(T0(iwork)); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, F f) { + const size_t work_amount = (size_t)D0 * D1; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1); + utils::nd_iterator_step(d0, D0, d1, D1); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); + } + }, tbb::static_partitioner()); +} + +template <typename T0, typename T1, typename T2, typename T3, typename T4, + typename T5, typename F> +void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, + const T4 &D4, const T5 &D5, F f) { + const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5; + if (work_amount == 0) return; + tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) { + T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0}; + utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, + d5, D5); + for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) { + f(d0, d1, d2, d3, d4, d5); + utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); + } + }, tbb::static_partitioner()); +} +#endif + +template <typename ...Args> +void parallel_nd_in_omp(Args &&...args) { +#if MKLDNN_THR == MKLDNN_THR_SEQ + for_nd(0, 1, utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_OMP + for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(), + utils::forward<Args>(args)...); +#elif MKLDNN_THR == MKLDNN_THR_TBB + assert(!"unsupported parallel_nd_in_omp()"); +#endif +} + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp new file mode 100644 index 0000000000..aa671a0b6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef MKLDNN_TRAITS_HPP +#define MKLDNN_TRAITS_HPP + +#include <assert.h> +#include <stdint.h> + +#include "mkldnn.h" +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +template <data_type_t> struct prec_traits {}; /* ::type -> float */ +template <typename> struct data_traits {}; /* ::data_type -> f32 */ +template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */ +template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */ + +template <> struct prec_traits<data_type::f32> { typedef float type; }; +template <> struct prec_traits<data_type::s32> { typedef int32_t type; }; +template <> struct prec_traits<data_type::s8> { typedef int8_t type; }; +template <> struct prec_traits<data_type::u8> { typedef uint8_t type; }; + +template <> struct data_traits<float> +{ static constexpr data_type_t data_type = data_type::f32; }; +template <> struct data_traits<int32_t> +{ static constexpr data_type_t data_type = data_type::s32; }; +template <> struct data_traits<int8_t> +{ static constexpr data_type_t data_type = data_type::s8; }; +template <> struct data_traits<uint8_t> +{ static constexpr data_type_t data_type = data_type::u8; }; + +template <> struct typesize_traits<4> { typedef float type; }; +template <> struct typesize_traits<2> { typedef int16_t type; }; +template <> struct typesize_traits<1> { typedef uint8_t type; }; + +#define PKIND_TRAITS_INST(op) \ +template <> struct pkind_traits<primitive_kind::op> { \ + typedef CONCAT2(op, _desc_t) desc_type; \ + static constexpr query_t query_d = query::CONCAT2(op, _d); \ +} +PKIND_TRAITS_INST(convolution); +PKIND_TRAITS_INST(deconvolution); +PKIND_TRAITS_INST(shuffle); +PKIND_TRAITS_INST(eltwise); +PKIND_TRAITS_INST(softmax); +PKIND_TRAITS_INST(pooling); +PKIND_TRAITS_INST(lrn); +PKIND_TRAITS_INST(batch_normalization); +PKIND_TRAITS_INST(inner_product); +PKIND_TRAITS_INST(rnn); +#undef PKIND_TRAITS_INST + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp new file mode 100644 index 0000000000..f89ea999e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef NSTL_HPP +#define NSTL_HPP + +#include <stdint.h> +#include <limits.h> +#include <float.h> + +#include <vector> +#include <map> + +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +void *malloc(size_t size, int alignment); +void free(void *p); + +struct c_compatible { + enum { default_alignment = 64 }; + static void *operator new(size_t sz) { + return malloc(sz, default_alignment); + } + static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; } + static void *operator new[](size_t sz) { + return malloc(sz, default_alignment); + } + static void operator delete(void *p) { free(p); } + static void operator delete[](void *p) { free(p); } +}; + +namespace nstl { + +template<typename T> +inline const T abs(const T& a) { + return a >= 0 ? a : -a; +} + +template<typename T> +inline const T& max(const T& a, const T& b) { + return a > b ? a : b; +} + +template<typename T> +inline const T& min(const T& a, const T& b) { + return a < b ? a : b; +} + +template<typename T> void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +// Rationale: MKL-DNN needs numeric limits implementation that does not +// generate dependencies on C++ run-time libraries. + +template<typename T> struct numeric_limits; + +template<> struct numeric_limits<float> { + static constexpr float lowest() { return -FLT_MAX; } + static constexpr float max() { return FLT_MAX; } +}; + +template<> struct numeric_limits<int32_t> { + static constexpr int lowest() { return INT32_MIN; } + static constexpr int max() { return INT32_MAX; } +}; + +template<> struct numeric_limits<int16_t> { + static constexpr int16_t lowest() { return INT16_MIN; } + static constexpr int16_t max() { return INT16_MAX; } +}; + +template<> struct numeric_limits<int8_t> { + static constexpr int8_t lowest() { return INT8_MIN; } + static constexpr int8_t max() { return INT8_MAX; } +}; + +template<> struct numeric_limits<uint8_t> { + static constexpr uint8_t lowest() { return 0; } + static constexpr uint8_t max() { return UINT8_MAX; } +}; + +template<typename T> struct is_integral +{ static constexpr bool value = false; }; +template<> struct is_integral<int32_t> { static constexpr bool value = true; }; +template<> struct is_integral<int16_t> { static constexpr bool value = true; }; +template<> struct is_integral<int8_t> { static constexpr bool value = true; }; +template<> struct is_integral<uint8_t> { static constexpr bool value = true; }; + +template <typename T, typename U> struct is_same +{ static constexpr bool value = false; }; +template <typename T> struct is_same<T, T> +{ static constexpr bool value = true; }; + +// Rationale: MKL-DNN needs container implementations that do not generate +// dependencies on C++ run-time libraries. +// +// Implementation philosophy: caller is responsible to check if the operation +// is valid. The only functions that have to return status are those that +// depend on memory allocation or similar operations. +// +// This means that e.g. an operator [] does not have to check for boundaries. +// The caller should have checked the boundaries. If it did not we crash and +// burn: this is a bug in MKL-DNN and throwing an exception would not have been +// recoverable. +// +// On the other hand, insert() or resize() or a similar operation needs to +// return a status because the outcome depends on factors external to the +// caller. The situation is probably also not recoverable also, but MKL-DNN +// needs to be nice and report "out of memory" to the users. + +enum nstl_status_t { + success = 0, + out_of_memory +}; + +template <typename T> class vector: public c_compatible { +private: + std::vector<T> _impl; +public: + typedef typename std::vector<T>::iterator iterator; + typedef typename std::vector<T>::const_iterator const_iterator; + typedef typename std::vector<T>::size_type size_type; + vector() {} + vector(size_type n): _impl(n) {} + vector(size_type n, const T &value): _impl(n, value) {} + template <typename input_iterator> + vector(input_iterator first, input_iterator last): _impl(first, last) {} + ~vector() {} + size_type size() const { return _impl.size(); } + T& operator[] (size_type i) { return _impl[i]; } + const T& operator[] (size_type i) const { return _impl[i]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template <typename input_iterator> + nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end) + { + _impl.insert(pos, begin, end); + return success; + } + void clear() { _impl.clear(); } + void push_back(const T& t) { _impl.push_back(t); } + void resize(size_type count) { _impl.resize(count); } + void reserve(size_type count) { _impl.reserve(count); } +}; + +template <typename Key, typename T> class map: public c_compatible { +private: + std::map<Key, T> _impl; +public: + typedef typename std::map<Key, T>::iterator iterator; + typedef typename std::map<Key, T>::const_iterator const_iterator; + typedef typename std::map<Key, T>::size_type size_type; + map() {} + ~map() {} + size_type size() const { return _impl.size(); } + T& operator[](const Key &k) { return _impl[k]; } + const T& operator[](const Key &k) const { return _impl[k]; } + iterator begin() { return _impl.begin(); } + const_iterator begin() const { return _impl.begin(); } + iterator end() { return _impl.end(); } + const_iterator end() const { return _impl.end(); } + template <typename input_iterator> + void clear() { _impl.clear(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp new file mode 100644 index 0000000000..be96e654ff --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t pooling_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + bool args_ok = true + && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l) + && one_of(alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding) + && one_of(padding_kind, padding_kind::padding_zero); + if (!args_ok) return invalid_arguments; + + if (padding_r == nullptr) padding_r = padding_l; + + auto pd = pooling_desc_t(); + pd.primitive_kind = primitive_kind::pooling; + pd.prop_kind = prop_kind; + pd.alg_kind = alg_kind; + pd.src_desc.ndims = src_desc->ndims; + + const bool is_fwd = one_of(prop_kind, forward_training, forward_inference); + + pd.diff_src_desc = pd.src_desc = zero_md(); + pd.diff_dst_desc = pd.dst_desc = zero_md(); + + (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc; + (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc; + + int sp_dims = src_desc->ndims - 2; + utils::array_copy(pd.strides, strides, sp_dims); + utils::array_copy(pd.kernel, kernel, sp_dims); + utils::array_copy(pd.padding[0], padding_l, sp_dims); + utils::array_copy(pd.padding[1], padding_r, sp_dims); + + pd.padding_kind = padding_kind; + if (one_of(alg_kind, pooling_max, pooling_avg_include_padding, + pooling_avg_exclude_padding)) { + pd.accum_data_type = types::default_accum_data_type( + src_desc->data_type, dst_desc->data_type); + } else { + pd.accum_data_type = dst_desc->data_type; + } + + bool consistency = true + && utils::one_of(src_desc->ndims, 4, 5) + && utils::one_of(dst_desc->ndims, 4, 5) + && src_desc->dims[0] == dst_desc->dims[0] + && src_desc->dims[1] == dst_desc->dims[1]; + for (int i = 2; i < src_desc->ndims; ++i) + consistency = consistency && ( + (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2] + + padding_r[i - 2]) / strides[i - 2] + 1 + == dst_desc->dims[i]); + if (!consistency) return invalid_arguments; + + *pool_desc = pd; + return success; +} +} + +status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc, + prop_kind_t prop_kind, alg_kind_t alg_kind, + const memory_desc_t *src_desc, const memory_desc_t *dst_desc, + const dims_t strides, const dims_t kernel, const dims_t padding_l, + const dims_t padding_r, padding_kind_t padding_kind) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc, + dst_desc, strides, kernel, padding_l, padding_r, padding_kind); +} + +status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc, + alg_kind_t alg_kind, const memory_desc_t *diff_src_desc, + const memory_desc_t *diff_dst_desc, const dims_t strides, + const dims_t kernel, const dims_t padding_l, const dims_t padding_r, + padding_kind_t padding_kind) { + return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind, + diff_src_desc, diff_dst_desc, strides, kernel, padding_l, + padding_r, padding_kind); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp new file mode 100644 index 0000000000..4c9c009412 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp @@ -0,0 +1,238 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef POOLING_PD_HPP +#define POOLING_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct pooling_fwd_pd_t; + +struct pooling_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::pooling; + + pooling_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , ws_md_() + {} + + const pooling_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::pooling_d: + *(const pooling_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common pooling aux functions */ + + dim_t MB() const { return src_desc().dims[0]; } + dim_t C() const { return src_desc().dims[1]; } + + dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; } + dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; } + dim_t IW() const { return src_desc().dims[ndims() - 1]; } + + dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; } + dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; } + dim_t OW() const { return dst_desc().dims[ndims() - 1]; } + + dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; } + dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; } + dim_t KW() const { return desc_.kernel[ndims() - 3]; } + + dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; } + dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; } + dim_t KSW() const { return desc_.strides[ndims() - 3]; } + + dim_t padFront() const + { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; } + dim_t padBack() const + { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; } + dim_t padT() const + { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; } + dim_t padB() const + { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; } + dim_t padL() const { return desc_.padding[0][ndims() - 3]; } + dim_t padR() const { return desc_.padding[1][ndims() - 3]; } + + int ndims() const { return src_desc().ndims; } + bool is_3d() const { return ndims() == 5; } + + bool has_zero_dim_memory() const + { return memory_desc_wrapper(src_desc()).has_zero_dim(); } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + pooling_desc_t desc_; + const pooling_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; + + void init_default_ws() { + ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md(); + ws_md_.data_type = indices_data_type(); + } + + data_type_t indices_data_type() const { + /* the simplest way to express 256... */ + const int u8_max = nstl::numeric_limits< + typename prec_traits<data_type::u8>::type>::max(); + return utils::array_product(desc()->kernel, ndims()) <= u8_max + ? data_type::u8 : data_type::s32; + } + +private: + const memory_desc_t &src_desc() const + { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; } + const memory_desc_t &dst_desc() const + { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; } +}; + +struct pooling_fwd_pd_t: public pooling_pd_t { + typedef pooling_fwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_fwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , src_md_(desc_.src_desc) + , dst_md_(desc_.dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } + +protected: + memory_desc_t src_md_; + memory_desc_t dst_md_; + + virtual status_t set_default_params() { + if (dst_md()->format_kind != format_kind::any) + return status::success; + + if (src_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(dst_md_, + src_md_.format_desc.blocking); + } +}; + +struct pooling_bwd_pd_t: public pooling_pd_t { + typedef pooling_bwd_pd_t base_class; + typedef pooling_fwd_pd_t hint_class; + + pooling_bwd_pd_t(engine_t *engine, + const pooling_desc_t *adesc, + const primitive_attr_t *attr, + const pooling_fwd_pd_t *hint_fwd_pd) + : pooling_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_md_(desc_.diff_src_desc) + , diff_dst_md_(desc_.diff_dst_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_src_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_dst_md_ : nullptr; } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + virtual int n_inputs() const override + { return 1 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_src_md_; + memory_desc_t diff_dst_md_; + + virtual status_t set_default_params() { + if (diff_src_md()->format_kind != format_kind::any) + return status::success; + + if (diff_dst_md()->format_kind != format_kind::blocked) + return status::unimplemented; + + return memory_desc_init_by_blocking_desc(diff_src_md_, + diff_dst_md_.format_desc.blocking); + } +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp new file mode 100644 index 0000000000..fdf6522f62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "primitive.hpp" +#include "type_helpers.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::primitive_kind; + +namespace { +// XXX: this is a huge hammer. This disables all and any msan checks on +// primitives outputs. +// +// A proper approach would be an implementation-specific unpoisoning. +void unpoison_outputs(const exec_args_t &args) { + for(const auto &arg: args) { + if (arg.second.is_const) continue; + auto *mem = arg.second.mem; + void *p; + mem->get_data_handle(&p); + size_t s = memory_desc_wrapper(*mem->md()).size(); + msan_unpoison(p, s); + } +} +} + +status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) { + if (primitive_desc) delete primitive_desc; + return success; +} + +status_t mkldnn_primitive_create(primitive_t **primitive, + const primitive_desc_t *primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return primitive_desc->create_primitive(primitive); +} + +status_t mkldnn_primitive_execute(const primitive_t *primitive, + stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) { + bool ok = true + && !utils::any_null(primitive, stream) + && primitive->engine() == stream->engine() + && IMPLICATION(nargs > 0, c_args != nullptr); + if (!ok) return invalid_arguments; + + exec_args_t args; + status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args); + if (status != status::success) return status; + + exec_ctx_t ctx(stream, std::move(args)); + + if (mkldnn_verbose()->level) { + double ms = get_msec(); + status = primitive->execute(ctx); + ms = get_msec() - ms; + printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms); + fflush(0); + } else { + status = primitive->execute(ctx); + } + + if (msan_enabled) unpoison_outputs(ctx.args()); + + return status; +} + +status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive, + const primitive_desc_t **primitive_desc) { + if (utils::any_null(primitive, primitive_desc)) + return invalid_arguments; + return safe_ptr_assign<const primitive_desc_t>(*primitive_desc, + primitive->pd()); +} + +status_t mkldnn_primitive_destroy(primitive_t *primitive) { + if (primitive != nullptr) + delete primitive; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp new file mode 100644 index 0000000000..3b506d6d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_HPP +#define PRIMITIVE_HPP + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "primitive_exec_types.hpp" + +/** \brief A pure virtual primitive class + * + * Primitive contains links to its inputs & outputs, though it does not track + * their readiness on execution step. + * + * @remark @b Rational. + * Dependencies are essential through-out the whole MKL-DNN library, so it + * makes sense to include them on the very low level. On the other hand, + * tracking them should be a task for corresponding essence, like scheduler, + * stream or whatever. Primitive itself should know nothing about the + * environment it is running in. + * + * @note + * To make user experience better we should provide API which allows + * achieving the best (or good enough) performance when creating primitives + * in natural order: i.e. from bottom to top for forward pass and from top to + * bottom for backward pass. Please consider restriction [1] in Level 0. + */ +struct mkldnn_primitive: public mkldnn::impl::c_compatible { + mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd) + : pd_(pd->clone()) {} + virtual ~mkldnn_primitive() { delete pd_; } + + /** returns primitive's engine */ + mkldnn::impl::engine_t *engine() const { return pd_->engine(); } + /** returns primitive's inputs */ + const mkldnn::impl::primitive_desc_t *pd() const { return pd_; } + /** returns primitive's kind */ + mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); } + + /** executes primitive with execution context @p ctx */ + virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx) + const = 0; + +protected: + const mkldnn::impl::primitive_desc_t *pd_; + +private: + mkldnn_primitive() = delete; + mkldnn_primitive(const mkldnn_primitive &) = delete; + mkldnn_primitive(mkldnn_primitive &&) = delete; + mkldnn_primitive &operator=(const mkldnn_primitive &) = delete; + mkldnn_primitive &operator=(mkldnn_primitive &&) = delete; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp new file mode 100644 index 0000000000..9fd638842c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +namespace mkldnn { +namespace impl { + +status_t scales_t::set(dim_t count, int mask, const float *scales) { + cleanup(); + + count_ = count; + mask_ = mask; + + if (count_ == 1) { + scales_ = scales_buf_; + utils::array_set(scales_, scales[0], scales_buf_size); + } else { + scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64); + if (scales_ == nullptr) + return status::out_of_memory; + + for (dim_t c = 0; c < count_; ++c) + scales_[c] = scales[c]; + } + + return status::success; +} + +} +} + +status_t post_ops_t::append_sum(float scale) { + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::sum; + entry_[len_].sum.scale = scale; + + len_++; + + return success; +} + +status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha, + float beta) { + using namespace mkldnn::impl::alg_kind; + bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic); + if (!known_alg) + return invalid_arguments; + + if (len_ == capacity) + return out_of_memory; + + entry_[len_].kind = primitive_kind::eltwise; + entry_[len_].eltwise.scale = scale; + entry_[len_].eltwise.alg = alg; + entry_[len_].eltwise.alpha = alpha; + entry_[len_].eltwise.beta = beta; + + len_++; + + return success; +} + +status_t primitive_attr_t::set_scratchpad_mode( + scratchpad_mode_t scratchpad_mode) { + using namespace mkldnn::impl::scratchpad_mode; + + const bool ok = one_of(scratchpad_mode, library, user); + if (!ok) + return invalid_arguments; + + scratchpad_mode_ = scratchpad_mode; + return success; +} + +status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) { + this->post_ops_ = post_ops; + return success; +} + +/* Public C API */ + +status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) { + if (attr == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + new mkldnn_primitive_attr); +} + +status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr, + const primitive_attr_t *existing_attr) { + if (any_null(attr, existing_attr)) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_primitive_attr>(*attr, + existing_attr->clone()); +} + +status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) { + if (attr) + delete attr; + + return success; +} + +status_t mkldnn_primitive_attr_get_scratchpad_mode( + const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) { + if (any_null(attr, scratchpad_mode)) + return invalid_arguments; + + *scratchpad_mode = attr->scratchpad_mode_; + + return success; +} + +status_t mkldnn_primitive_attr_set_scratchpad_mode( + primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) { + if (any_null(attr)) + return invalid_arguments; + + return attr->set_scratchpad_mode(scratchpad_mode); +} + +status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr, + dim_t *count, int *mask, const float **scales) { + if (any_null(attr, count, mask, scales)) + return invalid_arguments; + + *count = attr->output_scales_.count_; + *mask = attr->output_scales_.mask_; + *scales = attr->output_scales_.scales_; + + return success; +} + +status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr, + dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->output_scales_.set(count, mask, scales); +} + +status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr, + const post_ops_t **post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + *post_ops = &attr->post_ops_; + return success; +} + +status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr, + const post_ops_t *post_ops) { + if (any_null(attr, post_ops)) + return invalid_arguments; + + return attr->set_post_ops(*post_ops); +} + +status_t mkldnn_post_ops_create(post_ops_t **post_ops) { + if (post_ops == nullptr) + return invalid_arguments; + + return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops); +} + +status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) { + if (post_ops) + delete post_ops; + + return success; +} + +int mkldnn_post_ops_len(const post_ops_t *post_ops) { + if (post_ops) + return post_ops->len_; + + return 0; +} + +primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops, + int index) { + bool ok = post_ops && 0 <= index && index < post_ops->len_; + if (!ok) + return primitive_kind::undefined; + + return post_ops->entry_[index].kind; +} + +status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_sum(scale); +} + +namespace { +bool simple_get_params_check(const post_ops_t *post_ops, int index, + primitive_kind_t kind) { + bool ok = true + && post_ops != nullptr + && 0 <= index + && index < post_ops->len_ + && post_ops->entry_[index].kind == kind; + return ok; +} +} + +status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index, + float *scale) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::sum) + && !any_null(scale); + if (!ok) + return invalid_arguments; + + *scale = post_ops->entry_[index].sum.scale; + return success; +} + +status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale, + alg_kind_t kind, float alpha, float beta) { + if (post_ops == nullptr) + return invalid_arguments; + + return post_ops->append_eltwise(scale, kind, alpha, beta); +} + +status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops, + int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) { + bool ok = true + && simple_get_params_check(post_ops, index, primitive_kind::eltwise) + && !any_null(scale, alpha, beta); + if (!ok) + return invalid_arguments; + + const auto &e = post_ops->entry_[index].eltwise; + *scale = e.scale; + *alg = e.alg; + *alpha = e.alpha; + *beta = e.beta; + + return success; +} + +status_t mkldnn_primitive_attr_set_rnn_data_qparams( + primitive_attr_t *attr, const float scale, const float shift) { + if (attr == nullptr) + return invalid_arguments; + + return attr->rnn_data_qparams_.set(scale, shift); +} + +status_t mkldnn_primitive_attr_set_rnn_weights_qparams( + primitive_attr_t *attr, dim_t count, int mask, const float *scales) { + bool ok = !any_null(attr, scales) && count > 0 && mask >= 0; + if (!ok) + return invalid_arguments; + + return attr->rnn_weights_qparams_.set(count, mask, scales); +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp new file mode 100644 index 0000000000..e2130c7ab1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_ATTR_HPP +#define PRIMITIVE_ATTR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_data_qparams_t : public c_compatible { + rnn_data_qparams_t() : scale_(1.), shift_(0.) {} + bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); } + + status_t set(float scale, float shift) { + scale_ = scale; + shift_ = shift; + return status::success; + } + + float scale_; + float shift_; +}; + +struct scales_t: public c_compatible { + scales_t(): count_(1), mask_(0), scales_(scales_buf_) + { set(1.); } + + scales_t(const scales_t &rhs): scales_t() + { set(rhs.count_, rhs.mask_, rhs.scales_); } + + ~scales_t() { cleanup(); } + + scales_t &operator=(const scales_t &rhs) { + if (&rhs == this) + return *this; + status_t status = set(rhs.count_, rhs.mask_, rhs.scales_); + assert(status == status::success); + (void)status; + return *this; + } + + bool has_default_values() const { + for (dim_t c = 0; c < count_; ++c) { + if(scales_[c] != 1.) return false; + } + return true; + } + + status_t set(dim_t count, int mask, const float *scales); + status_t set(float single_scale) { return this->set(1, 0, &single_scale); } + + dim_t count_; + int mask_; + float *scales_; + +private: + enum { scales_buf_size = 16 }; + float scales_buf_[scales_buf_size]; + + void cleanup() { + if (scales_ != scales_buf_ && scales_ != nullptr) + impl::free(scales_); + + count_ = 1; + mask_ = 0; + scales_ = scales_buf_; + } +}; + +} +} + +struct mkldnn_post_ops: public mkldnn::impl::c_compatible { + struct entry_t { + struct eltwise_t { + mkldnn::impl::alg_kind_t alg; + float scale, alpha, beta; + }; + + mkldnn::impl::primitive_kind_t kind; + union { + struct { float scale; } sum; + eltwise_t eltwise; + }; + + bool is_eltwise(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::eltwise + && IMPLICATION(require_scale_one, eltwise.scale == 1.f); + } + + bool is_relu(bool require_scale_one = true, + bool require_nslope_zero = true) const { + using namespace mkldnn::impl; + return is_eltwise(require_scale_one) + && eltwise.alg == alg_kind::eltwise_relu + && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f); + } + + bool is_sum(bool require_scale_one = true) const { + using namespace mkldnn::impl; + return kind == primitive_kind::sum + && IMPLICATION(require_scale_one, sum.scale == 1.f); + } + }; + + mkldnn_post_ops(): len_(0) {} + + mkldnn::impl::status_t append_sum(float scale); + mkldnn::impl::status_t append_eltwise(float scale, + mkldnn::impl::alg_kind_t alg, float alpha, float beta); + + int find(mkldnn::impl::primitive_kind_t kind, int start = 0, + int stop = -1) const { + if (stop == -1) stop = len_; + stop = mkldnn::impl::nstl::min(stop, len_); + for (int idx = start; idx < stop; ++idx) + if (entry_[idx].kind == kind) return idx; + return -1; + } + + bool has_default_values() const { return len_ == 0; } + + bool contain(mkldnn::impl::primitive_kind_t kind, int index) const + { return find(kind, index, index + 1) == index; } + + enum { capacity = 4 }; + + int len_; + entry_t entry_[capacity]; +}; + +struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible { + mkldnn_primitive_attr() + : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library) + {} + + mkldnn_primitive_attr *clone() const + { return new mkldnn_primitive_attr(*this); } + + /** Returns true if the attributes have default values. + * + * @note The scratchpad_mode_ is not take into account */ + bool has_default_values() const { + return true + && output_scales_.has_default_values() + && post_ops_.has_default_values() + && rnn_data_qparams_.has_default_values() + && rnn_weights_qparams_.has_default_values(); + } + + mkldnn::impl::status_t set_scratchpad_mode( + mkldnn::impl::scratchpad_mode_t scratchpad_mode); + mkldnn::impl::status_t set_post_ops( + const mkldnn::impl::post_ops_t &post_ops); + + mkldnn::impl::scratchpad_mode_t scratchpad_mode_; + mkldnn::impl::scales_t output_scales_; + mkldnn::impl::post_ops_t post_ops_; + mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_; + mkldnn::impl::scales_t rnn_weights_qparams_; +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp new file mode 100644 index 0000000000..723c41e05a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t primitive_desc_t::query(query_t what, int idx, void *result) const { + auto safe_ret_md = [&](const memory_desc_t *_) { + if (_ == nullptr) return not_required; + *(const memory_desc_t **)result = _; + return success; + }; + + switch (what) { + case query::engine: *(engine_t**)result = engine(); break; + case query::primitive_kind: *(primitive_kind_t*)result = kind(); break; + + case query::scratchpad_engine: + *(engine_t**)result = scratchpad_engine(); break; + + case query::memory_consumption_s64: + *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break; + + case query::op_d: + if (idx != 0 || op_desc() == nullptr) return invalid_arguments; + *(const_c_op_desc_t *)result + = static_cast<const_c_op_desc_t>(op_desc()); break; + + case query::src_md: return safe_ret_md(src_md(idx)); + case query::diff_src_md: return safe_ret_md(diff_src_md(idx)); + case query::dst_md: return safe_ret_md(dst_md(idx)); + case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx)); + case query::weights_md: return safe_ret_md(weights_md(idx)); + case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx)); + case query::workspace_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(workspace_md(idx)); + case query::scratchpad_md: + if (idx != 0) return status::invalid_arguments; + return safe_ret_md(scratchpad_md(idx)); + + case query::num_of_inputs_s32: *(int*)result = n_inputs(); break; + case query::num_of_outputs_s32: *(int*)result = n_outputs(); break; + + case query::impl_info_str: *(const char **)result = name(); break; + + default: return unimplemented; + } + return success; +} + +status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc, + const primitive_attr_t **attr) { + if (utils::any_null(primitive_desc, attr)) + return invalid_arguments; + + *attr = primitive_desc->attr(); + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp new file mode 100644 index 0000000000..536dcfa1d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_DESC_HPP +#define PRIMITIVE_DESC_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "primitive_attr.hpp" +#include "verbose.hpp" + +struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible { + using md_t = mkldnn::impl::memory_desc_t; + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; } + + mkldnn_primitive_desc(mkldnn::impl::engine_t *engine, + mkldnn::impl::primitive_kind_t kind) + : engine_(engine), kind_(kind) { info_[0] = '\0'; } + + virtual mkldnn_primitive_desc *clone() const = 0; + virtual ~mkldnn_primitive_desc() {} + + const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; } + mkldnn::impl::engine_t *engine() const { return engine_; } + mkldnn::impl::primitive_kind_t kind() const { return kind_; } + + virtual void init_info() {} + const char *info() const { return info_; } + + mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() + { return scratchpad_registry_; } + const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const + { return scratchpad_registry_; } + virtual mkldnn::impl::engine_t *scratchpad_engine() const + { return engine_; } + + virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; } + + enum class arg_usage_t { unused, input, output }; + virtual arg_usage_t arg_usage( + mkldnn::impl::primitive_arg_index_t arg) const { + using mkldnn::impl::types::is_zero_md; + if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md())) + return arg_usage_t::output; + return arg_usage_t::unused; + } + +# define DECLARE_MD_STUB(stub) \ + virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \ + { return nullptr; } + + DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md); + DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md); + DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md); + DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md); + DECLARE_MD_STUB(workspace_md); +# undef DECLARE_MD_STUB + + const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const { + return idx == 0 ? &scratchpad_md_ : nullptr; + } + + virtual void init_scratchpad_md() { + auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user); + mkldnn::impl::dims_t dims = { size }; + mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims, + mkldnn::impl::data_type::u8, mkldnn_x); + } + + /** returns the scratchpad size for the given scratchpad mode. */ + mkldnn::impl::dim_t scratchpad_size( + mkldnn::impl::scratchpad_mode_t mode) const { + if (mode != attr_.scratchpad_mode_) return 0; + return scratchpad_registry().size(); + } + + virtual int n_inputs() const { return 0; } + virtual int n_outputs() const { return 0; } + + virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx, + void *result) const; + + virtual mkldnn::impl::status_t create_primitive( + mkldnn::impl::primitive_t **primitive) const = 0; + + virtual const char *name() const { return "mkldnn_primitive_desc"; } + + /* static magic */ + + template<typename pd_t> + static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd, + const mkldnn::impl::op_desc_t *adesc, + const mkldnn::impl::primitive_attr_t *attr, + mkldnn::impl::engine_t *engine, + const mkldnn::impl::primitive_desc_t *hint_fwd) { + using namespace mkldnn::impl; + using namespace mkldnn::impl::status; + using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type; + if (adesc->kind != pd_t::base_pkind) return invalid_arguments; + assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true); + auto hint = + reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd); + auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->init_info(); + _pd->init_scratchpad_md(); + *pd = _pd; + return success; + } + +protected: + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_attr_t attr_; + mkldnn::impl::primitive_kind_t kind_; + + mkldnn::impl::memory_desc_t scratchpad_md_; + + char info_[MKLDNN_VERBOSE_BUF_LEN]; + + mkldnn::impl::memory_tracking::registry_t scratchpad_registry_; + +protected: + /** compares ws between fwd_pd and this (make sense to use for bwd_pd) + * Expectation: this already set workspace, and this workspace should + * exactly match the one from fwd_pd */ + bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const { + using namespace mkldnn::impl; + if (!workspace_md()) return true; // the impl lives fine w/o workspace + return fwd_pd && fwd_pd->workspace_md() + && *fwd_pd->workspace_md() == *workspace_md(); + } +}; + +#define DECLARE_COMMON_PD_t(impl_name, ...) \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual const char *name() const override { return impl_name; } +#define DECLARE_COMMON_PD_T(impl_name, ...) \ + DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__) + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp new file mode 100644 index 0000000000..43e5a31ef3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "memory.hpp" +#include "primitive.hpp" +#include "primitive_exec_types.hpp" + +namespace mkldnn { +namespace impl { + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args) { + using namespace status; + + if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments; + + int n_inputs = 0; + int n_outputs = 0; + + for (int i = 0; i < nargs; ++i) { + primitive_arg_index_t arg = c_args[i].arg; + auto *mem = c_args[i].memory; + + switch (pd->arg_usage(arg)) { + case primitive_desc_t::arg_usage_t::input: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, true}; + n_inputs++; + break; + case primitive_desc_t::arg_usage_t::output: + if (args.count(arg) != 0) return invalid_arguments; + args[arg] = {mem, false}; + n_outputs++; + break; + case primitive_desc_t::arg_usage_t::unused: + break; + } + } + + bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md()); + + if (n_inputs != pd->n_inputs()) return invalid_arguments; + if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0)) + return invalid_arguments; + + return success; +} + +const void *exec_ctx_t::input(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +void *exec_ctx_t::output(primitive_arg_index_t arg) const { + if (args_.count(arg) != 1) return nullptr; + const auto ma = args_.at(arg); + assert(!ma.is_const); + void *ptr; + status_t status = ma.mem->get_data_handle(&ptr); + assert(status == status::success); MAYBE_UNUSED(status); + return ptr; +} + +const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const { + assert(args_.count(arg) == 1); + const auto ma = args_.at(arg); + assert(!ma.is_const); + return ma.mem; +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp new file mode 100644 index 0000000000..0645891da7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef PRIMITIVE_EXEC_TYPES_HPP +#define PRIMITIVE_EXEC_TYPES_HPP + +#include <unordered_map> + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct memory_arg_t { + memory_t *mem; + bool is_const; +}; + +using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>; + +status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs, + const mkldnn_exec_arg_t *c_args, exec_args_t &args); + +/** Primitive execution context (helps passing stream, memories, and events. */ +struct exec_ctx_t { + exec_ctx_t(const exec_ctx_t &) = default; + exec_ctx_t(exec_ctx_t &&) = default; + + exec_ctx_t(stream_t *stream): stream_(stream) {} + exec_ctx_t(stream_t *stream, exec_args_t &&args) + : stream_(stream) + , args_(std::move(args)) {} + + stream_t *stream() const { return stream_; } + const exec_args_t &args() const { return args_; } + + /* tentative solution... TODO: replace with functions return memory_t */ + const void *input(primitive_arg_index_t arg) const; + void *output(primitive_arg_index_t arg) const; + + const memory_t *memory(primitive_arg_index_t arg) const; + +private: + stream_t *stream_; + exec_args_t args_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp new file mode 100644 index 0000000000..5a1cd7d379 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_iterator_create( + primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc, + const primitive_attr_t *attr, engine_t *engine, + const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd); + if (it == nullptr) return out_of_memory; + + ++(*it); + if (*it == it->end()) { + delete it; + return unimplemented; + } + + *iterator = it; + return success; +} + +status_t mkldnn_primitive_desc_iterator_next( + primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return invalid_arguments; + ++(*iterator); + return *iterator == iterator->end() ? iterator_ends : success; +} + +primitive_desc_t *mkldnn_primitive_desc_iterator_fetch( + const primitive_desc_iterator_t *iterator) { + if (iterator == nullptr) return nullptr; + return *(*iterator); +} + +status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc, + const primitive_desc_t *existing_primitive_desc) { + if (utils::any_null(primitive_desc, existing_primitive_desc)) + return invalid_arguments; + return safe_ptr_assign<primitive_desc_t>(*primitive_desc, + existing_primitive_desc->clone()); +} + +status_t mkldnn_primitive_desc_iterator_destroy( + primitive_desc_iterator_t *iterator) { + if (iterator != nullptr) + delete iterator; + return success; +} + +status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc, + const_c_op_desc_t c_op_desc, const primitive_attr_t *attr, + engine_t *engine, const primitive_desc_t *hint_fwd_pd) { + const op_desc_t *op_desc = (const op_desc_t *)c_op_desc; + + mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd); + ++it; + if (it == it.end()) return unimplemented; + + return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it); +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp new file mode 100644 index 0000000000..4e88ab3aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef PRIMITIVE_ITERATOR_HPP +#define PRIMITIVE_ITERATOR_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { + using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, + const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) + : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) + , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) + , impl_list_(engine_->get_implementation_list()), last_idx_(0) + { + while (impl_list_[last_idx_] != nullptr) ++last_idx_; + } + ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } + + bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } + bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const + { return !operator==(rhs); } + + mkldnn::impl::primitive_desc_iterator_t end() const + { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } + + mkldnn::impl::primitive_desc_iterator_t &operator++() { + if (pd_) { delete pd_; pd_ = nullptr; } + while (++idx_ != last_idx_) { + auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, + hint_fwd_pd_); + if (s == mkldnn::impl::status::success) break; + } + return *this; + } + + mkldnn::impl::primitive_desc_t *operator*() const { + if (*this == end() || pd_ == nullptr) return nullptr; + return pd_->clone(); + } + +protected: + int idx_; + mkldnn::impl::engine_t *engine_; + mkldnn::impl::primitive_desc_t *pd_; + const mkldnn::impl::op_desc_t *op_desc_; + const mkldnn::impl::primitive_attr_t attr_; + const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; + const pd_create_f *impl_list_; + int last_idx_; + +private: + mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) + : idx_(last_idx), engine_(engine), pd_(nullptr) + , op_desc_(nullptr), hint_fwd_pd_(nullptr) + , impl_list_(nullptr), last_idx_(last_idx) {} +}; + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp new file mode 100644 index 0000000000..835cd73581 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "primitive_desc.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc, + query_t what, int index, void *result) { + if (any_null(primitive_desc, result)) + return invalid_arguments; + + return primitive_desc->query(what, index, result); +} + +const memory_desc_t *mkldnn_primitive_desc_query_md( + const primitive_desc_t *primitive_desc, query_t what, int index) { + const memory_desc_t *res_md = nullptr; + bool args_ok = true + && primitive_desc != nullptr + && (what & query::some_md) == query::some_md + && what != query::some_md + && mkldnn_primitive_desc_query(primitive_desc, + what, index, &res_md) == success; + return args_ok ? res_md : nullptr; +} + +int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc, + query_t what, int index) { + int res_s32; + bool args_ok = primitive_desc != nullptr + && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32) + && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32) + == success; + return args_ok ? res_s32 : 0; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp new file mode 100644 index 0000000000..d11f1a0361 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "reorder_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_reorder_primitive_desc_create( + primitive_desc_t **reorder_pd, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md, + const primitive_attr_t *attr) { + if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md)) + return invalid_arguments; + + auto s_ek = src_engine->kind(); + auto d_ek = dst_engine->kind(); + if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek))) + return invalid_arguments; + + auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd); + auto s_mdw = memory_desc_wrapper(*src_md); + auto d_mdw = memory_desc_wrapper(*dst_md); + + if (!s_mdw.consistent_with(d_mdw)) + return invalid_arguments; + + auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + for (auto r = e->get_reorder_implementation_list(); *r; ++r) { + if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md) + == success) { + (*r_pd)->init_info(); + (*r_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp new file mode 100644 index 0000000000..963cb0f58a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REORDER_PD_HPP +#define REORDER_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "primitive_attr.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct reorder_pd_t: public primitive_desc_t { + reorder_pd_t(engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) + : primitive_desc_t(engine, attr, primitive_kind::reorder) + , src_engine_(src_engine) + , dst_engine_(dst_engine) + , scratchpad_engine_(nullptr) + , src_md_(*src_md) + , dst_md_(*dst_md) + {} + + virtual const op_desc_t *op_desc() const override { return nullptr; } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_FROM) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_TO) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &src_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + float alpha() const { return attr()->output_scales_.scales_[0]; } + float beta() const { + const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); + return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; + } + virtual mkldnn::impl::engine_t *scratchpad_engine() const override + { return scratchpad_engine_; } + +protected: + engine_t *src_engine_; + engine_t *dst_engine_; + engine_t *scratchpad_engine_; + + memory_desc_t src_md_; + memory_desc_t dst_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp new file mode 100644 index 0000000000..36967431a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp @@ -0,0 +1,400 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu/gemm/os_blas.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::utils; + +namespace { +memory_desc_t copy_maybe_null(const memory_desc_t *md) { + return md ? *md : zero_md(); +} + +rnn_desc_t zero_rnn_desc() { + auto rd = rnn_desc_t(); + rd.src_layer_desc = zero_md(); + rd.src_iter_desc = zero_md(); + rd.weights_layer_desc = zero_md(); + rd.weights_iter_desc = zero_md(); + rd.bias_desc = zero_md(); + rd.dst_layer_desc = zero_md(); + rd.dst_iter_desc = zero_md(); + rd.diff_src_layer_desc = zero_md(); + rd.diff_src_iter_desc = zero_md(); + rd.diff_weights_layer_desc = zero_md(); + rd.diff_weights_iter_desc = zero_md(); + rd.diff_bias_desc = zero_md(); + rd.diff_dst_layer_desc = zero_md(); + rd.diff_dst_iter_desc = zero_md(); + return rd; +} +} + +/* Public C Api */ + +status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc, + mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f, + unsigned int flags, float alpha, float clipping) { + using namespace mkldnn::impl::alg_kind; + + bool args_ok = true + && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru, + gru_linear_before_reset) + && IMPLICATION(cell_kind == vanilla_rnn, + one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic)); + if (!args_ok) + return invalid_arguments; + + auto rcd = mkldnn_rnn_cell_desc_t(); + + rcd.cell_kind = cell_kind; + rcd.activation_kind = act_f; + rcd.flags = flags; + rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0; + rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0; + + *rnn_cell_desc = rcd; + + return success; +} + +int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 3; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3; + case mkldnn::impl::alg_kind::vanilla_lstm: return 4; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) { + switch (rnn_cell_desc->cell_kind) { + case mkldnn::impl::alg_kind::vanilla_rnn: return 1; + case mkldnn::impl::alg_kind::vanilla_gru: return 1; + case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1; + case mkldnn::impl::alg_kind::vanilla_lstm: return 2; + default: assert(!"unknown cell kind"); return 0; + } + return 0; +} + +status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc, + prop_kind_t prop_kind, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + using namespace data_type; + data_type_t src_layer_dt = src_layer_desc->data_type; + data_type_t dst_layer_dt = dst_layer_desc->data_type; + data_type_t weights_iter_dt = weights_iter_desc->data_type; + data_type_t weights_layer_dt = weights_layer_desc->data_type; + + bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt, + weights_layer_dt) + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + +#if USE_MKL_PACKED_GEMM + bool is_u8u8u8 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == u8) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == u8) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_f32u8f32 = src_layer_dt == u8 + && IMPLICATION(!is_zero_md(src_iter_desc), + src_iter_desc->data_type == f32) + && IMPLICATION(!is_zero_md(dst_iter_desc), + dst_iter_desc->data_type == f32) + && one_of(dst_layer_dt, u8, f32) + && everyone_is(s8, weights_iter_dt, weights_layer_dt) + && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32); + + bool is_inference = prop_kind == prop_kind::forward_inference; + bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm; + + return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference)) + ? success + : unimplemented; +#else + return is_f32 ? success : unimplemented; +#endif +} + +status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc, + rnn_direction_t direction, int L, int D, int T, int N, int S, int G, + int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok; + + // * algorithm specific + args_ok = true + && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru, + DIC == SIC); + if (!args_ok) return invalid_arguments; + int extra_bias = + rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset; + + // * on num layers + args_ok = true + && L == weights_layer_desc->dims[0] + && L == weights_iter_desc->dims[0] + && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0]) + && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0]) + && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]); + if (!args_ok) return invalid_arguments; + + // * on num directions + args_ok = true + && D == weights_layer_desc->dims[1] + && D == weights_iter_desc->dims[1] + && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1]) + && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1]) + && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]); + if (!args_ok) return invalid_arguments; + + // * on num iterations + args_ok = true + && T == src_layer_desc->dims[0] + && T == dst_layer_desc->dims[0]; + if (!args_ok) return invalid_arguments; + + // * on mb + args_ok = true + && N == src_layer_desc->dims[1] + && N == dst_layer_desc->dims[1] + && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]); + if (!args_ok) return invalid_arguments; + + // * on num gates + args_ok = true + && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc) + && G == weights_layer_desc->dims[3] + && G == weights_iter_desc->dims[3] + && IMPLICATION(!is_zero_md(bias_desc), + G + extra_bias == bias_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on num states + args_ok = true + && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc) + && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2]) + && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]); + if (!args_ok) return invalid_arguments; + + // * on slc + args_ok = true + && SLC == weights_layer_desc->dims[2] + && SLC == src_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on sic + args_ok = true + && SIC == weights_iter_desc->dims[2] + && IMPLICATION(!is_zero_md(src_iter_desc), + SIC == src_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * on dlc + int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1; + args_ok = true + && DLC == dlc_multiplier * DIC + && DLC == dst_layer_desc->dims[2]; + if (!args_ok) return invalid_arguments; + + // * on dic + args_ok = true + && DIC == weights_layer_desc->dims[4] + && DIC == weights_iter_desc->dims[4] + && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3]) + && IMPLICATION(!is_zero_md(dst_iter_desc), + DIC == dst_iter_desc->dims[4]); + if (!args_ok) return invalid_arguments; + + // * unrolling/fusion conditions + args_ok = true + && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC) + && IMPLICATION(T > 1, SIC == DIC); + if (!args_ok) return invalid_arguments; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, + const memory_desc_t *dst_iter_desc) { + bool args_ok = true && rnn_cell_desc != nullptr + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc); + if (!args_ok) return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc)); + + CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind, + src_layer_desc, src_iter_desc, weights_layer_desc, + weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc)); + + // Create the descriptor + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + + *rnn_desc = rd; + + return success; +} + +status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, + prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc, + const rnn_direction_t direction, const memory_desc_t *src_layer_desc, + const memory_desc_t *src_iter_desc, + const memory_desc_t *weights_layer_desc, + const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc, + const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc, + const memory_desc_t *diff_src_layer_desc, + const memory_desc_t *diff_src_iter_desc, + const memory_desc_t *diff_weights_layer_desc, + const memory_desc_t *diff_weights_iter_desc, + const memory_desc_t *diff_bias_desc, + const memory_desc_t *diff_dst_layer_desc, + const memory_desc_t *diff_dst_iter_desc) { + bool args_ok = true + && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc, + dst_layer_desc, diff_src_layer_desc, + diff_weights_layer_desc, diff_weights_iter_desc, + diff_dst_layer_desc); + if (!args_ok) + return invalid_arguments; + + auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) { + return is_zero_md(a_md) == is_zero_md(b_md); + }; + + args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc) + && xnor_md(dst_iter_desc, diff_dst_iter_desc) + && xnor_md(src_iter_desc, diff_src_iter_desc); + if (!args_ok) + return invalid_arguments; + + //check dimensions consistency + int L = weights_layer_desc->dims[0]; + int T = src_layer_desc->dims[0]; + int N = src_layer_desc->dims[1]; + const int D = one_of(direction, mkldnn_unidirectional_left2right, + mkldnn_unidirectional_right2left) ? + 1 : + 2; + int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc); + int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc); + int SLC = src_layer_desc->dims[2]; + int SIC = weights_iter_desc->dims[2]; + int DLC = dst_layer_desc->dims[2]; + int DIC = weights_layer_desc->dims[4]; + + status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc, + weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, + dst_iter_desc); + if (st != success) return st; + + st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S, + G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc, + diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc, + diff_dst_layer_desc, diff_dst_iter_desc); + if (st != success) return st; + + mkldnn_rnn_desc_t rd = zero_rnn_desc(); + + rd.primitive_kind = primitive_kind::rnn; + rd.prop_kind = prop_kind; + rd.cell_desc = *rnn_cell_desc; + rd.direction = direction; + + rd.src_layer_desc = copy_maybe_null(src_layer_desc); + rd.src_iter_desc = copy_maybe_null(src_iter_desc); + rd.weights_layer_desc = copy_maybe_null(weights_layer_desc); + rd.weights_iter_desc = copy_maybe_null(weights_iter_desc); + rd.bias_desc = copy_maybe_null(bias_desc); + rd.dst_layer_desc = copy_maybe_null(dst_layer_desc); + rd.dst_iter_desc = copy_maybe_null(dst_iter_desc); + rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc); + rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc); + rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc); + rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc); + rd.diff_bias_desc = copy_maybe_null(diff_bias_desc); + rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc); + rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc); + + *rnn_desc = rd; + + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp new file mode 100644 index 0000000000..1ee2ba1114 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_PD_HPP +#define RNN_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { + +struct rnn_fwd_pd_t; + +struct rnn_pd_t : public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::rnn; + + rnn_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , src_layer_md_(desc_.src_layer_desc) + , src_iter_md_(desc_.src_iter_desc) + , weights_layer_md_(desc_.weights_layer_desc) + , weights_iter_md_(desc_.weights_iter_desc) + , bias_md_(desc_.bias_desc) + , dst_layer_md_(desc_.dst_layer_desc) + , dst_iter_md_(desc_.dst_iter_desc) + , ws_md_() + {} + + const rnn_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::rnn_d: *(const rnn_desc_t **)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual const memory_desc_t *src_md(int index = 0) const override { + if (index == 0) return &src_layer_md_; + if (index == 1 && with_src_iter()) return &src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *weights_md(int index = 0) const override { + if (index == 0) return &weights_layer_md_; + if (index == 1) return &weights_iter_md_; + if (index == 2 && with_bias()) return &bias_md_; + return nullptr; + } + virtual const memory_desc_t *dst_md(int index = 0) const override { + if (index == 0) return &dst_layer_md_; + if (index == 1 && with_dst_iter()) return &dst_iter_md_; + return nullptr; + } + virtual const memory_desc_t *workspace_md(int index = 0) const override + { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; } + + /* common pooling aux functions */ + + bool is_training() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::backward); + } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + dim_t T() const { return desc_.src_layer_desc.dims[0]; } + dim_t MB() const { return desc_.src_layer_desc.dims[1]; } + + dim_t L() const { return desc_.weights_layer_desc.dims[0]; } + dim_t D() const { return desc_.weights_layer_desc.dims[1]; } + + dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; } + + dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; } + dim_t G() const { return desc_.weights_layer_desc.dims[3]; } + dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; } + + dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; } + + bool with_bias() const + { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); } + + bool with_src_iter() const + { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); } + + bool with_dst_iter() const + { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); } + + mkldnn::impl::alg_kind_t cell_kind() const + { return desc_.cell_desc.cell_kind; } + mkldnn::impl::alg_kind_t activation_kind() const + { return desc_.cell_desc.activation_kind; } + + bool is_lbr() const + { return cell_kind() == mkldnn_gru_linear_before_reset; } + + mkldnn_rnn_direction_t direction() const { return desc_.direction; } + +protected: + rnn_desc_t desc_; + const rnn_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t src_layer_md_; + memory_desc_t src_iter_md_; + memory_desc_t weights_layer_md_; + memory_desc_t weights_iter_md_; + memory_desc_t bias_md_; + memory_desc_t dst_layer_md_; + memory_desc_t dst_iter_md_; + + memory_desc_t ws_md_; +}; + +struct rnn_fwd_pd_t: public rnn_pd_t { + typedef rnn_fwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_fwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC_LAYER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter()) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_BIAS && with_bias()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST_LAYER) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter()) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && is_training()) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual int n_inputs() const override + { return 3 + with_bias() + with_src_iter(); } + virtual int n_outputs() const override + { return 1 + with_dst_iter() + is_training(); } +}; + +struct rnn_bwd_pd_t : public rnn_pd_t { + typedef rnn_bwd_pd_t base_class; + typedef rnn_fwd_pd_t hint_class; + + rnn_bwd_pd_t(engine_t *engine, + const rnn_desc_t *adesc, + const primitive_attr_t *attr, + const rnn_fwd_pd_t *hint_fwd_pd) + : rnn_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_src_layer_md_(desc_.diff_src_layer_desc) + , diff_src_iter_md_(desc_.diff_src_iter_desc) + , diff_weights_layer_md_(desc_.diff_weights_layer_desc) + , diff_weights_iter_md_(desc_.diff_weights_iter_desc) + , diff_bias_md_(desc_.diff_bias_desc) + , diff_dst_layer_md_(desc_.diff_dst_layer_desc) + , diff_dst_iter_md_(desc_.diff_dst_iter_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER, + MKLDNN_ARG_DIFF_DST_LAYER)) + return arg_usage_t::input; + + if (with_src_iter()) { + if (arg == MKLDNN_ARG_SRC_ITER) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC_ITER) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER, + MKLDNN_ARG_WEIGHTS_ITER)) + return arg_usage_t::input; + + if (with_bias()) { + if (arg == MKLDNN_ARG_BIAS) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_BIAS) + return arg_usage_t::output; + } + + if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER) + && with_dst_iter()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_WORKSPACE) + return arg_usage_t::input; + + if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_LAYER, + MKLDNN_ARG_DIFF_WEIGHTS_ITER)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override { + if (index == 0) return &diff_src_layer_md_; + if (index == 1 && with_src_iter()) return &diff_src_iter_md_; + return nullptr; + } + virtual const memory_desc_t *diff_weights_md( + int index = 0) const override { + if (index == 0) return &diff_weights_layer_md_; + if (index == 1) return &diff_weights_iter_md_; + if (index == 2 && with_bias()) return &diff_bias_md_; + return nullptr; + } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override { + if (index == 0) return &diff_dst_layer_md_; + if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_; + return nullptr; + } + + virtual int n_inputs() const override + { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); } + virtual int n_outputs() const override + { return 3 + with_src_iter() + with_bias(); } + +protected: + memory_desc_t diff_src_layer_md_; + memory_desc_t diff_src_iter_md_; + memory_desc_t diff_weights_layer_md_; + memory_desc_t diff_weights_iter_md_; + memory_desc_t diff_bias_md_; + memory_desc_t diff_dst_layer_md_; + memory_desc_t diff_dst_iter_md_; +}; + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp new file mode 100644 index 0000000000..6bc14fc72a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "scratchpad.hpp" + +namespace mkldnn { +namespace impl { + +/* Allocating memory buffers on a page boundary to reduce TLB/page misses */ +const size_t page_size = 2097152; + +/* + Implementation of the scratchpad_t interface that is compatible with + a concurrent execution +*/ +struct concurent_scratchpad_t : public scratchpad_t { + concurent_scratchpad_t(size_t size) { + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + + ~concurent_scratchpad_t() { + free(scratchpad_); + } + + virtual char *get() const { + return scratchpad_; + } + +private: + char *scratchpad_; + size_t size_; +}; + +/* + Implementation of the scratchpad_t interface that uses a global + scratchpad +*/ + +struct global_scratchpad_t : public scratchpad_t { + global_scratchpad_t(size_t size) { + if (size > size_) { + if (scratchpad_ != nullptr) free(scratchpad_); + size_ = size; + scratchpad_ = (char *) malloc(size, page_size); + assert(scratchpad_ != nullptr); + } + reference_count_++; + } + + ~global_scratchpad_t() { + reference_count_--; + if (reference_count_ == 0) { + free(scratchpad_); + scratchpad_ = nullptr; + size_ = 0; + } + } + + virtual char *get() const { + return scratchpad_; + } + +private: + /* + Using thread-local here is unnecessary and even buggy! All threads + actually share the same scratchpad, which is created and queried only + on the main thread. If the scratchpad is queried on some thread other + than the one it was created on (e.g. the application calls the API from + multiple threads), thread-local causes a segfault because the scratchpad + is uninitialized on the current thread. + */ + /*thread_local*/ static char *scratchpad_; + /*thread_local*/ static size_t size_; + /*thread_local*/ static unsigned int reference_count_; +}; + +/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr; +/*thread_local*/ size_t global_scratchpad_t::size_ = 0; +/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0; + + +/* + Scratchpad creation routine +*/ +scratchpad_t *create_scratchpad(size_t size) { +#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC + return new global_scratchpad_t(size); +#else + return new concurent_scratchpad_t(size); +#endif +} + +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp new file mode 100644 index 0000000000..f7a246bc99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SCRATCHPAD_HPP +#define COMMON_SCRATCHPAD_HPP + +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct scratchpad_t { + virtual ~scratchpad_t() {} + virtual char *get() const = 0; +}; + +scratchpad_t *create_scratchpad(size_t size); + +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp new file mode 100644 index 0000000000..e32e735224 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, int axis, dim_t group_size) { + bool args_ok = true + && !any_null(shuffle_desc, data_desc) + && one_of(prop_kind, forward_training, forward_inference, + backward, backward_data) + && axis >= 0 && axis < data_desc->ndims + && group_size > 0 && group_size <= data_desc->dims[axis]; + if (!args_ok) return invalid_arguments; + + auto sd = shuffle_desc_t(); + sd.primitive_kind = primitive_kind::shuffle; + sd.prop_kind = prop_kind; + sd.data_desc = *data_desc; + sd.axis = axis; + sd.group_size = group_size; + + bool consistency = true + && sd.data_desc.dims[axis] % sd.group_size == 0; + if (!consistency) return invalid_arguments; + + *shuffle_desc = sd; + return success; +} +} + +status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis, + dim_t group_size) { + if (!one_of(prop_kind, forward_training, forward_inference)) + return invalid_arguments; + return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis, + group_size); +} + +status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc, + const memory_desc_t *diff_data_desc, int axis, dim_t group_size) { + return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis, + group_size); +} + +// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp new file mode 100644 index 0000000000..cc5553fe7f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SHUFFLE_PD_HPP +#define SHUFFLE_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct shuffle_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::shuffle; + + typedef shuffle_pd_t base_class; + typedef shuffle_pd_t hint_class; + + shuffle_pd_t(engine_t *engine, + const shuffle_desc_t *adesc, + const primitive_attr_t *attr, + const shuffle_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const shuffle_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::shuffle_d: + *(const shuffle_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (is_fwd()) { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + } else { + if (arg == MKLDNN_ARG_DIFF_DST) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + } + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 && is_fwd() ? &data_md_ : nullptr; } + + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 && !is_fwd() ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override { return 1; } + + /* shuffle aux functions */ + + dim_t MB() const { return data_md()->dims[0]; } + dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; } + dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; } + + int ndims() const { return data_md()->ndims; } + + int axis() const { return desc_.axis; } + dim_t group_size() const { return desc_.group_size; } + dim_t axis_size() const { return data_md()->dims[axis()]; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + + const memory_desc_t *data_md() const { return &data_md_; } + +protected: + shuffle_desc_t desc_; + const shuffle_pd_t *hint_fwd_pd_; + memory_desc_t data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp new file mode 100644 index 0000000000..82848e3d1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::alg_kind; +using namespace mkldnn::impl::types; + +namespace { +status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind, + const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) { + bool args_ok = true + && !any_null(softmax_desc, data_desc) + && 0 <= softmax_axis + && softmax_axis < data_desc->ndims; + if (!args_ok) return invalid_arguments; + + auto sd = softmax_desc_t(); + sd.primitive_kind = primitive_kind::softmax; + sd.prop_kind = prop_kind; + + bool is_bwd = (sd.prop_kind == backward_data); + sd.data_desc = *data_desc; + sd.diff_desc = is_bwd ? *diff_desc : zero_md(); + sd.softmax_axis = softmax_axis; + + *softmax_desc = sd; + return success; +} +} + +status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc, + prop_kind_t prop_kind, const memory_desc_t *data_desc, + int softmax_axis) { + if (!one_of(prop_kind, forward_inference, forward_training)) + return invalid_arguments; + return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis); +} + +status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc, + const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, + int softmax_axis) { + return softmax_desc_init(softmax_desc, prop_kind::backward_data, + data_desc, diff_desc, softmax_axis); +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp new file mode 100644 index 0000000000..8a16ce901c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp @@ -0,0 +1,161 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SOFTMAX_PD_HPP +#define SOFTMAX_PD_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "primitive_desc.hpp" + +namespace mkldnn { +namespace impl { + +struct softmax_fwd_pd_t; + +struct softmax_pd_t: public primitive_desc_t { + static constexpr auto base_pkind = primitive_kind::softmax; + + softmax_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : primitive_desc_t(engine, attr, base_pkind) + , desc_(*adesc) + , hint_fwd_pd_(hint_fwd_pd) + , data_md_(desc_.data_desc) + {} + + const softmax_desc_t *desc() const { return &desc_; } + virtual const op_desc_t *op_desc() const override + { return reinterpret_cast<const op_desc_t *>(this->desc()); } + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual status_t query(query_t what, int idx, void *result) const override { + switch (what) { + case query::softmax_d: + *(const softmax_desc_t**)result = desc(); break; + default: return primitive_desc_t::query(what, idx, result); + } + return status::success; + } + + /* common softmax aux functions */ + + dim_t MB() const { return data_desc().dims[0]; } + dim_t C() const { return data_desc().dims[1]; } + dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; } + dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; } + dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; } + + int ndims() const { return data_desc().ndims; } + + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + } + +protected: + softmax_desc_t desc_; + const softmax_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t data_md_; + +private: + const memory_desc_t &data_desc() const { return desc_.data_desc; } +}; + +struct softmax_fwd_pd_t: public softmax_pd_t { + typedef softmax_fwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_fwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg == MKLDNN_ARG_SRC) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + + virtual int n_inputs() const override { return 1; } + virtual int n_outputs() const override + { return 1 + (workspace_md() != nullptr); } +}; + +struct softmax_bwd_pd_t: public softmax_pd_t { + typedef softmax_bwd_pd_t base_class; + typedef softmax_fwd_pd_t hint_class; + + softmax_bwd_pd_t(engine_t *engine, + const softmax_desc_t *adesc, + const primitive_attr_t *attr, + const softmax_fwd_pd_t *hint_fwd_pd) + : softmax_pd_t(engine, adesc, attr, hint_fwd_pd) + , diff_data_md_(desc_.diff_desc) + {} + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST)) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DIFF_SRC) + return arg_usage_t::output; + + if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr)) + return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &data_md_ : nullptr; } + virtual const memory_desc_t *diff_dst_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + virtual const memory_desc_t *diff_src_md(int index = 0) const override + { return index == 0 ? &diff_data_md_ : nullptr; } + + virtual int n_inputs() const override + { return 2 + (workspace_md() != nullptr); } + virtual int n_outputs() const override { return 1; } + +protected: + memory_desc_t diff_data_md_; +}; + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp new file mode 100644 index 0000000000..00af8935c0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "stream.hpp" +#include "utils.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::status; + +/* API */ + +status_t mkldnn_stream_create(stream_t **stream, engine_t *engine, + unsigned flags) { + bool args_ok = true + && !utils::any_null(stream, engine) + && flags == stream_flags::default_flags; + if (!args_ok) + return invalid_arguments; + + return safe_ptr_assign<stream_t>(*stream, new stream_t(engine, flags)); +} + +status_t mkldnn_stream_destroy(stream_t *stream) { + delete stream; + return success; +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp new file mode 100644 index 0000000000..f010e5f6ed --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef STREAM_HPP +#define STREAM_HPP + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" + +struct mkldnn_stream: public mkldnn::impl::c_compatible { + mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags) + : engine_(engine), flags_(flags) {} + virtual ~mkldnn_stream() {} + + /** returns stream's engine */ + mkldnn::impl::engine_t *engine() const { return engine_; } + + /** returns stream's kind */ + unsigned flags() const { return flags_; } + +protected: + mkldnn::impl::engine_t *engine_; + unsigned flags_; +}; + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp new file mode 100644 index 0000000000..365663c0f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp @@ -0,0 +1,79 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "engine.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "sum_pd.hpp" + +using namespace mkldnn::impl; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::status; + +status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds, const primitive_attr_t *attr, + engine_t *engine) { + bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0; + if (!args_ok) return invalid_arguments; + + const primitive_attr_t dummy_attr; + if (attr == NULL) + attr = &dummy_attr; + + const int ndims = src_mds[0].ndims; + const dims_t &dims = src_mds[0].dims; + const data_type_t dt = src_mds[0].data_type; + + for (int i = 1; i < n; ++i) { + if (src_mds[i].ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (src_mds[i].dims[d] != dims[d]) + return invalid_arguments; + } + if (src_mds[i].data_type != dt) return invalid_arguments; + } + + memory_desc_t dummy_dst_md; + if (dst_md) { + if (dst_md->ndims != ndims) return invalid_arguments; + for (int d = 0; d < ndims; ++d) { + if (dst_md->dims[d] != dims[d]) + return invalid_arguments; + } + } else { + dummy_dst_md = src_mds[0]; + dummy_dst_md.format_kind = format_kind::any; + dst_md = &dummy_dst_md; + } + + auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd); + + for (auto s = engine->get_sum_implementation_list(); *s; ++s) { + if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) { + (*s_pd)->init_info(); + (*s_pd)->init_scratchpad_md(); + return success; + } + } + return unimplemented; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp new file mode 100644 index 0000000000..80254667df --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SUM_PD_HPP +#define SUM_PD_HPP + +#include <assert.h> +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "primitive_desc.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +struct sum_pd_t: public primitive_desc_t { + sum_pd_t(engine_t *engine, const primitive_attr_t *attr, + const memory_desc_t *dst_md, int n, const float *scales, + const memory_desc_t *src_mds) + : primitive_desc_t(engine, attr, primitive_kind::sum) + , n_(n), dst_md_(*dst_md) + { + scales_.reserve(n_); + for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]); + src_mds_.reserve(n_); + for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]); + } + + virtual void init_info() override { impl::init_info(this, this->info_); } + + virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override { + if (arg >= MKLDNN_ARG_MULTIPLE_SRC + && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs()) + return arg_usage_t::input; + + if (arg == MKLDNN_ARG_DST) + return arg_usage_t::output; + + return primitive_desc_t::arg_usage(arg); + } + + virtual const memory_desc_t *src_md(int index = 0) const override + { return index < n_inputs() ? &src_mds_[index] : nullptr; } + virtual const memory_desc_t *dst_md(int index = 0) const override + { return index == 0 ? &dst_md_ : nullptr; } + + virtual int n_inputs() const override { return n_; } + virtual int n_outputs() const override { return 1; } + + const float *scales() const { return &scales_[0]; } + +protected: + int n_; + nstl::vector<float> scales_; + memory_desc_t dst_md_; + nstl::vector<memory_desc_t> src_mds_; + +protected: + /* inits dst_md_ in simple cases. The call may fail. */ + status_t init() { + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(&src_mds_[i]); + if (!src_d.is_blocking_desc() || src_d.is_additional_buffer()) + return status::unimplemented; + } + bool ok = true + && set_default_params() == status::success + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + status_t set_default_params() { + if (dst_md_.format_kind != format_kind::any) + return status::success; + + /* The stupidest ever heuristics (but not the same as we had before): + * - Pick the first non-plain format; + * - If all formats are plain, pick the format of the first input + */ + for (int i = 0; i < n_; ++i) { + const memory_desc_wrapper src_d(src_mds_[i]); + if (!src_d.is_plain() && src_d.is_blocking_desc()) { + return memory_desc_init_by_blocking_desc(dst_md_, + src_d.blocking_desc()); + } + } + + if (src_mds_[0].format_kind != format_kind::blocked) + return status::unimplemented; + + dst_md_ = src_mds_[0]; + + return status::success; + } +}; + +#define DECLARE_SUM_PD_t(impl_name, ...) \ + static status_t create(sum_pd_t **sum_pd, \ + engine_t *engine, const primitive_attr_t *attr, \ + const memory_desc_t *dst_md, int n, const float *scales, \ + const memory_desc_t *src_mds) { \ + using namespace status; \ + auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \ + if (_pd == nullptr) return out_of_memory; \ + if (_pd->init() != success) { delete _pd; return unimplemented; } \ + return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd); \ + } \ + virtual status_t create_primitive(primitive_t **p) const override { \ + double ms = get_msec(); \ + auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \ + ms = get_msec() - ms; \ + if (mkldnn_verbose()->level >= 2) { \ + printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \ + fflush(0); \ + } \ + return ret; \ + } \ + virtual pd_t *clone() const override { return new pd_t(*this); } \ + virtual const char *name() const override { return impl_name; } \ + +#define DECLARE_SUM_PD_T(impl_name, ...) \ + DECLARE_SUM_PD_t(impl_name, __VA_ARGS__) + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp new file mode 100644 index 0000000000..a408f45980 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp @@ -0,0 +1,200 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TAG_TRAITS_HPP +#define TAG_TRAITS_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +enum class block_dim_t { + _, + _A, _B, + _AB, _BC, +}; + +enum class inner_blk_t { + _, + _4a, _4b, + _8a, _8b, + _16a, _16b, + + _4b4a, _4b4c, _4c4b, + _8a8b, _8b8a, _8b8c, _8c8b, + _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b, + + _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c, +}; + +/** returns the offset within the block for weights blocked over oc and ic */ +template <inner_blk_t f> +constexpr int AB_or_BC_blk_off(int x0, int x1) { + using ib = inner_blk_t; + static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b, + ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b, + ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c), + "unexpected inner_blk format"); + return false ? 0 + : (f == ib::_4b4c) ? 4 * x0 + x1 + : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0 + : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1 + : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0 + : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1 + : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0 + : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1 + : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2 + : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4 + : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2 + : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4 + : INT_MIN; +} + +template <inner_blk_t b> struct inner_blk_traits { + using ib = inner_blk_t; +}; + +template <format_tag_t> struct tag_traits { + // block_dim_t block_dims; + // inner_blk_t inner_blks; + // int ndims; +}; + +#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \ +template <> struct tag_traits<format_tag::_tag> { \ + static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \ + static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \ + static constexpr int ndims = _ndims; \ +} + +DECL_TRAITS(a, _, _, 1); +DECL_TRAITS(ab, _, _, 2); +DECL_TRAITS(abc, _, _, 3); +DECL_TRAITS(abcd, _, _, 4); +DECL_TRAITS(abcde, _, _, 5); +DECL_TRAITS(abcdef, _, _, 6); +DECL_TRAITS(abdec, _, _, 5); +DECL_TRAITS(acb, _, _, 3); +DECL_TRAITS(acbde, _, _, 5); +DECL_TRAITS(acdb, _, _, 4); +DECL_TRAITS(acdeb, _, _, 5); +DECL_TRAITS(ba, _, _, 2); +DECL_TRAITS(bac, _, _, 3); +DECL_TRAITS(bacd, _, _, 4); +DECL_TRAITS(bcda, _, _, 4); +DECL_TRAITS(cba, _, _, 3); +DECL_TRAITS(cdba, _, _, 4); +DECL_TRAITS(cdeba, _, _, 5); +DECL_TRAITS(decab, _, _, 5); + +DECL_TRAITS(Abc4a, _A, _4a, 3); +DECL_TRAITS(aBc4b, _B, _4b, 3); +DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3); +DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3); +DECL_TRAITS(Abcd4a, _A, _4a, 4); +DECL_TRAITS(aBcd4b, _B, _4b, 4); +DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4); +DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4); +DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4); +DECL_TRAITS(Abcde4a, _A, _4a, 5); +DECL_TRAITS(aBcde4b, _B, _4b, 5); +DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5); +DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5); +DECL_TRAITS(aBcdef4b, _B, _4b, 6); +DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6); +DECL_TRAITS(aBdc4b, _B, _4b, 4); +DECL_TRAITS(aBdec4b, _B, _4b, 5); +DECL_TRAITS(aBdefc4b, _B, _4b, 6); +DECL_TRAITS(Acb4a, _A, _4a, 3); +DECL_TRAITS(Acdb4a, _A, _4a, 4); +DECL_TRAITS(Acdeb4a, _A, _4a, 5); + +DECL_TRAITS(Abc16a, _A, _16a, 3); +DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(aBc16b, _B, _16b, 3); +DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3); +DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3); +DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3); +DECL_TRAITS(aBc8b, _B, _8b, 3); +DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3); +DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3); +DECL_TRAITS(Abcd16a, _A, _16a, 4); +DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4); +DECL_TRAITS(aBcd16b, _B, _16b, 4); +DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4); +DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4); +DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4); +DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4); +DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4); +DECL_TRAITS(aBcd8b, _B, _8b, 4); +DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4); +DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4); +DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4); +DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4); +DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4); +DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4); +DECL_TRAITS(Abcde16a, _A, _16a, 5); +DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5); +DECL_TRAITS(aBcde16b, _B, _16b, 5); +DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5); +DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5); +DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5); +DECL_TRAITS(Abcde8a, _A, _8a, 5); +DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5); +DECL_TRAITS(aBcde8b, _B, _8b, 5); +DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5); +DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5); +DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5); +DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5); +DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5); +DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5); +DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5); +DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5); +DECL_TRAITS(aBcdef16b, _B, _16b, 6); +DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6); +DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6); +DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6); +DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6); +DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6); +DECL_TRAITS(aBdc16b, _B, _16b, 4); +DECL_TRAITS(aBdc8b, _B, _8b, 4); +DECL_TRAITS(aBdec16b, _B, _16b, 5); +DECL_TRAITS(aBdec8b, _B, _8b, 5); +DECL_TRAITS(aBdefc16b, _B, _16b, 6); +DECL_TRAITS(aBdefc8b, _B, _8b, 6); +DECL_TRAITS(Acb16a, _A, _16a, 3); +DECL_TRAITS(Acb8a, _A, _8a, 3); +DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4); +DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5); +DECL_TRAITS(Acdb16a, _A, _16a, 4); +DECL_TRAITS(Acdb8a, _A, _8a, 4); +DECL_TRAITS(Acdeb16a, _A, _16a, 5); +DECL_TRAITS(Acdeb8a, _A, _8a, 5); +DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3); +DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4); + +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp new file mode 100644 index 0000000000..4f06368738 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp @@ -0,0 +1,348 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef TYPE_HELPERS_HPP +#define TYPE_HELPERS_HPP + +#include <assert.h> +#include <math.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "mkldnn_traits.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "math_utils.hpp" + +namespace mkldnn { +namespace impl { + +template <typename T> +status_t safe_ptr_assign(T * &lhs, T* rhs) { + if (rhs == nullptr) return status::out_of_memory; + lhs = rhs; + return status::success; +} + +template <typename T, typename U> struct is_subset +{ static constexpr bool value = false; }; +template <typename T> struct is_subset<T, T> +{ static constexpr bool value = true; }; +template <typename T> struct is_subset<T, + typename utils::enable_if<nstl::is_integral<T>::value, float>::type> +{ static constexpr bool value = true; }; +#define ISSPEC(t1, t2) template <> \ + struct is_subset<t1, t2> { static constexpr bool value = true; } +ISSPEC(int16_t, int32_t); +ISSPEC(int8_t, int32_t); +ISSPEC(uint8_t, int32_t); +ISSPEC(int8_t, int16_t); +ISSPEC(uint8_t, int16_t); +#undef ISSPEC + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs); + +namespace types { + +inline size_t data_type_size(data_type_t data_type) { + using namespace data_type; + switch (data_type) { + case f32: return sizeof(prec_traits<f32>::type); + case s32: return sizeof(prec_traits<s32>::type); + case s8: return sizeof(prec_traits<s8>::type); + case u8: return sizeof(prec_traits<u8>::type); + case data_type::undef: + default: assert(!"unknown data_type"); + } + return 0; /* not supposed to be reachable */ +} + +inline format_kind_t format_tag_to_kind(format_tag_t tag) { + switch (tag) { + case format_tag::undef: return format_kind::undef; + case format_tag::any: return format_kind::any; + case format_tag::last: return format_kind::undef; + default: return format_kind::blocked; + } + + assert(!"unreachable"); + return format_kind::undef; +} + +inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs, + const memory_extra_desc_t &rhs) { + return true + && lhs.flags == rhs.flags + && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8, + lhs.compensation_mask == rhs.compensation_mask) + && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust, + lhs.scale_adjust == rhs.scale_adjust); +} + +inline bool blocking_desc_is_equal(const blocking_desc_t &lhs, + const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) { + using mkldnn::impl::utils::array_cmp; + return true + && lhs.inner_nblks == rhs.inner_nblks + && array_cmp(lhs.strides, rhs.strides, ndims) + && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks) + && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks); +} + +inline bool wino_desc_is_equal(const wino_desc_t &lhs, + const wino_desc_t &rhs) { + return lhs.wino_format == rhs.wino_format + && lhs.alpha == rhs.alpha + && lhs.ic == rhs.ic + && lhs.oc == rhs.oc + && lhs.ic_block == rhs.ic_block + && lhs.oc_block == rhs.oc_block + && lhs.ic2_block == rhs.ic2_block + && lhs.oc2_block == rhs.oc2_block + && lhs.r == rhs.r; +} + +inline bool rnn_packed_desc_is_equal( + const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) { + bool ok = true + && lhs.format == rhs.format + && lhs.n_parts == rhs.n_parts + && lhs.offset_compensation == rhs.offset_compensation + && lhs.size == rhs.size + && lhs.n == rhs.n; + if (!ok) + return false; + + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.parts[i] == rhs.parts[i]; + for (int i = 0; i < rhs.n_parts; i++) + ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i]; + return ok; +} + +inline memory_desc_t zero_md() { + auto zero = memory_desc_t(); + return zero; +} + +inline bool is_zero_md(const memory_desc_t *md) { + return md == nullptr || *md == zero_md(); +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t dst_dt) { + using namespace utils; + using namespace data_type; + + if (one_of(f32, src_dt, dst_dt)) return f32; + if (one_of(s32, src_dt, dst_dt)) return s32; + + if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32; + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +inline data_type_t default_accum_data_type(data_type_t src_dt, + data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) { + using namespace utils; + using namespace data_type; + using namespace prop_kind; + + /* prop_kind doesn't matter */ + if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32; + + if (one_of(prop_kind, forward_training, forward_inference)) { + if ((src_dt == u8 || src_dt == s8) + && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8)) + return s32; + } else if (prop_kind == backward_data) { + if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && + one_of(dst_dt, s8, u8)) + return s32; + } + + assert(!"unimplemented use-case: no default parameters available"); + return dst_dt; +} + +} + +inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) { + using namespace mkldnn::impl::utils; + bool base_equal = true + && lhs.ndims == rhs.ndims + && array_cmp(lhs.dims, rhs.dims, lhs.ndims) + && lhs.data_type == rhs.data_type + && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims) + && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims) + && lhs.offset0 == rhs.offset0 + && lhs.format_kind == rhs.format_kind; + if (!base_equal) return false; + if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false; + if (lhs.format_kind == format_kind::blocked) + return types::blocking_desc_is_equal(lhs.format_desc.blocking, + rhs.format_desc.blocking, lhs.ndims); + else if (lhs.format_kind == format_kind::wino) + return types::wino_desc_is_equal(lhs.format_desc.wino_desc, + rhs.format_desc.wino_desc); + else if (lhs.format_kind == format_kind::rnn_packed) + return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc, + rhs.format_desc.rnn_packed_desc); + return true; +} + +inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) { + return !operator==(lhs, rhs); +} + +inline status_t memory_desc_init_by_strides(memory_desc_t &md, + const dims_t strides) { + return mkldnn_memory_desc_init_by_strides( + &md, md.ndims, md.dims, md.data_type, strides); +} + +inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + status_t status = mkldnn_memory_desc_init_by_tag( + &md, md.ndims, md.dims, md.data_type, tag); + if (status != status::success || strides == nullptr) + return status; + + /* TODO: add consistency check */ + + for (int d = 0; d < md.ndims; ++d) + md.format_desc.blocking.strides[d] = strides[d]; + + return status::success; +} + +/** inits memory descriptor based on logical dimensions kept in @p md, and the + * blocking structure @p blk. + * + * @note blk.strides represent the order only (from smaller to bigger) + * + * TODO: move md related functions to one single place + */ +inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md, + const blocking_desc_t &blk) { + dims_t blocks = {0}; + utils::array_set(blocks, 1, md.ndims); + dim_t block_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk]; + block_size *= blk.inner_blks[iblk]; + } + + for (int d = 0; d < md.ndims; ++d) { + md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]); + md.padded_offsets[d] = 0; + } + md.offset0 = 0; + + md.format_kind = format_kind::blocked; + auto &mblk = md.format_desc.blocking; + mblk = blk; + + const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy + utils::array_copy(mblk.strides, blk.strides, ndims); + + int perm[MKLDNN_MAX_NDIMS]; + for (int d = 0; d < ndims; ++d) perm[d] = d; + + utils::simultaneous_sort(mblk.strides, perm, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + dim_t stride = block_size; + for (int _d = ndims - 1; _d >= 0; --_d) { + const int d = perm[_d]; + md.format_desc.blocking.strides[d] = stride; + stride *= md.padded_dims[d] / blocks[d]; + } + + md.extra = utils::zero<memory_extra_desc_t>(); + + return status::success; +} + +/** returns true if memory desc @p md corresponds to the given format tag and + * strides. + * If strides are not passed (or passed as nullptr) the dense structure is + * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns). + * Strides might contain `0` value, indicating the stride must match the one + * that mkldnn_memory_desc_init_by_tag() returns. + * Strides might contain `-1` values, that would be ignored during the + * comparison. For instance, this can be used if a stride along minibatch + * doesn't matter. */ +inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag, + const dims_t strides = nullptr) { + if (md.format_kind != types::format_tag_to_kind(tag)) + return false; + + memory_desc_t md_gold; + status_t status = mkldnn_memory_desc_init_by_tag( + &md_gold, md.ndims, md.dims, md.data_type, tag); + if (status != status::success) return false; + + if (md.format_kind != format_kind::blocked) + return false; // unimplemented yet + + const auto &blk = md.format_desc.blocking; + const auto &blk_gold = md_gold.format_desc.blocking; + + using utils::array_cmp; + bool same_blocks = true + && blk.inner_nblks == blk_gold.inner_nblks + && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks) + && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks); + + if (!same_blocks) + return false; + + if (strides == nullptr) + return array_cmp(blk.strides, blk_gold.strides, md.ndims); + + for (int d = 0; d < md.ndims; ++d) { + dim_t stride = strides[d]; + if (stride == -1) continue; + if (stride == 0) stride = blk_gold.strides[d]; + if (blk.strides[d] != stride) return false; + } + + return true; +} + +/** returns matching tag (or undef if match is not found) + * XXX: This is a workaround that eventually should go away! */ +template <typename... Tags> +format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md, + Tags ...tags) { + for (const auto tag: {tags...}) { + if (memory_desc_matches_tag(md, tag)) + return tag; + } + return format_tag::undef; +} + +} +} + +#include "memory_desc_wrapper.hpp" + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp new file mode 100644 index 0000000000..d23f4682dc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp @@ -0,0 +1,135 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <string.h> +#ifdef _WIN32 +#include <malloc.h> +#include <windows.h> +#endif +#include <limits.h> +#include <stdlib.h> +#include <stdio.h> + +#include "mkldnn.h" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { + +int getenv(const char *name, char *buffer, int buffer_size) { + if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0)) + return INT_MIN; + + int result = 0; + int term_zero_idx = 0; + size_t value_length = 0; + +#ifdef _WIN32 + value_length = GetEnvironmentVariable(name, buffer, buffer_size); +#else + const char *value = ::getenv(name); + value_length = value == NULL ? 0 : strlen(value); +#endif + + if (value_length > INT_MAX) + result = INT_MIN; + else { + int int_value_length = (int)value_length; + if (int_value_length >= buffer_size) { + result = -int_value_length; + } else { + term_zero_idx = int_value_length; + result = int_value_length; +#ifndef _WIN32 + strncpy(buffer, value, value_length); +#endif + } + } + + if (buffer != NULL) + buffer[term_zero_idx] = '\0'; + return result; +} + +int getenv_int(const char *name, int default_value) +{ + int value = default_value; + // # of digits in the longest 32-bit signed int + sign + terminating null + const int len = 12; + char value_str[len]; + if (getenv(name, value_str, len) > 0) + value = atoi(value_str); + return value; +} + +FILE *fopen(const char *filename, const char *mode) { +#ifdef _WIN32 + FILE *fp = NULL; + return ::fopen_s(&fp, filename, mode) ? NULL : fp; +#else + return ::fopen(filename, mode); +#endif +} + +void *malloc(size_t size, int alignment) { + void *ptr; + +#ifdef _WIN32 + ptr = _aligned_malloc(size, alignment); + int rc = ptr ? 0 : -1; +#else + int rc = ::posix_memalign(&ptr, alignment, size); +#endif + + return (rc == 0) ? ptr : 0; +} + +void free(void *p) { +#ifdef _WIN32 + _aligned_free(p); +#else + ::free(p); +#endif +} + +// Atomic operations +int32_t fetch_and_add(int32_t *dst, int32_t val) { +#ifdef _WIN32 + return InterlockedExchangeAdd(reinterpret_cast<long*>(dst), val); +#else + return __sync_fetch_and_add(dst, val); +#endif +} + +static int jit_dump_flag = 0; +static bool jit_dump_flag_initialized = false; +bool jit_dump_enabled() { + if (!jit_dump_flag_initialized) { + jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP"); + jit_dump_flag_initialized = true; + } + return jit_dump_flag != 0; +} + +} +} + +mkldnn_status_t mkldnn_set_jit_dump(int enabled) { + using namespace mkldnn::impl::status; + mkldnn::impl::jit_dump_flag = enabled; + mkldnn::impl::jit_dump_flag_initialized = true; + return success; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp new file mode 100644 index 0000000000..d5a8ec5139 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp @@ -0,0 +1,370 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef UTILS_HPP +#define UTILS_HPP + +#include <stddef.h> +#include <stdio.h> +#include <stdlib.h> +#include <assert.h> +#include <stdint.h> + +#if defined(__x86_64__) || defined(_M_X64) +#define MKLDNN_X86_64 +#endif + +#define MSAN_ENABLED 0 +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#undef MSAN_ENABLED +#define MSAN_ENABLED 1 +#include <sanitizer/msan_interface.h> +#endif +#endif + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +// Sanity check for 64 bits +static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only"); + +#define CHECK(f) do { \ + status_t status = f; \ + if (status != status::success) \ + return status; \ +} while (0) + +#define IMPLICATION(cause, effect) (!(cause) || !!(effect)) + +namespace utils { + +/* a bunch of std:: analogues to be compliant with any msvs version + * + * Rationale: msvs c++ (and even some c) headers contain special pragma that + * injects msvs-version check into object files in order to abi-mismatches + * during the static linking. This makes sense if e.g. std:: objects are passed + * through between application and library, which is not the case for mkl-dnn + * (since there is no any c++-rt dependent stuff, ideally...). */ + +/* SFINAE helper -- analogue to std::enable_if */ +template<bool expr, class T = void> struct enable_if {}; +template<class T> struct enable_if<true, T> { typedef T type; }; + +/* analogue std::conditional */ +template <bool, typename, typename> struct conditional {}; +template <typename T, typename F> struct conditional<true, T, F> +{ typedef T type; }; +template <typename T, typename F> struct conditional<false, T, F> +{ typedef F type; }; + +template <bool, typename, bool, typename, typename> struct conditional3 {}; +template <typename T, typename FT, typename FF> +struct conditional3<true, T, false, FT, FF> { typedef T type; }; +template <typename T, typename FT, typename FF> +struct conditional3<false, T, true, FT, FF> { typedef FT type; }; +template <typename T, typename FT, typename FF> +struct conditional3<false, T, false, FT, FF> { typedef FF type; }; + +template <bool, typename U, U, U> struct conditional_v {}; +template <typename U, U t, U f> struct conditional_v<true, U, t, f> +{ static constexpr U value = t; }; +template <typename U, U t, U f> struct conditional_v<false, U, t, f> +{ static constexpr U value = f; }; + +template <typename T> struct remove_reference { typedef T type; }; +template <typename T> struct remove_reference<T&> { typedef T type; }; +template <typename T> struct remove_reference<T&&> { typedef T type; }; + +template <typename T> +inline T&& forward(typename utils::remove_reference<T>::type &t) +{ return static_cast<T&&>(t); } +template <typename T> +inline T&& forward(typename utils::remove_reference<T>::type &&t) +{ return static_cast<T&&>(t); } + +template <typename T> +inline typename remove_reference<T>::type zero() +{ auto zero = typename remove_reference<T>::type(); return zero; } + +template <typename T, typename P> +inline bool everyone_is(T val, P item) { return val == item; } +template <typename T, typename P, typename... Args> +inline bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} + +template <typename T, typename P> +constexpr bool one_of(T val, P item) { return val == item; } +template <typename T, typename P, typename... Args> +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template <typename... Args> +inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); } + +template<typename T> +inline void array_copy(T *dst, const T *src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} +template<typename T> +inline bool array_cmp(const T *a1, const T *a2, size_t size) { + for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false; + return true; +} +template<typename T, typename U> +inline void array_set(T *arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val); +} + +namespace product_impl { +template<size_t> struct int2type{}; + +template <typename T> +constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; } + +template <typename T, size_t num> +inline T product_impl(const T *arr, int2type<num>) { + return arr[0]*product_impl(arr+1, int2type<num-1>()); } +} + +template <size_t num, typename T> +inline T array_product(const T *arr) { + return product_impl::product_impl(arr, product_impl::int2type<num-1>()); +} + +template<typename T, typename R = T> +inline R array_product(const T *arr, size_t size) { + R prod = 1; + for (size_t i = 0; i < size; ++i) prod *= arr[i]; + return prod; +} + +/** sorts an array of values using @p comparator. While sorting the array + * of value, the function permutes an array of @p keys accordingly. + * + * @note The arrays of @p keys can be omitted. In this case the function + * sorts the array of @vals only. + */ +template <typename T, typename U, typename F> +inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + if (comparator(vals[j], vals[j + 1]) > 0) { + nstl::swap(vals[j], vals[j + 1]); + if (keys) nstl::swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +template <typename T, typename U> +inline typename remove_reference<T>::type div_up(const T a, const U b) { + assert(b); + return (a + b - 1) / b; +} + +template <typename T, typename U> +inline typename remove_reference<T>::type rnd_up(const T a, const U b) { + return div_up(a, b) * b; +} + +template <typename T, typename U> +inline typename remove_reference<T>::type rnd_dn(const T a, const U b) { + return (a / b) * b; +} + +template <typename T> T *align_ptr(T *ptr, uintptr_t alignment) +{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); } + +template <typename T, typename U, typename V> +inline U this_block_size(const T offset, const U max, const V block_size) { + assert(offset < max); + // TODO (Roma): can't use nstl::max() due to circular dependency... we + // need to fix this + const T block_boundary = offset + block_size; + if (block_boundary > max) + return max - offset; + else + return block_size; +} + +template<typename T> +inline T nd_iterator_init(T start) { return start; } +template<typename T, typename U, typename W, typename... Args> +inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) { + start = nd_iterator_init(start, utils::forward<Args>(tuple)...); + x = start % X; + return start / X; +} + +inline bool nd_iterator_step() { return true; } +template<typename U, typename W, typename... Args> +inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) { + if (nd_iterator_step(utils::forward<Args>(tuple)...) ) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template<typename U, typename W, typename Y> +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X) +{ + U max_jump = end - cur; + U dim_jump = X - x; + if (dim_jump <= max_jump) { + x = 0; + cur += dim_jump; + return true; + } else { + cur += max_jump; + x += max_jump; + return false; + } +} +template<typename U, typename W, typename Y, typename... Args> +inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X, + Args &&... tuple) +{ + if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) { + x = (x + 1) % X; + return x == 0; + } + return false; +} + +template <typename T> +inline T pick(size_t i, const T &x0) { return x0; } +template <typename T, typename ...Args> +inline T pick(size_t i, const T &x0, Args &&... args) { + return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...); +} + +template <typename T> +T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference, + const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) { + switch (prop_kind) { + case prop_kind::forward_inference: return val_fwd_inference; + case prop_kind::forward_training: return val_fwd_training; + case prop_kind::backward_data: return val_bwd_d; + case prop_kind::backward_weights: return val_bwd_w; + default: assert(!"unsupported prop_kind"); + } + return T(); +} + +template <typename T> +T pick_by_prop_kind(prop_kind_t prop_kind, + const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w) +{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); } + +template <typename Telem, size_t Tdims> +struct array_offset_calculator { + template <typename... Targs> + array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... } + { + _base_ptr = base; + } + template <typename... Targs> + inline Telem &operator()(Targs... Fargs) + { + return *(_base_ptr + _offset(1, Fargs...)); + } + +private: + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t element) + { + return element; + } + + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t theta, size_t element) + { + return element + (_dims[dimension] * theta); + } + + template <typename... Targs> + inline size_t _offset(size_t const dimension, size_t theta, size_t element, + Targs... Fargs) + { + size_t t_prime = element + (_dims[dimension] * theta); + return _offset(dimension + 1, t_prime, Fargs...); + } + + Telem *_base_ptr; + const int _dims[Tdims]; +}; + +} + +int32_t fetch_and_add(int32_t *dst, int32_t val); +inline void yield_thread() {} + +// Reads an environment variable 'name' and stores its string value in the +// 'buffer' of 'buffer_size' bytes on success. +// +// - Returns the length of the environment variable string value (excluding +// the terminating 0) if it is set and its contents (including the terminating +// 0) can be stored in the 'buffer' without truncation. +// +// - Returns negated length of environment variable string value and writes +// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to +// store the value (including the terminating 0) without truncation. +// +// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment +// variable is not set. +// +// - Returns INT_MIN if the 'name' is NULL. +// +// - Returns INT_MIN if the 'buffer_size' is negative. +// +// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than +// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to +// retrieve the length of the environment variable value string. +// +int getenv(const char *name, char *buffer, int buffer_size); +// Reads an integer from the environment +int getenv_int(const char *name, int default_value = 0); +bool jit_dump_enabled(); +FILE *fopen(const char *filename, const char *mode); + +constexpr int msan_enabled = MSAN_ENABLED; +inline void msan_unpoison(void *ptr, size_t size) { +#if MSAN_ENABLED + __msan_unpoison(ptr, size); +#endif +} + +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp new file mode 100644 index 0000000000..89a57772cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp @@ -0,0 +1,665 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <stdlib.h> +#ifndef _WIN32 +#include <sys/time.h> +#endif + +#include "mkldnn.h" +#include "mkldnn_version.h" +#include "c_types_map.hpp" +#include "verbose.hpp" +#include "cpu/cpu_isa_traits.hpp" + +#include "batch_normalization_pd.hpp" +#include "pooling_pd.hpp" +#include "concat_pd.hpp" +#include "reorder_pd.hpp" +#include "convolution_pd.hpp" +#include "rnn_pd.hpp" +#include "deconvolution_pd.hpp" +#include "shuffle_pd.hpp" +#include "eltwise_pd.hpp" +#include "softmax_pd.hpp" +#include "inner_product_pd.hpp" +#include "sum_pd.hpp" +#include "lrn_pd.hpp" + +/* MKL-DNN CPU ISA info */ +#define ISA_ANY "No instruction set specific optimizations" +#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)" +#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)" +#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)" +#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512)" +#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions" +#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \ + "AVX512-DL Boost)" +#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions" +#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \ + "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions" + +namespace mkldnn { +namespace impl { + +static verbose_t verbose; +static bool initialized; +static bool version_printed = false; + +const verbose_t *mkldnn_verbose() { +#if !defined(DISABLE_VERBOSE) + if (!initialized) { + const int len = 2; + char val[len] = {0}; + if (getenv("MKLDNN_VERBOSE", val, len) == 1) + verbose.level = atoi(val); + initialized = true; + } + if (!version_printed && verbose.level > 0) { + printf("mkldnn_verbose,info," + "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n", + mkldnn_version()->major, mkldnn_version()->minor, + mkldnn_version()->patch, mkldnn_version()->hash, + get_isa_info()); + version_printed = true; + } +#else + verbose.level = 0; +#endif + return &verbose; +} + +double get_msec() { +#ifdef _WIN32 + static LARGE_INTEGER frequency; + if (frequency.QuadPart == 0) + QueryPerformanceFrequency(&frequency); + LARGE_INTEGER now; + QueryPerformanceCounter(&now); + return 1e+3 * now.QuadPart / frequency.QuadPart; +#else + struct timeval time; + gettimeofday(&time, NULL); + return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec; +#endif +} + +const char *get_isa_info() { + using namespace mkldnn::impl::cpu; + if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS; + if (mayiuse(avx512_mic)) return AVX512_MIC; + if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI; + if (mayiuse(avx512_core)) return AVX512_CORE; + if (mayiuse(avx512_common)) return AVX512_COMMON; + if (mayiuse(avx2)) return AVX2; + if (mayiuse(avx)) return AVX; + if (mayiuse(sse42)) return SSE42; + return ISA_ANY; +} + +/* init_info section */ +namespace { +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_DAT_LEN 256 +#define MKLDNN_VERBOSE_AUX_LEN 384 +#define MKLDNN_VERBOSE_PRB_LEN 384 + +#define DECL_DAT_AUX_PRB_STRS() \ + int dat_written = 0, aux_written = 0, prb_written = 0; \ + MAYBE_UNUSED((dat_written * aux_written * prb_written)); \ + char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \ + char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \ + char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str) + +#define DFMT "%" PRId64 + +void clear_buf(char *buf, int &written) { + /* TODO: do it better */ + buf[0] = '#'; + buf[1] = '\0'; + written = 1; +} + +#define DPRINT(buf, buf_len, written, ...) do { \ + int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \ + if (l < 0 || written + l > buf_len) { \ + clear_buf(buf, written); \ + } else { \ + written += l; \ + } \ +} while(0) + +// XXX: Outputs strings corresponding to memory formats used for data tensors. +void format_prb_desc_str(char *str, int len, const memory_desc_t *md) { + const auto dims = md->dims; + int written = 0; + if (md->ndims == 1) + DPRINT(str, len, written, + "x" DFMT, dims[0]); + else if (md->ndims == 2) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT, dims[0], dims[1]); + else if (md->ndims == 3) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "iw" DFMT, + dims[0], dims[1], dims[2]); + else if (md->ndims == 4) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3]); + else if (md->ndims == 5) + DPRINT(str, len, written, + "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT, + dims[0], dims[1], dims[2], dims[3], dims[4]); + else + mkldnn_md2dim_str(str, len, md); +} + +void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind, + const char *impl_str, mkldnn_prop_kind_t prop_kind, + const char *data_str, const char *aux_str, const char *prb_str) { + MAYBE_UNUSED(verbose_templ); + int written = 0; + DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s", + mkldnn_prim_kind2str(prim_kind), impl_str, + mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str); +} + +template <typename pd_t> static void init_info_bnorm(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "flags:%u", s->desc()->flags); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_conv(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->ndims() == 5) { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } else { + if (s->with_groups()) + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->G(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + else + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "_ic" DFMT "oc" DFMT + "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT + "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT, + s->MB(), s->IC(), s->OC(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_shuffle(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md(); + + if (1) { // data + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "axis:%d group_size:" DFMT, s->axis(), s->group_size()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_eltwise(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_iprod(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->desc()->prop_kind == prop_kind::backward_data + ? s->diff_src_md() : s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md() : s->weights_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bia + auto md = s->desc()->prop_kind == prop_kind::backward_weights + ? s->diff_weights_md(1) : s->weights_md(1); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + if (1) { // dst + auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_lrn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_mem(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "num:%d", s->n_inputs()); + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_pool(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src + auto md = s->is_fwd() ? s->src_md() : s->diff_src_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst + auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // ws + auto md = s->workspace_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind)); + + if (s->is_3d()) { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "", + s->MB(), s->C(), + s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } else { + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "mb" DFMT "ic" DFMT "_" + "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_" + "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT, + s->MB(), s->C(), + s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(), + s->IW(), s->OW(), s->KW(), s->KSW(), s->padL()); + } + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_softmax(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // data + auto md = s->dst_md(); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // diff data + auto md = s->diff_src_md(); + if (md) { + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + } + + mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +template <typename pd_t> static void init_info_rnn(pd_t *s, char *buffer) { + DECL_DAT_AUX_PRB_STRS(); + + if (1) { // src layer + auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // src iter + auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_layer + auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // wei_iter + auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // bias + auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst layer + auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + if (1) { // dst iter + auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1); + DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_"); + int l = mkldnn_md2fmt_str(dat_str + dat_written, + MKLDNN_VERBOSE_DAT_LEN - dat_written, md); + if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written); + } + + alg_kind_t alg_kind = s->cell_kind(); + rnn_direction_t rnn_dir = s->direction(); + DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written, + "alg:%s_%s", mkldnn_alg_kind2str(alg_kind), + mkldnn_rnn_direction2str(rnn_dir)); + + DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written, + "l" DFMT "t" DFMT "mb" DFMT + "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT, + s->L(), s->T(), s->MB(), + s->SIC(), s->SLC(), s->DIC(), s->DLC()); + + verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str, + aux_str, prb_str); +} + +#undef DPRINT + +#else // !defined(DISABLE_VERBOSE) + +#define DEFINE_STUB(name) \ + template <typename pd_t> \ + static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \ + { UNUSED(s); UNUSED(buffer); } + +DEFINE_STUB(bnorm); +DEFINE_STUB(conv); +DEFINE_STUB(eltwise); +DEFINE_STUB(iprod); +DEFINE_STUB(lrn); +DEFINE_STUB(mem); +DEFINE_STUB(pool); +DEFINE_STUB(softmax); +DEFINE_STUB(rnn); +DEFINE_STUB(shuffle); +#undef DEFINE_STUB + +#endif // !defined(DISABLE_VERBOSE) +} + +void init_info(batch_normalization_pd_t *s, char *b) +{ init_info_bnorm(s, b); } +void init_info(concat_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(convolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(deconvolution_pd_t *s, char *b) +{ init_info_conv(s, b); } +void init_info(eltwise_pd_t *s, char *b) +{ init_info_eltwise(s, b); } +void init_info(inner_product_pd_t *s, char *b) +{ init_info_iprod(s, b); } +void init_info(lrn_pd_t *s, char *b) +{ init_info_lrn(s, b); } +void init_info(pooling_pd_t *s, char *b) +{ init_info_pool(s, b); } +void init_info(reorder_pd_t *s, char *b) +{ init_info_mem(s, b); } +void init_info(rnn_pd_t *s, char *b) +{ init_info_rnn(s, b); } +void init_info(shuffle_pd_t *s, char *b) +{ init_info_shuffle(s, b); } +void init_info(softmax_pd_t *s, char *b) +{ init_info_softmax(s, b); } +void init_info(sum_pd_t *s, char *b) +{ init_info_mem(s, b); } + +} +} + +mkldnn_status_t mkldnn_set_verbose(int level) { + using namespace mkldnn::impl::status; + if (level < 0 || level > 2) return invalid_arguments; + mkldnn::impl::verbose.level = level; + mkldnn::impl::initialized = true; + return success; +} + +const mkldnn_version_t *mkldnn_version() { + static mkldnn_version_t ver = { + MKLDNN_VERSION_MAJOR, + MKLDNN_VERSION_MINOR, + MKLDNN_VERSION_PATCH, + MKLDNN_VERSION_HASH}; + return &ver; +} diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp new file mode 100644 index 0000000000..e3049750cb --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef VERBOSE_HPP +#define VERBOSE_HPP + +#include <stdio.h> +#include <cinttypes> + +#include "mkldnn_debug.h" +#include "c_types_map.hpp" +#include "utils.hpp" +#include "z_magic.hpp" + +namespace mkldnn { +namespace impl { + +struct verbose_t { + int level; +}; + +const verbose_t *mkldnn_verbose(); +double get_msec(); +const char *get_isa_info(); + +#if !defined(DISABLE_VERBOSE) +#define MKLDNN_VERBOSE_BUF_LEN 1024 +#else +#define MKLDNN_VERBOSE_BUF_LEN 1 +#endif + +void init_info(batch_normalization_pd_t *s, char *buffer); +void init_info(concat_pd_t *s, char *buffer); +void init_info(convolution_pd_t *s, char *buffer); +void init_info(deconvolution_pd_t *s, char *buffer); +void init_info(eltwise_pd_t *s, char *buffer); +void init_info(inner_product_pd_t *s, char *buffer); +void init_info(lrn_pd_t *s, char *buffer); +void init_info(pooling_pd_t *s, char *buffer); +void init_info(reorder_pd_t *s, char *buffer); +void init_info(rnn_pd_t *s, char *buffer); +void init_info(shuffle_pd_t *s, char *buffer); +void init_info(softmax_pd_t *s, char *buffer); +void init_info(sum_pd_t *s, char *buffer); + +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp new file mode 100644 index 0000000000..520bd4710b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef Z_MAGIC_HPP +#define Z_MAGIC_HPP + +#define CHAIn2(a,b) a b +#define CHAIN2(a,b) CHAIn2(a,b) + +#define CONCAt2(a,b) a ## b +#define CONCAT2(a,b) CONCAt2(a,b) + +#define STRINGIFy(s) #s +#define STRINGIFY(s) STRINGIFy(s) + +#ifdef _MSC_VER +# define PRAGMA_MACRo(x) __pragma(x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#else +# define PRAGMA_MACRo(x) _Pragma(#x) +# define PRAGMA_MACRO(x) PRAGMA_MACRo(x) +#endif + +#define UNUSED(x) ((void)x) +#define MAYBE_UNUSED(x) UNUSED(x) + +#if defined(_WIN32) && !defined(__GNUC__) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp new file mode 100644 index 0000000000..7cf7822d90 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr) { +# define BAR_CTR_OFF offsetof(ctx_t, ctr) +# define BAR_SENSE_OFF offsetof(ctx_t, sense) + using namespace Xbyak; + + Xbyak::Reg64 reg_tmp = [&]() { + /* returns register which is neither reg_ctx nor reg_nthr */ + Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx }; + for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i) + if (!utils::one_of(regs[i], reg_ctx, reg_nthr)) + return regs[i]; + return regs[0]; /* should not happen */ + }(); + + Label barrier_exit_label, barrier_exit_restore_label, spin_label; + + code.cmp(reg_nthr, 1); + code.jbe(barrier_exit_label); + + code.push(reg_tmp); + + /* take and save current sense */ + code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.push(reg_tmp); + code.mov(reg_tmp, 1); + + if (mayiuse(avx512_mic)) { + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]); + } + + code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp); + code.add(reg_tmp, 1); + code.cmp(reg_tmp, reg_nthr); + code.pop(reg_tmp); /* restore previous sense */ + code.jne(spin_label); + + /* the last thread {{{ */ + code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx + + // notify waiting threads + code.not_(reg_tmp); + code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp); + code.jmp(barrier_exit_restore_label); + /* }}} the last thread */ + + code.CodeGenerator::L(spin_label); + code.pause(); + code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]); + code.je(spin_label); + + code.CodeGenerator::L(barrier_exit_restore_label); + code.pop(reg_tmp); + + code.CodeGenerator::L(barrier_exit_label); +# undef BAR_CTR_OFF +# undef BAR_SENSE_OFF +} + +/** jit barrier generator */ +struct jit_t: public jit_generator { + void (*barrier)(ctx_t *ctx, size_t nthr); + + jit_t() { + generate(*this, abi_param1, abi_param2); + ret(); + barrier = reinterpret_cast<decltype(barrier)>(const_cast<uint8_t*>( + this->getCode())); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t) +}; + +void barrier(ctx_t *ctx, int nthr) { + static jit_t j; /* XXX: constructed on load ... */ + j.barrier(ctx, nthr); +} + +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp new file mode 100644 index 0000000000..0f55e33aa8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BARRIER_HPP +#define CPU_BARRIER_HPP + +#include <assert.h> + +#include "jit_generator.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace simple_barrier { + +STRUCT_ALIGN(64, +struct ctx_t { + enum { CACHE_LINE_SIZE = 64 }; + volatile size_t ctr; + char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; + volatile size_t sense; + char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)]; +}); + +inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero<ctx_t>(); } +void barrier(ctx_t *ctx, int nthr); + +/** injects actual barrier implementation into another jitted code + * @params: + * code -- jit_generator object where the barrier is to be injected + * reg_ctx -- read-only register with pointer to the barrier context + * reg_nnthr -- read-only register with the # of synchronizing threads + */ +void generate(jit_generator &code, Xbyak::Reg64 reg_ctx, + Xbyak::Reg64 reg_nthr); + +} + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp new file mode 100644 index 0000000000..1ed5ad57b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_PD_HPP +#define CPU_BATCH_NORMALIZATION_PD_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t { + using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t; +}; + +struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t { + using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp new file mode 100644 index 0000000000..b8d5c4fcaf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "cpu_batch_normalization_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters) { + int nthrs = mkldnn_get_max_threads(); + int l3_size = get_cache_size(3, true) * nthrs / 2; + + C_blks_per_iter = l3_size / working_set_size; + + if (C_blks_per_iter == 0) + C_blks_per_iter = 1; + if (C_blks_per_iter > C_blks) + C_blks_per_iter = C_blks; + + iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter; +} + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) { + if (nthr <= C_blks || !mkldnn_thr_syncable()) { + C_ithr = ithr; C_nthr = nthr; + N_ithr = 0; N_nthr = 1; + S_ithr = 0; S_nthr = 1; + N_s = 0; N_e = N; S_s = 0; S_e = SP; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + } else { + if (do_blocking) { + N_nthr = (int)nstl::min<dim_t>(N, nthr); + C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr); + S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr)); + } else { + C_nthr = (int)math::gcd((dim_t)nthr, C_blks); + N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr); + S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr)); + } + + if (!spatial_thr_allowed) + S_nthr = 1; + + if (S_nthr < 1) S_nthr = 1; + if (ithr < C_nthr * N_nthr * S_nthr) { + N_ithr = (ithr / S_nthr) % N_nthr ; + C_ithr = ithr / (N_nthr * S_nthr); + S_ithr = ithr % S_nthr; + balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e); + balance211(N, N_nthr, N_ithr, N_s, N_e); + balance211(SP, S_nthr, S_ithr, S_s, S_e); + } else { + S_ithr = N_ithr = C_ithr = -ithr; + S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1; + } + } + + // spatial_thr_allowed is meant to help maintain + // consistent decisions about spatial threading + // between mutiple invocations of this routine. + // It is caller's responsibility to check the + // return value and pass it as a flag to the + // next call if needed. + if (S_nthr == 1) + spatial_thr_allowed = false; + + return spatial_thr_allowed; +} + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size) { + if (!mkldnn_thr_syncable()) return false; + + dim_t nthr = mkldnn_get_max_threads(); + dim_t SP = bdesc->W() * bdesc->D() * bdesc->H(); + dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md()) + .padded_dims()[1]; + assert(C_PADDED % simd_w == 0); + + size_t data = bdesc->MB() * C_PADDED * SP * data_size; + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0); + dim_t C_blks_per_iter{ 1 }, iters{ 1 }; + dim_t C_blks = C_PADDED / simd_w; + + if (do_blocking) { + int num_tensors = bdesc->is_fwd() ? 1 : 2; + size_t working_set_size + = (bdesc->MB() * SP * simd_w * data_size) * num_tensors; + cache_balance(working_set_size, C_blks, C_blks_per_iter, iters); + } + + // Spatial threading decision made in this function shall be consistent + // with thread_balance() behavior. + C_blks = do_blocking ? C_blks_per_iter : C_blks; + + if (nthr <= C_blks) return false; + + dim_t S_nthr = 1; + if (do_blocking) { + dim_t N_nthr = nstl::min(bdesc->MB(), nthr); + dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } else { + dim_t C_nthr = math::gcd(nthr, C_blks); + dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr); + S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr)); + } + + return S_nthr > 1; +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp new file mode 100644 index 0000000000..0daef0716c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP +#define CPU_BATCH_NORMALIZATION_UTILS_HPP + +#include "batch_normalization_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace bnorm_utils { + +void cache_balance(size_t working_set_size, dim_t C_blks, + dim_t &C_blks_per_iter, int64_t &iters); + +bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr, + int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr, + dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s, + dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e); + +bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w, + int data_size); + +} +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp new file mode 100644 index 0000000000..b926491202 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_concat.hpp" +#include "cpu/simple_concat.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const cpd_create_f cpu_concat_impl_list[] = { + /* + INSTANCE(simple_concat_t<data_type::f32>), + INSTANCE(simple_concat_t<data_type::u8>), + INSTANCE(simple_concat_t<data_type::s8>), + INSTANCE(simple_concat_t<data_type::s32>), + INSTANCE(ref_concat_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const { + return cpu_concat_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp new file mode 100644 index 0000000000..0b01bcf163 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONCAT_PD_HPP +#define CPU_CONCAT_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "concat_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_concat_pd_t: public concat_pd_t { + using concat_pd_t::concat_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp new file mode 100644 index 0000000000..52a38a2294 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_CONVOLUTION_PD_HPP +#define CPU_CONVOLUTION_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "convolution_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t { + using convolution_fwd_pd_t::convolution_fwd_pd_t; + + bool has_padded_dst() const { + memory_desc_wrapper dst_d(&dst_md_); + return OC() != dst_d.padded_dims()[1]; + } + + bool wants_padded_bias() const { + if (!with_bias()) return false; + return has_padded_dst(); + } + + bool wants_zero_pad_dst(bool jit_impl = true) const { + if (!has_padded_dst()) return false; + const auto &po = attr()->post_ops_; + int idx; + if ((idx = po.find(primitive_kind::eltwise)) == -1) return false; + return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg, + jit_impl); + } +}; + +struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t { + using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t; +}; + +struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t { + using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t; + + bool wants_padded_bias() const { + if (!with_bias()) return false; + memory_desc_wrapper diff_dst_d(&diff_dst_md_); + return OC() != diff_dst_d.padded_dims()[1]; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp new file mode 100644 index 0000000000..164c8601d7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_DECONVOLUTION_PD_HPP +#define CPU_DECONVOLUTION_PD_HPP + +#include <assert.h> + +#include "deconvolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t { + using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t; +}; + +struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t { + using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t; +}; + +struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t { + using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp new file mode 100644 index 0000000000..c52f00026e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ELTWISE_PD_HPP +#define CPU_ELTWISE_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "eltwise_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t { + using eltwise_fwd_pd_t::eltwise_fwd_pd_t; +}; + +struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t { + using eltwise_bwd_pd_t::eltwise_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp new file mode 100644 index 0000000000..ce0a3667ad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp @@ -0,0 +1,324 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "type_helpers.hpp" +#include "verbose.hpp" + +#include "cpu_engine.hpp" +#include "cpu_memory.hpp" + +//#include "cpu/rnn/ref_rnn.hpp" + +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" +//#include "cpu/jit_avx512_common_1x1_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp" +#include "cpu/jit_avx512_common_convolution_winograd.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp" +#include "cpu/jit_avx512_common_convolution.hpp" +//#include "cpu/jit_avx2_1x1_convolution.hpp" +//#include "cpu/jit_sse42_1x1_convolution.hpp" +#include "cpu/jit_avx2_convolution.hpp" +#include "cpu/jit_sse42_convolution.hpp" +//#include "cpu/gemm_convolution.hpp" +//#include "cpu/gemm_x8s8s32x_convolution.hpp" +//#include "cpu/ref_convolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp" +//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp" +//#include "cpu/ref_deconvolution.hpp" +//#include "cpu/ref_shuffle.hpp" +//#include "cpu/jit_uni_eltwise.hpp" +//#include "cpu/ref_eltwise.hpp" +//#include "cpu/ref_softmax.hpp" +#include "cpu/jit_uni_pooling.hpp" +//#include "cpu/jit_uni_i8i8_pooling.hpp" +//#include "cpu/ref_pooling.hpp" +//#include "cpu/nchw_pooling.hpp" +//#include "cpu/nhwc_pooling.hpp" +//#include "cpu/jit_avx512_common_lrn.hpp" +//#include "cpu/jit_uni_lrn.hpp" +//#include "cpu/ref_lrn.hpp" +//#include "cpu/jit_uni_batch_normalization.hpp" +//#include "cpu/ref_batch_normalization.hpp" +//#include "cpu/ncsp_batch_normalization.hpp" +//#include "cpu/nspc_batch_normalization.hpp" +//#include "cpu/ref_inner_product.hpp" +//#include "cpu/gemm_inner_product.hpp" +//#include "cpu/gemm_x8s8s32x_inner_product.hpp" +//#include "cpu/jit_uni_dw_convolution.hpp" +//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +status_t cpu_engine_t::memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) { + auto _memory = new cpu_memory_t(this, md, handle); + if (_memory == nullptr) + return status::out_of_memory; + + status_t status = _memory->init(); + if (status != status::success) { + delete _memory; + return status; + } + + return safe_ptr_assign<memory_t>(*memory, _memory); +} + +using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; + +#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t> +static const pd_create_f cpu_impl_list[] = { + /* RNN */ + /* + INSTANCE(ref_rnn_fwd_f32_t), + INSTANCE(ref_rnn_fwd_u8s8_t), + INSTANCE(ref_rnn_bwd_f32_t), + */ + /* conv */ + /* + INSTANCE(jit_avx512_common_dw_convolution_fwd_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t), + INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t), + INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t), + */ + INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t), + INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t), + //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_winograd_fwd_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t), + //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t), + INSTANCE(jit_avx512_common_convolution_fwd_t<f32>), + //INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>), + //INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>), + /* + INSTANCE(jit_avx2_dw_convolution_fwd_t), + INSTANCE(jit_avx2_dw_convolution_bwd_data_t), + INSTANCE(jit_avx2_dw_convolution_bwd_weights_t), + INSTANCE(jit_avx2_1x1_convolution_fwd_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_data_t), + INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t), + INSTANCE(jit_sse42_dw_convolution_fwd_t), + INSTANCE(jit_sse42_dw_convolution_bwd_data_t), + INSTANCE(jit_sse42_dw_convolution_bwd_weights_t), + INSTANCE(jit_sse42_1x1_convolution_fwd_t), + */ + INSTANCE(jit_avx2_convolution_fwd_t), + //INSTANCE(jit_avx2_convolution_bwd_data_t), + //INSTANCE(jit_avx2_convolution_bwd_weights_t), + INSTANCE(jit_sse42_convolution_fwd_t), + /* + INSTANCE(gemm_convolution_fwd_t), + INSTANCE(gemm_convolution_bwd_data_t), + INSTANCE(gemm_convolution_bwd_weights_t), + INSTANCE(ref_convolution_fwd_t<f32>), + INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>), + INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>), + */ + /* conv (int) */ + /* + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>), + INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>), + INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>), + INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>), + INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>), + INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>), + INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>), + INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>), + INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>), + INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>), + INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>), + INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>), + */ + /* deconv */ + /* + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>), + INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>), + INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>), + INSTANCE(ref_deconvolution_bwd_weights_t), + INSTANCE(ref_deconvolution_bwd_data_t), + INSTANCE(ref_deconvolution_fwd_t), + */ + /* shuffle */ + /* + INSTANCE(ref_shuffle_t<4>), // f32 or s32 + INSTANCE(ref_shuffle_t<1>), // s8 or u8 + */ + /* eltwise */ + /* + INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>), + INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>), + INSTANCE(jit_uni_eltwise_fwd_t<avx2>), + INSTANCE(jit_uni_eltwise_bwd_t<avx2>), + INSTANCE(jit_uni_eltwise_fwd_t<sse42>), + INSTANCE(jit_uni_eltwise_bwd_t<sse42>), + INSTANCE(ref_eltwise_fwd_t<f32>), + INSTANCE(ref_eltwise_bwd_t<f32>), + */ + /* eltwise (int) */ + /* + INSTANCE(ref_eltwise_fwd_t<s32>), + INSTANCE(ref_eltwise_fwd_t<s8>), + INSTANCE(ref_eltwise_fwd_t<u8>), + INSTANCE(ref_eltwise_bwd_t<s32>), + */ + /* softmax */ + /* + INSTANCE(ref_softmax_fwd_t<f32>), + INSTANCE(ref_softmax_bwd_t<f32>), + */ + /* pool */ + INSTANCE(jit_uni_pooling_fwd_t<avx512_common>), + //INSTANCE(jit_uni_pooling_bwd_t<avx512_common>), + INSTANCE(jit_uni_pooling_fwd_t<avx>), + //INSTANCE(jit_uni_pooling_bwd_t<avx>), + INSTANCE(jit_uni_pooling_fwd_t<sse42>), + //INSTANCE(jit_uni_pooling_bwd_t<sse42>), + /* + INSTANCE(nchw_pooling_fwd_t<f32>), + INSTANCE(nchw_pooling_bwd_t<f32>), + INSTANCE(nhwc_pooling_fwd_t<f32>), + INSTANCE(nhwc_pooling_bwd_t<f32>), + INSTANCE(ref_pooling_fwd_t<f32>), + INSTANCE(ref_pooling_bwd_t<f32>), + */ + /* pool (int) */ + /* + INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>), + INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>), + INSTANCE(ref_pooling_fwd_t<s32>), + INSTANCE(ref_pooling_fwd_t<s8, s32>), + INSTANCE(ref_pooling_fwd_t<u8, s32>), + INSTANCE(ref_pooling_bwd_t<s32>), + */ + /* lrn */ + /* + INSTANCE(jit_avx512_common_lrn_fwd_t), + INSTANCE(jit_avx512_common_lrn_bwd_t), + INSTANCE(jit_uni_lrn_fwd_t<avx2>), + INSTANCE(jit_uni_lrn_bwd_t<avx2>), + INSTANCE(jit_uni_lrn_fwd_t<sse42>), + INSTANCE(ref_lrn_fwd_t<f32>), + INSTANCE(ref_lrn_bwd_t<f32>), + */ + /* batch normalization */ + /* + INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>), + INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>), + INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>), + INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>), + INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>), + INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>), + INSTANCE(ncsp_batch_normalization_fwd_t), + INSTANCE(ncsp_batch_normalization_bwd_t), + INSTANCE(nspc_batch_normalization_fwd_t), + INSTANCE(nspc_batch_normalization_bwd_t), + INSTANCE(ref_batch_normalization_fwd_t<f32>), + INSTANCE(ref_batch_normalization_bwd_t<f32>), + INSTANCE(ref_batch_normalization_fwd_t<s8>), + */ + /* inner product */ + /* + INSTANCE(gemm_inner_product_fwd_t<f32>), + INSTANCE(gemm_inner_product_bwd_data_t<f32>), + INSTANCE(gemm_inner_product_bwd_weights_t<f32>), + INSTANCE(ref_inner_product_fwd_t<f32>), + INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>), + INSTANCE(ref_inner_product_bwd_weights_t<f32>), + */ + /* inner product (int) */ + /* + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>), + INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>), + INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>), + INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>), + INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>), + INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>), + */ + /* eol */ + nullptr, +}; +#undef INSTANCE +} + +const pd_create_f* cpu_engine_t::get_implementation_list() const { + return cpu_impl_list; +} + +cpu_engine_factory_t engine_factory; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp new file mode 100644 index 0000000000..e4c877ee05 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp @@ -0,0 +1,70 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ENGINE_HPP +#define CPU_ENGINE_HPP + +#include <assert.h> + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "../common/engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class cpu_engine_t: public engine_t { +public: + cpu_engine_t(): engine_t(engine_kind::cpu) {} + + /* implementation part */ + + virtual status_t memory_create(memory_t **memory, + const memory_desc_t *md, void *handle) override; + + virtual const concat_primitive_desc_create_f* + get_concat_implementation_list() const override; + virtual const reorder_primitive_desc_create_f* + get_reorder_implementation_list() const override; + virtual const sum_primitive_desc_create_f* + get_sum_implementation_list() const override; + virtual const primitive_desc_create_f* + get_implementation_list() const override; +}; + +class cpu_engine_factory_t: public engine_factory_t { +public: + virtual size_t count() const override { return 1; } + virtual engine_kind_t kind() const override { return engine_kind::cpu; } + virtual status_t engine_create(engine_t **engine, + size_t index) const override { + assert(index == 0); + *engine = new cpu_engine_t(); + return status::success; + }; +}; + +extern cpu_engine_factory_t engine_factory; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp new file mode 100644 index 0000000000..5880d3450c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp @@ -0,0 +1,84 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_INNER_PRODUCT_PD_HPP +#define CPU_INNER_PRODUCT_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "inner_product_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { +inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d, + const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { + using namespace utils; + + auto strides_compatible = [&]() { + bool ok = true; + auto w_str = wei_d.blocking_desc().strides; + auto d_str = src_d.blocking_desc().strides; + for (int i = 1; i < src_d.ndims() - 1; i++) { + ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1]; + } + return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]); + }; + return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc() + && src_d.ndims() == wei_d.ndims() + && src_d.blocking_desc().inner_nblks + == wei_d.blocking_desc().inner_nblks + && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1) + && array_cmp(src_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_blks, + wei_d.blocking_desc().inner_nblks) + && array_cmp(src_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_idxs, + wei_d.blocking_desc().inner_nblks) + && strides_compatible() + && dst_d.matches_tag(format_tag::nc) + && src_d.only_padded_dim(1) + && wei_d.only_padded_dim(1) + && src_d.padded_dims()[1] == wei_d.padded_dims()[1] + && src_d.is_dense(true) + && dst_d.is_dense() + && wei_d.is_dense(true); +} +} + +struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t { + using inner_product_fwd_pd_t::inner_product_fwd_pd_t; +}; + +struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t { + using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t; +}; + +struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t { + using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp new file mode 100644 index 0000000000..da6e9dac8e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_ISA_TRAITS_HPP +#define CPU_ISA_TRAITS_HPP + +#include <type_traits> + +#define XBYAK64 +#define XBYAK_NO_OP_NAMES +/* in order to make selinux happy memory that would be marked with X-bit should + * be obtained with mmap */ +#define XBYAK_USE_MMAP_ALLOCATOR +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +/* turn off `size_t to other-type implicit casting` warning + * currently we have a lot of jit-generated instructions that + * take uint32_t, but we pass size_t (e.g. due to using sizeof). + * FIXME: replace size_t parameters with the appropriate ones */ +#pragma warning (disable: 4267) +#endif +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + isa_any, + sse41, + sse42, + avx, + avx2, + avx512_common, + avx512_core, + avx512_core_vnni, + avx512_mic, + avx512_mic_4ops, +} cpu_isa_t; + +template <cpu_isa_t> struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */ + +template <> struct cpu_isa_traits<sse42> { + typedef Xbyak::Xmm Vmm; + static constexpr int vlen_shift = 4; + static constexpr int vlen = 16; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits<avx> { + typedef Xbyak::Ymm Vmm; + static constexpr int vlen_shift = 5; + static constexpr int vlen = 32; + static constexpr int n_vregs = 16; +}; +template <> struct cpu_isa_traits<avx2>: + public cpu_isa_traits<avx> {}; + +template <> struct cpu_isa_traits<avx512_common> { + typedef Xbyak::Zmm Vmm; + static constexpr int vlen_shift = 6; + static constexpr int vlen = 64; + static constexpr int n_vregs = 32; +}; +template <> struct cpu_isa_traits<avx512_core>: + public cpu_isa_traits<avx512_common> {}; + +template <> struct cpu_isa_traits<avx512_mic>: + public cpu_isa_traits<avx512_common> {}; + +template <> struct cpu_isa_traits<avx512_mic_4ops>: + public cpu_isa_traits<avx512_common> {}; + +namespace { + +static Xbyak::util::Cpu cpu; +static inline bool mayiuse(const cpu_isa_t cpu_isa) { + using namespace Xbyak::util; + + switch (cpu_isa) { + case sse41: + case sse42: + // FIXME: SSE4.2 is actually NOT required + //return cpu.has(Cpu::tSSE42); + return cpu.has(Cpu::tSSE41); + case avx: + return cpu.has(Cpu::tAVX); + case avx2: + return cpu.has(Cpu::tAVX2); + case avx512_common: + return cpu.has(Cpu::tAVX512F); + case avx512_core: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ); + case avx512_core_vnni: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512BW) + && cpu.has(Cpu::tAVX512VL) + && cpu.has(Cpu::tAVX512DQ) + && cpu.has(Cpu::tAVX512_VNNI); + case avx512_mic: + return true + && cpu.has(Cpu::tAVX512F) + && cpu.has(Cpu::tAVX512CD) + && cpu.has(Cpu::tAVX512ER) + && cpu.has(Cpu::tAVX512PF); + case avx512_mic_4ops: + return true + && mayiuse(avx512_mic) + && cpu.has(Cpu::tAVX512_4FMAPS) + && cpu.has(Cpu::tAVX512_4VNNIW); + case isa_any: + return true; + } + return false; +} +} + +/* whatever is required to generate string literals... */ +#include "z_magic.hpp" +#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \ + (isa == sse42 ? prefix STRINGIFY(sse42) : \ + (isa == avx ? prefix STRINGIFY(avx) : \ + (isa == avx2 ? prefix STRINGIFY(avx2) : \ + (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \ + (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \ + (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \ + (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \ + prefix suffix_if_any))))))) + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp new file mode 100644 index 0000000000..49988f4c2d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_LRN_PD_HPP +#define CPU_LRN_PD_HPP + +#include <assert.h> + +#include "lrn_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t { + using lrn_fwd_pd_t::lrn_fwd_pd_t; +}; + +struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t { + using lrn_bwd_pd_t::lrn_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp new file mode 100644 index 0000000000..3c0624cf46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn_traits.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; + +enum blk_kind_t { a, b, c, ab, ba, bc, cb }; + +template <data_type_t dt, blk_kind_t blk_kind, int blksize> +void typed_zero_pad_blk( + const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) { + using data_t = typename prec_traits<dt>::type; + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + const auto &blk = m_d.blocking_desc(); + auto dim_is_blocked = [&](int dim) { + for (int i = 0; i < blk.inner_nblks; i++) + if (blk.inner_idxs[i] == dim) + return true; + return false; + }; + bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1), + C_blocked = dim_is_blocked(2); + + assert(blk.inner_nblks < 4); + assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked) + || (C_blocked && B_blocked)); + + const int a_tail_s = A_blocked ? dims[0] % blksize : 0; + const int b_tail_s = B_blocked ? dims[1] % blksize : 0; + const int c_tail_s = C_blocked ? dims[2] % blksize : 0; + assert(a_tail_s || b_tail_s || c_tail_s); + + const int A = A_blocked ? pdims[0] / blksize : dims[0]; + const int B = B_blocked ? pdims[1] / blksize : dims[1]; + const int C = C_blocked ? pdims[2] / blksize : dims[2]; + const int D = m_d.ndims() > 3 ? dims[3] : 1; + const int E = m_d.ndims() > 4 ? dims[4] : 1; + const int F = m_d.ndims() > 5 ? dims[5] : 1; + const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1; + + auto zeroize_tail = [&](data_t *d, const int tail_s) { + for (int b = tail_s; b < blksize; ++b) + d[b] = 0; + }; + auto zeroize_tail_inner = [&](data_t *d, const int tail_s) { + for (int b1 = 0; b1 < blksize; ++b1) + for (int b2 = tail_s; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + auto zeroize_tail_outer = [&](data_t *d, const int tail_s) { + for (int b1 = tail_s; b1 < blksize; ++b1) + for (int b2 = 0; b2 < blksize; ++b2) + d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2 + + b1 % inner_blk] + = 0; + }; + + if (c_tail_s) { + parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)]; + if (blk_kind == c) + zeroize_tail(x, c_tail_s); + else if (blk_kind == bc) + zeroize_tail_inner(x, c_tail_s); + else if (blk_kind == cb) + zeroize_tail_outer(x, c_tail_s); + }); + } + + if (b_tail_s) { + parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)]; + if (blk_kind == b) + zeroize_tail(x, b_tail_s); + else if (blk_kind == ab || blk_kind == cb) + zeroize_tail_inner(x, b_tail_s); + else if (blk_kind == ba || blk_kind == bc) + zeroize_tail_outer(x, b_tail_s); + }); + } + + if (a_tail_s) { + parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) { + auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)]; + if (blk_kind == a) + zeroize_tail(x, a_tail_s); + else if (blk_kind == ba) + zeroize_tail_inner(x, a_tail_s); + else if (blk_kind == ab) + zeroize_tail_outer(x, a_tail_s); + }); + } +} + +/* + * all + */ +template <data_type_t dt> +void typed_zero_pad_generic_blocked( + const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) { + const int ndims = m_d.ndims(); + const auto &dims = m_d.dims(); + const auto &pdims = m_d.padded_dims(); + + const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true); + + /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1] + * | \ / + * | --------------------- + * has contiguous + * padding + * + * step <-- D_k+1 * ... * D_ndims-1 + * step_dim <-- k + */ + + ptrdiff_t step = 1; + int step_dim = ndims - 1; + for (; step_dim >= 0; --step_dim) { + if (dims[step_dim] != pdims[step_dim]) + break; + step *= dims[step_dim]; + } + + assert(step_dim >= 0 && "no zero padding is required"); + if (step_dim < 0) + return; + + parallel_nd(nelems / step, [&](ptrdiff_t e1) { + bool need_zero = false; + + ptrdiff_t idx = e1; + for (int d = step_dim; d >= 0; --d) { + if (idx % pdims[d] >= dims[d]) { + need_zero = true; + break; + } + idx /= pdims[d]; + } + + if (need_zero) { + for (ptrdiff_t e0 = 0; e0 < step; ++e0) + data[m_d.off_l(e1 * step + e0, true)] = 0; + } + }); +} + +template <data_type_t dt> +status_t cpu_memory_t::typed_zero_pad() const { + const memory_desc_wrapper mdw(md()); + + if (mdw.format_kind() != format_kind::blocked) + return unimplemented; + + if (mdw.nelems(false) == mdw.nelems(true)) + return success; + + auto *data = (typename prec_traits<dt>::type *)data_; + auto blk = mdw.blocking_desc(); + + auto get_blksize = [&](int ind) { + int blksize = 1; + for (int i = 0; i < blk.inner_nblks; i++) { + if (blk.inner_idxs[i] == ind) + blksize *= blk.inner_blks[i]; + } + return blksize; + }; + const int blksize = get_blksize(blk.inner_idxs[0]); + +# define CASE(blksize_, blk_kind) \ + do { \ + if (blksize == blksize_) { \ + typed_zero_pad_blk<dt, blk_kind, blksize_>(mdw, data); \ + return success; \ + } \ + } while(0) + + switch (blk.inner_nblks) { + case 1: + if (blk.inner_idxs[0] == 0) { + CASE(4, a); + CASE(8, a); + CASE(16, a); + } else if (blk.inner_idxs[0] == 1) { + CASE(4, b); + CASE(8, b); + CASE(16, b); + } + break; + case 2: + case 3: + if (!IMPLICATION(blk.inner_nblks == 3, + blk.inner_idxs[0] == blk.inner_idxs[2])) + break; + + if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) { + CASE(4, ab); + CASE(8, ab); + CASE(16, ab); + } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) { + CASE(4, ba); + CASE(8, ba); + CASE(16, ba); + } + if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) { + CASE(4, bc); + CASE(8, bc); + CASE(16, bc); + } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) { + CASE(4, cb); + CASE(8, cb); + CASE(16, cb); + } + break; + default: break; + } + +# undef CASE + + // the last line of defence + typed_zero_pad_generic_blocked<dt>(mdw, data); + return success; +} + +status_t cpu_memory_t::zero_pad() const { + memory_desc_wrapper mdw(md()); + const bool skip_zeroing = false + || data_ == nullptr + || mdw.is_zero() + || !mdw.is_blocking_desc(); + if (skip_zeroing) return success; + + switch (mdw.data_type()) { + case f32: return typed_zero_pad<f32>(); + case s32: return typed_zero_pad<s32>(); + case s8: return typed_zero_pad<s8>(); + case u8: return typed_zero_pad<u8>(); + default: assert(!"memory is undefined"); return unimplemented; + } + return unimplemented; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp new file mode 100644 index 0000000000..2c01bcc6af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_MEMORY_HPP +#define CPU_MEMORY_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory.hpp" +#include "memory_desc_wrapper.hpp" + +#include "cpu_engine.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t: public memory_t { + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle) + : memory_t(engine, md) + , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE) + , data_((char *)handle) {} + + cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md) + : cpu_memory_t(engine, md, nullptr) {} + + ~cpu_memory_t() { if (own_data_) free(data_); } + + virtual status_t init() override { + if (own_data_) { + data_ = nullptr; + const size_t size = memory_desc_wrapper(this->md()).size(); + if (size) { + data_ = (char *)malloc(size, 64); + if (data_ == nullptr) + return status::out_of_memory; + } + } + return zero_pad(); + } + + cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); } + + virtual status_t get_data_handle(void **handle) const override { + *handle = static_cast<void *>(data_); + return status::success; + } + + virtual mkldnn::impl::status_t set_data_handle(void *handle) override { + if (own_data_) { free(data_); own_data_ = false; } + data_ = static_cast<char *>(handle); + return zero_pad(); + } + + virtual mkldnn::impl::status_t zero_pad() const override; + +private: + bool own_data_; + char *data_; + + template <mkldnn::impl::data_type_t> + mkldnn::impl::status_t typed_zero_pad() const; + + cpu_memory_t(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(const cpu_memory_t &) = delete; + cpu_memory_t &operator=(cpu_memory_t &&) = delete; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp new file mode 100644 index 0000000000..ac2daa415e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp @@ -0,0 +1,40 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_POOLING_PD_HPP +#define CPU_POOLING_PD_HPP + +#include "pooling_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t { + using pooling_fwd_pd_t::pooling_fwd_pd_t; +}; + +struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t { + using pooling_bwd_pd_t::pooling_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp new file mode 100644 index 0000000000..56127f36c2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_PRIMITIVE_HPP +#define CPU_PRIMITIVE_HPP + +#include "mkldnn.h" + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "primitive.hpp" +#include "scratchpad.hpp" + +#define CTX_IN_MEM(type, arg) static_cast<type>(ctx.input(arg)) +#define CTX_OUT_MEM(type, arg) static_cast<type>(ctx.output(arg)) + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_memory_t; + +struct cpu_primitive_t: public primitive_t { + cpu_primitive_t(const primitive_desc_t *pd, + bool use_global_scratchpad = false) + : primitive_t(pd) + , scratchpad_buffer_(nullptr) + , global_scratchpad_(nullptr) + { + const size_t scratchpad_size = + this->pd()->scratchpad_size(scratchpad_mode::library); + + if (scratchpad_size) { + if (use_global_scratchpad) + global_scratchpad_ = create_scratchpad(scratchpad_size); + else + scratchpad_buffer_ = malloc(scratchpad_size, 64); + } + } + + virtual ~cpu_primitive_t() { + delete global_scratchpad_; + free(scratchpad_buffer_); + } + +protected: + memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const { + void *ptr = nullptr; + if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) { + ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD); + } else { + ptr = global_scratchpad_ + ? global_scratchpad_->get() : scratchpad_buffer_; + } + + return pd()->scratchpad_registry().grantor(ptr); + } + +private: + void *scratchpad_buffer_; + scratchpad_t *global_scratchpad_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp new file mode 100644 index 0000000000..1d41ac5cea --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp @@ -0,0 +1,544 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "cpu_reducer.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void reduce_balancer_t::balance() { + using namespace nstl; + using namespace utils; + + assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0); + + const int job_complexity = 1; + + const int min_njobs_per_group = max(1, njobs_ / nthr_); + const int max_njobs_per_group = max(1, + static_cast<int>(max_buffer_size_ / (nthr_ * job_size_))); + + /* initial guess */ + int ngroups = min(njobs_ / min_njobs_per_group, nthr_); + int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1; + int njobs_per_group_ub = div_up(njobs_, ngroups); + + /* rough upper-bound estimation, will be fixed during brute force */ + size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_; + + /* brute force parameters for the best balance... */ + for (int c_njobs_per_group = min_njobs_per_group; + c_njobs_per_group < njobs_; ++c_njobs_per_group) { + /* current assumption */ + int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_); + int c_nthr_per_group = syncable_ + ? min(nthr_ / c_ngroups, reduction_size_) : 1; + int c_njobs_per_group_ub = div_up(njobs_, c_ngroups); + + if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group) + continue; + + int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group); + size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub; + size_t c_thread_complexity_ub = c_group_size_ub * ( + job_complexity * c_thread_reduction_ub + + (c_nthr_per_group != 1)); + + if (c_thread_complexity_ub < thread_complexity_ub) { + ngroups = c_ngroups; + nthr_per_group = c_nthr_per_group; + njobs_per_group_ub = c_njobs_per_group_ub; + thread_complexity_ub = c_thread_complexity_ub; + } + } + + assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1); + assert(ngroups * nthr_per_group <= nthr_); + assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_ + || nthr_per_group == 1); /* no reduction buffer overflow */ + assert(IMPLICATION(!syncable_, nthr_per_group == 1)); + + ngroups_ = ngroups; + nthr_per_group_ = nthr_per_group; + njobs_per_group_ub_ = njobs_per_group_ub; +} + +/* reducer jit-ted driver */ + +using namespace Xbyak; + +template <impl::data_type_t data_type> +struct reducer_2d_driver_t: public c_compatible { + typedef typename prec_traits<data_type>::type data_t; + + reducer_2d_driver_t(int n_src, size_t src_ld, + size_t src_step, size_t dst_step, bool nullify_dst) + : n_src_(n_src), src_ld_(src_ld), src_step_(src_step) + , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {} + virtual ~reducer_2d_driver_t() {} + void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx) + { assert(ker_); ker_(dst, srcs, ny, nx); } + +protected: + int n_src_; + size_t src_ld_, src_step_, dst_step_; + bool nullify_dst_; + void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx); +}; + +template <impl::data_type_t data_type, cpu_isa_t isa> +struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t<data_type>, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type; + const AddressFrame &vmmword = (isa == avx2) ? yword : zword; + void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op) + { if (data_type == data_type::f32) vaddps(x1, x2, op); + else vpaddd(x1, x2, op); } + void uni_add(const Xmm& x1, const Operand& op) + { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); } + + const int vlen = cpu_isa_traits<isa>::vlen; + const int typesize + = sizeof(typename mkldnn::impl::prec_traits<data_type>::type); + Xbyak::Reg64 reg_dst = abi_param1; + Xbyak::Reg64 reg_src = abi_param2; + Xbyak::Reg64 reg_ny = abi_param3; + Xbyak::Reg64 reg_nx = abi_param4; + + Xbyak::Reg64 reg_x = rax; + Xbyak::Reg64 reg_src_id = r10; + + reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step, + size_t dst_step, bool nullify_dst) + : reducer_2d_driver_t<data_type>(n_src, src_ld, src_step, + dst_step, nullify_dst) + { generate(); } + + void nullify_dst(int nloads, int load_len) { + UNUSED(load_len); + for (int i = 0; i < nloads; ++i) + uni_vpxor(Vmm(i), Vmm(i), Vmm(i)); + /* prefetches[dst] ? */ + } + + void load_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(Xmm(i), ptr[reg_dst + i * load_len]); + else if (load_len == vlen) + vmovups(Vmm(i), ptr[reg_dst + i * load_len]); + else + assert(!"unsupported"); + } + } + + void store_dst(int nloads, int load_len) { + for (int i = 0; i < nloads; ++i) { + if (load_len == typesize) + movd(ptr[reg_dst + i * load_len], Xmm(i)); + else if (load_len == vlen) + vmovups(ptr[reg_dst + i * load_len], Vmm(i)); + else + assert(!"unsupported"); + } + } + + void accumulate(int nloads, int load_len, size_t base_off) { + for (int i = 0; i < nloads; ++i) { + size_t off = base_off + i * load_len; + + if (load_len == typesize) + uni_add(Xmm(i), ptr[reg_src + off]); + else if (load_len == vlen) + uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]); + else + assert(!"unsupported"); + } + } + + void loop_x() { + const int nloads[] = {cpu_isa_traits<isa>::n_vregs, 1, 1}; + const int nbranches = sizeof(nloads) / sizeof(nloads[0]); + + const int load_len[nbranches] = {vlen, vlen, typesize}; + Label loop_x_label[nbranches + 1]; + + mov(reg_x, reg_nx); + + for (int id = 0; id < nbranches; ++id) { + L(loop_x_label[id]); + + cmp(reg_x, nloads[id] * load_len[id]); + jl(loop_x_label[id + 1], T_NEAR); + + if (this->nullify_dst_) + nullify_dst(nloads[id], load_len[id]); + else + load_dst(nloads[id], load_len[id]); + + if (nloads[id] > 1) { + Label loop_srcs; + mov(reg_src_id, this->n_src_); + L(loop_srcs); + + accumulate(nloads[id], load_len[id], 0); + add(reg_src, this->src_ld_ * typesize); + + dec(reg_src_id); + jnz(loop_srcs, T_NEAR); + + sub(reg_src, this->n_src_ * this->src_ld_ * typesize); + } else { + for (int src_id = 0; src_id < this->n_src_; ++src_id) { + const size_t base_off = src_id * this->src_ld_ * typesize; + accumulate(nloads[id], load_len[id], base_off); + } + } + + store_dst(nloads[id], load_len[id]); + + add(reg_src, nloads[id] * load_len[id]); + add(reg_dst, nloads[id] * load_len[id]); + + sub(reg_x, nloads[id] * load_len[id]); + + jmp(loop_x_label[id], T_NEAR); + } + + L(loop_x_label[nbranches]); + + /* restore address registers */ + sub(reg_src, reg_nx); + sub(reg_dst, reg_nx); + } + + void generate() { + assert(isa == avx2 || isa == avx512_common || isa == avx512_mic); + + preamble(); + + shl(reg_nx, 2); + + Label ny_loop; + L(ny_loop); + + loop_x(); + + add(reg_dst, this->dst_step_ * typesize); + add(reg_src, this->src_step_ * typesize); + + dec(reg_ny); + jnz(ny_loop, T_NEAR); + + postamble(); + this->ker_ = reinterpret_cast<decltype(this->ker_)>( + const_cast<uint8_t*>(this->getCode())); + } +}; + +template <impl::data_type_t data_type> +inline reducer_2d_driver_t<data_type> *create_reduce_2d_drv(int n_src, + size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) { + if (mayiuse(avx512_common)) + return new reducer_2d_driver_f_s_32_t<data_type, avx512_common>(n_src, + src_ld, src_step, dst_step, nullify_dst); + else if (mayiuse(avx2)) + return new reducer_2d_driver_f_s_32_t<data_type, avx2>(n_src, src_ld, + src_step, dst_step, nullify_dst); + assert(!"unimplemented"); + return nullptr; +} + +/* cpu_reducer_t */ + +template <impl::data_type_t data_type> +void cpu_reducer_t<data_type>::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ + * (balancer_.nthr_per_group_ - 1) + * cpu_reducer_t<data_type>::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template <impl::data_type_t data_type> +cpu_reducer_t<data_type>::cpu_reducer_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_ - 1, + space_per_thread(balancer()), 0, 0, false); +} + +template <impl::data_type_t data_type> +cpu_reducer_t<data_type>::~cpu_reducer_t() { delete drv_; } + +template <impl::data_type_t data_type> +typename cpu_reducer_t<data_type>::data_t * +cpu_reducer_t<data_type>::get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + + /* threads 0 from each group writes directly to the destination */ + if (id_in_grp == 0) + return dst + balancer().ithr_job_off(ithr) * balancer().job_size_; + + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1) + + (id_in_grp - 1); + + auto space = scratchpad.template get<data_t>(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template <impl::data_type_t data_type> +void cpu_reducer_t<data_type>::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + +#ifdef SIMPLE_IMPL + if (balancer().id_in_group(ithr) != 0) + return; /* only threads 0 do the reduction */ + + const int njobs_in_grp = balancer().ithr_njobs(ithr); + data_t *d = get_local_ptr(ithr, dst, scratchpad); + for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp) + { + const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad); + for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i) + d[i] += space[i]; + } +#else + using namespace utils; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const size_t cl = 64 / sizeof(data_t); + + const size_t reduction_size = njobs_in_grp * balancer().job_size_; + size_t start{0}, end{0}; + balance211(div_up(reduction_size, cl), balancer().nthr_per_group_, + id_in_grp, start, end); + + if (start == end) return; + + data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl; + const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad) + + start * cl; + const size_t len = nstl::min(end * cl, reduction_size) - start * cl; + + (*drv_)(d, space, 1, len); +#endif +} + +template struct cpu_reducer_t<data_type::f32>; +template struct cpu_reducer_t<data_type::s32>; + +/* cpu_reducer_2d_t */ + +template <impl::data_type_t data_type> +void cpu_reducer_2d_t<data_type>::conf_t::init_scratchpad( + memory_tracking::registrar_t &scratchpad) const { + if (balancer_.nthr_per_group_ == 1) return; + + const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_ + * cpu_reducer_2d_t<data_type>::space_per_thread(balancer_); + scratchpad.book(key_reducer_space, sizeof(data_t) * space_size); + scratchpad.book(key_reducer_space_bctx, + sizeof(simple_barrier::ctx_t) * balancer_.ngroups_); +} + +template <impl::data_type_t data_type> +cpu_reducer_2d_t<data_type>::cpu_reducer_2d_t(const conf_t &conf) + : conf_(conf), drv_(nullptr) +{ + if (balancer().nthr_per_group_ == 1) return; + + drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_, + space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_, + true); +} + +template <impl::data_type_t data_type> +cpu_reducer_2d_t<data_type>::~cpu_reducer_2d_t() { delete drv_; } + +template <impl::data_type_t data_type> +typename cpu_reducer_2d_t<data_type>::data_t *cpu_reducer_2d_t<data_type>:: +get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const { + const int id_in_grp = balancer().id_in_group(ithr); + const int grp_id = balancer().group_id(ithr); + const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp; + auto space = scratchpad.template get<data_t>(key_reducer_space); + return space + offset_factor * space_per_thread(balancer()); +} + +template <impl::data_type_t data_type> +int cpu_reducer_2d_t<data_type>::choose_x_blocking(int nx, int ny, + int nthr_per_grp) const { + // find x_blocking for better balance reducing work between threads + assert(conf_.x_block_ > 0 && nx > conf_.x_block_ + && nx % conf_.x_block_ == 0); + int x_blocking = nx / conf_.x_block_; + int min_x_blocking = + utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny)); + while (true) { + if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2) + x_blocking /= 2; + else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3) + x_blocking /= 3; + else + break; + } + if (x_blocking >= min_x_blocking * 4) x_blocking = 1; + x_blocking *= conf_.x_block_; + return x_blocking; +} + +template <impl::data_type_t data_type> +void cpu_reducer_2d_t<data_type>::reduce_block(const data_t* space_base, + data_t *dst, int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const { + data_t *d = dst + (start_y + ny_start) * conf_.dst_x_ + + start_x + nx_start; + const data_t *space = space_base + job * balancer().job_size_ + + ny_start * conf_.job_size_x_ + nx_start; +#ifdef SIMPLE_IMPL + for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) { + const data_t *w = &space[idg * space_per_thread(balancer())]; + for (int y = 0; y < ny_step; ++y) + for (int x = 0; x < nx_step; ++x) { + d[y * conf_.dst_x_ + x] + = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x]) + + w[y * conf_.job_size_x_ + x]; + } + } +#else + (*drv_)(d, space, ny_step, nx_step); +#endif +} + +template <impl::data_type_t data_type> +void cpu_reducer_2d_t<data_type>::reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + const int id_in_grp = balancer().id_in_group(ithr); + const int njobs_in_grp = balancer().ithr_njobs(ithr); + const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_); + const int global_job_start = balancer().ithr_job_off(ithr); + + const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad); + + const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_); + const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps; + + if (id_in_grp >= pr_grps * pr_nthr_per_grp) + return; /* idle */ + + const int pr_my_grp = id_in_grp / pr_nthr_per_grp; + const int pr_my_id = id_in_grp % pr_nthr_per_grp; + + int pr_job_start{0}, pr_job_end{0}; + balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end); + + for (int j = pr_job_start; j < pr_job_end; ++j) { + const int global_job = global_job_start + j; + const int j_y = global_job / njobs_x; + const int j_x = global_job % njobs_x; + const int start_y = j_y * conf_.job_size_y_; + const int start_x = j_x * conf_.job_size_x_; + const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_); + const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_); + int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp); + + int nxy_start{0}, nxy_end{0}; + balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id, + nxy_start, nxy_end); + if (nxy_start == nxy_end) continue; + nxy_start *= x_blocking; + nxy_end *= x_blocking; + + int nxy = nxy_start; + if (nxy % nx != 0) { + int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy); + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nx_step); + nxy += nx_step; + } + if ((nxy_end - nxy) > nx) { + int ny_step = (nxy_end - nxy) / nx; + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, ny_step, nx); + nxy += nx * ny_step; + } + if ((nxy_end - nxy) > 0) { + reduce_block(space_base, dst, j, start_y, start_x, + nxy / nx, nxy % nx, 1, nxy_end - nxy); + } + } +} + +template struct cpu_reducer_2d_t<data_type::f32>; +template struct cpu_reducer_2d_t<data_type::s32>; + +/* accumulator section */ + +template <impl::data_type_t data_type> +cpu_accumulator_1d_t<data_type>::cpu_accumulator_1d_t(): drv_(nullptr) { + drv_ = create_reduce_2d_drv<data_type>(1, 0, 0, 0, false); +} + +template <impl::data_type_t data_type> +cpu_accumulator_1d_t<data_type>::~cpu_accumulator_1d_t() { + delete drv_; +} + +template <impl::data_type_t data_type> +void cpu_accumulator_1d_t<data_type>::accumulate(data_t *dst, + const data_t *src, size_t size) { + (*drv_)(dst, src, 1, size); +} + +template struct cpu_accumulator_1d_t<data_type::f32>; +template struct cpu_accumulator_1d_t<data_type::s32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp new file mode 100644 index 0000000000..27f5939cd2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp @@ -0,0 +1,334 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REDUCER_HPP +#define CPU_REDUCER_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_barrier.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/** class to perform balancing over 3D array + * + * Conceptually the reduction happens according to the picture below: + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ result + * + * In a simple case the result must be contiguous in memory. + * @class cpu_reducer_t is an implementation. + * + * Threads are divided into groups. The groups are independent of each other. + * Each group may work on several jobs (the distribution is not uniform, since + * njobs might be not a multiple of groups). Threads within a group work on + * different parts of the reduction dimension. Thread 0 in each group is called + * master (@sa reduce_balancer_t::master()). + * + * If threading driver does not allow sync between sub-group of threads (e.g. + * Intel(R) TBB) the # of thread per group is enforced to be 1. + */ +struct reduce_balancer_t { + reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */ + reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size, + size_t max_buffer_size) + { init(nthr, job_size, njobs, reduction_size, max_buffer_size); } + + reduce_balancer_t &init(int nthr, int job_size, int njobs, + int reduction_size, size_t max_buffer_size) + { + syncable_ = mkldnn_thr_syncable(); + nthr_ = nthr; + job_size_ = job_size; + njobs_ = njobs; + reduction_size_ = reduction_size; + max_buffer_size_ = max_buffer_size; + balance(); + return *this; + } + + bool syncable_; + int nthr_; + int job_size_, njobs_, reduction_size_; + + int ngroups_; /** number of independent work (thread) groups */ + int nthr_per_group_; /** number of threads within a single work group */ + int njobs_per_group_ub_; /** the max # of jobs within a work group */ + + bool master(int ithr) const { return id_in_group(ithr) == 0; } + bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; } + + int group_id(int ithr) const { return ithr / nthr_per_group_; } + int id_in_group(int ithr) const { return ithr % nthr_per_group_; } + + int grp_njobs(int grp) const { + if (grp >= ngroups_) return 0; + return njobs_ / ngroups_ + (grp < njobs_ % ngroups_); + } + int grp_job_off(int grp) const { + if (grp >= ngroups_) return njobs_; + return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_); + } + + int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); } + int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); } + +private: + size_t max_buffer_size_; + void balance(); +}; + +/** forward declaration of reduce driver */ +template <impl::data_type_t data_type> struct reducer_2d_driver_t; + +/** class to perform a reduction over 3D array + * + * Balancing is based on @class reduce_balancer_t. + * Restrictions: the result of the reduction must be contiguous in memory. * + * The reduction happens according to the picture below (once more): + * + * <--job_size-> + * +-----------+ +-----------+ +-----------+ ^ + * | | | | | | | + * | | | | | | | + * | 1 | | 2 | . . . | njobs | | reduction_size + * | | | | | | | + * | | | | | | | + * +-----------+ +-----------+ +-----------+ v + * + * | | | | | | | | | + * v v v v v v v v v + * ===================================================== vertical reduction + * + * +-----------+ +-----------+ . . . +-----------+ (contiguous) result + * + * An example how work might be shared is shown below. + * + * In this example group 0 owns 2 (independent) jobs -- 2 big squares. + * The number of threads per group is also 2 (thread 0 of group 0 and thread 1 + * of group 0). Master threads (i.e. threads with id 0 in corresponding group) + * from each group put the partial result directly into destination memory, + * while all the other threads with-in the group use workspace (on the picture + * the only thread 1). Once intermediate results obtained each group reduces + * corresponding part (own jobs) to the destination memory. + * + * <------- group 0 -------> + * + * +-----------+ +-----------+ ^ + * | | | | | thread 0 of reduces to the dest-memory + * | | | | | group 0 +-----------+ +-----------+ + * |- - - - - -| |- - - - - -| X + * | | | | | thread 1 of reduces to workspace[tid=1]: + * | | | | | group 0 +-----------+ +-----------+ + * +-----------+ +-----------+ v + * | | | | | | + * v v v v v v + * ((barrier)) ============================= + * + * dest-memory: +-----------+ +-----------+ + */ +template <impl::data_type_t data_type> +struct cpu_reducer_t { + typedef typename prec_traits<data_type>::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer) + { balancer_ = balancer; return *this; } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + }; + + cpu_reducer_t(const conf_t &conf); + ~cpu_reducer_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get<simple_barrier::ctx_t>( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results. + * Reduction destination @p dst must be provided as well (master threads + * from each group will use it for partial result to reduce memory + * pressure). + * + * @note: job offset is already applied by get_local_ptr(), which means all + * threads should start writing from the very beginning of returned + * address. + */ + data_t *get_local_ptr(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get<simple_barrier::ctx_t>( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t<data_type> *drv_; + + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +template <impl::data_type_t data_type> +struct cpu_reducer_2d_t { + typedef typename prec_traits<data_type>::type data_t; + + struct conf_t { + conf_t() = default; + conf_t &init(const reduce_balancer_t &balancer, int job_size_x, + int job_size_y, int x_block, int dst_x, int dst_y) { + balancer_ = balancer; + job_size_x_ = job_size_x; + job_size_y_ = job_size_y; + x_block_ = x_block; + dst_x_ = dst_x; + dst_y_ = dst_y; + return *this; + } + + void init_scratchpad(memory_tracking::registrar_t &scratchpad) const; + + reduce_balancer_t balancer_; + int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_; + }; + + cpu_reducer_2d_t(const conf_t &conf); + ~cpu_reducer_2d_t(); + + /** initializes reducer. + * Must be called from a single thread prior to actual usage */ + void init(const memory_tracking::grantor_t &scratchpad) const { + if (balancer().nthr_per_group_ == 1) return; + + auto bctx = scratchpad.template get<simple_barrier::ctx_t>( + memory_tracking::names::key_reducer_space_bctx); + for (int i = 0; i < balancer().ngroups_; ++i) + simple_barrier::ctx_init(&bctx[i]); + } + + /** for given thread returns the pointer where to put partial results */ + data_t *get_local_ptr(int ithr, + const memory_tracking::grantor_t &scratchpad) const; + + /** performs the reduction with built-in synchronization. */ + void reduce(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + bool redundant_reduction = balancer().nthr_per_group_ == 1 + || balancer().idle(ithr); + if (redundant_reduction) return; + + auto bctx = scratchpad.template get<simple_barrier::ctx_t>( + memory_tracking::names::key_reducer_space_bctx); + simple_barrier::barrier(&bctx[balancer().group_id(ithr)], + balancer().nthr_per_group_); + + reduce_nolock(ithr, dst, scratchpad); + } + + const reduce_balancer_t &balancer() const { return conf_.balancer_; } + +private: + static size_t space_per_thread(const reduce_balancer_t &balancer) + { return balancer.njobs_per_group_ub_ * balancer.job_size_; } + + /* The scratchpad is organized as follows: + * + * data_t space[nthr_][njobs_per_group_ub_][jobs_size_]; + * simple_barrier::ctx_t barriers[groups_]; */ + + const conf_t conf_; + reducer_2d_driver_t<data_type> *drv_; + + int choose_x_blocking(int nx, int ny, int nthr_per_grp) const; + void reduce_block(const data_t* space_base, data_t *dst, + int job, int start_y, int start_x, + int ny_start, int nx_start, int ny_step, int nx_step) const; + void reduce_nolock(int ithr, data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; +}; + +/** simple 1d accumulator: y[:] += x[:] */ +template <impl::data_type_t data_type> +struct cpu_accumulator_1d_t { + typedef typename prec_traits<data_type>::type data_t; + + cpu_accumulator_1d_t(); + ~cpu_accumulator_1d_t(); + void accumulate(data_t *dst, const data_t *src, size_t size); + + reducer_2d_driver_t<data_type> *drv_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp new file mode 100644 index 0000000000..82be70353d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp @@ -0,0 +1,262 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "cpu_engine.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_memory.hpp" +#include "type_helpers.hpp" + +#include "cpu/jit_uni_reorder.hpp" +#include "cpu/simple_reorder.hpp" +#include "cpu/wino_reorder.hpp" +#include "cpu/rnn/rnn_reorders.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f; + +namespace { +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; + +#define REG_SR(idt, ifmt, odt, ofmt, ...) \ + simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create + +#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \ + REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse) + +#define REG_SR_DIRECT_COPY(idt, odt) \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \ + REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0) + +static const rpd_create_f cpu_reorder_impl_list[] = { + /* winograd */ + wino_reorder_t<f32, f32>::pd_t::create, + //wino_reorder_t<f32, s8>::pd_t::create, + + /* rnn reorders */ + rnn_data_reorder_t<f32, u8>::pd_t::create, + rnn_weights_reorder_t<f32, f32>::pd_t::create, + rnn_weights_reorder_t<f32, s8>::pd_t::create, + + /* conv reorders w/ compensation */ + REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8), + + REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8), + + /* regular reorders */ + +#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__)) + /* Direct copy for icc which is faster than jitted code; + * Direct copy for gcc which might or might not be faster than jitted + * code, but still worth it because doesn't require jitting, i.e. much + * faster creation time. This is tentative solution and should be removed + * later (when we will cache jitted code?...). */ + REG_SR_DIRECT_COPY(f32, f32), +#endif + +#ifdef __INTEL_COMPILER + /* direct copy for icc, which is faster than jitted code */ + /* + REG_SR_DIRECT_COPY(f32, s32), + REG_SR_DIRECT_COPY(f32, s8), + REG_SR_DIRECT_COPY(f32, u8), + REG_SR_DIRECT_COPY(s32, f32), + REG_SR_DIRECT_COPY(s32, s32), + REG_SR_DIRECT_COPY(s32, s8), + REG_SR_DIRECT_COPY(s32, u8), + REG_SR_DIRECT_COPY(s8, f32), + REG_SR_DIRECT_COPY(s8, s32), + REG_SR_DIRECT_COPY(s8, s8), + REG_SR_DIRECT_COPY(s8, u8), + REG_SR_DIRECT_COPY(u8, f32), + REG_SR_DIRECT_COPY(u8, s32), + REG_SR_DIRECT_COPY(u8, s8), + REG_SR_DIRECT_COPY(u8, u8), + */ +#endif + + /* jit */ + jit_uni_reorder_create, + + /* fp32: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, f32, nCw4c), + REG_SR_BIDIR(f32, any, f32, nCw8c), + REG_SR_BIDIR(f32, any, f32, OIw4i4o), + REG_SR_BIDIR(f32, any, f32, OIw8i8o), + REG_SR_BIDIR(f32, any, f32, OIw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCw16c), + REG_SR_BIDIR(f32, any, f32, OIw16o16i), + REG_SR_BIDIR(f32, any, f32, OIw16i16o), + REG_SR_BIDIR(f32, any, f32, IOw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOw16o16i), + + REG_SR_BIDIR(f32, any, f32, nChw4c), + REG_SR_BIDIR(f32, any, f32, nChw8c), + REG_SR_BIDIR(f32, any, f32, OIhw4i4o), + REG_SR_BIDIR(f32, any, f32, Ohwi8o), + + REG_SR_BIDIR(f32, any, f32, OIhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOIhw4o4i), + REG_SR_BIDIR(f32, any, f32, gOhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nChw16c), + REG_SR_BIDIR(f32, any, f32, Oihw4o), + REG_SR_BIDIR(f32, any, f32, Oihw16o), + REG_SR_BIDIR(f32, any, f32, Ohwi4o), + REG_SR_BIDIR(f32, any, f32, Ohwi16o), + REG_SR_BIDIR(f32, any, f32, OIhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIhw16i16o), + REG_SR_BIDIR(f32, any, f32, IOhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOihw4o), + REG_SR_BIDIR(f32, any, f32, gOihw16o), + REG_SR_BIDIR(f32, any, f32, gOhwi4o), + REG_SR_BIDIR(f32, any, f32, gOhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIhw16i16o), + REG_SR_BIDIR(f32, any, f32, gIOhw16o16i), + + REG_SR_BIDIR(f32, any, f32, nCdhw4c), + REG_SR_BIDIR(f32, any, f32, nCdhw8c), + REG_SR_BIDIR(f32, any, f32, OIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, Odhwi8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, OIdhw8o8i), + REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o), + REG_SR_BIDIR(f32, any, f32, gOdhwi8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o), + REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i), + + REG_SR_BIDIR(f32, any, f32, nCdhw16c), + REG_SR_BIDIR(f32, any, f32, Oidhw4o), + REG_SR_BIDIR(f32, any, f32, Oidhw16o), + REG_SR_BIDIR(f32, any, f32, Odhwi16o), + REG_SR_BIDIR(f32, any, f32, OIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, OIdhw16i16o), + REG_SR_BIDIR(f32, any, f32, gOidhw4o), + REG_SR_BIDIR(f32, any, f32, gOidhw16o), + REG_SR_BIDIR(f32, any, f32, gOdhwi16o), + REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i), + REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o), + */ + + /* fp32: blocked <-> blocked with tail */ + REG_SR_BIDIR(f32, nCw8c, f32, nCw16c), + REG_SR_BIDIR(f32, nChw8c, f32, nChw16c), + REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c), + + /* int: flat <-> blocked with tail */ + /* + REG_SR_BIDIR(f32, any, s32, nChw16c), + REG_SR_BIDIR(f32, any, s8, nChw16c), + REG_SR_BIDIR(f32, any, u8, nChw16c), + REG_SR_BIDIR(s32, any, f32, nChw16c), + REG_SR_BIDIR(s32, any, s32, nChw16c), + REG_SR_BIDIR(s32, any, s8, nChw16c), + REG_SR_BIDIR(s32, any, u8, nChw16c), + REG_SR_BIDIR(s8, any, f32, nChw16c), + REG_SR_BIDIR(s8, any, s32, nChw16c), + REG_SR_BIDIR(s8, any, s8, nChw16c), + REG_SR_BIDIR(s8, any, u8, nChw16c), + REG_SR_BIDIR(u8, any, f32, nChw16c), + REG_SR_BIDIR(u8, any, s32, nChw16c), + REG_SR_BIDIR(u8, any, s8, nChw16c), + REG_SR_BIDIR(u8, any, u8, nChw16c), + + REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i), + REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i), + REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i), + */ + + /* reference: the last line of defence */ + /* + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(f32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s32, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s32, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s32, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(s8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(s8, any, s8, any, fmt_order::any, spec::reference), + REG_SR(s8, any, u8, any, fmt_order::any, spec::reference), + + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s32, any, fmt_order::any, spec::reference), + REG_SR(u8, any, u8, any, fmt_order::any, spec::reference), + REG_SR(u8, any, s8, any, fmt_order::any, spec::reference), + */ + + /* eol */ + nullptr, +}; +} + +const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const { + return cpu_reorder_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp new file mode 100644 index 0000000000..1622eb6849 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REORDER_PD_HPP +#define CPU_REORDER_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "reorder_pd.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_reorder_pd_t: public reorder_pd_t { + using reorder_pd_t::reorder_pd_t; + + status_t init() { + const auto &post_ops = attr()->post_ops_; + bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1 + && post_ops.entry_[0].kind == primitive_kind::sum); + scratchpad_engine_ = src_engine_; + return args_ok ? status::success : status::unimplemented; + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp new file mode 100644 index 0000000000..f16587b99f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SHUFFLE_PD_HPP +#define CPU_SHUFFLE_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "shuffle_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_shuffle_pd_t: public shuffle_pd_t { + using shuffle_pd_t::shuffle_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp new file mode 100644 index 0000000000..3a39eab974 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SOFTMAX_PD_HPP +#define CPU_SOFTMAX_PD_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "softmax_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t { + using softmax_fwd_pd_t::softmax_fwd_pd_t; +}; + +struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t { + using softmax_bwd_pd_t::softmax_bwd_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp new file mode 100644 index 0000000000..1ab5d9f174 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu_engine.hpp" + +/* +#include "cpu/ref_sum.hpp" +#include "cpu/simple_sum.hpp" +*/ + +namespace mkldnn { +namespace impl { +namespace cpu { + +using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f; + +namespace { +#define INSTANCE(...) __VA_ARGS__::pd_t::create +static const spd_create_f cpu_sum_impl_list[] = { + /* + INSTANCE(simple_sum_t<data_type::f32>), + INSTANCE(ref_sum_t), + */ + nullptr, +}; +#undef INSTANCE +} + +const spd_create_f *cpu_engine_t::get_sum_implementation_list() const { + return cpu_sum_impl_list; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp new file mode 100644 index 0000000000..0965129f9b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SUM_PD_HPP +#define CPU_SUM_PD_HPP + +#include "c_types_map.hpp" +#include "sum_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_sum_pd_t: public sum_pd_t { + using sum_pd_t::sum_pd_t; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp new file mode 100644 index 0000000000..a9810dec28 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp @@ -0,0 +1,372 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#include <cmath> + +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "gemm_utils_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace gemm_utils { +#define BM_NOCOPY_AVX 64 +#define BN_NOCOPY_AVX 48 +#define BK_NOCOPY_AVX 384 +#define BN_LARGE_NOCOPY_AVX 192 +#define BM_SMALL_NOCOPY_AVX 16 +#define BN_SMALL_NOCOPY_AVX 1 +#define BK_SMALL_NOCOPY_AVX 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + + nthr = nthrs; + nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; + nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; + nthr_k = 1; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + int nthr_other = nthr_k = 1; + while ((nthr_m * nthr_n * nthr_other < nthr) + && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { + nthr_other++; + if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) + nthr_k = nthr_other; + } + } + nthr /= nthr_k; + + if (nthr_m == 1) + nthr_n = nthr; + if (nthr_n == 1) + nthr_m = nthr; + + // Simple partition reduction + while (nthr_m * nthr_n > nthr) + if (nthr_m > nthr_n) + nthr_m--; + else + nthr_n--; + while (nthr_m * nthr_n < nthr) + if (nthr_m < nthr_n) + nthr_m++; + else + nthr_n++; + + if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) + nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) + nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX +#undef BN_NOCOPY_AVX +#undef BK_NOCOPY_AVX +#undef BN_LARGE_NOCOPY_AVX +#undef BM_SMALL_NOCOPY_AVX +#undef BN_SMALL_NOCOPY_AVX +#undef BK_SMALL_NOCOPY_AVX + +#define BM_NOCOPY_AVX512_COMMON 32 +#define BN_NOCOPY_AVX512_COMMON 64 +#define BK_NOCOPY_AVX512_COMMON 192 +#define BN_LARGE_NOCOPY_AVX512_COMMON 192 +#define BM_SMALL_NOCOPY_AVX512_COMMON 16 +#define BN_SMALL_NOCOPY_AVX512_COMMON 1 +#define BK_SMALL_NOCOPY_AVX512_COMMON 4 +// Determine number of threads for each dimension of a 3-D partitioning +// algorithm based on input parameters +// m/n/k - First/second/third parameter for GEMM +// nthrs - total available number of threads +// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension +// BM/BN/BK - blocking values +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK) +{ + int nthr, nthr_m, nthr_n, nthr_k = 1; + int MB, NB, KB; + nthr = nthrs; + + int counter = 0; + float ratio_float = 1.; + int ratio = 1; + nthr = nthrs; + int nthr_m_gt_n; + + // Partition along K dimension + // - if threading allows having barriers (e.g. OMP) + // - if there is not enough parallelism along M or N + if (mkldnn_thr_syncable()) { + if (n <= 2 * BN_NOCOPY_AVX512_COMMON && + m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) { + nthr_k = k / BK_NOCOPY_AVX512_COMMON; + if (nthr_k > nthr / 4) + nthr_k = nthr / 4; + if (nthr_k < 1) + nthr_k = 1; + + while ((nthr_k > 1) && (nthr % nthr_k)) { + nthr_k--; + } + nthr /= nthr_k; + } else { + nthr_k = 1; + } + } + nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; + nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; + ratio_float = (float)nthr_m / nthr_n; + + if (nthr_m_gt_n) + ratio = (int)ratio_float; + else + ratio = (int)(1. / ratio_float); + + // scale down nthr_m and nthr_n if they are too large + while (nthr_m * nthr_n > 4 * nthr) { + nthr_m /= 2; + nthr_n /= 2; + } + + if (nthr_m < 1) + nthr_m = 1; + if (nthr_n < 1) + nthr_n = 1; + + // Simple partition reduction + counter = 0; + while (nthr_m * nthr_n > nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m--; + else { + nthr_n--; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n--; + else { + nthr_m--; + counter = -1; + } + } + counter++; + } + + // Simple partition increment + counter = 0; + while (nthr_m * nthr_n < 0.95 * nthr) { + if (nthr_m > nthr_n) { + if (counter < ratio) + nthr_m++; + else { + nthr_n++; + counter = -1; + } + } else { + if (counter < ratio) + nthr_n++; + else { + nthr_m++; + counter = -1; + } + } + counter++; + } + + // if nothing works out, then this should work + if ((nthr_m * nthr_n > nthr)) { + + if (nthr_m <= nthr_n) { + nthr_m = (int)sqrt((double)nthr); + if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON) + nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) + / BM_SMALL_NOCOPY_AVX512_COMMON; + nthr_n = nthr / nthr_m; + + while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { + nthr_m--; + nthr_n = nthr / nthr_m; + } + } else { + nthr_n = (int)sqrt((double)nthr); + if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON) + nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) + / BN_SMALL_NOCOPY_AVX512_COMMON; + nthr_m = nthr / nthr_n; + + while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { + nthr_n--; + nthr_m = nthr / nthr_n; + } + } + } + + MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; + MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; + NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; + NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; + KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; + KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; + + if (MB * nthr_m > m) + nthr_m = (m + MB - 1) / MB; + if (NB * nthr_n > n) + nthr_n = (n + NB - 1) / NB; + if (KB * nthr_k > k) + nthr_k = (k + KB - 1) / KB; + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + + *BM = MB; + *BN = NB; + *BK = KB; +} +#undef BM_NOCOPY_AVX512_COMMON +#undef BN_NOCOPY_AVX512_COMMON +#undef BK_NOCOPY_AVX512_COMMON +#undef BN_LARGE_NOCOPY_AVX512_COMMON +#undef BM_SMALL_NOCOPY_AVX512_COMMON +#undef BN_SMALL_NOCOPY_AVX512_COMMON +#undef BK_SMALL_NOCOPY_AVX512_COMMON + +// Partition n values as equally as possible among nthr threads +// and set the offset (t_offset) and number of values (t_block) for ithr +// Assumption: 0 <= ithr < nthr +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block) +{ + int band = n / nthr; + if (band == 0) + band = 1; + int tail = n - band * nthr; + if (tail < 0) + tail = 0; + + if (ithr < tail) { + band++; + *t_offset = band * ithr; + *t_block = band; + } else { + *t_offset = band * ithr + tail; + *t_block = band; + } + + if (*t_offset >= n) { + *t_offset = 0; + *t_block = 0; + } + + if (*t_offset + *t_block > n) { + *t_block = n - *t_offset; + } +} + +// Sum the m*n values from p_src into p_dst, assuming the two-dimensional +// arrays have leading dimensions ld_src and ld_dst, respectively +template<typename data_t> +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst) +{ + int i, j; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; + } + } +} + +template +void sum_two_matrices<float>(int m, int n, + float * __restrict p_src, dim_t ld_src, + float * __restrict p_dst, dim_t ld_dst); + +template +void sum_two_matrices<double>(int m, int n, + double * __restrict p_src, dim_t ld_src, + double * __restrict p_dst, dim_t ld_dst); +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp new file mode 100644 index 0000000000..3352298b4a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp @@ -0,0 +1,72 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_UTILS_HPP +#define GEMM_UTILS_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace gemm_utils { +// Alias for any dimension related variable. +typedef ptrdiff_t dim_t; + +template <typename T, bool isTransA, bool isTransB> +struct gemm_traits {}; + +template <bool isTransA, bool isTransB> +struct gemm_traits<double, isTransA, isTransB> { + static constexpr int m = 8; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 192; + static constexpr int BK = isTransB ? 96 : 512; +}; + +template <bool isTransA, bool isTransB> +struct gemm_traits<float, isTransA, isTransB> { + static constexpr int m = 16; + static constexpr int n = 6; + static constexpr int BM = 4032; + static constexpr int BN = isTransA ? 96 : 48; + static constexpr int BK = isTransB ? 96 : 256; +}; + +template <typename T> +using unroll_factor = gemm_traits<T, false, false>; + +template <typename data_t> +void sum_two_matrices(int m, int n, + data_t * __restrict p_src, dim_t ld_src, + data_t * __restrict p_dst, dim_t ld_dst); + +void calc_nthr_nocopy_avx512_common(int m, + int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, + int *BM, int *BN, int *BK); + +void calc_nthr_nocopy_avx(int m, int n, int k, + int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN, + int *BK); + +void partition_unit_diff( + int ithr, int nthr, int n, int *t_offset, int *t_block); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp new file mode 100644 index 0000000000..d7be43e392 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp @@ -0,0 +1,2131 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <cmath> +#include <mutex> + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx512_common_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#ifdef _WIN32 +#define STACK_K_CAPACITY 32 +#else +#define STACK_K_CAPACITY 2048 +#endif +#define SIZE 4 +#define OFFSET 128 +#define BASE_SHIFT 2 +#define SECOND_FETCH unroll_n +#define UNROLL_M 48 +#define UNROLL_N 8 + +namespace avx512_common_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + enum { ver_avx512_core, ver_avx512_mic } ver = + mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto ORIG_SP = qword[rsp + 120]; + + auto ZSTRIDE = zmm4; + auto VALPHA = zmm6; + auto VBETA = zmm7; + auto VBIAS1 = zmm1; + auto VBIAS2 = zmm2; + auto VBIAS3 = zmm3; + + auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80; + auto PREFETCHSIZEB = 16; + + Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15, + zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24, + zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 }; + + // Function for packing if needed + auto do_pack = [&](int unroll_m) { + Label pack2, pack3, pack4, pack10; + + mov(BO1, A); + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE] + | k1, + zmm0); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE] + | k2, + zmm1); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE] + | k3, + zmm2); + } + } else { + for (int i = 0; i < 4; i++) { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, + ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + add(BO1, 4 * SIZE); + } + add(AO1, unroll_m * 4 * SIZE); + + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]); + if (unroll_m > 16) + vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]); + if (unroll_m > 32) + vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]); + add(BO1, LDA); + } else { + kmovw(k4, k1); + vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO1 + LDA * 8]); + kshiftrw(k4, k1, 8); + vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm0, zmm5, zmm6, 0x44); + + if (unroll_m > 16) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k2); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k2, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + vshuff64x2(zmm1, zmm5, zmm6, 0x44); + } + + if (unroll_m > 32) { + lea(BO2, ptr[BO2 + LDA * 8]); + kmovw(k4, k3); + vgatherqps(ymm5 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + kshiftrw(k4, k3, 8); + vgatherqps(ymm6 | k4, + ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 8]); + vshuff64x2(zmm2, zmm5, zmm6, 0x44); + } + add(BO1, SIZE); + } + + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + if (unroll_m > 16) + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE], + zmm1 | k2); + if (unroll_m > 32) + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE], + zmm2 | k3); + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Function to update C, covering masking and other considerations + auto update = [&](Zmm reg, bool useCO1, int offset, int mask, + bool useScale = false) { + vmulps(reg, reg, VALPHA); + if (!isBeta0) { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]); + break; + case 1: + if (useCO1) + vmovups(zmm0 | k1 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k1 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 2: + if (useCO1) + vmovups(zmm0 | k2 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k2 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + case 3: + if (useCO1) + vmovups(zmm0 | k3 | T_z, + ptr[CO1 + LDC + offset * SIZE]); + else + vmovups(zmm0 | k3 | T_z, + ptr[CO2 + LDC + offset * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(zmm0, reg, zmm0); + } else { + vfmadd132ps(zmm0, reg, VBETA); + } + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3); + break; + } + } + } else { + if (!useScale) { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg); + else + vmovups(ptr[CO2 + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + offset * SIZE], reg | k3); + break; + } + } else { + switch (mask) { + case 0: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg); + break; + case 1: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1); + break; + case 2: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2); + break; + case 3: + if (useCO1) + vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3); + else + vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3); + break; + } + } + } + vpxorq(reg, reg, reg); + }; + + // Loop with unroll_n - 2 FMAs; called by innerkernel + auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) { + for (int i = 2; i < unroll_n; i++) { + if (ver == ver_avx512_core) { + if (!isTransB) { + switch (i) { + case 2: + vbroadcastss( + zmm3, + ptr[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, + ptr[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, + ptr[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, + ptr[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, + ptr[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } else { + if (!isTransB) { + switch (i) { + case 2: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 3: + vfmadd231ps(regs[i], zmm0, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + case 4: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + (iteration - OFFSET) * SIZE]); + break; + case 5: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 1 + + (iteration - OFFSET) * SIZE]); + break; + case 6: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB * 2 + + (iteration - OFFSET) * SIZE]); + break; + case 7: + vfmadd231ps(regs[i], zmm0, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO2 + LDB3 + + (iteration - OFFSET) * SIZE]); + break; + } + } else { + vfmadd231ps( + regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[i + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[i + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } + } + } + }; + + // Innerkernel; called by kernel + auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool doCPrefetch, bool isUnmasked = true) { + for (int i = 0; i < 8; i++) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET) + * SIZE]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]); + if (unroll_m >= 32) + prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]); + if (unroll_m >= 48) + prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]); + } + + if (!isDirect) { + if (i != 0) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * i + 1 * 16 + - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * i + 2 * 16 + - OFFSET) + * SIZE]); + } + } + } + } else { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[0], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[0], zmm0, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[0 + 8], zmm1, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[0 + 16], zmm2, + zword_b[BO1 + (0 - OFFSET) * SIZE]); + } + } + + if (unroll_n >= i + 1) { + if (!isTransB) { + switch (i) { + case 0: + prefetcht0( + ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 1: + prefetcht0(ptr[BO1 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 2: + prefetcht0(ptr[BO1 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 3: + prefetcht0(ptr[BO1 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 4: + prefetcht0( + ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 5: + prefetcht0(ptr[BO2 + LDB + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 6: + prefetcht0(ptr[BO2 + LDB * 2 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + case 7: + prefetcht0(ptr[BO2 + LDB3 + + (PREFETCHSIZEB - OFFSET) * SIZE]); + break; + } + } + } + + if (unroll_n >= 2) { + if (ver == ver_avx512_core) { + if (!isTransB) { + vbroadcastss(zmm3, + ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + vfmadd231ps(regs[1], zmm3, zmm0); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm3, zmm1); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm3, zmm2); + } else { + if (!isTransB) { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + LDB * 1 + + (i - OFFSET) * SIZE]); + } else { + vfmadd231ps(regs[1], zmm0, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 32) + vfmadd231ps(regs[1 + 8], zmm1, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + if (unroll_m >= 48) + vfmadd231ps(regs[1 + 16], zmm2, + zword_b[BO1 + (1 - OFFSET) * SIZE]); + } + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 0 * 16 - OFFSET) + * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * i + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + if (i == 7) + sub(LDA4, -unroll_m * 8 * SIZE); + } + fmaloop(unroll_m, unroll_n, i); + + if (i == 1) { + if (doCPrefetch) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 0 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 0 * 16 * SIZE]); + } + } + if (i == 3) { + if (doCPrefetch && unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 1 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 1 * 16 * SIZE]); + } + if (!isTransA) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 0 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 0 * SIZE]); + } + } + if (i == 5) { + if (doCPrefetch) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetchw(ptr[CO2 + 2 * 16 * SIZE]); + else + prefetcht0(ptr[CO2 + 2 * 16 * SIZE]); + } + add(CO2, LDC); + } + if (!isTransA) { + if (unroll_m >= 32) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 1 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 1 * SIZE]); + } + } + } + + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + } // end of for loop + + if (!isTransB) { + sub(BO1, -8 * SIZE); + if (unroll_n >= 4) + sub(BO2, -8 * SIZE); + } + if (!isTransA) { + if (unroll_m >= 48) { + if (ver == ver_avx512_core) + prefetcht0(ptr[AA + 16 * 2 * SIZE]); + else + prefetcht2(ptr[AA + 16 * 2 * SIZE]); + } + lea(AA, ptr[AA + LDA]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 8 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 8 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * 8 * SIZE); + } + + sub(LL, 1); + }; + + // Main kernel; does prefetching and calls innerkernel + // After calculating results in registers, writes back to C matrix by + // calling update + auto kernel = [&](int unroll_m, int unroll_n, bool isDirect, + bool isCopy, bool isUnmasked = true) { + if (!isDirect) { + lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]); + } else { + auto step = ver == ver_avx512_core ? 2 : 4; + lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE]); + } + } + } + + Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18; + + mov(LL, K); + sar(LL, 3); + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel( + unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + lea(CO2, ptr[CO1 + (16 - 1) * SIZE]); + add(LL, unroll_n); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + mov(LL, K); + and_(LL, 7); + jle(kernel18, T_NEAR); + align(16); + + L(kernel16); + if (isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + (1 * 16 - OFFSET) * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + (2 * 16 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + for (int i = 0; i < unroll_n; i++) { + if (!isTransB) { + switch (i) { + case 0: + vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]); + break; + case 1: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 2: + vbroadcastss( + zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 3: + vbroadcastss( + zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]); + break; + case 4: + vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]); + break; + case 5: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + break; + case 6: + vbroadcastss( + zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + break; + case 7: + vbroadcastss( + zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]); + break; + } + } else { + vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]); + } + vfmadd231ps(regs[i], zmm3, zmm0); + if (unroll_m >= 32) { + vfmadd231ps(regs[i + 8], zmm3, zmm1); + } + if (unroll_m >= 48) { + vfmadd231ps(regs[i + 16], zmm3, zmm2); + } + } + + if (isCopy) { + if (isUnmasked || unroll_m > 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0); + } else { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE], + zmm0 | k1); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 16 - OFFSET) + * SIZE], + zmm1 | k2); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2); + } else { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 2 * 16 - OFFSET) + * SIZE], + zmm2 | k3); + } + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isUnmasked || unroll_m > 16) { + vmovups(zmm0, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } else { + vmovups(zmm0 | k1 | T_z, + ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]); + } + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) { + vmovups(zmm1, ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm1 | k2 | T_z, + ptr[AO1 + + (unroll_m * 1 + 1 * 16 - OFFSET) + * SIZE]); + } + } + if (unroll_m >= 48) { + if (isUnmasked) { + vmovups(zmm2, ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } else { + vmovups(zmm2 | k3 | T_z, + ptr[AO1 + + (unroll_m * 1 + 2 * 16 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + sub(LL, 1); + jg(kernel16, T_NEAR); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta cases need to be handled + if (hasBias) { + mov(BIAS1, BIAS); + if (isUnmasked || unroll_m > 16) + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + else + vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]); + if (unroll_m >= 32) { + if (isUnmasked || unroll_m > 32) + vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]); + else + vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]); + } + if (unroll_m >= 48) { + if (isUnmasked) + vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]); + else + vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + bool useScale = i % 2 != 0; + bool useCO1 = i < 2; + if (i == 2) + lea(CO2, ptr[CO1 + LDC * 2]); + if (i == 4 || i == 6) + lea(CO2, ptr[CO2 + LDC * 2]); + if (hasBias) + vaddps(regs[i], VBIAS1, regs[i]); + if (isUnmasked || unroll_m > 16) { + update(regs[i], useCO1, 0, 0, useScale); + } else { + update(regs[i], useCO1, 0, 1, useScale); + } + if (unroll_m >= 32) { + if (hasBias) + vaddps(regs[i + 8], VBIAS2, regs[i + 8]); + if (isUnmasked || unroll_m > 32) { + update(regs[i + 8], useCO1, 16, 0, useScale); + } else { + update(regs[i + 8], useCO1, 16, 2, useScale); + } + } + if (unroll_m >= 48) { + if (hasBias) + vaddps(regs[i + 16], VBIAS3, regs[i + 16]); + if (isUnmasked) { + update(regs[i + 16], useCO1, 32, 0, useScale); + } else { + update(regs[i + 16], useCO1, 32, 3, useScale); + } + } + } + + switch (unroll_n) { + case 1: add(CO1, LDC); break; + case 2: lea(CO1, ptr[CO1 + LDC * 2]); break; + case 3: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 4: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 5: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 6: lea(CO1, ptr[CO2 + LDC * 2]); break; + case 7: lea(CO1, ptr[CO2 + LDC * 1]); break; + case 8: lea(CO1, ptr[CO2 + LDC * 2]); break; + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + case 7: + lea(BO1, ptr[BO1 + LDB * 8]); + sub(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 8]); + sub(BO2, LDB); + break; + case 8: + lea(BO1, ptr[BO1 + LDB * 8]); + lea(BO2, ptr[BO2 + LDB * 8]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 48 rows, 8 columns at a time (handling tail + // cases appropriately by doing 32 or 16 rows, and/or with masking, + // and/or fewer columns). + auto subloop = [&](int unroll_m) { + Label l_subloop_20x[8], l_subloop_mask_20x[8]; + Label l_subloop_30x[8], l_subloop_mask_30x[8]; + + Label subloop11, subloop11mask; + Label subloop30, subloop30mask; + Label subloop31, subloop31mask; + Label subloop96; + Label subloop98, subloop98mask; + Label subloop99; + + // Create mask + mov(BO1, rcx); + mov(rcx, M); + sub(rcx, unroll_m - 16); + mov(CO1, 16); + cmp(rcx, 16); + + cmovg(rcx, CO1); + mov(rax, 1); + sal(rax, cl); + sub(rax, 1); + mov(rcx, 0xffff); + + if (unroll_m == 16) { + kmovw(k1, eax); + } else if (unroll_m == 32) { + kmovw(k1, ecx); + kmovw(k2, eax); + } else { + kmovw(k1, ecx); + kmovw(k2, ecx); + kmovw(k3, eax); + } + mov(rcx, BO1); + + and_(rax, 0xffff); + cmp(rax, 0xffff); + jne(subloop96, T_NEAR); + + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true); + } else { + kernel(unroll_m, UNROLL_N, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_20x[1], T_NEAR); + align(16); + + L(subloop11); + kernel(unroll_m, UNROLL_N, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(l_subloop_30x[1], T_NEAR); + align(16); + + L(subloop31); + kernel(unroll_m, UNROLL_N, true, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop96); + if (isTransA) { + do_pack(unroll_m); + } + + mov(CO1, C); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, ptr[B + LDB * 4]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98mask, T_NEAR); + mov(AA, ORIG_A); + lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]); + L(subloop98mask); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30mask, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30mask, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + } + align(16); + + if (!isTransA) { + kernel(unroll_m, UNROLL_N, true, true, false); + } else { + kernel(unroll_m, UNROLL_N, false, false, false); + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(l_subloop_mask_20x[1], T_NEAR); + align(16); + + L(subloop11mask); + kernel(unroll_m, UNROLL_N, false, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_20x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_20x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, false, false, false); + jmp(subloop99, T_NEAR); + align(16); + } + + if (!isTransA) { + L(subloop30mask); + cmp(I, UNROLL_N); + jl(l_subloop_mask_30x[1], T_NEAR); + align(16); + + L(subloop31mask); + kernel(unroll_m, UNROLL_N, true, false, false); + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31mask, T_NEAR); + align(16); + + for (int i = 1; i <= 7; i++) { + L(l_subloop_mask_30x[i]); + cmp(I, i); + if (i < 7) { + jne(l_subloop_mask_30x[i + 1], T_NEAR); + } else { + jne(subloop99, T_NEAR); + } + kernel(unroll_m, i, true, false, false); + if (i < 7) + jmp(subloop99, T_NEAR); + align(16); + } + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + imul(rax, rax, 0x30); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + if (isTransA) { + vpbroadcastq(zmm2, LDA); + vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE); + mov(rax, -2); + kmovw(k4, eax); + + for (int i = 0; i < 6; i++) { + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + kshiftlw(k4, k4, 1); + } + vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f); + mov(FLAG, rax); + + for (int i = 8; i < 16; i++) { + for (int j = 0; j < 3; j++) { + vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j)); + } + } + + Label main0, main1, main2, main999; + + cmp(M, 32); + jle(main0, T_NEAR); + align(16); + + L(main1); + subloop(48); + sub(M, UNROLL_M); + cmp(M, 32); + jg(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 16); + jle(main2, T_NEAR); + + subloop(32); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 0); + jle(main999, T_NEAR); + subloop(16); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode<ker_t>(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032, BN, BK; + if (mayiuse(avx512_core)) { + BN = isTransA ? 384 : 64; + BK = 384; + } else { + BN = isTransA ? 96 : 64; + BK = isTransB ? 96 : 192; + if (!isTransA && !isTransB) + BK = 128; + } + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx512_common_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx512_common( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 48 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp new file mode 100644 index 0000000000..d581b7fd71 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP +#define JIT_AVX512_COMMON_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_common_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp new file mode 100644 index 0000000000..60d4220837 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp @@ -0,0 +1,2705 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <cmath> +#include <mutex> + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "ref_gemm_f32.hpp" +#include "gemm_utils_f32.hpp" +#include "jit_avx_gemm_f32.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define CACHE_LINE_SIZE 64 + +#define STACKSIZE get_size_of_abi_save_regs() +#if _WIN32 +#define STACK_K_CAPACITY 128 +#else +#define STACK_K_CAPACITY 8192 +#endif +#define SIZE 4 +#define OFFSET 32 +#define BASE_SHIFT 2 +#define SECOND_FETCH 14 + +namespace avx_gemm_f32 { +using namespace gemm_utils; + +struct xbyak_gemm : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm) + + xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false, + void *code_ptr = nullptr, + size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + { + using namespace Xbyak; + + const bool is_avx2 = mayiuse(avx2); + assert(IMPLICATION(!is_avx2, mayiuse(avx))); + + const int UNROLL_M = is_avx2 ? 16 : 8; + const int UNROLL_N = 6; + + bool isBeta0 = (beta == 0.0); + bool isBetaN = (!isBeta0 && beta != 1.0); + + // various definitions for convenience + auto ARG_M = abi_param1; + auto ARG_N = abi_param2; + auto K = abi_param3; + auto ARG_ALPHA = abi_param4; +#ifdef _WIN32 + auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE]; + auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE]; + const auto stackOffset = OFFSET_SHADOWSPACE + + sizeof(float *) + STACKSIZE; + auto A = rsi; + auto LDA = rdi; +#else + auto ARG_A = r8; + auto ARG_LDA = r9; + const auto stackOffset = STACKSIZE; + auto A = ARG_A; + auto LDA = ARG_LDA; +#endif + auto ARG_B = ptr[rsp + 8 + stackOffset]; + auto ARG_LDB = ptr[rsp + 16 + stackOffset]; + auto ARG_BETA = ptr[rsp + 24 + stackOffset]; + auto ARG_C = ptr[rsp + 32 + stackOffset]; + auto ARG_LDC = ptr[rsp + 40 + stackOffset]; + auto ARG_BIAS = ptr[rsp + 48 + stackOffset]; + auto ARG_WS = ptr[rsp + 56 + stackOffset]; + + auto B = r11; + auto LDB = rbx; + auto LDC = r13; + auto LL = rax; + auto AO1 = abi_param2; + auto BO1 = abi_param4; + auto BO2 = rbp; + auto CO1 = r14; + auto CO2 = r15; + auto LDB3 = r10; + auto LDA4 = abi_param1; + auto AA = r12; + auto BIAS1 = abi_param1; + + auto M = qword[rsp + 0]; + auto N = qword[rsp + 8]; + auto FLAG = qword[rsp + 16]; + auto I = qword[rsp + 24]; + auto C = qword[rsp + 32]; + auto BIAS = qword[rsp + 40]; + auto ALPHA = qword[rsp + 48]; + auto BETA = qword[rsp + 64]; + auto ORIG_A = qword[rsp + 80]; + auto MASK = dword[rsp + 88]; + auto STRIDE = qword[rsp + 120]; + auto ORIG_SP = qword[rsp + 152]; + + auto VALPHA = ymm1; + auto VBETA = ymm2; + auto VMASK = ymm3; + auto VBIAS1 = ymm2; + auto VBIAS2 = ymm4; + + auto PREFETCHSIZEA = 128; + auto PREFETCHSIZEB = (!isTransB) ? -16 : 0; + + // Function for packing if needed + auto do_pack = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + Label pack2, pack3, pack4, pack10; + + int regIdx; + Reg64 reg; + + mov(BO1, A); + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + + if (isTransA) { + lea(BO2, ptr[BO1 + LDA * 4]); + lea(CO1, ptr[LDA + LDA * 2]); + vmovupd(ymm7, STRIDE); + } + + mov(LL, K); + sar(LL, 2); + jle(pack3, T_NEAR); + align(16); + + L(pack2); + if (!isTransA) { + for (int i = 0; i < 4; i++) { + regIdx = (i % 2 == 0) ? 4 : 6; + if (isLoad1Unmasked) { + vmovups(Ymm(regIdx), + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx), VMASK, + ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(Ymm(regIdx + 1), + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(Ymm(regIdx + 1), VMASK, + ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + + vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + Ymm(regIdx)); + if (unroll_m > 8) { + vmovups(ptr[AO1 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + Ymm(regIdx + 1)); + } + } + + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + i * 4 - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + i * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + i * 4 - OFFSET) + * SIZE], + xmm1); + } + } else if (is_avx2) { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 0 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 0 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO1 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 1 * 4 - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 1 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovups(Xmm(ld_step % 2), ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + for (int off = 0; off < 4; ++off) + pextrd(ptr[dst_addr + unroll_m * off * SIZE], + Xmm(ld_step % 2), off); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm4, xmm0, xmm1); + vunpckhps(xmm5, xmm0, xmm1); + vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovups(xmm1, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + if (i == 0) + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(xmm6, xmm0, xmm1); + vunpckhps(xmm2, xmm0, xmm1); + + vunpcklpd(xmm0, xmm4, xmm6); + vunpckhpd(xmm1, xmm4, xmm6); + vmovups(ptr[AO1 + + (unroll_m * 0 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 1 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + vunpcklpd(xmm0, xmm5, xmm2); + vunpckhpd(xmm1, xmm5, xmm2); + vmovups(ptr[AO1 + + (unroll_m * 2 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * 3 + (i + 2) * 4 + - OFFSET) + * SIZE], + xmm1); + } + } else { + for (int i = 0; i < 2; i++) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 2 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 2 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + + for (int i = 0; i < 2; i++) { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm0, + ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE], + xmm4); + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + + ((2 * i + 1) - OFFSET) * SIZE], + xmm4); + + vmovups(ptr[AO1 + + (unroll_m * (2 * i) + 3 * 4 + - OFFSET) + * SIZE], + xmm0); + vmovups(ptr[AO1 + + (unroll_m * (2 * i + 1) + 3 * 4 + - OFFSET) + * SIZE], + xmm1); + } + + lea(BO2, ptr[BO2 + LDA * 4]); + } + } + add(BO1, (4 * SIZE)); + } + + add(AO1, unroll_m * 4 * SIZE); + sub(LL, 1); + jg(pack2, T_NEAR); + align(16); + + L(pack3); + mov(LL, K); + and_(LL, 3); + jle(pack10, T_NEAR); + align(16); + + L(pack4); + if (!isTransA) { + if (isLoad1Unmasked) { + vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m > 8) { + if (isLoad2Unmasked) { + vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm5, VMASK, + ptr[BO1 + (1 + 8 - OFFSET) * SIZE]); + } + } + add(BO1, LDA); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm4); + if (unroll_m > 8) { + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm5); + } + } else { + if (isLoad1Unmasked) { + for (int i = 0; i < 2; i++) { + reg = (i % 2 == 0) ? BO1 : BO2; + vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[reg + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, + ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else if (is_avx2) { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE], + xmm1); + + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE], + xmm1); + } else { + vxorps(xmm4, xmm4, xmm4); + lea(BO2, ptr[BO1 + LDA * 4]); + + auto el_cp = [&](int section, int ld_step) { + RegExp src_addr = section == 0 ? BO1 : BO2; + if (ld_step == 1 || ld_step == 2) + src_addr = src_addr + LDA * ld_step; + else if (ld_step == 3) + src_addr = src_addr + CO1; + src_addr = src_addr - OFFSET * SIZE; + + vmovss(xmm1, ptr[src_addr]); + RegExp dst_addr = AO1 + + (ld_step + section * 4 - OFFSET) * SIZE; + movss(ptr[dst_addr], xmm1); + }; + + Label l_end; + el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR); + el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR); + el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR); + el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR); + el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR); + el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR); + el_cp(1, 2); + L(l_end); + + lea(BO2, ptr[BO2 + LDA * 4]); + } + + if (unroll_m >= 16) { + assert(is_avx2); + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vmovaps(xmm4, xmm3); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + lea(BO2, ptr[BO2 + LDA * 4]); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE], + xmm1); + + if (isLoad2Unmasked) { + for (int i = 0; i < 2; i++) { + vmovss(Xmm(i + 1), + ptr[BO2 + (0 * 8 - OFFSET) * SIZE]); + vmovss(xmm0, ptr[BO2 + LDA * 1 + + (0 * 8 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDA * 2]); + vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0)); + } + vunpcklpd(xmm1, xmm1, xmm2); + } else { + vextractf128(xmm4, ymm3, 1); + vgatherqps(xmm1, + ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE], + xmm4); + } + vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE], + xmm1); + } + add(BO1, SIZE); + } + + add(AO1, unroll_m * SIZE); + sub(LL, 1); + jg(pack4, T_NEAR); + align(16); + + L(pack10); + }; + + // Fused multiply add; may become one or two instructions + auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2, + bool overWrite = false) { + if (useFma) { + if (is_avx2) { + vfmadd231ps(reg2, reg1, reg0); + } else { + assert(UNROLL_M == 8); + auto tent_vreg = overWrite ? reg1 : ymm1; + vmulps(tent_vreg, reg1, reg0); + vaddps(reg2, reg2, tent_vreg); + } + } else { + if (!overWrite) { + vmulps(ymm15, reg1, reg0); + vaddps(reg2, reg2, ymm15); + } else { + vmulps(reg1, reg1, reg0); + vaddps(reg2, reg2, reg1); + } + } + }; + + // Inner kernel with k=8 + auto innerkernel8 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 8; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 7) { + sub(LDA4, -unroll_m * 8 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3 || i == 4 || i == 5 || i == 6) { + if (unroll_m >= 16) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 7) { + if (!isTransB) { + if (unroll_n >= 4) { + sub(BO2, -8 * SIZE); + } + } + if (!isTransA) { + prefetcht2(ptr[AA]); + lea(AA, ptr[AA + LDA]); + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 8 * SIZE); + } + sub(LL, 1); + + }; + + // Inner kernel with k=4 + auto innerkernel4 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + if (!isDirect) { + prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + + for (int i = 0; i < 4; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (i == 0) { + if (!isTransB) { + prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]); + } + } + if (unroll_n >= 2) { + if (!isTransB) { + if (i == 1) { + prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * i + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + if (i == 3) { + sub(LDA4, -unroll_m * 4 * SIZE); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (i == 7) { + if (!isTransB) { + sub(BO1, -8 * SIZE); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + if (i == 3) { + prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + if (i == 4) { + prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + if (i == 5) { + prefetcht0( + ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + if (isTransB) { + prefetcht0(ptr[BO1 + BO2]); + add(BO1, LDB); + } + + if (i == 0) { + if (unroll_m >= 4) { + if (!isDirect) { + prefetcht0( + ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 1 || i == 2) { + if (unroll_m >= 8) { + if (!isDirect) { + prefetcht0(ptr[AO1 + + (PREFETCHSIZEA + (2 + 2 * i) * 8) + * SIZE]); + } else { + prefetcht0(ptr[AO1 + LDA4]); + } + } + } + if (i == 3) { + if (!isTransB) { + sub(BO1, -4 * SIZE); + if (unroll_n >= 4) { + sub(BO2, -4 * SIZE); + } + } + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps( + ymm0, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * (i + 1) + 1 * 8 + - OFFSET) + * SIZE]); + } + } + } + } + + if (!isDirect) { + sub(AO1, -unroll_m * 4 * SIZE); + } + + }; + + // Inner kernel with k=2 + auto innerkernel2 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12, + Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17, + Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22, + Ymm reg23) { + + Ymm fmareg; + + for (int i = 0; i < 2; i++) { + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg00 : reg12; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg06 : reg18; + fma(useFma, ymm1, ymm2, fmareg); + } + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg01 : reg13; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg07 : reg19; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + if (i == 2) { + prefetcht0( + ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]); + } + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg02 : reg14; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg08 : reg20; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg03 : reg15; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg09 : reg21; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg04 : reg16; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg10 : reg22; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fmareg = (i % 2 == 0) ? reg05 : reg17; + fma(useFma, ymm0, ymm2, fmareg); + if (unroll_m >= 16) { + fmareg = (i % 2 == 0) ? reg11 : reg23; + fma(useFma, ymm1, ymm2, fmareg); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + + (unroll_m * 1 + 0 * 8 - OFFSET) + * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + } + + }; + + // Inner kernel with k=1 + auto innerkernel1 = [&](int unroll_m, int unroll_n, + bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect, + bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02, + Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07, + Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) { + + if (isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + (1 * 8 - OFFSET) * SIZE]); + } + } + add(AO1, LDA); + } + + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg00); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg06); + } + + if (unroll_n >= 2) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg01); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg07); + } + } + + if (unroll_n >= 3) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg02); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg08); + } + } + + if (unroll_n >= 4) { + if (!isTransB) { + vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg03); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg09); + } + } + + if (unroll_n >= 5) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg04); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg10); + } + } + + if (unroll_n >= 6) { + if (!isTransB) { + vbroadcastss( + ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]); + } else { + vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]); + } + fma(useFma, ymm0, ymm2, reg05); + if (unroll_m >= 16) { + fma(useFma, ymm1, ymm2, reg11); + } + } + + if (isCopy) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE], + ymm0); + if (unroll_m >= 16) { + vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE], + ymm1); + } + sub(LDA4, -unroll_m * SIZE); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 1 + 1 * 8 - OFFSET) + * SIZE]); + } + } + sub(AO1, -unroll_m * SIZE); + } + + if (!isTransB) { + sub(BO1, -SIZE); + if (unroll_n >= 4) { + sub(BO2, -SIZE); + } + } else { + add(BO1, LDB); + } + + }; + + // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as + // appropriate + // After calculating results in registers, writes back to C matrix + auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma, + Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6), + Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9), + Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12), + Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15), + Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6), + Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9), + Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12), + Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) { + if (!isDirect) { + lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + mov(AO1, A); + } + + if (isCopy) { + lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]); + } else { + lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]); + } + + if (isTransB) { + lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]); + lea(BO2, ptr[BO2 + LDB * 2]); + } + + if (!isDirect) { + if (isLoad1Unmasked) { + vmovups(ymm0, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } else { + vmaskmovps(ymm0, VMASK, + ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]); + } + if (unroll_m >= 16) { + if (isLoad2Unmasked) { + vmovups(ymm1, ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } else { + vmaskmovps(ymm1, VMASK, + ptr[AO1 + + (unroll_m * 0 + 1 * 8 - OFFSET) + * SIZE]); + } + } + } + + for (int i = 4; i < 10; i++) { + vxorps(Ymm(i), Ymm(i), Ymm(i)); + vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6)); + } + + mov(LL, K); + sar(LL, 3); + + Label kernel12, kernel13, kernel14, kernel15; + Label kernel16, kernel17, kernel18; + + sub(LL, SECOND_FETCH); + jle(kernel13, T_NEAR); + align(16); + + L(kernel12); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel12, T_NEAR); + align(16); + + L(kernel13); + prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 2) + prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 3) + prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 4) + prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]); + if (unroll_n >= 5) + prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]); + if (unroll_n >= 6) + prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]); + + add(LL, SECOND_FETCH); + jle(kernel15, T_NEAR); + align(16); + + L(kernel14); + innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + jg(kernel14, T_NEAR); + align(16); + + L(kernel15); + test(K, 4); + jle(kernel16, T_NEAR); + innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + + L(kernel16); + test(K, 2); + jle(kernel17, T_NEAR); + innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12, + reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20, + reg21, reg22, reg23); + align(16); + + L(kernel17); + if (unroll_m == 16) { + if (unroll_n <= 3) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg06, reg06, reg18); + vaddps(reg07, reg07, reg19); + vaddps(reg08, reg08, reg20); + } + } + + if (unroll_m <= 8) { + vaddps(reg00, reg00, reg12); + vaddps(reg01, reg01, reg13); + vaddps(reg02, reg02, reg14); + vaddps(reg03, reg03, reg15); + vaddps(reg04, reg04, reg16); + vaddps(reg05, reg05, reg17); + } + + test(K, 1); + jle(kernel18, T_NEAR); + innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04, + reg05, reg06, reg07, reg08, reg09, reg10, reg11); + align(16); + + L(kernel18); + vbroadcastss(VALPHA, ALPHA); + + if (isBetaN) { + vbroadcastss(VBETA, BETA); + } + + // Write back the results; all beta and bias cases need to be + // handled + switch (unroll_n) { + case 1: mov(rax, LDC); break; + case 2: lea(rax, ptr[LDC * 2]); break; + case 3: lea(rax, ptr[LDC + LDC * 2]); break; + case 4: lea(rax, ptr[LDC + LDC * 4]); break; + case 5: + lea(rax, ptr[LDC * 4]); + add(rax, LDC); + break; + case 6: + lea(rax, ptr[LDC + LDC * 2]); + add(rax, rax); + break; + } + + if (hasBias) { + mov(BIAS1, BIAS); + if (isLoad1Unmasked) { + vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]); + } else { + vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]); + } + } + + for (int i = 0; i < unroll_n; i++) { + vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA); + if (!isBeta0) { + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break; + case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break; + case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]); + break; + case 1: + vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]); + break; + case 2: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]); + break; + case 4: + vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]); + break; + case 5: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]); + break; + } + } + + if (!isBetaN) { + vaddps(Ymm(i + 4), ymm0, Ymm(i + 4)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 4), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4)); + } + if (isLoad1Unmasked) { + switch (i) { + case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break; + case 1: + vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break; + case 4: + vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 1: + vmaskmovps( + ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + case 3: + vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 4: + vmaskmovps( + ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK, + Ymm(i + 4)); + break; + } + } + + if (unroll_m >= 16) { + // Re-use ymm4 (VBIAS2) + if (i == 0) { + if (hasBias) { + if (isLoad1Unmasked) { + vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]); + } else { + vmaskmovps( + VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]); + } + } + } + vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA); + if (!isBeta0) { + if (isLoad2Unmasked) { + switch (i) { + case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break; + case 1: + vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break; + case 4: + vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]); + break; + case 1: + vmaskmovps( + ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]); + break; + case 2: + vmaskmovps(ymm0, VMASK, + ptr[CO1 + LDC * 2 + 8 * SIZE]); + break; + case 3: + vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]); + break; + case 4: + vmaskmovps( + ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]); + break; + case 5: + vmaskmovps(ymm0, VMASK, + ptr[CO2 + LDC * 2 + 8 * SIZE]); + break; + } + } + if (!isBetaN) { + vaddps(Ymm(i + 10), ymm0, Ymm(i + 10)); + } else { + fma(useFma, VBETA, ymm0, Ymm(i + 10), true); + } + } + if (hasBias) { + vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10)); + } + if (isLoad2Unmasked) { + switch (i) { + case 0: + vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10)); + break; + case 1: + vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 2: + vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + case 3: + vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10)); + break; + case 4: + vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10)); + break; + case 5: + vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10)); + break; + } + } else { + switch (i) { + case 0: + vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 1: + vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 2: + vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 3: + vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10)); + break; + case 4: + vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + case 5: + vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK, + Ymm(i + 10)); + break; + } + } + } + if (i == 2) + add(CO1, rax); + } + if (unroll_n >= 4) { + add(CO2, rax); + } + + // Compute next address of B + if (!isTransB) { + lea(rax, ptr[K * SIZE]); + switch (unroll_n) { + case 1: + add(BO1, LDB); + add(BO2, LDB); + break; + case 2: + lea(BO1, ptr[BO1 + LDB * 2]); + lea(BO2, ptr[BO2 + LDB * 2]); + break; + case 3: + lea(BO1, ptr[BO1 + LDB3]); + lea(BO2, ptr[BO2 + LDB3]); + break; + case 4: + lea(BO1, ptr[BO1 + LDB * 4]); + lea(BO2, ptr[BO2 + LDB * 4]); + break; + case 5: + lea(BO1, ptr[BO1 + LDB * 4]); + add(BO1, LDB); + lea(BO2, ptr[BO2 + LDB * 4]); + add(BO2, LDB); + break; + case 6: + lea(BO1, ptr[BO1 + LDB3 * 2]); + lea(BO2, ptr[BO2 + LDB3 * 2]); + break; + } + sub(BO1, rax); + sub(BO2, rax); + } else { + mov(rax, LDB); + imul(rax, K); + sub(BO1, rax); + add(BO1, unroll_n * SIZE); + } + }; + + auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, true); + }; + + auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15)); + }; + + auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy); + }; + + auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy, + bool useFma = true) { + kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7), + Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14), + Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9), + Ymm(13), Ymm(14), Ymm(15)); + }; + + auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked, + bool isLoad2Unmasked, bool isDirect, bool isCopy) { + kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked, + isDirect, isCopy, false); + }; + + // High-level subroutine; does packing if needed, then splits C matrix. + // Operates on chunks of 16 rows, 6 columns at a time (handling tail + // cases appropriately). + // Masking is used for tail cases where M is not divisible by 8. + auto subloop = [&]( + int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) { + if (isTransA) { + do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked); + } + + Label subloop11, subloop11mask; + Label subloop20, subloop21, subloop22, subloop23; + Label subloop24, subloop25; + Label subloop30, subloop31, subloop32, subloop33; + Label subloop34, subloop35; + Label subloop98, subloop98mask; + Label subloop99, subloop99mask; + + mov(CO1, C); + lea(CO2, ptr[CO1 + LDC * 2]); + add(CO2, LDC); + add(C, unroll_m * SIZE); + mov(BO1, B); + if (!isTransB) { + lea(BO2, qword[B + LDB3]); + } + + if (!isTransA) { + lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]); + cmp(M, UNROLL_M); + jg(subloop98, T_NEAR); + + mov(AA, ORIG_A); + lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]); + L(subloop98); + } + + mov(LL, N); + mov(I, LL); + if (!isTransA) { + // If N is too small, skip copy operation + cmp(LL, UNROLL_N * 3); + jle(subloop30, T_NEAR); + + // If A is not aligned to cache line + cmp(FLAG, 0); + je(subloop30, T_NEAR); + } else { + cmp(LL, UNROLL_N); + jl(subloop20, T_NEAR); + } + align(16); + + if (!isTransA) { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, true); + } + } else { + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } + } + + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jl(subloop20, T_NEAR); + align(16); + + L(subloop11); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, false, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop11, T_NEAR); + align(16); + + L(subloop20); + cmp(I, 1); + jne(subloop21, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop21); + cmp(I, 2); + jne(subloop22, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop22); + cmp(I, 3); + jne(subloop23, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop23); + cmp(I, 4); + jne(subloop24, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop24); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + false, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false, + false); + } + jmp(subloop99, T_NEAR); + align(16); + + if (!isTransA) { + L(subloop30); + cmp(I, UNROLL_N); + jl(subloop25, T_NEAR); + align(16); + + L(subloop31); + if (unroll_m == 16) { + kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } else { + kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, + isLoad2Unmasked, true, false); + } + sub(I, UNROLL_N); + cmp(I, UNROLL_N); + jge(subloop31, T_NEAR); + align(16); + + L(subloop25); + cmp(I, 1); + jne(subloop32, T_NEAR); + if (unroll_m == 16) { + kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop32); + cmp(I, 2); + jne(subloop33, T_NEAR); + if (unroll_m == 16) { + kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop33); + cmp(I, 3); + jne(subloop34, T_NEAR); + if (unroll_m == 16) { + kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop34); + cmp(I, 4); + jne(subloop35, T_NEAR); + if (unroll_m == 16) { + kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + jmp(subloop99, T_NEAR); + align(16); + + L(subloop35); + cmp(I, 5); + jne(subloop99, T_NEAR); + if (unroll_m == 16) { + kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } else { + kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, + true, false); + } + align(16); + } + + L(subloop99); + // Compute address for A + if (!isTransA) { + add(A, unroll_m * SIZE); + } else { + mov(rax, LDA); + imul(rax, rax, unroll_m); + add(A, rax); + } + + // Compute next address of BIAS + if (hasBias) { + add(BIAS, unroll_m * SIZE); + } + }; + + preamble(); + + Label buffer_in_ws, buffer_allocated; + + // Get the registers + mov(B, ARG_B); + mov(LDB, ARG_LDB); + mov(r15, ARG_BETA); + mov(r12, ARG_C); + if (hasBias) + mov(r10, ARG_BIAS); + mov(LDC, ARG_LDC); + mov(rbp, rsp); + + vmovss(xmm0, ptr[ARG_ALPHA]); + vmovss(xmm1, ptr[r15]); + +#if _WIN32 + mov(A, ARG_A); + mov(LDA, ARG_LDA); +#endif + + cmp(K, STACK_K_CAPACITY); + jg(buffer_in_ws, T_NEAR); + + // Create buffer and align to 4kB page + lea(rax, ptr[K * SIZE]); + sal(rax, 4); + add(rax, 256); + sub(rsp, rax); + and_(rsp, -PAGE_4K); + jmp(buffer_allocated, T_NEAR); + + L(buffer_in_ws); + mov(rsp, ARG_WS); + + L(buffer_allocated); + + mov(ORIG_SP, rbp); + mov(M, ARG_M); + mov(N, ARG_N); + mov(C, r12); + if (hasBias) + mov(BIAS, r10); + vmovss(ALPHA, xmm0); + vmovss(BETA, xmm1); + sub(A, -OFFSET * SIZE); + sub(B, -OFFSET * SIZE); + mov(ORIG_A, A); + sal(LDA, BASE_SHIFT); + sal(LDB, BASE_SHIFT); + sal(LDC, BASE_SHIFT); + lea(LDB3, ptr[LDB + LDB * 2]); + + for (int i = 0; i < 8; i++) { + mov(dword[rsp + 88 + i * 4], i); + } + + if (isTransA && is_avx2) { + movq(xmm0, LDA); + vpbroadcastq(ymm1, xmm0); + vinsertf128(ymm0, ymm0, xmm0, 1); + vpermilpd(ymm0, ymm0, 5); + vpaddq(ymm1, ymm1, ymm1); + vperm2f128(ymm1, ymm1, ymm1, 8); + vpaddq(ymm0, ymm0, ymm1); + vmovups(STRIDE, ymm0); + } + + // Check A alignment and leading dimension; take copy-based path as + // needed + mov(rax, LDA); + or_(rax, A); + and_(rax, 0x1f); + mov(FLAG, rax); + + Label main0, main1, main2, main3, main999; + + cmp(M, UNROLL_M); + jl(main0, T_NEAR); + align(16); + + L(main1); + subloop(UNROLL_M, true, true); + sub(M, UNROLL_M); + cmp(M, UNROLL_M); + jge(main1, T_NEAR); + align(16); + + L(main0); + cmp(M, 0); + jle(main999, T_NEAR); + + if (UNROLL_M > 8) { + cmp(M, 8); + jle(main2, T_NEAR); + + sub(M, 8); + vbroadcastss(VMASK, M); + vpcmpgtd(VMASK, VMASK, MASK); + + subloop(16, true, false); + jmp(main999, T_NEAR); + align(16); + + L(main2); + cmp(M, 8); + jne(main3, T_NEAR); + subloop(8, true, true); + jmp(main999, T_NEAR); + } + + align(16); + + L(main3); + vbroadcastss(VMASK, M); + if (is_avx2) { + vpcmpgtd(VMASK, VMASK, MASK); + } else { + auto xmask = Xmm(VMASK.getIdx()); + auto xmm_tmp = xmm4; + + vextractf128(xmm_tmp, VMASK, 1); + vpcmpgtd(xmask, xmask, MASK); + vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4 + vinsertf128(VMASK, VMASK, xmm_tmp, 1); + } + subloop(8, false, false); + align(16); + + L(main999); + // Restore original stack + mov(rsp, ORIG_SP); + + vzeroupper(); + postamble(); + + ker_ = this->getCode<ker_t>(); + } + + typedef void (*ker_t)(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws); + + void operator()(dim_t m, dim_t n, dim_t k, + const float *alpha, const float *a, dim_t lda, + const float *b, dim_t ldb, const float *beta, float *c, + dim_t ldc, const float *bias, float *ws) const + { + ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws); + } + +private: + ker_t ker_; +}; + +const xbyak_gemm *get_xbyak_gemm( + bool isTransA, bool isTransB, float beta, bool hasBias) { + auto beta_idx = [](float beta) { + return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2); + }; + + // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)] + static xbyak_gemm *kernel_table[2][2][2][3]; + static std::once_flag initialized; + std::call_once(initialized, [=]{ + for (bool isTransA: {false, true}) + for (bool isTransB: {false, true}) + for (bool hasBias: {false, true}) + for (float beta: {0.0f, 1.0f, 2.0f}) { + // nocopy sgemm with bias for beta != 0.0 is not supported + if (hasBias && beta != 0.0) + continue; + kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] = + new xbyak_gemm(isTransA, isTransB, beta, hasBias); + } + }); + + return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)]; +} + +void sgemm_nocopy_driver(const char *transa, + const char *transb, int m, int n, int k, const float *alpha, + const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta, + float *c, dim_t ldc, const float *bias, float *ws) +{ + bool isTransA = (*transa == 'T' || *transa == 't'); + bool isTransB = (*transb == 'T' || *transb == 't'); + + int Bm, sizeM, Bn, sizeN, Bk, sizeK; + + int i, j; + + if ((m <= 0) || (n <= 0)) + return; + + if ((k <= 0) || (alpha[0] == 0.)) { + + if (beta[0] == 0.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] = 0.0; + } else if (beta[0] != 1.) { + for (j = 0; j < n; j++) + for (i = 0; i < m; i++) + c[i + j * ldc] *= beta[0]; + } + + return; + } + + assert(IMPLICATION(bias != nullptr, *beta == 0.0)); + + // XXX: this happens on every thread... + bool hasBias = (bias != nullptr); + auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias); + auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false); + auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false); + assert(ker_bn && ker_b1 && ker_b0); + + int BM = 4032; + int BN = isTransA ? 96 : 48; + int BK = isTransB ? 96 : 256; + const float *curA, *curB, *curBias = nullptr; + float *curC; + + for (Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK >= BK * 2) + sizeK = BK; + else { + if (sizeK > BK) + sizeK = (sizeK + 1) / 2; + } + + for (Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM >= BM * 2) + sizeM = BM; + else { + if (sizeM > BM + BM / 2) + sizeM = (sizeM + 1) / 2; + } + + for (Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN >= BN * 2) + sizeN = BN; + else { + if (sizeN > BN + BN / 2) + sizeN = (sizeN + 1) / 2; + } + + if (!isTransA) { + curA = a + Bm + Bk * lda; + } else { + curA = a + Bk + Bm * lda; + } + if (!isTransB) { + curB = b + Bk + Bn * ldb; + } else { + curB = b + Bn + Bk * ldb; + } + curC = c + Bm + (size_t)Bn * ldc; + if (bias != nullptr) { + if (Bk == 0) { + curBias = bias + Bm; + } else { + curBias = nullptr; + } + } + if (Bk == 0) { + if (*beta == 0.0 && bias == nullptr) + (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + else + (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } else { + (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK, + alpha, curA, lda, curB, ldb, beta, curC, ldc, + curBias, ws); + } + } + } + } +} + +} + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, + const int *p_m, const int *p_n, const int *p_k, const float *p_alpha, + const float *A, const int *p_lda, const float *B, const int *p_ldb, + const float *p_beta, float *C, const int *p_ldc, const float *bias) +{ + using namespace mkldnn::impl::utils; + using namespace avx_gemm_f32; + using namespace gemm_utils; + + if (*p_beta != 0 && bias) + return ref_gemm(transa, transb, p_m, p_n, p_k, + p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias); + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + int m = *p_m; + int n = *p_n; + int k = *p_k; + dim_t lda = *p_lda; + dim_t ldb = *p_ldb; + dim_t ldc = *p_ldc; + float beta = *p_beta; + int MB, NB, KB; + + int nthr_m, nthr_n, nthr_k, nthr_mn; + + // Determine threading partitioning + calc_nthr_nocopy_avx( + m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + // May not happen, but just in case + if (nthr < nthr_m * nthr_n * nthr_k) + nthr = nthr_m * nthr_n * nthr_k; + + nthr_mn = nthr_m * nthr_n; + + unsigned char * ompstatus_ = nullptr; + unsigned char volatile *ompstatus = nullptr; + + float *c_buffers = nullptr; + float *ws_buffers = nullptr; + + if (nthr_k > 1) { + ompstatus_ = (unsigned char *) malloc( + nthr * CACHE_LINE_SIZE, + CACHE_LINE_SIZE); + ompstatus = (unsigned char volatile *) ompstatus_; + assert(ompstatus); + + for (int i = 0; i < nthr; i++) + ompstatus[i * CACHE_LINE_SIZE] = 0; + + c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(float), PAGE_4K); + } + + const size_t ws_elems_per_thr = (size_t)k * 16 + 64; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); + if (k > STACK_K_CAPACITY) { + ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K); + } + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int k_from, k_to, myK; + int cbase, ibase; + const float *myA, *myB, *myBias = nullptr; + float *myC = C, myBeta; + float *ws = ws_buffers ? + ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0; + dim_t ld = ldc; + + int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k); + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + k_from = KB * (ithr_k); + k_to = KB * (ithr_k + 1); + if (k_to > k) + k_to = k; + myK = k_to - k_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + ibase = (ithr_m + nthr_m * ithr_n) * nthr_k; + + if ((myM > 0) && (myN > 0)) { + + if (*transa == 'N' || *transa == 'n') { + myA = &(A[m_from + k_from * lda]); + } else { + myA = &(A[k_from + m_from * lda]); + } + if (*transb == 'N' || *transb == 'n') { + myB = &(B[k_from + n_from * ldb]); + } else { + myB = &(B[n_from + k_from * ldb]); + } + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + if (bias) + myBias = &(bias[m_from]); + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0; + ld = MB; + myBias = nullptr; + } + + sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA, + lda, myB, ldb, &myBeta, myC, ld, myBias, ws); + + if (nthr_k > 1 && !sum_later) + ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1; + } + + if (nthr_k > 1 && !sum_later) { + + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + /* need to wait until main thread finishes */ + while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) { + }; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) { + }; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + + // handle C summation later + if (nthr_k > 1 && ompstatus[0] == 0) { + + parallel_nd(nthr, [&](const int ithr) { + int ithr_m, ithr_n, ithr_k, ithr_mn; + int m_from, m_to, myM; + int n_from, n_to, myN; + int cbase; + float *myC = C; + + if (ithr < nthr_m * nthr_n * nthr_k) { + + ithr_mn = ithr % nthr_mn; + ithr_m = ithr_mn % nthr_m; + ithr_n = ithr_mn / nthr_m; + ithr_k = ithr / nthr_mn; + + /* swap ithr_k for performance improvement */ + if (ithr_k == 0) + ithr_k = nthr_k - 1; + else if (ithr_k == nthr_k - 1) + ithr_k = 0; + + m_from = MB * (ithr_m); + m_to = MB * (ithr_m + 1); + if (m_to > m) + m_to = m; + myM = m_to - m_from; + + n_from = NB * (ithr_n); + n_to = NB * (ithr_n + 1); + if (n_to > n) + n_to = n; + myN = n_to - n_from; + + cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + if (nthr_k > 1) { + // sum matrices partitioned along K dimension + int n1, n2; + + partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2); + + if (ithr_k > 0) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1) + + (dim_t)n1 * MB; + + /* my cache is hot */ + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + + for (int ik = 1; ik < nthr_k; ++ik) { + if (ik != ithr_k) { + + myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1) + + (dim_t)n1 * MB; + + sum_two_matrices(myM, n2, myC, MB, + &C[m_from + (n_from + n1) * ldc], ldc); + } + } + } + } + }); + } + + + free(c_buffers); + free(ompstatus_); + free(ws_buffers); + + return mkldnn_success; +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp new file mode 100644 index 0000000000..aabf520a3c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX_GEMM_F32_HPP +#define JIT_AVX_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx_gemm_f32( + const char *transa, const char *transb, const int *M, + const int *N, const int *K, const float *alpha, const float *A, + const int *lda, const float *B, const int *ldb, const float *beta, + float *C, const int *ldc, const float *bias = nullptr); + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp new file mode 100644 index 0000000000..5147885a89 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "gemm_utils_f32.hpp" +#include "ref_gemm_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace gemm_utils; + +namespace { + +template <typename data_t> +void copy_A( + bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) { + for (int k = 0; k < K; k++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor<data_t>::m; i++) { + ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; + } + ws += unroll_factor<data_t>::m; + } +} + +template <typename data_t, bool isTransA, bool isTransB> +void kernel_mxn(int K, const data_t *A, const dim_t lda, + const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc, + const data_t alpha, const data_t beta) { + data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] = + { static_cast<data_t>(0.) }; + for (int k = 0; k < K; k++) { + for (int j = 0; j < unroll_factor<data_t>::n; j++) { + data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor<data_t>::m; i++) { + data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; + c[i + unroll_factor<data_t>::m * j] += a * b; + } + } + } + for (int j = 0; j < unroll_factor<data_t>::n; j++) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < unroll_factor<data_t>::m; i++) { + C[i + j * ldc] = (beta == static_cast<data_t>(0.)) + ? alpha * c[i + unroll_factor<data_t>::m * j] + : alpha * c[i + unroll_factor<data_t>::m * j] + + beta * C[i + j * ldc]; + } + } +} + +template <typename data_t, bool isTransA, bool isTransB> +void block_ker(const int M, const int N, const int K, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + data_t *C, const dim_t ldc, const data_t alpha, const data_t beta, + data_t *ws, bool do_copy) { + int Nu = rnd_dn(N, unroll_factor<data_t>::n); + int Mu = rnd_dn(M, unroll_factor<data_t>::m); + for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) { + for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) { + const data_t *b = isTransB ? &B[j] : &B[j * ldb]; + const data_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { + copy_A<data_t>(isTransA, K, a, lda, ws); + } + kernel_mxn<data_t, false, isTransB>( + K, ws, unroll_factor<data_t>::m, b, ldb, + &C[i + j * ldc], ldc, alpha, beta); + } else { + kernel_mxn<data_t, isTransA, isTransB>( + K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); + } + } + } + // tail processing + for (int i = 0; i < M; i++) { + for (int j = Nu; j < N; j++) { + data_t c = beta == static_cast<data_t>(0.) + ? static_cast<data_t>(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (int i = Mu; i < M; i++) { + for (int j = 0; j < Nu; j++) { + data_t c = beta == static_cast<data_t>(0.) + ? static_cast<data_t>(0.) + : beta * C[i + j * ldc]; + for (int p = 0; p < K; p++) { + data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; + data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template <typename data_t, bool isTransA, bool isTransB> +void gemm_ithr(const int M, const int N, const int K, const data_t alpha, + const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, + const data_t beta, data_t *C, const dim_t ldc, bool do_copy, + data_t *ws) { + constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM; + constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN; + constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK; + + const data_t *curA; + const data_t *curB; + data_t *curC; + + if ((M <= 0) || (N <= 0)) + return; + + if ((K <= 0) || (alpha == static_cast<data_t>(0))) { + dim_t MN = N * M; + if (beta == static_cast<data_t>(0.)) { + for (dim_t j = 0; j < MN; j++) + C[j] = static_cast<data_t>(0.); + } else if (beta != static_cast<data_t>(1.)) { + for (dim_t j = 0; j < MN; j++) + C[j] *= beta; + } + return; + } + + for (int Bk = 0; Bk < K; Bk += BK) { + int kb = nstl::min(K - Bk, BK); + for (int Bm = 0; Bm < M; Bm += BM) { + int mb = nstl::min(M - Bm, BM); + for (int Bn = 0; Bn < N; Bn += BN) { + int nb = nstl::min(N - Bn, BN); + curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; + curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; + curC = C + Bm + Bn * ldc; + if (Bk == 0) { + block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, beta, ws, do_copy); + } else { + block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, + curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0), + ws, do_copy); + } + } + } + } +} + +} + +template <typename data_t> +mkldnn_status_t ref_gemm( + const char *transa_, const char *transb_, const int *M_, + const int *N_, const int *K_, const data_t *alpha_, const data_t *A, + const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_, + data_t *C, const int *ldc_, const data_t *bias) { + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const int M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const data_t alpha = *alpha_, beta = *beta_; + + int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads(); + int nthr_m, nthr_n, nthr_k; + int MB, NB, KB; + // thread balancing over M, N, K & size of blocking dimensions + calc_nthr_nocopy_avx( + M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1)); + + data_t *c_buffers = nullptr; + data_t *ws_buffers = nullptr; + if (nthr_k > 1) { + c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB + * sizeof(data_t), PAGE_4K); + if (!c_buffers) { + nthr_k = 1; + KB = K; + } + } + + bool do_copy = (NB / unroll_factor<data_t>::n > 3); + const int nthr_mn = nthr_m * nthr_n; + const int nthr = nthr_mn * nthr_k; + const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m; + const size_t ws_size_per_thr + = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); + if (do_copy) { + ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K); + if (!ws_buffers) + do_copy = false; + } + + auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N, + int ithr) { + from = NB * (ithr); + to = NB * (ithr + 1); + if (to > N) + to = N; + myN = to - from; + }; + + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_n = ithr_mn / nthr_m; + int ithr_k = ithr / nthr_mn; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + data_t *ws = do_copy + ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) + : nullptr; + + int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, + k_from = 0, k_to = 0, myK = 0; + + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(k_from, k_to, myK, KB, K, ithr_k); + + if (myM > 0 && myN > 0) { + data_t myBeta, *myC; + dim_t ld; + if (ithr_k == 0) { + myC = &(C[m_from + n_from * ldc]); + myBeta = beta; + ld = ldc; + } else { + myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1); + myBeta = 0.0f; + ld = MB; + } + const data_t *myA = isTransA + ? &(A[k_from + m_from * lda]) + : &(A[m_from + k_from * lda]); + const data_t *myB = isTransB + ? &(B[n_from + k_from * ldb]) + : &(B[k_from + n_from * ldb]); + + if (!isTransA) { + if (!isTransB) { + gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } else { + if (!isTransB) { + gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } else { + gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA, + lda, myB, ldb, myBeta, myC, ld, do_copy, ws); + } + } + } + }); + + if (nthr_k > 1) { + parallel_nd(nthr, [&](const int ithr) { + int ithr_mn = ithr % nthr_mn; + int ithr_m = ithr_mn % nthr_m; + int ithr_k = ithr / nthr_mn; + int ithr_n = ithr_mn / nthr_m; + + int n_from = 0, n_to = 0, myN = 0; + int m_from = 0, m_to = 0, myM = 0; + + int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); + + get_thr_block(n_from, n_to, myN, NB, N, ithr_n); + get_thr_block(m_from, m_to, myM, MB, M, ithr_m); + + // sum matrices partitioned along K dimension + int offset = 0, block = 0; + gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset, + &block); + for (int ik = 1; ik < nthr_k; ++ik) { + data_t *myC = c_buffers + + MB * ((dim_t)NB * (cbase + ik - 1) + offset); + + gemm_utils::sum_two_matrices(myM, block, myC, MB, + &C[m_from + (n_from + offset) * ldc], ldc); + } + }); + } + + if (bias) { + parallel_nd(N, M, [&](int i, int j) { + C[i*ldc + j] += bias[j]; + }); + } + + free(ws_buffers); + free(c_buffers); + + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm<float>( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const float *alpha_, + const float *A, const int *lda_, const float *B, const int *ldb_, + const float *beta_, float *C, const int *ldc_, const float *bias); + +template mkldnn_status_t ref_gemm<double>( + const char *transa_, const char *transb_, + const int *M_, const int *N_, const int *K_, const double *alpha_, + const double *A, const int *lda_, const double *B, const int *ldb_, + const double *beta_, double *C, const int *ldc_, const double *bias); +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp new file mode 100644 index 0000000000..7c90ba6277 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_F32_HPP +#define REF_GEMM_F32_HPP + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <typename data_t> +mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M, + const int *N, const int *K, const data_t *alpha, const data_t *A, + const int *lda, const data_t *B, const int *ldb, const data_t *beta, + data_t *C, const int *ldc, const data_t *bias); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp new file mode 100644 index 0000000000..3dbe07d743 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn.h" + +#include "mkldnn_traits.hpp" +#include "nstl.hpp" + +#include "jit_generator.hpp" + +#include "gemm.hpp" + +#include "f32/jit_avx512_common_gemm_f32.hpp" +#include "f32/jit_avx_gemm_f32.hpp" +#include "f32/ref_gemm_f32.hpp" + +#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" +#include "s8x8s32/simple_gemm_s8s8s32.hpp" +#include "s8x8s32/ref_gemm_s8x8s32.hpp" + +#include "os_blas.hpp" + +/* USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes use Intel(R) MKL CBLAS + * yes no use jit + * no yes system-dependent CBLAS + * no no use jit + */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t check_gemm_input(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const int *lda, + const int *ldb, const int *ldc, const float *alpha, const float *beta, + const bool with_bias) { + if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) + return mkldnn_invalid_arguments; + if (with_bias && *beta != 0) + return mkldnn_unimplemented; + bool consistency = true + && utils::one_of(*transa, 'T', 't', 'N', 'n') + && utils::one_of(*transb, 'T', 't', 'N', 'n') + && *M >= 0 + && *N >= 0 + && *K >= 0; + + if (!consistency) + return mkldnn_invalid_arguments; + bool isTransA = utils::one_of(*transa, 'T', 't'); + bool isTransB = utils::one_of(*transb, 'T', 't'); + int nrowA = isTransA ? *K : *M; + int nrowB = isTransB ? *N : *K; + consistency = true + && *lda >= nstl::max(1, nrowA) + && *ldb >= nstl::max(1, nrowB) + && *ldc >= nstl::max(1, *M); + if (!consistency) + return mkldnn_invalid_arguments; + + return mkldnn_success; +} + +mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, + const char *transa, const char *transb, const int *M, const int *N, + const int *K, const int *lda, const int *ldb, const int *ldc, + const float *alpha, const float *beta, const bool with_bias) { + if (offsetc == nullptr) + return mkldnn_invalid_arguments; + if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) + return mkldnn_invalid_arguments; + + return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, + beta, with_bias); +} + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias, const bool force_jit_gemm) { + mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, + lda, ldb, ldc, alpha, beta, bias != nullptr); + if (status != mkldnn_success) + return status; + +#ifdef USE_CBLAS + if (!force_jit_gemm) { + bool trA = *transa == 't' || *transa == 'T'; + bool trB = *transb == 't' || *transb == 'T'; + CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; + cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, + *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); + + if (bias) { + // Add bias if necessary (bias is applied to columns of C) + cblas_int incx = 1, incy = 1; + parallel_nd(*N, [&](int n) { + ptrdiff_t offset = (ptrdiff_t)n * (*ldc); + cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); + }); + } + return mkldnn_success; + } +#endif + + if (mayiuse(avx512_common)) + return jit_avx512_common_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else if (mayiuse(avx)) + return jit_avx_gemm_f32(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); + else + return ref_gemm<float>(transa, transb, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); +} + +template <typename b_dt> +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, + M, N, K, LDA, LDB, LDC, alpha, beta, false); + if (status != mkldnn_success) + return status; + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + +#if USE_MKL_IGEMM + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + if (data_traits<b_dt>::data_type == data_type::u8) { + CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; + CBLAS_OFFSET Cblas_offsetc = + OCisR + ? CblasRowOffset + : OCisC + ? CblasColOffset + : CblasFixOffset; + cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, + *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, + *beta, C, *LDC, co); + return mkldnn_success; + } else { + assert(data_traits<b_dt>::data_type == data_type::s8); + // TODO CBLAS implementation of gemm_s8s8s32 goes here. + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if (utils::everyone_is(0, *ao, *bo)) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#else + cpu_isa_t isa = isa_any; + if (mayiuse(avx512_core_vnni)) { + isa = avx512_core_vnni; + } else if (mayiuse(avx512_core)) { + isa = avx512_core; + } + + if (data_traits<b_dt>::data_type == data_type::u8) { + switch (isa) { + case avx512_core: + case avx512_core_vnni: + return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, + C, LDC, co); + default: + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } else { + assert(data_traits<b_dt>::data_type == data_type::s8); + // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo + if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) + && *ao == 0 && *bo == 0) { + return simple_gemm_s8s8s32(transa, transb, offsetc, M, + N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, + C, LDC, co); + } else { + return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, + alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); + } + } +#endif +} + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +template +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} + +using namespace mkldnn::impl; +using namespace mkldnn::impl::cpu; + +mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, + const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, + const float *A, const int64_t *lda, const float *B, const int64_t *ldb, + const float *beta, float *C, const int64_t *ldc) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); +} + +mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} + +mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, + const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, + const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, + const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, + int32_t *C, const int64_t *ldc, const int32_t *co) { + int M_s32 = (int)*M; + int N_s32 = (int)*N; + int K_s32 = (int)*K; + int lda_s32 = (int)*lda; + int ldb_s32 = (int)*ldb; + int ldc_s32 = (int)*ldc; + + return gemm_s8x8s32<int8_t>(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, + alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp new file mode 100644 index 0000000000..dc15ff7130 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp @@ -0,0 +1,58 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_HPP +#define GEMM_HPP + +#include "mkldnn_types.h" +#include "os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t extended_sgemm(const char *transa, const char *transb, + const int *M, const int *N, const int *K, const float *alpha, + const float *A, const int *lda, const float *B, const int *ldb, + const float *beta, float *C, const int *ldc, + const float *bias = nullptr, bool force_jit_gemm = false); + +template <typename b_dt> +mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *lda, const int8_t *ao, + const b_dt *B, const int *ldb, const int8_t *bo, const float *beta, + int32_t *c, const int *ldc, const int32_t *co); + +#ifdef USE_CBLAS +#define GEMM_IMPL_STR "gemm:blas" +#else +#define GEMM_IMPL_STR "gemm:jit" +#endif + +#if USE_MKL_IGEMM +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas" +#else +#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit" +#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit" +#endif + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp new file mode 100644 index 0000000000..4d34ede0bd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef OS_BLAS_HPP +#define OS_BLAS_HPP + +/** \file + * Common stuff respecting USE_MKL and USE_CBLAS compile flags + * + * USE_MKL USE_CBLAS effect + * ------- --------- ------ + * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS + * yes no jit calls OK; assert if cblas is ever called + * no yes system-dependent CBLAS + * no no gemm convolution (or other blas) N/A; create stubs + */ + +#if defined(USE_MKL) + +#include "mkl_version.h" + +#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001) +#define USE_MKL_IGEMM \ + (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628) + +#include "mkl_cblas.h" +#if !defined(USE_CBLAS) +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif + +#else /* defined(USE_MKL) */ + +#define USE_MKL_PACKED_GEMM 0 +#define USE_MKL_IGEMM 0 + +#if defined(_SX) +/* TODO: _SX should also define USE_CBLAS in case the later is available */ +extern "C" { +#include "cblas.h" // CHECK: does SX also have a fortran API sgemm? +} + +#elif defined(USE_CBLAS) +#include "cblas.h" // Maybe a system/cmake cblas works for you? +#else +/* put the stubs to make a code compilable but not workable */ +#define cblas_sgemm(...) assert(!"CBLAS is unavailable") +#endif /* defined(_SX) */ + +#endif /* defined(USE_MKL) */ + +namespace mkldnn { +namespace impl { +namespace cpu { + +#if defined(USE_MKL) && defined(USE_CBLAS) +typedef MKL_INT cblas_int; + +#elif defined(USE_CBLAS) +typedef int cblas_int; + +#if defined(_SX) +/* this cblas.h is peculiar... */ +typedef CBLAS_ORDER CBLAS_LAYOUT; +#endif +#endif + +} +} +} + +#endif /* OS_BLAS_HPP */ + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp new file mode 100644 index 0000000000..dde72f4a17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp @@ -0,0 +1,206 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_H +#define COMMON_H + +#define GEMM_CODE_SIZE (4096L * 32) + +#define AVX512_UNROLL_M 48 +#define AVX512_UNROLL_N 8 +#define AVX512_UNROLL_K 1 +#define AVX512_BM 9984 +#define AVX512_BN 384 +#define AVX512_BK 768 +#define AVX512_BK_VNNI 1536 +#define AVX512_BK_TRADITIONAL 384 +#define AVX512_BLOCKING_SMALL_K 48 +#define AVX512_BN_SMALL_K 24 + + +#define PAGESIZE 4096 + +#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE +#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +enum { + PARTITION_1D_ROW, + PARTITION_1D_COL, + PARTITION_2D_COL_MAJOR, + PARTITION_2D = PARTITION_2D_COL_MAJOR, +}; + +enum { + COPY_NONE, + COPY_A, +}; + +enum { + NO_OFFSET, + FIX_OFFSET, + COL_OFFSET, + ROW_OFFSET, +}; + +// Alias for any dimension related variable. +typedef long long int dim_t; + +typedef struct { + // Interface arguments. + int transa, transb, offsetc; + dim_t m, n, k; + dim_t lda, ldb, ldc; + const int8_t *a; + const uint8_t *b; + int32_t *c; + const float *alpha, *beta; + + int8_t ao, bo; + const int32_t *co; + + // Kernel parameters. + dim_t um, un, uk, bm, bn, bk; + dim_t bn_small_k, bk_traditional, blocking_small_k; + + int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + // Gemv kernels + void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + // Gemv parameters + int swap; + +} blas_t; + + +class jit_avx512_core_u8_copy_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern); + + public: + jit_avx512_core_u8_copy_an_kern(); +}; + +class jit_avx512_core_u8_copy_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern); + + public: + jit_avx512_core_u8_copy_at_kern(); +}; + +class jit_avx512_core_u8_copy_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern); + + public: + jit_avx512_core_u8_copy_bn_kern(); +}; + +class jit_avx512_core_u8_copy_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern); + + public: + jit_avx512_core_u8_copy_bt_kern(); +}; + +class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern); + + public: + jit_avx512_core_u8_copy_sum_an_kern(); +}; + +class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern); + + public: + jit_avx512_core_u8_copy_sum_at_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern); + + public: + jit_avx512_core_u8_copy_sum_bn_kern(); +}; + +class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern); + + public: + jit_avx512_core_u8_copy_sum_bt_kern(); +}; + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp new file mode 100644 index 0000000000..db9dd9ef97 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg); +int gemv_threading_driver(blas_t *arg); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp new file mode 100644 index 0000000000..e4b8e1cde2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp @@ -0,0 +1,1409 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <cstdint> +#include <mutex> + +#include "common.hpp" +#include "mkldnn_types.h" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_gemm_s8u8s32.hpp" +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" +#include "gemv.hpp" + +#if defined(_MSC_VER) +#include <malloc.h> +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef struct { + int nthrs_m, nthrs_n; + int partition; + int copy_type; +} blas_thread_t; + +static inline void round_to_nearest(int32_t *rounded_val, double fp_val) { + if (fp_val >= 0.) { + fp_val += 0.5; + if (fp_val > INT32_MAX) { + fp_val = INT32_MAX; + } + } else { + fp_val -= 0.5; + if (fp_val < INT32_MIN) { + fp_val = INT32_MIN; + } + } + *rounded_val = (int32_t) fp_val; +} + +static inline void add_results(const dim_t m, const dim_t n, const dim_t k, + const float alpha, const float beta, const int32_t *c_partial_sum, + const dim_t ldcp, int32_t *c_data, const dim_t ldc, + const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao, + const int8_t bo, const int32_t *co, const int offsetc) +{ + for (dim_t j = 0; j < n; ++j) { + for (dim_t i = 0; i < m; ++i) { + int32_t ctemp = c_partial_sum[i + j * ldcp]; + + if (alpha == 1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float += (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else if (alpha == -1.0f) { + if (beta == 0.0f) { + c_data[i + j * ldc] = -ctemp; + } else { + double c_float = (double) beta + * (double) c_data[i + j * ldc]; + c_float -= (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } else { + if (beta == 0.0f) { + double c_float = alpha * (double) ctemp; + round_to_nearest(&c_data[i + j * ldc], c_float); + } else { + double c_float = alpha * (double) ctemp + + beta * (double) c_data[i + j * ldc]; + round_to_nearest(&c_data[i + j * ldc], c_float); + } + } + + if (offsetc == FIX_OFFSET) { + c_data[i + j * ldc] += co[0]; + } else if (offsetc == ROW_OFFSET) { + c_data[i + j * ldc] += co[j]; + } else if (offsetc == COL_OFFSET) { + c_data[i + j * ldc] += co[i]; + } + } + } +} + +// TODO Find a better place for those functions. +static inline dim_t ld_padd(const dim_t x) +{ + return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t))) + * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t)); +} + +void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, float beta, int32_t *c, + const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum, + const int32_t *co, const int offsetc, const blas_t *arg) +{ + int8_t ao = arg->ao; + int8_t bo = arg->bo; + int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0]; + + // Since m and n are limited by blocking, stack overflow may not happen; + // it's up to 32kB +#if !defined(_MSC_VER) + int32_t col_offset[m]; + int32_t row_offset[n]; +#else + int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m); + int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n); +#endif + + int col_req = 0; + int row_req = 0; + + if ((bo != 0) || (offsetc == COL_OFFSET)) + col_req = 1; + if ((ao != 0) || (offsetc == ROW_OFFSET)) + row_req = 1; + + // It needs one of colum or row offsets, but it doesn't need both + if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) { + if ((col_req == 0) && (row_req == 0)) { + if (m <= n) { + col_req = 1; + } else { + row_req = 1; + } + } + } + + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] = 0; + + if (offsetc == COL_OFFSET) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co[i]; + } + + if (bo != 0) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += bo * a_row_sum[i]; + } + } + + if (row_req) { + for (dim_t i = 0; i < n; i++) + row_offset[i] = 0; + + if (offsetc == ROW_OFFSET) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co[i]; + } + + if (ao != 0) { + for (dim_t i = 0; i < n; i++) + row_offset[i] += ao * b_col_sum[i]; + } + } + + if ((offsetc == FIX_OFFSET) && (co_0 != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += co_0; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += co_0; + } + } + + if ((ao != 0) && (bo != 0)) { + if (col_req) { + for (dim_t i = 0; i < m; i++) + col_offset[i] += (int32_t) k * ao * bo; + } else { + for (dim_t i = 0; i < n; i++) + row_offset[i] += (int32_t) k * ao * bo; + } + } + + if (col_req == 0) { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } else { + if (row_req == 0) { + if (beta == 0.0) { + arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } else { + if (beta == 0.0) { + arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } else { + arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset, + row_offset); + } + } + } +} + +static inline void *align(void *ptr, size_t alignment) +{ + return (void *) utils::rnd_up((uintptr_t) ptr, alignment); +} + +static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k, + const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co, + const blas_t *arg) +{ + dim_t lda = arg->lda; + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + float beta = *arg->beta; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m_padd); + + dim_t strideAm = (arg->transa == 0)? 1 : lda; + dim_t strideAn = (arg->transa != 0)? 1 : lda; + dim_t strideBm = (arg->transb == 0)? 1 : ldb; + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t a_buf_nelems = m_padd * k_padd; + size_t b_buf_nelems = k_padd * n_padd; + size_t a_row_sum_nelems = m_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K + + b_buf_nelems * sizeof(*b) + PAGE_4K + + a_row_sum_nelems * sizeof(*c) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems, + PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + float beta_saved = beta; + + int a_block_copied = 0; + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first time + if (Bk == 0) + beta = beta_saved; + else + beta = 1.0f; + + // Apply C offset when to the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn; + arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL, + NULL, b_col_sum); + + dim_t sizeUM = 0; + for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { + sizeUM = sizeM - Um; + if (sizeUM > arg->um) + sizeUM = arg->um; + + /* + * Use the whole A buffer only if we have multiple B blocks + * for k-dimension, otherwise we are wasting cache to store + * B and C blocks. + */ + dim_t Um_forA = 0; + if (sizeN < n) + Um_forA = Um; + + const int8_t *a_block = a + (Bm + Um) * strideAm + + Bk * strideAn; + if (!a_block_copied) { + arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL, + bufferA + Um_forA * sizeK, NULL, NULL, + a_row_sum + Um_forA); + } + + int32_t *c_block = c + (Bm + Um) + Bn * ldc; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm + Um; + } + if (need_c_buffer) { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, 0.0f, + bufferC + Um, ldc_buf, a_row_sum + Um_forA, + b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta + // and offsets. + add_results(sizeUM, sizeN, sizeK, alpha, beta, + bufferC + Um, ldc_buf, c_block, ldc, + a_row_sum + Um_forA, b_col_sum, ao, bo, + co + co_stride, offsetc); + } else { + igemm_inner_kernel(sizeUM, sizeN, sizeK, + bufferA + Um_forA * sizeK, bufferB, beta, + c_block, ldc, a_row_sum + Um_forA, b_col_sum, + co + co_stride, offsetc, arg); + } + } + a_block_copied = 1; + } + a_block_copied = 0; + } + } + + free(mem); + + return 0; +} + +static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n, + const dim_t k, const int8_t *bufferA, const uint8_t *b, + const float beta, int32_t *c, const int offsetc, const int32_t *co, + const int32_t *a_row_sum, const blas_t *arg) +{ + dim_t ldb = arg->ldb; + dim_t ldc = arg->ldc; + int8_t ao = arg->ao; + int8_t bo = arg->bo; + float alpha = *arg->alpha; + + if (m <= 0 || n <= 0) { + return 0; + } + + // Padding along N dimension. + dim_t n_padd = 0; + if (k < arg->blocking_small_k) { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), + arg->bn_small_k), arg->un); + } else { + n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn), + arg->un); + } + + // Padding for temporary buffer for C + dim_t ldc_buf = ld_padd(m); + + dim_t strideBn = (arg->transb != 0)? 1 : ldb; + + size_t b_buf_nelems = k * n_padd; + size_t b_col_sum_nelems = n_padd; + + size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K + + b_col_sum_nelems * sizeof(*c) + PAGE_4K; + + bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0); + if (need_c_buffer) { + size_t c_buf_nelems = ldc_buf * n_padd; + mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; + } + + char *mem = (char *) malloc(mem_size, 128); + + if (!mem) { + return -1; + } + + uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K); + int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K); + + int32_t *bufferC = NULL; + if (need_c_buffer) { + bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K); + } + + dim_t sizeN = 0; + for (dim_t Bn = 0; Bn < n; Bn += sizeN) { + sizeN = n - Bn; + if (sizeN > n_padd) + sizeN = n_padd; + + // Implement the kernel here. + const uint8_t *b_block = b + Bn * strideBn; + arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL, + b_col_sum); + + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = Bn; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + int32_t *c_block = c + Bn * ldc; + if (need_c_buffer) { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC, + ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg); + + // Finish the block adding the necessary alpha, beta and offsets. + add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block, + ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride, + offsetc); + } else { + igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block, + ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg); + } + } + + free(mem); + + return 0; + +} + +#define N2D_MAX_AVX512 384 +#define M2D_MIN_AVX512 384 +#define VECLEN 16 +#define NCONS 1 +static inline void set_thread_opts_avx512(int *p_nthrs, + blas_thread_t *thread_info, const blas_t *arg) +{ + int nthrs = *p_nthrs; + dim_t m = arg->m; + dim_t n = arg->n; + + thread_info->nthrs_m = 0; + thread_info->nthrs_n = 0; + thread_info->copy_type = COPY_NONE; // By default don't do parallel copy. + + int condition_2D_bsrc = -1; + if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) { + condition_2D_bsrc = 1; + } else { + condition_2D_bsrc = 0; + } + + int condition_1D_copya = 0; + if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + // If offset is non-zero, we need to keep 1D_copya to reduce update overhead + if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0 + || arg->offsetc != FIX_OFFSET) { + condition_2D_bsrc = 0; + condition_1D_copya = 1; + } + + if (condition_2D_bsrc == 1) { + int nthrs_m = 1; + int nthrs_n = nthrs; + + while ((nthrs_n % 2 == 0) && + (n / nthrs > N2D_MAX_AVX512 || + n / nthrs_n <= N2D_MAX_AVX512 / 2) && + (m / nthrs_m >= 2 * M2D_MIN_AVX512) && + (nthrs_m < 4)) { + nthrs_m *= 2; + nthrs_n /= 2; + } + + thread_info->nthrs_m = nthrs_m; + thread_info->nthrs_n = nthrs_n; + thread_info->partition = PARTITION_2D; + + // Reset the total number of threads that will be used. + *p_nthrs = nthrs_m * nthrs_n; + + } else if (condition_1D_copya && mkldnn_thr_syncable()) { + // Use parallel copy A algorithm + thread_info->copy_type = COPY_A; + thread_info->partition = PARTITION_1D_COL; + } else { + if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) { + thread_info->partition = PARTITION_1D_ROW; + } else { + thread_info->partition = PARTITION_1D_COL; + } + } +} +#undef N2D_MAX_AVX512 +#undef M2D_MIN_AVX512 +#undef VECLEN +#undef NCONS + +static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, + dim_t *t_offset, dim_t *t_block) +{ + dim_t band = n / nthrs; + + dim_t tail = n - (nthrs - 1) * band; + if (tail > (band + 1)) + band++; + tail = n - (nthrs - 1) * band; + + if (ithr < (nthrs - 1)) + *t_block = band; + else + *t_block = tail; + + *t_offset = ithr * band; + + if (*t_offset >= n) { + *t_block = 0; + *t_offset = 0; + } else if ((*t_offset + *t_block) > n) { + *t_block = n - *t_offset; + } +} + +static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, + const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, + const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp, + dim_t *p_n_band) +{ + dim_t m_disp = 0, n_disp = 0; + dim_t m_band = 0, n_band = 0; + + int mdiv = nthrs_m; + int ndiv = nthrs_n; + + dim_t m_bandt = m / mdiv; /* size per thread */ + dim_t n_bandt = n / ndiv; /* size per thread */ + int firstmgroup = mdiv - 1; + int firstngroup = ndiv - 1; + dim_t firstmval = m_bandt; + dim_t firstnval = n_bandt; + + int mthr_used = mdiv; + if (m - (mdiv - 1) * m_bandt > m_bandt + 1) { + if (m - (mdiv - 1) * m_bandt > mdiv) + ++m_bandt; + + firstmval = m_bandt + 1; + mthr_used = (int) (m / firstmval); + + if (mthr_used * firstmval < m) + ++mthr_used; + + firstmgroup = mthr_used - 1; + } + + int nthr_used = ndiv; + if (n - (ndiv - 1) * n_bandt > n_bandt + 1) { + firstnval = n_bandt + 1; + nthr_used = (int) (n / firstnval); + + if (nthr_used * firstnval < n) + ++nthr_used; + + firstngroup = nthr_used - 1; + } + + *nthrs = mthr_used * nthr_used; + + if (ithr < *nthrs) { + if (ithr_i < firstmgroup) { + m_band = firstmval; + m_disp = ithr_i * firstmval; + } else if (ithr_i <= mthr_used - 2) { + m_band = m_bandt; + m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt; + } else { + m_disp = firstmgroup * firstmval + + (mthr_used - 1 - firstmgroup) * m_bandt; + m_band = nstl::max(0LL, m - m_disp); + } + + if (ithr_j < firstngroup) { + n_band = firstnval; + n_disp = ithr_j * firstnval; + } else if (ithr_j <= nthr_used - 2) { + n_band = n_bandt; + n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt; + } else { + n_disp = firstngroup * firstnval + + (nthr_used - 1 - firstngroup) * n_bandt; + n_band = nstl::max(0LL, n - n_disp); + } + m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL); + n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL); + } + + if (ithr < *nthrs) { + *p_m_disp = m_disp; + *p_n_disp = n_disp; + *p_m_band = m_band; + *p_n_band = n_band; + } else { + *p_m_disp = 0; + *p_n_disp = 0; + *p_m_band = 0; + *p_n_band = 0; + } + + return; +} + +static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m, + dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c, + const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg) +{ + dim_t strideAm = (arg->transa == 0)? 1 : arg->lda; + dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb; + int offsetc = arg->offsetc; + + switch (thread_info->partition) { + case PARTITION_1D_ROW: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->m, &offset, &block); + + *m = block; + *n = arg->n; + *k = arg->k; + + // Set matrix A. + *a = arg->a + offset * strideAm; + + // Set matrix B. + *b = arg->b; + + // Set matrix C. + *c = arg->c + offset; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = offset; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_1D_COL: + { + dim_t offset = 0; + dim_t block = 0; + partition_1d(ithr, *nthrs, arg->n, &offset, &block); + + *m = arg->m; + *n = block; + *k = arg->k; + + // Set matrix A. + *a = arg->a; + + // Set matrix B. + *b = arg->b + offset * strideBn; + + // Set matrix C. + *c = arg->c + offset * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = offset; + } else if (offsetc == COL_OFFSET) { + co_stride = 0; + } + *co = arg->co + co_stride; + break; + } + + case PARTITION_2D_COL_MAJOR: + { + int nthrs_m = thread_info->nthrs_m; + int nthrs_n = thread_info->nthrs_n; + int ithr_i = ithr % nthrs_m; + int ithr_j = ithr / nthrs_m; + + dim_t m_disp = 0; + dim_t m_band = 0; + dim_t n_disp = 0; + dim_t n_band = 0; + + partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n, + arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band); + + *m = m_band; + *n = n_band; + *k = arg->k; + + // Set matrix A. + *a = arg->a + m_disp * strideAm; + + // Set matrix B. + *b = arg->b + n_disp * strideBn; + + // Set matrix C. + *c = arg->c + m_disp + n_disp * arg->ldc; + + // Set offset vector for C matrix + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = n_disp; + } else if (offsetc == COL_OFFSET) { + co_stride = m_disp; + } + *co = arg->co + co_stride; + break; + } + } +} + +#define MULTIPLIER 10 +static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m, + const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b, + int32_t *c, const int32_t *co, const blas_t *arg, + char **p_shared_mem) +{ + const dim_t lda = arg->lda; + const dim_t ldb = arg->ldb; + const dim_t strideAm = (arg->transa == 0)? 1 : lda; + const dim_t strideAn = (arg->transa != 0)? 1 : lda; + const dim_t strideBm = (arg->transb == 0)? 1 : ldb; + + // Padding along M dimension. + dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm), + arg->um); + + // Padding along K dimension. + dim_t k_padd = 0; + if (k <= arg->bk_traditional) { + k_padd = utils::rnd_up(k, arg->uk); + k_padd = nstl::max(128LL, k_padd); + } else if (k < 2 * arg->bk) { + k_padd = utils::rnd_up(k / 2, arg->uk); + } else { + k_padd = arg->bk; + } + + m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs; + if (m_padd > m) { + m_padd = utils::rnd_up(m, arg->um); + } + + size_t a_buf_nelems = m_padd * k_padd; + + // Allocate shared memory for A and its row sum buffers in master thread. + if (ithr == 0) { // If thread master + size_t a_row_sum_nelems = m_padd; + + size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K) + + a_row_sum_nelems * sizeof(*c) + PAGE_4K; + + *p_shared_mem = (char *) malloc(mem_size, 128); + + } + mkldnn_thr_barrier(); + + char *mem = *p_shared_mem; + int8_t *bufferA = (int8_t *) align(mem, PAGE_4K); + int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K); + + if (!mem) { + return -1; + } + + int result = 0; // Return status + + dim_t sizeK = 0; + for (dim_t Bk = 0; Bk < k; Bk += sizeK) { + sizeK = k - Bk; + if (sizeK > k_padd) + sizeK = k_padd; + + // Scale C blocks by beta only for the first term of partial sum. + float beta = 1.0f; + if (Bk == 0) + beta = *(arg->beta); + + // Apply C offset for the last k-block of the partial sum. + int offsetc = NO_OFFSET; + if (Bk + sizeK == k) + offsetc = arg->offsetc; + + dim_t sizeM = 0; + for (dim_t Bm = 0; Bm < m; Bm += sizeM) { + sizeM = m - Bm; + if (sizeM > m_padd) + sizeM = m_padd; + + if (ithr < nthrs) { + dim_t band = (sizeM + nthrs - 1) / nthrs; + band = utils::rnd_up(band, arg->um); + + dim_t offset = band * ithr; + + // If offset is too large don't use that thread for copying. + if (offset >= sizeM) { + offset = 0; + band = 0; + } + + // Handle the tail of the copy. + if (offset + band > sizeM) { + band = sizeM - offset; + } + + if (band > 0) { + const int8_t *a_block = a + (Bm + offset) * strideAm + + Bk * strideAn; + arg->copyA(&sizeK, &band, a_block, &lda, NULL, + bufferA + offset * sizeK, NULL, NULL, + a_row_sum + offset); + } + } + mkldnn_thr_barrier(); // Wait for finishing parallel copy. + + const uint8_t *b_block = b + Bk * strideBm; + int32_t *c_block = c + Bm; + dim_t co_stride = 0; + if (offsetc == FIX_OFFSET) { + co_stride = 0; + } else if (offsetc == ROW_OFFSET) { + co_stride = 0; + } else if (offsetc == COL_OFFSET) { + co_stride = Bm; + } + + result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK, + bufferA, b_block, beta, c_block, offsetc, co + co_stride, + a_row_sum, arg); + + mkldnn_thr_barrier(); // Wait for kernel computations to finish. + } + } + + // Free memory allocated in master thread + if (ithr == 0) { + free(mem); + } + + return result; +} +#undef MULTIPLIER + +static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k, + double fp_per_cycle, int *nthrs) +{ + double omp_overhead_small_core = 3.0e+3; + double omp_intercept_big_core = 4.0e+3; + double omp_slope_big_core = 5.0e+2; + + double gemm_cycles = 8.0 * m * n * k / fp_per_cycle; + + int i = *nthrs; + + // Use a different model for omp overheads if nthrs is <= 4 + if (*nthrs <= 4 && omp_overhead_small_core > 0) { + double omp_cycles = omp_overhead_small_core; + if (gemm_cycles < omp_cycles) { + *nthrs = 1; + return; + } else { + while (i > 1) { + if (omp_cycles * i < gemm_cycles * (i - 1)) break; + --i; + } + } + } else { + if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { + *nthrs = 1; + return; + } + + // adaptive decrement to march faster· + while (i > 1) { + double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; + if (omp_cycles * i < gemm_cycles * (i - 1)) + break; + + if (i < 10) + i -= 2; + else if (i < 30) + i -= 4; + else + i -= 8; + } + } + + if (i < 1) + i = 1; + + *nthrs = i; +} + +#define CACHE_LINE_SIZE 64 +static int gemm_threading_driver(blas_t *arg) +{ + if ((arg->m <= 0) || (arg->n <= 0)) + return mkldnn_success; + + if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) { + return mkldnn_success; + } + + int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr); + + if (nthr == 1) { + return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b, + arg->c, arg->co, arg); + } + + int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE, + PAGE_4K); + + if (!results) { + return -1; + } + + for (int i = 0; i < nthr; i++) { + results[i * CACHE_LINE_SIZE] = 0; // Initialize to success + } + + char *shared_mem = NULL; + + parallel(nthr, [&](const int ithr, const int nthr) { + int nthrs = nthr; + if (nthrs == 1) { + results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, + arg->b, arg->c, arg->co, arg); + } else { + blas_thread_t thread_info; + set_thread_opts_avx512(&nthrs, &thread_info, arg); + + const int8_t *a = NULL; + const uint8_t *b = NULL; + int32_t *c = NULL; + const int32_t *co = NULL; + dim_t m = -1; + dim_t n = -1; + dim_t k = -1; + decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co, + &thread_info, arg); + + if (ithr < nthrs) { + switch (thread_info.copy_type) { + case COPY_A: + results[ithr * CACHE_LINE_SIZE] = + parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg, + &shared_mem); + break; + + default: + case COPY_NONE: + results[ithr * CACHE_LINE_SIZE] = + gemm_kernel_driver(m, n, k, a, b, c, co, arg); + break; + } + } + } + }); + + int result = 0; // Initialize to success + for (int i = 0; i < nthr; i++) { + if (results[i] != 0) { + result = results[i * CACHE_LINE_SIZE]; + break; + } + } + + free(results); + + return result; +} +#undef CACHE_LINE_SIZE + +static jit_avx512_core_u8_copy_an_kern *copy_an; +static jit_avx512_core_u8_copy_at_kern *copy_at; +static jit_avx512_core_u8_copy_bn_kern *copy_bn; +static jit_avx512_core_u8_copy_bt_kern *copy_bt; +static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an; +static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at; +static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn; +static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt; +static jit_avx512_core_gemm_s8u8s32_kern *kernel; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_c; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r; +static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel; +static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel; + +static void jit_init(blas_t *arg) +{ + static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a, + const dim_t *lda, const int8_t *alpha, int8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a, + const dim_t *lda, const uint8_t *alpha, uint8_t *b, + const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum); + + static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k, + const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c, + const dim_t ldc, const int32_t *col_offset, + const int32_t *row_offset); + + static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + + static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + if (mayiuse(avx512_core_vnni)) { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK_VNNI; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } else { + arg->um = AVX512_UNROLL_M; + arg->un = AVX512_UNROLL_N; + arg->uk = AVX512_UNROLL_K; + arg->bm = AVX512_BM; + arg->bn = AVX512_BN; + arg->bk = AVX512_BK; + + arg->bk_traditional = AVX512_BK_TRADITIONAL; + arg->bn_small_k = AVX512_BN_SMALL_K; + arg->blocking_small_k = AVX512_BLOCKING_SMALL_K; + } + + static std::once_flag initialized; + std::call_once(initialized, []{ + + copy_an = new jit_avx512_core_u8_copy_an_kern(); + copy_at = new jit_avx512_core_u8_copy_at_kern(); + copy_bn = new jit_avx512_core_u8_copy_bn_kern(); + copy_bt = new jit_avx512_core_u8_copy_bt_kern(); + + copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern(); + copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern(); + copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern(); + copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern(); + + kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false); + kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true); + kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true); + kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false); + kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false); + kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true); + kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true); + kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false); + + gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern(); + + + copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *, + const int8_t *, const dim_t *, const int8_t *, int8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *, + const int8_t *, const dim_t *, const int8_t *, int8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *, + const uint8_t *, const dim_t *, const uint8_t *, uint8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *, + const uint8_t *, const dim_t *, const uint8_t *, uint8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *, + const int8_t *, const dim_t *, const int8_t *, int8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *, + const int8_t *, const dim_t *, const int8_t *, int8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *, + const uint8_t *, const dim_t *, const uint8_t *, uint8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *, + const uint8_t *, const dim_t *, const uint8_t *, uint8_t *, + const dim_t *, const dim_t *, int32_t *)>(); + + kern = kernel->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *, + const dim_t *, const float *, const int8_t *, const uint8_t *, + int32_t *, const dim_t, const int32_t *, const int32_t *)>(); + + gemv_s8u8s32_kern = + gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t> + (mayiuse(avx512_core_vnni)); + gemv_u8s8s32_kern = + gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t> + (mayiuse(avx512_core_vnni)); + }); + + if (arg->bo == 0) { // No need to compute A row sum if bo is zero + if (arg->transa == 0) { + arg->copyA = copyAn; + } else { + arg->copyA = copyAt; + } + } else { + if (arg->transa == 0) { + arg->copyA = copySumAn; + } else { + arg->copyA = copySumAt; + } + } + + if (arg->ao == 0) { // No need to compute B column sum if ao is zero + if (arg->transb == 0) { + arg->copyB = copyBn; + } else { + arg->copyB = copyBt; + } + } else { + if (arg->transb == 0) { + arg->copyB = copySumBn; + } else { + arg->copyB = copySumBt; + } + } + + arg->kernel = kern; + arg->kernel_b = kern_b; + arg->kernel_r = kern_r; + arg->kernel_c = kern_c; + arg->kernel_b0 = kern_b0; + arg->kernel_b0_b = kern_b0_b; + arg->kernel_b0_r = kern_b0_r; + arg->kernel_b0_c = kern_b0_c; + arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern; + arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern; +} + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) +{ + char transa = *transA; + char transb = *transB; + char offsetc = *offsetC; + + blas_t args; + + // Initialize blas structure + args.m = *m; + args.n = *n; + args.k = *k; + args.alpha = alpha; + args.a = a; + args.lda = *lda; + args.b = b; + args.ldb = *ldb; + args.beta = beta; + args.c = c; + args.ldc = *ldc; + args.transa = (transa == 'N' || transa == 'n') ? 0 : 1; + args.transb = (transb == 'N' || transb == 'n') ? 0 : 1; + args.um = 0; + args.un = 0; + args.bm = 0; + args.bn = 0; + args.bk = 0; + args.copyA = NULL; + args.copyB = NULL; + args.kernel = NULL; + args.kernel_b0 = NULL; + args.ao = *oa; + args.bo = *ob; + args.co = oc; + + if (offsetc == 'F' || offsetc == 'f') { + args.offsetc = FIX_OFFSET; + } else if (offsetc == 'R' || offsetc == 'r') { + args.offsetc = ROW_OFFSET; + } else { // offsetc == 'C' || offsetc == 'c' + args.offsetc = COL_OFFSET; + } + + jit_init(&args); + int result = gemm_threading_driver(&args); + + return (result < 0) ? mkldnn_out_of_memory : mkldnn_success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp new file mode 100644 index 0000000000..b2e2902a12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP +#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP + +#include <cstdint> +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t jit_avx512_core_gemm_s8u8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const uint8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); + +} +} +} + +#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp new file mode 100644 index 0000000000..57554a1852 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp @@ -0,0 +1,539 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_gemm_s8u8s32_kern.hpp" + + +#ifdef _WIN32 +static const bool is_windows = 1; +#else +static const bool is_windows = 0; +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + + + + +// Convert between vector register lengths. +static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); } +static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); } + +// Load from or store to C. +void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst, + const Xbyak::Address &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(make_ymm(dst), src); break; + case 4: vmovups(make_xmm(dst), src); break; + case 2: vmovlps(make_xmm(dst), src); break; + case 1: vmovss(make_xmm(dst), src); break; + } +} +void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst, + const Xbyak::Xmm &src, int nelems) +{ + switch (nelems) { + default: vmovups(dst, src); break; + case 8: vmovups(dst, make_ymm(src)); break; + case 4: vmovups(dst, make_xmm(src)); break; + case 2: vmovsd(dst, make_xmm(src)); break; + case 1: vmovss(dst, make_xmm(src)); break; + } +} + +// Perform length-4 dot product accumulations of unsigned and signed bytes +// in parallel. +// Use vpdpbusd if VNNI available, otherwise emulate. +void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst, + const Xmm &src1, const Xmm &src2) +{ + if (vnni) + vpdpbusd(dst, src1, src2); + else { + vpmaddubsw(dp_scratch, src1, src2); + vpmaddwd(dp_scratch, ones, dp_scratch); + vpaddd(dst, dst, dp_scratch); + } +} + +// Inner kernel. +void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n, + bool cfetch) +{ + int um_vecs = (unroll_m + 15) >> 4; + Label label_kernel_loop; + + L_aligned(label_kernel_loop); { + for (int h = 0; h < 4; h++) { + for (int j = 0; j < unroll_n; j++) { + const Zmm b = b_regs[j & 1]; + + vpbroadcastd(b, ptr[BO + isize * + (2 * j + 2 * h * unroll_n - offset_b)]); + dot_product(c_regs[0][j], b, a_regs[0]); + + if (j == 1 && !(h & 1)) + prefetch_b(ptr[BO + isize * (prefetch_size_b + + 2 * h * unroll_n - offset_b)]); + else if (j % 3 == 0) + prefetch_a(ptr[AO + isize * (prefetch_size_a + + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]); + + for (int i = 1; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + + if (cfetch && (j == std::min(1, unroll_n - 1))) { + if (h == 3) + lea(CO2, ptr[CO2 + LDC]); + else if (h < um_vecs) + prefetch_c(ptr[CO2 + (16 * h * size)]); + } + + if (h == 3 && j == std::min(3, unroll_n - 1)) + lea(AA, ptr[AA + (32 * isize)]); + } + + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * + (32 * i + 2 * (h + 1) * unroll_m - offset_a)]); + + if (h == 2) + prefetch_x(ptr[AA - (offset_a * isize)]); + } + + add(AO, 8 * isize * unroll_m); + add(BO, 8 * isize * unroll_n); + sub(LoopCount, 1); + jg(label_kernel_loop, T_NEAR); + } +} + +// k remainder loop for kernel. +void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m, + int unroll_n, int unroll_k, int bwidth) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + + for (int h = 0; h < unroll_k; h++) { + for (int j = 0; j < unroll_n; j++) { + Zmm b = b_regs[j & 1]; + auto b_src = ptr[BO + (-isize * offset_b + + bwidth * (j + h * unroll_n))]; + + switch (bwidth) { + case 4: + vpbroadcastd(b, b_src); + break; + case 2: + vpbroadcastw(b, b_src); + break; + case 1: + vpbroadcastb(b, b_src); + break; + } + for (int i = 0; i < um_vecs; i++) + dot_product(c_regs[i][j], b, a_regs[i]); + } + + if (unroll_k > 1) { + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i + + (h + 1) * 2 * unroll_m - offset_a)]); + } + } + + add(AO, unroll_k * unroll_m * bwidth); + add(BO, unroll_k * unroll_n * bwidth); +} + +// Inner loop. +void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n) +{ + if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N) + || (unroll_m < 0) || (unroll_n < 0)) + return; + + int um_vecs = (unroll_m + 15) >> 4; + int stage1 = unroll_n, stage2 = unroll_n; + + Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2; + Label label_k_main_loop_3, label_kernel_loop_3; + Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2; + Label label_k_rem_1, label_update_begin; + + mov(AO, A); + for (int i = 0; i < um_vecs; i++) + vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]); + + mov(LoopCount, K); + sar(LoopCount, 4); + jle(label_k_remainder_loop_begin, T_NEAR); + + // Main k loops, broken into three parts to time C prefetching. + sub(LoopCount, stage1 + stage2); + jle(label_k_main_loop_2, T_NEAR); + + kernel_loop(unroll_m, unroll_n, false); + + L_aligned(label_k_main_loop_2); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage1); + jle(label_k_main_loop_3, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + L_aligned(label_k_main_loop_3); + lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]); + add(LoopCount, stage2); + jle(label_k_remainder_loop_begin, T_NEAR); + + kernel_loop(unroll_m, unroll_n, true); + + // k remainder handling + L_aligned(label_k_remainder_loop_begin); + mov(LoopCount, K); + test(LoopCount, 8); + je(label_k_rem_4, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 2, 4); + + L_aligned(label_k_rem_4); + mov(LoopCount, K); + test(LoopCount, 4); + je(label_k_rem_2, T_NEAR); + + remainder_kernel(unroll_m, unroll_n, 1, 4); + + L_aligned(label_k_rem_2); + mov(LoopCount, K); + test(LoopCount, 2); + je(label_k_rem_1, T_NEAR); + + Zmm zero = zmm6; + Zmm tmp = zmm5; + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 2); + + L_aligned(label_k_rem_1); + mov(LoopCount, K); + test(LoopCount, 1); + je(label_update_begin, T_NEAR); + + vpxorq(zero, zero, zero); + for (int i = 0; i < um_vecs; i++) { + Zmm a = a_regs[i]; + vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]); + vpunpcklbw(tmp, a, zero); + vpunpckhbw(a, a, zero); + vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1); + vpunpcklwd(tmp, a, zero); + vpunpckhwd(a, a, zero); + vshufi32x4(a, tmp, a, 0x44); + vshufi32x4(a, a, a, 0xD8); + } + + remainder_kernel(unroll_m, unroll_n, 1, 1); + + // Add offsets and update C. + L_aligned(label_update_begin); + + if (enable_offset_r) { + // Add row offsets. + mov(rax, coffset_ry); + for (int j = 0; j < unroll_n; j++) { + Zmm row_offset = zmm0; + + vbroadcastss(row_offset, ptr[rax + size * j]); + + for (int i = 0; i < um_vecs; i++) + vpaddd(c_regs[i][j], c_regs[i][j], row_offset); + } + add(coffset_ry, size * unroll_n); + } + + if (enable_offset_c) { + // Add column offsets. + mov(rax, coffset_cy); + for (int i = 0; i < um_vecs; i++) { + Zmm col_offset = zmm0; + + c_load(col_offset, ptr[rax + size * 16 * i], unroll_m); + + for (int j = 0; j < unroll_n; j++) + vpaddd(c_regs[i][j], c_regs[i][j], col_offset); + } + } + + Reg64 LDC3 = rax; + lea(LDC3, ptr[LDC + LDC * 2]); + + // C updates. + int c_off_j = 0; + for (int j = 0; j < unroll_n; j++) { + if (j > 0 && (j & 3) == 0) { + lea(CO1, ptr[CO1 + LDC * 4]); + c_off_j += 4; + } + + int jj = j - c_off_j; + + for (int i = 0; i < um_vecs; i++) { + Zmm c = c_regs[i][j]; + Zmm c_old = zmm0; + decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj; + + auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i]; + + if (beta_zero) + c_store(c_mem, c, unroll_m); + else { + c_load(c_old, c_mem, unroll_m); + vpaddd(c_old, c, c_old); + c_store(c_mem, c_old, unroll_m); + } + + vpxorq(c, c, c); + } + } + + lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]); +} + +// Outer loop. +void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y, + Label *&cur_outerloop_label) +{ + Label label_m_loop, label_n_loop, label_n_remainder_loops[6]; + + L(*cur_outerloop_label); + cur_outerloop_label++; + if (unroll_x >= IGEMM_UNROLL_M) { + mov(J, M); + cmp(J, unroll_x); + jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label. + } else { + test(J, unroll_x); + jle(*cur_outerloop_label, T_NEAR); + } + + L_aligned(label_m_loop); { + mov(CO1, C); + add(C, unroll_x * size); + + mov(BO, B); + + mov(AA, K); + imul(AA, AA, unroll_x * isize); + lea(AA, ptr[A + AA + isize * prefetch_size_a]); + + if (enable_offset_c) { + mov(rax, coffset_cx); + mov(coffset_cy, rax); + add(rax, unroll_x * size); + mov(coffset_cx, rax); + } + + if (enable_offset_r) { + mov(rax, coffset_rx); + mov(coffset_ry, rax); + } + + mov(I, N); + cmp(I, unroll_y); + jl(label_n_remainder_loops[0], T_NEAR); + + L_aligned(label_n_loop); { + innerloop(unroll_x, unroll_y); + sub(I, unroll_y); + cmp(I, unroll_y); + jge(label_n_loop, T_NEAR); + } + + align(16); + + int label_idx = 0; + for (int uy = 16; uy > 0; uy >>= 1) { + L(label_n_remainder_loops[label_idx++]); + if (unroll_y > uy) { + test(I, uy); + jle(label_n_remainder_loops[label_idx], T_NEAR); + + innerloop(unroll_x, uy); + align(16); + } + } + L(label_n_remainder_loops[label_idx]); + + mov(A, AO); + if (unroll_x >= IGEMM_UNROLL_M) { + sub(J, unroll_x); + cmp(J, unroll_x); + jge(label_m_loop); + } + } + + align(16); +} + +void jit_avx512_core_gemm_s8u8s32_kern::generate() +{ + // Prologue + preamble(); + sub(rsp, stack_alloc_size); + + if (is_windows) { + mov(A, arg_a); + mov(B, arg_b); + } + + mov(C, arg_c); + mov(LDC, arg_ldc); + + sub(A, -offset_a * isize); + sub(B, -offset_b * isize); + + mov(M, qword[M]); + mov(N, qword[N]); + mov(K, qword[K]); + + lea(LDC, ptr[LDC * size]); + + if (enable_offset_c) { + mov(rax, arg_coffset_c); + mov(coffset_cx, rax); + } + if (enable_offset_r) { + mov(rax, arg_coffset_r); + mov(coffset_rx, rax); + } + + for (int i = 0; i < (max_unroll_m >> 4); i++) { + for (int j = 0; j < max_unroll_n; j++) { + auto &c = c_regs[i][j]; + vpxorq(c, c, c); + } + } + + if (!vnni) { + mov(rax, 1); + movq(make_xmm(ones), rax); + vpbroadcastw(ones, make_xmm(ones)); + } + + Label outerloop_labels[8]; + Label *cur_outerloop_label = &outerloop_labels[0]; + + // Main m loop. + outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label); + + // m remainder loops. + for (int um = 32; um > 0; um >>= 1) + if (IGEMM_UNROLL_M > um) + outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label); + + L(*cur_outerloop_label); + + // Epilogue. + add(rsp, stack_alloc_size); + postamble(); +} + + +jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool + beta_zero_, bool enable_offset_c_, bool enable_offset_r_) : + jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0), + arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0), + coffset_rx(0), coffset_ry(0) +{ + beta_zero = beta_zero_; + enable_offset_c = enable_offset_c_; + enable_offset_r = enable_offset_r_; + vnni = mayiuse(avx512_core_vnni); + + // Assign integer registers + M = is_windows ? rcx : rdi; + N = is_windows ? rdx : rsi; + K = is_windows ? r8 : rdx; + A = is_windows ? rsi : r8; + B = r9; + C = r10; + LDC = r11; + I = r12; + J = r13; + LoopCount = rax; + AO = r14; + BO = r15; + CO1 = rbx; + CO2 = rbp; + AA = is_windows ? rdi : rcx; + + // Assign vector registers + dp_scratch = zmm6; + ones = zmm7; + for (int i = 0; i < (max_unroll_m >> 4); i++) + a_regs[i] = Zmm(i); + b_regs[0] = zmm4; + b_regs[1] = zmm5; + + int rn = 0; + for (int i = 0; i < (max_unroll_m >> 4); i++) + for (int j = 0; j < max_unroll_n; j++) + c_regs[i][j] = Zmm(8 + rn++); + + // Assign stack variables. + stack_alloc_size = 32; + auto args_offset = stack_alloc_size + get_size_of_abi_save_regs() + + 8 + (is_windows ? 48 : 0); + + arg_a = ptr[rsp + (args_offset - 16)]; + arg_b = ptr[rsp + (args_offset - 8)]; + arg_c = ptr[rsp + (args_offset + 0)]; + arg_ldc = ptr[rsp + (args_offset + 8)]; + arg_coffset_c = ptr[rsp + (args_offset + 16)]; + arg_coffset_r = ptr[rsp + (args_offset + 24)]; + + coffset_cx = qword[rsp + 0]; + coffset_cy = qword[rsp + 8]; + coffset_rx = qword[rsp + 16]; + coffset_ry = qword[rsp + 24]; + + generate(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp new file mode 100644 index 0000000000..e8efcc1cc8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef IGEMM_KERNEL_GENERATOR_HPP +#define IGEMM_KERNEL_GENERATOR_HPP + +#include "jit_generator.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator { +public: + jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_, + bool enable_offset_r_); + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern); + +protected: + bool beta_zero; + bool enable_offset_c, enable_offset_r; + bool vnni; + + void prefetch_a(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_b(const Xbyak::Address &src) { + prefetcht0(src); + } + void prefetch_c(const Xbyak::Address &src) { + prefetchw(src); + } + void prefetch_x(const Xbyak::Address &src) { + prefetcht0(src); + } + + void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems); + void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems); + + void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1, + const Xbyak::Xmm &src2); + void kernel_loop(int unroll_m, int unroll_n, bool cfetch); + void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth); + void innerloop(int unroll_m, int unroll_n); + void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label); + + void generate(); + + +private: + static const int IGEMM_UNROLL_M = 48; + static const int IGEMM_UNROLL_N = 8; + + static const int isize = 2; + static const int size = 4; + + // Prefetch configuration + static const int prefetch_size_a = 32 * 5; + static const int prefetch_size_b = 32 * 4; + + static const int offset_a = 256, offset_b = 256; + static const int max_unroll_m = 48, max_unroll_n = 8; + + // Integer register assignments + Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount; + Xbyak::Reg64 AO, BO, CO1, CO2, AA; + + // Vector register assignments + Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2]; + Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n]; + + // Stack variable assignments + int stack_alloc_size; + Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r; + Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry; + + void L_aligned(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } +}; + +} +} +} + +#endif /* header guard */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp new file mode 100644 index 0000000000..4f0b10dadd --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp @@ -0,0 +1,290 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "gemv.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) { + + blas_t arg_gemv = *arg; + + if ((arg -> offsetc == FIX_OFFSET) && // Fix offset + (arg -> ao == 0) && + (arg -> bo == 0) && + (arg -> co[0] == 0) && + (*(arg -> alpha) == 1.0f) && + ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) { + + if (arg -> n == 1) { + + if (arg -> transa == 1) { // A transpose + arg_gemv.n = arg -> k; + arg_gemv.ldc = 1; + arg_gemv.swap = 0; + if (arg -> transb == 0) { // B non transpose + arg_gemv.ldb = 1; + } + // B transpose arg_gemv.ldb = arg -> ldb + gemv_threading_driver(&arg_gemv); + return 1; + } + } + + if (arg -> m == 1) { + + if (arg -> transb == 0) { // B non transpose + arg_gemv.transa = 1; + arg_gemv.m = arg -> n; + arg_gemv.n = arg -> k; + arg_gemv.a = (int8_t *) arg -> b; + arg_gemv.lda = arg -> ldb; + arg_gemv.b = (uint8_t *) arg -> a; + arg_gemv.swap = 1; + if (arg -> transa == 0) { // A non transpose + arg_gemv.ldb = arg -> lda; + } + else { // A transpose + arg_gemv.ldb = 1; + } + gemv_threading_driver(&arg_gemv); + return 1; + } + } + } + + return 0; +} + + +int gemv_kernel_driver(blas_t *arg) { + + dim_t m = arg -> m; + dim_t n = arg -> n; + uint8_t *a = (uint8_t *) arg -> a; + dim_t lda = arg -> lda; + int8_t *b = (int8_t *) arg -> b; + float beta = *(arg -> beta); + + if (arg -> swap) { + arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c); + } + else { + arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a, + arg -> lda, arg -> b, *(arg -> beta), arg -> c); + } + + return 0; +} + +int gemv_threading_driver(blas_t *arg) { + + dim_t nthr_m, nthr_n = 1; + dim_t MB, NB, UM = 16, UN = 64; + dim_t BLOCKM = 192, BLOCKN = 3072; + int status; + dim_t i; + + dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads(); + + uint8_t *new_x = NULL; + int32_t *tmp_y = NULL, *new_y = NULL; + + dim_t m = arg -> m, n = arg -> n; + + blas_t arg_seq = *arg; + float zero = 0.0f; + + nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr); + MB = m / nthr_m; + MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM; + nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1; + nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr); + + while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) { + nthr_n++; + } + + NB = n / nthr_n; + NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN; + nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1; + nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m); + + nthr = nthr_m * nthr_n; + + if (arg -> ldb != 1) { + new_x = (uint8_t *)malloc(n, 64); + if (new_x == NULL) + return 1; + for (i = 0; i < n; i++) { + new_x[i] = (arg -> b)[i * arg -> ldb]; + } + arg_seq.b = new_x; + arg_seq.ldb = 1; + } + else new_x = (uint8_t *) arg -> b; + + if (arg -> ldc != 1) { + new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64); + if (new_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + // GEMV computation + if (nthr == 1) { + + if (arg -> ldc != 1) { + if (*(arg -> beta) != 0.0f) { + for (i = 0; i < m; i++) { + new_y[i] = arg -> c[i * arg -> ldc]; + } + } + } + + status = gemv_kernel_driver(&arg_seq); + + if (arg -> ldc != 1) { + for (i = 0; i < m; i++) { + arg -> c[i * arg -> ldc] = new_y[i]; + } + } + + if (arg -> ldb != 1) { + free(new_x); + } + if (arg -> ldc != 1) { + free(new_y); + } + return status; + } + + if (nthr_n > 1) { + tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE); + if (tmp_y == NULL) { + if (arg -> ldb != 1) { + free(new_x); + } + return 1; + } + } + + parallel_nd((int) nthr, [&](const dim_t ithr) { + + dim_t m_from, m_to, myM; + dim_t n_from, n_to, myN; + + dim_t n_id, m_id; + dim_t loc_incy = 1; + int32_t *loc_y; + + blas_t arg_loc = arg_seq; + int j; + + m_id = ithr / nthr_n; + n_id = ithr % nthr_n; + + m_from = MB * m_id; + m_to = MB * (m_id + 1); + if ((m_to > m) || (m_id == nthr_m - 1)) + m_to = m; + + myM = m_to - m_from; + + n_from = NB * n_id; + n_to = NB * (n_id + 1); + if ((n_to > n) || (n_id == nthr_n - 1)) + n_to = n; + + myN = n_to - n_from; + + if (n_id != 0) { + arg_loc.beta = &zero; + loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from; + } + else { + if (arg -> ldc == 1) { + loc_y = arg_seq.c + m_from; + } + else { + // need to copy the block of c in new_y + loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t)); + if (*(arg -> beta) != 0.0f) { + for (j = 0; j < myM; j++) { + loc_y[j] = arg -> c[(m_from + j) * arg -> ldc]; + } + } + } + } + + arg_loc.m = myM; + arg_loc.n = myN; + arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from; + arg_loc.b = arg_seq.b + n_from; + arg_loc.c = loc_y; + arg_loc.ldc = loc_incy; + + gemv_kernel_driver(&arg_loc); + + if ((n_id == 0) && (arg -> ldc != 1)) { + for (j = 0; j < myM; j++) { + arg -> c[(m_from + j) * arg -> ldc] = loc_y[j]; + } + } + + }); + + if (nthr_n > 1) { + parallel_nd((int) nthr_m, [&](const dim_t ithr) { + + dim_t j, j_from, j_to, ii; + int32_t acc; + + j_from = MB * ithr; + j_to = MB * (ithr + 1); + if ((j_to > m) || (ithr == nthr - 1)) + j_to = m; + + for (j = j_from; j < j_to; j++) { + acc = 0; + for (ii = 0; ii < nthr_n - 1; ii++) { + acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j]; + } + (arg -> c)[j * arg -> ldc] += acc; + } + }); + free(tmp_y); + } + + if (arg -> ldb != 1) { + free(new_x); + } + + if (arg -> ldc != 1) { + free(new_y); + } + + return 0; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp new file mode 100644 index 0000000000..c57a8c1d12 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp @@ -0,0 +1,411 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp" + +#ifdef _WIN32 +#define is_windows 1 +#else +#define is_windows 0 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b, + Xbyak::Zmm a, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, + int use_vnni) { + + if (use_vnni) { + if (swap) + vpdpbusd(acc, a, b); + else + vpdpbusd(acc, b, a); + } + + else { + if (swap) + vpmaddubsw(tmp, a, b); + else + vpmaddubsw(tmp, b, a); + vpmaddwd(tmp, tmp, one); + vpaddd(acc, tmp, acc); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx, + int b_idx, int nreg_acc, + Xbyak::Reg64 A, Xbyak::Reg64 lda, + Xbyak::Reg64 X, Xbyak::Zmm tmp, + Xbyak::Zmm one, bool swap, int use_vnni, + int use_mask, Xbyak::Opmask mask_n) { + + int i; + int nreg_A = nreg_acc / 2 + (nreg_acc % 2); + + // load X + j + if (use_mask) + vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]); + else + vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]); + + xor_(r14, r14); + // load values of A + for (i = 0; i < nreg_A; i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A; i++) { + // vnni (acc, b, a, tmp, one, swap, use_vnni) + vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + if (use_mask) + vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]); + else + vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]); + add(r14, lda); + } + + for (i = 0; i < nreg_A - (nreg_acc % 2); i++) { + vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx), + Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni); + } + +} + +void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A, + Xbyak::Zmm B, Xbyak::Zmm C, + Xbyak::Zmm D) { + + vshufi32x4(dest, A, C, 0x44); + vshufi32x4(A, A, C, 0xEE); + vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3 + + vshufi32x4(dest, B, D, 0x44); + vshufi32x4(B, B, D, 0xEE); + vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3 + + vshufi32x4(A, C, D, 0x88); + vshufi32x4(B, C, D, 0xDD); + vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi + +} + +void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y, + int start_a_idx, int start_acc_idx, + Xbyak::Xmm beta, int use_mask, + Xbyak::Opmask mask_m) { + + int l, i, k, j, last_it; + Xbyak::Label store_label; + + l = 0; + for (k = 0; k < nreg_acc; k += 8) { + for (i = 0, j = k; i < 8; i += 4, j += 2) { + if (j < nreg_acc) { + // shuffle per block of 4 registers + shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest + Xbyak::Zmm(start_acc_idx + j), // A = acc0 + Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1 + Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4 + Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5 + + // extract low and high from dest and hadd + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0); + vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1); + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l + 1), + Xbyak::Ymm(start_a_idx + l + 2)); + } + l++; + } + + vphaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l - 2), + Xbyak::Ymm(start_a_idx + l - 1)); + + l++; + } + + // eventually add with C and store new value + vxorps(Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx), + Xbyak::Ymm(start_a_idx)); + vucomiss(beta, Xbyak::Ymm(start_a_idx)); + je(store_label, T_NEAR); + + // beta = 1 + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + // load Y and add + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]); + else + vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]); + + vpaddd(Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + l), + Xbyak::Ymm(start_a_idx + k / 8)); + } + + // store + aligned_label(store_label); + for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) { + last_it = (k + 8) > nreg_acc; + if (use_mask && last_it) + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m); + else + vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l)); + } + +} + +template <typename T> +T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) { + + Xbyak::Opmask mask_n = k1, mask_m = k2; + Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label; + Xbyak::Label n_tail_label, update_c_label, end_label; + constexpr unsigned int n_labels = (1 << unroll_m) - 1; + Xbyak::Label m_tail_label_case[n_labels]; + Xbyak::Label n_loop_label_case[n_labels]; + Xbyak::Label n_tail_label_case[n_labels]; + Xbyak::Label update_c_label_case[n_labels]; + + int i, ii; + + Xbyak::Zmm one, tmp; + Xbyak::Reg64 n = abi_param2, m = abi_param1; + Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3; + Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4; + Xbyak::Reg64 X = is_windows ? rdi : r8; + Xbyak::Xmm beta = xmm1; + Xbyak::Reg64 Y = is_windows ? rsi : r9; + + bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value; + + // Windows: read on the stack lda, X, beta, Y + + int zmm_idx = 1; + int nreg_acc = 1 << unroll_m; + int nreg_A = 1 << (unroll_m - 1); + int nreg_A_acc = nreg_acc + nreg_A; + + if (!use_vnni) { + // set a zmm register to one + tmp = Xbyak::Zmm(0); + one = Xbyak::Zmm(zmm_idx + 1); + zmm_idx += 2; // one + tmp + } + else { + beta = xmm0; + } + + preamble(); + + if (is_windows) { + mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]); + mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]); + movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]); + mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]); + } + + if (use_vnni && !is_windows) { + movaps(beta, xmm1); + } + + mov(rax, (1 << unroll_n) - 1); + kmovq(k3, rax); + + and_(rax, n); // rax contains n & ((1 << unroll_n) - 1) + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_n, rbx); + // mask_n set (AVX512 only), can use rax and rbx again + + // set mask_m for update of the C matrix + // load/store on the C matrix use Ymm so tail according to Ymm size + mov(rax, 7); // 8 * 32 = 256 Ymm size + and_(rax, m); // rax contains m & 7 + mov(rbx, 1); + shlx(rbx, rbx, rax); + sub(rbx, 1); + kmovq(mask_m, rbx); + // mask_m set (AVX512 only), can use rax and rbx again + + // setup register of ones when VNNI instructions not available + if (!use_vnni) { + vmovdqu16(one, ptr[rip + one_label]); + } + + // M loop + // base pointer for A rax contains a + i * lda + // Loop stop when rax >= a + (m & mask_um) * lda = rbx + // loop increment r10 = um * lda + // rbp = Y + i + mov(rax, A); // i = 0 + mov(rbx, m); + and_(rbx, mask_um); + imul(rbx, lda); + add(rbx, A); + mov(r10, lda); + sal(r10, unroll_m); + mov(rbp, Y); + + // N loop + // base pointer for X r11 contains x + j + // Loop stop when r11 >= x + n & mask_un = r12 + // loop increment un + // r13 = rax + j = A + i * lda + j + mov(r12, n); + and_(r12, mask_un); + add(r12, X); + + // M loop + aligned_label(m_loop_label); + cmp(rax, rbx); + jge(m_tail_label, T_NEAR); + + // enter M loop + for(i = 0; i < nreg_acc; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label); + cmp(r11, r12); + jge(n_tail_label, T_NEAR); + + // enter N loop + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label, T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label); + + ktestq(mask_n, k3); + je(update_c_label, T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc, + r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label); + + update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m); + + // increment rax with um * lda + add(rax, r10); + add(rbp, 1 << (unroll_m + 2)); + jmp(m_loop_label, T_NEAR); + // end M loop + + // M tail + aligned_label(m_tail_label); + + // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1 + mov(r10, m); + and_(r10, (1 << unroll_m) - 1); + for (ii = 1; ii < 1 << unroll_m; ii++) { + aligned_label(m_tail_label_case[ii-1]); + cmp(r10, ii); + if (ii == (1 << unroll_m) - 1) + jne(end_label, T_NEAR); + else + jne(m_tail_label_case[ii], T_NEAR); + + // m_tail = i, use i accumulators + + for(i = 0; i < ii; i++) { + vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A), + Xbyak::Zmm(i + zmm_idx + nreg_A)); + } + + // N loop + mov(r11, X); // j = 0 + mov(r13, rax); + aligned_label(n_loop_label_case[ii - 1]); + cmp(r11, r12); + jge(n_tail_label_case[ii - 1], T_NEAR); + + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 0, mask_n); + + // increment rax with un + add(r11, 1 << unroll_n); + add(r13, 1 << unroll_n); + jmp(n_loop_label_case[ii - 1], T_NEAR); + // end N loop + + // N tail + aligned_label(n_tail_label_case[ii - 1]); + ktestq(mask_n, k3); + je(update_c_label_case[ii - 1], T_NEAR); + n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13, + lda, r11, tmp, one, swap, use_vnni, 1, mask_n); + + // update C matrix + aligned_label(update_c_label_case[ii - 1]); + update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m); + + if (ii < ((1 << unroll_m) - 1)) + jmp(end_label, T_NEAR); + } + + aligned_label(end_label); + + postamble(); + + if (!use_vnni) { + aligned_label(one_label); + for (i = 0; i < size_vec_reg/8; i++) + dq(0x0001000100010001); + } + + return (T) getCode(); +} + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int); + +template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t +jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp new file mode 100644 index 0000000000..9ea23a5f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +class jit_avx512_core_gemv_s8u8s32_kern : jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern); + + // assumes untoll_{m,n} are a power of 2 + static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m + const int mask_um = 0xFFFFFFF0; + static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n + const int mask_un = 0xFFFFFFC0; + const int size_vec_reg = 64; // bytes + + void aligned_label(Xbyak::Label &label, int alignment = 16) { + align(alignment); + L(label); + } + + void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int); + void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64, + Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask); + void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm); + void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask); + +public: + jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {}; + + // m, n, alpha, a, lda, x, beta, y + typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float, + const int8_t*, const dim_t, const uint8_t*, + const float, int32_t*); + typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float, + const uint8_t*, const dim_t, const int8_t*, + const float, int32_t*); + + template <typename T> + T generate(int use_vnni); + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp new file mode 100644 index 0000000000..544cd2ff25 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp @@ -0,0 +1,819 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l170; +Xbyak::Label l1f0; +Xbyak::Label l20; +Xbyak::Label l224; +Xbyak::Label l234; +Xbyak::Label l240; +Xbyak::Label l254; +Xbyak::Label l32c; +Xbyak::Label l34; +Xbyak::Label l388; +Xbyak::Label l3b0; +Xbyak::Label l3c0; +Xbyak::Label l3cc; +Xbyak::Label l3dc; +Xbyak::Label l454; +Xbyak::Label l48c; +Xbyak::Label l4a8; +Xbyak::Label l4b8; +Xbyak::Label l4c4; +Xbyak::Label l4d8; +Xbyak::Label l570; +Xbyak::Label l5c4; +Xbyak::Label l5f0; +Xbyak::Label l60c; +Xbyak::Label l61c; +Xbyak::Label l628; +Xbyak::Label l638; +Xbyak::Label l6b0; +Xbyak::Label l6f4; +Xbyak::Label l720; +Xbyak::Label l73c; +Xbyak::Label l74c; +Xbyak::Label l758; +Xbyak::Label l76c; +Xbyak::Label l804; +Xbyak::Label l858; +Xbyak::Label l88c; +Xbyak::Label l8a4; +Xbyak::Label l8b2; +Xbyak::Label l8bc; +Xbyak::Label l8cc; +Xbyak::Label l944; +Xbyak::Label l98c; +Xbyak::Label l9b0; +Xbyak::Label l9c8; +Xbyak::Label l9d8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l234, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + mov(I, M); + sar(I, 0x2); + jle(l170, T_NEAR); + align(4); + +L(l34); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + movdqu(xmm0, xword[A1-0x60]); + movdqu(xmm1, xword[A1+LDA*1-0x60]); + movdqu(xmm2, xword[A1+LDA*2-0x60]); + movdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x30], xmm2); + sub(B, -192); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(l170); + test(M, 0x2); + jle(l1f0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + movdqu(xmm4, xword[A1-0x70]); + movdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + movdqa(xmm6, xmm0); + punpcklbw(xmm0, xmm3); + punpckhbw(xmm6, xmm3); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm6); + movdqa(xmm6, xmm1); + punpcklbw(xmm1, xmm4); + punpckhbw(xmm6, xmm4); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm6); + movdqa(xmm6, xmm2); + punpcklbw(xmm2, xmm5); + punpckhbw(xmm6, xmm5); + movdqu(xword[B-0x40], xmm2); + movdqu(xword[B-0x30], xmm6); + sub(B, -96); + align(4); + +L(l1f0); + test(M, 0x1); + jle(l224, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + movdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l224); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(l234); + cmp(N, 0x20); + jl(l3c0, T_NEAR); + align(4); + +L(l240); + mov(A1, A); + add(A, 0x20); + mov(I, M); + sar(I, 0x2); + jle(l32c, T_NEAR); + align(4); + +L(l254); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B-0x20], xmm4); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l254, T_NEAR); + align(4); + +L(l32c); + test(M, 0x2); + jle(l388, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l388); + test(M, 0x1); + jle(l3b0, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3b0); + sub(N, 0x20); + cmp(N, 0x20); + jge(l240, T_NEAR); + align(4); + +L(l3c0); + cmp(N, 0x10); + jl(l4b8, T_NEAR); + align(4); + +L(l3cc); + mov(A1, A); + add(A, 0x10); + mov(I, M); + sar(I, 0x2); + jle(l454, T_NEAR); + align(4); + +L(l3dc); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l3dc, T_NEAR); + align(4); + +L(l454); + test(M, 0x2); + jle(l48c, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(l48c); + test(M, 0x1); + jle(l4a8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l4a8); + sub(N, 0x10); + cmp(N, 0x10); + jge(l3cc, T_NEAR); + align(4); + +L(l4b8); + cmp(N, 0x8); + jl(l61c, T_NEAR); + align(4); + +L(l4c4); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(l570, T_NEAR); + align(4); + +L(l4d8); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l4d8, T_NEAR); + align(4); + +L(l570); + test(M, 0x4); + jle(l5c4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l5c4); + test(M, 0x2); + jle(l5f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l5f0); + test(M, 0x1); + jle(l60c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l60c); + sub(N, 0x8); + cmp(N, 0x8); + jge(l4c4, T_NEAR); + align(4); + +L(l61c); + cmp(N, 0x4); + jl(l74c, T_NEAR); + align(4); + +L(l628); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l6b0, T_NEAR); + align(4); + +L(l638); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l638, T_NEAR); + align(4); + +L(l6b0); + test(M, 0x4); + jle(l6f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l6f4); + test(M, 0x2); + jle(l720, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l720); + test(M, 0x1); + jle(l73c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l73c); + sub(N, 0x4); + cmp(N, 0x4); + jge(l628, T_NEAR); + align(4); + +L(l74c); + cmp(N, 0x2); + jl(l8b2, T_NEAR); + align(4); + +L(l758); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l804, T_NEAR); + align(4); + +L(l76c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l76c, T_NEAR); + align(4); + +L(l804); + test(M, 0x4); + jle(l858, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l858); + test(M, 0x2); + jle(l88c, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l88c); + test(M, 0x1); + jle(l8a4, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l8a4); + sub(N, 0x2); + cmp(N, 0x2); + jge(l758, T_NEAR); + align(4); + +L(l8b2); + cmp(N, 0x1); + jl(l9d8, T_NEAR); + align(4); + +L(l8bc); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l944, T_NEAR); + align(4); + +L(l8cc); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l8cc, T_NEAR); + align(4); + +L(l944); + test(M, 0x4); + jle(l98c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l98c); + test(M, 0x2); + jle(l9b0, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l9b0); + test(M, 0x1); + jle(l9c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l9c8); + sub(N, 0x1); + cmp(N, 0x1); + jge(l8bc, T_NEAR); + align(4); + +L(l9d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp new file mode 100644 index 0000000000..1c11fc6cef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp @@ -0,0 +1,2209 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1014; +Xbyak::Label l1390; +Xbyak::Label l159c; +Xbyak::Label l173c; +Xbyak::Label l18e4; +Xbyak::Label l1a7c; +Xbyak::Label l1a8c; +Xbyak::Label l1a98; +Xbyak::Label l1ab4; +Xbyak::Label l1c64; +Xbyak::Label l1d74; +Xbyak::Label l1e50; +Xbyak::Label l1f2c; +Xbyak::Label l1ffc; +Xbyak::Label l20; +Xbyak::Label l200c; +Xbyak::Label l2018; +Xbyak::Label l2034; +Xbyak::Label l2110; +Xbyak::Label l21a0; +Xbyak::Label l2210; +Xbyak::Label l2284; +Xbyak::Label l22f0; +Xbyak::Label l2300; +Xbyak::Label l230c; +Xbyak::Label l2324; +Xbyak::Label l2398; +Xbyak::Label l23e8; +Xbyak::Label l242c; +Xbyak::Label l2474; +Xbyak::Label l24b4; +Xbyak::Label l24c4; +Xbyak::Label l24d0; +Xbyak::Label l24e8; +Xbyak::Label l2520; +Xbyak::Label l254c; +Xbyak::Label l2578; +Xbyak::Label l25a8; +Xbyak::Label l25c8; +Xbyak::Label l25d6; +Xbyak::Label l25e0; +Xbyak::Label l25f0; +Xbyak::Label l260c; +Xbyak::Label l262c; +Xbyak::Label l264c; +Xbyak::Label l2668; +Xbyak::Label l2680; +Xbyak::Label l2690; +Xbyak::Label l44; +Xbyak::Label l58c; +Xbyak::Label l8b0; +Xbyak::Label lb14; +Xbyak::Label ld84; +Xbyak::Label lfdc; +Xbyak::Label lfec; +Xbyak::Label lff8; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lfec, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l58c, T_NEAR); + align(4); + +L(l44); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0x100], xmm4); + movdqu(xword[B+0x1c0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0x110], xmm4); + movdqu(xword[B+0x1d0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0x120], xmm4); + movdqu(xword[B+0x1e0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0x130], xmm4); + movdqu(xword[B+0x1f0], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movdqu(xword[B+0x140], xmm4); + movdqu(xword[B+0x200], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movdqu(xword[B+0x150], xmm4); + movdqu(xword[B+0x210], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movdqu(xword[B+0x160], xmm4); + movdqu(xword[B+0x220], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movdqu(xword[B+0x170], xmm4); + movdqu(xword[B+0x230], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movdqu(xword[B+0x180], xmm4); + movdqu(xword[B+0x240], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movdqu(xword[B+0x190], xmm4); + movdqu(xword[B+0x250], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movdqu(xword[B+0x1a0], xmm4); + movdqu(xword[B+0x260], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + movdqu(xword[B+0x1b0], xmm4); + movdqu(xword[B+0x270], xmm3); + sub(A1, -16); + sub(B, -768); + dec(I); + jg(l44, T_NEAR); + align(4); + +L(l58c); + test(M, 0x8); + jle(l8b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x70], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x80], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x90], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0xa0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0xb0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B], xmm0); + movdqu(xword[B+0xc0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x10], xmm0); + movdqu(xword[B+0xd0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x20], xmm0); + movdqu(xword[B+0xe0], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B+0x30], xmm0); + movdqu(xword[B+0xf0], xmm1); + sub(A1, -8); + sub(B, -384); + align(4); + +L(l8b0); + test(M, 0x4); + jle(lb14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x10], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B+0x30], xmm0); + sub(A1, -4); + sub(B, -192); + align(4); + +L(lb14); + test(M, 0x2); + jle(ld84, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(ld84); + test(M, 0x1); + jle(lfdc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lfdc); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + align(4); + +L(lfec); + cmp(N, 0x20); + jl(l1a8c, T_NEAR); + align(4); + +L(lff8); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1390, T_NEAR); + align(4); + +L(l1014); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movdqu(xword[B+0x80], xmm4); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movdqu(xword[B+0x90], xmm4); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movdqu(xword[B+0xa0], xmm4); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movdqu(xword[B+0xb0], xmm4); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movdqu(xword[B+0xc0], xmm4); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movdqu(xword[B+0xd0], xmm4); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movdqu(xword[B+0xe0], xmm4); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + movdqu(xword[B+0xf0], xmm4); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(l1014, T_NEAR); + align(4); + +L(l1390); + test(M, 0x8); + jle(l159c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x40], xmm0); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x30], xmm0); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x20], xmm0); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x10], xmm0); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l159c); + test(M, 0x4); + jle(l173c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l173c); + test(M, 0x2); + jle(l18e4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l18e4); + test(M, 0x1); + jle(l1a7c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l1a7c); + sub(N, 0x20); + cmp(N, 0x20); + jge(lff8, T_NEAR); + align(4); + +L(l1a8c); + cmp(N, 0x10); + jl(l200c, T_NEAR); + align(4); + +L(l1a98); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + mov(I, M); + sar(I, 0x4); + jle(l1c64, T_NEAR); + align(4); + +L(l1ab4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movdqu(xword[B], xmm4); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movdqu(xword[B+0x10], xmm4); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movdqu(xword[B+0x20], xmm4); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + movdqu(xword[B+0x30], xmm4); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l1ab4, T_NEAR); + align(4); + +L(l1c64); + test(M, 0x8); + jle(l1d74, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x50], xmm0); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l1d74); + test(M, 0x4); + jle(l1e50, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l1e50); + test(M, 0x2); + jle(l1f2c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l1f2c); + test(M, 0x1); + jle(l1ffc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l1ffc); + sub(N, 0x10); + cmp(N, 0x10); + jge(l1a98, T_NEAR); + align(4); + +L(l200c); + cmp(N, 0x8); + jl(l2300, T_NEAR); + align(4); + +L(l2018); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2110, T_NEAR); + align(4); + +L(l2034); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2034, T_NEAR); + align(4); + +L(l2110); + test(M, 0x8); + jle(l21a0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l21a0); + test(M, 0x4); + jle(l2210, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l2210); + test(M, 0x2); + jle(l2284, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2284); + test(M, 0x1); + jle(l22f0, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l22f0); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2018, T_NEAR); + align(4); + +L(l2300); + cmp(N, 0x4); + jl(l24c4, T_NEAR); + align(4); + +L(l230c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2398, T_NEAR); + align(4); + +L(l2324); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l2324, T_NEAR); + align(4); + +L(l2398); + test(M, 0x8); + jle(l23e8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l23e8); + test(M, 0x4); + jle(l242c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l242c); + test(M, 0x2); + jle(l2474, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2474); + test(M, 0x1); + jle(l24b4, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l24b4); + sub(N, 0x4); + cmp(N, 0x4); + jge(l230c, T_NEAR); + align(4); + +L(l24c4); + cmp(N, 0x2); + jl(l25d6, T_NEAR); + align(4); + +L(l24d0); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l2520, T_NEAR); + align(4); + +L(l24e8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l24e8, T_NEAR); + align(4); + +L(l2520); + test(M, 0x8); + jle(l254c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l254c); + test(M, 0x4); + jle(l2578, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2578); + test(M, 0x2); + jle(l25a8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l25a8); + test(M, 0x1); + jle(l25c8, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l25c8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l24d0, T_NEAR); + align(4); + +L(l25d6); + cmp(N, 0x1); + jl(l2690, T_NEAR); + align(4); + +L(l25e0); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l260c, T_NEAR); + align(4); + +L(l25f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l25f0, T_NEAR); + align(4); + +L(l260c); + test(M, 0x8); + jle(l262c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l262c); + test(M, 0x4); + jle(l264c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l264c); + test(M, 0x2); + jle(l2668, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l2668); + test(M, 0x1); + jle(l2680, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l2680); + sub(N, 0x1); + cmp(N, 0x1); + jge(l25e0, T_NEAR); + align(4); + +L(l2690); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp new file mode 100644 index 0000000000..56c36ee14a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp @@ -0,0 +1,564 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l118; +Xbyak::Label l1a8; +Xbyak::Label l20; +Xbyak::Label l218; +Xbyak::Label l28c; +Xbyak::Label l2f8; +Xbyak::Label l308; +Xbyak::Label l314; +Xbyak::Label l32c; +Xbyak::Label l3a0; +Xbyak::Label l3c; +Xbyak::Label l3f0; +Xbyak::Label l434; +Xbyak::Label l47c; +Xbyak::Label l4bc; +Xbyak::Label l4cc; +Xbyak::Label l4d8; +Xbyak::Label l4f0; +Xbyak::Label l528; +Xbyak::Label l554; +Xbyak::Label l580; +Xbyak::Label l5b0; +Xbyak::Label l5d0; +Xbyak::Label l5de; +Xbyak::Label l5e8; +Xbyak::Label l5f8; +Xbyak::Label l614; +Xbyak::Label l634; +Xbyak::Label l654; +Xbyak::Label l670; +Xbyak::Label l688; +Xbyak::Label l698; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l308, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l118, T_NEAR); + align(4); + +L(l3c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movdqu(xword[B-0x40], xmm4); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + movdqu(xword[B-0x30], xmm4); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l3c, T_NEAR); + align(4); + +L(l118); + test(M, 0x8); + jle(l1a8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x70], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l1a8); + test(M, 0x4); + jle(l218, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l218); + test(M, 0x2); + jle(l28c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l28c); + test(M, 0x1); + jle(l2f8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l2f8); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l308); + cmp(N, 0x4); + jl(l4cc, T_NEAR); + align(4); + +L(l314); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l3a0, T_NEAR); + align(4); + +L(l32c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l32c, T_NEAR); + align(4); + +L(l3a0); + test(M, 0x8); + jle(l3f0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3f0); + test(M, 0x4); + jle(l434, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l434); + test(M, 0x2); + jle(l47c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l47c); + test(M, 0x1); + jle(l4bc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4bc); + sub(N, 0x4); + cmp(N, 0x4); + jge(l314, T_NEAR); + align(4); + +L(l4cc); + cmp(N, 0x2); + jl(l5de, T_NEAR); + align(4); + +L(l4d8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + mov(I, M); + sar(I, 0x4); + jle(l528, T_NEAR); + align(4); + +L(l4f0); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l4f0, T_NEAR); + align(4); + +L(l528); + test(M, 0x8); + jle(l554, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l554); + test(M, 0x4); + jle(l580, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l580); + test(M, 0x2); + jle(l5b0, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l5b0); + test(M, 0x1); + jle(l5d0, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l5d0); + sub(N, 0x2); + cmp(N, 0x2); + jge(l4d8, T_NEAR); + align(4); + +L(l5de); + cmp(N, 0x1); + jl(l698, T_NEAR); + align(4); + +L(l5e8); + mov(A1, A); + add(A, LDA); + mov(I, M); + sar(I, 0x4); + jle(l614, T_NEAR); + align(4); + +L(l5f8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l5f8, T_NEAR); + align(4); + +L(l614); + test(M, 0x8); + jle(l634, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l634); + test(M, 0x4); + jle(l654, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l654); + test(M, 0x2); + jle(l670, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l670); + test(M, 0x1); + jle(l688, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l688); + sub(N, 0x1); + cmp(N, 0x1); + jge(l5e8, T_NEAR); + align(4); + +L(l698); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp new file mode 100644 index 0000000000..53e99d94de --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp @@ -0,0 +1,501 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l120; +Xbyak::Label l14c; +Xbyak::Label l168; +Xbyak::Label l178; +Xbyak::Label l184; +Xbyak::Label l194; +Xbyak::Label l20; +Xbyak::Label l20c; +Xbyak::Label l250; +Xbyak::Label l27c; +Xbyak::Label l298; +Xbyak::Label l2a8; +Xbyak::Label l2b4; +Xbyak::Label l2c8; +Xbyak::Label l34; +Xbyak::Label l360; +Xbyak::Label l3b4; +Xbyak::Label l3e8; +Xbyak::Label l400; +Xbyak::Label l40e; +Xbyak::Label l418; +Xbyak::Label l428; +Xbyak::Label l4a0; +Xbyak::Label l4e8; +Xbyak::Label l50c; +Xbyak::Label l524; +Xbyak::Label l534; +Xbyak::Label lcc; + + preamble(); +#ifdef _WIN32 + auto stacksize = get_size_of_abi_save_regs(); + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l178, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + mov(I, M); + sar(I, 0x3); + jle(lcc, T_NEAR); + align(4); + +L(l34); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l34, T_NEAR); + align(4); + +L(lcc); + test(M, 0x4); + jle(l120, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l120); + test(M, 0x2); + jle(l14c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l14c); + test(M, 0x1); + jle(l168, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l168); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l178); + cmp(N, 0x4); + jl(l2a8, T_NEAR); + align(4); + +L(l184); + mov(A1, A); + add(A, 0x4); + mov(I, M); + sar(I, 0x3); + jle(l20c, T_NEAR); + align(4); + +L(l194); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l194, T_NEAR); + align(4); + +L(l20c); + test(M, 0x4); + jle(l250, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l250); + test(M, 0x2); + jle(l27c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l27c); + test(M, 0x1); + jle(l298, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l298); + sub(N, 0x4); + cmp(N, 0x4); + jge(l184, T_NEAR); + align(4); + +L(l2a8); + cmp(N, 0x2); + jl(l40e, T_NEAR); + align(4); + +L(l2b4); + mov(A1, A); + add(A, 0x2); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l360, T_NEAR); + align(4); + +L(l2c8); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l2c8, T_NEAR); + align(4); + +L(l360); + test(M, 0x4); + jle(l3b4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3b4); + test(M, 0x2); + jle(l3e8, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3e8); + test(M, 0x1); + jle(l400, T_NEAR); + mov(ax, word[A1-0x80]); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l400); + sub(N, 0x2); + cmp(N, 0x2); + jge(l2b4, T_NEAR); + align(4); + +L(l40e); + cmp(N, 0x1); + jl(l534, T_NEAR); + align(4); + +L(l418); + mov(A1, A); + add(A, 0x1); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l4a0, T_NEAR); + align(4); + +L(l428); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l428, T_NEAR); + align(4); + +L(l4a0); + test(M, 0x4); + jle(l4e8, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l4e8); + test(M, 0x2); + jle(l50c, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l50c); + test(M, 0x1); + jle(l524, T_NEAR); + mov(al, byte[A1-0x80]); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l524); + sub(N, 0x1); + cmp(N, 0x1); + jge(l418, T_NEAR); + align(4); + +L(l534); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp new file mode 100644 index 0000000000..49a312fc88 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp @@ -0,0 +1,1283 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1024; +Xbyak::Label l1090; +Xbyak::Label l10d4; +Xbyak::Label l10fc; +Xbyak::Label l111a; +Xbyak::Label l1124; +Xbyak::Label l113c; +Xbyak::Label l11d4; +Xbyak::Label l1234; +Xbyak::Label l1278; +Xbyak::Label l129c; +Xbyak::Label l12bc; +Xbyak::Label l20; +Xbyak::Label l2a0; +Xbyak::Label l3c0; +Xbyak::Label l438; +Xbyak::Label l480; +Xbyak::Label l48c; +Xbyak::Label l4c8; +Xbyak::Label l5c; +Xbyak::Label l6a8; +Xbyak::Label l7b4; +Xbyak::Label l850; +Xbyak::Label l89c; +Xbyak::Label l8a8; +Xbyak::Label l8d0; +Xbyak::Label l9d0; +Xbyak::Label la64; +Xbyak::Label lab8; +Xbyak::Label lae8; +Xbyak::Label laf4; +Xbyak::Label lb14; +Xbyak::Label lc30; +Xbyak::Label lcc8; +Xbyak::Label ld1c; +Xbyak::Label ld54; +Xbyak::Label ld78; +Xbyak::Label ld84; +Xbyak::Label ld9c; +Xbyak::Label le58; +Xbyak::Label lebc; +Xbyak::Label lef8; +Xbyak::Label lf1c; +Xbyak::Label lf3c; +Xbyak::Label lf48; +Xbyak::Label lf60; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x30); + jl(l480, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x30); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x2); + jle(l2a0, T_NEAR); + align(4); + +L(l5c); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1+LDA*1-0x80]); + vmovdqu(xmm2, xword[A1+LDA*2-0x80]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x80]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x60], xmm2); + vmovdqu(xword[B-0x50], xmm3); + vmovdqu(xmm0, xword[A1-0x70]); + vmovdqu(xmm1, xword[A1+LDA*1-0x70]); + vmovdqu(xmm2, xword[A1+LDA*2-0x70]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x70]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x20], xmm2); + vmovdqu(xword[B-0x10], xmm3); + vmovdqu(xmm0, xword[A1-0x60]); + vmovdqu(xmm1, xword[A1+LDA*1-0x60]); + vmovdqu(xmm2, xword[A1+LDA*2-0x60]); + vmovdqu(xmm3, xword[A1+LDA3*1-0x60]); + lea(A1, ptr[A1+LDA*4]); + vpunpcklbw(xmm4, xmm0, xmm1); + vpunpckhbw(xmm5, xmm0, xmm1); + vpunpcklbw(xmm6, xmm2, xmm3); + vpunpckhbw(xmm7, xmm2, xmm3); + vpunpcklwd(xmm0, xmm4, xmm6); + vpunpckhwd(xmm1, xmm4, xmm6); + vpunpcklwd(xmm2, xmm5, xmm7); + vpunpckhwd(xmm3, xmm5, xmm7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm2); + vmovhlps(xmm6, xmm2, xmm2); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B+0x20], xmm2); + vmovdqu(xword[B+0x30], xmm3); + sub(B, -192); + dec(I); + jg(l5c, T_NEAR); + align(4); + +L(l2a0); + test(M, 0x2); + jle(l3c0, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vmovdqu(xmm6, xword[A1-0x80]); + vmovdqu(xmm4, xword[A1-0x70]); + vmovdqu(xmm5, xword[A1-0x60]); + add(A1, LDA); + vpunpcklbw(xmm3, xmm0, xmm6); + vpunpckhbw(xmm0, xmm0, xmm6); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm8, ymm8, ymm7); + vmovdqu(xword[B-0x80], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x70], xmm0); + vpunpcklbw(xmm3, xmm1, xmm4); + vpunpckhbw(xmm0, xmm1, xmm4); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm10, ymm10, ymm7); + vmovdqu(xword[B-0x60], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x50], xmm0); + vpunpcklbw(xmm3, xmm2, xmm5); + vpunpckhbw(xmm0, xmm2, xmm5); + vpmovsxbw(ymm7, xmm3); + vmovhlps(xmm6, xmm3, xmm3); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm12, ymm12, ymm7); + vmovdqu(xword[B-0x40], xmm3); + vpmovsxbw(ymm7, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm7, ymm7, ymm6); + vpmovsxwd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x30], xmm0); + sub(B, -96); + align(4); + +L(l3c0); + test(M, 0x1); + jle(l438, T_NEAR); + vmovdqu(xmm0, xword[A1-0x80]); + vmovdqu(xmm1, xword[A1-0x70]); + vmovdqu(xmm2, xword[A1-0x60]); + add(A1, LDA); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + vpmovsxbd(ymm7, xmm1); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbd(ymm7, xmm2); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm2); + sub(B, -48); + align(4); + +L(l438); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(l480); + cmp(N, 0x20); + jl(l89c, T_NEAR); + align(4); + +L(l48c); + mov(A1, A); + add(A, 0x20); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x2); + jle(l6a8, T_NEAR); + align(4); + +L(l4c8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm2); + movdqu(xmm0, xword[A1-0x70]); + movdqu(xmm1, xword[A1+LDA*1-0x70]); + movdqu(xmm2, xword[A1+LDA*2-0x70]); + movdqu(xmm3, xword[A1+LDA3*1-0x70]); + lea(A1, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm5); + punpckhwd(xmm2, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm4); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm2); + sub(B, -128); + dec(I); + jg(l4c8, T_NEAR); + align(4); + +L(l6a8); + test(M, 0x2); + jle(l7b4, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + movdqu(xmm3, xword[A1-0x70]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm2); + punpckhbw(xmm4, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm4); + movdqa(xmm4, xmm1); + punpcklbw(xmm1, xmm3); + punpckhbw(xmm4, xmm3); + pmovsxbw(xmm5, xmm1); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm4); + sub(B, -64); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l850, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1-0x70]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + pmovsxbd(xmm5, xmm1); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm1, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm1, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm1, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l850); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(l48c, T_NEAR); + align(4); + +L(l89c); + cmp(N, 0x10); + jl(lae8, T_NEAR); + align(4); + +L(l8a8); + mov(A1, A); + add(A, 0x10); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x2); + jle(l9d0, T_NEAR); + align(4); + +L(l8d0); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm2, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm3, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm4, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm4, xmm1); + movdqa(xmm1, xmm2); + punpcklbw(xmm2, xmm3); + punpckhbw(xmm1, xmm3); + movdqa(xmm3, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm3, xmm2); + movdqa(xmm2, xmm4); + punpcklwd(xmm4, xmm1); + punpckhwd(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm3); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + pmovsxbw(xmm5, xmm2); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x60], xmm4); + movdqu(xword[B-0x50], xmm2); + sub(B, -64); + dec(I); + jg(l8d0, T_NEAR); + align(4); + +L(l9d0); + test(M, 0x2); + jle(la64, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + movdqu(xmm1, xword[A1-0x80]); + add(A1, LDA); + movdqa(xmm2, xmm0); + punpcklbw(xmm0, xmm1); + punpckhbw(xmm2, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + pmovsxbw(xmm5, xmm2); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm2); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + align(4); + +L(la64); + test(M, 0x1); + jle(lab8, T_NEAR); + movdqu(xmm0, xword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lab8); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l8a8, T_NEAR); + align(4); + +L(lae8); + cmp(N, 0x8); + jl(ld78, T_NEAR); + align(4); + +L(laf4); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(lc30, T_NEAR); + align(4); + +L(lb14); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(lb14, T_NEAR); + align(4); + +L(lc30); + test(M, 0x4); + jle(lcc8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(lcc8); + test(M, 0x2); + jle(ld1c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(ld1c); + test(M, 0x1); + jle(ld54, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(ld54); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(laf4, T_NEAR); + align(4); + +L(ld78); + cmp(N, 0x4); + jl(lf3c, T_NEAR); + align(4); + +L(ld84); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(le58, T_NEAR); + align(4); + +L(ld9c); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(ld9c, T_NEAR); + align(4); + +L(le58); + test(M, 0x4); + jle(lebc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(lebc); + test(M, 0x2); + jle(lef8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(lef8); + test(M, 0x1); + jle(lf1c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(ld84, T_NEAR); + align(4); + +L(lf3c); + cmp(N, 0x2); + jl(l111a, T_NEAR); + align(4); + +L(lf48); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l1024, T_NEAR); + align(4); + +L(lf60); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(lf60, T_NEAR); + align(4); + +L(l1024); + test(M, 0x4); + jle(l1090, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l1090); + test(M, 0x2); + jle(l10d4, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l10d4); + test(M, 0x1); + jle(l10fc, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l10fc); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(lf48, T_NEAR); + align(4); + +L(l111a); + cmp(N, 0x1); + jl(l12bc, T_NEAR); + align(4); + +L(l1124); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l11d4, T_NEAR); + align(4); + +L(l113c); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l113c, T_NEAR); + align(4); + +L(l11d4); + test(M, 0x4); + jle(l1234, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l1234); + test(M, 0x2); + jle(l1278, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l1278); + test(M, 0x1); + jle(l129c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l129c); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l1124, T_NEAR); + align(4); + +L(l12bc); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp new file mode 100644 index 0000000000..a4f4ff09c6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp @@ -0,0 +1,3163 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l1750; +Xbyak::Label l1b6c; +Xbyak::Label l1e14; +Xbyak::Label l20; +Xbyak::Label l2068; +Xbyak::Label l226c; +Xbyak::Label l22b8; +Xbyak::Label l22c4; +Xbyak::Label l22f4; +Xbyak::Label l26b4; +Xbyak::Label l28cc; +Xbyak::Label l2a2c; +Xbyak::Label l2b5c; +Xbyak::Label l2c64; +Xbyak::Label l2c94; +Xbyak::Label l2ca0; +Xbyak::Label l2cc8; +Xbyak::Label l2eac; +Xbyak::Label l2fc0; +Xbyak::Label l3078; +Xbyak::Label l3118; +Xbyak::Label l319c; +Xbyak::Label l31c0; +Xbyak::Label l31cc; +Xbyak::Label l31ec; +Xbyak::Label l32e4; +Xbyak::Label l3378; +Xbyak::Label l33dc; +Xbyak::Label l3434; +Xbyak::Label l347c; +Xbyak::Label l349c; +Xbyak::Label l34a8; +Xbyak::Label l34c8; +Xbyak::Label l3558; +Xbyak::Label l35b0; +Xbyak::Label l35f4; +Xbyak::Label l3638; +Xbyak::Label l366c; +Xbyak::Label l368a; +Xbyak::Label l3694; +Xbyak::Label l36a8; +Xbyak::Label l36ec; +Xbyak::Label l3728; +Xbyak::Label l3760; +Xbyak::Label l3794; +Xbyak::Label l37b8; +Xbyak::Label l37d8; +Xbyak::Label l5cc; +Xbyak::Label l6c; +Xbyak::Label l968; +Xbyak::Label lc80; +Xbyak::Label lf1c; +Xbyak::Label lf64; +Xbyak::Label lf70; +Xbyak::Label lfb4; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x30); + jl(lf64, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + lea(I, ptr[I+LDA*8]); + lea(I, ptr[I+LDA*8]); + add(A, I); + vxorps(ymm8, ymm8, ymm8); + vxorps(ymm9, ymm9, ymm9); + vxorps(ymm10, ymm10, ymm10); + vxorps(ymm11, ymm11, ymm11); + vxorps(ymm12, ymm12, ymm12); + vxorps(ymm13, ymm13, ymm13); + vxorps(ymm14, ymm14, ymm14); + vxorps(ymm15, ymm15, ymm15); + mov(I, M); + sar(I, 0x3); + jle(l5cc, T_NEAR); + align(4); + +L(l6c); + vmovq(xmm0, qword[A1-0x80]); + vmovq(xmm1, qword[A1+LDA*1-0x80]); + vmovq(xmm2, qword[A1+LDA*2-0x80]); + vmovq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x80], xmm0); + vmovdqu(xword[B+0x40], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x70], xmm2); + vmovdqu(xword[B+0x50], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x60], xmm0); + vmovdqu(xword[B+0x60], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x50], xmm2); + vmovdqu(xword[B+0x70], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x40], xmm0); + vmovdqu(xword[B+0x80], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x30], xmm2); + vmovdqu(xword[B+0x90], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x20], xmm0); + vmovdqu(xword[B+0xa0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B-0x10], xmm2); + vmovdqu(xword[B+0xb0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B], xmm0); + vmovdqu(xword[B+0xc0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x10], xmm2); + vmovdqu(xword[B+0xd0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovq(xmm0, qword[A2-0x80]); + vmovq(xmm1, qword[A2+LDA*1-0x80]); + vmovq(xmm2, qword[A2+LDA*2-0x80]); + vmovq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm0, xmm1); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm1, xmm3); + vpunpckhqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x20], xmm0); + vmovdqu(xword[B+0xe0], xmm1); + vmovq(xmm2, qword[A2-0x80]); + vmovq(xmm3, qword[A2+LDA*1-0x80]); + vmovq(xmm4, qword[A2+LDA*2-0x80]); + vmovq(xmm5, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm3, xmm2, xmm3); + vpunpckldq(xmm5, xmm4, xmm5); + vpunpcklqdq(xmm2, xmm3, xmm5); + vpunpckhqdq(xmm3, xmm3, xmm5); + vmovdqu(xword[B+0x30], xmm2); + vmovdqu(xword[B+0xf0], xmm3); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm2); + vmovhlps(xmm7, xmm2, xmm2); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vpmovsxbw(ymm5, xmm1); + vmovhlps(xmm6, xmm1, xmm1); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm3); + vmovhlps(xmm7, xmm3, xmm3); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -8); + sub(B, -384); + dec(I); + jg(l6c, T_NEAR); + align(4); + +L(l5cc); + test(M, 0x4); + jle(l968, T_NEAR); + vmovd(xmm0, dword[A1-0x80]); + vmovd(xmm1, dword[A1+LDA*1-0x80]); + vmovd(xmm2, dword[A1+LDA*2-0x80]); + vmovd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x80], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x70], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x60], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x50], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x40], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B-0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B-0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x10], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovd(xmm0, dword[A2-0x80]); + vmovd(xmm1, dword[A2+LDA*1-0x80]); + vmovd(xmm2, dword[A2+LDA*2-0x80]); + vmovd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm0, xmm0, xmm1); + vpunpckldq(xmm2, xmm2, xmm3); + vpunpcklqdq(xmm0, xmm0, xmm2); + vmovdqu(xword[B+0x20], xmm0); + vmovd(xmm1, dword[A2-0x80]); + vmovd(xmm2, dword[A2+LDA*1-0x80]); + vmovd(xmm3, dword[A2+LDA*2-0x80]); + vmovd(xmm4, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpunpckldq(xmm1, xmm1, xmm2); + vpunpckldq(xmm3, xmm3, xmm4); + vpunpcklqdq(xmm1, xmm1, xmm3); + vmovdqu(xword[B+0x30], xmm1); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxbw(ymm6, xmm1); + vmovhlps(xmm7, xmm1, xmm1); + vpmovsxbw(ymm7, xmm7); + vphaddw(ymm6, ymm6, ymm7); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + sub(A1, -4); + sub(B, -192); + align(4); + +L(l968); + test(M, 0x2); + jle(lc80, T_NEAR); + mov(ax, word[A1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x7); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm8, ymm8, ymm5); + vmovdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm9, ymm9, ymm5); + vmovdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm10, ymm10, ymm5); + vmovdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm11, ymm11, ymm5); + vmovdqu(xword[B-0x50], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm12, ymm12, ymm5); + vmovdqu(xword[B-0x40], xmm0); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrw(xmm0, xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + vpinsrw(xmm0, xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + vpmovsxbw(ymm5, xmm0); + vmovhlps(xmm6, xmm0, xmm0); + vpmovsxbw(ymm6, xmm6); + vphaddw(ymm5, ymm5, ymm6); + vpmovsxwd(ymm5, xmm5); + vpaddd(ymm13, ymm13, ymm5); + vmovdqu(xword[B-0x30], xmm0); + sub(A1, -2); + sub(B, -96); + align(4); + +L(lc80); + test(M, 0x1); + jle(lf1c, T_NEAR); + mov(al, byte[A1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm8, ymm8, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm9, ymm9, ymm7); + vmovdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm10, ymm10, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm11, ymm11, ymm7); + vmovdqu(xword[B-0x70], xmm0); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + vpinsrb(xmm0, xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + vpinsrb(xmm0, xmm0, eax, 0xf); + vpmovsxbd(ymm7, xmm0); + vpaddd(ymm12, ymm12, ymm7); + vmovhlps(xmm7, xmm0, xmm0); + vpmovsxbd(ymm7, xmm7); + vpaddd(ymm13, ymm13, ymm7); + vmovdqu(xword[B-0x60], xmm0); + sub(B, -48); + align(4); + +L(lf1c); + mov(A1, qword[ARG_BIAS]); + vmovdqu(yword[A1], ymm8); + vmovdqu(yword[A1+0x20], ymm9); + vmovdqu(yword[A1+0x40], ymm10); + vmovdqu(yword[A1+0x60], ymm11); + vmovdqu(yword[A1+0x80], ymm12); + vmovdqu(yword[A1+0xa0], ymm13); + add(qword[ARG_BIAS], 0xc0); + sub(N, 0x30); + cmp(N, 0x30); + jge(l20, T_NEAR); + vzeroupper(); + align(4); + +L(lf64); + cmp(N, 0x20); + jl(l22b8, T_NEAR); + align(4); + +L(lf70); + mov(A1, A); + mov(I, LDA); + shl(I, 0x5); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + pxor(xmm12, xmm12); + pxor(xmm13, xmm13); + pxor(xmm14, xmm14); + pxor(xmm15, xmm15); + mov(I, M); + sar(I, 0x4); + jle(l1750, T_NEAR); + align(4); + +L(lfb4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x80], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x100], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x90], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x110], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0xa0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x120], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0xb0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x130], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0xc0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x140], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0xd0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x150], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0xe0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x160], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0xf0], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x170], xmm3); + sub(A1, -16); + sub(B, -512); + dec(I); + jg(lfb4, T_NEAR); + align(4); + +L(l1750); + test(M, 0x8); + jle(l1b6c, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B+0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B+0x50], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B+0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B+0x70], xmm1); + sub(A1, -8); + sub(B, -256); + align(4); + +L(l1b6c); + test(M, 0x4); + jle(l1e14, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movdqu(xword[B-0x40], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm13, xmm5); + movdqu(xword[B-0x30], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movdqu(xword[B-0x20], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm15, xmm5); + movdqu(xword[B-0x10], xmm0); + sub(A1, -4); + sub(B, -128); + align(4); + +L(l1e14); + test(M, 0x2); + jle(l2068, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm12, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm13, xmm6); + movdqu(xword[B-0x60], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + lea(A2, ptr[A2+LDA*4]); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm14, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x50], xmm0); + sub(A1, -2); + sub(B, -64); + align(4); + +L(l2068); + test(M, 0x1); + jle(l226c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm12, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm13, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm14, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm15, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l226c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + movdqu(xword[A1+0x40], xmm12); + movdqu(xword[A1+0x50], xmm13); + movdqu(xword[A1+0x60], xmm14); + movdqu(xword[A1+0x70], xmm15); + add(qword[ARG_BIAS], 0x80); + sub(N, 0x20); + cmp(N, 0x20); + jge(lf70, T_NEAR); + align(4); + +L(l22b8); + cmp(N, 0x10); + jl(l2c94, T_NEAR); + align(4); + +L(l22c4); + mov(A1, A); + mov(I, LDA); + shl(I, 0x4); + add(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + pxor(xmm10, xmm10); + pxor(xmm11, xmm11); + mov(I, M); + sar(I, 0x4); + jle(l26b4, T_NEAR); + align(4); + +L(l22f4); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B+0x40], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x10], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B+0x50], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x20], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B+0x60], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B+0x70], xmm3); + sub(A1, -16); + sub(B, -256); + dec(I); + jg(l22f4, T_NEAR); + align(4); + +L(l26b4); + test(M, 0x8); + jle(l28cc, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x20], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x10], xmm1); + sub(A1, -8); + sub(B, -128); + align(4); + +L(l28cc); + test(M, 0x4); + jle(l2a2c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movdqu(xword[B-0x60], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm11, xmm5); + movdqu(xword[B-0x50], xmm0); + sub(A1, -4); + sub(B, -64); + align(4); + +L(l2a2c); + test(M, 0x2); + jle(l2b5c, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm10, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x70], xmm0); + sub(A1, -2); + sub(B, -32); + align(4); + +L(l2b5c); + test(M, 0x1); + jle(l2c64, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + lea(A2, ptr[A1+LDA*4]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0x7); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x8); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x9); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xa); + mov(al, byte[A2+LDA3*1-0x80]); + lea(A2, ptr[A2+LDA*4]); + pinsrb(xmm0, eax, 0xb); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0xc); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0xd); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0xe); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0xf); + pmovsxbd(xmm5, xmm0); + paddd(xmm8, xmm5); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm9, xmm6); + pshufd(xmm5, xmm0, 0xaa); + pmovsxbd(xmm5, xmm5); + paddd(xmm10, xmm5); + pshufd(xmm6, xmm0, 0xff); + pmovsxbd(xmm6, xmm6); + paddd(xmm11, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l2c64); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + movdqu(xword[A1+0x20], xmm10); + movdqu(xword[A1+0x30], xmm11); + add(qword[ARG_BIAS], 0x40); + sub(N, 0x10); + cmp(N, 0x10); + jge(l22c4, T_NEAR); + align(4); + +L(l2c94); + cmp(N, 0x8); + jl(l31c0, T_NEAR); + align(4); + +L(l2ca0); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l2eac, T_NEAR); + align(4); + +L(l2cc8); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l2cc8, T_NEAR); + align(4); + +L(l2eac); + test(M, 0x8); + jle(l2fc0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l2fc0); + test(M, 0x4); + jle(l3078, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3078); + test(M, 0x2); + jle(l3118, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3118); + test(M, 0x1); + jle(l319c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l319c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l2ca0, T_NEAR); + align(4); + +L(l31c0); + cmp(N, 0x4); + jl(l349c, T_NEAR); + align(4); + +L(l31cc); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l32e4, T_NEAR); + align(4); + +L(l31ec); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l31ec, T_NEAR); + align(4); + +L(l32e4); + test(M, 0x8); + jle(l3378, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l3378); + test(M, 0x4); + jle(l33dc, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l33dc); + test(M, 0x2); + jle(l3434, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3434); + test(M, 0x1); + jle(l347c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l347c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l31cc, T_NEAR); + align(4); + +L(l349c); + cmp(N, 0x2); + jl(l368a, T_NEAR); + align(4); + +L(l34a8); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l3558, T_NEAR); + align(4); + +L(l34c8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l34c8, T_NEAR); + align(4); + +L(l3558); + test(M, 0x8); + jle(l35b0, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l35b0); + test(M, 0x4); + jle(l35f4, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l35f4); + test(M, 0x2); + jle(l3638, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3638); + test(M, 0x1); + jle(l366c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l366c); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l34a8, T_NEAR); + align(4); + +L(l368a); + cmp(N, 0x1); + jl(l37d8, T_NEAR); + align(4); + +L(l3694); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l36ec, T_NEAR); + align(4); + +L(l36a8); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(l36a8, T_NEAR); + align(4); + +L(l36ec); + test(M, 0x8); + jle(l3728, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l3728); + test(M, 0x4); + jle(l3760, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l3760); + test(M, 0x2); + jle(l3794, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(l3794); + test(M, 0x1); + jle(l37b8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l37b8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l3694, T_NEAR); + align(4); + +L(l37d8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp new file mode 100644 index 0000000000..c7f1393c9d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp @@ -0,0 +1,821 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l20; +Xbyak::Label l22c; +Xbyak::Label l340; +Xbyak::Label l3f8; +Xbyak::Label l48; +Xbyak::Label l498; +Xbyak::Label l51c; +Xbyak::Label l540; +Xbyak::Label l54c; +Xbyak::Label l56c; +Xbyak::Label l664; +Xbyak::Label l6f8; +Xbyak::Label l75c; +Xbyak::Label l7b4; +Xbyak::Label l7fc; +Xbyak::Label l81c; +Xbyak::Label l828; +Xbyak::Label l848; +Xbyak::Label l8d8; +Xbyak::Label l930; +Xbyak::Label l974; +Xbyak::Label l9b8; +Xbyak::Label l9ec; +Xbyak::Label la0a; +Xbyak::Label la14; +Xbyak::Label la28; +Xbyak::Label la6c; +Xbyak::Label laa8; +Xbyak::Label lae0; +Xbyak::Label lb14; +Xbyak::Label lb38; +Xbyak::Label lb58; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(N, qword[N]); + mov(M, qword[M]); + mov(LDA, qword[LDA]); + sub(A, -128); + sub(B, -128); + lea(LDA3, ptr[LDA+LDA*2]); + cmp(N, 0x8); + jl(l540, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + lea(A2, ptr[A1+LDA*4]); + lea(I, ptr[A1+LDA*8]); + mov(A, I); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x4); + jle(l22c, T_NEAR); + align(4); + +L(l48); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + movdqu(xmm2, xword[A1+LDA*2-0x80]); + movdqu(xmm3, xword[A1+LDA3*1-0x80]); + sub(A1, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x40], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x20], xmm3); + movdqu(xmm0, xword[A2-0x80]); + movdqu(xmm1, xword[A2+LDA*1-0x80]); + movdqu(xmm2, xword[A2+LDA*2-0x80]); + movdqu(xmm3, xword[A2+LDA3*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x30], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x10], xmm3); + sub(B, -128); + dec(I); + jg(l48, T_NEAR); + align(4); + +L(l22c); + test(M, 0x8); + jle(l340, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + movq(xmm2, qword[A1+LDA*2-0x80]); + movq(xmm3, qword[A1+LDA3*1-0x80]); + sub(A1, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x60], xmm1); + movq(xmm0, qword[A2-0x80]); + movq(xmm1, qword[A2+LDA*1-0x80]); + movq(xmm2, qword[A2+LDA*2-0x80]); + movq(xmm3, qword[A2+LDA3*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + align(4); + +L(l340); + test(M, 0x4); + jle(l3f8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + movd(xmm2, dword[A1+LDA*2-0x80]); + movd(xmm3, dword[A1+LDA3*1-0x80]); + sub(A1, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A2-0x80]); + movd(xmm1, dword[A2+LDA*1-0x80]); + movd(xmm2, dword[A2+LDA*2-0x80]); + movd(xmm3, dword[A2+LDA3*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + align(4); + +L(l3f8); + test(M, 0x2); + jle(l498, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A1+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A1+LDA3*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x3); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x4); + mov(ax, word[A2+LDA*1-0x80]); + pinsrw(xmm0, eax, 0x5); + mov(ax, word[A2+LDA*2-0x80]); + pinsrw(xmm0, eax, 0x6); + mov(ax, word[A2+LDA3*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l498); + test(M, 0x1); + jle(l51c, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A2+LDA*2-0x80]); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A2+LDA3*1-0x80]); + pinsrb(xmm0, eax, 0x7); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l51c); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l540); + cmp(N, 0x4); + jl(l81c, T_NEAR); + align(4); + +L(l54c); + mov(A1, A); + lea(A2, ptr[A1+LDA*2]); + lea(I, ptr[A1+LDA*4]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l664, T_NEAR); + align(4); + +L(l56c); + movdqu(xmm0, xword[A1-0x80]); + movdqu(xmm1, xword[A1+LDA*1-0x80]); + sub(A1, -16); + movdqu(xmm2, xword[A2-0x80]); + movdqu(xmm3, xword[A2+LDA*1-0x80]); + sub(A2, -16); + movdqa(xmm4, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm4, xmm1); + movdqa(xmm5, xmm2); + punpckldq(xmm2, xmm3); + punpckhdq(xmm5, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + movdqa(xmm3, xmm4); + punpcklqdq(xmm4, xmm5); + punpckhqdq(xmm3, xmm5); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + pmovsxbw(xmm5, xmm4); + movhlps(xmm6, xmm4); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x60], xmm4); + pmovsxbw(xmm5, xmm3); + movhlps(xmm6, xmm3); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x50], xmm3); + sub(B, -64); + dec(I); + jg(l56c, T_NEAR); + align(4); + +L(l664); + test(M, 0x8); + jle(l6f8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + movq(xmm1, qword[A1+LDA*1-0x80]); + sub(A1, -8); + movq(xmm2, qword[A2-0x80]); + movq(xmm3, qword[A2+LDA*1-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklqdq(xmm0, xmm2); + punpckhqdq(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l6f8); + test(M, 0x4); + jle(l75c, T_NEAR); + movd(xmm0, dword[A1-0x80]); + movd(xmm1, dword[A1+LDA*1-0x80]); + sub(A1, -4); + movd(xmm2, dword[A2-0x80]); + movd(xmm3, dword[A2+LDA*1-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + punpckldq(xmm2, xmm3); + punpcklqdq(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l75c); + test(M, 0x2); + jle(l7b4, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1+LDA*1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x1); + mov(ax, word[A2-0x80]); + pinsrw(xmm0, eax, 0x2); + mov(ax, word[A2+LDA*1-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l7b4); + test(M, 0x1); + jle(l7fc, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A2+LDA*1-0x80]); + pinsrb(xmm0, eax, 0x3); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l7fc); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l54c, T_NEAR); + align(4); + +L(l81c); + cmp(N, 0x2); + jl(la0a, T_NEAR); + align(4); + +L(l828); + mov(A1, A); + lea(A2, ptr[A1+LDA*1]); + lea(I, ptr[A1+LDA*2]); + mov(A, I); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(l8d8, T_NEAR); + align(4); + +L(l848); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + movdqu(xmm1, xword[A2-0x80]); + sub(A2, -16); + movdqa(xmm2, xmm0); + punpckldq(xmm0, xmm1); + punpckhdq(xmm2, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + pshufd(xmm6, xmm2, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm2); + sub(B, -32); + dec(I); + jg(l848, T_NEAR); + align(4); + +L(l8d8); + test(M, 0x8); + jle(l930, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + movq(xmm1, qword[A2-0x80]); + sub(A2, -8); + punpckldq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l930); + test(M, 0x4); + jle(l974, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + movd(xmm1, dword[A2-0x80]); + sub(A2, -4); + punpckldq(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l974); + test(M, 0x2); + jle(l9b8, T_NEAR); + mov(ax, word[A1-0x80]); + sub(A1, -2); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A2-0x80]); + sub(A2, -2); + pinsrw(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l9b8); + test(M, 0x1); + jle(l9ec, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A2-0x80]); + pinsrb(xmm0, eax, 0x1); + mov(byte[B-0x7f], al); + sub(B, -2); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + align(4); + +L(l9ec); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l828, T_NEAR); + align(4); + +L(la0a); + cmp(N, 0x1); + jl(lb58, T_NEAR); + align(4); + +L(la14); + mov(A1, A); + add(A, LDA); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x4); + jle(la6c, T_NEAR); + align(4); + +L(la28); + movdqu(xmm0, xword[A1-0x80]); + sub(A1, -16); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(I); + jg(la28, T_NEAR); + align(4); + +L(la6c); + test(M, 0x8); + jle(laa8, T_NEAR); + movq(xmm0, qword[A1-0x80]); + sub(A1, -8); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(laa8); + test(M, 0x4); + jle(lae0, T_NEAR); + movd(xmm0, dword[A1-0x80]); + sub(A1, -4); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(lae0); + test(M, 0x2); + jle(lb14, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(A1, -2); + sub(B, -2); + align(4); + +L(lb14); + test(M, 0x1); + jle(lb38, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrb(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(lb38); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(la14, T_NEAR); + align(4); + +L(lb58); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp new file mode 100644 index 0000000000..afe4f1713e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp @@ -0,0 +1,647 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_generator.hpp" +#include "common.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE) +{ + +#ifndef _WIN32 +#define M rdi +#define N rsi +#define A rdx +#define LDA rcx +#define ALPHA r8 +#define B r9 + +#define I rax +#define A1 r10 +#define A2 r8 +#define LDA3 r11 + +#define ARG_BIAS 24+stacksize+rsp + +#else + +#define M rcx +#define N rdx +#define A r8 +#define LDA r9 +#define ALPHA rax +#define B rdi + +#define I rax +#define A1 rsi +#define A2 r10 +#define LDA3 r11 + +#define ARG_ALPHA 40+stacksize+rsp +#define ARG_B 48+stacksize+rsp +#define ARG_BIAS 72+stacksize+rsp + +#endif + +inLocalLabel(); +{ + +Xbyak::Label l15c; +Xbyak::Label l1f4; +Xbyak::Label l20; +Xbyak::Label l248; +Xbyak::Label l280; +Xbyak::Label l2a4; +Xbyak::Label l2b0; +Xbyak::Label l2c8; +Xbyak::Label l384; +Xbyak::Label l3e8; +Xbyak::Label l40; +Xbyak::Label l424; +Xbyak::Label l448; +Xbyak::Label l468; +Xbyak::Label l474; +Xbyak::Label l48c; +Xbyak::Label l550; +Xbyak::Label l5bc; +Xbyak::Label l600; +Xbyak::Label l628; +Xbyak::Label l646; +Xbyak::Label l650; +Xbyak::Label l668; +Xbyak::Label l700; +Xbyak::Label l760; +Xbyak::Label l7a4; +Xbyak::Label l7c8; +Xbyak::Label l7e8; + + preamble(); + auto stacksize = get_size_of_abi_save_regs(); +#ifdef _WIN32 + mov(ALPHA, ptr[ARG_ALPHA]); + mov(B, ptr[ARG_B]); +#endif + + mov(M, qword[M]); + mov(N, qword[N]); + mov(LDA, qword[LDA]); + lea(LDA3, ptr[LDA+LDA*2]); + sub(A, -128); + sub(B, -128); + cmp(N, 0x8); + jl(l2a4, T_NEAR); + align(4); + +L(l20); + mov(A1, A); + add(A, 0x8); + pxor(xmm8, xmm8); + pxor(xmm9, xmm9); + mov(I, M); + sar(I, 0x3); + jle(l15c, T_NEAR); + align(4); + +L(l40); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x60], xmm0); + movdqu(xword[B-0x50], xmm1); + sub(B, -64); + dec(I); + jg(l40, T_NEAR); + align(4); + +L(l15c); + test(M, 0x4); + jle(l1f4, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + movq(xmm2, qword[A1-0x80]); + add(A1, LDA); + movq(xmm3, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + movdqa(xmm1, xmm0); + punpcklwd(xmm0, xmm2); + punpckhwd(xmm1, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + pmovsxbw(xmm5, xmm1); + movhlps(xmm6, xmm1); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm9, xmm5); + movdqu(xword[B-0x80], xmm0); + movdqu(xword[B-0x70], xmm1); + sub(B, -32); + align(4); + +L(l1f4); + test(M, 0x2); + jle(l248, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + movq(xmm1, qword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm8, xmm5); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm6, xmm6); + pmovsxwd(xmm6, xmm6); + paddd(xmm9, xmm6); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l248); + test(M, 0x1); + jle(l280, T_NEAR); + movq(xmm0, qword[A1-0x80]); + add(A1, LDA); + pmovsxbd(xmm5, xmm0); + pshufd(xmm6, xmm0, 0x55); + pmovsxbd(xmm6, xmm6); + paddd(xmm8, xmm5); + paddd(xmm9, xmm6); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l280); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm8); + movdqu(xword[A1+0x10], xmm9); + add(qword[ARG_BIAS], 0x20); + sub(N, 0x8); + cmp(N, 0x8); + jge(l20, T_NEAR); + align(4); + +L(l2a4); + cmp(N, 0x4); + jl(l468, T_NEAR); + align(4); + +L(l2b0); + mov(A1, A); + add(A, 0x4); + pxor(xmm7, xmm7); + mov(I, M); + sar(I, 0x3); + jle(l384, T_NEAR); + align(4); + +L(l2c8); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x70], xmm0); + sub(B, -32); + dec(I); + jg(l2c8, T_NEAR); + align(4); + +L(l384); + test(M, 0x4); + jle(l3e8, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + movd(xmm2, dword[A1-0x80]); + add(A1, LDA); + movd(xmm3, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + movhlps(xmm6, xmm0); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + align(4); + +L(l3e8); + test(M, 0x2); + jle(l424, T_NEAR); + movd(xmm0, dword[A1-0x80]); + add(A1, LDA); + movd(xmm1, dword[A1-0x80]); + add(A1, LDA); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l424); + test(M, 0x1); + jle(l448, T_NEAR); + movd(xmm0, dword[A1-0x80]); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l448); + mov(A1, qword[ARG_BIAS]); + movdqu(xword[A1], xmm7); + add(qword[ARG_BIAS], 0x10); + sub(N, 0x4); + cmp(N, 0x4); + jge(l2b0, T_NEAR); + align(4); + +L(l468); + cmp(N, 0x2); + jl(l646, T_NEAR); + align(4); + +L(l474); + mov(A1, A); + add(A, 0x2); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l550, T_NEAR); + align(4); + +L(l48c); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm4, eax, 0x0); + punpcklbw(xmm1, xmm2); + punpcklbw(xmm3, xmm4); + punpcklwd(xmm1, xmm3); + punpcklqdq(xmm0, xmm1); + pshufd(xmm6, xmm0, 0xd8); + pmovsxbw(xmm5, xmm6); + movhlps(xmm6, xmm6); + pmovsxbw(xmm6, xmm6); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movdqu(xword[B-0x80], xmm0); + sub(B, -16); + dec(LDA3); + jg(l48c, T_NEAR); + align(4); + +L(l550); + test(M, 0x4); + jle(l5bc, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm2, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm3, eax, 0x0); + punpcklbw(xmm0, xmm1); + punpcklbw(xmm2, xmm3); + punpcklwd(xmm0, xmm2); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + align(4); + +L(l5bc); + test(M, 0x2); + jle(l600, T_NEAR); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm0, eax, 0x0); + mov(ax, word[A1-0x80]); + add(A1, LDA); + pinsrw(xmm1, eax, 0x0); + punpcklbw(xmm0, xmm1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l600); + test(M, 0x1); + jle(l628, T_NEAR); + mov(ax, word[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(word[B-0x80], ax); + sub(B, -2); + align(4); + +L(l628); + mov(A1, qword[ARG_BIAS]); + movq(qword[A1], xmm7); + add(qword[ARG_BIAS], 0x8); + sub(N, 0x2); + cmp(N, 0x2); + jge(l474, T_NEAR); + align(4); + +L(l646); + cmp(N, 0x1); + jl(l7e8, T_NEAR); + align(4); + +L(l650); + mov(A1, A); + add(A, 0x1); + pxor(xmm7, xmm7); + mov(LDA3, M); + sar(LDA3, 0x3); + jle(l700, T_NEAR); + align(4); + +L(l668); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x4); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x5); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x6); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x7); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm6); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movq(qword[B-0x80], xmm0); + sub(B, -8); + dec(LDA3); + jg(l668, T_NEAR); + align(4); + +L(l700); + test(M, 0x4); + jle(l760, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x2); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x3); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + movd(dword[B-0x80], xmm0); + sub(B, -4); + align(4); + +L(l760); + test(M, 0x2); + jle(l7a4, T_NEAR); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x0); + mov(byte[B-0x80], al); + mov(al, byte[A1-0x80]); + add(A1, LDA); + pinsrb(xmm0, eax, 0x1); + pmovsxbw(xmm5, xmm0); + phaddw(xmm5, xmm5); + pmovsxwd(xmm5, xmm5); + paddd(xmm7, xmm5); + mov(byte[B-0x7f], al); + sub(B, -2); + align(4); + +L(l7a4); + test(M, 0x1); + jle(l7c8, T_NEAR); + mov(al, byte[A1-0x80]); + pinsrw(xmm0, eax, 0x0); + pmovsxbd(xmm5, xmm0); + paddd(xmm7, xmm5); + mov(byte[B-0x80], al); + sub(B, -1); + align(4); + +L(l7c8); + mov(A1, qword[ARG_BIAS]); + movd(dword[A1], xmm7); + add(qword[ARG_BIAS], 0x4); + sub(N, 0x1); + cmp(N, 0x1); + jge(l650, T_NEAR); + align(4); + +L(l7e8); + + postamble(); +} +outLocalLabel(); + +#undef M +#undef N +#undef A +#undef LDA +#undef ALPHA +#undef B +#undef I +#undef A1 +#undef A2 +#undef LDA3 +#ifdef _WIN32 +#undef ARG_ALPHA +#undef ARG_B +#endif +#undef ARG_BIAS +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp new file mode 100644 index 0000000000..4fc11afcbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <cstdint> + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "../f32/ref_gemm_f32.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <typename b_dt> +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co) { + + if (*M == 0 || *N == 0 || *K == 0) + return mkldnn_success; + + bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); + bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); + bool AisN = (*transa == 'N' || *transa == 'n'); + bool BisN = (*transb == 'N' || *transb == 'n'); + + int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC; + size_t sizeA = AisN ? lda * k : lda * m; + size_t sizeB = BisN ? ldb * n : ldb * k; + size_t sizeC = ldc * n; + + double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K); + double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K); + double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K); + + if (utils::any_null(dA, dB, dC)) { + free(dA); + free(dB); + free(dC); + return mkldnn_out_of_memory; + } + + auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; }; + auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; }; + + auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; }; + auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; }; + + const int a_rows = AisN ? m : k; + const int a_cols = AisN ? k : m; + mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) { + da_setter(i, j, + static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0])); + }); + + const int b_rows = BisN ? k : n; + const int b_cols = BisN ? n : k; + mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) { + db_setter(i, j, + static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0])); + }); + double one = 1.0, zero = 0.0; + ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero, + dC, LDC, nullptr); + + auto i2d = [=] (int32_t v) { return static_cast<double>(v); }; + auto f2d = [=] (float v) { return static_cast<double>(v); }; + + mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) { + double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]); + double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc])) + + f2d(*alpha) * dC[i + j * ldc] + coffset; + C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val)); + }); + + free(dA); + free(dB); + free(dC); + return mkldnn_success; +} + +template mkldnn_status_t ref_gemm_s8x8s32<uint8_t>( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const uint8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +template mkldnn_status_t ref_gemm_s8x8s32<int8_t>( + const char *transa, const char *transb, const char *offsetc, + const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const int8_t *B, const int *LDB, const int8_t *bo, + const float *beta, int32_t *C, const int *LDC, const int32_t *co); + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp new file mode 100644 index 0000000000..6c0370ae99 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_GEMM_S8X8S32_HPP +#define REF_GEMM_S8X8S32_HPP + +#include <stdint.h> + +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <typename b_dt> +mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb, + const char *offsetc, const int *M, const int *N, const int *K, + const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, + const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, + int32_t *C, const int *LDC, const int32_t *co); + +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp new file mode 100644 index 0000000000..de1035f3b2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common.hpp" +#include "nstl.hpp" +#include "math_utils.hpp" + +#include "../gemm.hpp" +#include "jit_avx512_core_gemm_s8u8s32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void compensation_init(const char *offsetC, int32_t *compensation, int len, + const int32_t *oc) { + bool OCisC = (*offsetC == 'C' || *offsetC == 'c'); + bool OCisF = (*offsetC == 'F' || *offsetC == 'f'); + + if (OCisF && (*oc) != 0) { + for (int i = 0; i < len; i++) + compensation[i] = *oc; + } else if (OCisC) { + for (int i = 0; i < len; i++) + compensation[i] = oc[i]; + } else { + parallel_nd(len, [=](int i) { compensation[i] = 0; }); + } +} + +void compensation_compute(bool transa, int m, int k, float alpha, + const int8_t *a, int lda, int32_t *compensation) { + if (!transa) { + const int L2_cache_size = get_cache_size(2, true); + const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1); + const int npanels = k / blocking_factor; + const bool has_tile = k % blocking_factor > 0; + + parallel_nd(npanels, m, [&](int j, int i) { + int32_t val = 0; + for (int jb = 0; jb < blocking_factor; jb++) { + val += a[(i + (ptrdiff_t)j * blocking_factor * lda) + + (ptrdiff_t)jb * lda]; + } + if (alpha != 1.0f) { + val = math::out_round<int32_t>(math::saturate<int32_t>( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + + if (has_tile) { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = npanels * blocking_factor; j < k; j++) { + val += a[i + (ptrdiff_t)j * lda]; + } + if (alpha != 1.0f) { + val = math::out_round<int32_t>(math::saturate<int32_t>( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + fetch_and_add(&compensation[i], val); + }); + } + } else { + parallel_nd(m, [=](int i) { + int32_t val = 0; + for (int j = 0; j < k; j++) { + val += a[j + (ptrdiff_t)i * lda]; + } + if (alpha != 1.0f) { + val = math::out_round<int32_t>(math::saturate<int32_t>( + (double)val * alpha * -128.0)); + } else { + val *= -128; + } + compensation[i] += val; + }); + } +} + +void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8, + const int8_t *b_s8, int ldb_s8) { + const int b_cols = transb ? k : n; + + parallel_nd(b_cols, [=](int j) { + const int b_rows = transb ? n : k; + + uint8_t *pb_u8 = b_u8 + j * ldb_u8; + const int8_t *pb_s8 = b_s8 + j * ldb_s8; + + for (int i = 0; i < b_rows; i++) { + (*pb_u8) = (*pb_s8) + 128; + pb_u8++; + pb_s8++; + } + }); +} + +/** + * gemm_s8s8s32 operation is defined as follows: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation + * + * where + * - compensation is a vector of length m that contains computed compensation + * that may contain C_offset if applicable. The compensation is applied inside + * gemm_s8u8s32 as a C_offset + * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128 + * + * What is the compensation: + * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied: + * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset = + * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset + * compensation = -alpha * op(A) * B_shift + * Since B_shift is a matrix, every element of which is equal to 128 then + * - if op(A) = A: compensation contains sum of the elements in each row + * scaled by -128 * alpha + * - if op(A) = A**T: compensation contains sum of the elements in each column + * scaled by -128 * alpha + * + * The rest of parameters is described in mkldnn.h + */ +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc) { + if (*oa != 0 || *ob != 0) return mkldnn_unimplemented; + + int M = *m, N = *n, K = *k; + bool transa = (*transA == 'T' || *transA == 't'); + bool transb = (*transB == 'T' || *transB == 't'); + int ld = transb ? N : K; + + uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64); + int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64); + + if (utils::any_null(b_u8, compensation)) { + free(b_u8); + free(compensation); + return mkldnn_out_of_memory; + } + + compensation_init(offsetC, compensation, M, oc); + compensation_compute(transa, M, K, *alpha, a, *lda, compensation); + copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb); + + gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8, + &ld, ob, beta, c, ldc, compensation); + + if ((*offsetC == 'R' || *offsetC == 'r')) + parallel_nd(M, N, + [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; }); + + free(b_u8); + free(compensation); + + return mkldnn_success; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp new file mode 100644 index 0000000000..03a3d2f7e0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp @@ -0,0 +1,37 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_GEMM_S8S8S32_HPP +#define SIMPLE_GEMM_S8S8S32_HPP + +#include <stdint.h> +#include "mkldnn_types.h" + +namespace mkldnn { +namespace impl { +namespace cpu { + +mkldnn_status_t simple_gemm_s8s8s32( + const char *transA, const char *transB, const char *offsetC, + const int *m, const int *n, const int *k, + const float *alpha, const int8_t *a, const int *lda, const int8_t *oa, + const int8_t *b, const int *ldb, const int8_t *ob, + const float *beta, int32_t *c, const int *ldc, const int32_t *oc); +} +} +} + +#endif // SIMPLE_GEMM_S8S8S32_HPP diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp new file mode 100644 index 0000000000..604a728b47 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp @@ -0,0 +1,307 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "gemm_convolution.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + const int K = jcp.ic * jcp.ks; + const int N = jcp.oc; + + if (jcp.im2col_sz && jcp.id != 1) + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, + nb_oh, owb, nb_ow); + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const data_t *_src = src + (n * jcp.ngroups + g) * src_step; + const data_t *_weights = weights + g * weights_g_size; + data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, oh, h_step, ow, w_step); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); + } + + const data_t one = 1.0; + + const int m = h_step * w_step; + const int LDA = jcp.im2col_sz ? m : M; + data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; + + extended_sgemm("N", "N", &m, &N, &K, &one, + jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, + &this->beta_, _dst, &M); + + data_t *d = _dst; + if (eltwise_) { + // fast branch for ReLU case + if (eltwise_->alg_ == alg_kind::eltwise_relu) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; + } + }); + } else { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + d_[oS] = eltwise_->compute_scalar(d_[oS]); + } + }); + } + } else if (jcp.with_bias) { + parallel_nd(jcp.oc, [&](const int oc) { + data_t b = bias[g * jcp.oc + oc]; + data_t *d_ = d + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + } + }); + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, + owb, nb_ow); + } + }); +} + +void gemm_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int M = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * M; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int m = jcp.os; + const int K = jcp.oc; + const int N = jcp.ic * jcp.ks; + const int LDC = jcp.im2col_sz ? m : M; + + const size_t work_amount = (size_t)jcp.ngroups * jcp.mb; + + if (jcp.id > 1) { + const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step); + parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; }); + } + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + + int g{0}, n{0}; + size_t start = 0, end = 0; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); + for (size_t iwork = start; iwork < end; ++iwork) { + + data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step; + const data_t *_weights = weights + g * weights_g_size; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g) + *dst_step + od * m; + + const data_t zero = 0.0, one = 1.0; + extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M, + _weights, &N, &zero, + jcp.im2col_sz ? _col:_diff_src + od * m, &LDC); + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::col2im(jcp, _col, + _diff_src); + else + jit_gemm_convolution_utils::col2im_3d(jcp, _col, + _diff_src, od); + } + } + nd_iterator_step(g, jcp.ngroups, n, jcp.mb); + } + }); +} + +void gemm_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col); + auto wei_reduction = scratchpad(ctx).get<data_t>(key_conv_wei_reduction); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const int K = jcp.os * jcp.od; + const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; + const size_t dst_step = jcp.oc * K; + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + const int k = jcp.os; + const int N = jcp.oc; + const int M = jcp.ic * jcp.ks; + const int LDA = jcp.im2col_sz ? k : K; + + parallel_nd(jcp.im2col_sz * jcp.nthr, + [&](ptrdiff_t i) { col[i] = (data_t)0; }); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + int ithr_g, nthr_g, ithr_mb, nthr_mb; + size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0}; + + const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; + jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, + mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); + + assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); + const int need_reduction = nthr_mb != 1; + + if (ithr_g != -1 && ithr_mb != -1) { + balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); + balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); + + assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); + + data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; + data_t *weights_reduce_base = wei_reduction + + ithr_g * nthr_mb * weights_g_size; + data_t *weights_reduce = weights_reduce_base + + ithr_mb * weights_g_size; + + for (size_t g = g_start; g < g_end; ++g) { + data_t *_diff_weights = need_reduction + ? weights_reduce : (diff_weights + g * weights_g_size); + for (size_t mb = mb_start; mb < mb_end; ++mb) { + const data_t *_src = src + (mb*jcp.ngroups+g)*src_step; + for (int od = 0; od < jcp.od; ++od) { + const data_t *_diff_dst = diff_dst + + (mb*jcp.ngroups+g)*dst_step + od * k; + + if (jcp.im2col_sz) { + if (jcp.id == 1) + jit_gemm_convolution_utils::im2col( + jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); + else + jit_gemm_convolution_utils::im2col_3d(jcp, _src, + _col, od); + } + + const data_t zero = 0.0, one = 1.0; + extended_sgemm( + "T", "N", &M, &N, &k, &one, + jcp.im2col_sz ? _col : _src + od * k, + &LDA, _diff_dst, &K, + mb == mb_start && od == 0 ? &zero : &one, + _diff_weights, &M); + } + } + } + if (need_reduction) { + mkldnn_thr_barrier(); + data_t *weights_base = diff_weights + g_start * weights_g_size; + jit_gemm_convolution_utils::bwd_weights_reduction_par( + ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base); + } + } else + if (need_reduction) { mkldnn_thr_barrier(); } + }); + + if (jcp.with_bias) { + parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) { + data_t db = 0; + size_t offset_ = (size_t)g * dst_step + (size_t)oc * K; + for (int mb = 0; mb < jcp.mb; ++mb) + { + size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step; + for (int od = 0; od < jcp.od; ++od) + for (int oh = 0; oh < jcp.oh; ++oh) + PRAGMA_OMP_SIMD(reduction(+:db)) + for (int ow = 0; ow < jcp.ow; ++ow) { + db += diff_dst[offset]; + offset++; + } + } + diff_bias[g*jcp.oc+oc] = db; + }); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp new file mode 100644 index 0000000000..302e46369a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp @@ -0,0 +1,250 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP +#define CPU_JIT_GEMM_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "gemm_convolution_utils.hpp" +#include "gemm/gemm.hpp" +#include "ref_eltwise.hpp" + +#include "cpu_convolution_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct gemm_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + + bool post_ops_ok() const { + auto const &po = attr()->post_ops_; + auto is_eltwise = [&](int idx) + { return po.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; + + switch (po.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + return false; + } + }; + + gemm_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , eltwise_(nullptr) + { + const auto &post_ops = pd()->attr()->post_ops_; + const data_t one = 1.0, zero = 0.0; + beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero; + + const int entry_idx = post_ops.find(primitive_kind::eltwise); + if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t( + post_ops.entry_[entry_idx].eltwise); + } + + ~gemm_convolution_fwd_t() { delete eltwise_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + data_t beta_; + + ref_eltwise_scalar_fwd_t* eltwise_; +}; + +struct gemm_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct gemm_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*diff_weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), diff_weights_md(0), diff_dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { + using namespace format_tag; + return utils::pick(ndims() - 3, ncw, nchw, ncdhw); + } + + format_tag_t wei_tag() const { + using namespace format_tag; + return with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + } + }; + + gemm_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp new file mode 100644 index 0000000000..f133b1e62b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp @@ -0,0 +1,771 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "cpu_isa_traits.hpp" + +#include "gemm_convolution_utils.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; +using namespace prop_kind; +using namespace data_type; + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od) +{ + const size_t OHW = jcp.oh * jcp.ow; + const size_t im_step = jcp.ih * jcp.iw * jcp.id; + const size_t col_step = jcp.ks * OHW; + + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict im_loc = im + ic * im_step; + float *__restrict col_loc = col + ic * col_step; + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; + if (id < 0 || id >= jcp.id) { + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + + col_[col_idx] = 0; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } else { + const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; + int ih_ = -jcp.t_pad; + for (int kh = 0; kh < jcp.kh; ++kh) { + int ih = ih_; + for (int oh = 0; oh < jcp.oh; ++oh) { + if (ih < 0 || ih >= jcp.ih) { + ih += jcp.stride_h; + continue; + } + int iw_ = -jcp.l_pad; + for (int kw = 0; kw < jcp.kw; ++kw) { + int iw = iw_; + for (int ow = 0; ow < jcp.ow; ++ow) { + if (iw < 0 || iw >= jcp.iw) { + iw += jcp.stride_w; + continue; + } + + const size_t col_idx = kw * OHW + oh * jcp.ow + + ow; + const size_t im_idx = ih * jcp.iw + iw; + + col_[col_idx] = im_[im_idx]; + iw += jcp.stride_w; + } + iw_ += (1 + jcp.dilate_w); + } + ih += jcp.stride_h; + } + ih_ += (1 + jcp.dilate_h); + col_ += jcp.kw * OHW; + } + } + id += (1 + jcp.dilate_d); + } + }); +} + +/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb) { + const size_t im_step = jcp.is; + const size_t col_step = jcp.ks * hb * wb; + if (jcp.stride_w == 1) { + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + auto ker = [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ + = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else { + const size_t im_idx = ih * jcp.iw + iw; + col_[ow] = im_[im_idx]; + } + } + } + }; + + if (jcp.outer_threading) { + for (int ic = 0; ic < jcp.ic; ic++) + for (int kh = 0; kh < jcp.kh; kh++) + for (int kw = 0; kw < jcp.kw; kw++) + for (int oh = 0; oh < hb; oh++) + ker(ic, kh, kw, oh); + } + else { + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); + } + } else if (jcp.ic == 1) { + parallel_nd(jcp.kh, hb, [&](int kh, int oh) { + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + col[col_idx] = 0; + } + } + else + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t col_idx + = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col[col_idx] = 0; + else + col[col_idx] = im[im_idx]; + } + } + }); + } else { + + parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, + [&](int ic, int kh, int kw, int oh) { + const float *__restrict im_ = im + ic * im_step; + float *__restrict col_ = col + ic * col_step + + ((kh * jcp.kw + kw) * hb + oh) * wb; + + const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) { + for (int ow = 0; ow < wb; ++ow) + col_[ow] = 0.f; + } else { + for (int ow = 0; ow < wb; ++ow) { + const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + const size_t im_idx = ih * jcp.iw + iw; + if (iw < 0 || iw >= jcp.iw) + col_[ow] = 0.f; + else + col_[ow] = im_[im_idx]; + } + } + }); + } +} + +inline int limit(int low, int upper, int value) { + return nstl::max(low, nstl::min(upper, value)); +} + +/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ +template <typename T> +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws, + int wb) { + uint8_t shift = jcp.signed_input ? 128 : 0; + const int dh = 1 + jcp.dilate_h; + const int dw = 1 + jcp.dilate_w; + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int im_iw_stride = jcp.ic * jcp.ngroups; + const int im_ih_stride = jcp.iw * im_iw_stride; + const int tp = jcp.t_pad; + const int lp = jcp.l_pad; + + if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { + /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ + const int hp = hs - tp; + const int wp = ws - lp; + const int ih_start = limit(0, jcp.ih, hp); + const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); + const int iw_start = limit(0, jcp.iw, wp); + const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); + + const int ihb = ih_end - ih_start; + const int iwb = iw_end - iw_start; + + const int imtr_ic_stride = ihb * iwb; + const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; + for (int ih = ih_start; ih < ih_end; ih++) { + const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; + const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; + for (int iw = iw_start; iw < iw_end; iw++) + imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; + } + } + + const int col_ic_str = hb * wb; + const int col_kw_stride = jcp.ic * col_ic_str; + const int col_kh_stride = jcp.kw * col_kw_stride; + + const int oh_init = ih_start - hp; + const int ow_init = iw_start - wp; + for (int kh = 0; kh < jcp.kh; kh++) { + const ptrdiff_t col_idx_kh = kh * col_kh_stride; + const int oh_kh = oh_init - kh; + const int oh_start = limit(0, hb, oh_kh); + const int oh_end = limit(0, hb, oh_kh + ihb); + for (int kw = 0; kw < jcp.kw; kw++) { + const ptrdiff_t col_idx_kw + = col_idx_kh + kw * jcp.ic * col_ic_str; + const int ow_kw = ow_init - kw; + const int imtr_shift = oh_kh * iwb + ow_kw; + const int ow_start = limit(0, wb, ow_kw); + const int ow_end = limit(0, wb, ow_kw + iwb); + for (int ic = 0; ic < jcp.ic; ic++) { + const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; + const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; + for (int oh = 0; oh < oh_start; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_start; oh < oh_end; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; + for (int ow = 0; ow < ow_start; ++ow) + col[col_idx_oh + ow] = shift; + for (int ow = ow_start; ow < ow_end; ++ow) + col[col_idx_oh + ow] + = imtr[imtr_idx_oh + ow] + shift; + for (int ow = ow_end; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + for (int oh = oh_end; oh < hb; oh++) { + const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; + for (int ow = 0; ow < wb; ++ow) + col[col_idx_oh + ow] = shift; + } + } + } + } + } else { + parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, + [&](int kh, int kw, int ic, int oh) { + const int hp = tp - kh * dh; + const int ih = (oh + hs) * sh - hp; + const ptrdiff_t col_idx_base + = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb; + if (ih < 0 || ih >= jcp.ih) + for (int ow = 0; ow < wb; ow++) + col[col_idx_base + ow] = shift; + else { + const int wp = lp - kw * dw; + const int ow_start = limit(0, wb, div_up(wp, sw) - ws); + const int ow_end + = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); + for (int ow = 0; ow < ow_start; ow++) + col[col_idx_base + ow] = shift; + const int iw_base = ws * sw - wp; + const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = iw_base + ow * sw; + const ptrdiff_t im_idx + = im_idx_base + iw * im_iw_stride; + col[col_idx_base + ow] = im[im_idx] + shift; + } + for (int ow = ow_end; ow < wb; ow++) + col[col_idx_base + ow] = shift; + } + }); + } +} + +template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp, + const int8_t *__restrict im, int8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); +template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp, + const uint8_t *__restrict im, uint8_t *__restrict imtr, + uint8_t *__restrict col, int hs, int hb, int ws, int wb); + +/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */ +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im) +{ + parallel(0, [&](const int ithr, const int nthr) { + int h_nthr = nstl::min(jcp.ih, nthr); + int w_nthr = nstl::min(jcp.iw, nthr / h_nthr); + int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0; + if (ithr < h_nthr * w_nthr) { + h_ithr = ithr / w_nthr; + w_ithr = ithr % w_nthr; + balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); + balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); + } else { + h_ithr = w_ithr = -ithr; + h_s = h_e = w_s = w_e = -1; + } + + for (int ih = h_s; ih < h_e; ++ih) { + for (int iw = w_s; iw < w_e; ++iw) { + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0; + } + } + } + + // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h + - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < h_s || ih >= h_e) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w + - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < w_s || iw >= w_e) continue; + + const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + + kh) * jcp.kw + kw) * jcp.ic; + const size_t im_idx + = (ih * jcp.iw + iw) * jcp.ic; + PRAGMA_OMP_SIMD() + for (int ic = 0; ic < jcp.ic; ++ic) { + im[im_idx + ic] += col[col_idx + ic]; + } + } + } + } + } + }); +} + +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od) +{ + parallel_nd(jcp.ic, [&](int ic) { + const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; + float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; + + int id = od * jcp.stride_d - jcp.f_pad; + for (int kd = 0; kd < jcp.kd; ++kd) { + if (id < 0 || id >= jcp.id) { + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + continue; + } + + float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw; + + for (int oh = 0; oh < jcp.oh; ++oh) { + for (int kh = 0; kh < jcp.kh; ++kh) { + const int ih = oh * jcp.stride_h - jcp.t_pad + + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int ow = 0; ow < jcp.ow; ++ow) { + for (int kw = 0; kw < jcp.kw; ++kw) { + const int iw = ow * jcp.stride_w - jcp.l_pad + + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + }} + }} + + col_ += jcp.kh * jcp.kw * jcp.os; + id += (1 + jcp.dilate_d); + } + }); +} + +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) { + const size_t col_step = jcp.ks * jcp.os; + const size_t im_step = jcp.ih * jcp.iw; + const int iS = jcp.ih * jcp.iw; + + parallel_nd(jcp.ic, [&](int ic) { + float *__restrict im_ = im + ic * im_step; + const float *__restrict col_ = col + ic * col_step; + PRAGMA_OMP_SIMD() + for (int is = 0; is < iS; ++is) im_[is] = 0.; + + for (int kh = 0; kh < jcp.kh; ++kh) { + for (int oh = 0; oh < jcp.oh; ++oh) { + const int ih = + oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); + if (ih < 0 || ih >= jcp.ih) continue; + + for (int kw = 0; kw < jcp.kw; ++kw) { + for (int ow = 0; ow < jcp.ow; ++ow) { + const int iw = + ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w); + if (iw < 0 || iw >= jcp.iw) continue; + + const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow; + const size_t im_idx = ih*jcp.iw + iw; + im_[im_idx] += col_[col_idx]; + } + } + } + } + }); +} + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads) { + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + const int is_1d = ndims == 3; + const int is_3d = ndims == 5; + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = is_3d ? src_d.dims()[2] : 1; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = is_3d ? dst_d.dims()[2] : 1; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = is_3d ? cd.padding[0][0] : 0; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_d = is_3d ? cd.strides[0] : 1; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = is_3d ? cd.dilates[0] : 0; + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef + || cd.diff_bias_desc.format_kind != format_kind::undef; + + jcp.is = jcp.ih * jcp.iw; + jcp.os = jcp.oh * jcp.ow; + jcp.ks = jcp.kh * jcp.kw * jcp.kd; + + jcp.signed_input = src_d.data_type() == data_type::s8; + + jcp.im2col_sz = !everyone_is(true, + jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id, + jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1, + jcp.ks == 1, !jcp.signed_input) + ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0; + + jcp.outer_threading = false; + + bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8) + && weights_d.data_type() == s8; + + const int vlen = mayiuse(avx512_common) + ? cpu_isa_traits<avx512_common>::vlen + : mayiuse(avx) + ? cpu_isa_traits<avx>::vlen + : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4; + const int simd_w = vlen / (is_int8_conv ? 1 : 4); + + const bool is_bwd_d = jcp.prop_kind == backward_data; + const bool is_bwd_w = jcp.prop_kind == backward_weights; + const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; + jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; + + using namespace memory_tracking::names; + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + + // TODO: maybe mitigate blocking restriction + const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int L2 = get_cache_size(2, true) + / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); + bool is_blocking_applicable = true + && is_fwd && jcp.im2col_sz + && jcp.id == 1 && jcp.od == 1 + && jcp.dilate_h == 0 && jcp.dilate_w == 0 + && !is_depthwise + && wei_size < L2/2; + if (is_blocking_applicable) { + // looking for oh and ow blocking + int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; + const int ic = jcp.ic; + const int oc = jcp.oc; + const int iw = jcp.iw; + const int ow = jcp.ow; + const int oh = jcp.oh; + const int os = oh * ow; + + // 1. cache requirement + int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); + if (is_int8_conv) { + // Heuristic rule: gemm needed a lot of memory for internal usage + row_size *= 5; + // memory for accumulators + row_size += oc * ow * sizeof(uint32_t); + // memory for transposition + row_size += ic * iw; + } + + h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); + if (h_block == 1) { + int col_size = ic * jcp.ks + 2 * (ic + oc); + if (is_int8_conv) { + col_size *= 5; + col_size += oc * sizeof(uint32_t); + col_size += ic; + } + w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); + } + + // 2. threading requirement + if (h_block != oh) + h_block = nstl::max(1, rnd_dn(h_block, 4)); + if (w_block != ow) + w_block = nstl::max(1, rnd_dn(w_block, simd_w)); + + float thr_eff = 0.f; + float thr_eff_treshold = 0.9f; + if (w_block == ow) { + do { + int nb_h = div_up(oh, h_block); + size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; + float disb = (float)oh / rnd_up(oh, h_block); + thr_eff = (float)work / rnd_up(work, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff >= thr_eff_treshold) + break; + h_block = rnd_dn(h_block - 4, 4); + } while (h_block > 0); + } + if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block + { + h_block = 1; + int nb_h = oh; + do { + int nb_w = div_up(ow, w_block); + size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; + float disb = (float)ow / rnd_up(ow, w_block); + thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff > thr_eff_treshold) + break; + w_block = rnd_dn(w_block - simd_w, simd_w); + } while (w_block > 0); + } + h_block = nstl::max(1, h_block); + w_block = nstl::max(1, w_block); + const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { + jcp.oh_block = h_block; + jcp.ow_block = w_block; + jcp.outer_threading = true; + } + // updating jcp.im2col_sz + if (jcp.oh_block != 1) + jcp.ow_block = ow; + jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; + } + // For threading selection in bwd_d we do: + // 1. Rough estimation of efficiency for inner and outer threading. + // 2. Gemm size estimation in assumption that it does not work + // so effectively for small sizes. + // 64K - this is heuristic gemm size per thread threshold. + const int gemm_thrld = 64 * 1024; + + if (is_int8_conv) { + if (is_fwd) { + if (!jcp.outer_threading) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); + scratchpad.book(key_conv_gemm_imtr, + sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_d) { + bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + const size_t outer_work = jcp.ngroups * jcp.mb; + const float outer_thr_eff + = (float)outer_work / rnd_up(outer_work, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); + scratchpad.book(key_conv_int_dat_in_acc_dt, + sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); + } else if (is_bwd_w) { + assert(!"unimplemented prop_kind"); + return status::unimplemented; + } + } else { + if (is_fwd) { + if (!jcp.outer_threading) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work_amount + = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); + const float inner_thr_eff = (float)inner_work_amount + / rnd_up(inner_work_amount, max_threads); + jcp.outer_threading = jcp.os / max_threads < 512 + && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } + } else if (is_bwd_d) { + const size_t outer_work_amount = jcp.ngroups * jcp.mb; + const float outer_thr_eff = (float)outer_work_amount + / rnd_up(outer_work_amount, max_threads); + const size_t inner_work + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff = (float)inner_work + / rnd_up(inner_work, max_threads); + jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) + && (jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + } else if (is_bwd_w) + jcp.outer_threading = jcp.os / max_threads < 256 + && (jcp.mb != 1 || jcp.ngroups > 2); + + jcp.nthr = jcp.outer_threading ? max_threads : 1; + scratchpad.book(key_conv_gemm_col, + sizeof(float) * jcp.nthr * jcp.im2col_sz); + + if (is_bwd_w) { + jcp.need_wei_reduction = mkldnn_thr_syncable() + ? jcp.mb != 1 && jcp.nthr != 1 : false; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); + } + } + + return status::success; +} + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, + int &nthr_g, int &ithr_mb, int &nthr_mb) { + nthr_g = nstl::min(ngroups, nthr); + nthr_mb = nstl::min(mb, nthr / nthr_g); + if (ithr / nthr_mb >= ngroups) { + ithr_g = ithr_mb = -1; + } else { + ithr_g = ithr / nthr_mb; + ithr_mb = ithr % nthr_mb; + } +} + +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights) { + const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + + size_t weights_start{0}, weights_end{0}; + balance211(weights_g_size, nthr, ithr, weights_start, weights_end); + + for (int i = 0; i < nthr; ++i) { + const float *ws_i = weights_reduce_ws + i * weights_g_size; + for (size_t s = weights_start; s < weights_end; ++s) + weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; + } +} + +}; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp new file mode 100644 index 0000000000..e006789344 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP +#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_engine.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace jit_gemm_convolution_utils { + +void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, + int od); +void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, + float *__restrict col, int hs, int hb, int ws, int wb); +template <typename T> +void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, + T* __restrict imtr, uint8_t *__restrict col, + int hs, int hb, int ws, int wb); + +void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col, + int32_t *__restrict im); +void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im, + int od); +void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im); + +status_t init_conf(jit_gemm_conv_conf_t &jcp, + memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, int max_threads); + +void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, + int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb); +void bwd_weights_reduction_par(int ithr, int nthr, + const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws, + float *weights); + +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp new file mode 100644 index 0000000000..2872122f0d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp @@ -0,0 +1,156 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" + +#include "gemm_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::data_type; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::primitive_kind; + +template <impl::data_type_t data_type> +void gemm_inner_product_fwd_t<data_type>::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = !memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights, + wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias); + + if (do_relu) { + float nslope = post_ops.entry_[0].eltwise.alpha; + parallel_nd(MB, OC, [&](int mb, int oc) { + size_t dst_off = mb * OC + oc; + if (dst[dst_off] < 0) + dst[dst_off] *= nslope; + }); + } +} + +template <impl::data_type_t data_type> +void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights, + wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC); +} + +template <impl::data_type_t data_type> +void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + diff_dst += diff_dst_d.offset0(); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC_total_padded(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->diff_weights_md(), hwio, dhwio, io); + + float alpha = 1.0, beta = 0.0; + if (wei_tr) + extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC, + &beta, diff_weights, &OC); + else + extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC, + &beta, diff_weights, &IC); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + constexpr int blksize = 8; + const int OC_blocks = OC / blksize; + const int rem_OC = OC % blksize; + parallel(0, [&](const int ithr, const int nthr) { + int oc_st{0}, oc_e{0}; + balance211(OC_blocks, nthr, ithr, oc_st, oc_e); + oc_st = oc_st * blksize; + oc_e = oc_e * blksize; + + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] = diff_dst[oc]; + } + + for (int mb = 1; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int oc = oc_st; oc < oc_e; ++oc) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + + if (rem_OC != 0 && ithr == nthr-1) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) + diff_bias[oc] = diff_dst[oc]; + for (int mb = 1; mb < MB; ++mb) { + for (int oc = OC_blocks * blksize; oc < OC; oc++) { + diff_bias[oc] += diff_dst[mb * OC + oc]; + } + } + } + }); + } +} + +template struct gemm_inner_product_fwd_t<data_type::f32>; +template struct gemm_inner_product_bwd_data_t<data_type::f32>; +template struct gemm_inner_product_bwd_weights_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp new file mode 100644 index 0000000000..acf0a49b9a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp @@ -0,0 +1,157 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_GEMM_INNER_PRODUCT_HPP +#define CPU_GEMM_INNER_PRODUCT_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +struct gemm_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type, + src_md()->data_type, + weights_md()->data_type, + dst_md()->data_type, + with_bias() ? weights_md(1)->data_type : data_type) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct gemm_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_src_md()->data_type, + weights_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(diff_src_md(), weights_md(), + diff_dst_md()); + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + src_md()->data_type, + diff_weights_md()->data_type, + diff_dst_md()->data_type, + with_bias() ? diff_weights_md(1)->data_type : data_type) + && attr()->has_default_values() + && dense_gemm_consitency_check(src_md(), diff_weights_md(), + diff_dst_md()); + + return ok ? status::success : status::unimplemented; + } + }; + + gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..fed7e4d693 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp @@ -0,0 +1,740 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "math_utils.hpp" + +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace mkldnn::impl::memory_tracking::names; + +template <data_type_t src_type, data_type_t dst_type> +void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>:: +execute_forward(const exec_ctx_t &ctx) const { + auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + assert(IMPLICATION( + jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); + assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base, + scratchpad); + }); +} + +template <data_type_t src_type, data_type_t dst_type> +_gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t( + const pd_t *pd) + : ker_(nullptr) + , jcp_(pd->jcp_) + , OC_(pd->jcp_.oc) + , OS_(pd->jcp_.os) + , bias_data_type_(data_type::undef) + , bias_data_type_size_(0) + , scale_idx_mult_(0) + , do_bias_(false) + , do_relu_(false) + , do_sum_(false) +{ + using namespace types; + + const auto dst_md = memory_desc_wrapper(pd->dst_md()); + dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1); + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + + int entry_idx = -1; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + entry_idx = idx; + break; + } + } + do_relu_ = entry_idx >= 0; + + do_signed_scaling_ = jcp_.signed_input; + + do_sum_ = post_ops.contain(primitive_kind::sum, 0); + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + const size_t vlen_start + = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + + for (size_t i = vlen_start; i > 0; i--) { + if (OC_ % i == 0) { + vlen_ = i; + break; + } + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs + return; + else + generate(); +} + +template <data_type_t src_type, data_type_t dst_type> +void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask_short = r10; + Reg64 reg_rem_mask_vlen = r11; + Opmask kreg_rem_mask_short = k1; + Opmask kreg_rem_mask_vlen = k3; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = vlen_; + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + Zmm vreg_sum_scale = Zmm(3); + Zmm vreg_signed_scale = Zmm(4); + + size_t def_unroll = 4; + size_t max_unroll = 12; + size_t zmm_step = 2; + if (do_sum_) { + max_unroll = 8; + zmm_step = 3; + } + + auto vreg_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 0); + }; + auto vreg_bias = [&](int idx) { + return Zmm(5 + idx * zmm_step + 1); + }; + auto vreg_prev_dst = [&](int idx) { + return Zmm(5 + idx * zmm_step + 2); + }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); + vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); + +#undef PARAM_OFF + + mov(reg_rem_mask_vlen, 1); + shl(reg_rem_mask_vlen, vlen); + sub(reg_rem_mask_vlen, 1); + kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen); + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply sum (if any), + // bias (if any), scaling, and relu (if any); + // then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask_short; + else + vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen; + vmovups(vreg_scale_, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask_short; + else + vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_signed_scaling_) + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask_short; + else + vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + + if (do_sum_) + { + auto vreg_prev_dst_ = vreg_prev_dst(idx); + if (apply_mask) + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short; + else + vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen; + + switch (dst_type) { + case data_type::f32: + case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break; + case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break; + case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break; + default: assert(!"unsupported data type"); + } + if (dst_type != data_type::f32) + vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx)); + + vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); + } + + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + if (dst_type == data_type::u8) + vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indexed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t)); + }; + + // <--------- OC ---------------> + // + // ^ ................+..............+-------------+....................... + // | . : not accessed |Prologue loop| . + // | . +--------------+-------------+ . + // . | | . + // O . | Main loop (unrolled) | . + // S . | | . + // . +--------------+-------------+ . + // | . | Epilogue loop|not accessed : . + // v ................+--------------+.............+....................... + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask_short, 1); + // cl == reg_tmp because reg_tmp <= vlen here + shl(reg_rem_mask_short, cl); + sub(reg_rem_mask_short, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } + else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask_short, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask_short, 1); + shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen + sub(reg_rem_mask_short, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask_short, reg_rem_mask_short); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode<decltype(ker_)>(); +} + +template <data_type_t src_type, data_type_t dst_type> +void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator () + (dst_data_t *dst, const acc_data_t *acc, const char *bias, + const float *scales, float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + size_t os_offset = start / OC_; + args.acc = acc + start; + args.dst = dst + os_offset * dst_os_stride_ + oc_offset; + args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset); + args.nslope = nslope; + args.sum_scale = sum_scale; + args.signed_scale = signed_scale; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } + else { + // Fallback + const size_t first_oc = start % OC_; + const size_t last_oc = (end - 1) % OC_; + const size_t first_os = start / OC_; + const size_t last_os = (end - 1) / OC_; + for (size_t os = first_os; os <= last_os; os++) { + const size_t start_oc = (os == first_os) ? first_oc : 0; + const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1; + for (size_t oc = start_oc; oc <= end_oc; oc++) { + const size_t acc_off = os * jcp_.oc + oc; + const size_t dst_off = os * dst_os_stride_ + oc; + + float d = (float)(acc[acc_off]); + if (jcp_.signed_input) + d *= signed_scale; + + if (do_bias_) + d += get_bias(bias, g * jcp_.oc + oc, + bias_data_type_); + + d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_]; + if (do_sum_) + d += sum_scale * dst[dst_off]; + if (do_relu_ && d < 0) + d *= nslope; + dst[dst_off] = qz_a1b0<float, dst_data_t>()(d); + } + } + } +}; + +template <data_type_t src_type, data_type_t dst_type> +void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base, + const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const { + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto src_md = memory_desc_wrapper(pd()->src_md()); + const size_t src_mb_stride = src_md.blk_off(1); + const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto dst_md = memory_desc_wrapper(pd()->dst_md()); + const size_t dst_mb_stride = dst_md.blk_off(1); + const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_sum = post_ops.contain(primitive_kind::sum, 0); + const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; + + float nslope = 0; + for (int idx = 0; idx < post_ops.len_; ++idx) { + const auto &e = post_ops.entry_[idx]; + if (e.is_relu(true, false)) { + nslope = e.eltwise.alpha; + break; + } + } + + auto col = scratchpad.get<uint8_t>(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; + + const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc; + const int32_t *_wei_comp = (const int32_t *)(wei_base + offset); + + int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 }; + size_t start = 0, end = 0; + + const int nb_oh = div_up(jcp.oh, jcp.oh_block); + const int nb_ow = div_up(jcp.ow, jcp.ow_block); + const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, + nb_oh, owb, nb_ow); + + for (size_t iwork = start; iwork < end; ++iwork) { + int oh = ohb * jcp.oh_block; + int ow = owb * jcp.ow_block; + const src_data_t *__restrict src = src_base + n * src_mb_stride + + g * src_g_stride; + const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; + dst_data_t *__restrict dst = + dst_base + n * dst_mb_stride + g * dst_g_stride; + const int32_t *wei_comp = _wei_comp + g * jcp.oc; + const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); + const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::im2col_u8<src_data_t>( + jcp, src, imtr, col, oh, h_step, ow, w_step); + + const int M = jcp.oc; + const int K = jcp.ks * jcp.ic; + const int N = h_step * w_step; + const int LDA = M * jcp.ngroups; + const int LDB = jcp.im2col_sz ? N : K; + const char *BT = jcp.im2col_sz ? "T" : "N"; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", + &M, &N, &K, &onef, wei, &LDA, &off_a, + jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b, + &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c); + + auto wei_adj_scale = + (wei_md.extra().flags | memory_extra_flags::scale_adjust) + ? wei_md.extra().scale_adjust : 1.f; + + parallel(0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)N * jcp.oc, nthr, ithr, start, end); + (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_, + acc, bia_base, scales, nslope, sum_scale, + 1.f / wei_adj_scale, g, start, end); + }); + + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, + owb, nb_ow); + } +} + +template <data_type_t dst_type> +void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>:: +execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base, + bia_base, diff_src_base, scratchpad); + }); +} + +template <data_type_t dst_type> +void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>:: +execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const +{ + const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; + + const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md()); + const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1); + const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc; + + const auto wei_md = memory_desc_wrapper(pd()->weights_md(0)); + const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0; + + const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md()); + const size_t diff_src_mb_stride = diff_src_md.blk_off(1); + const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic; + const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1); + + /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */ + const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1); + const float *scales = pd()->attr()->output_scales_.scales_; + const size_t work_amount = jcp.ngroups * jcp.mb; + + auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col) + + (ptrdiff_t)ithr * jcp.im2col_sz; + auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt) + + (ptrdiff_t)ithr * jcp.is * jcp.ic; + + int n{0}, g{0}; + size_t start = 0, end = 0; + + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); + + for (size_t iwork = start; iwork < end; ++iwork) { + const diff_dst_data_t *diff_dst = diff_dst_base + + n * diff_dst_mb_stride + g * diff_dst_g_stride; + const wei_data_t *wei = wei_base + g * wei_g_stride; + diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride + + g * diff_src_g_stride; + + const int M = jcp.ks * jcp.ic; + const int N = jcp.os; + const int K = jcp.oc; + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + const float onef = 1.0, zerof = 0.0; + const int LD = K * jcp.ngroups; + + gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, + wei, &LD, &off_a, diff_dst, &LD, &off_b, + &zerof, jcp.im2col_sz ? col : acc, &M, &off_c); + + if (jcp.im2col_sz) + jit_gemm_convolution_utils::col2im_s32(jcp, col, acc); + + parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) { + float d = (float)acc[is * jcp.ic + ic]; + if (jcp.with_bias) + d += get_bias(bia_base, g * jcp.ic + ic, + pd()->desc()->bias_desc.data_type); + d *= scales[(g * jcp.ic + ic) * scale_idx_mult]; + const size_t diff_src_off = is * diff_src_os_stride + ic; + diff_src[diff_src_off] = + qz_a1b0<float, diff_src_data_t>()(d); + }); + nd_iterator_step(n, jcp.mb, g, jcp.ngroups); + } +} + +using namespace data_type; + +template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>; +template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>; +template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>; +template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>; + +template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>; +template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>; +template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>; +template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>; + +template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>; +template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>; +template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>; +template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..9e77b890d5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP +#define GEMM_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" +#include "gemm_convolution_utils.hpp" + +#include "gemm/gemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t src_type, data_type_t dst_type> +struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, s8, data_type::undef, dst_type, + s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common( + dat_tag(), format_tag::any, dat_tag()) + && post_ops_ok() + && memory_desc_matches_tag(*src_md(), dat_tag()) + && memory_desc_matches_tag(*dst_md(), dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), src_md(), weights_md(0), dst_md(), + mkldnn_get_max_threads()); + } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio); + + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + + bool post_ops_ok() const { + using namespace mkldnn::impl::primitive_kind; + auto const &po = attr()->post_ops_; + auto is_relu = [&](int idx) { + return po.entry_[idx].is_relu(true, false); }; + + switch (po.len_) { + case 0: return true; + case 1: return is_relu(0) || po.contain(sum, 0); + case 2: return po.contain(sum, 0) && is_relu(1); + default: return false; + } + return false; + } + }; + + _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true), pp_ker_(nullptr) + { pp_ker_ = new pp_ker_t(pd()); } + ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + typedef typename prec_traits<data_type::s32>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_ker_t : jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _gemm_x8s8s32x_convolution_fwd_t::pp_kernel); + pp_ker_t(const pd_t *pd); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, + float nslope, float sum_scale, float signed_scale, + int g, size_t start, size_t end); + + size_t dst_os_stride_; + + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + float sum_scale; + float signed_scale; + size_t len; + size_t oc_offset; + }; + void(*ker_)(const ker_args *args); + + const jit_gemm_conv_conf_t &jcp_; + size_t OC_; + size_t OS_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + bool do_sum_; + bool do_signed_scaling_; + size_t vlen_; + }; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src_base, const wei_data_t *wei_base, + const char *bia_base, dst_data_t *dst_base, + const memory_tracking::grantor_t &scratchpad) const; + + int nthr_; + pp_ker_t *pp_ker_; + +}; + +template <data_type_t dst_type> +struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t{ + pd_t(engine_t *engine, + const convolution_desc_t *adesc, const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR, + _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>); + + status_t init() { + using namespace data_type; + + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(dst_type, s8, data_type::undef, u8, s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, f32, s32, s8, u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), wei_tag(), dat_tag()) + && attr()->post_ops_.has_default_values() + && memory_desc_matches_tag(*diff_src_md(), dat_tag()) + && memory_desc_matches_tag(*diff_dst_md(), dat_tag()) + && memory_desc_matches_tag(*weights_md(), wei_tag()); + if (!ok) return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, + *desc(), diff_src_md(), weights_md(), diff_dst_md(), + mkldnn_get_max_threads()); + } + + virtual bool support_bias() const override { return true; } + + jit_gemm_conv_conf_t jcp_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + format_tag_t wei_tag() const { + return with_groups() ? format_tag::hwigo : format_tag::hwio; + } + }; + + _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd, true) {} + + typedef typename prec_traits<data_type::u8>::type diff_dst_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type diff_src_data_t; + typedef typename prec_traits<data_type::s32>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + void execute_backward_data_thr(const int ithr, const int nthr, + const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base, + const char *bia_base, diff_src_data_t *diff_src_base, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp new file mode 100644 index 0000000000..1e435a233a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp @@ -0,0 +1,453 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "gemm/gemm.hpp" +#include "gemm_x8s8s32x_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace math; +using namespace format_tag; +using namespace memory_tracking::names; + +template<data_type_t src_type, data_type_t dst_type> +gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::pp_kernel_t( + const pd_t *pd, bool dst_is_acc) + : ker_(nullptr), OC_(pd->OC()) + , bias_data_type_(data_type::undef), bias_data_type_size_(0) + , scale_idx_mult_(0), do_bias_(false), do_relu_(false) +{ + using namespace types; + + scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1)); + + auto &post_ops = pd->attr()->post_ops_; + do_relu_ = post_ops.len_ == 1; + do_bias_ = pd->with_bias(); + bias_data_type_ = pd->desc()->bias_desc.data_type; + if (do_bias_) { + assert(bias_data_type_ != data_type::undef); + bias_data_type_size_ = data_type_size(bias_data_type_); + } + + if (!mayiuse(avx512_core)) + // use fallback code for older CPUs since they do not have optimized + // x8s8s32 GEMM anyways. The configuration variables above are used by + // the fallback code. + return; + else + generate(); +} + +template<data_type_t src_type, data_type_t dst_type> +void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::generate() +{ + using namespace Xbyak; + using namespace utils; + + // TODO: clean-up + Reg64 reg_param = abi_param1; + Reg64 reg_dst = rdx; + Reg64 reg_acc = rax; + Reg64 reg_bias = rbx; + Reg64 reg_scales = rsi; + + Reg64 reg_len = r8; + Reg64 reg_tmp = rcx; // intentional for shifting purposes + Reg64 reg_oc_offset = r9; + Reg64 reg_rem_mask = r10; + Opmask kreg_rem_mask = k1; + Opmask kreg_relu_cmp = k2; + + const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + + Zmm vreg_zero = Zmm(0); + Zmm vreg_scale = Zmm(1); + Zmm vreg_nslope = Zmm(2); + + auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); }; + auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); }; + + preamble(); + +#define PARAM_OFF(x) offsetof(ker_args, x) + mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); + mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); + mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); + mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); + mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); + mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); + vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]); + if (scale_idx_mult_ == 0) + vbroadcastss(vreg_scale, dword[reg_scales]); +#undef PARAM_OFF + + if (do_relu_ || dst_type == data_type::u8) + vxorps(vreg_zero, vreg_zero, vreg_zero); + + // Load accumulated value, convert to float, apply bias (if any), scaling, + // and relu (if any); then convert to destination type and store + auto compute = [&](size_t offset, int idx, bool apply_mask) { + auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; + + if (scale_idx_mult_ > 0) { + assert(scale_idx_mult_ == 1); + auto scale_addr = ptr[reg_scales + offset * sizeof(float)]; + auto vreg_scale_ = vreg_scale; + if (apply_mask) + vreg_scale_ = vreg_scale_ | kreg_rem_mask; + vmovups(vreg_scale, scale_addr); + } + + auto vreg_dst_ = vreg_dst(idx); + if (apply_mask) + vreg_dst_ = vreg_dst_ | kreg_rem_mask; + vcvtdq2ps(vreg_dst_, acc_addr); + + if (do_bias_) { + auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_]; + auto vreg_bias_ = vreg_bias(idx); + if (apply_mask) + vreg_bias_ = vreg_bias_ | kreg_rem_mask; + + switch (bias_data_type_) { + case data_type::s8: + vpmovsxbd(vreg_bias_, bias_addr); + break; + case data_type::u8: + vpmovzxbd(vreg_bias_, bias_addr); + break; + case data_type::s32: + case data_type::f32: + vmovups(vreg_bias_, bias_addr); + break; + default: assert(!"unimplemented"); + } + if (bias_data_type_ != data_type::f32) + vcvtdq2ps(vreg_bias(idx), vreg_bias(idx)); + vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx)); + } + + vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale); + if (do_relu_) { + vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os); + vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope); + } + + if (dst_type == data_type::u8) + vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero); + + if (dst_type != data_type::f32) { + vcvtps2dq(vreg_dst(idx), vreg_dst(idx)); + } + + auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; + switch (dst_type) { + case data_type::s8: + vpmovsdb(dst_addr, vreg_dst_); + break; + case data_type::u8: + vpmovusdb(dst_addr, vreg_dst_); + break; + case data_type::f32: + case data_type::s32: + vmovups(dst_addr, vreg_dst_); + break; + default: assert(!"unimplemented"); + } + }; + + // Advance all pointers by an immediate + auto advance_ptrs_imm = [&](size_t offset) { + add(reg_dst, offset * sizeof(dst_data_t)); + add(reg_acc, offset * sizeof(acc_data_t)); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + add(reg_scales, offset * sizeof(float)); + } + if (do_bias_) + add(reg_bias, offset * bias_data_type_size_); + }; + + // Advance all pointers by a value stored in a register + auto advance_ptrs_reg = [&](Reg64 offset) { + lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]); + lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); + } + if (do_bias_) + lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]); + }; + + // Rewind pointers that point to data that is indixed by output channel + // (bias or per-oc scaling factors) + auto rewind_ptrs = [&]() { + if (do_bias_) + sub(reg_bias, OC_ * bias_data_type_size_); + if (scale_idx_mult_) { + assert(scale_idx_mult_ == 1); + sub(reg_scales, OC_ * sizeof(float)); + } + }; + + // <-------------------- OC -------------------------------> + // + // ^ +....................+----------------------------------+ + // | : not accessed | Prologue loop | + // | +--------------------+----------------------------------+ + // | | + // M | Main loop (unrolled) | + // B | | + // +--------------------------------+----------------------+ + // | | Epilogue loop | not accessed : + // v +--------------------------------+......................+ + + Label prologue_end; + cmp(reg_oc_offset, 0); + je(prologue_end, T_NEAR); + + // Prologue loop + { + mov(reg_tmp, OC_); + sub(reg_tmp, reg_oc_offset); + cmp(reg_tmp, reg_len); + cmovg(reg_tmp, reg_len); + sub(reg_len, reg_tmp); + + Label prologue_loop, prologue_loop_tail, prologue_loop_end; + cmp(reg_tmp, vlen); + jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?) + L(prologue_loop); { + compute(0, 0, false); + advance_ptrs_imm(vlen); + sub(reg_tmp, vlen); + cmp(reg_tmp, vlen); + jge(prologue_loop, T_NEAR); + } + + L(prologue_loop_tail); + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here + sub(reg_rem_mask, 1); + jz(prologue_loop_end, T_NEAR); + + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + advance_ptrs_reg(reg_tmp); + + L(prologue_loop_end); + rewind_ptrs(); + } + L(prologue_end); + + // Main loop + Label main_loop_end; + { + cmp(reg_len, OC_); + jle(main_loop_end, T_NEAR); + + Label main_loop; + L(main_loop); { + size_t def_unroll = 4; + size_t max_unroll = 13; + + size_t OC_loop, OC_tail; + if (OC_ < max_unroll * vlen) { + // Fully unroll small loops + OC_loop = 0; + OC_tail = OC_; + } else { + OC_loop = vlen * def_unroll; + OC_tail = OC_ % OC_loop; + } + + assert(!!OC_loop || !!OC_tail); + + if (OC_tail % vlen) { + int vlen_tail = OC_tail % vlen; + unsigned tail_mask = (1 << vlen_tail) - 1; + mov(reg_tmp, tail_mask); + kmovq(kreg_rem_mask, reg_tmp); + } + + if (OC_loop) { + mov(reg_tmp, rnd_dn(OC_, OC_loop)); + Label oc_loop; + L(oc_loop); { + for (size_t offset = 0; offset < OC_loop; offset += vlen) + compute(offset, offset / vlen, false); + advance_ptrs_imm(OC_loop); + sub(reg_tmp, OC_loop); + jnz(oc_loop); + } + } + + if (OC_tail) { + for (size_t offset = 0; offset < OC_tail; offset += vlen) { + bool use_mask = (offset + vlen) > OC_tail; + compute(offset, offset / vlen, use_mask); + } + advance_ptrs_imm(OC_tail); + } + + rewind_ptrs(); + sub(reg_len, OC_); + cmp(reg_len, OC_); + jge(main_loop, T_NEAR); + } + } + L(main_loop_end); + + // Epilogue loop + Label epilogue_end; + { + cmp(reg_len, 0); + je(epilogue_end, T_NEAR); + + Label epilogue_loop, epilogue_loop_tail; + cmp(reg_len, vlen); + jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?) + L(epilogue_loop); { + compute(0, 0, false); + sub(reg_len, vlen); + advance_ptrs_imm(vlen); + cmp(reg_len, vlen); + jge(epilogue_loop, T_NEAR); + } + + L(epilogue_loop_tail); + mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift + mov(reg_rem_mask, 1); + shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16 + sub(reg_rem_mask, 1); + jz(epilogue_end, T_NEAR); + kmovq(kreg_rem_mask, reg_rem_mask); + compute(0, 0, true); + } + + L(epilogue_end); + + postamble(); + + ker_ = getCode<decltype(ker_)>(); +} + +template<data_type_t src_type, data_type_t dst_type> +void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::operator ()( + dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end) +{ + using math::get_bias; + + if (end <= start) + return; + + if (ker_) { + // JIT + ker_args args; + size_t oc_offset = start % OC_; + args.dst = dst + start; + args.acc = acc + start; + args.bias = bias + oc_offset * bias_data_type_size_; + args.scales = scales + scale_idx_mult_ * oc_offset; + args.nslope = nslope; + args.len = end - start; + args.oc_offset = oc_offset; + ker_(&args); + } else { + // Fallback + size_t oc = start % OC_; + for (size_t i = start; i < end; i++) { + float d = (float)acc[i]; + float b = get_bias(bias, oc, bias_data_type_); + d = d + b; + d *= scales[oc * scale_idx_mult_]; + if (do_relu_ && d < 0) + d *= nslope; + dst[i] = qz_a1b0<float, dst_data_t>()(d); + oc = (oc == OC_ - 1) ? 0 : oc + 1; + } + } +}; + +template <data_type_t src_type, data_type_t dst_type> +void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + + bool wei_tr = memory_desc_matches_one_of_tag( + *pd()->weights_md(), oiw, oihw, oidhw, oi); + + const int M = OC; + const int N = MB; + const int K = pd()->IC_total_padded(); + const int8_t off_a = 0, off_b = 0; + const int32_t off_c = 0; + + const float *scales = pd()->attr()->output_scales_.scales_; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + acc_data_t *acc = pd()->dst_is_acc_ + ? (acc_data_t *)dst + : scratchpad(ctx).template get<acc_data_t>(key_iprod_int_dat_in_acc_dt); + + const float onef = 1.0, zerof = 0.0; + gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights, + wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c); + + if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_ + || pd()->with_bias()) { + const bool force_sequential = MB * OC < 2000; + parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) { + size_t start, end; + balance211((size_t)OC * MB, nthr, ithr, start, end); + (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end); + }); + } +} + +using namespace data_type; + +template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>; +template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>; +template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>; +template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>; +template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>; +template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>; +template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>; +template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp new file mode 100644 index 0000000000..ac6a5c8f85 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp @@ -0,0 +1,166 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP +#define GEMM_X8S8S32X_INNER_PRODUCT_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "gemm/gemm.hpp" +#include "jit_generator.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, impl::data_type_t dst_type> +struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T(src_type == data_type::u8 + ? IGEMM_S8U8S32_IMPL_STR + : IGEMM_S8S8S32_IMPL_STR, + gemm_x8s8s32x_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == src_type + && dst_md()->data_type == dst_type + && weights_md()->data_type == s8 + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_, + attr()->post_ops_.entry_[0].is_relu(true, false)) + && dense_gemm_consitency_check(src_md(), weights_md(), + dst_md()); + if (!ok) return status::unimplemented; + + dst_is_acc_ = utils::one_of(dst_type, s32, f32); + + init_scratchpad(); + + return status::success; + } + + bool dst_is_acc_; + + protected: + status_t set_default_params() { + using namespace format_tag; + if (src_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md_, + utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc))); + } + if (dst_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md_, nc)); + if (weights_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md_, + utils::pick(ndims() - 2, io, wio, hwio, dhwio))); + } + return inner_product_fwd_pd_t::set_default_params(); + } + + private: + void init_scratchpad() { + if (!dst_is_acc_) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book( + memory_tracking::names::key_iprod_int_dat_in_acc_dt, + sizeof(acc_data_t) * MB() * OC()); + } + } + }; + + gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); } + ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; } + + typedef typename prec_traits<dst_type>::type data_t; + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + typedef typename prec_traits<data_type::s32>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + // XXX: this is throwaway code that will become unnecessary when we have a + // sufficiently advanced igemm jit generator that supports quantization, + // relu, and whatnot + class pp_kernel_t: jit_generator { + public: + DECLARE_CPU_JIT_AUX_FUNCTIONS( + gemm_x8s8s32x_inner_product_fwd_t::pp_kernel); + pp_kernel_t(const pd_t *pd, bool dst_is_acc); + + void operator()(dst_data_t *dst, const acc_data_t *acc, + const char *bias, const float *scales, float nslope, + size_t start, size_t end); + private: + void generate(); + + struct ker_args { + dst_data_t *dst; + const acc_data_t *acc; + const char *bias; + const float *scales; + float nslope; + size_t len; + size_t oc_offset; + }; + void (*ker_)(const ker_args *args); + + size_t OC_; + data_type_t bias_data_type_; + size_t bias_data_type_size_; + size_t scale_idx_mult_; + bool do_bias_; + bool do_relu_; + }; + + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + pp_kernel_t *pp_kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..6fa251d465 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp @@ -0,0 +1,674 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop, bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto vreg_load = [=](int i) { + return Ymm(ur * load_loop_blk + i); + }; + + auto vreg_accum = [=](int i, int j) { + return Ymm(j * load_loop_blk + i); + }; + + auto bias_ptr = [=](int i) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) + { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt]; + }; + + auto output_ptr = [=](int i, int j) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done, init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) + vmovups(vreg_accum(i, j), bias_ptr(i)); + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r = vreg_accum(i, j); + vxorps(r, r, r); + } + + L(init_done); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(vreg_load(i), load_ptr(0, i)); + vbroadcastss(vreg_bcast, bcast_ptr(0, 0)); + }; + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r = vreg_accum(i, j); + vaddps(r, r, output_ptr(i, j)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + vmovups(output_ptr(i, j), vreg_accum(i, j)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + if (mayiuse(avx2)) + vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(vtmp, vreg_bcast, vreg_load(i)); + vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp); + } + if (j == ur - 1 && !(last_block + && u == jcp.reduce_loop_unroll - 1)) + vmovups(vreg_load(i), load_ptr(u + 1, i)); + } + if (j < ur - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1)); + } + if (!last_block || u < jcp.reduce_loop_unroll - 1) + vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0)); + } + }; + + Label reduce_loop, reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i) { return Ymm(i); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r = diff_bias_reg(i); + vxorps(r, r, r); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) + vmovups(diff_bias_reg(i), diff_bias_ptr(i)); + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) + vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i)); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) + vmovups(diff_bias_ptr(i), diff_bias_reg(i)); + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_avx2_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, 8); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + + const int simd_w = 8; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.ih == jcp.oh && jcp.iw == jcp.ow + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + // TODO: remove this restriction + // optimized 1x1 bwd_w does not support Intel AVX + if (jcp.prop_kind == backward_weights && !mayiuse(avx2)) + return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx2_1x1_conv_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..bfdeb2b18d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp @@ -0,0 +1,110 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP +#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_conv_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32) + + jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_avx2_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using ymm_t = const Xbyak::Ymm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t imm_addr64 = reduce_loop_iter; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + ymm_t vreg_bcast = ymm_t(15); + ymm_t vtmp = ymm_t(14); + + jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp new file mode 100644 index 0000000000..f116ac9056 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp @@ -0,0 +1,545 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx2_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +/* convolution forward */ + +void jit_avx2_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + const int ndims = dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t<avx2>::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step = step(jcp.nb_load_blocking, + jcp.nb_load - ocb, jcp.nb_load_blocking_max); + + const int _ocb = g * nb_oc + ocb; + p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + + p.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + nb_ic_blocking * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + + if (ocb == 0) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +/* convolution backward wtr data */ + +void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + const int ndims = diff_dst_d.ndims(); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto ker = [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t<avx2>::call_params_t(); + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int load_step = 0; + for (int icb = 0; icb < jcp.nb_load; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic, + load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = start; iwork < end; iwork += bcast_step) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb = 0; ocb < jcp.nb_reduce; + ocb += jcp.nb_reduce_blocking) { + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, + ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + nb_oc_blocking * jcp.oc_block); + + kernel_->jit_ker(&p); + } + + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + }; + + parallel(0, ker); +} + +/* convolution backward wtr weights */ + +jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t( + const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) +{ + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + reducer_weights_ = + new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_); + reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_); + init_rtus_driver<avx2>(this); +} + +void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + const int ndims = diff_dst_d.ndims(); + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + const int bcast_work = div_up(nb_ic, nb_ic_blocking); + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + const int load_work = div_up(nb_oc, nb_oc_blocking); + + const int sp_dim = jcp.reduce_dim; + const int mb_sp_work = jcp.mb * sp_dim; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, + data_t *store_to, size_t store_to_ld, const data_t *diff_dst, + const data_t *src, int ithr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t<avx2>::call_params_t(); + + p.output_stride = store_to_ld * sizeof(float); + const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block; + + int oc_b_step = 0; + for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { + oc_b_step = step(12, nb_oc_blocking - oc_b, 18); + p.load_dim = oc_b_step * jcp.oc_block; + + int ic_b_step = 0; + for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { + ic_b_step = step(12, nb_ic_blocking - ic_b, 18); + p.bcast_dim = ic_b_step * jcp.ic_block; + rp.icb = p.bcast_dim / jcp.ic_block; + + p.output_data = store_to + oc_b * store_to_ld + + ic_b * jcp.ic_block * jcp.oc_block; + + /* spatial reduction */ + int sp_step = 0; + for (int sp = sp_start; sp < sp_end; sp += sp_step) { + sp_step = step(sp_step_def, sp_end - sp, 192); + p.reduce_dim = sp_step; + rp.os = p.reduce_dim; + + p.first_last_flag = sp == sp_start && first_image + ? FLAG_REDUCE_FIRST : 0; + + p.load_data = diff_dst + + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + (ic_b * jcp.is + sp) * jcp.ic_block; + if (ndims == 3) + rp.src = src + + iw * src_d.blocking_desc().strides[2]; + else + rp.src = src + + ih * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + + if (oc_b == 0) + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = src + + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_njobs = rw->balancer().ithr_njobs(ithr); + if (w_njobs == 0) return; + + /* setup: independent work (oc, ic) */ + const int w_job_start = rw->balancer().ithr_job_off(ithr); + int g{0}, load_i{0}, bcast_i{0}; + nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, + bcast_i, bcast_work); + + /* setup: reduction work (mb, sp) */ + int mb_sp_start{0}, mb_sp_end{0}; + balance211(mb_sp_work, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); + int img_start{0}, sp_start{0}; + nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); + + /* independent work */ + for (int iwork = 0; iwork < w_njobs; ++iwork) { + const int oc_b = nb_oc_blocking * load_i; + const int ic_b = nb_ic_blocking * bcast_i; + + const int _ic_b = g * nb_ic + ic_b; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + size_t store_to_ld; + + if (rw->balancer().nthr_per_group_ == 1) { + const size_t off = pd()->with_groups() + ? diff_weights_d.blk_off(g, oc_b, ic_b) + : diff_weights_d.blk_off(oc_b, ic_b); + store_to = &diff_weights[off]; + store_to_ld = jcp.ic * jcp.oc_block; + } else { + const size_t off = iwork * rw->balancer().job_size_; + store_to = + rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; + store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; + } + + /* reduction work */ + int img = img_start; + int sp = sp_start; + int sp_step = 0; + for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) + { + sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); + + const bool first_image = img == img_start; + oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, + store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)], + &src[src_d.blk_off(img, _ic_b)], ithr); + + sp = 0; + img += 1; + } + + nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i, + bcast_work); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = + rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp new file mode 100644 index 0000000000..9762242173 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_1x1_conv_kernel_f32.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t { + // TODO: (Roma) Code duplication duplication! Remove with templates + // (maybe...)! + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver<avx2>(this); + } + + ~jit_avx2_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t<avx2> *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , rtus_driver_(nullptr) + { + kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + init_rtus_driver<avx2>(this); + } + + ~jit_avx2_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + rtus_driver_t<avx2> *rtus_driver_; +}; + +struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), + jit_avx2_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, + *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); + + rtus_prepare_space_info(this, scratchpad); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; + cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const int ic_block = jcp_.bcast_block; + const int nb_ic = jcp_.nb_bcast; + const int nb_ic_blocking = jcp_.nb_bcast_blocking; + const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); + + const int oc_block = jcp_.load_block; + const int nb_oc = jcp_.nb_load; + const int nb_oc_blocking = jcp_.nb_load_blocking; + const int load_work = utils::div_up(nb_oc, nb_oc_blocking); + + const int job_size + = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; + const int njobs_x = bcast_work; + const int njobs_y = jcp_.ngroups * load_work; + + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = max_threads * job_size * 8; + + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + oc_block, jcp_.ngroups * jcp_.oc / oc_block, + jcp_.mb, max_buffer_size)); + } + + reducer_wei_conf_.init( + reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, + jcp_.mb * jcp_.nb_reduce, max_buffer_size), + job_size / nb_oc_blocking, nb_oc_blocking, ic_block, + nb_ic * ic_block * oc_block, nb_oc); + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx2_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete rtus_driver_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_1x1_conv_kernel_f32 *kernel_; + cpu_reducer_2d_t<data_type::f32> *reducer_weights_; + cpu_reducer_t<data_type::f32> *reducer_bias_; + rtus_driver_t<avx2> *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp new file mode 100644 index 0000000000..e24770a2da --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp @@ -0,0 +1,1501 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw + + (ki*dilate_w + jj*stride_w - pad_l)); + else + inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w + - pad_l)*ic_blk + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + } +} + +void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t inp_off; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) + inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw + + (jj * stride_w - pad_l)); + else + inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk + + ifm2); + vbroadcastss(Ymm(oc_blocks * ur_w + jj), + make_safe_addr(aux_reg_input, inp_off, reg_long_offt)); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = + ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk; + vmovups(ymm15, ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj++) + if (mayiuse(avx2)) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(oc_blocks * ur_w + jj), ymm15); + else { // Intel AVX support + vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, char pad_tag, + int oc_blocks, char oc_blocks_tag) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int od = jcp.od; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : ic_blk; + const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? dilate_w : ic_blk * dilate_w; + + Label init_done, init_first; + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + vmovups(Ymm(ur_w * ii + jj), + make_safe_addr(reg_output, offt, reg_long_offt)); + } + } + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + vmovups(Ymm(ur_w * ii + jj), + yword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj)); + } + + L(init_done); + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + Label skip_kh_loop, skip_kd_loop, kd_loop; + if (jcp.ndims == 5) { + push(reg_output); + push(oi_iter); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_input); + + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) { + cmp(reg_ki, 0); + je(skip_kd_loop, T_NEAR); + } + L(kd_loop); + mov(kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_input, aux_reg_inp_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks, + oc_blocks_tag); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult); + add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + pop(reg_output); + } + + Label regular_store; + + if (jcp.with_eltwise) { + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off + = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk; + Ymm reg_out = Ymm(ur_w * ii + jj); + vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out); + } + } +} + +inline void jit_avx2_conv_fwd_kernel_f32::solve_common( + int oc_blocks, char oc_blocks_tag) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, + 'l', oc_blocks, oc_blocks_tag); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, + 'l', oc_blocks, oc_blocks_tag); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, + 'm', oc_blocks, oc_blocks_tag); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, + 'r', oc_blocks, oc_blocks_tag); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, + 't', oc_blocks, oc_blocks_tag); // "tail" +} + +void jit_avx2_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + if (jcp.nb_oc > jcp.nb_oc_blocking) { + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + L(exit); + } else if (jcp.nb_oc == jcp.nb_oc_blocking) { + solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking); + } else { + solve_common(nb_oc_tail, '0' + nb_oc_tail); + } + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(avx)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu) + return status::unimplemented; + } + + const int simd_w = 8; + const bool flat = jcp.ic < simd_w; + const bool mimo = !flat; + + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off = + one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c); + if (!args_ok) return status::unimplemented; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively + // Thus, we can only assign 14 or 15 YMMs for data storage + const int num_avail_regs = mayiuse(avx2) ? 15 : 14; + if (!mayiuse(avx2)) { + if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) { + // current register assignment requires more YMMs than available + // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad + if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1) + jcp.ur_w -= 1; + else + for (int b = 3; b > 1; b--) + if (jcp.nb_oc % b == 0) { + jcp.nb_oc_blocking = b; + break; + } + } + } + + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +void jit_avx2_conv_fwd_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow, + int r_overflow) +{ + int kw = jcp.kw; + int kh = jcp.kh; + int kd = jcp.kd; + int iw = jcp.iw; + int ih = jcp.ih; + int id = jcp.id; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + Label kd_loop, skip_kd_loop; + Label oc_loop, skip_oc_loop; + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(ur_w * ii + jj)); + } + + if (one_of(jcp.ndims, 3, 4)) { + cmp(reg_channel_work, 0); + jle(skip_oc_loop, T_NEAR); + xor_(reg_channel, reg_channel); + + mov(aux_reg_ddst_oc_loop, reg_ddst); + mov(aux_reg_kernel_oc_loop, reg_kernel); + + L(oc_loop); + mov(aux_reg_ddst, aux_reg_ddst_oc_loop); + mov(aux_reg_kernel, aux_reg_kernel_oc_loop); + } + + if (jcp.ndims == 5) { + assert(jcp.nb_oc_blocking == 1); + push(oi_iter); + + mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_ddst); + mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]); + + L(kd_loop); + mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]); + } else { + mov(kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_ddst, aux_reg_dst_d); + mov(aux_reg_kernel, aux_reg_ker_d); + } + + Label kh_loop, skip_kh_loop; + cmp(kj, 0); + jle(skip_kh_loop, T_NEAR); + L(kh_loop); { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); // 0; + int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) { + + for (int jj = jj_start ; jj < jj_end; jj += stride_w) { + int aux_output_offset + = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2; + vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w), + ptr[aux_reg_ddst + + sizeof(float) * aux_output_offset]); + } + + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset + = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block + + ki * jcp.ic_block * jcp.oc_block + + ofm2 * jcp.ic_block; + vmovups(ymm15, + ptr[aux_reg_kernel + + sizeof(float) * aux_kernel_offset]); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + vfmadd231ps(Ymm(ur_w * ii + jj), + Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15); + } + } + } + add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block + * ic_block); + sub(aux_reg_ddst, sizeof(float) * ow * oc_block); + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + L(skip_kh_loop); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, + sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_loop, T_NEAR); + L(skip_kd_loop); + + pop(oi_iter); + } + + if (one_of(jcp.ndims, 3, 4)) { + int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow + * jcp.oc_block; + int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw + * jcp.ic * jcp.oc_block; + + add(aux_reg_ddst_oc_loop, ddst_oc_shift); + add(aux_reg_kernel_oc_loop, kernel_oc_shift); + + inc(reg_channel); + cmp(reg_channel, reg_channel_work); + jl(oc_loop, T_NEAR); + + L(skip_oc_loop); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + } + + Label no_update_label; + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int ii = 0; ii < nb_ic_block; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(Ymm(15), + make_safe_addr(reg_dsrc, offt, reg_long_offt)); + vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), + Ymm(15)); + + } + } + L(no_update_label); + + for (int ii = 0; ii < nb_ic_block; ii++) + for (int jj = 0; jj < ur_w; jj++) { + size_t offt = + sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block; + vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt), + Ymm(ur_w * ii + jj)); + } +} + +void jit_avx2_conv_bwd_data_kernel_f32::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]); + + int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block; + int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block; + + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + int r_overflow = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad)) / jcp.stride_w); + int r_overflow1 = nstl::max(0, (jcp.kw - 1 + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) + n_oi--; + + if (jcp.ur_w == jcp.iw) { + compute_loop(jcp.ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(jcp.ur_w, l_overflow, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } else { + xor_(oi_iter, oi_iter); + if (l_overflow > 0) { + compute_loop(jcp.ur_w, l_overflow, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + } + + if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop; + L(ow_loop); { + compute_loop(jcp.ur_w, 0, 0); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + inc(oi_iter); + cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR); + } + } + + if (r_overflow1 > 0 ) { + compute_loop(jcp.ur_w, 0, r_overflow1); + add(reg_dsrc, dsrc_shift); + add(reg_ddst, ddst_shift); + } + + if (jcp.ur_w_tail != 0) + compute_loop(jcp.ur_w_tail, 0, r_overflow); + } + + this->postamble(); +} + +status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + + int ndims = diff_src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int simd_w = 8; + + /* derivatives */ + jcp.idp = jcp.id + 2 * jcp.f_pad; + jcp.ihp = jcp.ih + 2 * jcp.t_pad; + jcp.iwp = jcp.iw + 2 * jcp.l_pad; + jcp.ohp = jcp.oh; /* do we really need */ + jcp.owp = jcp.ow; /* padded output ??? */ + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + /* gemm-based convolution performs better in these cases */ + if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1) + return status::unimplemented; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + if (jcp.oc % jcp.oc_block) return status::unimplemented; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.nb_ic_blocking = 1; + jcp.nb_oc_blocking = 1; + jcp.ur_w = 1; + + if(one_of(ndims, 3, 4) && jcp.ow < 40) + jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2; + + if (ndims == 3) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c); + jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + + bool args_ok = true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i, + gOIdhw8o8i, OIdhw8o8i) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && jcp.stride_w == jcp.stride_h + && jcp.stride_d == 1 + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.ic % simd_w == 0 + && jcp.oc % simd_w == 0 + && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) return status::unimplemented; + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad; + int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w); + + const int max_regs = 15; /* Maximun number of registers available for + result accumulation and delta dst data. + One additional register is reserved for weights + data. */ + + /* Find the best blocking with maximum number of fma instructions + per ur_w * nb_ic_blocking compute loops. Number of required registers + is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs. + ur_w must be divisible by stride_w */ + if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers + distribution exceeds max_regs */ + return status::unimplemented; + + int best_nfmas = 0; + for (int b = 1; b <= 4; b++) + { + if (jcp.nb_ic % b != 0) + continue; + + for (int u = jcp.stride_w; + u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w; + u += jcp.stride_w) + { + int ur_w = nstl::min(u, jcp.iw); + /* maximum 1 step with l_overflow so far */ + if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw) + continue; + int nfmas = utils::div_up(ur_w, jcp.stride_w) * b; + if (nfmas > best_nfmas + || (nfmas == best_nfmas && jcp.ur_w < ur_w)) { + jcp.ur_w = ur_w; + jcp.nb_ic_blocking = b; + best_nfmas = nfmas; + } + } + } + if (best_nfmas == 0) /* can't find appropriate blocking */ + return status::unimplemented; + + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + /* maximum 1 ur_w block with r_overflow so far */ + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + return status::success; +} + +void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +void jit_avx2_conv_bwd_weights_kernel_f32::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + compute_oh_loop_common(); + this->postamble(); +} + +status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(avx2)) return status::unimplemented; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c); + } else if (ndims == 5) { + jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c); + jcp.wei_tag = diff_weights_d.matches_one_of_tag( + Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c); + } + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + const int simd_w = 8; + + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id + - jcp.f_pad); + if (ndims == 5) + if (jcp.f_pad != 0 || back_pad != 0) + return status::unimplemented; + + const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); + const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1); + const bool boundaries_ok = true + && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad + && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad; + if (!boundaries_ok) + return status::unimplemented; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + if (mimo) + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + bool args_ok = true + && IMPLICATION(flat, true + && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o, + gOdhwi8o)) + && IMPLICATION(mimo, true + && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o, + OIdhw8i8o, gOIdhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c) + && IMPLICATION(mimo, jcp.ic % simd_w == 0) + && jcp.oc % simd_w == 0 + && jcp.kw < 14 + && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */ + && jcp.kh <= jcp.ih /* [bwd_w:r2] */ + && jcp.kd <= jcp.f_pad + jcp.id + && jcp.kd <= jcp.id + && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */ + && jcp.dilate_d == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + return status::success; +} + +void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_loop; + mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0 + L(kd_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult); + sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block + * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + mov(kj, reg_kh); + Label kh_comeback_loop; + L(kh_comeback_loop); { + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + sub(reg_input, sizeof(float) * jcp.iw * inp_mult); + sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_loop, T_NEAR); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset, + int kernel_offset, int output_offset) +{ + const int kw = jcp.kw; + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + vmovups(Ymm(kw * ic_block_step + 0), + yword[reg_output + + sizeof(float) * i_ur * oc_block + output_offset]); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw; + if (i_iw - pad_l < 0 + || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r) + continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t i_off = (size_t)input_offset + sizeof(float)*( + one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? (i_iw - pad_l) + i_ic + * ((size_t)jcp.id * jcp.ih * jcp.iw) + : (i_iw - pad_l) * ic_block + i_ic); + vbroadcastss(Ymm(kw * ic_block_step + 1), + make_safe_addr(reg_input, i_off, reg_long_offt)); + vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic), + Ymm(kw * ic_block_step + 0), + Ymm(kw * ic_block_step + 1)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + size_t off + = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block + + kernel_offset; + vmovups(yword[reg_kernel + off], + Ymm(i_kw * ic_block_step + i_ic)); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp() +{ + int ic_block_step; + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block; + } else { + ic_block_step = jcp.kw > 7 ? 1 + : jcp.kw > 3 ? 2 + : jcp.kw > 1 ? 4 : 8; + } + + const int max_ur_w = jcp.ow > 56 ? 14 : 28; + + if (jcp.ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + od_step_comeback_pointers(); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } else { + oh_step_comeback_pointers(); + } +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad + = nstl::max(0, + (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0, + 0, 0); + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, ic_block); + jl(ic_block_loop, T_NEAR); + } + if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + const int ic_block = jcp.ic_block; + const int oc_block = jcp.oc_block; + const int stride_w = jcp.stride_w; + int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block; + Label kd_loop; + + const int r_pad = jcp.r_pad; + + int ur_w = nstl::min(jcp.ow, max_ur_w); + int ur_w_trips = jcp.ow / ur_w; + int ur_w_tail = jcp.ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block; + + int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + mov(ki, jcp.kd); + L(kd_loop); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + Label kh_loop; + L(kh_loop); { + xor_(b_ic, b_ic); + Label ic_block_loop; + L(ic_block_loop); { + if (jcp.l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, + jcp.l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) + * (ur_w * stride_w - jcp.l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + Label ow_block_loop; + L(ow_block_loop); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_loop, T_NEAR); + } + } + + if (ur_w_tail > 0) + compute_ic_block_step(ur_w_tail, + 0, r_pad, ic_block_step, 0, 0, 0); + + sub(reg_input, sizeof(float) * input_comeback); + sub(reg_output, sizeof(float) * output_comeback); + + size_t inp_icblk_stride = sizeof(float) * ic_block_step + * (one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? jcp.id*jcp.ih*jcp.iw : 1); + safe_add(reg_input, inp_icblk_stride, reg_long_offt); + add(reg_kernel, sizeof(float) * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_loop, T_NEAR); + } + if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) { + size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, offt, reg_long_offt); + add(reg_input, sizeof(float) * jcp.iw); + } else { + add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block); + } + add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_loop, T_NEAR); + } + +} + +inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common() +{ + const int icoc_block = jcp.ic_block * jcp.oc_block; + const int t_pad = jcp.t_pad; + const int stride_h = jcp.stride_h; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) + ? 1 : jcp.ic_block; + int b_pad = jcp.b_pad; + + Label oh_tpad_loop, oh_loop, oh_loop_end; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + if (t_pad > 0) { + assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */ + mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih); + add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block); + + L(oh_tpad_loop); { + compute_oh_step_disp(); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + add(reg_kh, stride_h); + + /* the overlap between input and kernel may not reach kernel size. + * so far we do not support that (until we put constant here) */ + const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */ + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_loop, T_NEAR); + } + + if (t_pad % stride_h != 0) { + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block); + add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult); + } + } + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_loop, T_NEAR); + + mov(reg_kh, jcp.kh); + L(oh_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1); + jge(oh_loop_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_loop, T_NEAR); + } + L(oh_loop_end); + if (b_pad > 0) { + Label oh_bpad_loop, oh_bpad_loop_end; + cmp(reg_oj, jcp.oh); + jge(oh_bpad_loop_end, T_NEAR); + + mov(reg_kh, jcp.ih + t_pad); + sub(reg_kh, reg_ih_count); + L(oh_bpad_loop); { + compute_oh_step_disp(); + add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult); + add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block); + + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_loop_end, T_NEAR); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_loop, T_NEAR); + } + L(oh_bpad_loop_end); + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp new file mode 100644 index 0000000000..412c50c9ee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP +#define JIT_AVX2_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_conv_fwd_kernel_f32: public jit_generator { + jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_avx2_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t aux_reg_inp_d = r11; + reg64_t aux_reg_ker_d = abi_not_param1; + + reg64_t reg_ki = rsi; + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = r15; + reg64_t reg_long_offt = r15; + Xbyak::Reg32 reg_ci_flag = r13d; + + Xbyak::Ymm ytmp = Xbyak::Ymm(14); + + jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, + char pad_label, int oc_blocks, char oc_blocks_label); + inline void solve_common(int oc_blocks, char oc_blocks_label); + + void generate(); +}; + +struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) + + jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t reg_dsrc = rsi; + reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only + reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 + case only */ + + reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only + reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only + + reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only + reg64_t kj = r11; + reg64_t oi_iter = r12; + reg64_t reg_kh = r14; + reg64_t reg_channel = r13; // used in ndims < 5 case only + reg64_t reg_channel_work = r9; // used in ndims < 5 case only + reg64_t reg_long_offt = r15; + + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) + + jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_tmp = r11; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t ki = r14; + reg64_t reg_long_offt = r11; + + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset); + inline void compute_oh_step_disp(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_loop_common(); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp new file mode 100644 index 0000000000..13f61e84fe --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp @@ -0,0 +1,410 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx2_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, d, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (pd()->ndims() == 4) \ + ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ + (pd()->ndims() == 3) \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : (pd()->ndims() == 4) \ + ? wht_blk_off_(f, g, oc, ic, kh, kw) \ + : wht_blk_off_(f, g, oc, ic, kd, kh, kw) + +void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od + * jcp.oh; + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const int dj = od * jcp.stride_d; + const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); + const int d_b_overflow = nstl::max(jcp.id, dj + + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + + const int id = nstl::max(dj - jcp.f_pad + + div_up(d_t_overflow, + (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0); + + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, id, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wd, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + + const int kd_padding = jcp.kd + - div_up(d_t_overflow, (jcp.dilate_d + 1)) + - div_up(d_b_overflow, (jcp.dilate_d + 1)); + par_conv.kd_padding = nstl::max(0, kd_padding); + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + od, jcp.od, oh, jcp.oh); + } + icbb += icb_step; + } + }; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, ker); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +void jit_avx2_convolution_bwd_data_t::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; + int ih_block_size = jcp.ih; + int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; + if (work_amount < (size_t)2 * mkldnn_get_max_threads()) { + ih_block_size = 1; + num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); + work_amount *= num_ih_blocks; + } + + auto ker = [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + size_t n{0}, g{0}, icbb{0}, ihb{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, + ihb, num_ih_blocks); + for (size_t iwork = start; iwork < end; ++iwork) { + for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) + for (int id = 0; id < jcp.id; ++id) { + auto par_conv = jit_conv_call_s(); + + const int idp = jcp.id + 2 * jcp.f_pad; + const int d_t_overflow = nstl::max(0, + jcp.kd - 1 - id - jcp.f_pad); + const int back_pad = idp - jcp.id - jcp.f_pad; + const int d_b_overflow = nstl::max(0, + jcp.kd - 1 - (jcp.id - 1 - id) - back_pad); + const int od = id + jcp.f_pad - d_b_overflow; + + int ih_start = ihb * ih_block_size; + int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); + for (int ih = ih_start; ih < ih_end; ++ih) { + + const int i_t_overflow = nstl::max(0, (jcp.kh - 1 + - ih - jcp.t_pad) / jcp.stride_h); + const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih + + ih - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ih) % jcp.stride_h); + int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; + + par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; + par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow - i_b_overflow; + par_conv.kw_padding = 0; + + const int k_lo = overflow_kh_lo + + i_b_overflow * jcp.stride_h; + const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; + + par_conv.src = &diff_src[src_blk_off(diff_src_d, n, + /*jcp.ic == 3 ? 0 :*/ + g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)]; + par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, + n, g * jcp.nb_oc + oc, od, oh, 0)]; + par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, + jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb, + d_b_overflow, k_lo, 0)]; + + par_conv.src_prf = nullptr; + par_conv.dst_prf = nullptr; + par_conv.filt_prf = nullptr; + par_conv.channel = oc; + par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc, + jcp.nb_oc_blocking); + + kernel_->jit_ker(&par_conv); + } + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, + num_ih_blocks); + } + }; + + parallel(0, ker); +} + +void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto scratchpad = this->scratchpad(ctx); + + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in; + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_wei); + auto rw = this->reducer_weights_; + rw->init(reducer_wei_scratchpad); + + auto ker = [&](int ithr, int nthr) { + assert(nthr == rw->balancer().nthr_); + + const int w_job_start = rw->balancer().ithr_job_off(ithr); + const int w_njobs = rw->balancer().ithr_njobs(ithr); + + if (w_njobs == 0) return; + + /* reduction dimension */ + int img_od_start{0}, img_od_end{0}, img{0}, od_s{0}; + balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, + rw->balancer().id_in_group(ithr), img_od_start, img_od_end); + + int img_start = img_od_start, img_end = img_od_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + /* jobs */ + int g_start{0}, ocb_start{0}, icb_start{0}; + nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc, icb_start, jcp.nb_ic); + + while (img_start < img_end) { + int g = g_start, ocb = ocb_start, icb = icb_start; + + const int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int idp = jcp.id + jcp.f_pad + jcp.back_pad; + + if (id_s < idp - jcp.back_pad - jcp.kd + 1) + for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + /* TODO: put dw <-- 0 in kernel */ + if (img == img_first) + array_set(rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_, 0, + rw->balancer().job_size_); + + for (int od = od_s; od < od_e; ++od) { + const int id = od * jcp.stride_d; + if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; + + auto par_conv = jit_conv_call_s(); + par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; + par_conv.dst = + &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)]; + par_conv.filt = rw->get_local_ptr(ithr, diff_weights, + reducer_wei_scratchpad) + + w_job_loc * rw->balancer().job_size_; + + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb, + jcp.nb_ic); + } + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } + rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, + jcp.nb_oc); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 8; ++o) + d_bias[o] = 0.; + + for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 8; ++o) + d_bias[o] += d_dst[o]; + d_dst += 8; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(0, [&](const int ithr, const int nthr) { + ker(ithr, nthr); + if (pd()->with_bias()) + ker_bias(ithr, nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + for (int oc = 0; oc < jcp.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp new file mode 100644 index 0000000000..bb65bce79c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP +#define CPU_JIT_AVX2_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx2_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx2_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_, + *desc(), src_md(), weights_md(), dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() < 8; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_avx2_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_fwd_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) + : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx2_convolution_bwd_data_t() { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_data_kernel_f32 *kernel_; +}; + +struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx2, ""), + jit_avx2_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf( + jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + auto reducer_wei_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_wei); + reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); + + return status::success; + } + + jit_conv_conf_t jcp_; + cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; + cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_; + + protected: + bool set_default_formats() { + using namespace format_tag; + const bool flat = IC() == 3; + + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + + private: + void init_balancers() { + const int max_threads = mkldnn_get_max_threads(); + const size_t max_buffer_size = 1<<21; /* just a heuristic */ + + if(with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(max_threads, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + + reducer_wei_conf_.init(reduce_balancer_t(max_threads, + jcp_.kd * jcp_.kh * jcp_.kw + * jcp_.ic_block * jcp_.oc_block, + jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, + jcp_.mb * jcp_.od, max_buffer_size)); + } + }; + + jit_avx2_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr) + , reducer_weights_(nullptr) + , reducer_bias_(nullptr) + { + kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_); + reducer_bias_ = + new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_); + reducer_weights_ = + new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_); + } + + ~jit_avx2_convolution_bwd_weights_t() { + delete kernel_; + delete reducer_weights_; + delete reducer_bias_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx2_conv_bwd_weights_kernel_f32 *kernel_; + cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..635b83b2bf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp @@ -0,0 +1,1255 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <float.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" +#include "cpu_barrier.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_common_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); + + if (jcp.ver == ver_4fma) + { + Label bcast_loop; + Label bcast_loop_wraparound; + Label bcast_loop_out; + Label bcast_loop_ur_full; + + cmp(bcast_loop_iter, jcp.ur); + jle(bcast_loop_wraparound, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jg(bcast_loop, T_NEAR); + } + + L(bcast_loop_wraparound); + if (jcp.ur_tail) { + je(bcast_loop_ur_full, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + jmp(bcast_loop_out, T_NEAR); + } + L(bcast_loop_ur_full); + reduce_loop(load_loop_blk, jcp.ur, 0, true); + L(bcast_loop_out); + } + else + { + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } + } +} + +void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load, int i_fma) { + return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step) + + jcp.fma_step * i_load + i_fma); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_out * jcp.oc_block * i_load); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + int offt; + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + offt = (i_reduce == jcp.reduce_loop_unroll) + ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll + : i_ur * jcp.reduce_loop_unroll + i_reduce; + } else { + if (jcp.transpose_src) { + const int reduce_group = i_reduce / 4; + const int reduce_shift = i_reduce % 4; + offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift; + } + else + offt = i_reduce * jcp.ic_block + i_ur; + } + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int offt; + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) + return EVEX_compress_addr(aux_reg_output_data, + (i_load * jcp.bcast_dim + i_ur) * jcp.load_block + * jcp.typesize_out); + else + return ptr[aux_reg_output_data + + (i_load + ? reg_output_stride * i_load + : 0) // TODO: Xbyak should allow 0 scale + + jcp.typesize_out * jcp.load_block * i_ur]; + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_sum) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + for (int i_ur = 0; i_ur < ur; ++i_ur) { + mic_prefetcht1(output_ptr(i_load, i_ur)); + } + } + } + + if (jcp.with_bias + && one_of(jcp.prop_kind, forward_training, forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero, T_NEAR); + + for (int i_load = 0; i_load < load_loop_blk; i_load++) + for (int i_ur = 0; i_ur < ur; ++i_ur) + vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load)); + jmp(init_done, T_NEAR); + } + + L(init_zero); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + L(init_done); + }; + + auto store = [=]() { + Label store_noadd; + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + auto r = vreg_accum(i_load, i_ur); + vaddps(r, r, output_ptr(i_load, i_ur)); + } + + L(store_noadd); + if (jcp.with_eltwise) { + Label store_noeltwise; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_noeltwise, T_NEAR); + + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + L(store_noeltwise); + } + + auto store_output = [=](bool output_is_aligned) { + for (int i_ur = 0; i_ur < ur; ++i_ur) + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + if (output_is_aligned && jcp.use_vmovntps) + vmovntps(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + else + vmovups(output_ptr(i_load, i_ur), + vreg_accum(i_load, i_ur)); + }; + + Label unaligned_store, end_store; + test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + }; + + auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load, + bool last_block, bool wraparound, int reduce_step) + { + bool pf_ker_l1 = true; + bool pf_ker_l2 = wraparound; + int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk; + int i_op = (i_reduce / reduce_step) * ur * load_loop_blk + + i_ur * load_loop_blk + i_load; + + int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0; + int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0; + int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur; + + int pf_inp_ops = n_ops / 2; // # of operations during which to pf input + int pf_inp_trigger; + if (jcp.prop_kind == backward_weights) + pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block); + else + pf_inp_trigger = nstl::max(1, pf_inp_ops / ur); + + int n_other_pf = + load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1); + int n_other_pf_ops = n_ops - pf_inp_ops; + int other_pf_trigger + = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0; + + if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) { + // input prefetches have the highest priority b/c the + // first iteration of the kernel block touches all the + // cache lines + int i_pf = i_op / pf_inp_trigger; + auto pf_reg = wraparound && last_block + ? reg_bcast_data + : (last_block ? aux1_reg_bcast_data + : aux_reg_bcast_data); + int offt = i_pf; + if (jcp.prop_kind == backward_weights) { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.is : jcp.reduce_block); + offt *= jcp.bcast_block; + } else { + offt += wraparound && last_block + ? 0 + : (last_block ? jcp.ur : jcp.bcast_dim); + offt *= jcp.reduce_block; + } + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_op >= pf_inp_ops && n_other_pf) { + // remaining prefetches are spread among the rest of the + // operations; prefetches for output take priority + // TODO: spread L2 prefetches among L1 prefetches + i_op -= pf_inp_ops; + if (i_op % other_pf_trigger == 0) { + int i_pf = i_op / (load_loop_blk * other_pf_trigger); + if (i_pf < n_pf_ker_l2) { + int offt = (i_pf + (i_load + 1) * jcp.reduce_dim) + * jcp.load_block; + mic_prefetcht1(ptr[aux_reg_load_data + + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) { + i_pf -= n_pf_ker_l2; + auto pf_reg = last_block ? reg_load_data + : aux_reg_load_data; + int offt = (i_pf + i_load * jcp.reduce_dim + + (last_block + ? (wraparound ? jcp.reduce_dim : 0) + : jcp.reduce_block)) + * jcp.load_block; + mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]); + } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) { + i_pf -= n_pf_ker_l1 + n_pf_ker_l2; + int offt = i_pf * jcp.load_block; + mic_prefetcht0(ptr[aux_reg_output_data + + offt * jcp.typesize_out]); + } + } + } + }; + + auto fma_block = [=](bool last_block) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + + int reduce_step = jcp.fma_step; + + for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + // if transposed input data used and if spatial size is + // not divided by transpose step (4) then for last reduce step + // we should load only needed load_registers data + // and clear remaining + if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block + && i_reduce == jcp.reduce_loop_unroll - reduce_step) { + Label load_all; + Label load_finish; + test(reg_reduce_pos_flag, FLAG_SP_LAST); + jz(load_all, T_NEAR); + + const int n_loads = jcp.is % jcp.fma_step; + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + if (i_fma < n_loads) + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + else + vpxord(vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma), + vreg_load(i_load, i_fma)); + } + jmp(load_finish); + + L(load_all); + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + L(load_finish); + } else { + for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) { + vmovups(vreg_load(i_load, i_fma), + load_ptr(i_reduce + i_fma, i_load)); + } + } + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + if (jcp.ver == ver_4fma) + v4fmaddps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, false)); + else if (jcp.ver == ver_avx512_core && jcp.expl_bcast + && load_loop_blk > 1) + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), vreg_bcast); + else + vfmadd231ps(vreg_accum(i_load, i_ur), + vreg_load(i_load, 0), + bcast_ptr(i_reduce, i_ur, true)); + prefetch_callback(ur, i_reduce, i_ur, i_load, + last_block, wraparound, reduce_step); + } + } + } + }; + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} + +void jit_avx512_common_1x1_conv_kernel::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + sub(rsp, stack_space_needed); + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (one_of(jcp.prop_kind, forward_training, forward_inference)) + mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha)); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.bcast_dim * jcp.load_block * + jcp.typesize_out); + break; + case backward_weights: + for (int i_load = 0; i_load < load_loop_blk; i_load++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 }; + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 }; + + const int size_ur_cases_fma + = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + sizeof(ur_cases_fma_expl_bcast) : + sizeof(ur_cases_fma_embd_bcast); + const int size_ur_cases_4fma = sizeof(ur_cases_4fma); + + const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ? + ur_cases_fma_expl_bcast : + ur_cases_fma_embd_bcast; + const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma; + const int num_ur_cases = + (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma) + / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_common)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, + format_kind::undef, cd.diff_bias_desc.format_kind) + != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c); + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ic_block = jcp.oc_block = simd_w; + jcp.transpose_src = false; + + if (everyone_is(data_type::f32, src_d.data_type(), + weights_d.data_type(), dst_d.data_type())) + { + const int is_bwd_d = jcp.prop_kind == backward_data; + format_tag_t wei_tag = with_groups + ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, + gOIhw16i16o, gIOhw16o16i) + : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, + OIhw16i16o, IOhw16o16i); + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) && + ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4 + == 0) { + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops) + && !reduce_src + /* Heuristic condition for relation of src size to oc. Otherwise + the src transposition overhead exceed the benefit from 4fma + */ + && ((jcp.is * jcp.ic) / jcp.oc <= 2048) + && mkldnn_thr_syncable() + ) + { + jcp.transpose_src = true; + jcp.ver = ver_4fma; + jcp.fma_step = 4; + } else { + jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma; + jcp.fma_step = 1; + } + jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type); + jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type); + } else { + return status::unimplemented; + } + + /* once all the formats are set, check the padding consistency */ + args_ok = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + const int SMALL_SPATIAL = 10; + const int BIG_SPATIAL = 28; + const int BIG_REDUCE_DIM = 1024; + const int BIG_LOAD_DIM = 256; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + int reduce_blocking_max{ 0 }; + + jcp.load_grp_count = 1; + + const int L1_capacity = get_cache_size(1, true) / sizeof(float); + const int L2_size = get_cache_size(2, true) / sizeof(float); + const int L2_capacity = (L2_size * 3) / 4; + + if (one_of(jcp.prop_kind, forward_training, forward_inference, + backward_data)) { + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + } else { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.ic_block; + + jcp.bcast_dim = jcp.os; + } + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + // adjusting registry blocking + int max_regs, min_regs, size_treshold, ur_step; + const int spatial + = (one_of(jcp.prop_kind, forward_training, forward_inference)) ? + jcp.oh : + jcp.ih; + if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) { + max_regs = 9; + min_regs = 6; + size_treshold = 14; + ur_step = 1; + jcp.expl_bcast = true; + + if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM + && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { + max_regs = 6; + min_regs = 5; + } + } else { + max_regs = jcp.ver == ver_4fma ? 28 : 30; + min_regs = 9; + size_treshold = jcp.ver == ver_4fma ? 28 : 14; + ur_step = jcp.ver == ver_4fma ? 4 : 1; + jcp.expl_bcast = false; + jcp.use_vmovntps = true; + } + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i -= ur_step) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in; + + jcp.bcast_block = jcp.ur; + + jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_iter_step = jcp.load_block; + + if (jcp.prop_kind == backward_data) + jcp.loop_order = loop_lbr; + else + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + int nb_load = div_up(jcp.load_dim, jcp.load_block); + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast) { + if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL + && spatial < BIG_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 80); + else if (spatial > SMALL_SPATIAL) + reduce_blocking = nstl::min(jcp.reduce_dim, 512); + else + reduce_blocking = nstl::min(jcp.reduce_dim, 256); + + if ((jcp.mb > 28 && spatial >= 28) + || (jcp.mb > 112 && spatial >= 17)) + jcp.use_vmovntps = true; + else + jcp.use_vmovntps = false; + } else { + + reduce_blocking = nb_reduce; + if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + else if (spatial > SMALL_SPATIAL + && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 8; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + // Check input data cache aliasing. + // For other ISA constants may be updated. + // 64 * 1024 is chosen due to 1MB L2 16-way cache. + // 7 is empirical value. It is about half of 16. + // So we leave about half of the set for other data - weights, dst + int way_size = (64 * 1024) / jcp.typesize_in; + int max_hits = 7; + if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { + int nrb = reduce_blocking / simd_w; + int sp = jcp.bcast_dim; + int wl = way_size / simd_w; + for (int start_off = 0; start_off < jcp.ur; start_off++) { + for (int off = start_off, hits = 0; off < sp * nrb; off += wl) { + if (off % sp >= jcp.ur || ++hits < max_hits) + continue; + int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp); + reduce_blocking + = nstl::min(reduce_blocking, max_r_blocking); + break; + } + } + } + + if (reduce_blocking < jcp.reduce_dim) { + jcp.use_vmovntps = false; + if (jcp.prop_kind == backward_data) + jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; + else + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + } + load_blocking = jcp.load_dim; + + int load_size = jcp.load_dim * jcp.reduce_dim; + int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; + + if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads + && nb_load * nb_bcast > nthreads) { + // Some heuristic here + float calc_koef = 0.01, best_cost = FLT_MAX; + int n_lgc = nthreads; + float ratio = (float)load_size / (float)bcast_size; + int best_lgc = ratio > 1 ? n_lgc : 1; + auto calc_job_cost = [&](int lb, int tg, float mem_k) { + int bb_size = jcp.mb * div_up(nb_bcast, tg); + float calc_size = (float)(bb_size * jcp.ur) + * (lb * jcp.load_block) * jcp.reduce_dim; + float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) + * jcp.reduce_dim; + return calc_koef * calc_size + mem_k * mem_size; + }; + for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { + lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; + int min_lb = nb_load / lgc; + int max_lb = div_up(nb_load, lgc); + int min_tg = nthreads / lgc; + int max_tg = div_up(nthreads, lgc); + // Some heuristic here + float mem_koef = (max_tg == 1) ? 1.f : 1.3f; + float job_cost = 0.; + if (nthreads % lgc < nb_load % lgc) { + job_cost = calc_job_cost(max_lb, min_tg, mem_koef); + } else { + auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); + auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); + job_cost = nstl::max(job_cost1, job_cost2); + } + + if (job_cost < best_cost) { + best_lgc = lgc; + best_cost = job_cost; + } + } + jcp.load_grp_count = best_lgc; + load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } else { + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + } + + if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64 + && load_size >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); + load_blocking = jcp.load_block; + } + + if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim + && jcp.oh * jcp.ow > 64 + && IMPLICATION(reduce_src, jcp.load_dim < 1024)) { + /* Looking for best loading dimension blocking + * to get the best thread and data read/write efficiency + * by finding the optimal 'load_chunk' value + * Example: + * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512 + * the 'best' load_chunk value should be 1 + * TODO: remove heuristic constants in above condition + * TODO: check this blocking for other ISA + */ + float best_eff = -1.f; + int best_lgc = 1; + + for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { + int lgc = div_up(nb_load, load_chunk); + if (lgc > nthreads) + continue; + int thr_per_grp = div_up(nthreads, lgc); + int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp) + * jcp.bcast_block; + int load_per_thr = load_chunk * simd_w; + float data_norm = (bcast_per_thr + load_per_thr) / 2.f; + float data_eff = (bcast_per_thr * load_per_thr) + / (data_norm * data_norm); + float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc) + / div_up(nthreads, lgc); + float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) + / rnd_up(jcp.mb * nb_bcast, thr_per_grp); + float thr_eff = thr_eff_over_grp * thr_eff_in_grp; + float load_eff = (float)nb_load / rnd_up(nb_load, lgc); + float overall_eff = data_eff + thr_eff + load_eff; + if (overall_eff > best_eff) { + best_eff = overall_eff; + best_lgc = lgc; + } + } + jcp.load_grp_count = best_lgc; + load_blocking + = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; + } + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) + * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + } else if (jcp.prop_kind == backward_weights) { + + jcp.use_vmovntps = false; + if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma) + jcp.use_vmovntps = true; + + if (jcp.transpose_src) + jcp.reduce_dim = jcp.tr_is; + else + jcp.reduce_dim = jcp.is; + + if (jcp.ver == ver_4fma) { + // reduce_block should be divided by fma_step + jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4); + } else { + jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); + if (jcp.reduce_dim % jcp.reduce_block != 0) + jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); + if (jcp.reduce_block > 256) { + jcp.reduce_block = 1; + } + + } + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) { + // if reduce_block is big then generated JIT code may be big + // for small values of ur because reduce_loop_unroll = reduce_block + jcp.ur = jcp.bcast_block / 2; + jcp.expl_bcast = true; + } else { + jcp.ur = jcp.bcast_block; + jcp.expl_bcast = false; + } + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in; + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = + jcp.oc_block * jcp.ic_block * jcp.typesize_out; + jcp.bcast_loop_output_substep = + jcp.oc_block * jcp.ur * jcp.typesize_out; + jcp.bcast_loop_bcast_step = + jcp.ic_block * jcp.reduce_dim * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in; + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + balance(jcp, nthreads); + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + load_blocking = best_divider(load_blocking, 16, load_blocking, false); + load_blocking *= jcp.load_block; + + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + int min_bcast_blocking = 5; + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + bcast_blocking = best_divider( + bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + // for reduction balance + if (jcp.ver == ver_avx512_core) { + int max_reduce_blocking + = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim); + int min_reduce_blocking = nstl::min( + L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); + reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, + max_reduce_blocking, true); + reduce_blocking + = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), + jcp.reduce_block); + } else { + int max_reduce_blocking = L2_capacity + / ((bcast_blocking + load_blocking) * jcp.reduce_block); + max_reduce_blocking = nstl::min(max_reduce_blocking, + (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block); + + int num_jobs = div_up(jcp.load_dim, load_blocking) + * div_up(jcp.bcast_dim, bcast_blocking); + int threads_per_job = nstl::max(1, nthreads / num_jobs); + reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block); + reduce_blocking = div_up(reduce_blocking, threads_per_job); + + reduce_blocking = best_divider(reduce_blocking, + max_reduce_blocking - 2, max_reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + } + + reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + if (jcp.ver == ver_4fma) { + assert(jcp.reduce_loop_unroll % jcp.fma_step == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + } + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +void jit_avx512_common_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.prop_kind != backward_data && jcp.with_bias + && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); + + if (jcp.prop_kind == backward_weights) { + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic; + scratchpad.book(key_conv_wei_reduction, + jcp.typesize_out * wei_size * (jcp.nthr_mb - 1)); + } + + if (jcp.transpose_src) { + const size_t tr_src_size = + (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is; + scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size); + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * jcp.nthr); + } +} + +void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp, + int nthreads) +{ + // initialize jcp reduction threading properties + jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; + if (nthreads < jcp.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + const int nb_load = div_up(jcp.load_dim, jcp.load_block); + const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + jcp.nthr_g = jcp.ngroups; + const int nthr = nthreads / jcp.nthr_g; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level + * optimizer tries to minimize memory consumption. few notes: (n1) + * unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + int bcast_koeff = 1; + int load_koeff = 1; + int output_koeff = 12; + if (jcp.transpose_src) { + bcast_koeff = 5; + load_koeff = 1; + output_koeff = 8; + } + return 0 + + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block + / jcp.stride_h / jcp.stride_w /* (n1) */ + + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) + * div_up(jcp.ngroups, jcp.nthr_g) + * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block + + (size_t)output_koeff /* (n2) */ + * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) + * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block + * jcp.oc_block; + }; + + int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; + auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); + for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); + for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); + auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + jcp.nthr_mb = nthr_mb; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) + jcp.nthr_mb = nstl::min(jcp.mb, nthreads); + + jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; + assert(jcp.nthr <= nthreads); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..d2ae017943 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp @@ -0,0 +1,108 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP +#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_1x1_conv_kernel : public jit_generator { + jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_common_1x1_conv_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel) + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + + reg64_t reg_bcast_data = r8; + reg64_t reg_load_data = r10; + reg64_t reg_output_data = r9; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = aux_reg_load_data; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reg_load_loop_work = rsi; + reg64_t reg_reduce_loop_work = r11; + reg64_t bcast_loop_iter = rdx; + reg64_t reduce_loop_iter = abi_param1; + reg64_t reg_reduce_pos_flag = rax; + reg64_t reg_output_stride = r13; + reg64_t reg_bias_data = r12; + reg64_t reg_relu_ns = r13; + reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; + + Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); + + jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; + + int bcast_loop_work_offt = 0; + int stack_space_needed = 16; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp new file mode 100644 index 0000000000..54d58c8a39 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp @@ -0,0 +1,816 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_common_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + + +namespace { +template <typename T, typename U> +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const int grp_count = nstl::min(nx_divider, nthr); + const int grp_size_big = nthr / grp_count + 1; + const int grp_size_small = nthr / grp_count; + const int n_grp_big = nthr % grp_count; + const int threads_in_big_groups = n_grp_big * grp_size_big; + + const int ithr_bound_distance = ithr - threads_in_big_groups; + T grp, grp_ithr, grp_nthr; + if (ithr_bound_distance < 0) { // ithr in first groups + grp = ithr / grp_size_big; + grp_ithr = ithr % grp_size_big; + grp_nthr = grp_size_big; + } else { // ithr in last groups + grp = n_grp_big + ithr_bound_distance / grp_size_small; + grp_ithr = ithr_bound_distance % grp_size_small; + grp_nthr = grp_size_small; + } + + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} +/* convolution forward */ + +template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> +void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + const auto &jcp = kernel_->jcp; + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.template get<dst_data_t>( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + parallel(0, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> +void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>:: +execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space); + + const int ndims = src_d.ndims(); + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t<avx512_common>::call_params_t(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count); + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) + { + const int nb_ic_blocking_step = + nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 + | (icb == 0 ? FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic + ? FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow, + int ih, int iw) + { + + const int _ocb = g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + + p.output_data = &dst[dst_off]; + p.bias_data = &bias[_ocb * jcp.oc_block]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int _icb = g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, _icb, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + inner_ker(ocb, icb, n, g, oh, ow, ih, iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + + +template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>; +/* convolution backward wtr data */ + +template <data_type_t diff_dst_type, data_type_t wei_type, + data_type_t diff_src_type> +void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type, + diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad(ctx).template get<diff_src_data_t>( + key_conv_rtus_space); + + const int ndims = diff_src_d.ndims(); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + const int nb_ic = jcp.nb_load; + const int nb_oc = jcp.nb_reduce; + const int os_block = jcp.bcast_block; + const int nb_oc_blocking = jcp.nb_reduce_blocking; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + parallel(0, [&](const int ithr, const int nthr) { + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t<avx512_common>::call_params_t(); + + int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load, icb_start, icb_end, jcp.load_grp_count); + + bool reduce_outer = (jcp.loop_order == loop_rbl + || jcp.loop_order == loop_rlb); + int nboc_outer = reduce_outer ? nb_oc : 1; + int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; + + int nboc_inner = reduce_outer ? 1 : nb_oc; + int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; + + for (int ocb_outer = 0; ocb_outer < nboc_outer; + ocb_outer += ocb_outer_step) { + size_t cur_ocb_outer = + nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer; + + int load_step = 0; + for (int icb = icb_start; icb < icb_end; icb += load_step) { + load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, + jcp.nb_load_blocking_max); + + p.load_dim = this_block_size(icb * jcp.ic_block, + icb_end * jcp.ic_block, load_step * jcp.ic_block); + rp.icb = p.load_dim / jcp.ic_block; + + int bcast_step; + for (int iwork = bcast_start; iwork < bcast_end; + iwork += bcast_step) + { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + + const int oh = os / jcp.ow; + const int ow = os % jcp.ow; + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + const int _icb = g * nb_ic + icb; + rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw); + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_; + p.output_data = rp.ws; + } else + p.output_data = rp.src; + + for (int ocb_inner = 0; ocb_inner < nboc_inner; + ocb_inner += ocb_inner_step) { + int cur_ocb_inner = + nstl::min(ocb_inner + ocb_inner_step, nboc_inner) - + ocb_inner; + + int ocb = reduce_outer ? ocb_outer : ocb_inner; + int nb_oc_blocking_step = reduce_outer + ? cur_ocb_outer : cur_ocb_inner; + const int _ocb = g * nb_oc + ocb; + size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow); + p.bcast_data = &diff_dst[diff_dst_off]; + + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; + + p.reduce_dim = this_block_size(ocb * jcp.oc_block, + jcp.oc, nb_oc_blocking_step * jcp.oc_block); + + kernel_->jit_ker(&p); + } + if (pd()->rtus_.reduce_src_) + rtus_driver_->ker_(&rp); + } + } + } + }); +} + +template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>; + +/* convolution backward wtr weights */ + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +jit_avx512_common_1x1_convolution_bwd_weights_t :: + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) + , trans_kernel_(nullptr), rtus_driver_(nullptr) +{ + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>(); + reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_); + init_rtus_driver<avx512_common>(this); + + const auto &jcp = kernel_->jcp; + + if (jcp.transpose_src) { + auto tp = jit_transpose4x16_src_t(); + tp.src_pf0_distance = 4; + tp.tr_src_pf0_distance = 0; + tp.src_pf1 = true; + tp.tr_src_pf1 = false; + trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp); + } +} + +void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( + const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + + const auto scratchpad = this->scratchpad(ctx); + + auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space); + data_t *diff_bias = pd()->wants_padded_bias() + ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in; + auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction); + + /* prepare src transposition barriers */ + auto tr_src = scratchpad.get<data_t>(key_conv_tr_src); + auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>( + key_conv_tr_src_bctx); + if (jcp.transpose_src) { + for (int i = 0; i < jcp.nthr; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + + const int ndims = src_d.ndims(); + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic; + + simple_barrier::ctx_t reduction_barrier; + simple_barrier::ctx_init(&reduction_barrier); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); + + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + const int nb_ic = jcp.nb_bcast; + const int nb_ic_blocking = jcp.nb_bcast_blocking; + + const int nb_oc = jcp.nb_load; + const int nb_oc_blocking = jcp.nb_load_blocking; + + const int sp_nb = jcp.nb_reduce; + const int mb_sp_work = jcp.mb * sp_nb; + + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + // TODO: use memory descriptor with the same fmt as src + // (or use a macro :)) + auto tr_src_off = [&](int img, int icb, int is) { + const size_t tr_chn_size = jcp.tr_is * jcp.ic_block; + const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups; + return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block; + }; + + auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size, + int g_start, int g_work, int ic_b_start, int ic_b_work, + int ithr, int nthr, int first_ic_b) + { + const int work_amount = g_work * ic_b_work; + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int g{ 0 }, ic_b{ 0 }; + nd_iterator_init(start, g, g_work, ic_b, ic_b_work); + g += g_start; + const int ic_b_tr = g * nb_ic + first_ic_b + ic_b; + ic_b += ic_b_start; + + const int _ic = g * nb_ic + ic_b; + + const int is = sp_b_start * jcp.reduce_block; + const int ih = is / jcp.iw; + const int iw = is % jcp.iw; + + const int src1_off = data_blk_off(src_d, img, _ic, ih, iw); + data_t *src1 = (data_t *)&src[src1_off]; + data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.is * jcp.ic_block; + const int tr_src_stride = jcp.tr_is * jcp.ic_block; + + const int my_work = end - start; + for (int iwork = 0; iwork < my_work; iwork++) { + auto par_trans = jit_src_transpose_s(); + assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4); + par_trans.size = sp_size; + par_trans.src = src1; + par_trans.tr_src = tr_src1; + par_trans.src_prf = src1 + 64 * 16; + par_trans.tr_src_prf = tr_src1 + 80 * 16; + trans_kernel_->jit_ker(&par_trans); + + src1 += src_stride; + tr_src1 += tr_src_stride; + } + }; + + auto ker = [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1)); + + const int ithr_ic_b = ithr % jcp.nthr_ic_b; + const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; + const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; + const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / + jcp.nthr_g; + + const int ithr_but_oc + = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; + + /* reduction dimension */ + int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 }; + if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) { + // it's preferable to parallelize by mb if possible + int img_start{ 0 }, img_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end); + mb_sp_b_start = img_start * sp_nb; + mb_sp_b_end = img_end * sp_nb; + } + else { + balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, + mb_sp_b_end); + } + + /* independent dimensions */ + int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 }; + int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 }; + + balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); + balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, + oc_b_end); + balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, + ic_b_end); + + const int g_work = g_end - g_start; + const int oc_b_work = oc_b_end - oc_b_start; + const int ic_b_work = ic_b_end - ic_b_start; + + data_t *diff_wei = ithr_mb == 0 + ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size; + + int sp_b_step = 0; + for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; + mb_sp_b += sp_b_step) { + int img{ 0 }, sp_b{ 0 }; + nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); + sp_b_step = step(jcp.nb_reduce_blocking, + nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), + jcp.nb_reduce_blocking_max); + + for (int g = g_start; g < g_end; ++g) { + int load_step = 0; + int bcast_step = 0; + for (int ic_b = ic_b_start; ic_b < ic_b_end; + ic_b += bcast_step) { + bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, + jcp.nb_bcast_blocking_max); + if (jcp.transpose_src) { + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + const int sp_size + = nstl::min(sp_b_step * jcp.reduce_block, + jcp.is - sp_b * jcp.reduce_block); + uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b, + bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start); + if (jcp.nthr_oc_b > 1) + simple_barrier::barrier( + &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); + } + + for (int oc_b = oc_b_start; oc_b < oc_b_end; + oc_b += load_step) { + load_step = step(nb_oc_blocking, oc_b_end - oc_b, + jcp.nb_load_blocking_max); + const int _ic_b = g * nb_ic + ic_b; + const int _ic_b_tr = g * nb_ic + ic_b_start; + const int _oc_b = g * nb_oc + oc_b; + + data_t *store_to; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + store_to = diff_wei + off; + + const data_t *diff_src = jcp.transpose_src ? + &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] : + &src[src_d.blk_off(img, _ic_b)]; + + int sp_b_end = sp_b + sp_b_step; + const data_t *pdiff_dst + = &diff_dst[diff_dst_d.blk_off(img, _oc_b)]; + const data_t *local_src = diff_src; + + auto p = jit_1x1_conv_call_s(); + auto rp = rtus_driver_t<avx512_common>::call_params_t(); + + p.output_stride + = jcp.ic * jcp.oc_block * jcp.typesize_out; + + p.load_dim = load_step * jcp.oc_block; + + p.bcast_dim = bcast_step * jcp.ic_block; + rp.icb = bcast_step; + p.output_data = store_to; + + p.reduce_dim = sp_b_step * jcp.reduce_block; + rp.os = p.reduce_dim; + + p.first_last_flag = 0 + | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0) + | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); + + int sp = sp_b * jcp.reduce_block; + p.load_data = pdiff_dst + sp * jcp.oc_block; + + if (pd()->rtus_.reduce_src_) { + const int oh = sp / jcp.ow; + const int ow = sp % jcp.ow; + + const int ih = nstl::max(oh * stride_h - pad_t, 0); + const int iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + rp.ws = rtus_space + + ithr * pd()->rtus_.space_per_thread_ + + sp * jcp.ic_block; + + if (ndims == 3) + rp.src = local_src + iw + * src_d.blocking_desc().strides[2]; + else + rp.src = local_src + ih + * src_d.blocking_desc().strides[2] + + iw * src_d.blocking_desc().strides[3]; + rtus_driver_->ker_(&rp); + + p.bcast_data = rp.ws; + } else + p.bcast_data = local_src + sp * jcp.ic_block; + + kernel_->jit_ker(&p); + } + } + } + } + + /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ + if (jcp.nthr_mb > 1) { + simple_barrier::barrier(&reduction_barrier, jcp.nthr); + const int work = g_work * oc_b_work * ic_b_work; + int start{ 0 }, end{ 0 }; + balance211(work, jcp.nthr_mb, ithr_mb, start, end); + if (start == end) + return; + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + int w = start; + int sub_g_start{ 0 }, sub_oc_b_start{ 0 }, + sub_ic_b_start{ 0 }; + nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, + oc_b_work, sub_ic_b_start, ic_b_work); + while (w < end) { + const int g = g_start + sub_g_start; + const int oc_b = oc_b_start + sub_oc_b_start; + const int ic_b = ic_b_start + sub_ic_b_start; + + const int acc_size + = nstl::min(end - w, ic_b_work - sub_ic_b_start) + * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b); + data_t *d = diff_weights + off; + data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, g_work, + sub_oc_b_start, oc_b_work, sub_ic_b_start, + ic_b_work); + } + } + } + }; + + auto ker_bias = [&](int ithr, int nthr) { + assert(nthr == rb->balancer().nthr_); + + const int b_job_start = rb->balancer().ithr_job_off(ithr); + const int b_njobs = rb->balancer().ithr_njobs(ithr); + + if (b_njobs == 0) + return; + + /* reduction dimension */ + int img_start{ 0 }, img_end{ 0 }; + + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ithr), img_start, img_end); + + /* jobs */ + int g_start{ 0 }, ocb_start{ 0 }; + nd_iterator_init( + b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); + + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_load + ocb; + + const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; + data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, + reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0.; + + for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); + } + } + rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + ker(ithr, jcp.nthr); + if (pd()->with_bias()) + ker_bias(ithr, jcp.nthr); + }); + + /* TODO: put this in ker_bias */ + if (pd()->wants_padded_bias()) { + assert(jcp.ngroups == 1); + utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp new file mode 100644 index 0000000000..2e9fda76d6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp @@ -0,0 +1,344 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_avx512_common_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, + impl::data_type_t wei_type = src_type, + impl::data_type_t dst_type = src_type> +struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_fwd_t); + + status_t init() { + using namespace utils; + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), + mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = + new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr()); + init_rtus_driver<avx512_common>(this); + } + + ~jit_avx512_common_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t<avx512_common> *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_fwd_f32_t + = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>; + +template <impl::data_type_t diff_dst_type, + impl::data_type_t wei_type = diff_dst_type, + impl::data_type_t diff_src_type = diff_dst_type> +struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *diff_src_d = diff_src_md(); + rtus_prepare(this, conv_d, diff_src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver<avx512_common>(this); + } + + ~jit_avx512_common_1x1_convolution_bwd_data_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<diff_src_type>::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + + private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + rtus_driver_t<avx512_common> *rtus_driver_; +}; + +using jit_avx512_common_1x1_convolution_bwd_data_f32_t + = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>; + +struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t +{ + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), + jit_avx512_common_1x1_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, diff_dst_md()); + + status_t status = jit_avx512_common_1x1_conv_kernel::init_conf( + jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), + *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad, + jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + // TODO (Roma): structs conf header cleanup + jit_1x1_conv_conf_t jcp_; + cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; + reduce_to_unit_stride_t rtus_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, + jcp_.mb, max_buffer_size)); + } + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd); + + ~jit_avx512_common_1x1_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + delete reducer_bias_; + delete rtus_driver_; + delete trans_kernel_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + + private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_1x1_conv_kernel *kernel_; + cpu_accumulator_1d_t<data_type::f32> *acc_ker_; + cpu_reducer_t<data_type::f32> *reducer_bias_; + jit_transpose4x16_src *trans_kernel_; + rtus_driver_t<avx512_common> *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp new file mode 100644 index 0000000000..235fb02fef --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp @@ -0,0 +1,4539 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" + +#include "jit_avx512_common_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) +#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + +constexpr auto small_spatial = 14; +unsigned int L1_cache_size = get_cache_size(1, true); + +inline void pick_loop_order(jit_conv_conf_t &jcp) { + using namespace prop_kind; + assert(one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)); + auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow; + auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh; + + // ow-threading is currently implemented for forward only + // TODO: single code for fwd and bwd after ow-thr for bwd + // meaningless switch was removed + if (jcp.prop_kind == backward_data) { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cgn : loop_gnc; + } else { + jcp.loop_order = (w <= small_spatial && h <= small_spatial) + ? loop_cwgn : loop_gncw; + } +} + +inline bool is_1stconv(const jit_conv_conf_t &jcp) { + if (mayiuse(avx512_core)) + return (jcp.ic < 16 && jcp.ngroups == 1); + else + return one_of(jcp.ic, 1, 3); +} + +inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) { + return (jcp.nb_ow > 1); +} + +inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) { + return (jcp.ver == ver_4fma && is_ow_threading_on(jcp)); +} + +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + if (!is_owb_prefetching(jcp)) { + size_t aux_output_offset = get_output_offset(j, k); + mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } + } +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w) +{ + Label no_update_label, store_label, eltwise_label; + + mov(reg_channel, ptr[param1 + GET_OFF(channel)]); + if (jcp.with_bias) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + } + + if (!jcp.with_sum) { + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + } + + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = get_output_offset(j, k); + vaddps(vmm, + make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt)); + } + + if (!jcp.with_sum) { + jmp(eltwise_label, T_NEAR); + } else { + cmp(reg_channel, 0); + jne(eltwise_label, T_NEAR); + } + + L(no_update_label); + if (jcp.with_bias) { + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + int bias_offset = jcp.typesize_out * k * jcp.oc_block; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset)); + } + mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64)); + } + } + + L(eltwise_label); + if (jcp.with_eltwise) { + cmp(reg_channel, jcp.nb_ic - 1); + jl(store_label, T_NEAR); + + if (ur_w == jcp.ur_w) { + eltwise_injector_->compute_vector_range(0, + jcp.nb_oc_blocking * jcp.ur_w); + } else { + for (int k = 0; k < jcp.nb_oc_blocking; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); + } + } + + L(store_label); + for (int k = 0; k < jcp.nb_oc_blocking; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + size_t aux_output_offset = (size_t)typesize * + ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block; + vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset, + reg_out_long_offt), vmm); + if (!is_owb_prefetching(jcp)) + mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf, + aux_output_offset, reg_out_long_offt)); + } +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w, + int pad_l, int pad_r) +{ + assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0); + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + Label kh_label, kd_label; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in + * ((size_t)(kw + ur_w * stride_w - pad_l) + + (size_t)ic_block * iw * ih * jcp.id); + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + + L(kd_label); + } + mov(reg_kj, reg_kh); + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + L(kh_label); + for (int ki = 0; ki < kw; ki += 4) { + for (int ic = 0; ic < ic_block; ic++) { + for (int i = 0; i < 4; i++) { + int aux_ker_offset + = jcp.typesize_in + * ((ki + i) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block); + if (ki + i < kw) + vmovups(vmm_ker(i), + EVEX_compress_addr(aux_reg_ker, aux_ker_offset)); + else + vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i)); + } + + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + + for (int j = j_start, prf_count=0; j < j_end; j++) { + size_t aux_input_offset = (size_t)jcp.typesize_in + * ((size_t)(ki + j * stride_w + - pad_l) + (size_t)ic * iw * ih * jcp.id); + v4fmaddps(vmm_out(j, 0), vmm_ker(0), + EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset, + reg_long_offt)); + if (ki + prf_count < kw && prf_count < 4 + && ((ki < 2 && j % 4) || j % 2)) { + int aux_ker_offset = jcp.typesize_in + * ((ki + prf_count) * oc_block + + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_ker_offset)); + prf_count++; + } + if (ki == 0 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf, + aux_input_offset, reg_long_offt)); + } + if (ki == 1 + && j % (64 / (stride_w * jcp.typesize_in)) == 0) { + mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset+jcp.typesize_in * iw, reg_long_offt)); + } + } + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block); + add(aux_reg_inp, jcp.typesize_in * iw); + add(aux_reg_inp_prf, jcp.typesize_in * iw); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block); + add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ +} + +template<> +void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w, + int pad_l, int pad_r) +{ + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block; + int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block; + + bool check_last_kh = (jcp.kh > 3); + bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28); + + int oi_ipref_t0 = get_ow_start(0, pad_l); + int ow_end_ipref = get_ow_end(ur_w, 0, pad_r); + + assert(jcp.oc % jcp.nb_oc_blocking == 0); + + auto kernel_offset = [=](int ocb, int ic, int ki) { + int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int ic_offset = ic * jcp.oc_block; + return typesize * (blk_offset + ic_offset); + }; + auto kernel_loads = [=](int ki, int ic, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, ic + ii, ki); + vmovups(vmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) { + int aux_inp_offset + = typesize + * ((oi_ipref_t0 * stride_w - pad_l) * ic_block + + (jcp.dilate_h + 1) * jcp.iw * ic_block); + prefetcht0(EVEX_compress_addr(aux_reg_inp, + aux_inp_offset)); + oi_ipref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_ker_prf, reg_ker_prf); + mov(aux_reg_inp_prf, reg_inp_prf); + } + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + int kw = jcp.kw; + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1 + && ki == kw - 1 && (ic + 4) == ic_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + + if (oi % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, + prf_count_t0 + ic + 4 + - ic_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, ic + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, ic + + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_inp_next_kh(ki, 2, prf_count_t0, + prf_count_t1); + } + + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0, + prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + if (oi % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(0, + prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int ic = 0; ic < ic_block; ic += 4) + for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) { + kernel_loads(ki, ic, kk); + for (int oi = get_ow_start(ki, pad_l), + prf_count_t1 = 0, prf_count_t0 = 0; + oi < get_ow_end(ur_w, ki, pad_r); oi++) { + int aux_input_offset = typesize + * ((ki * (jcp.dilate_w + 1) + oi * stride_w + - pad_l) * ic_block + ic); + v4fmaddps(vmm_out(oi, kk), vmm_ker(0), + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + + if (!is_owb_prefetching(jcp)) { + if ((oi % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + ic + prf_count_t1, ki))); + prf_count_t1++; + } + } else { + if (!(ki == 0 && ic == 0) + && !(ki == kw-1 && ic == 0) && + (oi % 2) && (prf_count_t1 < 4) + ) { + mic_prefetcht0(EVEX_compress_addr( + aux_reg_ker, kernel_offset(kk, + ic + 4 + prf_count_t0, ki))); + prf_count_t0++; + } + } + if (!is_owb_prefetching(jcp)) { + if (pref_current_inp) { + if (ki == 0 && ic == 0 && kk == 0) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + shift_input_ptr)); + } else { + if (ki == 1 && ic == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp_prf, aux_input_offset)); + } + } else { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift + = jcp.typesize_in * ur_w * stride_w * inp_mult; + bool kk_pref_slot = kk ? oi % 2 : !(oi % 2); + if (ki == 0 && ic == 0 && kk_pref_slot) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + + if (ki == kw - 1 && ic == 0 && kk_pref_slot) + mic_prefetcht0(EVEX_compress_addr( + aux_reg_inp, + aux_input_offset + inp_shift)); + } + } + } + } + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + add(aux_reg_ker_prf, shift_kernel_ptr); + add(aux_reg_inp_prf, shift_input_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w, + int pad_l, int pad_r) +{ + bool prf_ker = true; + bool prf_inp = true; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + int id = jcp.id; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = ic_block * nb_oc_block * kw; + int num_ker_prfs = prf_ker ? num_ker_loads : 0; + int num_inp_prfs = prf_inp ? + ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) : + 0; + if (jcp.is_1stconv && prf_inp) { + num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block; + } + int num_prfs = num_ker_prfs + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w; + int prf_inst_spacing + = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1; + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_inp_prf, reg_inp_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id; + assert(reg_inp_prf == reg_long_offt); + if (max_input_offset > INT_MAX) push(reg_inp_prf); + + + if (jcp.ndims == 5) { + push(reg_out_prf); + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + mov(aux_reg_inp_d_prf, reg_inp_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + mov(aux_reg_inp_prf, aux_reg_inp_d_prf); + } + + align(16); + L(kh_label); + { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int ic = 0; ic < ic_block; ic++) { + int aux_kernel_offset = 0; + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + aux_kernel_offset = get_kernel_offset(ki, ic, 0, i); + vmovups(vmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + aux_kernel_offset + = get_kernel_offset(ki, ic, 0, load_offset); + vmovups(vmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth); + int j_start = get_ow_start(ki, pad_l); + int j_end = get_ow_end(ur_w, ki, pad_r); + for (int j = j_start; j < j_end; j++) { + size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l); + auto addr = EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true); + vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr); + int fma_idx = step * ur_w + j; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (prf_ker && !ker_prf_inserted + && ker_prfs < num_ker_prfs) { + int ker_prf_offset + = jcp.typesize_in * ker_prfs * jcp.oc_block; + mic_prefetcht2(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else if (prf_inp) { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + size_t inp_prf_stride = nstl::max(kw, stride_w); + size_t inp_prf_offset; + if (!jcp.is_1stconv) { + inp_prf_offset + = ic_block * jcp.typesize_in + * ((inp_prf_idx / kw) + * inp_prf_stride + + (inp_prf_idx % kw)); + } else { + size_t ic_prf_stride = + (size_t)jcp.typesize_in * iw * ih * id; + size_t iw_prf_stride + = jcp.typesize_in * jcp.simd_w; + inp_prf_offset = ((inp_prf_idx / ic_block) + * iw_prf_stride + + (inp_prf_idx % ic_block) + * ic_prf_stride); + } + mic_prefetcht0(EVEX_compress_addr_safe( + aux_reg_inp_prf, inp_prf_offset, + reg_long_offt)); + } + } + } + } + step++; + } + } + add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block); + if (prf_ker) + add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block); + add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + if (prf_inp) + add(aux_reg_inp_prf, + jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + add(aux_reg_inp_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + pop(reg_out_prf); + } + if (max_input_offset > INT_MAX) pop(reg_inp_prf); +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w, + int pad_l, int pad_r) +{ + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_oc_block = jcp.nb_oc_blocking; + Label kh_label, kd_label; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block + * jcp.ic_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * inp_mul; + + + auto input_offset = [=](int oi, int ic, int ki) { + return (size_t)jcp.typesize_in + * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * inp_mul + (size_t)ic + * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id)); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_out); + + mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); + mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]); + mov(aux_reg_inp_d, reg_inp); + + L(kd_label); + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_inp, aux_reg_inp_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + for (int ic = 0; ic < ic_block; ic++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + size_t aux_input_offset = input_offset(jj, ic, ki); + vbroadcastss(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt)); + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = jcp.typesize_in + * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block + * oc_block + ki * ic_block * oc_block + ic * oc_block); + if (jj_end - jj_start > 0) + vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj++) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(vmm_out(jj, ii), + vmm_inp(jj, nb_oc_block), vmm_wei); + else { + size_t aux_input_offset = input_offset(jj, ic, ki); + vfmadd231ps(vmm_out(jj, ii), vmm_wei, + EVEX_compress_addr_safe(aux_reg_inp, + aux_input_offset, reg_long_offt, true)); + } + } + } + } + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_inp_d, + typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block + * jcp.ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_out); + } +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w, + int pad_l, int pad_r) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + } + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + + if (jcp.ver == ver_4fma) + if(jcp.is_1stconv) + compute_loop_4fma_1st(ur_w, pad_l, pad_r); + else + compute_loop_4fma(ur_w, pad_l, pad_r); + else if (jcp.ver == ver_fma) + if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast) + || mayiuse(avx512_mic)) + compute_loop_fma(ur_w, pad_l, pad_r); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1) + compute_loop_fma(ur_w, pad_l, pad_r); + else + compute_loop_fma_core(ur_w, pad_l, pad_r); + else + assert(!"unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +template<typename Vmm> +void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate() +{ + int iw = jcp.iw; + int ow = jcp.ow; + int ow_block = jcp.ow_block; + int nb_ow = jcp.nb_ow; + int kw = jcp.kw; + int l_pad = jcp.l_pad; + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; + int out_shift = jcp.typesize_out * ur_w * jcp.oc_block; + + preamble(); + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + int r_pad = nstl::max( + 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); + int n_oi = ow / ur_w; + int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + + if (!is_ow_threading_on(jcp)) { + // ow is being processed as a whole - with left and right paddings + if (r_pad1 > 0) n_oi--; + + if (ow == ur_w) { + mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]); + mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]); + compute_loop(ur_w, l_pad, r_pad); + } else { + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + if (n_oi == 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, r_pad1); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } else { + xor_(reg_oi, reg_oi); + if (l_pad > 0) { + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + inc(reg_oi); + } + if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; + Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; + + assert(ow_block % ur_w == 0); + int n_oi_not_last_ow_block = ow_block / ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + + int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w; + + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + + mov(reg_oi, n_oi_first_ow_block); + mov(reg_inp_prf, reg_inp); + mov(reg_out_prf, reg_out); + + if (l_pad > 0) { + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift_pad); + add(reg_out_prf, out_shift); + compute_loop(ur_w, l_pad, 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + + L(middle_ow_blocks_label); + + if (l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + add(reg_inp_prf, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + L(oi_loop_start_label); + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + jmp(oi_loop_start_label, T_NEAR); + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) { + je(last_oi_label, T_NEAR); + } else { + je(end_label, T_NEAR); + } + // that is last block + if (!last_ow_block_padded) { + jmp(tail_label, T_NEAR); + } + + // last oi block with right padding + L(last_oi_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w, 0, r_pad1); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + L(tail_label); + mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]); + if (ur_w_tail != 0) { + add(reg_inp_prf, inp_shift); + add(reg_out_prf, out_shift); + compute_loop(ur_w_tail, 0, r_pad); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_common_conv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, int nthreads) +{ + using namespace prop_kind; + + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int regs = 28; + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero<decltype(jcp)>(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2]; + jcp.ow = dst_d.dims()[ndims-1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2]; + jcp.kw = weights_d.dims()[with_groups + ndims-1]; + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.is_1stconv = is_1stconv(jcp); + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + jcp.simd_w = full_simd_w; + bool ok_to_try_xmm = true + && mayiuse(avx512_core) + && src_d.data_type() == data_type::f32 + && !jcp.is_1stconv + && !ok_to_pad_channels + && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0) + && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0); + if (ok_to_try_xmm) + jcp.simd_w = 4; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + jcp.aligned_threads = 0; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0; + if (!args_ok) + return status::unimplemented; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) { + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + if (dst_d.data_type() == data_type::s32) return status::unimplemented; + } + + auto src_tag = jcp.is_1stconv + ? pick(ndims - 3, ncw, nchw, ncdhw) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c)); + auto dst_tag = (jcp.simd_w == 4) + ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c) + : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o) + : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o)); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, x)); + } + + if (mayiuse(avx512_common) && + src_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && dst_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + + if (jcp.is_1stconv) { + // TODO: fix & remove constraints below + bool not_for_4fma + = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad), + nstl::max(jcp.kw, jcp.kh) < 7); + bool is_dilated + = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w); + if (one_of(true, not_for_4fma, is_dilated)) + jcp.ver = ver_fma; + if (jcp.ver == ver_4fma) { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o) + : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o)); + } else { + wei_tag = with_groups + ? ((jcp.simd_w == 4) + ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o) + : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)) + : ((jcp.simd_w == 4) + ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o)); + } + } + } else { + return status::unimplemented; + } + + if (weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (jcp.is_1stconv) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + // avx512_core guard - just to avoid possible regression for other archs + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } else { + for (int ur_w = regs; ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) { + jcp.ur_w = nstl::min(jcp.ow, regs); + } + } + // TODO (Tanya): currently applied to Segnet convolutions only. + // Need to try for other topologies + if (jcp.ow > 150 && jcp.ur_w < regs/2) + jcp.ur_w = regs; + + int n_oi = (jcp.ow / jcp.ur_w); + int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + if (jcp.l_pad > 0 && r_pad > 0) + n_oi--; + + bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0 + && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw; + int mult = 1; + if (jcp.l_pad > 0) mult += 1; + if (r_pad > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) { + jcp.ur_w = ur_w; + break; + } + } + } + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off + = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ? + jcp.ic : + 1; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + auto is_ow_threading_applicable = [=]() { + return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4) + && IMPLICATION(mayiuse(avx512_mic), + jcp.ver == ver_4fma + && IMPLICATION(jcp.mb != 1, + jcp.ih == 1 && jcp.kh == 1))); + }; + + if (jcp.ver == ver_4fma && !jcp.is_1stconv) { + if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8 + && jcp.oh <= 8 && jcp.ow == jcp.oh) + || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) { + if (jcp.nb_oc % 2 == 0) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } else { + for (int i = jcp.nb_oc; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) { + jcp.nb_oc_blocking = i; + break; + } + } + if (jcp.ver == ver_4fma && is_ow_threading_applicable()) { + if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow + && jcp.ow != 2 * jcp.ur_w) { + jcp.nb_oc_blocking = 2; + jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking); + } + } + } + + jcp.ow_block = jcp.ow; + + auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { + int nb_ow = div_up(jcp.ow, ow_block); + int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); + int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; + float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); + float thr_eff = disbalance * (float)work_amount + / rnd_up(work_amount, nthreads); + return thr_eff; + }; + + auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { + int res_ow_block = jcp.ow; + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + if (!is_ow_threading_applicable()) + return res_ow_block; + + int L2_part = (get_cache_size(2) * 7 / 8) / typesize; + if (jcp.ver == ver_4fma) + L2_part /= 2; + int size_src_chunk = jcp.ic_block * ur_w * jcp.kh; + int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w; + int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block + * jcp.kw * jcp.kh; + int nurw_cache = (L2_part - 2 * size_wei_chunk) + / (2 * size_dst_chunk + 2 * size_src_chunk); + // current design of generate() requires ow_block >= 2 * ur_w + int ow_block_cache = ur_w * nstl::max(2, nurw_cache); + + int ow_block_thr = ow_block_cache; + eff = get_thr_eff(nb_oc_blocking, ow_block_thr); + + int max_nb_ow = div_up(jcp.ow, 2 * ur_w); + int start_nb_ow = div_up(jcp.ow, ow_block_thr); + for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); + float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f; + if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + float thr_eff = get_thr_eff(nb_oc_blocking, ow_block); + float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f; + if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) { + ow_block_thr = ow_block; + eff = thr_eff; + } + eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f; + if (eff > eff_threshold) + break; + } + res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr)); + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + return res_ow_block; + }; + + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_oc_blocking = 2; + unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w) + * jcp.ic_block * jcp.kh * jcp.kd; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block + * try_nb_oc_blocking; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_oc_blocking * jcp.kd; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + + bool embd_bcast_condition = true + && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size) + && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192) + && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512); + + if (jcp.mb == 1) { + unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) + * div_up(jcp.iw, jcp.stride_w) * jcp.ic; + unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; + + // Estimate whether we need to limit the number of threads + // and calculate this number. Includes some heuristic. + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; + int job_size_min = work_amount / nthreads; + int job_size_max = div_up(work_amount, nthreads); + int ch_max = rnd_up(jcp.oh, job_size_max); + int ch_min = (job_size_min == 0) + ? jcp.oh + : rnd_up(jcp.oh, job_size_min); + bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 + && (jcp.oh != 8 || ch_max / jcp.oh > 1); + bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 + && (jcp.oh != 8 || ch_min / jcp.oh > 1); + bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) + || nthreads > oc_chunks; + if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1 + && wei_size / inp_size > 24 + && (not_aligned_max || not_aligned_min) + && eligible_case) { + // Try to find nthreads > mkldnn_get_max_threads() / 2 such + // that oc_chunks is a multiple of nthreads, or nthreads is a + // multiple of oc_chunks. Otherwise, keep default value. + // TODO: implement a task-based alternative without throttling. + jcp.aligned_threads = nthreads; + for (int i = nthreads; i > nthreads / 2; i--) { + if (oc_chunks % i == 0 || i % oc_chunks == 0) { + jcp.aligned_threads = i; + break; + } + } + } + } + + if (jcp.kw > 3 + || (jcp.stride_w == 1 && jcp.stride_h == 1 + && embd_bcast_condition) + || ((jcp.stride_w != 1 || jcp.stride_h != 1) + && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10) + && embd_bcast_condition))) + || (jcp.mb == 1 + && (jcp.ur_w >= jcp.ow || jcp.is_1stconv + || (jcp.ow <= 147 && jcp.oc <= 96)))) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.ow, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3 + && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0 + && IMPLICATION(jcp.is_1stconv, jcp.mb == 1) + && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) { + jcp.nb_oc_blocking = try_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_ic_blocking = 1; + if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) { + float best_thr_eff = 0.f; + int best_nb_oc_blocking = 1; + for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) { + if (jcp.nb_oc % i == 0) { + float thr_eff; + int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); + get_ow_block(i, ur_w, thr_eff); + if (thr_eff > 1.05f * best_thr_eff) { + best_nb_oc_blocking = i; + best_thr_eff = thr_eff; + } + } + } + jcp.nb_oc_blocking = best_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + } + } + + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + args_ok = true + && jcp.l_pad <= jcp.ur_w + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_ic_L2 = jcp.nb_ic; + + float thr_eff; + jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + const int L2_size = get_cache_size(2, true) / sizeof(float); + // Source and output data needs to fit in L2, + // leaving some space for weights and prefetching. + int h_L2 = int(((0.6f * L2_size) / jcp.simd_w + - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) + / (jcp.stride_h * jcp.iw + jcp.ow)); + jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); + + if (jcp.ver == ver_4fma) { + if (!is_ow_threading_on(jcp)) { + for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic; + divf++) { + size_t l2_src + = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id; + size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking + * jcp.oh * jcp.od; + size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block + * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.oh == 7) { + jcp.nb_ic_L2 = 1; + break; + } + temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf + : jcp.nb_ic_L2); + } else { + jcp.nb_ic_L2 = temp_nb; + break; + } + } + } else if (jcp.ic > 64) { + jcp.nb_ic_L2 = 2; /* according to performance data*/ + } + } + + return status::success; +} + +void jit_avx512_common_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w) +{ + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + vpxord(zmm, zmm, zmm); + size_t aux_src_offset + = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) + * jcp.ic_block; + mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w) +{ + Label no_update_label; + + mov(reg_channel, ptr[param + GET_OFF(channel)]); + cmp(reg_channel, 0); + je(no_update_label, T_NEAR); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt)); + } + } + + L(no_update_label); + for (int k = 0; k < jcp.nb_ic_blocking; k++) { + for (int j = 0; j < ur_w; j++) { + Zmm zmm = zmm_out(j, k); + size_t aux_src_offset = (size_t)typesize + * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block; + vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset, + reg_long_offt), zmm); + mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset, + reg_long_offt)); + } + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma( + int ur_w, int l_overflow, int r_overflow) +{ + int ow = jcp.ow; + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + Label kh_label, last_iter_label, loop_end_label, kd_label; + int ker_load_number = 4; + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * ow * oc_block; + int ii_dpref_t0 = get_iw_start(0, l_overflow); + int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow); + + bool check_last_kh = (jcp.kh > 3); + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + auto kernel_loads = [=](int ki, int oc, int kk) { + for (int ii = 0; ii < ker_load_number; ii++) { + int aux_kernel_offset = kernel_offset(kk, oc + ii, ki); + vmovups(zmm_ker(ii), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + }; + auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) { + if (cnt1 >= ker_load_number && cnt0 >= ker_load_number + && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) { + int aux_dst_offset = typesize * ((ii_dpref_t0 + + jcp.l_pad) * oc_block + jcp.ow * oc_block); + prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + ii_dpref_t0++; + } + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + align(16); + L(kh_label); + if (check_last_kh) { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1 + && ki == kw - 1 && (oc + 4) == oc_block); + + if (last_kernel_loads) { + cmp(reg_kj, 1); + je(last_iter_label, T_NEAR); + } + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + + if (ii % 2) { + if (prf_count_t0 < 4) { + int aux_kernel_prf; + if (last_kernel_loads) + aux_kernel_prf= kernel_offset(0, prf_count_t0 + + oc + 4 - oc_block, 0) + typesize * kw + * oc_block * ic_block; + else + aux_kernel_prf = kernel_offset(kk, oc + 4 + + prf_count_t0, ki); + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker, + aux_kernel_prf)); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } else + prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1); + } + if (last_kernel_loads) { + jmp(loop_end_label, T_NEAR); + + L(last_iter_label); + + kernel_loads(ki, oc, kk); + for (int ii = get_iw_start(ki, l_overflow), + prf_count_t0 = 0, prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if (ii % 2) { + if (prf_count_t0 < 4) { + mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(0, prf_count_t0, 0))); + prf_count_t0++; + } else if (prf_count_t1 < 4) { + mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf, + kernel_offset(kk, oc + prf_count_t1, ki))); + prf_count_t1++; + } + } + } + L(loop_end_label); + } + } + } else { + for (int ki = 0; ki < kw; ki++) + for (int oc = 0; oc < oc_block; oc += 4) + for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) { + kernel_loads(ki, oc, kk); + + for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0; + ii < get_iw_end(ur_w, ki, r_overflow); ii++) { + int aux_dst_offset = typesize + * ((ii + jcp.l_pad - ki) * oc_block + oc); + v4fmaddps(zmm_out(ii, kk), zmm_ker(0), + EVEX_compress_addr(aux_reg_dst, aux_dst_offset)); + if ((ii % 2) && (prf_count_t1 < 4)) { + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, kernel_offset(kk, + oc + prf_count_t1, ki))); + prf_count_t1++; + } + if ( ki == 1 && oc == 0 && kk == 0) + mic_prefetcht1(EVEX_compress_addr( + aux_reg_dst_prf, aux_dst_offset)); + } + } + } + + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + add(aux_reg_ker_prf, shift_ker_ptr); + sub(aux_reg_dst_prf, shift_dst_ptr); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma( + int ur_w, int l_overflow, int r_overflow) +{ + Label kh_label, kd_label; + int kw = jcp.kw; + int ow = jcp.ow; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int l_pad = jcp.l_pad; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int stride_h = jcp.stride_h; + + int ker_pipeline_depth = 4; + assert(ker_reg_base_idx + ker_pipeline_depth <= 32); + assert(oc_block >= ker_pipeline_depth); + + int num_ker_loads = oc_block * kw; + int num_inp_prfs = ur_w * nstl::min(kw, stride_w) + + nstl::max(0, kw - stride_w); + int num_prfs = num_ker_loads + num_inp_prfs; + int num_fmas = num_ker_loads * ur_w / stride_w; + int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs); + int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + + mov(aux_reg_dst_prf, reg_dst_prf); + mov(aux_reg_ker_prf, reg_ker_prf); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + mov(aux_reg_dst_d_prf, reg_dst_prf); + mov(aux_reg_ker_d_prf, reg_ker_prf); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + mov(aux_reg_dst_prf, aux_reg_dst_d_prf); + mov(aux_reg_ker_prf, aux_reg_ker_d_prf); + } + + L(kh_label); { + int step = 0; + int ker_prfs = 0; + for (int ki = 0; ki < kw; ki++) { + for (int oc = 0; oc < oc_block; oc++) { + if (step == 0) { + for (int i = 0; i < ker_pipeline_depth; i++) { + int aux_kernel_offset = typesize * ((oc + i) * oc_block + + ki * ic_block * oc_block); + vmovups(zmm_ker(i), EVEX_compress_addr( + aux_reg_ker, aux_kernel_offset)); + } + } else if (step < num_ker_loads - ker_pipeline_depth + 1) { + int load_offset = ker_pipeline_depth - 1; + int ker_load_reg_idx + = (step + load_offset) % ker_pipeline_depth; + int aux_kernel_offset = typesize * ((oc + load_offset) + * oc_block + ki * ic_block * oc_block); + vmovups(zmm_ker(ker_load_reg_idx), + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + + bool ker_prf_inserted = false; + auto zmm_kernel = zmm_ker(step % ker_pipeline_depth); + + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + assert(stride_w != 1 + || jj_start == nstl::max(0, + l_overflow - (kw - 1 - ki) * dilate_w)); + assert(stride_w != 1 + || jj_end == ur_w - nstl::max(0, + r_overflow - ki * dilate_w)); + + for (int jj = jj_start; jj < jj_end; jj += stride_w) { + assert((jj + l_pad - ki * dilate_w) % stride_w == 0); + int aux_dst_offset = typesize * + (((jj + l_pad - ki * dilate_w) + / stride_w) * jcp.oc_block + oc); + vfmadd231ps(zmm_out(jj, 0), zmm_kernel, + EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true)); + + int fma_idx = (step * ur_w + jj) / stride_w; + int prf_slot_idx = fma_idx / prf_inst_spacing; + if (fma_idx % prf_inst_spacing == prf_inst_trigger) { + if (!ker_prf_inserted && ker_prfs < num_ker_loads) { + int ker_prf_offset = typesize + * ker_prfs * jcp.oc_block; + mic_prefetcht1(EVEX_compress_addr( + aux_reg_ker_prf, ker_prf_offset)); + ker_prf_inserted = true; + ker_prfs++; + } else { + int inp_prf_idx = prf_slot_idx - ker_prfs; + if (inp_prf_idx < num_inp_prfs) { + int inp_prf_offset + = ic_block * typesize + * ((inp_prf_idx / kw) * kw + + (inp_prf_idx % kw)); + mic_prefetcht0(EVEX_compress_addr( + aux_reg_dst_prf, inp_prf_offset)); + } + } + } + } + step++; + } + } + + add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block); + add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block); + sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block); + + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + sub(aux_reg_dst_d_prf, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh + * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + } + + if (jcp.ndims == 5) + { + pop(reg_src); + pop(reg_src_prf); + } +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core( + int ur_w, int l_overflow, int r_overflow) +{ + int kw = jcp.kw; + int ow = jcp.ow; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int nb_ic_block = jcp.nb_ic_blocking; + Label kh_label, kd_label; + + int shift_ker_ptr = typesize * kw * oc_block * ic_block; + int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block; + + auto output_offset = [=](int oi, int oc, int ki) { + return typesize * + (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc); + }; + auto kernel_offset = [=](int icb, int oc, int ki) { + int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki; + int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block; + int oc_offset = oc * jcp.oc_block; + return typesize * (blk_offset + oc_offset); + }; + + if (one_of(jcp.ndims, 3, 4)) { + mov(aux_reg_dst, reg_dst); + mov(aux_reg_ker, reg_ker); + } + + if (jcp.ndims == 5) { + push(reg_src_prf); + push(reg_src); + + mov(reg_ki, ptr[param + GET_OFF(kd_padding)]); + mov(aux_reg_dst_d, reg_dst); + mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]); + + L(kd_label); + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + } else { + mov(reg_kj, reg_kh); + } + + if (jcp.ndims == 5) { + mov(aux_reg_dst, aux_reg_dst_d); + mov(aux_reg_ker, aux_reg_ker_d); + } + + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_iw_start(ki, l_overflow); + int jj_end = get_iw_end(ur_w, ki, r_overflow); + for (int oc = 0; oc < oc_block; oc++) { + if (jcp.kernel_kind == expl_bcast) { + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_output_offset = output_offset(jj, oc, ki); + vbroadcastss(zmm_inp(jj, nb_ic_block), + ptr[aux_reg_dst + aux_output_offset]); + } + } + for (int ii = 0; ii < nb_ic_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, oc, ki); + if (jj_end - jj_start > 0) + vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker, + aux_kernel_offset)); + for (int jj = jj_start; jj < jj_end; jj += stride_w) + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_out(jj, ii), + zmm_inp(jj, nb_ic_block), zmm_wei); + else + vfmadd231ps(zmm_out(jj, ii), zmm_wei, + EVEX_compress_addr(aux_reg_dst, + output_offset(jj, oc, ki), true)); + } + } + } + add(aux_reg_ker, shift_ker_ptr); + sub(aux_reg_dst, shift_dst_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + sub(aux_reg_dst_d, + typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block); + add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block); + + dec(reg_ki); + cmp(reg_ki, 0); + jg(kd_label, T_NEAR); + + pop(reg_src); + pop(reg_src_prf); + } +} + +inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop( + int ur_w, int l_overflow, int r_overflow) +{ + if (jcp.ndims == 5) push(reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + mov(reg_kj, ptr[param + GET_OFF(kd_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + } + mov(reg_kj, ptr[param + GET_OFF(kh_padding)]); + cmp(reg_kj, 0); + je(skip_compute_loop, T_NEAR); + + if (jcp.ver == ver_4fma) + compute_loop_4fma(ur_w, l_overflow, r_overflow); + else if (jcp.ver == ver_fma) + if (mayiuse(avx512_mic)) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1) + compute_loop_fma(ur_w, l_overflow, r_overflow); + else + compute_loop_fma_core(ur_w, l_overflow, r_overflow); + else + assert("!unknown convolution version"); + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) pop(reg_oi); +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::generate() +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ur_w = jcp.ur_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ur_w_tail = jcp.ur_w_tail; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block; + int src_shift = jcp.typesize_out * ur_w * oc_block; + + preamble(); + + mov(reg_src, ptr[param + GET_OFF(src)]); + mov(reg_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + + mov(reg_kh, ptr[param + GET_OFF(kh_padding)]); + mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]); + mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]); + mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]); + + int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w); + int r_overflow = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad)) / stride_w); + int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w + - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w); + + int n_oi = iw / ur_w; + if (r_overflow1 > 0) n_oi--; + + if (ur_w == iw) { + compute_loop(ur_w, l_overflow, r_overflow); + } else if (n_oi == 0) { + compute_loop(ur_w, l_overflow, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + if (ur_w_tail != 0) + compute_loop(ur_w_tail, 0, r_overflow); + } else { + xor_(reg_oi, reg_oi); + if (l_overflow > 0) { + compute_loop(ur_w, l_overflow, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + } + if ((l_overflow <= 0 && n_oi > 0) + || (l_overflow > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); { + compute_loop(ur_w, 0, 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + compute_loop(ur_w, 0, r_overflow1); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + add(reg_src_prf, src_shift); + add(reg_dst_prf, dst_shift); + } + if (ur_w_tail != 0) { + compute_loop(ur_w_tail, 0, r_overflow); + } + } + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf( + jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + if (!mayiuse(avx512_common)) return status::unimplemented; + + jcp = zero<decltype(jcp)>(); + + jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + int ndims = diff_src_d.ndims(); + + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2]; + jcp.iw = diff_src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + if ((jcp.dilate_w != 0 && jcp.stride_w != 1) + || (jcp.dilate_d != 0 && jcp.stride_d != 1) + || (jcp.dilate_h != 0 && jcp.stride_h != 1)) + return status::unimplemented; + + jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1); + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + jcp.back_pad = (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1); + + jcp.aligned_threads = 0; + + jcp.is_1stconv = false; + + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && diff_src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + + auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i) + : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i); + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + jcp.ur_w = jcp.stride_w; + + int regs = 28; + if (jcp.iw <= regs) + jcp.ur_w = jcp.iw; + else { + for (int ur_w = regs; ur_w > 0; --ur_w) + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - jcp.l_pad) / jcp.stride_w); + int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w); + int n_oi = jcp.iw / jcp.ur_w; + if (r_overflow1 > 0) n_oi--; + + if (mayiuse(avx512_common) + && diff_dst_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && diff_src_d.data_type() == data_type::f32) { + jcp.ver = ver_fma; + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + if (mayiuse(avx512_mic_4ops) + && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && jcp.ver != ver_fma) + return status::unimplemented; + + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (jcp.ver == ver_4fma) { + if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) { + jcp.nb_ic_blocking = 2; + } else { + for (int i = jcp.nb_ic; i > 0; i--) + if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + } + } + + jcp.loop_order = loop_gnc; + + bool large_code_size = (jcp.ur_w != jcp.ow) + && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1)) + && (r_overflow1 > 0) && (l_overflow > 0); + if (large_code_size) { + const int max_code_size = 24 * 1024; + const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw; + int mult = 1; + if (l_overflow > 0) mult += 1; + if (r_overflow1 > 0) mult += 1; + for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) { + if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2 + < max_code_size) { + if (ur_w % jcp.stride_w == 0) { + jcp.ur_w = ur_w; + break; + } + } + } + } + + if (jcp.ver == ver_fma && mayiuse(avx512_core)) { + int try_nb_ic_blocking = 2; + unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block + * try_nb_ic_blocking * jcp.kh; + unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block; + unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * try_nb_ic_blocking; + unsigned int ker_total_size = ker_inp_size + ker_out_size + + ker_wei_size; + if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8) + || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13)) + || ker_total_size > L1_cache_size ))) + || jcp.stride_h > 1 || jcp.stride_d > 1) { + jcp.kernel_kind = embd_bcast; + jcp.ur_w = nstl::min(jcp.iw, regs); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size + && jcp.ow > 8)) && jcp.stride_h == 1) + if (jcp.nb_ic % try_nb_ic_blocking == 0) { + jcp.nb_ic_blocking = try_nb_ic_blocking; + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } else { + jcp.kernel_kind = expl_bcast; + jcp.nb_oc_blocking = 1; + jcp.nb_ic_blocking = 4; + if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic; + if (jcp.nb_ic % jcp.nb_ic_blocking != 0) + for (int i = jcp.nb_ic_blocking; i > 0; i--) + if (jcp.nb_ic % i == 0) { + jcp.nb_ic_blocking = i; + break; + } + jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1); + if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw; + } + } + jcp.ur_w_tail = jcp.iw % jcp.ur_w; + + if (l_overflow * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w); + if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w) + return status::unimplemented; + if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0)) + return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_oc_L2 = jcp.nb_oc; + if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) { + for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc; + divf++) { + size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih + * jcp.id; + size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od; + size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh + * jcp.kd * jcp.nb_ic_blocking * temp_nb; + if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) { + if (jcp.kh == 3 && jcp.ih == 7) { + jcp.nb_oc_L2 = 1; + break; + } + temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf + : jcp.nb_oc_L2); + } else { + jcp.nb_oc_L2 = temp_nb; + break; + } + } + } + + args_ok = true + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + return status::success; +} + +void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28; + +void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers() +{ + Label kd_comeback_label; + + /* 'depth' loop count bound by 'kd_work_size' */ + mov(kj, reg_kd_count); + L(kd_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kd_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers() +{ + Label kh_comeback_label, kd_comeback_label; + mov(kj, reg_kh); + L(kh_comeback_label); { + int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult); + sub(reg_kernel, + jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_comeback_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(Zmm(i_kw * ic_block_step + i_ic), + EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block + + i_ic) * jcp.oc_block + kernel_offset)); + + for (int i_ur = 0; i_ur < ur_w; i_ur++) { + if (i_ur == 0) { + vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 0) + * oc_block + output_offset)); + if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block + + output_offset)); + if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block + + output_offset)); + if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + } else if (i_ur + 3 < ur_w) + vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4), + EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block + + output_offset)); + + for (int i_kw = 0; i_kw < kw; i_kw++) { + int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1); + if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w + + (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue; + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + const size_t i_offset = (size_t)input_offset + + (size_t)typesize * (jcp.ver == ver_4fma + ? (i_iw - pad_l + i_ic * jcp.tr_iw) + : (jcp.is_1stconv + ? (i_iw - pad_l) + (size_t)i_ic + * ((size_t)jcp.ih*jcp.iw*jcp.id) + : (i_iw - pad_l) * ic_block + i_ic)); + vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic), + Zmm(kw * ic_block_step + i_ur % 4), + EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt, + true)); + } + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) + vmovups(EVEX_compress_addr(reg_kernel, typesize + * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset), + Zmm(i_kw * ic_block_step + i_ic)); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + // TODO: add prefetches to fma version as well + + assert(jcp.ver == ver_4fma); + + int kw = jcp.kw; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + auto zmm_ker = [=](int i_kw, int i_ic) { + return Zmm(i_kw * ic_block_step + i_ic); + }; + + auto ker_addr = [=](int i_kw, int i_ic) { + size_t local_offset + = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block; + return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset); + }; + + auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) { + int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1); + int local_offset = jcp.typesize_in * (i_iw + i_ic * stride); + return EVEX_compress_addr(reg_input, + local_offset + input_offset + extra_offset); + }; + + auto zmm_out = [=](int i_iw) { + // TODO: move reg calc to global member funcs + const int out_zmm_base_idx = 28; + return Zmm(out_zmm_base_idx + i_iw % 4); + }; + + auto out_addr = [=](int i_ur) { + return EVEX_compress_addr(reg_output, + jcp.typesize_in * i_ur * oc_block + output_offset); + }; + + auto pf_callback = [=](int i_ur, int i_kw, int i_ic) { + assert(i_ur % 4 == 0); + if (i_ur == 0) + prefetcht1(ker_addr(i_kw, i_ic)); + if (i_ur + 4 >= ur_w) + prefetcht0(ker_addr(i_kw, i_ic)); + + const ptrdiff_t next_input_block_offset + = jcp.typesize_in * ic_block_step * jcp.tr_iw; + if (i_ur % 16 == 4 && i_kw == 0) { + if (i_ur + 16 < ur_w) + prefetcht0(inp_addr(i_ur + 16, i_ic)); + else + prefetcht0(inp_addr(0, i_ic, next_input_block_offset)); + } + if (i_ur % 16 == 4 && i_kw == 1) { + if (input_wraparound) + prefetcht1(inp_addr(i_ur, i_ic, -input_offset)); + else + prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset)); + } + }; + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto zmm = zmm_ker(i_kw, i_ic); + vpxord(zmm, zmm, zmm); + } + + for (int i_ur = 0; i_ur < ur_w; i_ur += 4) { + + for (int i = 0; i < 4; i++) { + auto zmm = zmm_out(i_ur + i); + if (i_ur + i < ur_w) + vmovups(zmm, out_addr(i_ur + i)); + else + vpxord(zmm, zmm, zmm); + prefetcht0(out_addr(i_ur + i + 4)); + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + int i_iw = i_ur + i_kw; + v4fmaddps(zmm_ker(i_kw, i_ic), + zmm_out(i_ur), inp_addr(i_iw, i_ic)); + pf_callback(i_ur, i_kw, i_ic); + } + } + + for (int i_kw = 0; i_kw < kw; i_kw++) + for (int i_ic = 0; i_ic < ic_block_step; i_ic++) { + auto addr = ker_addr(i_kw, i_ic); + auto zmm = zmm_ker(i_kw, i_ic); + vaddps(zmm, zmm, addr); + vmovups(addr, zmm); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step( + int ur_w, int pad_l, int pad_r, + int ic_block_step, int input_offset, int kernel_offset, + int output_offset, bool input_wraparound) +{ + if (jcp.ver == ver_4fma) + compute_ic_block_step_4fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else if (jcp.ver == ver_fma) + compute_ic_block_step_fma(ur_w, pad_l, pad_r, + ic_block_step, input_offset, kernel_offset, output_offset, + input_wraparound); + else + assert(!"unknown convolution version"); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow_icblock( + int ic_block_step, int max_ur_w) +{ + UNUSED(max_ur_w); + + Label kh_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int inp_mul = !jcp.is_1stconv ? ic_block : 1; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) { + const int input_offset = jcp.typesize_in + * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic); + compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step, + input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0, + i_b_ic + ic_block_step >= jcp.ic_block); + } + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul); + add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + + if (jcp.ndims == 5) { + add(aux_reg_input, + jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_unroll_ow( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, kd_label; + + UNUSED(max_ur_w); + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + + int r_pad = nstl::max(0, + (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.l_pad; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); + { + xor_(b_ic, b_ic); + L(ic_block_label); { + compute_ic_block_step(ow, l_pad, r_pad, ic_block_step, + 0, 0, 0); + size_t inp_icblk_stride = jcp.is_1stconv + ? (size_t)jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_common( + int ic_block_step, int max_ur_w) +{ + Label kh_label, ic_block_label, ow_block_label, kd_label; + + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + + int ow = jcp.ow; + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad; + + int ur_w = nstl::min(ow, max_ur_w); + int ur_w_trips = ow / ur_w; + int ur_w_tail = ow % ur_w; + if ((ur_w_tail == 0 && r_pad != 0) + || r_pad >= ur_w_tail) { + if (ur_w_trips > 1) { + ur_w_tail += ur_w; + ur_w_trips--; + } else { + ur_w_tail += (ur_w - ur_w / 2); + ur_w = ur_w / 2; + } + } + + int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block; + int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult; + int output_comeback = ur_w_trips * ur_w * oc_block; + + if (jcp.ndims == 5) { + L(kd_label); + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + } + + mov(kj, reg_kh); + L(kh_label); { + xor_(b_ic, b_ic); + L(ic_block_label); { + if (l_pad != 0) { + ur_w_trips--; + compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad) + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + } + + if (ur_w_trips > 0) { + xor_(reg_ur_w_trips, reg_ur_w_trips); + L(ow_block_label); { + compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0); + add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w + * inp_mult); + add(reg_output, jcp.typesize_in * ur_w * oc_block); + + inc(reg_ur_w_trips); + cmp(reg_ur_w_trips, ur_w_trips); + jl(ow_block_label, T_NEAR); + } + } + + if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad, + ic_block_step, 0, 0, 0); + + sub(reg_input, jcp.typesize_in * input_comeback); + sub(reg_output, jcp.typesize_in * output_comeback); + int inp_icblk_stride = jcp.is_1stconv + ? jcp.ih * jcp.iw * jcp.id + : (jcp.ver == ver_4fma ? jcp.tr_iw : 1); + size_t input_offset + = inp_icblk_stride * jcp.typesize_in * ic_block_step; + safe_add(reg_input, input_offset, reg_long_offt); + add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block); + + add(b_ic, ic_block_step); + cmp(b_ic, jcp.ic_block); + jl(ic_block_label, T_NEAR); + } + if (jcp.is_1stconv) { + size_t input_offset + = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block; + safe_sub(reg_input, input_offset, reg_long_offt); + add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw); + } else if (jcp.ver != ver_4fma) { + add(reg_input, jcp.typesize_in + * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block); + } + add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block); + dec(kj); + cmp(kj, 0); + jg(kh_label, T_NEAR); + } + if (jcp.ndims == 5) { + add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih + * jcp.iw * (jcp.is_1stconv ? 1 : ic_block)); + add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block + * oc_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_step_disp() +{ + int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2); + if (jcp.is_1stconv) { + bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0); + ic_block_step + = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1; + } + + bool too_large_to_unroll + = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1) + && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1); + + int ow = jcp.ow; + if (jcp.ndims == 5) { + /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of + * 'movs' must be guaranteed. */ + mov(ki, reg_kd_count); + push(reg_kd_count); + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + } + + if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) + compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w); + else if (ow <= max_ur_w) + compute_oh_step_unroll_ow(ic_block_step, max_ur_w); + else + compute_oh_step_common(ic_block_step, max_ur_w); + + if (jcp.ndims == 5) { + mov(reg_input, aux_reg_input); + mov(reg_kernel, aux_reg_kernel); + pop(reg_kd_count); + od_step_comeback_pointers(); + } else { + oh_step_comeback_pointers(); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel() +{ + Label skip_zeroing, zeroing_loop; + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jz(skip_zeroing, T_NEAR); + + Zmm zero = Zmm(0); + vpxord(zero, zero, zero); + xor_(reg_tmp, reg_tmp); + L(zeroing_loop); { + assert(jcp.oc_block * jcp.typesize_out + == cpu_isa_traits<avx512_common>::vlen); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block + * jcp.typesize_out], zero); + add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd + * jcp.typesize_out); + jnz(zeroing_loop); + } + + L(skip_zeroing); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel() +{ + Label skip_bias, bias_loop, skip_load_bias; + + mov(reg_tmp, ptr[param + GET_OFF(flags)]); + test(reg_tmp,reg_tmp); + jne(skip_bias, T_NEAR); + + mov(reg_bias, ptr[param + GET_OFF(bias)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + vpxord(Zmm(1), Zmm(1), Zmm(1)); + + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + cmp(reg_tmp, 0); + jne(skip_load_bias, T_NEAR); + vmovups(Zmm(1), ptr[reg_bias]); + + L(skip_load_bias); + + mov(reg_oi, ptr[param + GET_OFF(d_worksize)]); + sub(reg_oi, ptr[param + GET_OFF(d_index)]); + mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out); + imul(reg_oi, reg_tmp); + + xor_(reg_tmp, reg_tmp); + L(bias_loop); { + vmovups(Zmm(0), ptr[reg_output + reg_tmp]); + vaddps(Zmm(1), Zmm(1), Zmm(0)); + add(reg_tmp, jcp.oc_block * jcp.typesize_out); + cmp(reg_tmp, reg_oi); + jl(bias_loop); + } + vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1)); + + L(skip_bias); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32 + ::compute_oh_loop_common() +{ + int b_pad = jcp.b_pad; + int t_pad = jcp.t_pad; + bool is_dilated = jcp.dilate_h != 0; + int dilate_h = jcp.dilate_h + 1; + int stride_h = jcp.stride_h; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label, + oh_bpad_label, oh_bpad_label_end, od_label, od_label_end, + oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end; + + int ow = jcp.ow; + + mov(reg_kh, jcp.kh); + xor_(reg_ih_count, reg_ih_count); + xor_(reg_oj, reg_oj); + /* Compute 'top' edge */ + if (t_pad > 0) { + const int kh_range = 1 + (jcp.kh - 1) * dilate_h; + const int overflow + = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h)); + const int underflow = div_up(t_pad, dilate_h); + const int initial_inp_ker_overlap = jcp.kh - overflow - underflow; + mov(reg_kh, initial_inp_ker_overlap); + add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block + * jcp.oc_block); + // generate loop to process kernel while it remains within t_pad + ih + if (kh_range < t_pad + jcp.ih) { + if (is_dilated) { + const int tail = t_pad % dilate_h; + const int shift = tail == 0 ? 0 : dilate_h - tail; + mov(reg_tmp, shift); + if (tail != 0) + add(reg_input, jcp.typesize_in * shift * iw * inp_mult); + } + L(oh_tpad_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_shift, T_NEAR); + // unshift input as new kernel element enters + sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult); + xor_(reg_tmp, reg_tmp); + } + // kernel overlap only changes when (t_pad + oj) % dilate_h == 0 + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_kh, stride_h); + if (is_dilated) { + jmp(oh_dilate_label_noshift, T_NEAR); + L(oh_dilate_label_shift); + // shift input as old kernel element progresses + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + L(oh_dilate_label_noshift); + } + inc(reg_oj); + add(reg_ih_count, stride_h); + + // final number of kernel elements that overlap with input + const int final_inp_ker_overlap + = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h)); + cmp(reg_kh, final_inp_ker_overlap); + jl(oh_tpad_label, T_NEAR); + } + } + // need second loop to process kernel if it is larger than the input + // (does not apply to dilations as they must have unit stride) + if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h : + t_pad % stride_h)) { + assert(!is_dilated); + mov(reg_kh, jcp.ih); + L(oh_tpad_tail_label); { + compute_oh_step_disp(); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw + * jcp.ic_block * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h)); + jl(oh_tpad_tail_label, T_NEAR); + } + } + // correct any excess shifts to kernel and input + // (does not apply to dilations as they must have unit stride, + // kernel must fit inside input, and padding is smaller than input) + if (t_pad <= jcp.oh * stride_h) { + // kernel has moved beyond padding (adjust for stride effects) + if (t_pad % stride_h != 0) { + assert(!is_dilated); + int inp_corr = stride_h - t_pad % stride_h; + add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw + * jcp.ic_block * jcp.oc_block); + add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult); + } + } else { + // kernel still overlaps padding (complete reset) + assert(!is_dilated); + sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h) + * jcp.kw * jcp.ic_block * jcp.oc_block); + } + } + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + cmp(reg_oj, jcp.oh); + jge(oh_label, T_NEAR); + + /* Compute middle block(s) */ + mov(reg_kh, jcp.kh); + L(oh_label); { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + + inc(reg_oj); + add(reg_ih_count, stride_h); + + cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h); + jge(oh_label_end, T_NEAR); + + cmp(reg_oj, jcp.oh); + jl(oh_label, T_NEAR); + } + L(oh_label_end); + + /* Compute bottom edge */ + if (b_pad > 0) { + cmp(reg_oj, jcp.oh); + jge(oh_bpad_label_end, T_NEAR); + + if (is_dilated) { + mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations + mov(reg_tmp, 0); + } else { + mov(reg_kh, jcp.ihp - b_pad); + sub(reg_kh, reg_ih_count); + } + L(oh_bpad_label); + { + compute_oh_step_disp(); + add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult); + add(reg_output, jcp.typesize_in * ow * jcp.oc_block); + if (is_dilated) { + inc(reg_tmp); + cmp(reg_tmp, dilate_h); + jl(oh_dilate_label_end, T_NEAR); + xor_(reg_tmp, reg_tmp); + } + sub(reg_kh, stride_h); + cmp(reg_kh, 0); + jle(oh_bpad_label_end, T_NEAR); + if (is_dilated) + L(oh_dilate_label_end); + + inc(reg_oj); + cmp(reg_oj, jcp.oh); + jl(oh_bpad_label, T_NEAR); + } + L(oh_bpad_label_end); + } +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() { + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw; + int ow = jcp.ow; + const int input_backpad_overlap + = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d); + + const size_t filter_shift + = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block; + const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult; + const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block; + + Label d_loop_label, loop_end_label, common_block_label, fpad_end_label, + backpad_end_label, backpad_label; + + if (jcp.with_bias) bias_kernel(); + + /* initially offset 'kd' by f_pad */ + add(reg_kernel, ptr[param + GET_OFF(kd_offset)]); + + mov(reg_input_d, ptr[param + GET_OFF(src)]); + mov(reg_output_d, ptr[param + GET_OFF(dst)]); + mov(reg_d_index, ptr[param + GET_OFF(d_index)]); + mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]); + + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jge(loop_end_label, T_NEAR); + + L(d_loop_label); + + mov(reg_input, reg_input_d); + mov(reg_output, reg_output_d); + + push(reg_input_d); + push(reg_output_d); + push(reg_d_index); + + compute_oh_loop_common(); + + pop(reg_d_index); + pop(reg_output_d); + pop(reg_input_d); + + /* Compute 'front' edge */ + if (jcp.f_pad > 0) { + + /* Check if within fpad region */ + cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d)); + jge(fpad_end_label, T_NEAR); + + /* Fpad steps */ + sub(reg_kernel, filter_shift * jcp.stride_d); + add(reg_kd_count, jcp.stride_d); + + /* Final number of kernel elements that overlap with input */ + const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id); + cmp(reg_kd_count, inp_ker_overlap); + jl(common_block_label, T_NEAR); + + /* Correct any excess shifts to kernel and input */ + if (jcp.f_pad <= jcp.od * jcp.stride_d) { + /* Filter has moved beyond padding (adjust for stride effects) */ + if (jcp.f_pad % jcp.stride_d != 0) { + int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d; + add(reg_kernel, filter_shift * inp_corr); + add(reg_input_d, input_shift * inp_corr); + } + } else { + /* Filter still overlaps padding (complete reset) */ + sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift); + } + + /* Apply correction */ + mov(reg_kd_count, jcp.kd); + jmp(common_block_label); + + L(fpad_end_label); + } + + /* Compute bottom edge */ + if (jcp.back_pad > 0) { + + /* Check if within back_pad region */ + cmp(reg_d_index, input_backpad_overlap - 1); + jl(backpad_end_label, T_NEAR); + jg(backpad_label, T_NEAR); + + /* Execute overlap correction between the filter and the initial + * back_pad region. */ + mov(reg_kd_count, + jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d); + jmp(backpad_end_label, T_NEAR); + + L(backpad_label); + sub(reg_kd_count, jcp.stride_d); + cmp(reg_kd_count, 0); + jle(loop_end_label, T_NEAR); + + L(backpad_end_label); + } + + /* Compute middle block */ + add(reg_input_d, input_shift * jcp.stride_d); + + /* Execute common block and loop */ + L(common_block_label); + add(reg_output_d, output_shift); + inc(reg_d_index); + cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]); + jl(d_loop_label, T_NEAR); + + L(loop_end_label); +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() { + // FIXME: use register mapping from the class declaration + bool ok = jcp.ver == ver_4fma + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && everyone_is(1, jcp.stride_h, jcp.stride_w); + if (!ok) return false; + if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2) + return false; + + // General code layout: + // + // Blocking over OH -- top level + // (Reduces L2 pressure; not very useful right now) + // Loop over all KHxKW kernel -- emit_kh_kw_loop() + // Loop over OH block -- emit_h_loop() + // Loop over OW blocks -- emit_fma_block() + // (Supports both fully unrolled and partially unrolled versions to + // reduce code size) + // Loop over OW block -- emit_fma_step() + + int max_working_set_size = 128 * 1024; + int pad_ow = jcp.ow; + + int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in; + int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in; + int row_size = inp_row_size + out_row_size; + + int h_block_size = jcp.oh; + int working_set_size = row_size * h_block_size; + + if (working_set_size > max_working_set_size) { + int opt_working_set_size = 48 * 1024; + assert(opt_working_set_size < max_working_set_size); + + while (working_set_size > opt_working_set_size) { + for (int i = 2; i <= h_block_size; i++) + if (i == h_block_size) + h_block_size = h_block_size / 2; + else if (h_block_size % i == 0) { + h_block_size = h_block_size / i; + break; + } + working_set_size = row_size * h_block_size; + + if (h_block_size == 1 && working_set_size > opt_working_set_size) + return false; + } + } + + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below) + if (h_block_size < nstl::max(1, jcp.t_pad) + || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size + : jcp.oh % h_block_size)) + return false; + + // check that we can use simple arithmetic for prefetch address + // calculations + // TODO: we need some traits for this check (Roma) + int cache_line_size = 64; + assert(jcp.ic_block * typesize == 64); + assert(jcp.oc_block * typesize == 64); + + int num_inp_l2_pfs = jcp.tr_iw * h_block_size; + int avg_h_loop_len = h_block_size; + int num_inp_l2_pfs_per_fma_block + = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + int num_out_l2_pfs = pad_ow * h_block_size; + int num_out_l2_pfs_per_fma_block + = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh); + + Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors + Reg64 reg_kh = rax; + Reg64 reg_kw = rbx; + Reg64 reg_tmp = abi_not_param1; + Reg32 reg_tmp_w = reg_tmp.cvt32(); + Reg64 reg_ohs = rdx; + Reg64 reg_ihs = rsi; + Reg64 reg_h = r8; + Reg64 reg_i = r9; + Reg64 reg_j = r10; + + Reg64 reg_inp = r13; + Reg64 reg_out = r14; + Reg64 reg_ker = r15; + + Reg64 reg_inp_pf_l1 = rbp; + + Reg64 reg_inp_pf_l2 = r11; + Reg64 reg_out_pf_l2 = r12; + + Xmm reg_inp_pf_save = xmm17; + Xmm reg_out_pf_save = xmm18; + + Reg64 reg_inp_save = abi_param1; + Reg64 reg_out_save = reg_tmp; + + auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); }; + auto zmm_ker = [&](int ic1) { return Zmm(ic1); }; + auto inp_addr = [&](int oi, int ic1) { + return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in]; + }; + auto out_addr = [&](int oi, int oj = 0) { + assert(jcp.ver == ver_4fma); + return ptr[reg_out + + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in]; + }; + auto ker_addr = [&](int ic1) { + return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out]; + }; + + auto emit_block = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row) + { + // TODO: add an fma version (Roma) + auto pad_ow = jcp.ow; + + int ow4u = rnd_up(pad_ow, 4); + int def_step_size = 16; + + bool has_w_tail = (pad_ow % def_step_size != 0 + || pad_ow % 4 != 0); + bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail; + + auto emit_step = [&](int ur_ow, + int num_inp_l1_pfs_per_fma_step, + int num_inp_l2_pfs_per_fma_step, + int num_out_l2_pfs_per_fma_step, bool is_w_tail) + { + bool block_wraparound = is_w_tail && is_last_row; + + assert(ur_ow % 4 == 0); + int tail_size = ow4u % ur_ow; + int this_ur_ow + = (is_w_tail && tail_size) ? tail_size : ur_ow; + int ow_last_chunk4 = pad_ow % 4; + int ow_zero_tail4 = ow_last_chunk4 + ? 4 - ow_last_chunk4 : 0; + + auto emit_out_pf = [&](int oi) { +#if 1 + if (oi + def_step_size < ur_ow || !block_wraparound) + mic_prefetcht0(ptr[reg_out + + ((def_step_size + oi) + * jcp.oc_block * jcp.typesize_in)]); + else { + assert(block_wraparound); + assert(oi + def_step_size >= ur_ow); + mic_prefetcht0(ptr[reg_out_save + + ((oi + def_step_size - ur_ow) + * jcp.oc_block * jcp.typesize_in)]); + } +#else + // XXX: This is an alternative prefetching strategy that + // always prefetches the next row. Keeping it here for + // future experiments (Roma) + if (!block_wraparound) + mic_prefetcht0(ptr[reg_out + + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]); + else + mic_prefetcht0(ptr[reg_out + reg_ohs + - ((h_block_size - 1) * jcp.ow + - oi) * jcp.oc_block * jcp.typesize_in]); +#endif + if (oi < num_out_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_out_pf_l2 + + oi * jcp.oc_block * jcp.typesize_in]); + }; + + auto emit_inp_pf = [&](int oi4, int ic1) { + int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block; + int num_pf_slots = jcp.ic_block * ur_ow / 4; + + int num_pfs = num_inp_l1_pfs_per_fma_step + + num_inp_l2_pfs_per_fma_step; + int pf_freq = nstl::max(1, num_pf_slots / num_pfs); + + if (pf_slot_idx % pf_freq) + return; + + int pf_idx = pf_slot_idx / pf_freq; + + if (pf_idx < num_inp_l2_pfs_per_fma_step) + mic_prefetcht1(ptr[reg_inp_pf_l2 + + pf_idx * jcp.ic_block * jcp.typesize_in]); + else { + pf_idx -= num_inp_l2_pfs_per_fma_step; + // prefetch the 'tail' of the cache line because most of + // the accesses are not aligned + mic_prefetcht0(ptr[reg_inp_pf_l1 + + pf_idx * jcp.ic_block * jcp.typesize_in + + cache_line_size - jcp.typesize_in]); + } + }; + + auto numloads = 4; + + int steps = this_ur_ow; + for (int oi4 = 0; oi4 < steps; oi4 += numloads) { + for (int oi1 = 0; oi1 < numloads; oi1++) { + int oi = oi4 + oi1; + if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) { + vmovups(zmm_out(oi), out_addr(oi)); + emit_out_pf(oi); + } else { + auto zmm = zmm_out(oi); + vpxord(zmm, zmm, zmm); + } + } + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + if (jcp.ver == ver_4fma) { + v4fmaddps(zmm_ker(ic1), + zmm_out(oi4), inp_addr(oi4, ic1)); + } else { + assert(!"unknown convolution version"); + } + emit_inp_pf(oi4, ic1); + } + } + }; + + // Input is transposed and padded but we only access about jcp.iw + // elements so use that to compute the # of cache lines in each 'row' + int num_inp_l1_pfs + = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block; + + if (full_w_unroll) { + emit_step(ow4u, num_inp_l1_pfs, + num_inp_l2_pfs_per_fma_block, + num_out_l2_pfs_per_fma_block, true); + add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size); + add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size); + } else { + Label w_loop; + int num_w_iters = pad_ow / def_step_size; + int num_w_iters_full = num_w_iters + has_w_tail; + int num_inp_l1_pfs_per_fma_step + = div_up(num_inp_l1_pfs, num_w_iters_full); + int num_inp_l2_pfs_per_fma_step + = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full); + int num_out_l2_pfs_per_fma_step + = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full); + mov(reg_i, num_w_iters); + L(w_loop); { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, false); + add(reg_inp, def_step_size * jcp.typesize_in); + add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in); + add(reg_inp_pf_l1, + num_inp_l1_pfs_per_fma_step * cache_line_size); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + sub(reg_i, 1); + jnz(w_loop); + } + if (has_w_tail) { + emit_step(def_step_size, num_inp_l1_pfs_per_fma_step, + num_inp_l2_pfs_per_fma_step, + num_out_l2_pfs_per_fma_step, true); + add(reg_inp_pf_l2, + num_inp_l2_pfs_per_fma_step * cache_line_size); + add(reg_out_pf_l2, + num_out_l2_pfs_per_fma_step * cache_line_size); + } + // reset reg_inp and reg_out because emit_h_loop expects + // unmodified pointers + int w_offset = num_w_iters * def_step_size; + sub(reg_inp, w_offset * jcp.typesize_in); + sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in); + } + }; + + auto emit_h_loop = [&](int h_block_size, + bool is_last_block, bool is_last_kh_kw_iter) + { + Label h_loop, skip_h_loop; + mov(reg_j, 1); + cmp(reg_j, reg_h); + je(skip_h_loop, T_NEAR); + L(h_loop); { + + lea(reg_inp_pf_l1, + ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]); + emit_block(h_block_size, + is_last_block, is_last_kh_kw_iter, false); + + add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in); + add(reg_j, 1); + cmp(reg_j, reg_h); + jb(h_loop); + } + + L(skip_h_loop); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + mic_prefetcht0(ker_addr(ic1)); + + lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]); + emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true); + }; + + auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block, + int h_block_size) + { + xor_(reg_kh, reg_kh); + Label kh_loop, kh_loop_end; + + int last_oh_block_size + = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size); + int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size; + // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size + int ih_block_size = oh_block_size - 1 + jcp.kh + - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad; + + L(kh_loop); { + // determine starting indices for this block + if (is_first_block) { + xor_(reg_tmp, reg_tmp); + mov(reg_ohs, jcp.t_pad); + sub(reg_ohs, reg_kh); + cmovb(reg_ohs, reg_tmp); + + mov(reg_ihs, reg_ohs); + sub(reg_ihs, jcp.t_pad); + add(reg_ihs, reg_kh); + } else { + xor_(reg_ohs, reg_ohs); + mov(reg_ihs, reg_kh); + } + + // determine effective size of block based on padding + mov(reg_tmp, oh_block_size); + sub(reg_tmp, reg_ohs); + mov(reg_h, ih_block_size); + sub(reg_h, reg_ihs); + cmp(reg_tmp, reg_h); + cmovb(reg_h, reg_tmp); + + Label kh_loop_work; + cmp(reg_h, 0); + jg(kh_loop_work, T_NEAR); + + // empty h loop for this jcp.kh: + // - set the output to 0 if necessary + // - move ker pt + // - jump to the end + sub(reg_h, 1); + Label skip_ker_zeroing; + + // The reg_ker ptr has highest bit set if the output needs to be + // zeroed. Those who have byte-aligned their data will suffer the + // consiquences :( + // TODO: move the flag to a mask register? (Roma) + test(reg_ker, 1); + jz(skip_ker_zeroing, T_NEAR); + + Label zeroing_loop; + vpxord(zmm0, zmm0, zmm0); + and_(reg_ker, ~1); // temporarily clear the zeroing flag + mov(reg_tmp, jcp.kw); + L(zeroing_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) + vmovups(ker_addr(ic1), zmm0); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out); + sub(reg_tmp, 1); + jnz(zeroing_loop, T_NEAR); + } + // restore the zeroing flag (it will be cleared after the end of + // emit_kh_kw_loop, but we may need it until then) + or_(reg_ker, 1); + jmp(kh_loop_end, T_NEAR); + + L(skip_ker_zeroing); + add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw + * jcp.typesize_out); + jmp(kh_loop_end, T_NEAR); + + L(kh_loop_work); + + mul_by_const(reg_ihs, reg_tmp, + jcp.tr_iw * jcp.ic_block * jcp.typesize_in); + mul_by_const(reg_ohs, reg_tmp, + pad_ow * jcp.oc_block * jcp.typesize_in); + + add(reg_inp, reg_ihs); + add(reg_out, reg_ohs); + + Label kw_loop; + xor_(reg_kw, reg_kw); + L(kw_loop); { + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vpxord(zmm, zmm, zmm); + mic_prefetcht1(ker_addr(ic1)); + } + + mov(reg_out_save, reg_out); + mov(reg_inp_save, reg_inp); + lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]); + +#if 0 + // XXX: Generate code with special prefetches when switching + // blocks or at the end of the last block. Disabled to reduce + // code size and because there's no performance benefit (Roma) + Label regular_h_loop, end_h_loop; + cmp(reg_kw, jcp.kw - 1); + jne(regular_h_loop, T_NEAR); + cmp(reg_kh, jcp.kh - 1); + jne(regular_h_loop, T_NEAR); + + emit_h_loop(oh_block_size, is_last_block, true); + jmp(end_h_loop, T_NEAR); + + L(regular_h_loop); + emit_h_loop(oh_block_size, is_last_block, false); + + L(end_h_loop); +#else + emit_h_loop(oh_block_size, is_last_block, false); +#endif + + mov(reg_out, reg_out_save); + mov(reg_inp, reg_inp_save); + + Label do_store; + // The reg_ker ptr has highest bit set if the output needs to + // be zeroed. Those who have byte-aligned their data will + // suffer the consiquences :( + mov(reg_tmp, reg_ker); + and_(reg_ker, ~1); + test(reg_tmp, 1); + jnz(do_store, T_NEAR); + + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + if (jcp.ver == ver_4fma) { + vaddps(zmm, ker_addr(ic1)); + } else { + assert(!"unknown convolution version"); + } + } + + L(do_store); + for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) { + auto zmm = zmm_ker(ic1); + vmovups(ker_addr(ic1), zmm); + } + + mov(reg_ker, reg_tmp); + add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out); + add(reg_kw, 1); + cmp(reg_kw, jcp.kw); + jl(kw_loop); + } + + sub(reg_inp, reg_ihs); + sub(reg_out, reg_ohs); + + + L(kh_loop_end); + add(reg_kh, 1); + cmp(reg_kh, jcp.kh); + jl(kh_loop); + } + }; + + mov(reg_inp, ptr[param + GET_OFF(src)]); + mov(reg_out, ptr[param + GET_OFF(dst)]); + mov(reg_ker, ptr[param + GET_OFF(filt)]); + mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]); + mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]); + mov(reg_tmp, ptr[param + GET_OFF(channel)]); + or_(reg_ker, reg_tmp); + + bool single_kh_kw_loop = (h_block_size == jcp.oh); + + size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in; + size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad); + size_t inp_block_step = inp_row_step * h_block_size; + size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in + * h_block_size; + + if (!single_kh_kw_loop) { + // Save the original prefetch pointers from the OpenMP driver + vmovq(reg_inp_pf_save, reg_inp_pf_l2); + vmovq(reg_out_pf_save, reg_out_pf_l2); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, first_inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + } + emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size); + + if (!single_kh_kw_loop) { + size_t ker_reset_offset + = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh; + sub(reg_ker, ker_reset_offset); + and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates + + add(reg_inp, first_inp_block_step); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + + int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2; + if (num_innermost_iters > 0) { + Label h_block_loop; + + mov(reg_tmp_w, num_innermost_iters); + kmovw(reg_h_block, reg_tmp_w); + L(h_block_loop); { + emit_kh_kw_loop(false, false, h_block_size); + sub(reg_ker, ker_reset_offset); + add(reg_inp, inp_row_step * h_block_size); + add(reg_out, out_block_step); + mov(reg_inp_pf_l2, reg_inp); + add(reg_inp_pf_l2, inp_block_step); + mov(reg_out_pf_l2, reg_out); + add(reg_out_pf_l2, out_block_step); + kmovw(reg_tmp_w, reg_h_block); + sub(reg_tmp_w, 1); + kmovw(reg_h_block, reg_tmp_w); + jnz(h_block_loop); + } + } + + // Restore the original prefetch pointers that came from the OpenMP + // driver + vmovq(reg_inp_pf_l2, reg_inp_pf_save); + vmovq(reg_out_pf_l2, reg_out_pf_save); + emit_kh_kw_loop(false, true, h_block_size); + } + + return true; +} + +bool jit_avx512_common_conv_bwd_weights_kernel_f32 + ::flat_4ops_compute() { + const auto &j = jcp; + const bool ok = j.ver == ver_4fma && j.is_1stconv + && everyone_is(0, j.dilate_h, j.dilate_w); + if (!ok) return false; + + Reg64 reg_ptr_tr_src = r8; + Reg64 reg_ptr_dst = r9; + Reg64 reg_ptr_wei = r10; + Reg64 reg_ptr_bia = r11; + + Reg64 reg_kh_step = rax; + Reg64 reg_oh = abi_not_param1; + Reg64 reg_kh = rdx; + + Reg32 reg_flag_save = ebx; + Reg32 reg_flag = esi; + + Zmm vbia(31); + + auto zmm_wei = [&](int kh, int kw) { + return Zmm(8 + kh * j.kw + kw); + }; + auto zmm_dst = [&](int ow) { + return Zmm(ow % 8); + }; + + auto addr_tr_src = [&](int kh, int iw) { + return ptr[reg_ptr_tr_src + + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in]; + }; + auto addr_dst = [&](int ow) { + return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in]; + }; + auto addr_wei = [&](int kh, int kw) { + return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block + * jcp.typesize_out]; + }; + + auto emit_fma_block = [&](int kh_step) { + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + auto vwei = zmm_wei(kh, kw); + vpxord(vwei, vwei, vwei); + } + } + + for (int ow = 0; ow < j.ow; ow += 4) { + for (int _ow = ow; _ow < ow + 4; ++_ow) { + auto vdst = zmm_dst(_ow); + if (_ow < j.ow) + vmovups(vdst, addr_dst(_ow)); + else + vpxord(vdst, vdst, vdst); + } + + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) { + const int iw = ow + (kw % j.stride_w) * j.tr_ld + + (kw / j.stride_w); + v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow), + addr_tr_src(kh, iw)); + if (1 && kh == 0 && kw < 4) { + prefetcht1(ptr[reg_ptr_dst + + (j.ow + ow + kw) * jcp.oc_block + * jcp.typesize_in]); + } + if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */ + const int off = kw + 4 - j.kw; + if (off >= 0 && ow + off < j.ow) + vaddps(vbia, vbia, zmm_dst(ow + off)); + } + } + } + } + + Label l_store; + test(reg_flag, FLAG_MB_FIRST); + jnz(l_store, T_NEAR); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vaddps(zmm_wei(kh, kw), addr_wei(kh, kw)); + } + L(l_store); + for (int kh = 0; kh < kh_step; ++kh) { + for (int kw = 0; kw < j.kw; ++kw) + vmovups(addr_wei(kh, kw), zmm_wei(kh, kw)); + } + }; + + auto emit_kh_loop = [&]() { + const int kh_step_rem = j.kh % j.kh_step; + xor_(reg_kh, reg_kh); + mov(reg_kh_step, j.kh_step); + + Label l_kh_loop; + L(l_kh_loop); { + Label l_done; + + if (kh_step_rem != 0) { + Label l_keep_kh_step; + cmp(reg_kh, j.kh - j.kh_step); + jle(l_keep_kh_step, T_NEAR); + + mov(reg_kh_step, kh_step_rem); + emit_fma_block(kh_step_rem); + jmp(l_done, T_NEAR); + + L(l_keep_kh_step); + } + + emit_fma_block(j.kh_step); + + L(l_done); + + add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out); + add(reg_kh, j.kh_step); + + cmp(reg_kh, j.kh); + jl(l_kh_loop, T_NEAR); + } + + const int kh_steps = rnd_up(j.kh, j.kh_step); + sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in); + sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out); + }; + + auto emit_oh_loop = [&]() { + mov(reg_oh, j.oh); + + Label l_oh_loop; + L(l_oh_loop); { + Label l_restore_mb_flag, l_jump; + + cmp(reg_oh, j.oh); + je(l_restore_mb_flag, T_NEAR); + + and_(reg_flag, ~FLAG_MB_FIRST); + jmp(l_jump, T_NEAR); + + L(l_restore_mb_flag); + mov(reg_flag, reg_flag_save); + + L(l_jump); + + emit_kh_loop(); + + add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld + * jcp.typesize_in); + add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in); + + dec(reg_oh); + jnz(l_oh_loop, T_NEAR); + } + }; + + auto emit_bia_store = [&]() { + if (!j.with_bias) return; + + Label l_bia_store, l_bia_skip; + test(reg_flag, FLAG_IC_FIRST); + jz(l_bia_skip); + + test(reg_flag, FLAG_MB_FIRST); + jnz(l_bia_store, T_NEAR); + vaddps(vbia, ptr[reg_ptr_bia]); + L(l_bia_store); + vmovups(ptr[reg_ptr_bia], vbia); + L(l_bia_skip); + }; + + mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]); + mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]); + mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]); + mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]); + mov(reg_flag_save, ptr[param + GET_OFF(flags)]); + + vpxord(vbia, vbia, vbia); + emit_oh_loop(); + emit_bia_store(); + + return true; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop() +{ + if (flat_4ops_compute()) + return; + if (compute_full_spat_loop()) + return; + + maybe_zero_kernel(); + + if (jcp.ndims == 5) compute_d_loop_common(); + else compute_oh_loop_common(); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::generate() +{ + preamble(); + + mov(reg_input, ptr[param + GET_OFF(src)]); + mov(reg_output, ptr[param + GET_OFF(dst)]); + mov(reg_kernel, ptr[param + GET_OFF(filt)]); + + compute_loop(); + + postamble(); +} + +status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) { + if (!mayiuse(avx512_common)) + return status::unimplemented; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper diff_weights_d(&diff_weights_md); + const memory_desc_wrapper diff_bias_d(&diff_bias_md); + const memory_desc_wrapper diff_dst_d(&diff_dst_md); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero<decltype(jcp)>(); + + jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2]; + jcp.iw = src_d.dims()[ndims-1]; + jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2]; + jcp.ow = diff_dst_d.dims()[ndims-1]; + + jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2]; + jcp.kw = diff_weights_d.dims()[with_groups + ndims-1]; + + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4]; + jcp.l_pad = cd.padding[0][ndims-3]; + + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4]; + jcp.stride_w = cd.strides[ndims-3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4]; + jcp.dilate_w = cd.dilates[ndims-3]; + + const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1); + bool ok = true + // general condition to simplify dilations + && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) + && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) + && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) + // special condition to simplify dilations in compute_oh_loop_common + && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih); + if (!ok) + return status::unimplemented; + + jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h + + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1)); + jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1)); + + /* XXX: currently, does not support dilation_d > 0 */ + if (ndims == 5) + if (jcp.dilate_d > 0) + return status::unimplemented; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.aligned_threads = 0; + + /* check for the 1st convolution */ + jcp.is_1stconv = is_1stconv(jcp); + + jcp.oc_block = jcp.simd_w; + + bool ok_to_pad_channels = true + && jcp.ngroups == 1 + && src_d.data_type() == data_type::f32; + + if (ok_to_pad_channels) + jcp.oc = rnd_up(jcp.oc, jcp.simd_w); + + if (jcp.oc % jcp.oc_block) + return status::unimplemented; + + auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = with_groups + ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + + if (diff_dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag)); + jcp.dst_tag = dst_tag; + } else { + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag); + } + if (jcp.dst_tag != dst_tag) + return status::unimplemented; + + /* conditions on bias memory */ + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (diff_bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md, x)); + } + + jcp.nb_oc = jcp.oc / jcp.oc_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2; + const bool boundaries_ok = true + && jcp.t_pad <= max_pad + && jcp.b_pad <= max_pad + && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad) + && jcp.f_pad < jcp.kd; + if (!boundaries_ok) + return status::unimplemented; + + /* yet another common check */ + if (jcp.kw > 14) + return status::unimplemented; + + /* setting register strategy */ + for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) { + if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; } + } + + if (jcp.is_1stconv) { + auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + if (jcp.ic == 1 && jcp.src_tag != src_tag) + jcp.src_tag = src_d.matches_one_of_tag( + pick(ndims - 3, nwc, nhwc, ndhwc)); + } + if (jcp.src_tag == format_tag::undef) + return status::unimplemented; + + const bool src_ok = true + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type()) + && one_of(jcp.ic, 1, 2, 3) + && jcp.ngroups == 1; + if (!src_ok) + return status::unimplemented; + + const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, + jcp.stride_w), 16); + const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1); + const int kh_step_rem = jcp.kh % kh_step; + + const auto wei_4fma_tag = with_groups + ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o) + : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o); + + auto current_wei_tag = format_tag::undef; + if (diff_weights_d.format_kind() != format_kind::any) + current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag); + + const bool use_4fma = true + && one_of(ndims, 3, 4) + && mayiuse(avx512_mic_4ops) + && mkldnn_thr_syncable() + && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) + && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad) + && jcp.kw <= 28 - jcp.with_bias + && jcp.stride_w == 4 + && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */ + && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */ + && IMPLICATION(diff_weights_d.format_kind() != format_kind::any, + current_wei_tag == wei_4fma_tag); + + if (use_4fma) { + jcp.ver = ver_4fma; + jcp.kh_step = kh_step; + jcp.tr_ld = tr_ld; + jcp.ic_block = 1; + if (diff_weights_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag)); + jcp.wei_tag = wei_4fma_tag; + } else { + jcp.ver = ver_fma; + jcp.ic_block = jcp.ic; + + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + jcp.nb_ic = jcp.ic / jcp.ic_block; + } else { + auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + jcp.src_tag = src_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(src_tag); + } + if (jcp.src_tag != src_tag) + return status::unimplemented; + + if (diff_weights_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); + jcp.wei_tag = wei_tag; + } else { + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + } + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + + jcp.ic_block = jcp.simd_w; + if (ok_to_pad_channels) + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + if ((mayiuse(avx512_mic) || mayiuse(avx512_core)) + && utils::everyone_is(data_type::f32, + src_d.data_type(), diff_weights_d.data_type(), + diff_dst_d.data_type())) { + jcp.ver = ver_fma; + if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 && + everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) && + mkldnn_thr_syncable()) { + jcp.ver = ver_4fma; + } + } else { + return status::unimplemented; + } + if (jcp.ver == ver_4fma) { + jcp.ur_w = jcp.ow; + // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to + // cross the right boundary. The only requirement is not to have + // NaNs there because another multiplicand is always guaranteed to + // be zero. This also may require the top-level driver to allocate + // four extra guarding elements at the very end of the buffer. + // I'm not proud of this hack, but it improves performance by + // about 5-10% depending on the dimensions (Roma) + + const int tr_round = 4; + + jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round); + jcp.tr_src_num_guard_elems = tr_round; // upper bound + } + } + + if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) { + jcp.typesize_in = sizeof(float); + jcp.typesize_out = sizeof(float); + } else + return status::unimplemented; + + bool args_ok = true + && jcp.ic % jcp.ic_block == 0 + && jcp.oc % jcp.oc_block == 0 + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + { // balancing + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); + jcp.nthr = nthr; + jcp.nthr_mb = nthr_mb; + jcp.nthr_g = nthr_g; + jcp.nthr_oc_b = nthr_oc_b; + jcp.nthr_ic_b = nthr_ic_b; + } + + return status::success; +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.ver == ver_4fma) { + if (jcp.is_1stconv) { + const size_t tr_src_size = + jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } else { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic; + const size_t min_tr_src_size_per_thr + = jcp.ih * jcp.ic_block * jcp.tr_iw; + const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr + + jcp.tr_src_num_guard_elems; + scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size); + } + + /* prepare synchronization contexts */ + if (jcp.nthr_oc_b > 1) { + const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; + scratchpad.book(key_conv_tr_src_bctx, + sizeof(simple_barrier::ctx_t) * tr_src_bctx_size); + } + } + + if (jcp.nthr_mb > 1) { + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic + * jcp.kh * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const size_t wei_bia_reduction_size = wei_size + bia_size; + + scratchpad.book(key_conv_wei_bia_reduction, + jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1)); + scratchpad.book(key_conv_wei_bia_reduction_bctx, + sizeof(simple_barrier::ctx_t)); + } + + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc); +} + +void jit_avx512_common_conv_bwd_weights_kernel_f32::balance( + const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_, + int &nthr_oc_b_, int &nthr_ic_b_) +{ + nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; + + const int max_threads = mkldnn_get_max_threads(); + + if (max_threads < j.ngroups) { + /* simplification... fortunately it doesn't hurt much */ + return; + } + + if (!mkldnn_thr_syncable() && j.ver == ver_4fma) { + // should not happen -- the driver is not ready + // for TBB-like non-synchronous threading yet + return; + } + + if (j.ver == ver_4fma && j.is_1stconv) { + nthr_g_ = 1; + nthr_oc_b_ = 1; + nthr_ic_b_ = nstl::min(j.nb_ic, max_threads); + nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb); + nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_; + return; + } + + nthr_g_ = j.ngroups; + const int nthr = max_threads / nthr_g_; + + auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + /* calculate per thread memory cost (read/write). high level optimizer + * tries to minimize memory consumption. few notes: + * (n1) unclear why, but that essentially helps first convolution... + * (n2) assuming the reduction over minibatch is always there: + * - instead of 8 it should be 5 here (write ~= 2 read): + * kernel: temporal workspace 1 write + * reduction: 1 read from workspace and 1 write to the diff_wei + * - but experiments showed 8 works better than 5 or 6... */ + + const int src_coef = j.ver == ver_4fma ? 4 : 1; + const int dst_coef = 1; + const int wei_coef = 8; + + return 0 + + src_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id + / j.stride_d / j.stride_h / j.stride_w /* (n1) */ + + dst_coef + * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od + + wei_coef /* (n2) */ + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) + * j.kh * j.kw * j.kd * j.ic_block * j.oc_block; + }; + + int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + + /* step 1: find the best thread distribution with lowest memory cost */ + const int nthr_mb_max = nstl::min(nthr, j.mb * j.od); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + if (mem_cost <= best_mem_cost) { + best_mem_cost = mem_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + + if (!mayiuse(avx512_mic)) { + auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { + return 1 + * div_up(j.mb, nthr_mb) + * div_up(j.ngroups, nthr_g_) + * div_up(j.nb_oc, nthr_oc_b) + * div_up(j.nb_ic, nthr_ic_b); + }; + + /* step 2: search for a thread distribution with lower compute cost. + * the constrains: + * - memory cost cannot exceed 110% of the best found in the step 1 + * - unless compute cost is 133% lower than the current best case + * note: both constants were found empirically */ + int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_); + for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { + const int nthr_par = nthr / nthr_mb; + const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc); + for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { + int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic); + int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b); + + const bool opt1 = comp_cost <= best_comp_cost + && mem_cost < 1.1 * best_mem_cost; + const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost; + + if (opt1 || opt2) { + best_comp_cost = comp_cost; + nthr_mb_ = nthr_mb; + nthr_oc_b_ = nthr_oc_b; + nthr_ic_b_ = nthr_ic_b; + } + } + + if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; } + } + } + + if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads) + nthr_mb_ = nstl::min(j.mb * j.od, max_threads); + nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; + + assert(nthr_ <= max_threads); + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1)); +} + +template struct _jit_avx512_common_conv_fwd_kernel<Zmm>; +template struct _jit_avx512_common_conv_fwd_kernel<Xmm>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp new file mode 100644 index 0000000000..f76770797a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template<typename Vmm> +struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { + + _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_common_conv_fwd_kernel() { + delete eltwise_injector_; + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_inp = r8; + reg64_t reg_ker = r9; + reg64_t reg_out = r10; + + reg64_t reg_inp_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_out_prf = r13; + reg64_t reg_owb = r12; + + reg64_t aux_reg_inp = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_inp_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t reg_channel = rsi; + reg64_t reg_bias = rdx; + + reg64_t aux_reg_ker_d = r9; + reg64_t aux_reg_inp_d = rbx; + reg64_t aux_reg_inp_d_prf = r13; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_relu_ns = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_tmp = rbp; + + reg64_t reg_ic_loop = rdx; + reg64_t reg_inp_loop = rsi; + + reg64_t reg_init_flag = r13; + reg64_t reg_bias_ptr = param; + + reg64_t aux_reg_ic = r12; + reg64_t reg_binp = rax; + reg64_t reg_bout = r11; + reg64_t aux1_reg_inp = rbx; + reg64_t aux_reg_out = abi_not_param1; + + reg64_t reg_long_offt = r11; + reg64_t reg_out_long_offt = r14; + + inline Vmm vmm_ker(int i_ic) { + assert(i_ic < 4); + return Vmm(ker_reg_base_idx + i_ic); + } + + inline Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Vmm(idx); + } + + inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + + Xbyak::Reg64 imm_addr64 = r15; + Vmm vmm_wei = Vmm(31); + + jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r); + inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r); + inline void compute_loop(int ur_w, int pad_l, int pad_r); + + void generate(); + + inline size_t get_output_offset(int oi, int n_oc_block) { + return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh + * jcp.ow * jcp.od + oi) * jcp.oc_block; + } + + inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { + size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1; + size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id; + return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1) + + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str); + } + + inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) { + return jcp.typesize_in * jcp.oc_block + * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd + + (ic + ker_number) + ki * jcp.ic_block); + } + + inline int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + + inline int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } +}; + +struct jit_avx512_common_conv_fwd_kernel { + + jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_common_conv_fwd_kernel() { + delete xmm_kernel_; + delete zmm_kernel_; + } + + enum { + typesize = sizeof(float) + }; + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + void(*jit_ker)(jit_conv_call_s *); + _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_; + _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_; +}; + +struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator { + + jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = abi_param1; + reg64_t reg_dst = r8; + reg64_t reg_ker = r9; + reg64_t reg_src = r10; + + reg64_t reg_dst_prf = r11; + reg64_t reg_ker_prf = r12; + reg64_t reg_src_prf = r13; + + reg64_t aux_reg_dst = r14; + reg64_t aux_reg_ker = r15; + + reg64_t aux_reg_dst_prf = rsi; + reg64_t aux_reg_ker_prf = rdx; + + reg64_t aux_reg_dst_d_prf = r13; + reg64_t aux_reg_dst_d = rbx; + reg64_t aux_reg_ker_d_prf = abi_not_param1; + reg64_t aux_reg_ker_d = r9; + reg64_t reg_ki = r10; + + reg64_t reg_kj = rax; + reg64_t reg_oi = rbx; + reg64_t reg_kh = abi_not_param1; + + reg64_t reg_channel = rsi; + + reg64_t reg_tmp = rbp; + reg64_t reg_long_offt = r14; + + inline Xbyak::Zmm zmm_ker(int i_ic) { + assert(i_ic < 4); + return Xbyak::Zmm(ker_reg_base_idx + i_ic); + } + inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < ker_reg_base_idx); + return Xbyak::Zmm(idx); + } + + Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow); + inline void compute_loop(int ur_w, int l_overflow, int r_overflow); + void generate(); + + inline int get_iw_start(int ki, int l_overflow) + { + int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return res; + } + + inline int get_iw_end(int ur_w, int ki, int r_overflow) + { + if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + + return ur_w - res; + } +}; + +struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { + + jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp) + { + generate(); + jit_ker = (void (*)(jit_conv_call_s *))getCode(); + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &diff_weights_md, + memory_desc_t &diff_bias_md, + memory_desc_t &diff_dst_md); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + enum {typesize = sizeof(float)}; + static const int max_ur_w; + + reg64_t param = abi_param1; + reg64_t reg_input = rax; + reg64_t reg_kernel = rdx; + reg64_t reg_output = rsi; + reg64_t b_ic = abi_not_param1; + reg64_t kj = r8; + reg64_t reg_kh = r9; + reg64_t reg_ur_w_trips = r10; + reg64_t reg_oj = r15; + reg64_t reg_ih_count = rbx; + reg64_t reg_tmp = r14; + reg64_t reg_long_offt = r14; + + reg64_t ki = r11; + reg64_t reg_kd_count = r12; + reg64_t reg_oi = r12; + reg64_t reg_d_index = r13; + reg64_t reg_input_d = r15; + reg64_t reg_output_d = rbx; + reg64_t aux_reg_input = r12; + reg64_t aux_reg_kernel = r13; + reg64_t reg_bias = rbx; + + inline void bias_kernel(); + inline void maybe_zero_kernel(); + inline void compute_oh_step_unroll_ow_icblock(int ic_block_step, + int max_ur_w); + inline void od_step_comeback_pointers(); + inline void oh_step_comeback_pointers(); + inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); + inline void compute_ic_block_step(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound = false); + inline void compute_ic_block_step_fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_ic_block_step_4fma(int ur_w, + int pad_l, int pad_r, int ic_block_step, + int input_offset, int kernel_offset, int output_offset, + bool input_wraparound); + inline void compute_oh_step_common(int ic_block_step, int max_ur_w); + inline void compute_oh_step_disp(); + inline void compute_oh_loop_common(); + inline void compute_d_loop_common(); + + inline bool compute_full_spat_loop(); + inline bool flat_4ops_compute(); + + inline void compute_loop(); + + void generate(); + + static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, + int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp new file mode 100644 index 0000000000..1bdcd0d6a8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp @@ -0,0 +1,1163 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include <math.h> + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +#ifndef KERNEL_SIZE_THRESHOLD +#define KERNEL_SIZE_THRESHOLD 16 +#endif + +#define MIN_REQUIRED_DIMN_REG_BLOCK 14 + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + if (jcp.ver == ver_4fma) + return jcp.mb >= 32; + else + return jcp.mb >= 16; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template <typename data_t> +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: + case L3: + default: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} + +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block, + float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate( + bool is_beta_zero) +{ + // const int dimK_simd_block = jcp.dimK_reg_block; + + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][tile] += + // A[dimM_block][dimK_block][dimK_reg_block] * + // broadcast(B[dimK_block][tile][dimK_reg_block]); + // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block], + // so we load it before the loop on tile + // 2) the loop on tile must be fully unrolled. Don't know about the one on + // dimK_reg_block. I think it should be + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + + prefetcher_t<float> L1_pf(this, reg_srcB, L1, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + prefetcher_t<float> L2_pf(this, reg_srcB, L2, + jcp.dimN_reg_block * jcp.dimK_reg_block, + jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block, + fma_ipc); + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { + // First, we zero the accumulators if first nb_ic iteration, + // otherwise we load them + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (is_beta_zero) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + 64 * tile]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_dimK_reg_block; i++) + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + 64 * (offset + i)]); + }; + + // Used when doing double buffering + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block += inc_dimK_reg_block) { + int current; + /* Loading the next vector from A */ + current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block + inc_dimK_reg_block) + % (2 * inc_dimK_reg_block); + load_A(next, dimK_reg_block + inc_dimK_reg_block); + } else { + next = 0; + load_A(next, dimK_reg_block); + } + /* Performing the fmas */ + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (jcp.ver != ver_avx512_core) + L1_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + if (jcp.ver == ver_4fma) + v4fmaddps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4)); + else + vfmadd231ps(zmm, Zmm(current), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, + true)); + if (jcp.ver != ver_avx512_core) + L2_pf.prefetch( + dimK_reg_block * jcp.dimN_reg_block + tile); + } + } + + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + } + + + auto store_output = [=](bool output_is_aligned) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm(jcp.zmm_start + tile); + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(zword[reg_dstC + 64 * tile], zmm); + else + vmovups(zword[reg_dstC + 64 * tile], zmm); + } + }; + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimN_reg_block * 64); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + + if (mayiuse(avx512_core)) + return status::unimplemented; + else if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // Checking conditions not supported by these kernels + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + return status::success; +} + + +status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) { + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK) + && (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block < current_best); + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block); + + if (jcp.dimN_reg_block >= jcp.nb_reg) { + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return (dimN_reg_block < jcp.nb_reg) + && (dimN_reg_block > current_best); + }; + + jcp.dimN_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimN, 1, test_cond_dimN_reg_block); + } + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + jcp.dimM_simd_block = 16; + /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/ + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_simd_block, .5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + // TODO: replace double buffering with nuple buffering to maximize register + // usage. + // the choice of the number of buffers will then come after choosing + // dimN_reg_block + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2); + else + jcp.zmm_start = 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + + jcp.sched_policy = WSCHED_INVALID; + set_wsched_DATA_W_S_G_D_avx512_common(jcp); + + assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D); + return status::success; +} + +bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) || + (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { + status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + jcp.tile_4fma_padding = 0; // only relevant for backward weights + + return res; +} + +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate() +{ + auto load_B = [=](int reg_idx, int offset) { + for (int i = 0; i < 4; i++) { + vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]); + } + }; + + preamble(); + int curr = 0; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int origB_offset = (j * alpha + i) * jcp.dimK_4fma; + size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block * + jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float); + mov(reg_transB_idx, transB_offset); + for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) { + /*double buffering to hide load latencies*/ + int next = (curr + 4) % 8; + if (i == 0 && tb == 0) { + load_B(0, origB_offset); + } + if (tb + 4 < (jcp.dimK_4fma -1)) { + load_B(next, origB_offset + 4); + } else if (i < alpha - 1) { + load_B(next, origB_offset + jcp.dimK_4fma); + } + + vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3)); + vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1)); + vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3)); + + vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9)); + vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9)); + + vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1)); + vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1)); + + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * tb * jcp.dimN_reg_block], + Zmm(curr+2)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 1) * jcp.dimN_reg_block], + Zmm(curr+3)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 2) * jcp.dimN_reg_block], + Zmm(8)); + vmovntps(zword[reg_transB + reg_transB_idx + + sizeof(float) * (tb + 3) * jcp.dimN_reg_block], + Zmm(9)); + curr = next; + + } + } + } + postamble(); + ret(); +} +void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate( + bool is_first_tile) +{ + // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) + // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++) + // for (int nb_tile_block_ur = 0; nb_tile_block_ur < + // jcp.nb_tile_block_ur; nb_tile_block_ur++) + // for (int tile_block_ur = 0; tile_block_ur < + // jcp.tile_block_ur; tile_block_ur++) + // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3) + // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] += + // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block] + // * + // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur]) + auto inner_loops = [=]() { + int inc_fma = jcp.ver == ver_4fma ? 4 : 1; + const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2; + prefetcher_t<float> L1_pf(this, reg_srcB, L1, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + prefetcher_t<float> L2_pf(this, reg_srcB, L2, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma, + jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma + / inc_fma, + fma_ipc); + + auto load_A = [=](int reg_idx, int offset) { + for (int i = 0; i < inc_fma; i++) { + vmovups(Zmm(reg_idx + i), + zword[reg_srcA + + sizeof(float) * jcp.dimM_simd_block * (offset + i)]); + } + }; + + Label dimM_block_loop, dimK_block_loop, dimN_block_loop; + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + { /************* OC_block (M) loop ***********/ + if (jcp.dimN_block > 1) { + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + } + { /*************** IC_block (N) loop *********/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + if (is_first_tile) + vpxord(zmm, zmm, zmm); + else + vmovups(zmm, zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * + sizeof(float)]); + } + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + { /************* nb_tile_ur(K) loop ********/ + int next = 0; + if (jcp.double_buffering) { + load_A(next, 0); + } + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block++) { + int srcB_offset = dimK_reg_block * jcp.dimK_4fma + * jcp.dimN_reg_block; + for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma; + dimK_4fma += inc_fma) { + int current = next; + if (jcp.double_buffering) { + next = (dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma) + % (2 * inc_fma); + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma + inc_fma); + } else { + next = 0; + load_A(next, dimK_reg_block * jcp.dimK_4fma + + dimK_4fma); + } + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; + ++dimN_reg_block) { + L1_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + L2_pf.prefetch(srcB_offset / inc_fma + + dimK_4fma / inc_fma + * jcp.dimN_reg_block + + dimN_reg_block); + if (jcp.ver == ver_4fma) { + int srcB_trans_offset = (dimK_4fma / 4) * 64 + + dimK_4fma % 4; + v4fmaddps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * ( + srcB_offset + + srcB_trans_offset + + (dimN_reg_block % 4) * 16 + + (dimN_reg_block / 4) * 4))); + } else { + vfmadd231ps( + Zmm(jcp.zmm_start + dimN_reg_block), + Zmm(current), + EVEX_compress_addr(reg_srcB, + sizeof(float) * (srcB_offset + dimN_reg_block), + true)); + } + } + } + } + } + + add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma + * jcp.dimM_simd_block * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + /******** Write C back to memory *******/ + for (int dimN_reg_block = 0; + dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) { + Zmm zmm(jcp.zmm_start + dimN_reg_block); + vmovups(zword[reg_dstC + + dimN_reg_block * jcp.dimM_simd_block * sizeof(float)], + zmm); + } + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block * + jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block + * sizeof(float)); + if (jcp.dimN_block > 1) { + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + } + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimN_block * jcp.dimK_block + * jcp.dimK_reg_block * jcp.dimN_reg_block + * jcp.dimK_4fma * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + } + }; + + /* Preamble */ + // register used to handle long fma encoding + preamble(); + mov(reg_srcA, reg_srcA_const); + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { +bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw; + lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C) +{ + float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma + * dimM_simdw; + lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma; + lhs *= sizeof(float); + float rhs = C * L1_cache_size; + return (lhs <= rhs); +} + +bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block + * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} + +bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block, + int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block, + float C) +{ + float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block; + lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma; + lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block + * dimN_reg_block; + lhs *= sizeof(float); + float rhs = C * L2_cache_size; + return (lhs <= rhs); +} +} // namespace + +status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp) +{ + /*************** Choose dimN_reg_block (ic_simd_block) + * *******************************/ + jcp.dimN = jcp.ic; + /*Hardcoded to 16 because N = ic for bwd weights and + innermost dimension for ic is assumed 16 in src transforms. This + choice covers load latencies while maintaining simplicity of kernel + for POR topologies. FIXME in future??: Will not work for future topologies + when ic%16 != 0*/ + jcp.dimN_reg_block = jcp.ic_simd_block; + + /****************************** Choose dimK_block + * **************************/ + // No freedom for choosing dimM_simd_block because ic_simd_block + // is determined by input data format + jcp.dimM_simd_block = jcp.oc_simd_block; + + auto test_cond1bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f) + && (dimK_block > current_best); + }; + + auto test_cond2bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f) + && (dimK_block > current_best); + }; + + auto test_cond2_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1, + jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block); + if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block); + + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block); + if (jcp.dimK_reg_block < jcp.dimK_block) { + jcp.dimK_reg_block = get_divisor_satisfying_cond( + jcp, jcp.dimK_block, 1, test_cond1_dimK_block); + } + jcp.dimK_block /= jcp.dimK_reg_block; + jcp.dimK_nb_block + = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + /***************************** Chose dimN block + * ****************************/ + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block, + jcp.dimN_reg_block, 0.5f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.ic_block = jcp.dimN_block; + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /********************************* Choose dimM block + * ************************/ + jcp.dimM = jcp.oc; + + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1, + jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block, + 1.0f) + && (dimM_block > current_best) + && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2; + }; + + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block); + jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block; + + jcp.sched_policy = WSCHED_WEI_S_D_G_W; + return status::success; +} + +status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) +{ + jcp.nthr = mkldnn_get_max_threads(); + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + if (mayiuse(avx512_core)) + return status::unimplemented; + if (!mayiuse(avx512_common)) + return status::unimplemented; + else if (mayiuse(avx512_mic_4ops)) + jcp.ver = ver_4fma; + else + jcp.ver = ver_fma; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /*************************** New Kernel Parameters + * *****************************/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + jcp.dimK_4fma = 1; + jcp.tile_4fma_padding = 0; + +#define MAX_4FMA_UR 8 + if (jcp.ver == ver_4fma) { + auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma, + int current_best) { + return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR) + && (dimK_4fma > current_best); + }; + jcp.dimK_4fma = get_divisor_satisfying_cond( + jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma); + if (jcp.dimK_4fma == 1) + jcp.dimK_4fma = 4; + if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0) + jcp.tile_4fma_padding = jcp.dimK_4fma + - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma); + } + + jcp.tile_4fma = jcp.dimK_4fma; + /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src + * transform + * will not work correctly, this is solved by applying padding.*/ + jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + + jcp.double_buffering = true; + if (jcp.double_buffering) + jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2; + else + jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1; + jcp.nb_reg = 32 - jcp.zmm_start; + + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp); + assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W); + + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; + + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + + return res; + +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp new file mode 100644 index 0000000000..6c117143f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp @@ -0,0 +1,179 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP +#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +//alpha determines the output tile_size +constexpr int alpha = 6; +constexpr int tile_size = 4; +//simd length used for vectorization +constexpr int simd_w = 16; + +struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator { + _jit_avx512_common_conv_winograd_data_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter + = (decltype(gemm_loop_ker_first_iter)) this->getCode(); + + //************** Subsequent iterations kernel **************// + if (jcp.dimK_nb_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + +protected: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_beta_zero); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; +}; + +struct jit_avx512_common_conv_winograd_fwd_kernel_f32 + : _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); +}; + +struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32 + : public _jit_avx512_common_conv_winograd_data_kernel_f32 { + using _jit_avx512_common_conv_winograd_data_kernel_f32:: + _jit_avx512_common_conv_winograd_data_kernel_f32; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + + //******************* First iter kernel ********************// + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; + } + + if (jcp.tile_block > 1) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + if (jcp.ver == ver_4fma) { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->transpose_ker_generate(); + transpose_4fma_ker = (decltype(transpose_4fma_ker))addr; + } + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*transpose_4fma_ker)(float *, float *); + +private: + using reg64_t = const Xbyak::Reg64; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(bool is_first_tile); + void transpose_ker_generate(); + + reg64_t reg_origB = abi_param2; + reg64_t reg_transB = abi_param1; + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA_const = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_sp = rsp; + reg64_t reg_srcA = r9; + reg64_t reg_nb_ic = r10; + reg64_t reg_loop_cpt = r11; + reg64_t reg_transB_idx = r13; + + /* Registers used by new kernel */ + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r12; + reg64_t reg_dimN_block_loop_cnt = r11; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp new file mode 100644 index 0000000000..abddc19221 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp @@ -0,0 +1,1526 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define PIPELINE(field) \ + do { \ + p.field = p.field ## _prf; \ + p.field ## _prf = field; \ + } while (0) + +inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int kh_padding, int kd_padding) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + + if (p.src) + ker(&p); +} +// The special case for the driver with ow-parallelization (FWD) +// TODO: implement it for BWD_D and BWD_W too +inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker, + jit_conv_call_s &p, const void *src, const void *dst, const void *filt, + const void *bias, int channel, int kh_padding, int kd_padding, int owb) +{ + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kh_padding); + PIPELINE(kd_padding); + PIPELINE(owb); + + if (p.src) + ker(&p); +} + +void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p, + const void *src, const void *dst, const void *filt, const void *bias, + int channel, int d_index, int d_worksize, + int kd_padding /* kd_work_size */, size_t kd_offset) { + PIPELINE(src); + PIPELINE(dst); + PIPELINE(filt); + PIPELINE(bias); + PIPELINE(channel); + PIPELINE(kd_padding); + PIPELINE(d_worksize); + PIPELINE(d_index); + PIPELINE(kd_offset); + + if (p.src) + ker(&p); +} +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> +void jit_avx512_common_convolution_fwd_t<src_type, wei_type, + dst_type>::prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const { + if (!pd()->wants_padded_bias()) return; + + auto padded_bias = scratchpad.template get<dst_data_t>( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); + utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, + (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding); + bias = padded_bias; +} + +template <data_type_t src_type, data_type_t wei_type, + data_type_t dst_type> +void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>:: +execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + auto bias_w = bias ? bias + g_oc : nullptr; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src_w, dst_w, wht_w, bias_w, icb, 1, owb); + + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + if (jcp.loop_order == loop_cwgn) { + int dummy{0}; + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gncw) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template <data_type_t src_type, data_type_t wei_type, + data_type_t dst_type> +void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>:: +execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow; + + int nthr; + if (jcp.aligned_threads) + nthr = jcp.aligned_threads; + else + nthr = mkldnn_get_max_threads(); + + parallel(nthr, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, + occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + auto bias_w = bias ? bias + g_oc : nullptr; + + for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { + int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; + + auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s); + auto src_w + = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s); + auto wht_w + = weights + wht_blk_off(weights_d, g, ocb, icb_l2); + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); + ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_b, ij = ih_b; + oj < min(oh_e, oh_b + jcp.h_blocking); + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up(max(0, ij - jcp.ih + + (jcp.kh - 1) * dilate_h + 1), dilate_h); + int kh_padding = nstl::max( + 0, jcp.kh - i_t_overflow - i_b_overflow); + + auto aux_src = src_c + + i_t_overflow * dilate_h * src_h_stride; + auto aux_wht = wht_w + i_t_overflow * wht_h_stride; + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, aux_src, dst_c, aux_wht, bias_w, icb, + kh_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, + g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template <data_type_t src_type, data_type_t wei_type, + data_type_t dst_type> +void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>:: +execute_forward_3d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + prepare_padded_bias(bias, this->scratchpad(ctx)); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + + parallel(0, [&](const int ithr, const int nthr) { + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int start{0}, end{0}, start_copy; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh + * jcp.nb_ow; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t src_d_stride = src_d.blk_off(0, 0, 1); + size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); + size_t src_c_stride = src_d.blk_off(0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); + + for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { + start = start_copy; + int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0}; + + if (jcp.loop_order == loop_cwgn) + nd_iterator_init(start, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int g_ocb = g * jcp.nb_oc + ocb; + int g_oc = g_ocb * jcp.oc_block; + int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + int id_s = -jcp.f_pad + od_s * jcp.stride_d; + + int dilate_d = jcp.dilate_d + 1; + int d_t_overflow = div_up(max(0, -id_s), dilate_d); + int d_b_overflow = div_up( + max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), + dilate_d); + int kd_padding = nstl::max(0, + jcp.kd - d_t_overflow - d_b_overflow); + + auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0; + auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s, + iw_s) + d_t_overflow * dilate_d * src_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) + + d_t_overflow * wht_d_stride; + + for (int icb = icb_l2; + icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) { + auto src_c = src_w; + auto dst_c = dst_w; + for (int oj = oh_s, ij = ih_s; + oj < oh_e; ++oj, ij += jcp.stride_h) + { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = div_up(max(0, -ij), dilate_h); + int i_b_overflow = div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + + 1), + dilate_h); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker, + par_conv, + src_c + i_t_overflow * dilate_h * src_h_stride, + dst_c, wht_w + i_t_overflow * wht_h_stride, + bias_w, icb, kh_padding, kd_padding, owb); + + src_c += src_h_stride * jcp.stride_h; + dst_c += dst_h_stride; + } + src_w += src_c_stride; + wht_w += wht_ic_stride; + } + + if (jcp.loop_order == loop_cwgn) + nd_iterator_jump(start, end, + occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb, + od_s, jcp.od, oh_s, jcp.oh); + else if (jcp.loop_order == loop_gncw) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow, + od_s, jcp.od, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + } + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + src, dst, weights, bias, 0, 0, 0); + }); +} + +template struct jit_avx512_common_convolution_fwd_t<data_type::f32>; + +template <data_type_t diff_dst_type, data_type_t wei_type, + data_type_t diff_src_type> +void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, + diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}; + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n, + jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w, diff_dst_w, wht_w, 0, ocb, 1); + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) { + int dummy{0}; + nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups, + n, jcp.mb, dummy, 1); + } else if (jcp.loop_order == loop_gnc) { + int dummy{0}; + nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc, + ic_chunks, dummy, 1); + } else { + assert(!"unsupported loop order"); + } + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template <data_type_t diff_dst_type, data_type_t wei_type, + data_type_t diff_src_type> +void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, + diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb); + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1); + }); +} + +template <data_type_t diff_dst_type, data_type_t wei_type, + data_type_t diff_src_type> +void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, + diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const +{ + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + parallel(0, [&](const int ithr, const int nthr) { + int start{0}, end{0}, start_copy; + int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; + int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih; + balance211(work_amount, nthr, ithr, start, end); + start_copy = start; + + auto par_conv = jit_conv_call_s(); + size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); + size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); + size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); + size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); + size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); + size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); + + bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; + bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; + + for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { + start = start_copy; + int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0}; + if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_init(start, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + + while (start < end) { + int icb = icc * jcp.nb_ic_blocking; + int g_icb = g * jcp.nb_ic + icb; + int g_ocb = g * jcp.nb_oc; + + int work_rem = end - start; + int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; + int d_len = 0, d_lo = 0, d_oj = 0; + if (is_fast_path_d) { // dilate == 0 && stride == 1 + int d_t_overflow = max(0, jcp.kd - 1 - id_s + - jcp.f_pad); + int d_b_overflow = max(0, jcp.kd - jcp.id + id_s + - jcp.back_pad); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow; + } else if (jcp.dilate_d != 0) { // stride == 1 + int dilate_d = jcp.dilate_d + 1; + // Note: use div_up to account for "holes" in filter + int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + - id_s - jcp.f_pad), dilate_d); + int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1 + - jcp.id + id_s - jcp.back_pad), dilate_d); + d_len = jcp.kd - d_t_overflow - d_b_overflow; + d_lo = d_b_overflow; + d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; + } else { // dilate == 0 + int d_t_overflow = max(0, (jcp.kd - 1 - id_s + - jcp.f_pad) / jcp.stride_d); + int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s + - jcp.back_pad) / jcp.stride_d); + int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1 + + jcp.back_pad - id_s) % jcp.stride_d); + int overflow_kd_lo = (id_s + jcp.f_pad) + % jcp.stride_d; + + d_len = (overflow_kd_hi - overflow_kd_lo) + / jcp.stride_d + 1 - d_t_overflow + - d_b_overflow; + d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; + d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; + } + assert(d_len >= 0); + + auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb) + + id_s * diff_src_d_stride; + auto diff_dst_w = diff_dst + + diff_dst_d.blk_off(n, g_ocb + ocb_l2) + + d_oj * diff_dst_d_stride; + auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) + + d_lo * wht_d_stride; + + for (int ocb = ocb_l2; + ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) { + for (int ij = ih_s; ij < ih_e; ++ij) { + int oj, k_len, k_lo; + if (is_fast_path_h) { // dilate == 0 && stride == 1 + int i_t_overflow = max(0, jcp.kh - 1 - ij + - jcp.t_pad); + int i_b_overflow = max(0, jcp.kh - jcp.ih + ij + - jcp.b_pad); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow; + } else if (jcp.dilate_h != 0) { // stride == 1 + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int i_t_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + - ij - jcp.t_pad), dilate_h); + int i_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 + - jcp.ih + ij - jcp.b_pad), dilate_h); + k_len = jcp.kh - i_t_overflow - i_b_overflow; + k_lo = i_b_overflow; + oj = ij + jcp.t_pad - i_b_overflow * dilate_h; + } else { // dilate == 0 + int i_t_overflow = max(0, (jcp.kh - 1 - ij + - jcp.t_pad) / jcp.stride_h); + int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij + - jcp.b_pad) / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1 + + jcp.b_pad - ij) % jcp.stride_h); + int overflow_kh_lo = (ij + jcp.t_pad) + % jcp.stride_h; + + k_len = (overflow_kh_hi - overflow_kh_lo) + / jcp.stride_h + 1 - i_t_overflow + - i_b_overflow; + k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; + oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; + } + assert(k_len >= 0); + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src_w + ij * diff_src_h_stride, + diff_dst_w + oj * diff_dst_h_stride, + wht_w + k_lo * wht_h_stride, + 0, ocb, k_len, d_len); + } + diff_dst_w += diff_dst_c_stride; + wht_w += wht_oc_stride; + } + + if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, + icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id, + ih_s, jcp.ih); + else if (jcp.loop_order == loop_gnc) + nd_iterator_jump(start, end, + g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id, + ih_s, jcp.ih); + else + assert(!"unsupported loop order"); + } + } + + jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv, + diff_src, diff_dst, weights, 0, 0, 1, 1); + }); +} + +template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>; + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>:: +jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) + , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr) +{ + const auto &j = pd()->jcp_; + + nthr_ = j.nthr; + nthr_mb_ = j.nthr_mb; + nthr_g_ = j.nthr_g; + nthr_oc_b_ = j.nthr_oc_b; + nthr_ic_b_ = j.nthr_ic_b; + + kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j); + + if (j.ver == ver_4fma) + trans_kernel_ = create_trans_src(&j); + + if (nthr_mb_ > 1) + acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>(); + + reducer_bias_ = + new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_); +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::thread_info_t { + const src_data_t *src; + const diff_dst_data_t *diff_dst; + const diff_weights_data_t *diff_weights; + diff_weights_data_t *diff_bias; + + const memory_tracking::grantor_t scratchpad; + + src_data_t *tr_src; + simple_barrier::ctx_t *tr_src_bctx; + + diff_dst_data_t *tr_diff_dst; + simple_barrier::ctx_t *tr_diff_dst_bctx; + + diff_weights_data_t *wei_bia_reduction; + simple_barrier::ctx_t *wei_bia_reduction_bctx; + + int ithr; + int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; + int ithr_but_oc; + int ithr_but_ic; + + int img_start = 0, img_end = 0, img_work; + int g_start = 0, g_end = 0, g_work; + int oc_b_start = 0, oc_b_end = 0, oc_b_work; + int ic_b_start = 0, ic_b_end = 0, ic_b_work; + + thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, + const exec_ctx_t &ctx, int ithr) + : scratchpad(self->scratchpad(ctx)), ithr(ithr) + { + diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + diff_bias = self->pd()->wants_padded_bias() + ? scratchpad.template get<diff_weights_data_t>( + key_conv_padded_bias) + : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + + tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); + tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( + key_conv_tr_src_bctx); + + tr_diff_dst = scratchpad.template get<diff_dst_data_t>( + key_conv_tr_diff_dst); + tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>( + key_conv_tr_diff_dst_bctx); + + wei_bia_reduction = scratchpad.template get<diff_weights_data_t>( + key_conv_wei_bia_reduction); + wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>( + key_conv_wei_bia_reduction_bctx); + + ithr_ic_b = ithr % self->nthr_ic_b_; + ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; + ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; + ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; + + ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ + + ithr_ic_b; + + ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ + + ithr_oc_b; + + const auto &jcp = self->kernel_->jcp; + + /* reduction dimension */ + balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end); + img_work = img_end - img_start; + + /* independent dimensions */ + balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); + g_work = g_end - g_start; + + balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, + oc_b_end); + oc_b_work = oc_b_end - oc_b_start; + + balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, + ic_b_end); + ic_b_work = ic_b_end - ic_b_start; + } +}; + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + // TODO: use memory descriptor with the same fmt as src (or use a macro :)) + auto tr_src_off = [&](int ithr_mb, int ic, int ij) { + const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; + const size_t tr_chn_size = tr_row_size * jcp.ih; + const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups; + + return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size; + }; + + auto uker_trans = [&](int img) { + const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih; + + int start{0}, end{0}; + balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); + const int my_work = end - start; + + int g{0}, ic_b{0}, j{0}; + nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih); + g += ti->g_start; + ic_b += ti->ic_b_start; + + const int _ic = g * jcp.nb_ic + ic_b; + src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)]; + src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)]; + + assert(jcp.ic_block == 16); + const int src_stride = jcp.iw * jcp.ic_block; + const int tr_src_stride = jcp.tr_iw * jcp.ic_block; + + const int pf_depth = 2; + struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth]; + + for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) { + pf_circ_buf[iwork % pf_depth] = {src1, tr_src1}; + + if (iwork >= pf_depth - 1) { + int old_idx = (iwork - pf_depth + 1) % pf_depth; + auto ctx = jit_trans_src_t::ctx_t(); + ctx.src = pf_circ_buf[old_idx].src; + ctx.tr_src = pf_circ_buf[old_idx].tr_src; + ctx.src_prf = src1; + ctx.tr_src_prf = tr_src1; + (*trans_kernel_)(&ctx); + } + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#if 0 + // reference transposition + const int l_pad = jcp.l_pad; + const int iwlp = l_pad + jcp.iw; + const int tr_iw = jcp.tr_iw; + + for (size_t iwork = start; iwork < end; iwork++) { + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = 0; i < l_pad; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = l_pad; i < iwlp; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] + = (src_data_t)src1[(i - l_pad) * 16 + j]; + + PRAGMA_OMP_SIMD() +# pragma unroll + for (int i = iwlp; i < tr_iw; i++) + for (int j = 0; j < jcp.ic_block; j++) + tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0; + + src1 += src_stride; + tr_src1 += tr_src_stride; + } +#endif + }; + + if (jcp.is_1stconv && jcp.ver == ver_4fma) { + /* prepare contexts */ + auto tr_ctx = jit_trans_src_t::ctx_t(); + tr_ctx.tr_src = ti->tr_src + + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld; + + assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1)); + tr_ctx.nthr_oc_b = nthr_oc_b_; + int ih_start{0}, ih_end{0}; + balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end); + tr_ctx.tr_src_ih_start = ih_start; + tr_ctx.tr_src_ih_end = ih_end; + tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc; + + auto p = jit_conv_call_s(); + p.src = tr_ctx.tr_src; + + /* zero diff_bias if applicable */ + if (jcp.with_bias && ti->ithr_ic_b == 0) { + assert(jcp.oc_block == 16); + for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) { + diff_weights_data_t *db = &diff_bia[oc_b * 16]; + for (int o = 0; o < 16; ++o) + db[o] = 0; + } + } + + for (int img = ti->img_start; img < ti->img_end; ++img) { + p.flags = (img == ti->img_start) * FLAG_MB_FIRST; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _ic = g * jcp.nb_ic + ic_b; + tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)]; + + (*trans_kernel_)(&tr_ctx); + + if (ic_b == 0) + p.flags |= FLAG_IC_FIRST; + else + p.flags &= ~FLAG_IC_FIRST; + + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + const int _oc = g * jcp.nb_oc + oc_b; + p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + + const size_t off = + wht_blk_off(diff_weights_d, g, oc_b, ic_b); + p.filt = diff_wei + off; + p.bias = diff_bia + _oc * jcp.oc_block; + + kernel_->jit_ker(&p); + } + } + } + } + } else { + for (int img = ti->img_start; img < ti->img_end; ++img) { + auto p = jit_conv_call_s(); + + if (jcp.ver == ver_4fma) { + /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ + using simple_barrier::barrier; + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + uker_trans(img); + if (nthr_oc_b_ > 1) + barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); + } + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img, _oc)], + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + 0, (img == ti->img_start), 0); + + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_ker_pipeline(kernel_->jit_ker, p, + jcp.ver == ver_4fma + ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)] + : &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off( + diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + 0, 0, 0); + } + } +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const +{ + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size + = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; + + diff_weights_data_t *diff_wei = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_weights + : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; + diff_weights_data_t *diff_bia = ti->ithr_mb == 0 + ? (diff_weights_data_t*)ti->diff_bias + : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size + + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc; + + const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block; + const int input_step = jcp.ih * jcp.iw * inp_mult; + const int output_step = jcp.ow * jcp.oh * jcp.oc_block; + int img{0}, od_s{0}; + int img_start = ti->img_start, img_end = ti->img_end; + nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); + const int img_first = img; + + while (img_start < img_end) { + auto p = jit_conv_call_s(); + + int work_rem = img_end - img_start; + const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; + const int id_s = od_s * jcp.stride_d; + const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); + const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); + const int kd_back_pad + = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); + int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw + * jcp.ic_block * jcp.oc_block * jcp.typesize_out; + + for (int g = ti->g_start; g < ti->g_end; ++g) { + for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) { + for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { + const int _oc = g * jcp.nb_oc + oc_b; + const int _ic = g * jcp.nb_ic + ic_b; + + auto src = &ti->src[src_d.blk_off(img, _ic) + + ik_overlap * input_step]; + auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc) + + od_s * output_step]; + + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst, + diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), + diff_bia + _oc * 16, (img == img_first), od_s, od_e, + jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off); + + if (ic_b == 0) p.flags = 0; + else p.flags = 1; + } + } + } + + const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start; + const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start; + jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, + &ti->src[src_d.blk_off(img + 1, _ic)], + &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)], + diff_wei + wht_blk_off(diff_weights_d, ti->g_start, + ti->oc_b_start, ti->ic_b_start), + diff_bia, 0, 0, 0, 0, 0); + nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); + } +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kh; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; + const int kh = sub_ic_b_kh_start % jcp.kh; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); + + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) { + if (ti->ithr == 0) + acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias, + diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const { + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw + * jcp.kd; + + /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ + simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); + + const int ic_b_kh_work = ti->ic_b_work * jcp.kd; + const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; + + int start{0}, end{0}; + balance211(work, nthr_mb_, ti->ithr_mb, start, end); + if (start == end) return; + + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + int w = start; + int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0}; + nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + while (w < end) { + const int g = ti->g_start + sub_g_start; + const int oc_b = ti->oc_b_start + sub_oc_b_start; + const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; + const int kd = sub_ic_b_kh_start % jcp.kd; + + const int acc_size + = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) + * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; + + const size_t off + = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); + diff_weights_data_t *d + = (diff_weights_data_t *)ti->diff_weights + off; + diff_weights_data_t *s + = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; + acc_ker_->accumulate(d, s, acc_size); + + nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, + ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); + } + } +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + auto rb = this->reducer_bias_; + assert(nthr_ == rb->balancer().nthr_); + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t( + ti->scratchpad, prefix_reducer_bia); + + const auto &jcp = kernel_->jcp; + + if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return; + + const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); + const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); + + if (b_njobs == 0) return; + + /* reduction dimension */ + int img_start{0}, img_end{0}; + balance211(jcp.mb, rb->balancer().nthr_per_group_, + rb->balancer().id_in_group(ti->ithr), img_start, img_end); + + /* jobs */ + int g_start{0}, ocb_start{0}; + nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); + for (int img = img_start; img < img_end; ++img) { + int g = g_start, ocb = ocb_start; + for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { + const size_t _oc = g * jcp.nb_oc + ocb; + + const diff_dst_data_t *d_dst + = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)]; + diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr, + ti->diff_bias, reducer_bia_scratchpad) + + b_job_loc * rb->balancer().job_size_; + + if (img == img_start) + for (int o = 0; o < 16; ++o) + d_bias[o] = 0; + for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { + PRAGMA_OMP_SIMD() + for (int o = 0; o < 16; ++o) + d_bias[o] += d_dst[o]; + d_dst += 16; + } + + nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); + } + } + + rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const { + + const auto &jcp = kernel_->jcp; + + const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh + * jcp.kw * jcp.kd; + const int bia_size = jcp.ngroups * jcp.oc; + const diff_weights_data_t *diff_bias_ws + = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; + + if (nthr_mb_ > 1) mkldnn_thr_barrier(); + + if (ti->ithr == 0) + { + for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { + acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); + diff_bias_ws += bia_size; + } + } +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx) const +{ + const auto &j = pd()->jcp_; + auto scratchpad = this->scratchpad(ctx); + + if (j.ver == ver_4fma) { + if (!j.is_1stconv) { + // XXX: See the comment about tr_iw and guarding elements in + // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf() + const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic; + const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw; + + auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); + /* to avoid NaNs in computations we zero tail num_guard_elems for + * each possible thread group */ + + for (int ithr = 1; ithr <= max_nthr; ++ithr) { + src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr]; + for (int i = 0; i < j.tr_src_num_guard_elems; ++i) + ts[i] = 0; + } + } + + if (j.nthr_oc_b > 1) { + const int tr_src_bctx_size = j.nthr / j.nthr_oc_b; + auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( + key_conv_tr_src_bctx); + for (int i = 0; i < tr_src_bctx_size; ++i) + simple_barrier::ctx_init(&tr_src_bctx[i]); + } + } + + if (nthr_mb_ > 1) { + simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( + key_conv_wei_bia_reduction_bctx)); + } + + const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad, + prefix_reducer_bia); + auto rb = this->reducer_bias_; + rb->init(reducer_bia_scratchpad); +} + +template <data_type_t src_type, data_type_t diff_dst_type, + data_type_t diff_weights_type> +void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, + diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) const { + prepare_scratchpad_data(ctx); + + parallel(nthr_, [&](const int ithr, const int nthr) { + assert(nthr_ == nthr); + + thread_info_t thread_info(this, ctx, ithr); + + if (utils::one_of(pd()->ndims(), 3, 4)) { + compute_diff_weights(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); + if (pd()->with_bias()) compute_diff_bias(&thread_info); + } else if (pd()->ndims() == 5) { + compute_diff_weights_3d(&thread_info); + if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); + if (pd()->with_bias()) compute_diff_bias_3d(&thread_info); + } else { + assert(false); + } + }); + + /* TODO: put that into compute_diff_bias() */ + if (pd()->wants_padded_bias()) { + auto diff_bias = scratchpad(ctx).template get<const diff_weights_data_t>( + key_conv_padded_bias); + auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias_in[oc] = diff_bias[oc]; + } +} + +template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp new file mode 100644 index 0000000000..3341c3ebe0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_transpose_src_utils.hpp" +#include "jit_avx512_common_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, + impl::data_type_t wei_type = src_type, + impl::data_type_t dst_type = src_type> +struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad, + jcp_); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_common_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + ~jit_avx512_common_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_forward_1d(ctx); + else if (pd()->ndims() == 4) + execute_forward_2d(ctx); + else if (pd()->ndims() == 5) + execute_forward_3d(ctx); + else + assert(false); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); + + return status::success; + } + +private: + void prepare_padded_bias(const dst_data_t *&bias, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_fwd_kernel *kernel_; +}; + +template <impl::data_type_t diff_dst_type, + impl::data_type_t wei_type = diff_dst_type, + impl::data_type_t diff_src_type = diff_dst_type> +struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, + data_type::undef, diff_dst_type, data_type::undef) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_, + *desc(), *diff_src_md(), *weights_md(), *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), + OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i, + OIdhw16o16i, gOIdhw16o16i); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_avx512_common_convolution_bwd_data_t(const pd_t *apd) + : cpu_primitive_t(apd) + { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); } + ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<diff_src_type>::type diff_src_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->ndims() == 3) + execute_backward_data_1d(ctx); + else if (pd()->ndims() == 4) + execute_backward_data_2d(ctx); + else if (pd()->ndims() == 5) + execute_backward_data_3d(ctx); + else + assert(false); + return status::success; + } + +private: + void execute_backward_data_1d(const exec_ctx_t &ctx) const; + void execute_backward_data_2d(const exec_ctx_t &ctx) const; + void execute_backward_data_3d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_; +}; + +template <impl::data_type_t src_type, + impl::data_type_t diff_dst_type = src_type, + impl::data_type_t diff_weights_type = src_type> +struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_weights_type, + diff_weights_type, diff_dst_type, data_type::undef) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), src_md_, diff_weights_md_, + diff_bias_md_, diff_dst_md_); + if (status != status::success) return status; + + init_balancers(); + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad( + scratchpad, jcp_); + + auto reducer_bia_scratchpad = memory_tracking::registrar_t( + scratchpad, memory_tracking::names::prefix_reducer_bia); + reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); + + return status; + } + + jit_conv_conf_t jcp_; + typename cpu_reducer_t<diff_weights_type>::conf_t reducer_bia_conf_; + + private: + void init_balancers() { + const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; + if (with_bias()) { + reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, + jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb, + max_buffer_size)); + } + } + }; + + jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd); + ~jit_avx512_common_convolution_bwd_weights_t() { + delete kernel_; + if (trans_kernel_) + delete trans_kernel_; + if (acc_ker_) + delete acc_ker_; + delete reducer_bias_; + } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + void prepare_scratchpad_data(const exec_ctx_t &ctx) const; + struct thread_info_t; + void compute_diff_weights(const thread_info_t *) const; + void compute_diff_weights_3d(const thread_info_t *) const; + void reduce_diff_weights(const thread_info_t *) const; + void reduce_diff_weights_3d(const thread_info_t *) const; + void compute_diff_bias(const thread_info_t *) const; + void compute_diff_bias_3d(const thread_info_t *) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_; + + jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_; + jit_trans_src_t *trans_kernel_; + cpu_accumulator_1d_t<diff_weights_type> *acc_ker_; + cpu_reducer_t<diff_weights_type> *reducer_bias_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp new file mode 100644 index 0000000000..62247c0264 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp @@ -0,0 +1,1215 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include <immintrin.h> +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_convolution_winograd.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +namespace { + +unsigned int LLC_cache_size = get_cache_size(3, false); + +void inline load_ps(float *dest, const float *src_mem) { +#ifdef __INTEL_COMPILER + __m512 *Iv512 = (__m512 *)dest; + Iv512[0] = _mm512_load_ps(src_mem); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v]; +#endif +} + +void inline store_output(float *dest, const float *data, bool streamout) { +#ifdef __INTEL_COMPILER + if (streamout) + _mm512_stream_ps(dest, *((__m512 *)data)); + else + _mm512_store_ps(dest, *((__m512 *)data)); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} + +void inline accum_output( + float *dest, float *data, bool streamout, bool with_relu_postsum) { +#ifdef __INTEL_COMPILER + __m512 _data = _mm512_loadu_ps(data); + __m512 _dest = _mm512_loadu_ps(dest); + _data = _mm512_add_ps(_data, _dest); + if (with_relu_postsum) + _data = _mm512_max_ps(_data, _mm512_setzero_ps()); + if (streamout) + _mm512_stream_ps(dest, _data); + else + _mm512_store_ps(dest, _data); +#else + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + data[v] += dest[v]; + + if (with_relu_postsum) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + if (data[v] < 0.f) + data[v] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + dest[v] = data[v]; +#endif +} +} + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) { + float Fw[6][16]; + float T[6][3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +#pragma unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * F[2][i][j][k]; + t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k]; + t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k]; + + T[0][i][k] = 1.13777777777778f * F[0][i][j][k]; + T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k]; + T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k]; + T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k]; + T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k]; + T[5][i][k] = F[2][i][j][k]; + } + } +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < 16; k++) { + t0[k] = 0.26890756302521f * T[i][2][k]; + t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k]; + t2[k] = t0[k] + 0.119514472455649f * T[i][0][k]; + + Fw[0][k] = 1.13777777777778f * T[i][0][k]; + Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k]; + Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k]; + Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k]; + Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k]; + Fw[5][k] = T[i][2][k]; +#pragma unroll + for (int l = 0; l < 6; l++) { + Fw_[i][l][j][k] = Fw[l][k]; + } + } + } + } +} + +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) { + float T[4][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + +#pragma unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][v] + Mw[2][i][v]; + t1[v] = Mw[3][i][v] + Mw[4][i][v]; + t2[v] = Mw[1][i][v] - Mw[2][i][v]; + t3[v] = Mw[3][i][v] - Mw[4][i][v]; + + T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v]; + T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f; + T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v]; + } + } +#pragma unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = T[i][1][v] - T[i][2][v]; + t3[v] = T[i][3][v] - T[i][4][v]; + + O[i][0][v] = t0[v] + t1[v] + T[i][0][v]; + O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f; + O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f; + O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v]; + } + } +} + + +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]) +{ + const float rcp3 = 1.0f / 3.0f; + const float rcp4 = 1.0f / 4.0f; + const float rcp6 = 1.0f / 6.0f; + const float rcp12 = 1.0f / 12.0f; + const float rcp24 = 1.0f / 24.0f; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float T[6][4][16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = F[2][i][j] * rcp6; + t1[j] = F[0][i][j] * -rcp6 - t0[j]; + t2[j] = F[0][i][j] * rcp24 + t0[j]; + t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6; + t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3; + + T[0][i][j] = F[0][i][j] * rcp4; + T[1][i][j] = t1[j] - t3[j]; + T[2][i][j] = t1[j] + t3[j]; + T[3][i][j] = t2[j] + t4[j]; + T[4][i][j] = t2[j] - t4[j]; + T[5][i][j] = F[3][i][j]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < 16; j++) { + t0[j] = T[i][2][j] * rcp6; + t1[j] = T[i][0][j] * -rcp6 - t0[j]; + t2[j] = T[i][0][j] * rcp24 + t0[j]; + t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6; + t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3; + + Fw[i][0][j] = T[i][0][j] * rcp4; + Fw[i][1][j] = t1[j] - t3[j]; + Fw[i][2][j] = t1[j] + t3[j]; + Fw[i][3][j] = t2[j] + t4[j]; + Fw[i][4][j] = t2[j] - t4[j]; + Fw[i][5][j] = T[i][3][j]; + } + } +} + +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[4][6][16]; + float M_[3][16]; + float t0[16]; + float t1[16]; + float t2[16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l]; + t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l]; + t2[l] = t1[l] * 4.0f + Mw[5][i][j][l]; + + T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l]; + T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) + + 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]); + T[2][i][l] = t0[l] + t2[l]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int l = 0; l < 16; l++) { + t0[l] = T[i][1][l] + T[i][2][l]; + t1[l] = T[i][3][l] + T[i][4][l]; + t2[l] = t1[l] * 4.0f + T[i][5][l]; + + M_[0][l] = T[i][0][l] + t0[l] + t1[l]; + M_[1][l] = (T[i][1][l] - T[i][2][l]) + + 2.0f * (T[i][3][l] - T[i][4][l]); + M_[2][l] = t0[l] + t2[l]; + + for (int k = 0; k < 3; k++) { + M[i][k][j][l] = M_[k][l]; + } + } + } + } +} + +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]) +{ + float T[6][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + float t5[16]; + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = I[2][i][v] * -2.25f + I[4][i][v]; + t1[v] = I[1][i][v] * -2.25f + I[3][i][v]; + t2[v] = I[2][i][v] * -0.390625f + I[4][i][v]; + t3[v] = I[1][i][v] * -0.390625f + I[3][i][v]; + t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v]; + t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v]; + + T[0][i][v] = I[2][i][v] * -2.640625f + t4[v]; + T[1][i][v] = t1[v] * 0.625f + t0[v]; + T[2][i][v] = t1[v] * -0.625f + t0[v]; + T[3][i][v] = t3[v] * 1.5f + t2[v]; + T[4][i][v] = t3[v] * -1.5f + t2[v]; + T[5][i][v] = I[3][i][v] * -2.640625f + t5[v]; + } + } + +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * -2.25f + T[i][4][v]; + t1[v] = T[i][1][v] * -2.25f + T[i][3][v]; + t2[v] = T[i][2][v] * -0.390625f + T[i][4][v]; + t3[v] = T[i][1][v] * -0.390625f + T[i][3][v]; + t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v]; + t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v]; + + Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v]; + Iw[i][1][v] = t1[v] * 0.625f + t0[v]; + Iw[i][2][v] = t1[v] * -0.625f + t0[v]; + Iw[i][3][v] = t3[v] * 1.5f + t2[v]; + Iw[i][4][v] = t3[v] * -1.5f + t2[v]; + Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v]; + } + } +} + +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]) +{ + float T[6][4][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float t3[16]; + float t4[16]; + +pragma_unroll + for (int i = 0; i < 4; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = F[2][i][v] * 0.26890756302521f; + t1[v] = F[0][i][v] * -0.688403361344538f - t0[v]; + t2[v] = F[0][i][v] * 0.119514472455649f + t0[v]; + t3[v] = F[1][i][v] * 0.430252100840336f + + F[3][i][v] * 0.168067226890756f; + t4[v] = F[1][i][v] * 0.179271708683473f + + F[3][i][v] * 0.403361344537815f; + + T[0][i][v] = F[0][i][v] * 1.13777777777778f; + T[1][i][v] = t1[v] - t3[v]; + T[2][i][v] = t1[v] + t3[v]; + T[3][i][v] = t2[v] + t4[v]; + T[4][i][v] = t2[v] - t4[v]; + T[5][i][v] = F[3][i][v]; + } + } +pragma_unroll + for (int i = 0; i < 6; i++) { + for (int v = 0; v < 16; v++) { + t0[v] = T[i][2][v] * 0.26890756302521f; + t1[v] = T[i][0][v] * -0.688403361344538f - t0[v]; + t2[v] = T[i][0][v] * 0.119514472455649f + t0[v]; + t3[v] = T[i][1][v] * 0.430252100840336f + + T[i][3][v] * 0.168067226890756f; + t4[v] = T[i][1][v] * 0.179271708683473f + + T[i][3][v] * 0.403361344537815f; + + Fw[i][0][v] = T[i][0][v] * 1.13777777777778f; + Fw[i][1][v] = t1[v] - t3[v]; + Fw[i][2][v] = t1[v] + t3[v]; + Fw[i][3][v] = t2[v] + t4[v]; + Fw[i][4][v] = t2[v] - t4[v]; + Fw[i][5][v] = T[i][3][v]; + } + } +} + +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]) +{ + float T[3][6][16]; + float t0[16]; + float t1[16]; + float t2[16]; + float M_[3][16]; + + for (int j = 0; j < 16; j++) { +pragma_unroll + for (int i = 0; i < 6; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v]; + t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v]; + t2[v] = t1[v] * 2.25f + Mw[5][i][j][v]; + + T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v]; + T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) + + 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]); + T[2][i][v] = t0[v] * 0.390625f + t2[v]; + } + } +pragma_unroll + for (int i = 0; i < 3; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + t0[v] = T[i][1][v] + T[i][2][v]; + t1[v] = T[i][3][v] + T[i][4][v]; + t2[v] = t1[v] * 2.25f + T[i][5][v]; + + M_[0][v] = T[i][0][v] + t0[v] + t1[v]; + M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) + + 1.5f * (T[i][3][v] - T[i][4][v]); + M_[2][v] = t0[v] * 0.390625f + t2[v]; + } + +pragma_unroll + for (int k = 0; k < 3; k++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < 16; v++) { + M[i][k][j][v] = M_[k][v]; + } + } + } + } +} + +template <bool is_fwd> +void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp, bool streamout = true) +{ + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + const int wp_max = inpw + l_pad; + const int hp_max = inph + t_pad; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + + array_offset_calculator<float, 5> input(inp, + jcp.mb, jcp.dimK/simd_w, inph, inpw, + simd_w); + array_offset_calculator<float, 8> output(tinp, + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((t_pad <= ydim) && (ydim < hp_max)) { + float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ; + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((l_pad <= xdim) && (xdim < wp_max)) { + float *pinp_i = pinp_j + (xdim - l_pad) * 16; + load_ps(I[j][i], pinp_i); + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_I_4x4_3x3(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(tile_block, j, i, + nb_tile_block_ur, 0, 0, + tile_block_ur, 0)), + Iw[j][i], streamout); + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template <bool is_fwd> +void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + array_offset_calculator<float, 6> input(wp, + jcp.oc/jcp.oc_simd_block, + jcp.ic/jcp.ic_simd_block, + jcp.kh, jcp.kw, + simd_w, simd_w); + array_offset_calculator<float, 8> output(twp, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + simd_w, simd_w); + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + float *base_inp = is_fwd + ? &(input(0, 0, j, i, v1, 0)) + : &(input(0, 0, 2 - j, 2 - i, v1, 0)); + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + if (is_fwd) + F[j][i][v1][v2] = *(base_inp + v2); + else + F[j][i][v2][v1] = *(base_inp + v2); + } + } + } + } + + trans_W_4x4_3x3(Fw, F); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v1 = 0; v1 < simd_w; v1++) { + PRAGMA_OMP_SIMD() + for (int v2 = 0; v2 < simd_w; v2++) { + output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2]; + } + } + } + } +} + +template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum> +void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias, + bool streamout = true) { + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + /* Prepare for PostOps */ + bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1; + + array_offset_calculator<float, 8> input(toutp, + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Ow[j][i][v] = input(tile_block, 0, + j, i, + nb_tile_block_ur, 0, + tile_block_ur, v); + } + } + } + + trans_O_4x4_3x3(Ow, O); + + for (int j = 0; j < tile_size; j++) { + int ydim = tj * tile_size + j; + if (ydim < outh) { + float *pout_j = pout_b + ydim * outw * simd_w; + for (int i = 0; i < tile_size; i++) { + int xdim = ti * tile_size + i; + if (xdim < outw) { + float *pout_i = pout_j + xdim * simd_w; + if (is_fwd) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + O[j][i][v] += with_bias ? bias[v] : 0.f; + O[j][i][v] = true + && with_relu_presum && O[j][i][v] < 0.f + ? O[j][i][v] + * jcp.eltwise.alpha + : O[j][i][v]; + } + } + if (with_sum) + accum_output(pout_i, O[j][i], streamout, + with_relu_postsum); + else + store_output(pout_i, O[j][i], streamout); + } + } + } + } + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template <bool ver_4fma> +void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *Iw_temp, + void (*transpose_4fma_ker)(float *, float *)) +{ + + const int ifwp = conv.iw + conv.l_pad; + const int ifhp = conv.ih + conv.t_pad; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp, + alpha, alpha, conv.tile_4fma, simd_w); + array_offset_calculator<float, 5> input(inp, + conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w); + array_offset_calculator<float, 8> output(tinp, + conv.nb_ic, alpha, alpha, + conv.tile_block, conv.ic_block, + conv.nb_tile_block_ur, conv.tile_block_ur, + conv.ic_simd_block * conv.tile_4fma); + + int tile_base_index = + image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding); + int tile_4fma = 0; + int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if ((conv.t_pad <= ydim) && ydim < ifhp) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if ((conv.l_pad <= xdim) && xdim < ifwp) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input(0, 0, + ydim - conv.t_pad, + xdim - conv.l_pad, v); + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + trans_I_4x4_3x3(Iw, I); + + if (ver_4fma) { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, + tile_4fma, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = Iw[j][i][v]; + } + } + } + tile_4fma++; + if (tile_4fma == conv.tile_4fma) { + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + tile_4fma = 0; + tile_block_ur++; + } + } else { + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + } + + if (tile_block_ur == conv.tile_block_ur) { + tile_block_ur = 0; + ++nb_tile_block_ur; + } + if (nb_tile_block_ur == conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } + + if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) { + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) { + float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + Iw_temp_base[v] = 0; + } + } + } + } + float *outp = &(output(0, 0, 0, + tile_block, 0, + nb_tile_block_ur, tile_block_ur, 0)); + transpose_4fma_ker(outp, (float *)Iw_temp); + } +} + +template <bool with_bias> +void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv, + float *inp, float *tinp, float *dbias) +{ + + const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding; + float I[alpha][alpha][simd_w]; + float Iw[alpha][alpha][simd_w]; + + array_offset_calculator<float, 5> input(inp, + conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block); + array_offset_calculator<float, 8> output(tinp, + conv.nb_oc, alpha, alpha, + conv.tile_block, conv.oc_block, + conv.nb_tile_block_ur, + conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block); + + int tile_base_index = image * total_tiles; + int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma); + int nb_tile_block_ur = + (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + % conv.nb_tile_block_ur; + int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma) + / conv.nb_tile_block_ur; + + for (int tj = 0; tj < conv.jtiles; tj++) { + for (int ti = 0; ti < conv.itiles; ti++) { + for (int j = 0; j < alpha; j++) { + int ydim = tj * tile_size + j; + if (ydim < conv.oh) { + for (int i = 0; i < alpha; i++) { + int xdim = ti * tile_size + i; + if (xdim < conv.ow) { + float *input_base = &(input(0, 0, ydim, xdim, 0)); + + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = input_base[v]; + } + if (with_bias && j < tile_size && i < tile_size) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + dbias[v] += input_base[v]; + } + } + } else { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } else { + for (int i = 0; i < alpha; i++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + I[j][i][v] = 0.0f; + } + } + } + } + + trans_W_3x3_4x4_wu(Iw, I); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + store_output(&(output(0, j, i, + tile_block, 0, + nb_tile_block_ur, + tile_block_ur, 0)), + Iw[j][i], true); + } + } + tile_block_ur++; + if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= conv.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv, + float *wp, float *twp) +{ + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + + array_offset_calculator<float, 8> input(twp, + conv.nb_ic, conv.nb_oc, + alpha, alpha, + conv.oc_block, conv.ic_block, + conv.ic_simd_block, conv.oc_simd_block); + array_offset_calculator<float, 6> output(wp, + conv.oc/simd_w, conv.ic/simd_w, + conv.kh, conv.kw, + conv.ic_simd_block, conv.oc_simd_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + PRAGMA_OMP_SIMD() + for (int k = 0; k < conv.oc_simd_block; k++) { + Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k); + } + } + } + } + + trans_O_3x3_4x4_wu(Fw, F); + + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + for (int v = 0; v < conv.ic_simd_block; v++) { + store_output(&(output(0, 0, j, i, v, 0)), + F[j][i][v], true); + } + } + } +} + +template <bool is_fwd> +void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Note that jcp.with_eltwise is true for both fused conv+relu primitive + * and conv primitive with PostOps with relu before sum + * (PostOps relu after sum is handled later) */ + auto output_transform = jcp.with_bias + ? (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data<is_fwd, true, true, true> + : output_transform_data<is_fwd, true, true, false>) + : (jcp.with_sum + ? output_transform_data<is_fwd, true, false, true> + : output_transform_data<is_fwd, true, false, false>)) + : (jcp.with_eltwise + ? (jcp.with_sum + ? output_transform_data<is_fwd, false, true, true> + : output_transform_data<is_fwd, false, true, false>) + : (jcp.with_sum + ? output_transform_data<is_fwd, false, false, true> + : output_transform_data<is_fwd, false, false, false>)); + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator<float, 5> input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator<float, 5> output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator<float, 6> weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator<float, 2> bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator<float, 8> M(is_fwd + ? scratchpad.template get<float>(key_wino_M) + : scratchpad.template get<float>(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator<float, 8> U( + scratchpad.template get<float>(key_wino_U), + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator<float, 8> V(is_fwd + ? scratchpad.template get<float>(key_wino_V) + : scratchpad.template get<float>(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_cache_size ? true : false; + + const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0; + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data<is_fwd>(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout); + }); + + parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block, + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data<is_fwd>(jcp, + &(weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr); + }); + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block, + [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) { + + kernel_->gemm_loop_ker_first_iter( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, 0, 0, 0, 0))); + for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, + 0, 0, 0))); + } + + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block, + [&](int img, int M_blk1, int M_blk2) { + + const int M_blk = M_blk1 * jcp.dimM_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), + bias_ptr, output_is_aligned); + + }); + + } +} + +template struct _jit_avx512_common_convolution_winograd_t<true>; +template struct _jit_avx512_common_convolution_winograd_t<false>; + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); + for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc) + diff_bias[oc] = padded_bias[oc]; + } +} + +void jit_avx512_common_convolution_winograd_bwd_weights_t:: +_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const { + auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ? + diff_src_transform_bwd_weights<true> : + diff_src_transform_bwd_weights<false>; + auto diff_dst_transform_bwd_weights_ver = jcp.with_bias + ? diff_dst_transform_bwd_weights<true> + : diff_dst_transform_bwd_weights<false>; + + array_offset_calculator<float, 5> src((float *)ptr_src, + jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator<float, 6> diff_weights(ptr_diff_weights, + jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias() + ? scratchpad.get<float>(key_conv_padded_bias) : ptr_diff_bias, + jcp.oc/simd_w, simd_w); + + array_offset_calculator<float, 8> U( + scratchpad.get<float>(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, jcp.oc_simd_block); + + array_offset_calculator<float, 8> M( + scratchpad.get<float>(key_wino_M), + jcp.nb_oc, alpha, alpha, + jcp.tile_block, jcp.oc_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma, + jcp.oc_simd_block); + array_offset_calculator<float, 8> V( + scratchpad.get<float>(key_wino_V), + jcp.nb_ic, alpha, alpha, + jcp.tile_block, jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block * jcp.tile_4fma); + + const int trans_buffer_size = alpha * alpha * jcp.tile_4fma + * jcp.ic_simd_block; + array_offset_calculator<float, 2> trans_buffer( + scratchpad.get<float>(key_conv_tr_src), + nthreads, + trans_buffer_size); + + array_offset_calculator<float, 2> diff_bias_prv( + scratchpad.get<float>(key_conv_bia_reduction), + nthreads, + jcp.oc); + +PRAGMA_OMP(parallel num_threads(nthreads)) + { + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + +PRAGMA_OMP(for nowait) + for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) { + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) + diff_bias(bofm, v) = 0.0f; + } + } + + const int ithread = mkldnn_get_thread_num(); + + parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block, + [&](int img, int ifm1, int ifm2) { + float *transb = jcp.ver == ver_4fma + ? &(trans_buffer(ithread, 0)) + : NULL; + diff_src_transform_bwd_weights_ver(img, jcp, + &(src(img, ifm1 * jcp.ic_block + ifm2, + 0, 0, 0)), + &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)), + transb, + kernel_->transpose_4fma_ker); + }); + + parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block, + [&](int img, int ofm1, int ofm2) { + float *dbias = jcp.with_bias + ? &(diff_bias_prv(ithread, + simd_w * (ofm1 * jcp.oc_block + ofm2))) + : NULL; + diff_dst_transform_bwd_weights_ver(img, jcp, + &(diff_dst(img, ofm1 * jcp.oc_block + ofm2, + 0, 0, 0)), + &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)), + dbias); + }); + +PRAGMA_OMP(barrier) + + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) { + parallel_nd_in_omp(alpha, alpha, jcp.nb_oc, + [&](int oj, int oi, int ofm1) { + kernel_->gemm_loop_ker_first_iter( + (float *)&(U(ifm1, ofm1, oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, + 0, 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, + 0, 0, 0, 0, 0))); + for (int tile_block = 1; tile_block < jcp.tile_block; + tile_block++) { + kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1, + oj, oi, + 0, 0, 0, 0)), + (const float *)&(M(ofm1, oj, oi, tile_block, + 0, 0, 0, 0)), + (const float *)&(V(ifm1, oj, oi, tile_block, + 0, 0, 0, 0))); + } + }); + } + +PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2) { + diff_weights_transform_bwd_weights(jcp, + &(diff_weights(ofm1 * jcp.oc_block + ofm2, + ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), + &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0))); + }); + + if (jcp.with_bias) { +PRAGMA_OMP(for) + for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) { + for (int ithr = 0; ithr < nthreads; ithr++) { + float* base_bias_ptr = &(diff_bias(ofm1, 0)); + float* base_bias_prv_ptr = &(diff_bias_prv( + ithr * jcp.oc + ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < simd_w; ofm2++) { + base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2]; + } + } + } + } + } + + _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad); +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp new file mode 100644 index 0000000000..6c76f37c72 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp @@ -0,0 +1,318 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP +#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_common { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc + * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding); + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) { + const int nthr = mkldnn_get_max_threads(); + + size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr + * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block; + scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M); + + size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + + size_t padded_bias_sz = + jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0; + scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz); + } +} +} + +template <bool is_fwd> +struct _jit_avx512_common_convolution_winograd_t { + _jit_avx512_common_convolution_winograd_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp); + } + + ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; } + + protected: + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_common_convolution_winograd_fwd_t + : _jit_avx512_common_convolution_winograd_t<true> + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(), + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_fwd_t(){}; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_data_t + : _jit_avx512_common_convolution_winograd_t<false>, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf( + jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd) + : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) {} + + ~jit_avx512_common_convolution_winograd_bwd_data_t(){}; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, nullptr, this->scratchpad(ctx)); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_common_convolution_winograd_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, + hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""), + jit_avx512_common_convolution_winograd_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats() + && mkldnn_thr_syncable(); + if (!ok) return status::unimplemented; + + status_t status = + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_common::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_tag, nChw16c); + } + }; + + jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true), kernel_(nullptr) + { + kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32( + pd()->jcp_); + } + + ~jit_avx512_common_convolution_winograd_bwd_weights_t() + { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx)); + return status::success; + } + +private: + void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx, + const memory_tracking::grantor_t &scratchpad) const; + void _maybe_execute_diff_bias_copy(float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_; +}; + +void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]); +void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]); +void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]); +void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]); +void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]); +void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]); + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp new file mode 100644 index 0000000000..d4a451c021 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp @@ -0,0 +1,853 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_common_lrn.hpp" + +#include "jit_generator.hpp" + +#define FWD_RBC 4 +#define BWD_RBC 3 + +#define XMM_SIZE (4*sizeof(float)) +#define ZMM_SIZE (vlen) +#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE) +#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE) +#define SRC_PREV_OFFSET (vlen - XMM_SIZE) + +#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \ + statement;\ +} + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +enum params { vsize = 16, vlen = 64}; + +typedef struct { + const float *src; + float *dst, *ws0, *ws1; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *ws0, *ws1; + float *diff_src; +} jit_args_bwd_t; + +struct nChw16c_across { +/* version: + * -1: channels 0..15, + * 1: channels C-16 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 dst = r8; + Reg64 scratch0 = rdx; + Reg64 scratch1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm zalpha = zmm0; + Xmm xalpha = xmm0; + Zmm zk = zmm1; + Xmm xk = xmm1; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r9; + + int xsrc_prev = 2; + int zsrc = 7; + int xsrc_next = 3; + int zc = 7; + + int za = 2; + int zb = 3; + int zd = 5; + int ze = 6; + int zsum = 4; + int zdst = 2; + int zbase = 3; + int zsum2 = 5; + + prop_kind_t pk; + int use_h_parallelism; + + float alpha, k; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_fwd_t *); + void operator()(jit_args_fwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*FWD_RBC, + prf2_offt = 8*FWD_RBC + }; + + inline void compute_loop(int loop_size_param) + { + // loop_size - param for IRB_LOOP macro + int loop_size = FWD_RBC; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*3 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*7 + i); + }; + + if (!is_first && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen])); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen])); + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen])); + } + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0, + (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0, + (irb + prf2_offt)*vlen))); + } + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen))); + if (pk != prop_kind::forward_inference) { + IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1, + (irb + prf0_offt) * vlen))); + IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1, + (irb + prf2_offt) * vlen))); + } + + loop_size = loop_size_param; + if (loop_size == 0) + return; + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_prev), + ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET])); + } + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xsrc_next), + ptr[src + (irb + HW) * vlen])); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xsrc_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xsrc_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + + assert(zc == zsrc); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc))); + + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd))); + IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze))); + + IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha)); + + IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum))); + + IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2))); + + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum))); + + if (pk != prop_kind::forward_inference) { + IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen), + zreg(irb, zsum))); + } + IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum))); + IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst))); + if (pk != prop_kind::forward_inference) { + /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ + IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase))); + IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen), + zreg(irb, zsum))); + } + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + prop_kind_t prop_kind, + int use_h_parallel, + float A, + float K, + void *code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , pk(prop_kind) + , use_h_parallelism(use_h_parallel) + , alpha(A) + , k(K) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(dst, ptr[param + 8]); + if (pk != prop_kind::forward_inference) + { + mov(scratch0, ptr[param + 16]); + mov(scratch1, ptr[param + 24]); + } + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == -2; + is_single = J.version == 3; + + W = J.W; + HW = J.W*J.H; + int LSB = use_h_parallelism ? W : HW; + + sub(t, FWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(zalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(zk, xk); + + if (is_first || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2); + } + } + if (is_last || is_single) { + vxorps(xmm2, xmm2, xmm2); + for(int irb = 0; irb < FWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xmm2); + } + } + + int LSREST = LSB % FWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(FWD_RBC); + + add(src, FWD_RBC*vlen); + add(dst, FWD_RBC*vlen); + if (pk != prop_kind::forward_inference) + { + add(scratch0, FWD_RBC*vlen); + add(scratch1, FWD_RBC*vlen); + } + + for(int irb = 0; irb < FWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST); + + add(t, FWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); + } +}; + +status_t jit_avx512_common_lrn_fwd_t::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, data_d.data_type()) + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc()->prop_kind == forward_training) { + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + } + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float k = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk, + use_h_parallelism, alpha, k); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk, + use_h_parallelism, alpha, k); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k); + } +} + +jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, c16, C16, h, H); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, c16, C16, h, H); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_fwd_t args; + args.src = &src[offset]; + args.dst = &dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32: + public jit_generator { + int HW, W; + bool is_first; + bool is_last; + bool is_single; + + Reg64 src = rax; + Reg64 diffsrc = r8; + Reg64 diffdst = r9; + Reg64 workspace0 = rdx; + Reg64 workspace1 = rsi; + Reg64 imm_addr64 = rbx; + + Zmm znalphabeta = zmm0; + Xmm xnalphabeta = xmm0; + + Reg64 param = abi_param1; + Reg64 t = rsp; + Reg64 hw = r10; + + int xws1_prev = 1; + int xdiffdst_prev = 2; + int zws1 = 1; + + int zsrc = 1; + int zdiffdst = 5; + int zdiffsrc = 6; + + int xws1_next = 1; + int xdiffdst_next = 3; + + int za = 1; + int zb = 2; + int zd = 3; + int ze = 4; + int zws0 = 2; + + float nalphabeta; + + int use_h_parallelism; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32) + + void (*ker)(jit_args_bwd_t *); + void operator()(jit_args_bwd_t *arg) { ker(arg); } + + enum { + prf0_offt = 1*BWD_RBC, + prf2_offt = 8*BWD_RBC + }; + + inline void compute_loop(int loop_size_param, int prefetchL1, + int prefetchL2) + { + // loop_size - param for IRB_LOOP macro + int loop_size = loop_size_param; + + auto xreg = [=](int irb, int i) { + return Xmm(irb*6 + i); + }; + + auto zreg = [=](int irb, int i) { + return Zmm(irb*6 + i); + }; + +// ---- prefetching ------------------------------------------- + if (!is_first && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + - 2 * HW) * vlen])); + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + - HW) * vlen])); + } + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen])); + + if (!is_last && !is_single) { + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt + + 2 * HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt + + 2 * HW) * vlen])); + + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt + + HW) * vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt + + HW) * vlen])); + } + if (prefetchL1) + IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen])); + if (prefetchL2) + IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen])); +// ----------------------------------------------------------- + + if (loop_size_param == 0) + return; + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb + - 2 * HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb + - HW) * vlen + SRC_PREV_OFFSET])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev), + xreg(irb, xws1_prev))); + } + + IRB_LOOP(vmovups(zreg(irb, zws1), + EVEX_compress_addr(workspace1, irb*vlen))); + IRB_LOOP(vmovups(zreg(irb, zdiffdst), + EVEX_compress_addr(diffdst, irb*vlen))); + IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst), + zreg(irb, zws1))); + + if (!is_last && !is_single) { + IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb + + 2 * HW) * vlen])); + IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb + + HW) * vlen])); + IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next), + xreg(irb, xws1_next))); + } + + if (!is_first && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK], + xreg(irb, xdiffdst_prev))); + } + IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE), + zreg(irb, zdiffsrc))); + if (!is_last && !is_single) { + IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], + xreg(irb, xdiffdst_next))); + } + + IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 2*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE - 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 1*sizeof(float)))); + IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK + + XMM_SIZE + 2*sizeof(float)))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, za))); + assert(zsrc == za); + IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zb))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, zd))); + IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc), + zreg(irb, ze))); + IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta)); + + IRB_LOOP(vmovups(zreg(irb, zws0), + EVEX_compress_addr(workspace0, irb*vlen))); + IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst), + zreg(irb, zws0))); + IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc), + zreg(irb, zdiffdst))); + + Label unaligned_store, end_store; + test(diffsrc, vlen - 1); + jnz(unaligned_store, T_NEAR); + IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + jmp(end_store, T_NEAR); + L(unaligned_store); { + IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen), + zreg(irb, zdiffsrc))); + } + L(end_store); + } + + jit_avx512_common_lrn_kernel_f32( + const struct nChw16c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2*A*B) + , use_h_parallelism(use_h_parallel) + { + this->preamble(); + + mov(src, ptr[param + 0]); + mov(diffdst, ptr[param + 8]); + mov(workspace0, ptr[param + 16]); + mov(workspace1, ptr[param + 24]); + mov(diffsrc, ptr[param + 32]); + + W = J.W; + HW = J.H*J.W; + int LSB = this->use_h_parallelism ? W : HW; + + sub(t, BWD_RBC*BUFFER_BLOCK); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(znalphabeta, xnalphabeta); + + is_first = J.version == -1 || J.version == -2; + is_last = J.version == +1 || J.version == +2; + is_single = J.version == 3; + + if (is_first || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1); + } + } + if (is_last || is_single) { + vxorps(xmm1, xmm1, xmm1); + for(int irb = 0; irb < BWD_RBC; irb++) { + vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1); + } + } + + int LSREST = LSB % BWD_RBC; + int LS = LSB - LSREST; + + Label lrn_loop; + + if (LS > 0) { + mov(hw, LS); + + L(lrn_loop); + { + compute_loop(BWD_RBC, 1, 1); + + add(src, BWD_RBC*vlen); + add(diffsrc, BWD_RBC*vlen); + add(diffdst, BWD_RBC*vlen); + add(workspace0, BWD_RBC*vlen); + add(workspace1, BWD_RBC*vlen); + + for(int irb = 0; irb < BWD_RBC; irb++) + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + } + + compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1); + + add(t, BWD_RBC*BUFFER_BLOCK); + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); + } + +}; + +status_t jit_avx512_common_lrn_bwd_t::pd_t::init() { + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(avx512_common) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % vsize == 0 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + dims_t ws_dims = { MB(), C(), H(), 2*W() }; + mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32, + format_tag::nChw16c); + + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && desc()->lrn_beta == 0.75 + && data_d.matches_tag(format_tag::nChw16c); + + return args_ok_across ? success : unimplemented; +} + +jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr) + , ker_last_(nullptr) { + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + const float alpha = pd()->desc()->lrn_alpha / ls; + const float beta = pd()->desc()->lrn_beta; + + use_h_parallelism = H > 28 ? 1 : 0; + + if (C / vsize == 1) { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), + alpha, beta, use_h_parallelism); + } else { + ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), + alpha, beta, use_h_parallelism); + ker_first_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism); + ker_last_ = new jit_avx512_common_lrn_kernel_f32( + nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism); + } +} + +jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + const int C16 = C / vsize; + const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16; + + balance211(work_amount, nthr, ithr, start, end); + if (use_h_parallelism) { + int n{0}, c16{0}, h{0}; + nd_iterator_init(start, n, N, h, H, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize + + h*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize + + h*2*W*vsize; + auto ws_offset1 = ws_offset0 + W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + nd_iterator_step(n, N, h, H, c16, C16); + } + } else { + int n{0}, c16{0}; + nd_iterator_init(start, n, N, c16, C16); + for (size_t iwork = start; iwork < end; ++iwork) { + auto offset = n*C*H*W + c16*H*W*vsize; + auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize; + auto ws_offset1 = ws_offset0 + H*W*vsize; + + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.ws0 = &ws[ws_offset0]; + args.ws1 = &ws[ws_offset1]; + args.diff_src = &diff_src[offset]; + + if (C16 == 1) + (*ker_)(&args); + else if (c16 == 0) + (*ker_first_)(&args); + else if (c16 == C16 - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + + nd_iterator_step(n, N, c16, C16); + } + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp new file mode 100644 index 0000000000..37fbb9b3e5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP +#define CPU_JIT_AVX512_COMMON_LRN_HPP + +#include "c_types_map.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_fwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_fwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_fwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""), + jit_avx512_common_lrn_bwd_t); + + status_t init(); + }; + + jit_avx512_common_lrn_bwd_t(const pd_t *apd); + ~jit_avx512_common_lrn_bwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + int use_h_parallelism; + struct jit_avx512_common_lrn_kernel_f32; + jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp new file mode 100644 index 0000000000..c58d3fa0a6 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp @@ -0,0 +1,1103 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_2x3.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = + reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode())); + } + + void generate(); + + Zmm vreg_inp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(15 - i); + } + + Zmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + +}; + +void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() { + Label ic_block_label; + + const int load_block = 16; + int out_offset = 0, inp_offset = 0; + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, src); + READ_PARAM(reg_aux_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + } + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + for (int y = 0; y < jcp.alpha; y++) { + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm = vreg_inp(y * jcp.alpha + x); + + vxorps(zmm, zmm, zmm); + kandw(r_mask, y_mask, x_mask(x)); + inp_offset = sizeof(float) + * ((-jcp.t_pad + y) * jcp.iw * load_block + + (-jcp.l_pad + x) * load_block); + vmovups(zmm | r_mask, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + } + for (int y = 0; y < jcp.alpha; y++) { + vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), + vreg_inp(y * jcp.alpha + 2)); + vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 2)); + vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), + vreg_inp(y * jcp.alpha + 1)); + vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), + vreg_inp(y * jcp.alpha + 3)); + } + for (int x = 0; x < jcp.alpha; x++) { + vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), + vreg_tmp(x + jcp.alpha * 2)); + vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 2)); + vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), + vreg_tmp(x + jcp.alpha * 1)); + vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), + vreg_tmp(x + jcp.alpha * 3)); + } + + for (int i = 0; i < 16; i++) { + out_offset = sizeof(float) * (jcp.inp_stride * i); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), + vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); + add(reg_aux_ptr_dst, sizeof(float) * load_block); + } + dec(reg_ic_block); + cmp(reg_ic_block, 0); + jg(ic_block_label, T_NEAR); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>( + const_cast<uint8_t *>(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); + return Zmm(31 - id_reg_stg); + } + + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id_reg_out < jcp.alpha * jcp.alpha + 12); + return Zmm(31 - id_reg_out); + } + + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_prev_dst = Zmm(0); + Zmm vreg_bias = Zmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert (id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1); + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() { + Label oc_block_label; + + const int load_block = 16; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for (int i = 0; i < 16; i++) { + int internal_offset = sizeof(float) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for (int y = 0; y < jcp.alpha; y++) { + vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); + vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); + + vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); + vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3)); + } + for (int x = 0; x < jcp.m; x++) { + vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1)); + vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2)); + + vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2)); + vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3)); + } + + + if (jcp.with_bias) { + auto bias_addr = ptr [ reg_ptr_bias ]; + vmovups(vreg_bias, bias_addr); + } + for (int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]); + for (int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = sizeof(float) * + (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + + if (maybe_relu(0)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + if (p_sum_scale) { // post_op: sum + vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + vmovups(vreg_prev_dst | r_mask, addr); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) { + vxorps(vreg_zero, vreg_zero, vreg_zero); + vmaxps(zmm, vreg_zero, zmm); + } + + vmovups(addr, zmm | r_mask); + } + } + }; + + preamble(); + +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_aux_ptr_src, wino_dst); + READ_PARAM(reg_aux_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +#undef READ_PARAM + + for (int i = 0; i < jcp.alpha * jcp.alpha; i++) + vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); + + for (int i = 0; i < jcp.alpha; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); + + int oc_blocks = 1; + oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); + { + loop_body(); + add(reg_aux_ptr_src, sizeof(float) * load_block); + add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, jcp.typesize_bia * load_block); + } + dec(reg_oc_block); + cmp(reg_oc_block, 0); + jg(oc_block_label, T_NEAR); + + sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp) { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>( + const_cast<uint8_t *>(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, + memory_desc_t& expect_wei_md); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert (31 - jcp.n2_block * jcp.m_block - i > 1); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + + Reg64 reg_mb = rbx; + Reg64 reg_nnb = rdx; + Reg64 reg_K = rsi; + +}; + +bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + preamble(); +#define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); +#undef READ_PARAM + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + + mov(reg_K, jcp.k_chunks); + L(K_loop_label); { + int wei_offset = 0; + for (int _i = 0; _i < jcp.k2_block; _i++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + if (jcp.small_mb) { + int wei_offset = sizeof(float) + * ((nb2 * jcp.nb_ic * jcp.ic_block + * jcp.oc_block) + + _i * jcp.oc_block); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } else { + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, + sizeof(float) * wei_offset)); + wei_offset += jcp.oc_block; + } + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = sizeof(float) * (m * jcp.K + _i); + if (jcp.n2_block > 1) { + vbroadcastss(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2), + vreg_src); + } else { + vfmadd231ps(vreg_out(0, m), vreg_wei(0), + EVEX_compress_addr(reg_aux_src2, inp_offset, true)); + } + } + } + add(reg_aux_src2, sizeof(float) * jcp.ic_block); + if (jcp.small_mb) + add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); + else + add(reg_aux_wei2, + sizeof(float) * jcp.k2_block * jcp.n2_block + * jcp.oc_block); + } + dec(reg_K); + cmp(reg_K, 0); + jg(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + int nb2 = 0; + for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = sizeof(float) * + (m * jcp.N + nb2 * jcp.oc_block); + vmovups(EVEX_compress_addr(reg_aux_dst2,offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); + add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); + } + dec(reg_mb); + cmp(reg_mb, 0); + jg(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); + add(reg_aux_wei, + sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block + * jcp.oc_block); + + dec(reg_nnb); + cmp(reg_nnb, 0); + jg(nnb_loop_label, T_NEAR); + } + postamble(); +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + return jcp.mb >= 4; +} +} + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &wei_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr, memory_desc_t &expect_wei_md) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + int simdw = 16; + + format_tag_t dat_tag = format_tag::nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simdw); + jcp.ic = rnd_up(jcp.ic, simdw); + } + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core))) + return status::unimplemented; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (src_d.data_type() != data_type::f32) + return status::unimplemented; + if (wei_d.data_type() != data_type::f32) + return status::unimplemented; + if (dst_d.data_type() != data_type::f32) + return status::unimplemented; + + jcp.ic_block = simdw; + jcp.oc_block = simdw; + + bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad + && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 + && jcp.l_pad < 2 && jcp.l_pad >= 0; + if (!ok) + return status::unimplemented; + + const int L2_cap = get_cache_size(2, true) / sizeof(float); + const int L3_capacity = get_cache_size(3, false) / sizeof(float); + int a = jcp.alpha; + int aa = a * a; + int mb = jcp.mb; + int ic = jcp.ic; + int oc = jcp.oc; + int ih = jcp.ih; + int iw = jcp.iw; + auto wei_sz = (float)aa * ic * oc; + auto inp_sz = (float)mb * ih * iw * ic; + auto sp_sz = (float)mb * ih * iw; + + /* Heuristics here. Numbers '28','196' is an observation from data. */ + if (wei_sz / inp_sz > 5) + jcp.small_mb = true; + else + jcp.small_mb = false; + + if (mb > nstl::min(jcp.nthr, 28) + || (!jcp.small_mb + && (wei_sz >= 0.9f * L2_cap + || inp_sz > L2_cap * jcp.nthr + L3_capacity)) + || (jcp.small_mb && sp_sz > 196)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + const int skx_free_regs = 30; + + auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, + int &n2_block, float ®_eff) { + M = (xb * yb) / jcp.alpha; + int max_m_block = m_block = nstl::min(M, skx_free_regs); + int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); + reg_eff = 0; + for (int im = max_m_block; im > 0; im--) { + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = in2 * im + in2; + float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; + if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs + || cur_reg_eff <= reg_eff) + continue; + reg_eff = cur_reg_eff; + m_block = im; + n2_block = in2; + } + } + }; + + int oh = jcp.oh; + int ow = jcp.ow; + int nb_oc = jcp.nb_oc; + int Z = ic + oc; + int Y = ic * oc; + const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float); + + /* Selecting xb and yb blocking */ + int min_yb = jcp.alpha; + int min_xb = jcp.alpha; + int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); + int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); + float best_eff = 0.f; + for (int ix = max_xb; ix >= min_xb; ix -= 2) { + if (rnd_up(ow, ix) < iw - 2) + continue; + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + if (rnd_up(oh, iy) < ih - 2) + continue; + int ex_y = rnd_up(oh, iy); + int ex_x = rnd_up(ow, ix); + float work_eff = (float)(ih * iw) / (ex_y * ex_x); + + int M, m_block, n2_b; + float reg_eff, thr_eff, par_eff, mem_eff, req_mem; + + find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); + + /* outer parallelization */ + int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + + mem_eff = 1.f; + req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; + if (req_mem > L2_cap / 2) { + if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7) + mem_eff /= (n2_b + 1) / 2.f; + else + mem_eff /= (n2_b + 1) / 3.f; + } + + float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; + + /* inner parallelization */ + int bsz = iy * ix / a; + int gemmw = aa * (nb_oc / n2_b); + int bsz_r = rnd_up(bsz, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); + + req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); + int oc_per_thr = + nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); + req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; + if (req_mem > L2_cap) + mem_eff = 0.1f; + par_eff = 1 / (2.f * nblocks); + + float inner_eff = thr_eff + work_eff + mem_eff + par_eff; + + float eff = jcp.small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.M = M; + jcp.m_block = m_block; + jcp.n2_block = n2_b; + } + } + } + + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + assert(jcp.M % jcp.m_block == 0); + assert(jcp.nb_oc % jcp.n2_block == 0); + + jcp.n_chunks = jcp.nb_oc / jcp.n2_block; + jcp.k2_block = jcp.ic_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format + = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = 1.f; + size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::pd_t::jit_conf(memory_desc_t& expect_wei_md) { + return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t:: + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +jit_avx512_core_fp32_wino_conv_2x3_fwd_t + ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + const size_t wino_size_offset = + (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb); + const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; + const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get<float>(key_wino_V); + auto ptr_M = scratchpad.get<float>(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = ptr_V + size_wino_src * ithr; + auto wino_dst = ptr_M + size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t + ::call_params_t(); + auto dst_trans_p = + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t + ::call_params_t(); + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw + * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; + x_in_block += 2) { + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow + * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb( + const float *src, const float *wei, const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const +{ + const auto &jcp = kernel_->jcp; + const auto &oscales = pd()->attr()->output_scales_; + + if (pd()->wants_padded_bias()) { + auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); + utils::array_copy(padded_bias, bia, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bia = padded_bias; + } + + auto ptr_V = scratchpad.get<float>(key_wino_V); + auto ptr_M = scratchpad.get<float>(key_wino_M); + + for (int mb = 0; mb < jcp.mb; mb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t :: + call_params_t(); + + unsigned short v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; + v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; + } + auto local_s = src + + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block + + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; + auto local_w = ptr_V + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t :: + call_params_t(); + + gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; + gemm_p.dst = ptr_M + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), + [&](int y_in_block_b, int x_in_block_b) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t :: + call_params_t(); + + unsigned short v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; + v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; + } + auto local_d = dst + + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block + + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; + auto local_w = ptr_M + m * jcp.oc; + + auto scales = oscales.scales_; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp new file mode 100644 index 0000000000..7e38b07f5a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp @@ -0,0 +1,144 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t; +struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t; +struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t; + +struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_2x3_fwd_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + memory_desc_t expect_wei_md = *weights_md(); + status_t jit_conf_result = jit_conf(expect_wei_md); + if (jit_conf_result != status::success) return jit_conf_result; + set_default_alg_kind(alg_kind::convolution_winograd); + + if (weights_md_.format_kind == format_kind::any) + weights_md_ = expect_wei_md; + if (weights_md_ != expect_wei_md) + return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(memory_desc_t& expect_wei_md); + + void init_scratchpad() { + using namespace memory_tracking::names; + + auto scratchpad = scratchpad_registry().registrar(); + + int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb; + + size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K); + + size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr; + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K); + + if (wants_padded_bias()) { + assert(jcp_.ngroups == 1); + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc); + } + } + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nChw16c, any, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd); + ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + if (pd()->jcp_.small_mb) + execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx)); + + return status::success; + } + +private: + void execute_forward_small_mb(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const float *src, const float *wei, + const float *bia, float *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_; + jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_; + jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp new file mode 100644 index 0000000000..96325e3ade --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp @@ -0,0 +1,1020 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifdef __INTEL_COMPILER +#include <immintrin.h> +#endif + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3.hpp" + +#ifndef _MSC_VER +#define pragma_unroll _Pragma("unroll") +#else +#define pragma_unroll +#endif + + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template <bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd> +::weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const +{ + float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f, + 1.13777777777778f, 0.430252100840336f, 0.179271708683473f}; + const int kh = 3; + const int kw = 3; + float Fw[alpha][alpha][simd_w][simd_w]; + float F[kh][kw][simd_w][simd_w]; + float T[alpha][3][simd_w]; + auto p = jit_wino_transform_call_s(); + + p.src = wp; + p.dst = twp; + p.G = G; + p.M = F; + p.Mw = Fw; + p.T = T; + + kernel_->weights_transform_data_ker(&p); +} + +template<bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data +(int image, const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = pout_b; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->output_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template<bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd> +::output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const { + + float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f}; + float Ow[alpha][alpha][simd_w]; + float O[tile_size][tile_size][simd_w]; + float T[tile_size][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + p.src = toutp; + p.dst = outp; + p.G = G; + p.M = O; + p.Mw = Ow; + p.T = T; + p.bias = bias; + + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block) + * outh * outw * jcp.dimM_simd_block; + + kernel_->output_transform_data_ker(&p); + + tile_index++; + } + } +} + + +template<bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd> + ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + auto p = jit_wino_transform_call_s(); + + p.src = inp; + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + int tile_base_index = image * jcp.itiles * jcp.jtiles; + int tile_block_ur = tile_base_index % jcp.tile_block_ur; + int nb_tile_block_ur = + (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur; + int tile_block = + (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur; + + for (int tj = 0; tj < jcp.jtiles; tj++) { + for (int ti = 0; ti < jcp.itiles; ti++) { + + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tile_block = tile_block; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_block_ur++; + if (tile_block_ur >= jcp.tile_block_ur) { + tile_block_ur = 0; + nb_tile_block_ur++; + } + if (nb_tile_block_ur >= jcp.nb_tile_block_ur) { + nb_tile_block_ur = 0; + tile_block++; + } + } + } +} + +template <bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd> + ::input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const +{ + float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float Iw[alpha][alpha][simd_w]; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + + array_offset_calculator<float, 5> input(inp, + jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w); + array_offset_calculator<float, 7> output(tinp, + alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + auto p = jit_wino_transform_call_s(); + + p.dst = tinp; + p.G = G; + p.M = I; + p.Mw = Iw; + p.T = T; + + + int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur; + + for (int nb_tile_block_ur = 0; + nb_tile_block_ur < jcp.nb_tile_block_ur; + nb_tile_block_ur++) { + + for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur; + tile_block_ur++) { + + int img = tile_index / (jcp.jtiles * jcp.itiles); + int ti = tile_index % jcp.itiles; + int tj = (tile_index / jcp.itiles) % jcp.jtiles; + float *pinp_b = &(input(img, 0, 0, 0, 0)); + + p.src = pinp_b; + p.tile_block_ur = tile_block_ur; + p.nb_tile_block_ur = nb_tile_block_ur; + p.tj = tj; + p.ti = ti; + + kernel_->input_transform_data_ker(&p); + + tile_index++; + } + } +} + +template <bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + /* Notation: + FWD: dimM:oc, dimN:ntiles, dimK:ic, + BWD: dimM:ic, dimN:ntiles, dimK:oc, + FWD/BWD: V: src/diff_dst transform, U:weight transform, + M:dst/diff_src transform */ + array_offset_calculator<float, 5> input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, + jcp.dimK_reg_block); + array_offset_calculator<float, 5> output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, + jcp.dimM_simd_block); + array_offset_calculator<float, 6> weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator<float, 2> bias(bias_ptr, + jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block); + + array_offset_calculator<float, 8> M(is_fwd + ? scratchpad.template get<float>(key_wino_M) + : scratchpad.template get<float>(key_wino_V), + jcp.dimN_nb_block, jcp.dimM_nb_block, + alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get<float>(key_wino_U); + + array_offset_calculator<float, 8> U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + array_offset_calculator<float, 8> V(is_fwd + ? scratchpad.template get<float>(key_wino_V) + : scratchpad.template get<float>(key_wino_M), + jcp.dimN_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + { + + parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block, + [&](int img, int K_blk1, int K_blk2) { + input_transform_data(img, jcp, + &(input(img, K_blk1 * jcp.dimK_block + K_blk2, + 0, 0, 0)), + &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + }); + + if (jcp.prop_kind != prop_kind::forward_inference) { + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), + (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, + [&](int N_blk1, int oj, int oi, int M_blk1) { + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; + K_blk1++) + for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++) + kernel_->gemm_loop_ker( + (float *)&(M(N_blk1, M_blk1, oj, oi, + N_blk2, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, + K_blk1, 0, 0, 0, 0)), + (const float *)&(V(N_blk1, oj, oi, + N_blk2, K_blk1, 0, 0, 0)), K_blk1); + }); + + parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block), + [&](int img, int M_blk1, int M_blk2) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + output_transform_data(img, jcp, p_ops, + &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(img, M_blk, 0, 0, 0)), bias_ptr); + }); + + } +} + +template <bool is_fwd> +void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD( + float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const auto &p_ops = attr_->post_ops_; + + const int inph = is_fwd ? jcp.ih : jcp.oh; + const int inpw = is_fwd ? jcp.iw : jcp.ow; + const int outh = is_fwd ? jcp.oh : jcp.ih; + const int outw = is_fwd ? jcp.ow : jcp.iw; + + array_offset_calculator<float, 5> input(inp_ptr, + jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block); + array_offset_calculator<float, 5> output(out_ptr, + jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block); + array_offset_calculator<float, 6> weights(wei_ptr, + jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw, + jcp.ic_simd_block, jcp.oc_simd_block); + array_offset_calculator<float, 2> bias(bias_ptr, + jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block); + + auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference) + ? wei_ptr + : scratchpad.template get<float>(key_wino_U); + + array_offset_calculator<float, 8> U(wino_wei, + jcp.dimM_nb_block, + alpha, alpha, + jcp.dimK_nb_block, + jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, jcp.dimM_simd_block); + + array_offset_calculator<float, 8> M(is_fwd + ? scratchpad.template get<float>(key_wino_M) + : scratchpad.template get<float>(key_wino_V), + 0, jcp.dimM_nb_block, alpha, alpha, + jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block, + jcp.dimN_reg_block, jcp.dimM_simd_block); + array_offset_calculator<float, 8> V(is_fwd + ? scratchpad.template get<float>(key_wino_V) + : scratchpad.template get<float>(key_wino_M), + 0, alpha, alpha, jcp.dimN_block, + jcp.dimK_nb_block, jcp.dimK_block, + jcp.dimN_reg_block, jcp.dimK_reg_block); + + const bool wants_padded_bias = jcp.with_bias + && jcp.oc_without_padding != jcp.oc; + float last_slice_bias[simd_w] = {0}; + if (wants_padded_bias) { + for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc) + last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc); + } + + if (jcp.prop_kind != prop_kind::forward_inference) { + + parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block), + [&](int ofm1, int ifm1, int ofm2, int ifm2) { + float *U_base_ptr = is_fwd + ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0)) + : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0)); + weight_transform_data(jcp, + &(weights( + ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2, + ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2, + 0, 0, 0, 0)), + U_base_ptr); + }); + } + + parallel_nd(jcp.tile_block, [&](int tile_block) { + int ithr = mkldnn_get_thread_num(); + + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) { + for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) { + + input_transform_tileblock_data( + tile_block, jcp, + &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)), + &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0))); + } + } + + for (int oj = 0; oj < alpha; oj++) { + for (int oi = 0; oi < alpha; oi++) { + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) + for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) + for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++) + kernel_->gemm_loop_ker( + (float *)&(M(ithr, M_blk1, oj, oi, + N_blk, 0, 0, 0)), + (const float *)&(U(M_blk1, oj, oi, K_blk1, + 0, 0, 0, 0)), + (const float *)&(V(ithr, oj, oi, + N_blk, K_blk1, 0, 0, 0)), K_blk1); + } + } + + for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) { + for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; + M_blk2++) { + const int M_blk = + M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2; + + float *bias_ptr = wants_padded_bias + && M_blk == jcp.dimM / jcp.dimM_simd_block - 1 + ? last_slice_bias : &bias(M_blk, 0); + + output_transform_tileblock_data(tile_block, jcp, p_ops, + &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)), + &(output(0, M_blk, 0, 0, 0)), bias_ptr); + } + } + }); +} + +template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>; +template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>; + +namespace { + +void subarray_sum(size_t num_arrs, float *output, size_t nelems, + float *input_ptrs[], size_t input_starts[], size_t input_ends[]) { + using namespace nstl; + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const int ithr = mkldnn_get_thread_num(); + const int nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + size_t input_start = max(start_e, min(input_starts[0], end_e)); + size_t input_end = max(start_e, min(input_ends[0], end_e)); + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < input_start; e++) { + output[e] = 0.f; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] = input_ptrs[0][e]; + } + + PRAGMA_OMP_SIMD() + for (size_t e = input_end; e < end_e; e++) { + output[e] = 0.f; + } + + for (size_t a = 1; a < num_arrs; a++) { + input_start = max(start_e, input_starts[a]); + input_end = min(input_ends[a], end_e); + + PRAGMA_OMP_SIMD() + for (size_t e = input_start; e < input_end; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} + +const int max_threads_number = 1024; + +// Sum to the first buffer array +void array_sum(size_t num_arrs, float *output, + size_t nelems, float *input_ptrs[], bool reduce_to_first = true) { + const size_t block_size = 16 * 1024 / sizeof(float); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + +PRAGMA_OMP(parallel) + { + const size_t ithr = mkldnn_get_thread_num(); + const size_t nthr = mkldnn_get_num_threads(); + size_t start{ 0 }, end{ 0 }; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + if (!reduce_to_first) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = input_ptrs[0][e]; + } + } + for (size_t a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += input_ptrs[a][e]; + } + } + } + } +} +} //bwdw namespace + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_SDGtWo(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator<float, 5> src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator<float, 6> diff_weights(ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U), + 0, alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc + * jcp.ic / jcp.nb_ic; + array_offset_calculator<float, 7>diff_weights_prv( + scratchpad.get<float>(key_wino_U) + U_sz, + 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + + array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M), + 0, alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V), + 0, alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator<float, 2> diff_bias_prv( + scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc); + + auto trans_ker_p = jit_wino_transform_call_s(); + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f, + 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f, + 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + +PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc / simd_w, + [&](int ithr, int ofm){ + float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w)); + PRAGMA_OMP_SIMD() + for (int v = 0; v < simd_w; v++) { + pdbias[v] = 0.0f; + } + }); + } + + int ithr = mkldnn_get_thread_num(); + for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) { + int first_tblk = 0; +PRAGMA_OMP(for) + for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) { + int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur; + int img = tile_index / (jcp.itiles * jcp.jtiles); + trans_ker_p.ti = tile_index % jcp.itiles; + trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles; + trans_ker_p.M = I; + trans_ker_p.T = T; + trans_ker_p.G = G_I_3x3_4x4; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + } + + for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) { + trans_ker_p.G = G_W_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias && ifm1 == 0) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + } + + for (int oj = 0; oj < alpha; ++oj) { + for (int oi = 0; oi < alpha; ++oi) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(M(ithr, oj, oi, 0, 0, 0, 0, 0)), + &(V(ithr, oj, oi, 0, 0, 0, 0))); + } + } + trans_ker_p.G = G_O_3x3_4x4; + for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) { + for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) { + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block + + ofm3; + for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) { + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(Us(ithr, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights_prv(ithr, + ofm, ifm, 0, 0, 0, 0)); + if (first_tblk == 0) { + kernel_->diff_weights_transform(&trans_ker_p); + } else { + kernel_->diff_weights_transform_accum(&trans_ker_p); + } + } + } + } + } + ++first_tblk; + } + } +} + + // Reduce diff-weights + { + float *output = ptr_diff_weights; + float *input_base = scratchpad.get<float>(key_wino_U) + U_sz; + int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + nelems * i; + } + array_sum(nthreads, output, nelems, input_ptrs, false); + + if (jcp.with_bias) { + output = ptr_diff_bias; + input_base = scratchpad.get<float>(key_conv_bia_reduction); + for (int i = 0; i < nthreads; ++i) { + input_ptrs[i] = input_base + jcp.oc * i; + } + array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs, + false); + } + } +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t:: +_execute_backward_weights_S_D_Giot_W(const float *ptr_src, + const float *ptr_diff_dst, float *ptr_diff_weights, + float *ptr_diff_bias, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const int nthreads = jcp.nthr; + + array_offset_calculator<float, 5> src((float *)ptr_src, + jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w); + array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, + jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w); + array_offset_calculator<float, 6> diff_weights((float *)ptr_diff_weights, + jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w); + array_offset_calculator<float, 1> diff_bias((float *)ptr_diff_bias, jcp.oc); + + array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U), + jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + const int U_size = jcp.oc * jcp.ic * alpha * alpha; + array_offset_calculator<float, 10> Us( + scratchpad.get<float>(key_wino_U) + U_size, + 0, jcp.nb_ic, jcp.nb_oc, + alpha, alpha, + jcp.oc_block, jcp.ic_block, + jcp.ic_simd_block, + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M), + jcp.nb_oc, + jcp.tile_block, + alpha, alpha, + jcp.oc_block, + jcp.nb_tile_block_ur, + jcp.tile_block_ur , + jcp.oc_reg_block, + jcp.oc_simd_block); + + array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V), + jcp.nb_ic, + jcp.tile_block, + alpha, alpha, + jcp.ic_block, + jcp.nb_tile_block_ur, jcp.tile_block_ur, + jcp.ic_simd_block); + + array_offset_calculator<float, 2> diff_bias_prv( + scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc); + + size_t input_starts[max_threads_number] = {0}; + size_t input_ends[max_threads_number] = {0}; + size_t first_tblk = 0; + + auto trans_ker_p = jit_wino_transform_call_s(); + float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, + 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f}; + float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, + 0.119514472455649f, 0.430252100840336f, 0.168067226890756f, + 0.179271708683473f, 0.403361344537815f, 1.13777777777778f}; + float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f}; + float I[alpha][alpha][simd_w]; + float T[alpha][alpha][simd_w]; + +PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) +{ + if (jcp.with_bias) { + parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) { + diff_bias_prv(ithr, ofm) = 0.0f; + }); + } + + trans_ker_p.G = G_I_3x3_4x4; + trans_ker_p.M = I; + trans_ker_p.T = T; + + parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb, + [&](int ifm1, int ifm2, int img){ + size_t ifm = ifm1 * jcp.ic_block + ifm2; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0)); + kernel_->src_transform(&trans_ker_p); + }); + + int ithr = mkldnn_get_thread_num(); + trans_ker_p.G = G_W_3x3_4x4; + parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb, + [&](int ofm1, int ofm2, int img){ + int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block; + size_t tile_base_index = img * (jcp.itiles * jcp.jtiles); + size_t tblk3 = tile_base_index % jcp.tile_block_ur; + size_t tblk2 = (tile_base_index / jcp.tile_block_ur) + % jcp.nb_tile_block_ur; + size_t tblk1 = (tile_base_index / jcp.tile_block_ur) + / jcp.nb_tile_block_ur; + trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3; + trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0)); + trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0)); + if (jcp.with_bias) { + trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w)); + kernel_->diff_dst_transform_wbias(&trans_ker_p); + } else { + kernel_->diff_dst_transform(&trans_ker_p); + } + }); + + PRAGMA_OMP(barrier) + + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block, + [&](int ifm1, int ofm1, int oj, int oi, int tblk1){ + if (first_tblk == 0) { + input_starts[ithr] = + (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, + 0, 0)) + - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0, + 0, 0, 0)); + input_ends[ithr] = input_starts[ithr] + + jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + else if (tblk1 == 0) { + input_ends[ithr] += jcp.oc_block * jcp.ic_block + * jcp.ic_simd_block * jcp.oc_reg_block + * jcp.oc_simd_block; + } + + if (first_tblk == 0 || tblk1 == 0) { + kernel_->gemm_loop_ker_first_iter( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } else { + kernel_->gemm_loop_ker( + &(Us(ithr, ifm1, ofm1, oj, oi, + 0, 0, 0, 0, 0)), + &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)), + &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0))); + } + ++first_tblk; + }); +} + + // Reduce diff-weights + { + float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0)); + size_t nelems = jcp.ic * jcp.oc * alpha * alpha; + float *input_ptrs[max_threads_number]; + for (int i = 0; i < nthreads; ++i) + input_ptrs[i] = output + nelems * (i + 1); + subarray_sum(nthreads, output, nelems, input_ptrs, + input_starts, input_ends); + } + + trans_ker_p.G = G_O_3x3_4x4; +PRAGMA_OMP(parallel firstprivate(trans_ker_p)) + { + parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, + [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){ + int ofm = (ofm1 * jcp.oc_block + ofm2) + * jcp.oc_reg_block + ofm3; + int ifm = ifm1 * jcp.ic_block + ifm2; + trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0, + ofm2, ifm2, 0, ofm3, 0)); + trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm, + 0, 0, 0, 0)); + kernel_->diff_weights_transform(&trans_ker_p); + }); + } + + if (jcp.with_bias) { + parallel_nd(jcp.oc / simd_w, [&](int ofm1) { + float* pbias = &(diff_bias(ofm1 * simd_w)); + float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w)); + + const int blk_sz = ofm1 == jcp.oc / simd_w - 1 + ? jcp.oc_without_padding - ofm1 * simd_w : simd_w; + + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] = pbias_prv[ofm2]; + } + + for (int ithr = 1; ithr < nthreads; ++ithr) { + pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w)); + PRAGMA_OMP_SIMD() + for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) { + pbias[ofm2] += pbias_prv[ofm2]; + } + } + }); + } +} + +} +} +} +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp new file mode 100644 index 0000000000..f1a56aac70 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp @@ -0,0 +1,386 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP +#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace winograd_avx512_core { +inline void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_winograd_conf_t &jcp) { + using namespace utils; + using namespace memory_tracking::names; + + size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc; + size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles + * jcp.jtiles; + size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles + * jcp.jtiles; + + switch (jcp.sched_policy) { + case WSCHED_DATA_W_SGD: + V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.ic; + M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur + * jcp.tile_block_ur * jcp.oc; + break; + case WSCHED_WEI_SDGtWo: + U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc + * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw); + M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.oc / jcp.nb_oc); + V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block) + * (jcp.ic / jcp.nb_ic); + break; + case WSCHED_WEI_S_D_Giot_W: + U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc; + M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles; + V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles; + break; + default: break; + } + + scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M); + scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M); + scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M); + + if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) { + size_t br_sz = (size_t)jcp.nthr * jcp.oc; + scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M); + } +} +} + +template <bool is_fwd> +struct _jit_avx512_core_fp32_wino_conv_4x3_t { + + _jit_avx512_core_fp32_wino_conv_4x3_t( + const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr) + : kernel_(nullptr), attr_(attr) { + kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp); + } + + ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; } + + protected: + void weight_transform_data(const jit_conv_winograd_conf_t &jcp, + float *wp, float *twp) const; + void input_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void input_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, + float *inp, float *tinp) const; + void output_transform_data(int image, + const jit_conv_winograd_conf_t &jcp, + const post_ops_t &p_ops, float *toutp, float *pout_b, + float *bias) const; + void output_transform_tileblock_data(int tile_block, + const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops, + float *toutp, float *outp, float *bias) const; + void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, + float *wei_ptr, float *bias_ptr, + const memory_tracking::grantor_t &scratchpad) const; + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_; + const primitive_attr_t *attr_; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t + : _jit_avx512_core_fp32_wino_conv_4x3_t<true> + , public cpu_primitive_t + { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + *attr()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = desc()->prop_kind == prop_kind::forward_training + ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t<true>(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)src, dst, (float *)weights, + (float *)bias, scratchpad); + break; + default: + break; + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t + : _jit_avx512_core_fp32_wino_conv_4x3_t<false>, + public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_data + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd) + : _jit_avx512_core_fp32_wino_conv_4x3_t<false>(apd->jcp_, apd->attr()) + , cpu_primitive_t(apd, true) + {} + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC); + + auto scratchpad = this->scratchpad(ctx); + + switch ((pd()->jcp_).sched_policy) { + case WSCHED_DATA_W_S_G_D: + this->_execute_data_W_S_G_D((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + case WSCHED_DATA_W_SGD: + this->_execute_data_W_SGD((float *)diff_dst, diff_src, + (float *)weights, NULL, scratchpad); + break; + + default: + break; + } + + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t + : public cpu_primitive_t { + struct pd_t : public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""), + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t); + + status_t init() { + bool ok = true + && mkldnn_thr_syncable() + && desc()->prop_kind == prop_kind::backward_weights + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && set_default_formats(); + if (!ok) + return status::unimplemented; + + status_t status = + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: + init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(), + *diff_weights_md()); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + auto scratchpad = scratchpad_registry().registrar(); + winograd_avx512_core::init_scratchpad(scratchpad, jcp_); + + return status; + } + + jit_conv_winograd_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o; + return set_default_formats_common(nChw16c, wei_fmt, nChw16c); + } + }; + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd, true) + , kernel_(nullptr) + { + kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + pd()->jcp_); + } + + ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t() + { + delete kernel_; + } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + switch (kernel_->jcp.sched_policy) { + case WSCHED_WEI_SDGtWo: + _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + case WSCHED_WEI_S_D_Giot_W: + _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights, + diff_bias, scratchpad(ctx)); + break; + default: + assert(kernel_->jcp.sched_policy != WSCHED_INVALID); + break; + } + return status::success; + } + +private: + void _execute_backward_weights_SDGtWo(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + void _execute_backward_weights_S_D_Giot_W(const float *src, + const float *diff_dst, float *diff_weights, float *diff_bias, + const memory_tracking::grantor_t &scratchpad) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp new file mode 100644 index 0000000000..0d64a2d13a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp @@ -0,0 +1,2596 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include <math.h> + +#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace mkldnn::impl::utils; + +unsigned int L1_cache_size = get_cache_size(1, true); +unsigned int L2_cache_size = get_cache_size(2, true); +unsigned int LLC_data_size = get_cache_size(3, false); + +// the test funtion takes jcp, the candidate and the current best. +// it returns true if the new candidate is better +int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, + int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) +{ + int best_divisor = default_best; + auto test_num + = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { + if (test(jcp, num, best_divisor)) { + best_divisor = num; + } + }; + + for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { + if (number % divisor == 0) { + test_num(jcp, divisor); + test_num(jcp, number / divisor); + } + } + + return best_divisor; +} + +namespace { +bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { + /* Determines if current winograd implementation is faster than direct. + Following conditions are empirical and based on performance data */ + unsigned int ncores_per_socket = + cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + unsigned int nthreads = mkldnn_get_max_threads(); + + if (jcp.prop_kind == prop_kind::forward_inference) { + return jcp.mb >= 4; + } else if (nthreads > ncores_per_socket) { + double src_dst_transforms_per_core = alpha * alpha + * (jcp.ic + jcp.oc) + * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) + * ((jcp.ow + tile_size - 1) / tile_size) + * sizeof(float) / 1024. / 1024. / nthreads; + double wei_transform = alpha * alpha + * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.; + + if (jcp.prop_kind == prop_kind::backward_weights) { + if (src_dst_transforms_per_core < 0.3 + || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) + return false; + else + return true; + } else { + if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) + return false; + } + } + + return jcp.mb > 8; +} +} + +/* assumes 512 bits registers */ +/* TODO: add support for strides */ +/* TODO: handle the prefetch distance automatically */ +typedef enum cache_t_ { L1, L2, L3 } cache_t; + +template <typename data_t> +struct prefetcher_t { + prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, + cache_t cache_type, size_t block_size, /* in number of elements*/ + int nb_instructions_in_block, int fma_ipc) + : cg_(generator) + , reg_base_addr_(reg_base_addr) + , cache_type_(cache_type) + , cache_block_size_(block_size) + { + nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); + prefetch_spread_ + = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); + prefetch_blk_ + = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); + + /* assumption: when fetch in Li, data is already in L(i+1) */ + int cache_latency; + switch (cache_type_) { + case L1: cache_latency = 14; break; + case L2: cache_latency = 250; break; + case L3: cache_latency = 250; break; + } + + prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); + } + + void prefetch(int instruction_number) + { + if (instruction_number % prefetch_spread_ == 0) { + for (int i = 0; (i < prefetch_blk_) + && (prefetches_issued_ < nb_cache_lines_to_prefetch_); + i++, prefetches_issued_++) { + prefetch_inst_(cg_->EVEX_compress_addr( + reg_base_addr_, (cache_block_size_ * prefetch_distance_) + * sizeof(data_t) + + (prefetches_issued_ * 64))); + } + } + } + +private: + void prefetch_inst_(const Xbyak::Address &addr) + { + switch (cache_type_) { + case L1: cg_->prefetcht0(addr); break; + case L2: cg_->prefetcht1(addr); break; + case L3: cg_->prefetcht2(addr); break; + default: + break; // TODO: raise an exception or put an assert + } + } + + jit_generator *cg_; + Xbyak::Reg64 reg_base_addr_; + cache_t cache_type_; + int cache_block_size_ = 0; + int nb_cache_lines_to_prefetch_ = 0; + int prefetches_issued_ = 0; + int prefetch_spread_ = 0; + int prefetch_blk_ = 0; + int prefetch_distance_ = 0; +}; + +// utilities to support kernel parameter selection +bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, + int dimN_block, float C2_min, float C2_max) { + float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic) + * dimN_block * jcp.dimN_reg_block + + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float); + float L2_lb = C2_min * L2_cache_size; + float L2_ub = C2_max * L2_cache_size; + return (block_size > L2_lb && block_size < L2_ub); +} + +bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, + int dimM_block, float C1_min, float C1_max) { + float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block + * jcp.dimK_reg_block * jcp.dimM_reg_block + + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block + + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) + * (float)sizeof(float); + float L1_lb = C1_min * L1_cache_size; + float L1_ub = C1_max * L1_cache_size; + return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); +} +bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block + + dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, + int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) +{ + float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block + * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L1_cache_size; + return (lhs < rhs); +} +bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, + int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, + int dimM_simd_block, float C) +{ + float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block + * dimM_simd_block * dimM_reg_block + + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block + * dimM_simd_block * dimM_reg_block + + nb_dimN_reg_block * dimK_nb_block * dimK_block + * dimN_reg_block * dimK_reg_block) + * (float)sizeof(float); + float rhs = C * L2_cache_size; + return (lhs < rhs); +} + +bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, + int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) +{ + float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK + * (float)sizeof(float); + float B_size = dimN_block * dimN_reg_block * dimK + * (float)sizeof(float); + return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); +} +} + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate() +{ + // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) + // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; + // dimM_reg_block++) // unrolled + // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) + // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; + // dimK_reg_block++) // unrolled + // for (int tile =0; tile < jcp.dimN_reg_block; tile++) + // C[dimM_block][dimM_reg_block][tile] += + // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] + // * broadcast(B[dimK_block][tile][dimK_reg_block]); + // Notes: + // jcp.kernel_kind defines embedded or explicit broadcast + // dimM_reg_block=1 for embedded bcast kernel + + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + auto zmm_srcB = [=](int tile) { + int idx = 1 + tile; + assert(idx < 1 + jcp.dimN_reg_block); + return Xbyak::Zmm(idx); + }; + auto zmm_dstC = [=](int dimM_reg_block, int tile) { + int idx{0}; + if (jcp.kernel_kind == embd_bcast) + idx = 1 + tile; + else + idx = 1 + jcp.dimN_reg_block + + dimM_reg_block * jcp.dimN_reg_block + tile; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + + auto prepare_output = [=]() { + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block, tile); + vpxord(zmm, zmm, zmm); + } + } + }; + auto store_output = [=](bool output_is_aligned) { + Label save; + cmp(reg_is_beta_zero, 0); + je(save, T_NEAR); + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); + } + } + + L(save); + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + Zmm zmm = zmm_dstC(dimM_reg_block,tile); + int output_offset + = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; + + // In W_SGD, output will be reused. + if (output_is_aligned + && jcp.dimK_nb_block == 1 + && jcp.sched_policy == WSCHED_DATA_W_S_G_D + && (jcp.dimN * jcp.dimM * alpha * alpha + * sizeof(float) > 2 * LLC_data_size)) + vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); + else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop; + + if (jcp.dimM_block > 1) { + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + } + + prepare_output(); + + if (jcp.dimK_block > 1) { + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + } + + for (int dimK_reg_block = 0; + dimK_reg_block < jcp.dimK_reg_block; + dimK_reg_block ++) { + + if (jcp.kernel_kind == expl_bcast) { + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + vbroadcastss(zmm_srcB(tile), + ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); + } + } + + /* Performing the fmas */ + + for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; + dimM_reg_block++) { + + vmovups(zmm_srcA(), + zword[reg_srcA + + jcp.dimK_reg_block * jcp.dimK_block * 64 + * dimM_reg_block + + dimK_reg_block * 64] + ); + + for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { + if (jcp.kernel_kind == expl_bcast) + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + zmm_srcB(tile)); + else + vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), + EVEX_compress_addr(reg_srcB, + 64 * tile + dimK_reg_block * 4, true)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block * 64); + add(reg_srcB, jcp.dimN_reg_block * 64); + if (jcp.dimK_block > 1) { + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + Label unaligned_store, end_store; + test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1); + jnz(unaligned_store, T_NEAR); + store_output(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + store_output(false); + } + L(end_store); + + if (jcp.dimM_block > 1) { + sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); + add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); + if (jcp.kernel_kind == expl_bcast) { + add(reg_srcA, + (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64 + * jcp.dimK_block); + } + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + /* kernel */ + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::weights_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int kh = jcp.kh; + int kw = jcp.kw; + + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_zero = Xbyak::Zmm(30); + + auto zmm_M = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_MT = [=](int i) { + return Xbyak::Zmm(i + simd_w); + }; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_F = [=](int i) { + return Xbyak::Zmm(alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(alpha + 3 + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(2 * alpha + 3 + i); + }; + + auto zmm_load = [=](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(wreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < alpha; i++) { + vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); + } + vpxord(zmm_zero, zmm_zero, zmm_zero); + }; + + auto trans16x16 = [=]() { + for (int i = 0; i < simd_w; i+=2 ) { + vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); + vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]); + vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1)); + vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1)); + } + for (int i = 0; i < simd_w; i+=4 ) { + vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2)); + vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2)); + vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3)); + vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3)); + } + for (int i = 0; i < simd_w; i += 8) { + vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); + vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88); + vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88); + vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88); + vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd); + vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd); + vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd); + vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd); + } + { + int i = 0; + int mask = 0x88; + vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); + vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); + vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); + vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); + vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); + vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); + vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); + vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); + mask = 0xdd; + vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); + vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); + vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); + vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); + vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); + vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); + vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); + vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); + vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); + vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); + vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); + vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); + vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); + vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); + vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); + vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); + } + }; + + auto load_src = [=]() { + mov(wreg_src, ptr[param1 + GET_OFF(src)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + for (int j = 0; j < kh; j++) { + for (int i = 0; i < kw; i++) { + if (is_fwd) { + for (int v1 = 0; v1 < simd_w; v1++) { + int offset_src = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w + v1 * simd_w) * typesize; + vmovups(zmm_temp, ptr[wreg_src + offset_src]); + vmovups(ptr[wreg_F + offset_F], zmm_temp); + } + } else { + int offset_src = ((2 - j) * kw * simd_w * simd_w + + (2 - i) * simd_w * simd_w) * typesize; + int offset_F = (j * kw * simd_w * simd_w + + i * simd_w * simd_w) * typesize; + lea(wreg_M, ptr[wreg_src + offset_src]); + lea(wreg_MT, ptr[wreg_F + offset_F]); + trans16x16(); + } + } + } + }; + + auto store_dst = [=]() { + mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + mov(wreg_dst_aux, wreg_dst); + mov(wreg_Fw_aux, wreg_Fw); + + int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimK_block * simd_w * simd_w; + + L(Loop_j); + { + for (int i = 0; i < alpha; i++) { + // touch pages + vmovups(zmm_load(0), ptr[wreg_Fw_aux + + (i * simd_w * simd_w) * typesize]); + mov(wreg_dst_idx, i * dim5 * typesize); + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); + } + for (int i = 0; i < alpha; i++) { + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_Fw = (i * simd_w * simd_w + v1 * simd_w) + * typesize; + vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); + } + mov(wreg_dst_idx, i * dim5 * typesize); + for (int v1 = 1; v1 < simd_w; v1++) { + int offset_dst = v1 * simd_w * typesize; + vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], + zmm_load(v1)); + } + } + add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); + add(wreg_dst_aux, alpha * dim5 * typesize); + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, alpha); + jl(Loop_j, T_NEAR); + } + }; + + auto trans_W_4x4_3x3 = [=]() { + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmovups(dst, a); + vfmadd231ps(dst, b, c); + }; + auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, b, c); + vsubps(dst, a, zmm_temp); + }; + auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vsubps(dst, zmm_zero, a); + vfnmadd231ps(dst, b, c); + }; + + mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); + mov(wreg_F, ptr[param1 + GET_OFF(M)]); + mov(wreg_T, ptr[param1 + GET_OFF(T)]); + + Label Loop_j; + mov(wreg_cnt_j, 0); + L(Loop_j); + mov(wreg_F_aux, wreg_F); + mov(wreg_Fw_aux, wreg_Fw); + mov(wreg_temp, wreg_cnt_j); + shl(wreg_temp, 4 + 2); + lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); + lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); + + for (int i = 0; i < 3; i++) { + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w + * simd_w + i * simd_w * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); + + vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); + fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); + fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); + fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); + vmovaps(zmm_T(5), zmm_F(2)); + + for (int idx = 0; idx < 6; idx ++) { + vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w) + * typesize], zmm_T(idx)); + } + } + for (int i = 0; i < 6; i++) { + + for (int idx = 0; idx < 3; idx ++) { + vmovups(zmm_T(idx), ptr[wreg_T + + (i * 3 * simd_w + idx * simd_w) * typesize]); + } + vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); + fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); + fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); + + vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); + fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); + fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); + fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); + vmovaps(zmm_F(5), zmm_T(2)); + + for (int l = 0; l < 6; l++) { + vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w + + l * simd_w * simd_w) * typesize], zmm_F(l)); + } + } + add(wreg_cnt_j, 1); + cmp(wreg_cnt_j, 16); + jl(Loop_j, T_NEAR); + }; + + auto inner_loops = [=]() { + load_src(); + init_G(); + trans_W_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::output_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int outw = is_fwd ? jcp.ow : jcp.iw; + int outh = is_fwd ? jcp.oh : jcp.ih; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + bool with_bias = jcp.with_bias; + bool with_relu = jcp.with_eltwise; + bool with_relu_postsum = jcp.with_relu_postsum; + bool with_sum = jcp.with_sum; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_O = [=](int i) { + return Xbyak::Zmm(1 + alpha + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + 2 * alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + 3 * alpha + i); + }; + + auto init_G = [=]() { + mov(oreg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < 6; i++) { + vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_src, ptr[param1 + GET_OFF(src)]); + + mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, + (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_nb_tile_block_ur); + + mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(oreg_tile_block_ur, oreg_tile_block_ur, + jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block_ur); + + if (not_tiled) { + mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(oreg_tile_block, oreg_tile_block, + jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block + * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block + * jcp.dimM_simd_block * typesize); + add(oreg_src, oreg_tile_block); + } + + int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) + * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + int j_base_offset = j * alpha * last4dim; + int i_base_offset = i * last4dim; + vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]); + vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_dst = [=]() { + vpxord(zmm_zero, zmm_zero, zmm_zero); + mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(oreg_ydim, 2); // tj * tile_size (==4) + mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(oreg_xdim, 2); // ti * tilesize (==4) + + if (with_bias) + mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); + + auto store_one = [=](int j, int i, bool is_aligned) { + auto zmm_O = Xbyak::Zmm(31); + auto zmm_relu_ns = Xbyak::Zmm(30); + auto xmm_relu_ns = Xbyak::Xmm(30); + int offset = (j * tile_size * simd_w + i * simd_w) * typesize; + + vmovups(zmm_O, ptr[oreg_O + offset]); + if (is_fwd) { + if (with_bias) { + vaddps(zmm_O, zmm_O, ptr[oreg_bias]); + } + if (with_relu) { + if (jcp.eltwise.alpha == 0) { + vmaxps(zmm_O, zmm_O, zmm_zero); + } else { + Opmask kmask = Opmask(7); + mov(imm_addr64, float2int(jcp.eltwise.alpha)); + vmovq(xmm_relu_ns, imm_addr64); + vbroadcastss(zmm_relu_ns, xmm_relu_ns); + vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); + vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); + } + } + } + if (with_sum) { + vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); + if (with_relu_postsum) // orig: with_relu_postsum + vmaxps(zmm_O, zmm_O, zmm_zero); + } + if (is_aligned) + vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); + else + vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); + }; + + auto i_loop = [=](int j, bool is_aligned) { + for (int i = 0; i < tile_size; i++) { + Label next; + mov(oreg_temp, oreg_xdim); + add(oreg_temp, i); + cmp(oreg_temp, outw); + jge(next, T_NEAR); + shl(oreg_temp, 4 + 2); // * 16 * 4 + + store_one(j, i, is_aligned); + + L(next); + } + }; + + + for (int j = 0; j < tile_size; j++) { + Label next, unaligned; + mov(oreg_temp, oreg_ydim); + add(oreg_temp, j); + cmp(oreg_temp, outh); + jge(next, T_NEAR); + + mov(oreg_out_j, oreg_dst); + imul(oreg_temp, oreg_temp, outw * simd_w * typesize); + add(oreg_out_j, oreg_temp); + + test(oreg_dst, 63); + jnz(unaligned, T_NEAR); + + i_loop(j, true); + jmp(next, T_NEAR); + + L(unaligned); + i_loop(j, false); + + L(next); + } + }; + + auto trans_O_4x4_3x3 = [=]() { + auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){ + vmulps(dst, v1, u1); + vfmadd231ps(dst, v2, u2); + }; + mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); + mov(oreg_T, ptr[param1 + GET_OFF(T)]); + mov(oreg_O, ptr[param1 + GET_OFF(M)]); + + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + + vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); + vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); + vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); + vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); + + vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); + fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); + + for (int j = 0; j < tile_size; j++) { + vmovups(ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize], zmm_T(j)); + } + } + for (int j = 0; j < tile_size; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w + + i * simd_w) * typesize]); + } + vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); + vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); + vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); + vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); + + vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); + vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); + fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); + fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); + fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); + vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); + + for (int i = 0; i < tile_size; i++) { + vmovups(ptr[oreg_O + (j * tile_size * simd_w + + i * simd_w) * typesize], zmm_O(i)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_O_4x4_3x3(); + store_dst(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + ::input_transform_data_ker_generate() +{ + bool is_fwd = one_of(jcp.prop_kind, + mkldnn_forward_training, mkldnn_forward_inference); + int inpw = is_fwd ? jcp.iw : jcp.ow; + int inph = is_fwd ? jcp.ih : jcp.oh; + int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; + int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; + int wp_max = inpw + l_pad; + int hp_max = inph + t_pad; + bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; + int G_size = 9; + + auto zmm_zero = Xbyak::Zmm(0); + auto zmm_temp = Xbyak::Zmm(31); + auto zmm_G = [=](int i) { + return Xbyak::Zmm(1 + i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(1 + G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(1 + G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(1 + G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(ireg_temp, ptr[param1 + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + xor_(ireg_zero, ireg_zero); + vpxord(zmm_zero, zmm_zero, zmm_zero); + + mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); + shl(ireg_ydim, 2); // tj * tile_size (==4) + mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); + shl(ireg_xdim, 2); // ti * tilesize (==4) + + for (int j = 0; j < alpha; j++) { + mov(ireg_temp, ireg_ydim); + add(ireg_temp, j); + + mov(ireg_mask_j, 0xffff); + cmp(ireg_temp, t_pad); + cmovl(ireg_mask_j, ireg_zero); + cmp(ireg_temp, hp_max); + cmovge(ireg_mask_j, ireg_zero); + + sub(ireg_temp, t_pad); + imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); + mov(ireg_inp_j, ireg_src); + add(ireg_inp_j, ireg_temp); + + for (int i = 0; i < alpha; i++) { + + mov(ireg_temp, ireg_xdim); + add(ireg_temp, i); + + mov(ireg_mask, 0xffff); + cmp(ireg_temp, l_pad); + cmovl(ireg_mask, ireg_zero); + cmp(ireg_temp, wp_max); + cmovge(ireg_mask, ireg_zero); + and_(ireg_mask, ireg_mask_j); + + sub(ireg_temp, l_pad); + shl(ireg_temp, 4 + 2); + + vpxord(zmm_temp, zmm_temp, zmm_temp); + Opmask kmask = Opmask(7); + kmovw(kmask, ireg_mask_32); + vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); + vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w) + * typesize], zmm_temp); + } + } + }; + + auto store_Iw = [=]() { + + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); + + bool streamout + = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) + > 2 * LLC_data_size + ? true : false; + + if (not_tiled) { + mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); + imul(ireg_tile_block, ireg_tile_block, + alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize); + } + + mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); + imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, + jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block + * jcp.dimK_reg_block * typesize); + + mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); + imul(ireg_tile_block_ur, ireg_tile_block_ur, + jcp.dimK_reg_block * typesize); + + add(ireg_output, ireg_nb_tile_block_ur); + add(ireg_output, ireg_tile_block_ur); + if (not_tiled) + add(ireg_output, ireg_tile_block); + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w + + i * simd_w) * typesize]); + + int j_base_offset = + j * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int i_base_offset = + i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + + if (not_tiled && streamout) + vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + else + vmovups(ptr[ireg_output + j_base_offset + i_base_offset], + zmm_temp); + } + } + }; + + auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { + vmulps(zmm_temp, a, b); + vaddps(dst, zmm_temp, c); + }; + + auto trans_I_4x4_3x3 = [=]() { + mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); + mov(ireg_T, ptr[param1 + GET_OFF(T)]); + mov(ireg_I, ptr[param1 + GET_OFF(M)]); + + mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w + + i * simd_w) * typesize]); + int j_base_offset = + i * alpha * jcp.dimN_block * jcp.dimK_nb_block + * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block + * typesize; + int idx_base_offset = + idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block + * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; + prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w) + * typesize],zmm_T(idx)); + } + } + for (int i = 0; i < alpha; i++) { + for (int idx = 0; idx < alpha; idx++) { + vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx + * simd_w) * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int idx = 0; idx < alpha; idx++) { + vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w) + * typesize],zmm_I(idx)); + } + } + }; + + auto inner_loops = [=]() { + init_G(); + load_src(); + trans_I_4x4_3x3(); + store_Iw(); + }; + + preamble(); + inner_loops(); + postamble(); +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d) +{ + if (!mayiuse(avx512_core)) { + return status::unimplemented; + } + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ver = ver_avx512_core; + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Checking conditions not supported by these kernels + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + if (jcp.wei_tag != wei_tag) + return status::unimplemented; + } + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && (one_of(weights_d.format_kind(), + format_kind::any, format_kind::wino) + || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0])); + if (!layout_consistency) + return status::unimplemented; + + return status::success; +} + +void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { + + /* ----------- dimM reg block ---------------------*/ + auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_reg_block, int current_best) { + int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; + return (dimM_reg_block >= 1) + && (dimM_reg_block <= max_dimM_reg_block ) + && (dimM_reg_block > current_best); + }; + jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); + + /* ----------- dimN reg block ---------------------*/ + + auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_reg_block, int current_best) { + return jcp.kernel_kind == embd_bcast + ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best + : dimN_reg_block >= 1 + && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block) + < jcp.nb_reg + && dimN_reg_block > current_best; + }; + jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp, + jcp.dimN, 1, test_cond_dimN_reg_block); +} + +status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { + if (jcp.ver != ver_avx512_core) + return status::unimplemented; + + jcp.kernel_kind = embd_bcast; + + set_kernel_dims_reg_block(jcp); + + /*-------------- L2 blocking for dimN block ---------*/ + + auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, + int dimN_block, int current_best) { + return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) + && (dimN_block > current_best) + && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) + >= 1.5 * mkldnn_get_max_threads()); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + + if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) + && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) { + + /* ------------------- L1 blocking for GEMM --------------*/ + /* -------------------- Choose dimK block ----------------*/ + + auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, + int dimK_block, int current_best) { + return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); + + if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + + /* -------------- Choose dimM block -------------------*/ + auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp, + int dimM_block, int current_best) { + return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, + 0.2, 0.5) && (dimM_block > current_best); + }; + + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, + test_cond_dimM_block); + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + + jcp.sched_policy = WSCHED_DATA_W_SGD; + return status::success; + } + + } + return status::unimplemented; +} + +void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { + + set_kernel_dims_reg_block(jcp); + + //********************* Choosing dimK_block **********************// + auto test_cond1_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, + 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) + && (dimK_block > current_best); + }; + + auto test_cond1_bis_dimK_block = []( + jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, dimK_block, + jcp.dimK_reg_block, 1, jcp.dimM_reg_block, + jcp.dimM_simd_block, .9f) + && (dimK_block > current_best); + }; + + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); + // If we are not able to use streams, we fall back to condition [1] + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimK_block = get_divisor_satisfying_cond( + jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); + jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; + + //********************* Choosing dimM_block **********************// + auto test_cond1_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .5f) + && (dimM_block > current_best); + }; + + auto test_cond1_bis_dimM_block = []( + jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) { + return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, + jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, .3f) + && (dimM_block > current_best); + }; + + if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) + jcp.dimM_block = get_divisor_satisfying_cond( + jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_dimM_block); + else + jcp.dimM_block = get_divisor_satisfying_cond(jcp, + jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1, + test_cond1_bis_dimM_block); + jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block + * jcp.dimM_reg_block); + + //******************* Choosing dimN_block *******************// + auto test_cond2_dimN_block = []( + jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) { + return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, + jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, + jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) + && (dimN_block > current_best); + }; + + jcp.dimN_block = get_divisor_satisfying_cond( + jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); + jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); +} + +status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { + + jcp.kernel_kind = expl_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, + jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK, + .1f, .35f))) { + jcp.kernel_kind = embd_bcast; + set_kernel_blocking_DATA_W_S_G_D(jcp); + } + jcp.sched_policy = WSCHED_DATA_W_S_G_D; + return status::success; +} + +status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) +{ + jcp.nb_reg = 32; + jcp.dimN = dimN; + jcp.dimK = dimK; + jcp.dimM = dimM; + jcp.sched_policy = WSCHED_INVALID; + + jcp.dimK_reg_block = 16; + jcp.dimM_simd_block = 16; + + if (jcp.kernel_kind == embd_bcast) { + jcp.dimM_reg_block = 1; + } + + if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) + set_wsched_DATA_W_S_G_D_avx512_core(jcp); + + assert(jcp.sched_policy != WSCHED_INVALID); + return status::success; +} + +bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_relu(0) || is_sum(0); // relu or sum + case 2: return (is_sum(0) && is_relu(1)) + || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum + case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu + default: return false; + } + + return false; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_t &src_md, memory_desc_t &weights_md, + const memory_desc_t &dst_md, const primitive_attr_t &attr) { + + status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); + + if (st != status::success) + return st; + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; + jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; + + status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); + + jcp.ic_simd_block = jcp.dimK_reg_block; + jcp.ic_block = jcp.dimK_block; + jcp.nb_ic = jcp.dimK_nb_block; + jcp.oc_simd_block = jcp.dimM_simd_block; + jcp.oc_block = jcp.dimM_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.ic_reg_block = 1; + jcp.nb_oc = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + if (cd.prop_kind == mkldnn_forward_inference) { + memory_desc_t expect_wei_md = weights_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::f32; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_OBaaIBOIio; + wd.r = 3; + wd.alpha = 6; + + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.dimK_reg_block; + wd.oc_block = jcp.dimM_simd_block; + wd.ic2_block = jcp.dimK_block; + wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; + size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; + wd.size = max_size; + wd.adj_scale = 1.f; + + if (weights_md.format_kind == format_kind::any) + weights_md = expect_wei_md; + if (weights_md != expect_wei_md) + return status::unimplemented; + } + + return res; +} + +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) +{ + status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); + + if (st != status::success) + return st; + + jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); + + jcp.oc_simd_block = jcp.dimK_reg_block; + jcp.oc_block = jcp.dimK_block; + jcp.nb_oc = jcp.dimK_nb_block; + jcp.ic_simd_block = jcp.dimM_simd_block; + jcp.ic_block = jcp.dimM_block; + jcp.ic_reg_block = jcp.dimM_reg_block; + jcp.oc_reg_block = 1; + jcp.nb_ic = jcp.dimM_nb_block; + jcp.tile_block_ur = jcp.dimN_reg_block; + jcp.nb_tile_block_ur = jcp.dimN_block; + jcp.tile_block = jcp.dimN_nb_block; + + return res; +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +src_transform_generate() { + constexpr int G_size = 9; + const size_t ifwp = jcp.iw + jcp.l_pad; + const size_t ifhp = jcp.ih + jcp.t_pad; + + auto zmm_G = [=](int i) { + return Xbyak::Zmm(i); + }; + auto zmm_I = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + auto zmm_T = [=](int i) { + return Xbyak::Zmm(G_size + alpha + i); + }; + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 2 * alpha + i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) { + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + } + }; + + auto load_src = [=]() { + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + xor_(reg_zero, reg_zero); + + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + + for (int j = 0; j < alpha; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.t_pad); + cmovl(reg_maskj, reg_zero); + cmp(reg_ydim, ifhp); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad + imul(reg_src_offset, reg_src_offset, jcp.iw); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + sub(reg_src_offset, jcp.l_pad); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < alpha; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.l_pad); + cmovl(reg_maski, reg_zero); + cmp(reg_xdim, ifwp); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + auto zmm_src = Xbyak::Zmm(31); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src, zmm_src, zmm_src); + vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); + vmovups(ptr[reg_I], zmm_src); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + add(reg_I, simd_w * typesize); + } + add(reg_ydim, 1); + } + }; + + auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { + vmovups(dst, c); + vfmadd231ps(dst, a, b); + }; + + auto trans_I_3x3_4x4 = [=]() { + //Use 24 registers + mov(reg_I, ptr[reg_transp + GET_OFF(M)]); + mov(reg_T, ptr[reg_transp + GET_OFF(T)]); + for (int i = 0; i < alpha; i++) { + for (int j = 0; j < alpha; j++) { + size_t I_off = (j * alpha + i) * simd_w * typesize; + vmovups(zmm_I(j), ptr[reg_I + I_off]); + } + + fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); + fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); + fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); + fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); + fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); + fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); + + fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); + fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); + + for (int j = 0; j < alpha; j++) { + vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], + zmm_T(j)); + } + + } + + for (int j = 0; j < alpha; j++) { + for (int i = 0; i < alpha; i++) { + vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]); + } + + fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); + fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); + fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); + fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); + fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); + fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); + + fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); + fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); + fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); + fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); + fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); + fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); + + for (int i = 0; i < alpha; i++) { + size_t dst_off = (j * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur) + * simd_w * typesize; + vmovups(ptr[reg_dst + dst_off], zmm_I(i)); + } + } + }; + + auto compute_transform_SDGtWo = [=]() { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + xor_(reg_tile_count, reg_tile_count); + Label loop_mb, loop_jtiles, loop_itiles, done; + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(done); + + add(reg_dst, simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); + jmp(loop_mb); + } + L(done); + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, simd_w * typesize); + add(reg_dst, reg_temp); + + Label loop_jtiles, loop_itiles, next_tile_block, next_tile; + L(loop_jtiles); + + { + L(loop_itiles); + { + load_src(); + + trans_I_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * simd_w * typesize); + size_t tblk_off = alpha * alpha * jcp.ic_block + * jcp.nb_tile_block_ur * jcp.tile_block_ur + * simd_w * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + }; + + preamble(); + init_G(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + compute_transform_SDGtWo(); + else + compute_transform(); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_dst_transform_generate(bool with_bias) { + + constexpr int G_size = 8; + auto zmm_G = [](int i) { + return Xbyak::Zmm(31); + }; + + auto zmm_src = [=](int j, int i) { + return Xbyak::Zmm(G_size + j * 4 + i); + }; + + auto zmm_bias = Xbyak::Zmm(31); + + auto load_src = [=]() { + if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); + mov(reg_ydim, reg_tj); + shl(reg_ydim, 2); //tj * tile_size(=4) + for (int j = 0; j < tile_size; j++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maskj, 0xffff); + cmp(reg_ydim, jcp.oh); + cmovge(reg_maskj, reg_zero); + + /*address offset for tile in src*/ + mov(reg_src_offset, reg_ydim); + imul(reg_src_offset, reg_src_offset, jcp.ow); + + mov(reg_xdim, reg_ti); + shl(reg_xdim, 2); // xdim = ti * tile_size + + add(reg_src_offset, reg_xdim); + imul(reg_src_offset, reg_src_offset, simd_w * typesize); + for (int i = 0; i < tile_size; i++) { + /* check if tile index is within physical spatial boundaries*/ + mov(reg_maski, 0xffff); + cmp(reg_xdim, jcp.ow); + cmovge(reg_maski, reg_zero); + and_(reg_maski, reg_maskj); + + Opmask kmask_src = Xbyak::Opmask(7); + kmovw(kmask_src, reg_maski_32); + vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); + vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]); + if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias, + ptr[reg_src + reg_src_offset]); + + add(reg_xdim, 1); //xdim = ti * tile_size + i + add(reg_src_offset, simd_w * typesize); + } + add(reg_ydim, 1); + } + if(with_bias) vmovups(ptr[reg_bias], zmm_bias); + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 16 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(j * 4 + i); + }; + + auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) + vmovups(ptr[reg_dst + dst_off], a); + else + vmovntps(ptr[reg_dst + dst_off], a); + }; + + auto trans_W_3x3_4x4 = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < tile_size; i++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); + vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); + vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); + vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); + vmovups(zmm_T(5, i), zmm_src(3, i)); + } + + for (int j = 0; j < alpha; j++) { + vbroadcastss(zmm_G(0), ptr[reg_G]); + vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); + + vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); + vmovups(zmm_t(1), zmm_t(0)); + vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); + + vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); + vmovups(zmm_t(2), zmm_t(0)); + vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); + + vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); + vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); + + vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); + vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); + + vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); + vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); + + vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); + vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); + + vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); + vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); + vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); + vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); + vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); + vmovups(zmm_t(3), zmm_T(j, 3)); + + int alpha_offset = (jcp.oc / jcp.nb_oc) + * (jcp.ntiles / jcp.tile_block) * typesize; + int dst_off = j * alpha * alpha_offset; + movps(reg_dst, dst_off, zmm_t(0)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(5)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(1)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(6)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(2)); + dst_off += alpha_offset; + movps(reg_dst, dst_off, zmm_t(3)); + } + + }; + auto compute_transform_SDGtWo = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; + + L(loop_oc_ur); + { + mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); + mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); + xor_(reg_tile_count, reg_tile_count); + L(loop_mb); + { + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(tiles_done); + + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + xor_(reg_tj, reg_tj); + add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); + jmp(loop_mb); + } + + L(tiles_done); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + auto compute_transform = [=]() { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + + xor_(reg_zero, reg_zero); + xor_(reg_oc_ur, reg_oc_ur); + Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile; + + L(loop_oc_ur); + { + xor_(reg_ti, reg_ti); + xor_(reg_tj, reg_tj); + + L(loop_jtiles); + { + L(loop_itiles); + { + load_src(); + + trans_W_3x3_4x4(); + + add(reg_tile_count, 1); + cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); + jge(next_tile_block); + add(reg_dst, jcp.oc_reg_block * simd_w * typesize); + jmp(next_tile); + + L(next_tile_block); + sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) + * jcp.oc_reg_block * simd_w * typesize); + int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc) + * (jcp.ntiles/jcp.tile_block) * typesize; + add(reg_dst, tblk_off); + xor_(reg_tile_count, reg_tile_count); + + L(next_tile); + add(reg_ti, 1); + cmp(reg_ti, jcp.itiles); + jl(loop_itiles); + } + xor_(reg_ti, reg_ti); + add(reg_tj, 1); + cmp(reg_tj, jcp.jtiles); + jl(loop_jtiles); + } + + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); + imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, reg_temp); + add(reg_dst, simd_w * typesize); + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); + + if (with_bias) add(reg_bias, simd_w * typesize); + add(reg_oc_ur, 1); + cmp(reg_oc_ur, jcp.oc_reg_block); + jl(loop_oc_ur); + } + }; + + preamble(); + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + compute_transform_SDGtWo(); + } else { + compute_transform(); + } + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel:: +diff_weights_transform_generate(bool first_tile) { + int G_size = 4; + + auto zmm_G = [](int i) { + return Xbyak::Zmm(i); + }; + + auto init_G = [=]() { + mov(reg_G, ptr[reg_transp + GET_OFF(G)]); + for (int i = 0; i < G_size; i++) + vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); + }; + + auto zmm_src = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto load_src = [=](int i) { + for (int j = 0; j < alpha; j++) { + size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block + * jcp.ic_block * simd_w * simd_w * typesize; + size_t src_off = (j * alpha + i) * alpha_offset; + vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); + } + }; + + auto zmm_t = [=](int i) { + return Xbyak::Zmm(G_size + 6 + i); + }; + + auto zmm_T = [=](int j, int i) { + return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); + }; + + auto zmm_dst = [=](int i) { + return Xbyak::Zmm(G_size + i); + }; + + auto zmm_temp = Xbyak::Zmm(31); + + auto store_dst = [=](int j) { + for (int i = 0; i < jcp.kw; i++) { + size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; + + if (!first_tile) { + vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); + vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); + } + vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); + } + }; + + auto compute_transform = [=] () { + mov(reg_src, ptr[reg_transp + GET_OFF(src)]); + mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); + + xor_(reg_ic_simd, reg_ic_simd); + Label loop_ic_simd; + L(loop_ic_simd); + { + for (int i = 0; i < alpha; i++) { + load_src(i); + + vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); + vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); + vmovups(zmm_t(2), zmm_src(5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); + vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); + vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); + vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); + vsubps(zmm_temp, zmm_src(3), zmm_src(4)); + vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); + vmovups(zmm_T(2, i), zmm_t(2)); + vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); + } + + for (int j = 0; j < jcp.kh; j++) { + vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); + vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); + vmovups(zmm_t(2), zmm_T(j, 5)); + vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); + + vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); + vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); + vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); + vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); + vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); + vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); + vmovups(zmm_dst(2), zmm_t(2)); + vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); + + store_dst(j); + } + + add(reg_src, jcp.oc_reg_block * simd_w * typesize); + add(reg_dst, simd_w * typesize); + add(reg_ic_simd, 1); + cmp(reg_ic_simd, simd_w); + jl(loop_ic_simd); + } + }; + preamble(); + push(reg_EVEX_max_8b_offt); + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + init_G(); + compute_transform(); + pop(reg_EVEX_max_8b_offt); + postamble(); +} + +void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( + bool is_first_tile) +{ + auto zmm_srcA = [=]() { + return Xbyak::Zmm(0); + }; + + auto zmm_srcB = [=] (size_t N_ur){ + return Xbyak::Zmm(N_ur + 1); + }; + + auto broadcastB = [=](size_t K_ur) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast) + * sizeof(float); + vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); + } + }; + + auto load_srcA = [=] (size_t K_ur, int M_ur) { + size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_ur * jcp.dimM_simd_block) * sizeof(float); + vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); + }; + + auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){ + size_t idx = 1 // zmm_srcA + + jcp.dimN_bcast_ur // zmm_srcB + + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; + assert(idx < 32); + return Xbyak::Zmm(idx); + }; + auto prepare_accumm = [=](){ + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { + Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); + vpxord(zmm, zmm, zmm); + } + } + }; + + auto store_dstC = [=](){ + /******** Write C back to memory *******/ + for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { + for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { + Zmm zmm = zmm_dstC(M_reg, N_ur); + size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block + + M_reg * jcp.dimM_simd_block) * sizeof(float); + if (!is_first_tile) { + vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); + vaddps(zmm, zmm, Xbyak::Zmm(0)); + } + vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); + } + } + }; + + auto inner_loops = [=]() { + Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; + + mov(reg_dimM_block_loop_cnt, jcp.dimM_block); + L(dimM_block_loop); + { /************* OC_block (M) loop ***********/ + mov(reg_dimN_block_loop_cnt, jcp.dimN_block); + L(dimN_block_loop); + { /*************** IC_block (N) loop *********/ + + mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur); + L(dimN_bcast_ur); + { + prepare_accumm(); + + mov(reg_dimK_block_loop_cnt, jcp.dimK_block); + L(dimK_block_loop); + { + /************* nb_tile_ur(K) loop ********/ + for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { + + broadcastB(K_ur); + + for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { + load_srcA(K_ur, M_reg_ur); + for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) { + vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(), + zmm_srcB(N_bcast)); + } + } + } + add(reg_srcA, jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + add(reg_srcB, jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + sub(reg_dimK_block_loop_cnt, 1); + jnz(dimK_block_loop); + } + + store_dstC(); + + sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); + add(reg_dstC, jcp.dimN_bcast_ur + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_nb_dimN_bcast_ur, 1); + jnz(dimN_bcast_ur); + } + + sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); + add(reg_srcB, jcp.dimK_block + * jcp.dimK_reg_block + * jcp.dimN_reg_block * sizeof(float)); + sub(reg_dimN_block_loop_cnt, 1); + jnz(dimN_block_loop); + } + + sub(reg_srcB, jcp.dimN_block + * jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimN_reg_block + * sizeof(float)); + add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block + * jcp.dimM_reg_block * jcp.dimM_simd_block + * sizeof(float)); + sub(reg_dimM_block_loop_cnt, 1); + jnz(dimM_block_loop); + } + }; + + /* Preamble */ + preamble(); + + inner_loops(); + + /* Postamble */ + postamble(); + ret(); +} + +namespace { + +void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { +/*M params*/ + jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block + / jcp.dimM_simd_block; + jcp.oc_reg_block = jcp.dimM_reg_block; + jcp.oc_block = jcp.dimM_block; + jcp.nb_oc = jcp.dimM_nb_block; + /*N params*/ + jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; + jcp.ic_block = jcp.dimN_block; + jcp.nb_ic = jcp.dimN_nb_block; + + /*K params*/ + jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; + jcp.tile_block_ur = jcp.dimK_reg_block; + jcp.nb_tile_block_ur = jcp.dimK_block; + jcp.tile_block = jcp.dimK_nb_block; +} + +status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { + + size_t K_blk_ur, N_blk, M_blk; + /* IS this strategy feasible? */ + auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { + size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); + size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) + && (jcp.dimK / nthreads >= 1.0); + }; + + auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, + int max_block=1) { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); + size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance=true; + if (!(jcp.dimK % nthreads)) { + load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); + } + return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) + && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) + && load_balance + && (M_L2_block < L2_cache_size); + }; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&](){ + size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * K_blk_ur * sizeof(float); + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block + * K_blk_ur * sizeof(float); + size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block + * N_blk * jcp.dimN_reg_block * sizeof(float); + size_t L2_block = M_L2_block + V_L2_block + U_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size); + }; + + if (test_MV_large_enough(jcp)) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimN_bcast_ur = 8; + /*dimK_block and dimK_ur*/ + size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1); + + jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block; + jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block; + for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { + if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (!(jcp.dimN_block % N_blk)) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (!(jcp.dimM_block % M_blk) && blocking_ok()) { + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented; + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.sched_policy = WSCHED_WEI_SDGtWo; + set_jcp_WEI_params(jcp); + jcp.nthr = nstl::min(mkldnn_get_max_threads(), + jcp.tile_block); + return status::success; + } + } + } + } + } + } + } + return status::unimplemented; +} + +status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { + if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) { + jcp.dimM_reg_block = 2; + } else { + jcp.dimM_reg_block = 1; + } + jcp.dimN_bcast_ur = 8; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; + jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; + float C1 = 0.0, C2 = 0.0; + float C1_max = 0.5, C2_max = 1.4; + int N_blk, M_blk, K_blk_ur; + + auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, + int useless_arg=0) { + return (dimK_ur >= 2) && (dimK_ur <= 8); + }; + + auto blocking_ok = [&]() -> bool { + size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float); + size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); + bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) + && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); + + size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block; + size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block; + size_t nb_K_blk = jcp.dimK / K_blk_ur; + size_t nthreads = mkldnn_get_max_threads(); + bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; + if (!(nb_K_blk % nthreads)) { + load_balance = load_balance && (nb_K_blk % nthreads == 0); + } + + size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float); + + size_t L2_block = V_L2_block; + /*Replace 2.375 with L2+L3 cache size*/ + bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size); + return L1_cond && load_balance && L2_cond; + }; + + for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { + if (jcp.dimK % K_blk_ur == 0) { + for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { + if (jcp.dimN_block % N_blk == 0) { + for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { + if (jcp.dimM_block % M_blk == 0) { + if (blocking_ok()) { + jcp.dimN_block = N_blk; + jcp.dimM_block = M_blk; + jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur); + jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; + } + } + } + } + } + } + } + jcp.dimK_reg_block = 1; + jcp.dimK_block = 1; + jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; + set_jcp_WEI_params(jcp); + return status::success; +} +} // namespace +status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf( + jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d) { + if (!mayiuse(avx512_core)) + return status::unimplemented; + else + jcp.ver = ver_avx512_core; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.prop_kind = cd.prop_kind; + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + jcp.mb = src_d.dims()[0]; + jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + jcp.kh = diff_weights_d.dims()[with_groups + 2]; + jcp.kw = diff_weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.r_pad = nstl::max( + 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); + jcp.b_pad = nstl::max( + 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + jcp.ohp = jcp.oh; + jcp.owp = jcp.ow; + jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + bool ok_to_pad_channels = jcp.ngroups == 1; + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + } + + // Winograd specific initialization + jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; + jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; + jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; + + // Winograd kernel works only for 3x3 convolution with stride 1 + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + if (jcp.ngroups != 1) + return status::unimplemented; + if ((jcp.kh != 3) || (jcp.kw != 3)) + return status::unimplemented; + if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0)) + return status::unimplemented; + if ((jcp.stride_h != 1) || (jcp.stride_w != 1)) + return status::unimplemented; + if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0) + return status::unimplemented; + + format_tag_t dat_tag = nChw16c; + format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + if (jcp.src_tag != dat_tag) return status::unimplemented; + if (jcp.wei_tag != wei_tag) return status::unimplemented; + if (jcp.dst_tag != dat_tag) return status::unimplemented; + + bool layout_consistency = true + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; + if (!layout_consistency) return status::unimplemented; + + /******************Kernel blocking Parameters ***********/ + jcp.ic_simd_block = simd_w; + jcp.oc_simd_block = simd_w; + + jcp.dimK = jcp.ntiles; + jcp.dimN = jcp.ic; + jcp.dimM = jcp.oc; + jcp.dimM_simd_block = jcp.oc_simd_block; + jcp.dimN_reg_block = jcp.ic_simd_block; + jcp.sched_policy = WSCHED_INVALID; + status_t res = set_wsched_WEI_SDGtWo(jcp); + if (res == status::unimplemented) { + res = set_wsched_WEI_S_D_Giot_W(jcp); + assert(res == status::success); + } + return res; +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp new file mode 100644 index 0000000000..025a554d92 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp @@ -0,0 +1,291 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP +#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP + +#include "c_types_map.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +#include "jit_avx512_common_conv_winograd_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel + : public jit_generator { + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) { + { + this->weights_transform_data_ker_generate(); + weights_transform_data_ker + = (decltype(weights_transform_data_ker)) this->getCode(); + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->input_transform_data_ker_generate(); + input_transform_data_ker = (decltype(input_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->output_transform_data_ker_generate(); + output_transform_data_ker + = (decltype(output_transform_data_ker))addr; + } + { + align(); + const Xbyak::uint8 *addr = getCurr(); + this->gemm_loop_generate(); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + } + + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel) + + static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d); + + static status_t init_conf_kernel( + jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *, const int); + void (*input_transform_data_ker)(jit_wino_transform_call_s *); + void (*output_transform_data_ker)(jit_wino_transform_call_s *); + void (*weights_transform_data_ker)(jit_wino_transform_call_s *); + +protected: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void gemm_loop_generate(); + void input_transform_data_ker_generate(); + void output_transform_data_ker_generate(); + void weights_transform_data_ker_generate(); + + /* registers used for GEMM */ + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + reg64_t reg_is_beta_zero = abi_param4; + + reg64_t reg_dimM_block_loop_cnt = r10; + reg64_t reg_dimK_block_loop_cnt = r11; + + /* registers used for transforms*/ + reg64_t param = abi_param1; + + /* registers used for output_transform_data_ker */ + reg64_t oreg_temp = abi_not_param1; + reg64_t oreg_Ow = r9; + reg64_t oreg_src = r11; + reg64_t oreg_tile_block = r12; + reg64_t oreg_tile_block_ur = r13; + reg64_t oreg_nb_tile_block_ur = r14; + reg64_t oreg_O = r8; + reg64_t oreg_T = r10; + reg64_t oreg_dst = r11; + reg64_t oreg_ydim = r14; + reg64_t oreg_xdim = r15; + reg64_t oreg_out_j = r12; + reg64_t oreg_bias = rbx; + reg64_t imm_addr64 = rax; + + /* registers used for input_transform_data_ker */ + reg64_t ireg_temp = abi_not_param1; + reg64_t ireg_jtiles = rax; + reg64_t ireg_itiles = rbx; + reg64_t ireg_I = r8; + reg64_t ireg_src = r13; + reg64_t ireg_ydim = r14; + reg64_t ireg_xdim = r15; + reg64_t ireg_inp_j = r12; + reg64_t ireg_inp_i = rdx; + reg64_t ireg_mask_j = r11; + reg64_t ireg_mask = rsi; + reg32_t ireg_mask_32 = esi; + reg64_t ireg_zero = r9; + reg64_t ireg_Iw = r9; + reg64_t ireg_T = r10; + reg64_t ireg_tile_block = r12; + reg64_t ireg_tile_block_ur = r13; + reg64_t ireg_nb_tile_block_ur = r14; + reg64_t ireg_output = r15; + + /* registers used for wei transform */ + reg64_t wreg_temp = abi_not_param1; + reg64_t wreg_F = r8; + reg64_t wreg_src = r9; + reg64_t wreg_MT = r15; + reg64_t wreg_M = r14; + reg64_t wreg_dst = r10; + reg64_t wreg_dst_aux = r9; + reg64_t wreg_dst_idx = r8; + reg64_t wreg_Fw = r11; + reg64_t wreg_T = r12; + reg64_t wreg_cnt_j = rdx; + reg64_t wreg_F_aux = r14; + reg64_t wreg_Fw_aux = r15; +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel + : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_t &src_md, + memory_desc_t &weights_md, const memory_desc_t &dst_md, + const primitive_attr_t &attr); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel + : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel { + using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel:: + _jit_avx512_core_fp32_wino_conv_4x3_data_kernel; + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); +}; + +struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel + : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) + + jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel( + jit_conv_winograd_conf_t ajcp) + : jcp(ajcp) + { + //******************* First iter kernel ********************// + this->gemm_loop_generate(true); + gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode(); + + align(); + const Xbyak::uint8 *addr = getCurr(); + this->src_transform_generate(); + src_transform = (decltype(src_transform))addr; + + if (jcp.with_bias) { + align(); + addr = getCurr(); + this->diff_dst_transform_generate(true); + diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; + } + + align(); + addr = getCurr(); + this->diff_dst_transform_generate(false); + diff_dst_transform = (decltype(diff_dst_transform))addr; + + if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { + align(); + addr = getCurr(); + this->gemm_loop_generate(false); + gemm_loop_ker = (decltype(gemm_loop_ker))addr; + } + + align(); + addr = getCurr(); + this->diff_weights_transform_generate(true); + diff_weights_transform = (decltype(diff_weights_transform))addr; + + if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { + align(); + addr = getCurr(); + this->diff_weights_transform_generate(false); + diff_weights_transform_accum = + (decltype(diff_weights_transform_accum))addr; + }; + } + + static status_t init_conf(jit_conv_winograd_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_dst_d, + const memory_desc_wrapper &diff_weights_d); + + jit_conv_winograd_conf_t jcp; + void (*gemm_loop_ker)(float *, const float *, const float *); + void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); + void (*src_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform)(jit_wino_transform_call_s *); + void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); + void (*diff_weights_transform)(jit_wino_transform_call_s *); + void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + enum { typesize = sizeof(float) }; + + void src_transform_generate(); + void diff_dst_transform_generate(bool with_bias); + void diff_weights_transform_generate(bool first_tile); + + /*registers common to transforms*/ + reg64_t reg_transp = abi_param1; + reg64_t reg_ti = rbx; + reg64_t reg_tj = abi_not_param1; + reg64_t reg_src = r8; + reg64_t reg_dst = r9; + reg64_t reg_G = rsi; /*TODO: check if this is ok*/ + reg64_t reg_temp = rsi; + + /*registers common to src/diff_dst transform*/ + reg64_t reg_I = r10; + reg64_t reg_ydim = r11; + reg64_t reg_xdim = r12; + reg64_t reg_src_offset = r13; + reg64_t reg_zero = r14; + reg64_t reg_tile_count = r15; + reg64_t reg_maski = rsi; + reg32_t reg_maski_32 = esi; + reg64_t reg_maskj = rdx; + + reg64_t reg_T = rax; + reg64_t reg_oc_ur = rax; + reg64_t reg_ic_simd = r14; + reg64_t reg_bias = r10; + + void gemm_loop_generate(bool is_first_tile); + + reg64_t reg_dstC = abi_param1; + reg64_t reg_srcA = abi_param2; + reg64_t reg_srcB = abi_param3; + + reg64_t reg_dimM_block_loop_cnt = r9; + reg64_t reg_dimN_block_loop_cnt = r10; + reg64_t reg_nb_dimN_bcast_ur = r11; + reg64_t reg_dimK_block_loop_cnt = r12; +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp new file mode 100644 index 0000000000..002010ffa2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp @@ -0,0 +1,1284 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp" +#include "jit_generator.hpp" + +#include <string.h> + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { + // Below scales are applied to source and weights data accordingly + // because this winograd implementation + // transforms source which may increase values up to 4x + // and transforms weights which may increase values up to 9/4x + const float adj_src_scale = 1.f / 4.f; + const float adj_wei_scale = 4.f / 9.f; + // Winograd transforms need ic and oc to be multiples of 16 + const int load_block = 16; +} + +/// SRC TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *wino_src; + const void *v_y_masks; + const void *v_x_masks; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode())); + } + void generate(); + + int reg_inp_ind(int i) { + assert(i < jcp.alpha * jcp.alpha); + return (31 - i); + } + + Xmm vreg_inp(int i) { + return Xmm(reg_inp_ind(i)); + } + + Zmm zmm_inp(int i) { + return Zmm(reg_inp_ind(i)); + } + + Xmm vreg_tmp(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(15 - i); + } + Xmm vreg_out(int i) { + assert(i < jcp.alpha * jcp.alpha); + return Xmm(31 - i); + } + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_ic_block = r8; + + int unsign_val_in_wino_domain; + + Reg64 reg_scratch_src_alpha = rdx; + Xmm xmm_src_alpha = Xmm(0); + Zmm zmm_src_alpha = Zmm(0); + + Reg64 reg_shift = rax; + Xmm xmm_shift = Xmm(1); + Xmm xmm_zero = Xmm(0); + + Reg64 reg_maskx = rbx; + Reg64 reg_masky = rsi; + Reg64 reg_nomask = reg_maskx; +}; + +void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() { + Label ic_block_label; + Label end_label; + Label mask_label; + Label nomask_label; + + auto load_src = [=](bool mask) { + for (int y = 0; y < jcp.alpha; y++) { + if (mask) + kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]); + for (int x = 0; x < jcp.alpha; x++) { + Zmm zmm_i = zmm_inp(y * jcp.alpha + x); + Xmm vreg_i = vreg_inp(y * jcp.alpha + x); + int inp_offset = sizeof(uint8_t) + * ((-jcp.t_pad + y) * jcp.iw * jcp.ic + + (-jcp.l_pad + x) * jcp.ic); + if (mask) { + kandw(r_mask, y_mask, x_mask(x)); + vmovdqu8(vreg_i | r_mask | T_z, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } else { + vmovdqu8(vreg_i, + EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); + } + vpmovzxbd(zmm_i, vreg_i); // to int32 + vcvtdq2ps(zmm_i, zmm_i); // to fp32 + vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha + vcvtps2dq(zmm_i, zmm_i); // to int32 + vpmovusdb(vreg_i, zmm_i); // to u8 + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_ptr_dst, wino_src); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); +# undef READ_PARAM + + mov(reg_maskx, ptr[reg_ptr_v_x_masks]); + mov(reg_masky, ptr[reg_ptr_v_y_masks]); + test(reg_maskx, reg_maskx); + jz(end_label, T_NEAR); // skip kernel if x mask is all 0's + test(reg_masky, reg_masky); + jz(end_label, T_NEAR); // skip kernel if y mask is all 0's + and_(reg_maskx, reg_masky); + mov(reg_nomask, reg_maskx); + not_(reg_nomask); // zero if x and y masks are all 1's + + xor_(reg_shift, reg_shift); + mov(reg_shift.cvt8(), (int8_t)-128); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + for (int i = 0; i < jcp.alpha; i++) { + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + } + + mov(reg_scratch_src_alpha, float2int(adj_src_scale)); + + mov(reg_ic_block, jcp.ic / load_block); + L(ic_block_label); + { + vmovq(xmm_src_alpha, reg_scratch_src_alpha); + vbroadcastss(zmm_src_alpha, xmm_src_alpha); + + test(reg_nomask, reg_nomask); + jz(nomask_label, T_NEAR); + load_src(true); + jmp(mask_label, T_NEAR); + L(nomask_label); + load_src(false); + L(mask_label); + + for(int y = 0; y < 4; y++) { + vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2)); + vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2)); + vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1)); + vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3)); + } + for(int x = 0;x < 4; x++) { + vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2)); + vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2)); + vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1)); + vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3)); + } + + vmovd(xmm_shift, reg_shift.cvt32()); + vpxor(xmm_zero, xmm_zero, xmm_zero); + vpshufb(xmm_shift, xmm_shift, xmm_zero); + + for (int i = 0; i < 16; i++) { + int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i); + if (i != unsign_val_in_wino_domain) + vpsubb(vreg_out(i), vreg_out(i), Xmm(1)); + vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i)); + } + + add(reg_aux_ptr_src, sizeof(uint8_t) * load_block); + add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block); + } + dec(reg_ic_block); + jnz(ic_block_label, T_NEAR); + + L(end_label); + postamble(); +} + +/// DST TRANSFORMS ///////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS( + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t) + + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *wino_dst; + const void *dst; + const void *v_y_masks; + const void *v_x_masks; + + const void *bias; + const void *scales; + }; + void (*ker_)(const call_params_t *); + + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode())); + } + + void generate(); + bool maybe_relu(int position); + + Zmm vreg_inp(int i) { // 16 + assert(i < jcp.alpha * jcp.alpha); + return Zmm(31 - i); + } + Zmm vreg_stg(int id) { // 8 + const int id_reg_stg = jcp.alpha * jcp.alpha + id; + assert(id < 8); + return Zmm(31 - id_reg_stg); + } + Zmm vreg_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Zmm(31 - id_reg_out); + } + Xmm xmm_out(int id) { // 4 + const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; + assert(id < 4); + return Xmm(31 - id_reg_out); + } + Zmm vreg_tmp(int id) { // 2 + const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; + assert(id < 2); + return Zmm(31 - id_reg_tmp); + } + + Zmm vreg_zero = Zmm(0); + Zmm vreg_bias = Zmm(1); + Zmm vreg_prev_dst = Zmm(2); + Zmm zmm_bias_alpha = Zmm(2); + Xmm xmm_bias_alpha = Xmm(2); + + Opmask y_mask = Opmask(1); + Opmask r_mask = Opmask(2); + Opmask x_mask(int id) { + assert(id < 4); + return Opmask(3 + id); + } + + Reg64 reg_scratch_bias_alpha = r15; + + Reg64 reg_ptr_src = r14; + Reg64 reg_ptr_dst = r13; + + Reg64 reg_ptr_v_y_masks = r12; + Reg64 reg_ptr_v_x_masks = r11; + + Reg64 reg_aux_ptr_src = r10; + Reg64 reg_aux_ptr_dst = r9; + + Reg64 reg_oc_block = r8; + + Reg64 reg_ptr_bias = rbx; + Reg64 reg_ptr_scales = abi_not_param1; + Reg64 reg_ptr_sum_scale = rdx; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* relu before sum */ + return false + || p.contain(eltwise, 0) + || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0)); + } else if (position == 1) { + /* relu after sum */ + const int sum_idx = p.contain(sum, 0) + ? 0 : (p.contain(sum, 1) ? 1 : -1); + if (sum_idx == -1) + return false; + + return false + || p.contain(eltwise, sum_idx + 1) + || jcp.dst_dt == data_type::u8; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() { + Label oc_block_label; + + auto loop_body = [=]() { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + for(int i = 0; i < 16; i++) { + int internal_offset = sizeof(int32_t) * jcp.out_stride * i; + vmovups(vreg_inp(i), + EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); + } + for(int y = 0; y < jcp.alpha; y++) { + vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1)); + vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2)); + + vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2)); + vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3)); + } + for(int x = 0; x < jcp.m; x++) { + vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1)); + vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2)); + + vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2)); + vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3)); + } + + + if (jcp.with_bias) { + vmovq(xmm_bias_alpha, reg_scratch_bias_alpha); + vbroadcastss(zmm_bias_alpha, xmm_bias_alpha); + + auto bias_addr = ptr [ reg_ptr_bias ]; + switch (jcp.bia_dt) { + case data_type::f32: + case data_type::s32: vmovups(vreg_bias, bias_addr); break; + case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break; + case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break; + default: assert(!"unsupported dst data type"); + } + if (jcp.bia_dt != data_type::f32) + vcvtdq2ps(vreg_bias, vreg_bias); + vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha + } + for(int y = 0; y < jcp.m; y++) { + kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]); + for(int x = 0; x < jcp.m; x++) { + kandw(r_mask, y_mask, x_mask(x)); + + int i = y * jcp.m + x; + int offset = jcp.typesize_out * + (y * jcp.ow * jcp.oc + x * jcp.oc); + Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); + + Zmm zmm = vreg_out(i); + Xmm xmm = xmm_out(i); + vcvtdq2ps(zmm, zmm); + if (jcp.with_bias) + vaddps(zmm, zmm, vreg_bias); + vmulps(zmm, zmm, ptr [reg_ptr_scales]); + if (maybe_relu(0)) + vmaxps(zmm, vreg_zero, zmm); + if (p_sum_scale) { // post_op: sum + vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(vreg_prev_dst | r_mask, addr); break; + case data_type::s8: + vpmovsxbd(vreg_prev_dst | r_mask, addr); break; + case data_type::u8: + vpmovzxbd(vreg_prev_dst | r_mask, addr); break; + default: assert(!"unknown dst_dt"); + } + if (jcp.dst_dt != data_type::f32) + vcvtdq2ps(vreg_prev_dst, vreg_prev_dst); + if (*p_sum_scale == 1.f) + vaddps(zmm, vreg_prev_dst); + else + vfmadd231ps(zmm, vreg_prev_dst, + zword_b[reg_ptr_sum_scale]); + } + if (maybe_relu(1)) + vmaxps(zmm, vreg_zero, zmm); + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(addr, zmm | r_mask); break; + case data_type::s8: + vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + case data_type::u8: + vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break; + default: assert(!"unknown dst_dt"); + } + } + } + }; + + preamble(); + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, wino_dst); + READ_PARAM(reg_ptr_dst, dst); + READ_PARAM(reg_ptr_v_y_masks, v_y_masks); + READ_PARAM(reg_ptr_v_x_masks, v_x_masks); + READ_PARAM(reg_ptr_bias, bias); + READ_PARAM(reg_ptr_scales, scales); +# undef READ_PARAM + + if (jcp.with_bias) + mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale)); + + mov(reg_aux_ptr_src, reg_ptr_src); + mov(reg_aux_ptr_dst, reg_ptr_dst); + + vpxord(vreg_zero, vreg_zero, vreg_zero); + + for (int i = 0; i < jcp.m; i++) + kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]); + + int oc_blocks = jcp.oc / load_block; + mov(reg_oc_block, oc_blocks); + L(oc_block_label); { + loop_body(); + add(reg_aux_ptr_src, sizeof(int32_t) * load_block); + add(reg_aux_ptr_dst, jcp.typesize_out * load_block); + + add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block); + add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block); + } + dec(reg_oc_block); + jnz(oc_block_label, T_NEAR); + + postamble(); + +} + +/// GEMM kernel //////////////////////////////////////////////////////////////// +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t) + jit_conv_conf_2x3_wino_t jcp; + const primitive_attr_t &attr_; + + struct call_params_t { + const void *src; + const void *dst; + const void *wei; + const void *dst_b; + }; + void (*ker_)(const call_params_t *); + + void generate(); + static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp, + const primitive_attr_t &attr); + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) + { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode())); + } + + static status_t init_conf( + jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, memory_desc_t &bias_md, + const primitive_attr_t &attr); + + Zmm vreg_out(int n, int m) { + const int id_reg_out = n * jcp.m_block + m; + assert(id_reg_out < jcp.n2_block * jcp.m_block); + return Zmm(31 - id_reg_out); + } + Zmm vreg_wei(int i) { + assert(31 - jcp.n2_block * jcp.m_block - i + > (jcp.ver == ver_vnni ? 0 : 2)); + return Zmm(31 - jcp.n2_block * jcp.m_block - i); + } + + Zmm vreg_src = Zmm(0); + Zmm vreg_one = Zmm(1); + Zmm vreg_tmp = Zmm(2); + + Reg64 reg_ptr_src = r15; + + Reg64 reg_aux_dst_b = r13; + Reg64 reg_aux_dst = r12; + Reg64 reg_aux_dst2 = r11; + Reg64 reg_aux_wei = r10; + Reg64 reg_aux_wei2 = r9; + Reg64 reg_aux_src = r8; + Reg64 reg_aux_src2 = rax; + Reg64 reg_mb = rbx; + Reg64 reg_nnb = abi_not_param1; + Reg64 reg_scratch = rdx; + Reg64 reg_K = rsi; +}; + +bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok( + jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_relu(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_relu(1)) || + (p.contain(sum, 1) && is_relu(0)); + case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); + default: return false; + } + + return false; +} + +void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() { + Label nnb_loop_label, K_loop_label, mb_loop_label; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vreg_tmp, vreg_src, vreg_wei); + vpmaddwd(vreg_tmp, vreg_tmp, vreg_one); + vpaddd(vreg_acc, vreg_acc, vreg_tmp); + } + }; + + preamble(); +# define READ_PARAM(reg, field) \ + mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src, src); + READ_PARAM(reg_aux_dst, dst); + READ_PARAM(reg_aux_wei, wei); + READ_PARAM(reg_aux_dst_b, dst_b); +# undef READ_PARAM + + if (jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(vreg_one, _t); + } + + if (!jcp.small_mb) { + mov(reg_nnb, jcp.n_chunks); + L(nnb_loop_label); + } + mov(reg_aux_dst2, reg_aux_dst); + mov(reg_aux_src, reg_ptr_src); + mov(reg_mb, jcp.M / jcp.m_block); + L(mb_loop_label); + { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + for (int m = 0; m < jcp.m_block; m++) { + int offset = jcp.typesize_acc * nb2 * jcp.n_block; + vmovups(vreg_out(nb2, m), + EVEX_compress_addr(reg_aux_dst_b, offset)); + } + } + mov(reg_aux_src2, reg_aux_src); + mov(reg_aux_wei2, reg_aux_wei); + mov(reg_K, jcp.k_chunks); + L(K_loop_label); + { + for (int k = 0; k < jcp.k2_block; k += 4) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int wei_offset + = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K); + vmovups(vreg_wei(nb2), + EVEX_compress_addr(reg_aux_wei2, wei_offset)); + } + for (int m = 0; m < jcp.m_block; m++) { + int inp_offset = jcp.typesize_in * m * jcp.K; + vpbroadcastd(vreg_src, + EVEX_compress_addr(reg_aux_src2, inp_offset)); + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) + compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src); + } + add(reg_aux_src2, jcp.typesize_in * 4); + add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block); + } + } + dec(reg_K); + jnz(K_loop_label, T_NEAR); + + for (int m = 0; m < jcp.m_block; m++) { + for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { + int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block); + vmovups(EVEX_compress_addr(reg_aux_dst2, offset), + vreg_out(nb2, m)); + } + } + add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K); + add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N); + } + dec(reg_mb); + jnz(mb_loop_label, T_NEAR); + + if (!jcp.small_mb) { + add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block); + add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K); + + dec(reg_nnb); + jnz(nnb_loop_label, T_NEAR); + } + + postamble(); +} +namespace { +bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { + if (jcp.ver == ver_vnni) { + return (jcp.mb <= mkldnn_get_max_threads() + && (jcp.mb > 4 + && jcp.ic > 64 + && !(jcp.oc > 128 && jcp.ih < 14))) + || jcp.mb > mkldnn_get_max_threads(); + } + return true; +} +} + +status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t +::init_conf(jit_conv_conf_2x3_wino_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &wei_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper wei_d(&wei_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; + + jcp.nthr = mkldnn_get_max_threads(); + + jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = wei_d.dims()[with_groups + 2]; + jcp.kw = wei_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ver = ver_avx512_core; + if (!(mayiuse(avx512_core) && + src_d.data_type() == data_type::u8 + && wei_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, + is_winograd_faster_than_direct(jcp))) + return status::unimplemented; + + // block sizes needed for GEMM kernel + jcp.ic_block = 4; + jcp.oc_block = 16; + + bool ok = true + && jcp.ngroups == 1 + && jcp.oc % load_block == 0 && jcp.ic % load_block == 0 + && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 + && everyone_is(3, jcp.kh, jcp.kw) + && everyone_is(1, jcp.stride_h, jcp.stride_w) + && everyone_is(0, jcp.dilate_h, jcp.dilate_w) + && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad + && one_of(jcp.t_pad, 0, 1) + && one_of(jcp.l_pad, 0, 1); + if (!ok) return status::unimplemented; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_acc = sizeof(int32_t); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.m = 2; + jcp.r = 3; + jcp.alpha = jcp.m + jcp.r - 1; + + int aa = jcp.alpha * jcp.alpha; + int L1_cap = get_cache_size(1, true); + int L2_cap = get_cache_size(2, true); + // need 1 extra reg for bcast, and 2 tmp regs for non-vnni + int free_regs = jcp.ver == ver_vnni ? 31 : 29; + + auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float thr_eff; + float Z = (float)jcp.ic + jcp.oc; + float Y = (float)jcp.ic * jcp.oc; + if (small_mb == 0) { // outer par + int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix); + thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); + } else { // inner par + int tranw = iy * ix / jcp.alpha; + int gemmw = aa * (jcp.nb_oc / n2_b); + int tranw_r = rnd_up(tranw, jcp.nthr); + int gemmw_r = rnd_up(gemmw, jcp.nthr); + thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y); + } + return thr_eff; + }; + + auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) { + float mem_eff, req_mem; + int M = ix * iy / jcp.alpha; + if (small_mb == 0) { // outer parallelization strategy + // memory for wino transforms (other memory has poor reuse) + req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc); + mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f; + } else { // inner parallelization strategy + // memory used during gemm + int N = jcp.oc_block * n2_b; + req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N; + mem_eff = nstl::min(1.f, L2_cap / req_mem); + // memory used during wino transforms + int M_per_thr = div_up(M, jcp.nthr); + req_mem = (float)aa * M_per_thr + * (jcp.ic + jcp.typesize_acc * jcp.oc); + if (req_mem > L2_cap) + mem_eff = 0.1f; + } + return mem_eff; + }; + + auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff, + float mem_eff, float reg_eff) { + // these coefficients are chosen empirically + float mem_fac = 0.1f, reg_fac = 0.2f; + // normalized overhead relative to memory and register components + float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff; + // thread and work components affect all others + tot_eff *= thr_eff * work_eff; + return tot_eff; + }; + + auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff, + int &m_block, int &n2_block, float &tot_eff) { + int M = (ix * iy) / jcp.alpha; + int max_m_block = nstl::min(M, free_regs); + int max_n2_block = nstl::min(jcp.nb_oc, free_regs); + tot_eff = 0.f; + for (int im = max_m_block; im > 0; im--) { + if (M % im) + continue; + for (int in2 = max_n2_block; in2 > 0; in2--) { + int used_regs = (im + 1) * in2; + float mem_eff = get_mem_eff(small_mb, ix, iy, in2); + float reg_eff = (float)(im * in2) / (im + in2); + float thr_eff = get_thr_eff(small_mb, ix, iy, in2); + float cur_tot_eff = get_tot_eff( + small_mb, thr_eff, work_eff, mem_eff, reg_eff); + if (jcp.nb_oc % in2 || used_regs > free_regs + || cur_tot_eff <= tot_eff) + continue; + tot_eff = cur_tot_eff; + m_block = im; + n2_block = in2; + } + } + }; + + /* Selecting xb and yb blocking */ + int min_yb = jcp.m; + int min_xb = jcp.m; + int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2)); + int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2)); + float best_eff = 0.f; + for (int ix = min_xb; ix <= max_xb; ix += 2) { + assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2); + for (int iy = max_yb; iy >= min_yb; iy -= 2) { + assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2); + + int m_b[2]; + int n2_b[2]; + bool small_mb; + float inner_eff, outer_eff, work_eff; + + int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix); + work_eff = (float)jcp.oh * jcp.ow / tiled_area; + if (best_eff > 0.f && work_eff < 4.f / 9.f) + continue; // no gain from Winograd transformation + + /* outer parallelization */ + find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff); + + /* inner parallelization */ + find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff); + + small_mb = inner_eff > outer_eff; + float eff = small_mb ? inner_eff : outer_eff; + if (eff > best_eff) { + best_eff = eff; + jcp.yb = iy; + jcp.xb = ix; + jcp.m_block = m_b[small_mb]; + jcp.n2_block = n2_b[small_mb]; + jcp.small_mb = small_mb; + } + } + } + + assert((jcp.m_block + 1) * jcp.n2_block <= free_regs); + assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); + + jcp.mb_block = 1; + if (jcp.small_mb) { + // For small mb harness, set mb_block as large as possible subject to + // the constraint that winograd activations fit into available L3 cache + int L3_cap = get_cache_size(3, true); + int M = jcp.xb * jcp.yb / 4; + int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in; + int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc; + int max_mb_block = nstl::min( + jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size)); + for (int i = max_mb_block; i > 1; i--) { + if (jcp.mb % i == 0) { + jcp.mb_block = i; + break; + } + } + } + jcp.nb_mb = jcp.mb / jcp.mb_block; + + jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4; + jcp.N = jcp.oc; + jcp.K = jcp.ic; + + jcp.inp_stride = jcp.M * jcp.ic; + jcp.out_stride = jcp.M * jcp.oc; + jcp.wei_stride = jcp.ic * jcp.oc; + jcp.bia_stride = jcp.oc; + + jcp.n_block = jcp.oc_block; + jcp.k_block = jcp.ic_block; + + jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block; + + // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4 + // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is + // a multiple of load_block = 16, we just use that for now. + jcp.k2_block = load_block; + jcp.k_chunks = jcp.K / jcp.k2_block; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + /* re-create weights primitive descriptor + and set weights wino_blocking */ + memory_desc_t expect_wei_md = wei_md; + + expect_wei_md.format_kind = format_kind::wino; + expect_wei_md.data_type = data_type::s8; + mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; + wd.wino_format = mkldnn_wino_wei_aaOIoi; + wd.r = jcp.r; + wd.alpha = jcp.alpha; + wd.ic = jcp.ic; + wd.oc = jcp.oc; + wd.ic_block = jcp.ic_block; + wd.oc_block = jcp.oc_block; + wd.oc2_block = jcp.n2_block; + wd.ic2_block = 1; + wd.adj_scale = adj_wei_scale; + + size_t max_size = types::data_type_size(data_type::s8) * + jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; + max_size += types::data_type_size(data_type::s32) * + jcp.alpha * jcp.alpha * jcp.oc; + wd.size = max_size; + + if (wei_md.format_kind == format_kind::any) + wei_md = expect_wei_md; + if (wei_md != expect_wei_md) + return status::unimplemented; + + const int tilesize = jcp.alpha * jcp.alpha; + const int numtiles = jcp.M; + const int alltiles = numtiles * tilesize; + + jcp.size_wino_src + = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K) + / jcp.typesize_in; + jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic; + jcp.size_wino_dst = alltiles * jcp.oc; + + return status::success; +} +//////////////////////////////////////////////////////////////////////////////// + +template <data_type_t dst_data_type> +status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: + pd_t::jit_conf() { + return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf( + jcp_, *this->desc(), this->src_md_, this->weights_md_, + this->dst_md_,this->bias_md_, *this->attr()); +} + +template <data_type_t dst_data_type> +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t:: +init_scratchpad() { + auto scratchpad = this->scratchpad_registry().registrar(); + + int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr; + scratchpad.book(key_wino_V, + sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K); + scratchpad.book(key_wino_M, + sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K); + + dim_t scale_count = attr()->output_scales_.count_; + scratchpad.book(key_conv_adjusted_scales, + sizeof(float) * nstl::max<dim_t>(scale_count, 16)); +} + +template <data_type_t dst_data_type> +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) +{ + kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t( + pd()->jcp_, *pd()->attr()); + src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t( + pd()->jcp_, *pd()->attr()); + dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t( + pd()->jcp_, *pd()->attr()); +} + +template <data_type_t dst_data_type> +jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() { + delete kernel_; + delete src_trans_; + delete dst_trans_; +} + +template <data_type_t dst_data_type> +const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: +adjust_oscales(const memory_tracking::grantor_t &scratchpad) const { + const float *oscales = pd()->attr()->output_scales_.scales_; + auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / (adj_src_scale * adj_wei_scale); + if (count == 1) + utils::array_set(loc_scales, oscales[0] * factor, 16); + else + for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor; + return loc_scales; +} + +template <data_type_t dst_data_type> +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const auto &jcp = kernel_->jcp; + if (jcp.small_mb) + execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx)); + else + execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx)); +} + +template <data_type_t dst_data_type> +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: +execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V); + auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M); + + parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb), + [&](int mb, int tile_y_b, int tile_x_b) { + + int tile_y = tile_y_b * jcp.yb; + int tile_x = tile_x_b * jcp.xb; + + int ithr = mkldnn_get_thread_num(); + auto wino_src = wino_src_base + jcp.size_wino_src * ithr; + auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + auto gemm_p = + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t(); + + /* transformation of input tensor to winograd domain */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min(jcp.alpha, + nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min(jcp.alpha, + nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + mb * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + } + } + /* gemms */ + for (int tile_ij = 0; tile_ij < 16; tile_ij++) { + // start threads at different GEMMs to help bring weights into LLC + int offset = (tile_ij + ithr) % 16; + gemm_p.src = wino_src + jcp.inp_stride * offset; + gemm_p.dst = wino_dst + jcp.out_stride * offset; + gemm_p.wei = wei + jcp.wei_stride * offset; + gemm_p.dst_b = dst_bias + jcp.bia_stride * offset; + + kernel_->ker_(&gemm_p); + } + + /* transformation from winograd domain to output tensor */ + for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { + for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) { + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + mb * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + } + } + }); +} + +template <data_type_t dst_data_type> +void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>:: +execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const auto &jcp = kernel_->jcp; + const float *oscales = adjust_oscales(scratchpad); + + auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei); + auto wino_src = scratchpad.template get<src_data_t>(key_wino_V); + auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M); + + for (int mbb = 0; mbb < jcp.nb_mb; mbb++) { + for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) { + for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { + /* transformation of input tensor to winograd domain */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto src_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t(); + + uint16_t v_y_masks[4], v_x_masks[4]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + + int v_ys = nstl::max(0, jcp.t_pad - y); + int v_ye = nstl::min( + jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); + + int v_xs = nstl::max(0, jcp.l_pad - x); + int v_xe = nstl::min( + jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); + +#pragma unroll(4) + for (int i = 0; i < jcp.alpha; i++) { + v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff); + v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff); + } + auto local_s = src + + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic + + y * jcp.iw * jcp.ic + x * jcp.ic; + auto local_w = wino_src + m * jcp.ic; + + src_trans_p.src = local_s; + src_trans_p.wino_src = local_w; + src_trans_p.v_y_masks = v_y_masks; + src_trans_p.v_x_masks = v_x_masks; + + src_trans_->ker_(&src_trans_p); + }); + + /* gemms */ + parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) { + auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t:: + call_params_t(); + + gemm_p.src = wino_src + jcp.inp_stride * tile_ij; + gemm_p.dst = wino_dst + jcp.out_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + gemm_p.wei = wei + jcp.wei_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block * jcp.K; + gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij + + nnb * jcp.n2_block * jcp.n_block; + + kernel_->ker_(&gemm_p); + }); + + /* transformation from winograd domain to output tensor */ + parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block, + [&](int y_in_block_b, int x_in_block_b, int mb) { + int y_in_block = y_in_block_b * 2; + int x_in_block = x_in_block_b * 2; + + auto dst_trans_p = + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t(); + + uint16_t v_y_masks[2], v_x_masks[2]; + + int y = y_in_block + tile_y; + int x = x_in_block + tile_x; + int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2) + + (x_in_block / 2); + +#pragma unroll(2) + for (int i = 0; i < jcp.m; i++) { + v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0); + v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0); + } + auto local_d = dst + + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc + + y * jcp.ow * jcp.oc + x * jcp.oc; + auto local_w = wino_dst + m * jcp.oc; + + auto scales = oscales; + dst_trans_p.dst = local_d; + dst_trans_p.wino_dst = local_w; + dst_trans_p.v_y_masks = v_y_masks; + dst_trans_p.v_x_masks = v_x_masks; + + dst_trans_p.scales = scales; + dst_trans_p.bias = bia; + + dst_trans_->ker_(&dst_trans_p); + }); + }}} +} + +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>; +template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp new file mode 100644 index 0000000000..9e6e57b051 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp @@ -0,0 +1,128 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t; +struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t; +struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t; + +template <data_type_t dst_data_type> +struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""), + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>); + + status_t init() { + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && expect_data_types(data_type::u8, data_type::s8, + data_type::undef, dst_data_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_conf(); + if (status != status::success) return status; + set_default_alg_kind(alg_kind::convolution_winograd); + + init_scratchpad(); + + return status; + } + + jit_conv_conf_2x3_wino_t jcp_; + + protected: + status_t jit_conf(); + void init_scratchpad(); + + bool set_default_formats() { + using namespace format_tag; + return set_default_formats_common(nhwc, any, nhwc); + } + }; + + typedef typename prec_traits<data_type::u8>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<data_type::s32>::type acc_data_t; + typedef typename prec_traits<dst_data_type>::type dst_data_t; + + jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd); + ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad) + const; + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei, + const char *bia, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_; + jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_; + jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp new file mode 100644 index 0000000000..f4ec29ab00 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp @@ -0,0 +1,820 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_bcast_data, reg_bcast_data); + + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + reduce_loop(load_loop_blk, jcp.ur, i, false); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } + else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + int output_offset = jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep; + + add(aux_reg_output_data, output_offset); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + reduce_loop(load_loop_blk, jcp.ur_tail, 0, true); + L(bcast_loop_tail_out); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in, + zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk, + int ur, int substep, bool wraparound) +{ + auto vreg_load = [=](int i_load) { + return Zmm(ur * load_loop_blk + i_load); + }; + + auto vreg_accum = [=](int i_load, int i_ur) { + return Zmm(i_ur * load_loop_blk + i_load); + }; + + auto zmm_bias_alpha = [=]() { + return Zmm(ur * load_loop_blk); + }; + + auto xmm_bias_alpha = [=]() { + return Xmm(ur * load_loop_blk); + }; + auto bias_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_bias_data, + jcp.typesize_bia * jcp.oc_block * i_load); + }; + + auto comp_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_comp_data, + sizeof(int32_t) * jcp.oc_block * i_load); + }; + + auto scale_ptr = [=](int i_load) { + return EVEX_compress_addr(reg_ptr_scales, + jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); + }; + + auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { + assert(i_ur < jcp.ur); + assert(i_reduce <= jcp.reduce_loop_unroll); + assert(jcp.reduce_loop_unroll == jcp.reduce_block); + + int offt = (jcp.ic_without_padding * i_ur + i_reduce); + + return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt, + bcast); + }; + + auto load_ptr = [=](int i_reduce, int i_load) { + int u0 = i_reduce % jcp.reduce_loop_unroll; + int u1 = i_reduce / jcp.reduce_loop_unroll; + + int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; + + return EVEX_compress_addr(aux_reg_load_data, + u1 * jcp.reduce_loop_load_step + + jcp.typesize_in * offt); + }; + + auto output_ptr = [=](int i_load, int i_ur) { + return EVEX_compress_addr(aux_reg_output_data, + jcp.typesize_out * (jcp.oc_without_padding * i_ur + + i_load * jcp.load_block)); + }; + + auto init = [=]() { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vpxord(r, r, r); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } + }; + + auto store = [=](const bool mask_flag_in) { + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = (sum_idx != -1) + ? &p.entry_[sum_idx].sum.scale + : nullptr; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + if (p_sum_scale && *p_sum_scale != 1.f) { + mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + } + if (jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_scratch, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_scratch); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; + auto zmm_bias = zmm_tmp; + auto zmm_comp = zmm_bcast; + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, + EVEX_compress_addr(rsp,reg_bias_data_off)); + cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag); + } + + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + vcvtdq2ps(r, r); + if (jcp.signed_input) + vaddps(r, r, zmm_comp); + if (jcp.with_bias) + vaddps(r, r, zmm_bias); + + zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r; + vmulps(mask_zmm, r, scale_ptr(i_load)); + } + } + + if (maybe_eltwise(0)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + if (p_sum_scale) { // post_op: sum + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto zmm_prev_dst = zmm_zero; + + auto r = vreg_accum(i_load, i_ur); + cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur), + mask_flag); + + if (*p_sum_scale == 1.f) + vaddps(r, zmm_prev_dst); + else + vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + + if (maybe_eltwise(1)) + eltwise_injector_->compute_vector_range(0, ur * load_loop_blk); + + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + const bool mask_flag = mask_flag_in && + i_load == load_loop_blk - 1; + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(r, zmm_zero, r); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(r, r); + } + for (int i_ur = 0; i_ur < ur; ++i_ur) { + auto r = vreg_accum(i_load, i_ur); + zmm_t r_zmm = mask_flag ? r | ktail_mask : r; + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: + vmovups(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::s8: + vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break; + case data_type::u8: + vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + auto fma_block = [=](bool last_block) { + int reduce_step = 4; + int tail_size = jcp.ic_without_padding % reduce_step; + int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding + ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) + : jcp.reduce_loop_unroll; + for (int i_reduce = 0; i_reduce < loop_unroll; + i_reduce += reduce_step) { + for (int i_load = 0; i_load < load_loop_blk; ++i_load) + vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); + for (int i_ur = 0; i_ur < ur; ++i_ur) { + if (last_block && tail_size != 0 + && i_reduce == loop_unroll - reduce_step) { + Xmm xmm_bcast = Xmm(zmm_bcast.getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data + + jcp.ic_without_padding * i_ur + i_reduce + r], r); + vpbroadcastd(zmm_bcast, xmm_bcast); + } else { + vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false)); + } + if (jcp.signed_input) + vpsubb(zmm_bcast, zmm_bcast, zmm_shift); + for (int i_load = 0; i_load < load_loop_blk; ++i_load) { + compute(vreg_accum(i_load, i_ur), + vreg_load(i_load), zmm_bcast); + } + } + } + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + if (jcp.ic != jcp.ic_without_padding) { + fma_block(true); + } else { + fma_block(false); + } + + if (jcp.oc_without_padding != jcp.oc) { + Label end_store, common_store; + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + + /*Check if it is the last load_loop_blk*/ + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + cmp(reg_load_loop_work, 0); + jg(common_store, T_NEAR); + + /*Check if it is the last ocb*/ + test(reg_reduce_pos_flag, FLAG_OC_LAST); + jz(common_store, T_NEAR); + + store(true); + jmp(end_store, T_NEAR); + + L(common_store); + store(false); + + L(end_store); + + add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + } else { + store(false); + } +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate() +{ + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + sub(rsp, stack_space_needed); + + if (jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_last_load.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + if (jcp.with_bias) + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + if (jcp.signed_input) { + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + + + auto load_loop_body = [=](int load_loop_blk) { + bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + if (jcp.with_bias) { + if (jcp.signed_input) + mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); + add(reg_bias_data, + load_loop_blk * jcp.load_block * jcp.typesize_bia); + if (jcp.signed_input) + mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); + } + if (jcp.signed_input) { + mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); + add(reg_comp_data, + load_loop_blk * jcp.load_block * sizeof(int32_t)); + mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); + } + mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); + mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); + add(reg_ptr_scales, + jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float)); + mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); + mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); + add(reg_output_data, + load_loop_blk * jcp.load_block * jcp.typesize_out); + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + const int simd_w = 16; + + Label load_loop_blk[7]; + + static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 }; + const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); + const int *ur_cases_fma = ur_cases_fma_expl_bcast; + const int *ur_cases = ur_cases_fma; + const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); + + for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { + int label_idx = num_ur_cases - ur_idx - 1; + if (jcp.ur <= ur_cases[ur_idx]) { + cmp(reg_load_loop_work, simd_w * (label_idx + 1)); + jle(load_loop_blk[label_idx], T_NEAR); + } + } + + for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { + if (jcp.ur <= ur_cases[ur_idx]) { + int label_idx = num_ur_cases - ur_idx - 1; + L(load_loop_blk[label_idx]); + { + if (label_idx == 0) { + cmp(reg_load_loop_work, 0); + je(load_loop_blk[num_ur_cases], T_NEAR); + } + + for (int _i = 1; _i <= label_idx + 1; _i++) { + prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]); + prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]); + } + + load_loop_body(label_idx + 1); + if (label_idx - 1 > 0) { + cmp(reg_load_loop_work, 2 * label_idx * simd_w); + je(load_loop_blk[label_idx - 1], T_NEAR); + } + cmp(reg_load_loop_work, (label_idx + 1) * simd_w); + jge(load_loop_blk[label_idx]); + } + for (int idx = label_idx - 1; idx > 0; --idx) { + cmp(reg_load_loop_work, simd_w * (idx + 1)); + je(load_loop_blk[idx], T_NEAR); + } + if (ur_idx < num_ur_cases - 2) { + cmp(reg_load_loop_work, simd_w); + jle(load_loop_blk[0], T_NEAR); + } + } + } + L(load_loop_blk[num_ur_cases]); + + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( + jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, int nthreads, bool reduce_src) { + if (!mayiuse(avx512_core)) return status::unimplemented; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) + || weights_d.data_type() != data_type::s8 + || !one_of(dst_d.data_type(), + data_type::f32, data_type::s32, data_type::s8, data_type::u8)) + return status::unimplemented; + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + jcp.kh = weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + 3]; + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + jcp.tr_is = rnd_up(jcp.is, 4); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + format_tag_t dat_tag = format_tag::nhwc; + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 16; + + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.ic, simd_w); + + args_ok = true + && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.ic_block = jcp.oc_block = simd_w; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + const int SMALL_SPATIAL = 7 * 7; + const int BIG_REDUCE_DIM = 1024; + + int load_blocking = 0; + int load_blocking_max = 0; + int bcast_blocking = 0; + int bcast_blocking_max = 0; + int reduce_blocking = 0; + int reduce_blocking_max = 0; + jcp.load_grp_count = 1; + jcp.use_vmovntps = false; + + const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in); + const int L2_capacity = (L2_size * 3) / 4; + + int size_treshold = 28; + int max_regs = 0; + int min_regs = 6; + if (jcp.ver == ver_vnni) + max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) + && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9; + else + max_regs = 8; + jcp.expl_bcast = true; + + if (jcp.mb == 1 && jcp.ic > 128 + && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { + if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) + max_regs = min_regs; // mobilenet_v2 performance improvement + jcp.ur = nstl::min(max_regs, jcp.os); + } else { + const int spatial = jcp.oh; + jcp.ur = 1; + for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { + if ((spatial >= size_treshold && spatial % ur_w == 0) + || (spatial < size_treshold && jcp.os % ur_w == 0)) { + jcp.ur = ur_w; + break; + } + } + if (jcp.ur == 1) { + jcp.ur = nstl::min(max_regs, jcp.os); + int os_tail = jcp.os % max_regs; + for (int i = max_regs; i >= min_regs; i--) { + int i_tail = jcp.os % i; + if (i_tail > os_tail || i_tail == 0) { + jcp.ur = i; + os_tail = i_tail; + if (i_tail == 0) + break; + } + } + } + } + + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.typesize_in; + + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out; + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in; + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step + = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; + + jcp.load_loop_iter_step = jcp.load_block; + + jcp.loop_order = reduce_src ? loop_blr : loop_lbr; + + int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + reduce_blocking = nb_reduce; + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 64; + else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) + reduce_blocking = 16; + reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); + reduce_blocking *= jcp.reduce_block; + + bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; + if (cmp_reduce) + jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; + load_blocking = jcp.load_dim; + + jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast); + jcp.load_grp_count = best_divider( + nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false); + + if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); + } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads + && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { + jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // + load_blocking = jcp.load_block; + } + + bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, + div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block; + bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking); + bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); + + int space_for_bcast + = (L2_capacity - /* kernel_size - */ + 2 * jcp.load_block * reduce_blocking + - jcp.ur * reduce_blocking - 3 * 1024); + if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) + space_for_bcast /= 2; + + int bcast_in_cache + = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); + bcast_blocking = nstl::min( + bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); + + load_blocking_max = load_blocking; + bcast_blocking_max = bcast_blocking * 3 / 2; + reduce_blocking_max = reduce_blocking; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + assert(reduce_blocking_max); + assert(load_blocking % jcp.load_block == 0); + assert(reduce_blocking % jcp.reduce_block == 0); + assert(load_blocking_max % jcp.load_block == 0); + assert(reduce_blocking_max % jcp.reduce_block == 0); + + assert(jcp.reduce_loop_unroll % 4 == 0); + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + + assert(jcp.bcast_block % jcp.ur == 0); + assert(jcp.reduce_dim % jcp.reduce_block == 0); + + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + // miniumum size of load dim chunk for work distribution within threads + jcp.nb_load_chunk = 1; + // peformance improvements for googlenet_v3, mb=1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 + && jcp.ic * jcp.oc <= L2_size) { + jcp.nb_load_chunk = 4; + jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); + } + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace mkldnn::impl::memory_tracking::names; + + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp new file mode 100644 index 0000000000..22e9732a1f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP +#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t) + jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode(); + } + + ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const memory_desc_wrapper &bias_d, + const primitive_attr_t &attr, + int nthreads, bool reduce_src); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); + + bool maybe_eltwise(int position); + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + + private: + jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; + + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using mask_t = const Xbyak::Opmask; + + reg64_t reg_bcast_data = r8; + reg64_t reg_ptr_scales = r8; + reg64_t reg_output_data = r9; + reg64_t reg_load_data = r10; + reg64_t reg_ptr_sum_scale = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t reg_bias_data = r12; + reg64_t reg_comp_data = r12; + reg64_t reg_scratch = r13; + reg64_t aux_reg_bcast_data = r14; + reg64_t aux_reg_load_data = r15; + reg64_t imm_addr64 = r15; + reg64_t reg_reduce_pos_flag = rax; + reg64_t aux1_reg_bcast_data = rbx; + reg64_t reg_bcast_loop_work = rbx; + reg64_t bcast_loop_iter = rdx; // Note: Fix me + reg64_t reg_load_loop_work = rsi; + reg64_t aux_reg_output_data = abi_not_param1; + reg64_t reduce_loop_iter = abi_param1; + + reg64_t reg_last_load = r8; + mask_t ktail_mask = k6; + + mask_t vmask = k7; + + Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28); + Xbyak::Zmm zmm_one = Xbyak::Zmm(29); + Xbyak::Zmm zmm_zero = Xbyak::Zmm(30); + Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31); + Xbyak::Zmm zmm_shift = Xbyak::Zmm(30); + + Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31); + Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31); + + int bcast_loop_work_off = 0; + int reg_bias_data_off = 8; + int reg_bcast_data_off = 16; + int reg_load_data_off = 24; + int reg_ptr_sum_scale_off = 32; + int reg_comp_data_off = 40; + int stack_space_needed = 48; + + void bcast_loop(int load_loop_blk); + void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); + + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp new file mode 100644 index 0000000000..0bf09fc677 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp @@ -0,0 +1,292 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +namespace { +template <typename T, typename U> +void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, + T nx, T &nx_start, T &nx_end, T nx_divider) +{ + const T grp_size = utils::div_up(nthr, nx_divider); + const T grp_count = utils::div_up(nthr, grp_size); + + T grp = ithr / grp_size; + T grp_ithr = ithr % grp_size; + T grp_nthr = grp_size; + T first_grps = nthr % grp_count; + if (first_grps > 0 && grp >= first_grps) { + ithr -= first_grps * grp_size; + grp_nthr--; + grp = ithr / grp_nthr + first_grps; + grp_ithr = ithr % grp_nthr; + } + balance211(nx, grp_count, grp, nx_start, nx_end); + balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); +} +} + +/* convolution forward */ +template <data_type_t src_type, data_type_t dst_type> +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>:: +execute_forward(const exec_ctx_t &ctx) const +{ + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + auto scratchpad = this->scratchpad(ctx); + + if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { + auto local_scales = scratchpad.template get<float>( + key_conv_adjusted_scales); + auto scales = pd()->attr()->output_scales_.scales_; + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, scales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = scales[c] * factor; + } + } + + parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad); + }); +} + +template <data_type_t src_type, data_type_t dst_type> +void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type> +::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src, + const wei_data_t *weights, const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = kernel_->jcp; + auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space); + auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + const int stride_h = pd()->desc()->strides[0]; + const int stride_w = pd()->desc()->strides[1]; + const int pad_t = pd()->desc()->padding[0][0]; + const int pad_l = pd()->desc()->padding[0][1]; + + const auto &oscales = pd()->attr()->output_scales_; + + int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block) + * jcp.oc_block * jcp.ic_block; + wei_data_t *w = const_cast<wei_data_t *>(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast<int32_t *>(w + offset) : 0; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + + auto p = jit_1x1_conv_call_s(); + + auto rp = rtus_driver_t<avx512_common>::call_params_t(); + const int nb_oc = jcp.nb_load; + const int os_block = jcp.bcast_block; + + int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, + jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, + jcp.load_grp_count); + if (jcp.nb_load_chunk > 1) { + ocb_start *= jcp.nb_load_chunk; + ocb_end *= jcp.nb_load_chunk; + } + + auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, + int &oh, int &ow, int &ih, int &iw) + { + int osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, + jcp.nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + oh = os / jcp.ow; + ow = os % jcp.ow; + + ih = nstl::max(oh * stride_h - pad_t, 0); + iw = nstl::max(ow * stride_w - pad_l, 0); + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int &load_step) + { + load_step = step(jcp.nb_load_blocking, ocb_end - ocb, + jcp.nb_load_blocking_max); + p.load_dim = this_block_size(ocb * jcp.oc_block, + ocb_end * jcp.oc_block, load_step * jcp.oc_block); + + if (ocb + load_step >= nb_oc) + p.first_last_flag |= FLAG_OC_LAST; + else + p.first_last_flag &= ~FLAG_OC_LAST; + + }; + + auto init_reduce = [&]() + { + p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic); + rp.icb = p.reduce_dim / jcp.reduce_block; + }; + + auto inner_ker = [&](int ocb, int n, int g, int oh, int ow, + int ih, int iw) + { + const int icb = 0; // Start from the first IC block + const int _ocb = g * nb_oc + ocb; + const int _icb = g; + + const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow); + + p.output_data = &dst[dst_off]; + p.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; + p.compensation = (jcp.signed_input) + ? &compensation[_ocb * jcp.oc_block] : 0; + p.scales = (jcp.signed_input && jcp.ver != ver_vnni) + ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] + : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + _icb * jcp.is * jcp.ic_block; + if (ocb == ocb_start) { + rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + rtus_driver_->ker_(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw); + + kernel_->jit_ker(&p); + }; + + if (jcp.loop_order == loop_rlb) { + init_reduce(); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == loop_rbl) { + init_reduce(); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else if (jcp.loop_order == loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n, g, bcast_step, oh, ow, ih, iw; + init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, load_step); + init_reduce(); + inner_ker(ocb, n, g, oh, ow, ih, iw); + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } +} + +using namespace data_type; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>; +template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp new file mode 100644 index 0000000000..ad9027ac17 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" +#include "jit_uni_1x1_conv_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template<impl::data_type_t src_type, impl::data_type_t dst_type> +struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_(), rtus_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""), + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t< + src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && !has_zero_dim_memory() + && set_default_formats_common(dat_tag(), format_tag::any, + dat_tag()) + && set_or_check_wei_format(); + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + rtus_prepare(this, conv_d, src_d, dst_md()); + + status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel:: + init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), + *weights_md(1), *attr(), mkldnn_get_max_threads(), + rtus_.reduce_src_); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( + scratchpad, jcp_, *attr()); + + rtus_prepare_space_info(this, scratchpad); + + return status::success; + } + + jit_1x1_conv_conf_t jcp_; + reduce_to_unit_stride_t rtus_; + + protected: + format_tag_t dat_tag() const { return format_tag::nhwc; } + + bool set_or_check_wei_format() { + using namespace format_tag; + + const bool is_src_s8 = src_md_.data_type == data_type::s8; + format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i; + + memory_desc_t want_wei_md = weights_md_; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (is_src_s8) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups() ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md_.format_kind == format_kind::any) { + weights_md_ = want_wei_md; + return true; + } + + return weights_md_ == want_wei_md; + } + }; + + template <cpu_isa_t isa, typename conv_t> + friend void init_rtus_driver(conv_t *self); + + jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), rtus_driver_(nullptr) + { + kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_, + *pd()->attr()); + init_rtus_driver<avx512_common>(this); + } + + ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() { + delete kernel_; + delete rtus_driver_; + } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + typedef typename prec_traits<data_type::s32>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + + private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const char *bias, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_; + rtus_driver_t<avx512_common> *rtus_driver_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp new file mode 100644 index 0000000000..e89d068302 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp @@ -0,0 +1,140 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "type_helpers.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_1x1_conv_utils.hpp" +#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, impl::data_type_t dst_type> +struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t + : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<src_type, dst_type>); + + status_t init_convolution() { + convolution_desc_t cd; + status_t status; + + auto dd = desc(); + status = conv_desc_init(&cd, prop_kind::forward_training, + alg_kind::convolution_direct, &(dd->src_desc), + &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc), + dd->strides, dd->dilates, dd->padding[0], dd->padding[1], + dd->padding_kind); + + if (status == status::success) { + status = mkldnn_primitive_desc::create<conv_pd_t>( + &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr); + } + + if (status == status::success) + status = set_default_params(); + + return status; + }; + + status_t init() { + bool ok = true + && is_fwd() + && desc()->alg_kind == alg_kind::deconvolution_direct + && !has_zero_dim_memory() + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && desc()->weights_desc.data_type == data_type::s8 + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + CHECK(init_convolution()); + + return status::success; + } + + virtual void init_scratchpad_md() override { + const auto conv_1x1_pd = static_cast<conv_pd_t *>(conv_pd_); + scratchpad_md_ = *conv_1x1_pd->scratchpad_md(); + } + + protected: + status_t set_default_params() { + auto conv_1x1_pd_ = static_cast<conv_pd_t *>(conv_pd_); + src_md_ = *conv_1x1_pd_->src_md(); + dst_md_ = *conv_1x1_pd_->dst_md(); + weights_md_ = *conv_1x1_pd_->weights_md(); + if (with_bias()) + bias_md_ = *conv_1x1_pd_->weights_md(1); + return status::success; + } + + using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t + <src_type, dst_type>::pd_t; + friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t; + primitive_desc_t *conv_pd_; + }; + + jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + + ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t() + { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + return conv_p_->execute(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +} +} +} + +#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp new file mode 100644 index 0000000000..10e98a00c4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp @@ -0,0 +1,1182 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_memory.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +namespace { +void pick_loop_order(jit_conv_conf_t &jcp, int nthr) +{ + jcp.loop_order = loop_cwgn; + if (jcp.ngroups > 1) { + jcp.loop_order = loop_ngcw; + if (jcp.mb < nthr) + jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; + } +} +} + +template<typename Vmm> +bool _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::maybe_eltwise(int position) +{ + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + + return false; +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w) +{ + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + for (int k = 0; k < nb_oc_block; k++) + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + vpxord(vmm, vmm, vmm); + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + if (jcp.is_depthwise && !jcp.is_fast_depthwise) { + Reg32 _t32 = reg_scratch.cvt32(); + mov(_t32, (uint32_t)128); + vpbroadcastd(vmm_shift, _t32); + } else { + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)128); + vpbroadcastb(vmm_shift, _t8); + } + } +} + +template<typename Vmm> +const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>:: + vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) { + return vmm_in; +} + +template<> +const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>:: + vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) { + return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) + : zmm_in; +} + + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in, + const Vmm vmm_in, const Operand &op, bool mask_flag) { + //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; + const Vmm vmm = vmm_mask(vmm_in, mask_flag); + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(vmm, op); break; + case data_type::s8: vpmovsxbd(vmm, op); break; + case data_type::u8: vpmovzxbd(vmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(vmm_in, vmm_in); +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + if (ur_w == jcp.ur_w) + eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w); + else + for (int k = 0; k < nb_oc_block; k++) + eltwise_injector_->compute_vector_range(k * jcp.ur_w, + k * jcp.ur_w + ur_w); +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output( + int ur_w, bool last_oc_block_flag) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; + + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale = nullptr; + if (sum_idx != -1) { + const auto &p_entry = p.entry_[sum_idx]; + p_sum_scale = &p_entry.sum.scale; + } + + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.signed_input && jcp.ver != ver_vnni) { + /* put 'wei_adj_scale = 0.5' for bias calculation */ + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * k * oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + + cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + /* bias *= 0.5 */ + vmulps(vmm_bias, vmm_bias, vmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * k * oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + + cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); + } + /* add to zmm_accum: compensation, bias and permute */ + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.is_fast_depthwise) + vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); + vcvtdq2ps(vmm, vmm); + if (jcp.signed_input) + vaddps(vmm, vmm, vmm_comp); + if (jcp.with_bias) + vaddps(vmm, vmm, vmm_bias); + + const Vmm vmm_k = vmm_mask(vmm, mask_flag); + vmulps(vmm_k, vmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + + /* Do post-ops */ + if (maybe_eltwise(0)) compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + Vmm vmm = vmm_out(j, k); + cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(vmm, vmm_prev_dst); + else + vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) compute_eltwise(ur_w); + + /* write out register to output_addr */ + for (int k = 0; k < nb_oc_block; k++) { + const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; + for (int j = 0; j < ur_w; j++) { + Vmm vmm = vmm_out(j, k); + if (jcp.dst_dt == data_type::u8) { + vpxord(vmm_zero, vmm_zero, vmm_zero); + vmaxps(vmm, vmm_zero, vmm); + } + + if (jcp.dst_dt != data_type::f32) { + /* Note: using Zmm for rounding in Xmm/Ymm kernel + because there is no instruction to do rounding + from Xmm/Ymm -> Xmm/Ymm. + Embedded rounding is not supported for Xmm. + TODO: maybe avoid Zmm if it helps performance.*/ + Zmm zmm = zmm_out(j, k); + vcvtps2dq(zmm, zmm); + } + } + + for (int j = 0; j < ur_w; j++) { + int aux_output_offset = jcp.typesize_out + * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_out, aux_output_offset); + + Vmm vmm = vmm_out(j, k); + const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); + + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_vmm); break; + case data_type::s8: vpmovsdb(addr, r_vmm); break; + case data_type::u8: vpmovusdb(addr, r_vmm); break; + default: assert(!"unknown dst_dt"); + } + } + } + +} + +template <typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + assert(!"invalid group blocking for depthwise convolution"); +} + +template <> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + + auto input_spatial_index = [=](int oi, int ki) { + return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); + }; + + auto input_offset2 = [=](int ii, int ci) { + return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); + }; + + auto input_offset3 = [=](int oi, int ci, int ki) { + return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); + }; + + auto kernel_offset = [=](int ci, int ki) { + return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); + }; + + auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { + // okay for depthwise since src is zero-extended + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddwd(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + int ii_start = 0; + int ii_end = -1; + if (jcp.is_resrc_depthwise && !h_padded) { + // find bounds of input spatial indices + bool first = true; + for (int ki = 0; ki < jcp.kw; ki++) { + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + for (int oi = oi_start; oi < oi_end; oi++) { + int ii = input_spatial_index(oi, ki); + if (first || ii < ii_start) + ii_start = ii; + if (first || ii > ii_end) + ii_end = ii; + first = false; + } + } + } + + if (jcp.signed_input) { + vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero); + vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift); + } + for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { + const bool mask_flag = last_ic_block_flag != no_last_block + && ci == jcp.nb_ch_blocking - 1; + if (jcp.is_resrc_depthwise && !h_padded) { + // now we can load input once and reuse up to jcp.kw times + for (int ii = ii_start; ii <= ii_end; ii++) { + int aux_input_offset = input_offset2(ii, ci); + const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); + const Zmm zmm_inp_msk = mask_flag + ? zmm_inp_tmp | ktail_mask | T_z + : zmm_inp_tmp; + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } else { + vpmovzxbd(zmm_inp_msk, + EVEX_compress_addr(aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); + } + } + for (int ki = 0; ki < jcp.kw; ki++) { + int aux_kernel_offset = kernel_offset(ci, ki); + if (jcp.is_fast_depthwise) { + vbroadcasti32x4(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); + } else { + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + } + if (h_padded) { + assert(jcp.signed_input); + for (int oi = 0; oi < ur_w; oi++) + compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); + } else { + const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src; + int oi_start = get_ow_start(ki, pad_l); + int oi_end = get_ow_end(ur_w, ki, pad_r); + int start_ = jcp.signed_input ? 0 : oi_start; + int end_ = jcp.signed_input ? ur_w : oi_end; + for (int oi = start_; oi < end_; oi++) { + if (oi >= oi_start && oi < oi_end) { + if (jcp.is_resrc_depthwise) { + int ii = input_spatial_index(oi, ki); + zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); + } else { + int aux_input_offset = input_offset3(oi, ci, ki); + if (jcp.is_fast_depthwise) { + assert(!mask_flag); + vbroadcasti32x4(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } else { + vpmovzxbd(r_zmm_src, + EVEX_compress_addr(aux_reg_inp, + aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(zmm_src, zmm_src, vmm_shift); + } + } else if (jcp.signed_input) { + zmm_src = zmm_shifted_zero; + } + compute(zmm_out(oi, ci), zmm_wei, zmm_src); + } + } + } + } +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l, + int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { + if (jcp.is_depthwise) + return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); + + int kw = jcp.kw; + int stride_w = jcp.stride_w; + int ic_block = jcp.ic_block; + int oc_block = jcp.oc_block; + int ch_block_all = jcp.ch_block * ic_block * oc_block; + + int nb_oc_block = jcp.nb_oc_blocking; + + auto input_offset = [=](int oi, int ic, int ki) { + return jcp.typesize_in + * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) + * jcp.ic_without_padding * jcp.ngroups + 4 * ic); + }; + auto kernel_offset = [=](int ii, int ic, int ki) { + return jcp.typesize_in + * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all + + 4 * ic * oc_block); + }; + auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else { + vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); + vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); + vpaddd(vreg_acc, vreg_acc, vmm_tmp); + } + }; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = get_ow_start(ki, pad_l); + int jj_end = get_ow_end(ur_w, ki, pad_r); + int tail_size = jcp.ic_without_padding % 4; + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */ + int icb = (last_ic_block_flag != no_last_block) + ? div_up((jcp.ic_without_padding % ic_block), 4) + : ic_block / 4; + for (int ic = 0; ic < icb; ic++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Vmm inp = vmm_inp(0,nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } else { + for (int jj = _start; jj < _end; jj++) { + int aux_input_offset = input_offset(jj, ic, ki); + if (jj >= jj_start && jj < jj_end) { + if (last_ic_block_flag == last_sp_block + && tail_size != 0 && ic == icb - 1) { + Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_inp + aux_input_offset + r], r); + vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); + } else { + vpbroadcastd(vmm_inp(jj, nb_oc_block), + EVEX_compress_addr( + aux_reg_inp, aux_input_offset)); + } + if (jcp.signed_input) + vpaddb(vmm_inp(jj, nb_oc_block), + vmm_inp(jj, nb_oc_block), vmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Vmm inp = vmm_inp(jj, nb_oc_block); + vpxord(inp, inp, inp); + vpaddb(inp, inp, vmm_shift); + } + } + } + } + for (int ii = 0; ii < nb_oc_block; ii++) { + int aux_kernel_offset = kernel_offset(ii, ic, ki); + vmovups(vmm_wei, + EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); + for (int jj = _start; jj < _end; jj++) { + Vmm inp = (h_padded == true) + ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block); + compute(vmm_out(jj, ii), vmm_wei, inp); + } + } + } + } +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { + Label kh_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, + b_overflow_label, no_b_overflow_label; + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; + int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ic_without_padding * jcp.ngroups; + + mov(aux_reg_inp, reg_inp); + mov(aux_reg_ker, reg_ker); + + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } + mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); + if ((jcp.signed_input) || (!jcp.signed_input && + (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) { + cmp(reg_kj, 0); + je(skip_kh_loop, T_NEAR); + } + L(kh_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); + + add(aux_reg_ker, shift_kernel_ptr); + add(aux_reg_inp, shift_input_ptr); + dec(reg_kj); + cmp(reg_kj, 0); + jg(kh_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); + + add(aux_reg_ker, shift_kernel_ptr); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_sp_block) +{ + prepare_output(ur_w); + + // IC loop + Label icb_label; + mov(reg_icb, jcp.nb_ic); + L(icb_label); + if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + + cmp(reg_icb, 1); // The last IC block + jne(common_ker, T_NEAR); + + kh_loop(ur_w, pad_l, pad_r, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, pad_l, pad_r, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, pad_l, pad_r, no_last_block); + } + // End of IC Loop + int inp_step = jcp.ic_block; + int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block; + add(reg_inp, jcp.typesize_in * inp_step); + add(reg_ker, jcp.typesize_in * ker_step); + + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_label, T_NEAR); + + sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); + sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + + jne(common_store, T_NEAR); + + store_output(ur_w, true); // last oc block + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + } else { + store_output(ur_w, false); + } +} + +template<typename Vmm> +void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate() +{ + Label permute_index_table; + int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad) + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad + * jcp.ic_without_padding * jcp.ngroups; + int inp_shift = jcp.typesize_in * + (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding + * jcp.ngroups); + int out_shift = jcp.typesize_out * + (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); + preamble(); + + if (jcp.is_depthwise) { + int idx = jcp.max_regs_ur - 1; + if (!jcp.is_resrc_depthwise) + zmm_src = Zmm(++idx); + if (jcp.ver != ver_vnni) + zmm_tmp = Zmm(++idx); + if (jcp.is_fast_depthwise) + zmm_permute = Zmm(++idx); + if (jcp.signed_input) { + zmm_shifted_zero = Zmm(++idx); + ++idx; // due to extra register used for shifts and compensations + } + assert(idx == ker_dw_reg_base_idx); + } + + if (!jcp.is_depthwise && jcp.ver != ver_vnni) { + xor_(reg_scratch, reg_scratch); + Reg16 _t16 = reg_scratch.cvt16(); + mov(_t16, 0x1); + vpbroadcastw(vmm_one, _t16); + } + + mov(reg_inp, ptr[param1 + GET_OFF(src)]); + mov(reg_out, ptr[param1 + GET_OFF(dst)]); + mov(reg_ker, ptr[param1 + GET_OFF(filt)]); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise + ? jcp.ngroups % jcp.ch_block + : jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + Reg32 regw_tmp = reg_oi.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + if (jcp.is_fast_depthwise) { + // prepare mask register for blending weights + mov(reg_scratch, 0x8888444422221111); + kmovq(kblend_mask, reg_scratch); + // load permute indices from data section + mov(reg_scratch, permute_index_table); + vmovdqu32(zmm_permute, ptr[reg_scratch]); + } + + int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + int n_oi = jcp.ow / jcp.ur_w; + int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1); + + if (jcp.nb_ow == 1) { + if (r_pad1 > 0 || jcp.ur_w_tail == 0) + n_oi--; + + xor_(reg_oi, reg_oi); + if (jcp.ow == jcp.ur_w) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true); + } else { + if (n_oi == 0) { + icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } else { + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + inc(reg_oi); + } + if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1)) + { + Label ow_loop_label; + L(ow_loop_label); { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + inc(reg_oi); + cmp(reg_oi, n_oi); + jl(ow_loop_label, T_NEAR); + } + } + if (r_pad1 > 0 || jcp.ur_w_tail == 0) { + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label, + oi_loop_label, oi_loop_end_label; + + assert(jcp.ow_block % jcp.ur_w == 0); + int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + int n_oi_last_ow_block + = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w; + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded + = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded + = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) n_oi_last_ow_block--; + else if (first_ow_block_padded) n_oi_first_ow_block--; + else if (next_last_ow_block_padded) n_oi_next_last_ow_block--; + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // is that the first ow-block ? + jg(middle_ow_blocks_label, T_NEAR); + + // the first ow block, compute left padding + mov(reg_oi, n_oi_first_ow_block); + if (jcp.l_pad > 0) { + icb_loop(jcp.ur_w, jcp.l_pad, 0, false); + add(reg_inp, inp_shift_pad); + add(reg_out, out_shift); + + dec(reg_oi); + } + jmp(oi_loop_label, T_NEAR); + + // middle or last ow block entry + L(middle_ow_blocks_label); + + if (jcp.l_pad > 0) { + // just to consider left padding, not compute + add(reg_inp, inp_shift_pad_second_block); + } + + // set number of iteration for oi-loop + if (n_oi_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + je(oi_loop_label, T_NEAR); + } + + if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) { + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + + mov(reg_oi, n_oi_next_last_ow_block); + je(oi_loop_label, T_NEAR); + } + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); { + cmp(reg_oi, 0); + jle(oi_loop_end_label, T_NEAR); + + icb_loop(jcp.ur_w, 0, 0, false); + + add(reg_inp, inp_shift); + add(reg_out, out_shift); + dec(reg_oi); + + jmp(oi_loop_label, T_NEAR); + } + L(oi_loop_end_label); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ? + jl(end_label, T_NEAR); + if (next_last_ow_block_padded) + je(last_oi_label, T_NEAR); + else + je(end_label, T_NEAR); + + // that is last block + if (!last_ow_block_padded) + jmp(tail_label, T_NEAR); + + // last oi block with right padding + L(last_oi_label); + icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0); + add(reg_inp, inp_shift); + add(reg_out, out_shift); + + mov(reg_owb, ptr[param1 + GET_OFF(owb)]); + cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? + jl(end_label, T_NEAR); + + // ur_w tail + L(tail_label); + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_pad, true); + } + L(end_label); + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); + + if (jcp.is_fast_depthwise) { + align(64); + L(permute_index_table); + const uint32_t _idx[] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) + dd(_idx[i]); + } +} + +bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) +{ + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: return (p.contain(sum, 0) && is_eltwise(1)) || + (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr, + int nthreads) +{ + using namespace prop_kind; + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + bool is_1d = ndims == 3; + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero<decltype(jcp)>(); + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.ic_block = 1; + jcp.oc_block = 1; + } else { + jcp.ch_block = 1; + jcp.ic_block = 16; + jcp.oc_block = 16; + + if (jcp.ngroups == 1) { + /* For non grouped convolutions, pad channels by 16 if needed */ + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) { + /* For grouped convolutions, MKL-DNN doesn't support padding. + Use Ymm when channels per group is multiple of 8, + Xmm when channels per group is multiple of 4 */ + jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4; + jcp.oc_block = jcp.ic_block; + } + if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0) + return status::unimplemented; + } + + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core; + jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni + && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 + // would require byte masking + // for load from src + jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw + && jcp.kw < 4 && jcp.dilate_w == 0; + if (jcp.is_depthwise) { + jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise + - 2 * jcp.signed_input - (jcp.ver != ver_vnni); + } else { + jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28; + } + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + format_tag_t wei_tag; + if (jcp.ic_block == 16 || jcp.ch_block == 16) { + if (is_1d) { + wei_tag = with_groups + ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i + : OIw4i16o4i; + } else { + wei_tag = with_groups + ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i + : OIhw4i16o4i; + } + } else if (with_groups && jcp.ic_block == 8) { + wei_tag = gOIhw2i8o4i; + } else + wei_tag = gOIhw4o4i; + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; + jcp.dst_dt = cd.dst_desc.data_type; + + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + jcp.typesize_bia = jcp.with_bias + ? types::data_type_size(bias_d.data_type()) + : 0; + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_ic = jcp.ic / jcp.ic_block; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) + int nb_ch_blocking = 4; + for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--) + if (jcp.nb_ch % nb_ch_blocking == 0) + break; + jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; + + // If OC blocking is incommensurate with the number of OC blocks (general + // requirement for all convolutions), or if it results in an unrolling + // factor smaller than the left padding (special requirement for SSD:fc6), + // then search for a smaller OC blocking that satisfies both constraints. + auto is_oc_blocking_ok = [&](int block) { + int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); + return jcp.nb_oc % block == 0 + && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1; + }; + + // choose nb_oc work chunk size for distribution within threads + int max_threading_nb_oc_chunk = 4; + // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 + && jcp.stride_w == 1 && jcp.ic % 64 == 0) + max_threading_nb_oc_chunk = 2; + jcp.nb_oc_blocking_thr_chunk = + nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); + for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { + if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) + break; + } + + // choose oc blocking for computational kernel + jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; + // Performance improvements for googlenet_v3 with mb = 1; + // TODO: generalize this condition and rewrite it in appropriate manner + const int size_treshold_for_nb_oc_blocking_reduction = 17; + if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction + && jcp.stride_w == 1 + && !(jcp.kh == 1 && jcp.kw == 3) + && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { + const int max_nb_oc_blocking = 2; + jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 + && is_oc_blocking_ok(jcp.nb_oc_blocking)) + break; + } + + if (jcp.is_resrc_depthwise) + jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) + / (jcp.nb_ch_blocking + jcp.stride_w); + else + jcp.ur_w + = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking + : jcp.nb_oc_blocking + 1); + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.ow_block = jcp.ow; + int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh + * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); + float best_thr_eff + = (float)base_work_amount / rnd_up(base_work_amount, nthreads); + int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w); + for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { + int ow_block + = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow); + if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block + && best_thr_eff > 0.8f) + break; + if (div_up(jcp.ow, ow_block) != nb_ow) + continue; + auto work_amount = base_work_amount * nb_ow; + float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads); + if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) { + jcp.ow_block = ow_block; + best_thr_eff = thr_eff; + } + if (best_thr_eff > 0.9f) + break; + } + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + bool args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); + if (!args_ok) + return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + if (r_pad_no_tail > jcp.ur_w) + return status::unimplemented; + + pick_loop_order(jcp, nthreads); + + jcp.nb_ic_L2 = jcp.nb_ic; + + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + return status::success; +} + +void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>; +template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp new file mode 100644 index 0000000000..d8a05ad53e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template<typename Vmm> +struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) + + enum { STATE_FIRST_DST_LOAD = 0x1U }; + + _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : jcp(ajcp), attr_(attr), + eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise); + + generate(); + jit_ker_ = (void (*)(jit_conv_call_s *))getCode(); + } + + ~_jit_avx512_core_x8s8s32x_fwd_kernel() { + delete eltwise_injector_; + } + + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker_)(jit_conv_call_s *); + +private: + jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; + + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + ker_dw_reg_base_idx = 30, + }; + typedef enum { + no_last_block, + last_ic_block, + last_sp_block, + } ic_block_t; + + /* data regs */ + const Xbyak::Reg64 reg_ptr_scales = rax; + const Xbyak::Reg64 reg_inp = r8; + const Xbyak::Reg64 reg_ker = r9; + const Xbyak::Reg64 reg_out = r10; + const Xbyak::Reg64 aux_reg_inp = r11; + const Xbyak::Reg64 reg_ptr_sum_scale = r11; + const Xbyak::Reg64 aux_reg_ker = r12; + const Xbyak::Reg64 reg_compensation = r14; + /* counter regs */ + const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; + const Xbyak::Reg64 reg_oi = rbx; + const Xbyak::Reg64 reg_bias = rdx; + const Xbyak::Reg64 reg_oc_blocks = rsi; + const Xbyak::Reg64 reg_owb = aux_reg_ker; + const Xbyak::Reg64 reg_scratch = reg_compensation; + const Xbyak::Reg64 reg_kj = reg_ptr_scales; + const Xbyak::Reg64 reg_overflow = reg_ptr_scales; + const Xbyak::Reg64 reg_icb = reg_bias; + + const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); + + const Vmm vmm_wei = Vmm(31); + /* used during bias section of store_output */ + const Vmm vmm_comp = Vmm(30); // only for signed input + const Vmm vmm_bias = Vmm(31); + /* used during post_op sum section of store_output */ + const Vmm vmm_prev_dst = Vmm(31); + /* used during write-out section of store_output */ + const Vmm vmm_zero = Vmm(31); + + /* used in compute_ker (but set during prepare_output) */ + const Vmm vmm_shift = vmm_comp; // only for signed input + /* used in compute_ker (but only for pre-VNNI machines) */ + const Vmm vmm_tmp = Vmm(28); // not used for depthwise + const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise. + + /* registers use only for depthwise + groups are always blocked by 16(padded if needed), + hence use only Zmm registers */ + const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); + Xbyak::Zmm zmm_tmp; + Xbyak::Zmm zmm_src; + Xbyak::Zmm zmm_shifted_zero; + Xbyak::Zmm zmm_permute; + + Vmm vmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Vmm(idx); + } + Xbyak::Zmm zmm_out(int i_ur, int i_oc) { + int idx = i_ur + i_oc * jcp.ur_w; + assert(idx < (jcp.is_depthwise + ? ker_dw_reg_base_idx : ker_reg_base_idx)); + return Xbyak::Zmm(idx); + } + Vmm vmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Vmm(idx); + } + Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return Xbyak::Zmm(idx); + } + Vmm vmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Vmm(nb_c_block * jcp.ur_w); + } + Xbyak::Xmm xmm_bias_alpha() { + int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + return Xbyak::Xmm(nb_c_block * jcp.ur_w); + } + int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w - nstl::max(0, utils::div_up(pad_r + - (jcp.kw - 1 - ki) + * (jcp.dilate_w + 1), + jcp.stride_w)); + } + + bool maybe_eltwise(int position); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block_flag); + void compute_ker_dw( + int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded); + void compute_ker(int ur_w, int pad_l, int pad_r, + ic_block_t last_ic_block_flag, bool h_padded = false); + void compute_eltwise(int ur_w); + void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); + void icb_loop( + int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); + void generate(); + void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, + bool mask_flag); + const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); +}; + +struct jit_avx512_core_x8s8s32x_fwd_kernel { + + jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) : + jit_ker(nullptr), + zmm_kernel_(nullptr), + ymm_kernel_(nullptr), + xmm_kernel_(nullptr) { + int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; + switch (ch_block) { + case 16: + zmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>( + ajcp, attr); + jit_ker = zmm_kernel_->jit_ker_; + return; + case 8: + ymm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>( + ajcp, attr); + jit_ker = ymm_kernel_->jit_ker_; + return; + case 4: + xmm_kernel_ = + new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>( + ajcp, attr); + jit_ker = xmm_kernel_->jit_ker_; + return; + default: + assert(!"invalid channel blocking"); + } + } + + ~jit_avx512_core_x8s8s32x_fwd_kernel() { + delete xmm_kernel_; + delete ymm_kernel_; + delete zmm_kernel_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + memory_desc_t &src_pd, + memory_desc_t &weights_pd, + memory_desc_t &dst_pd, + memory_desc_t &bias_pd, + const primitive_attr_t &attr, + int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + void (*jit_ker)(jit_conv_call_s *); + _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_; + _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp new file mode 100644 index 0000000000..cdbf333d5e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp @@ -0,0 +1,423 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_avx512_core_x8s8s32x_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace nstl; + +using jit_conv_ker_t = void (*)(jit_conv_call_s *); + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() \ + ? (d).blk_off((g), __VA_ARGS__) \ + : (d).blk_off(__VA_ARGS__)) + +template <data_type_t src_type, data_type_t dst_type> +void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, + dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get<float>( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast<wei_data_t *>(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast<int32_t *>(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, + nb_groups, n, jcp.mb); + break; + case loop_gncw: + nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, + gg, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + int ocb = occ * jcp.nb_oc_blocking; + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.nb_ic * jcp.ic_block; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.dst = dst + dst_d.blk_off(n, g_oc, ow_s); + p.src = src + src_d.blk_off(n, g_ic, iw_s); + p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0); + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.oc_blocks = jcp.is_depthwise ? gb : ocb; + p.kh_padding = jcp.kh; + p.t_overflow = 0; + p.b_overflow = 0; + p.owb = owb; + + kernel_->jit_ker(&p); + + ++start; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups, + n, jcp.mb); + break; + case loop_gncw: + nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_ngcw: + nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb, + jcp.nb_ow); + break; + case loop_nwcg: + nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, + nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template <data_type_t src_type, data_type_t dst_type> +void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, + dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ch_block == 1); + assert(jcp.nb_ch_blocking == 1); + assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get<float>( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast<wei_data_t *>(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast<int32_t *>(&w[offset]) : 0; + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk; + int nb_groups = jcp.nb_ch; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; + + parallel(0, [&](const int ithr, const int nthr) { + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 }; + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, + occ, oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + while (start < end) { + for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk; + occ1 += jcp.nb_oc_blocking) { + int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1; + int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block; + + int g_ic = g * jcp.nb_ic * jcp.ic_block; + + int work_rem = end - start; + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + if (jcp.loop_order == loop_nhwcg) + oh_e = oh_s + 1; // step instead + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias + ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) + : 0; + int32_t *compensation_w = (jcp.signed_input) + ? compensation + g_oc : 0; + + auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + + for (int oj = oh_s, ij = ih_s; oj < oh_e; + ++oj, ij += jcp.stride_h) { + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, + div_up(max(0, -ij), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, div_up( + max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, + jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = (!jcp.signed_input) + ? i_t_overflow * wht_h_stride : 0; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = ocb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + src_w += src_h_stride * jcp.stride_h; + dst_w += dst_h_stride; + } + } + switch (jcp.loop_order) { + case loop_cwgn: + nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g, + nb_groups, n, jcp.mb, oh_s, jcp.oh); + break; + case loop_ngcw: + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); + break; + case loop_nhwcg: + ++start; + nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, + oc_chunks, g, nb_groups); + break; + default: assert(!"unsupported loop order"); + } + } + }); +} + +template <data_type_t src_type, data_type_t dst_type> +void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, + dst_type>::execute_forward_2d_dw(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const size_t bia_dt_size = pd()->with_bias() + ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0; + + const auto &jcp = pd()->jcp_; + assert(jcp.ic_block == 1); + assert(jcp.oc_block == 1); + assert(jcp.nb_ic == 1); + assert(jcp.nb_oc == 1); + assert(jcp.nb_oc_blocking == 1); + assert(jcp.nb_ch % jcp.nb_ch_blocking == 0); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales = scratchpad(ctx).template get<float>( + key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + + size_t offset = weights_d.size() - weights_d.additional_buffer_size(); + auto w = const_cast<wei_data_t *>(weights); + int32_t* compensation = (jcp.signed_input) + ? reinterpret_cast<int32_t *>(&w[offset]) : 0; + int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking; + int group_block = jcp.ch_block; + + parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, + [&](int n, int oh_s, int owb, int gg) { + + auto p = jit_conv_call_s(); + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + int gb = gg * jcp.nb_ch_blocking; + int g = gb * group_block; + + int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; + int ow_s = owb * jcp.ow_block; + int iw_s = ow_s * jcp.stride_w; + + auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0; + int32_t *compensation_w = jcp.signed_input ? compensation + g : 0; + + auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s); + auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s); + auto wht_w = weights + wht_blk_off(weights_d, gb, 0); + + auto scales = &oscales[jcp.is_oc_scale * g]; + + int dilate_h = jcp.dilate_h + 1; + int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h)); + int i_b_overflow = nstl::min(jcp.kh, + div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1), + dilate_h)); + int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); + + size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride; + p.src = src_w + i_t_overflow * dilate_h * src_h_stride; + p.dst = dst_w; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.oc_blocks = gb; + p.kh_padding = kh_padding; + p.scales = scales; + p.t_overflow = i_t_overflow; + p.b_overflow = i_b_overflow; + p.owb = owb; + + kernel_->jit_ker(&p); + }); +} + +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::u8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s8>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::s32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::s8, data_type::f32>; +template struct jit_avx512_core_x8s8s32x_convolution_fwd_t< + data_type::u8, data_type::f32>; +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp new file mode 100644 index 0000000000..203ebdf942 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, impl::data_type_t dst_type> +struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""), + jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, data_type::s8, data_type::undef, + dst_type, data_type::s32) + && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type, + data_type::f32, data_type::s32, data_type::s8, + data_type::u8)) + && !has_zero_dim_memory(); + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf( + jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_, + *attr(), mkldnn_get_max_threads()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status; + } + + jit_conv_conf_t jcp_; + }; + + jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override + { + const auto &_pd = pd(); + if (_pd->ndims() == 3) + execute_forward_1d(ctx); + else if (_pd->jcp_.is_depthwise) + execute_forward_2d_dw(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + void execute_forward_2d_dw(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_avx512_core_x8s8s32x_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp new file mode 100644 index 0000000000..142af1f541 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp @@ -0,0 +1,1034 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_avx512_core_x8s8s32x_deconvolution.hpp" + +#define GET_OFF(field) offsetof(jit_deconv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; +using namespace Xbyak; + +using namespace nstl; + +#define wht_blk_off(d, g, ...) \ + (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \ + (d).blk_off(__VA_ARGS__)) + +status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( + jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, + memory_desc_t &src_md, memory_desc_t &weights_md, + memory_desc_t &dst_md, const bool with_bias, + memory_desc_t &bias_md, const primitive_attr_t &attr) { + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper bias_d(&bias_md); + + if (!(mayiuse(avx512_core) + && one_of(src_d.data_type(), data_type::u8, data_type::s8) + && weights_d.data_type() == data_type::s8 + && one_of(dst_d.data_type(), data_type::f32, data_type::s32, + data_type::s8, data_type::u8))) + return status::unimplemented; + + jcp = zero<decltype(jcp)>(); + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + jcp.signed_input = src_d.data_type() == data_type::s8; + const int ndims = jcp.ndims = dst_d.ndims(); + const bool is_1d = ndims == 3; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; + jcp.is_depthwise = true && with_groups + && utils::everyone_is(1, jcp.ic_without_padding, + jcp.oc_without_padding); + + /* TODO: future work, on hold until depthwise specialized kernel is + * implemented. */ + if (jcp.is_depthwise && jcp.signed_input) + return status::unimplemented; + + format_tag_t dat_tag = utils::pick(ndims - 3, + format_tag::nwc, format_tag::nhwc); + + if (src_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(src_md, dat_tag)); + jcp.src_tag = dat_tag; + } else { + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + } + if (jcp.src_tag != dat_tag) + return status::unimplemented; + + if (dst_d.format_kind() == format_kind::any) { + CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); + jcp.dst_tag = dat_tag; + } else { + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + } + if (jcp.dst_tag != dat_tag) + return status::unimplemented; + + auto set_or_check_wei_format = [&]() { + using namespace format_tag; + + format_tag_t wei_tag = is_1d + ? (jcp.is_depthwise + ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i)) + : (jcp.is_depthwise + ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i)); + + memory_desc_t want_wei_md = weights_md; + memory_desc_init_by_tag(want_wei_md, wei_tag); + if (jcp.signed_input && !jcp.is_depthwise) { + want_wei_md.extra.flags = 0 + | memory_extra_flags::compensation_conv_s8s8 + | memory_extra_flags::scale_adjust; + want_wei_md.extra.compensation_mask = (1 << 0) + + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); + want_wei_md.extra.scale_adjust = + mayiuse(avx512_core_vnni) ? 1.f : 0.5f; + } + + if (weights_md.format_kind == format_kind::any) { + weights_md = want_wei_md; + return true; + } + + return weights_md == want_wei_md; + }; + + if (!set_or_check_wei_format()) + return status::unimplemented; + + jcp.with_bias = with_bias; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + } + + jcp.prop_kind = cd.prop_kind; + jcp.mb = src_d.dims()[0]; + jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + if (jcp.is_depthwise) { + jcp.ch_block = 16; + jcp.oc_block = 1; + jcp.ic_block = 1; + } else { + jcp.ch_block = 1; + jcp.oc_block = 16; + jcp.ic_block = 16; + + if (jcp.ngroups == 1) { + jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); + jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); + } + if (jcp.ic % jcp.ic_block != 0) + return status::unimplemented; + } + + jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) + || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) + return status::unimplemented; + + /* padding: bottom and right */ + jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.oh + jcp.t_pad - 1); + jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.ow + jcp.l_pad - 1); + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + jcp.ver = ver_avx512_core; + if (mayiuse(avx512_core_vnni)) + jcp.ver = ver_vnni; + const auto &oscales = attr.output_scales_; + jcp.is_oc_scale = oscales.mask_ == 1 << 1; + + assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0)); + + jcp.dst_dt = dst_d.data_type(); + jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; + jcp.typesize_bia + = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; + jcp.typesize_in = types::data_type_size(src_d.data_type()); + jcp.typesize_out = types::data_type_size(dst_d.data_type()); + + jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); + jcp.nb_oc = jcp.oc / jcp.oc_block; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + /* kernel blocking params */ + const int regs = jcp.ver == ver_vnni ? 30 : 28; + jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); + for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) + if (jcp.nb_oc % jcp.nb_oc_blocking == 0 + && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) + break; + + jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + + if (jcp.ow < jcp.ur_w) { + jcp.ur_w = jcp.ow; + jcp.ur_w_tail = 0; + } else { + for (; jcp.ur_w >= 1; jcp.ur_w--) { + /* ur_w should be multiple of stride_w in order + to simplify logic for get_ow_start and get_ow_end */ + bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; + + /* boundary conditions: + These conditions ensure all elements close to boundary + are computed in a single call of compute loop */ + bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + int r_overflow_no_tail + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + bool right_boundary_covered + = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; + + if (is_multiple_of_stride && left_boundary_covered + && right_boundary_covered) + break; + else if (jcp.ur_w == 1) + /* The boundary conditions above are also important + to maintain simplicity of calls to icb_loop, + if those conditions are not satisfied, + then special cases will need to be added + to use correct l_overflow/r_overflow values + when different iterations of compute loop + work on the locations close to boundary. + So to keep code simple, return unimplemented + for extreme case when a good ur_w cannot be found. + */ + return status::unimplemented; + } + } + + jcp.wei_adj_scale = + (weights_d.extra().flags | memory_extra_flags::scale_adjust) + ? weights_d.extra().scale_adjust : 1.f; + + jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; + return status::success; +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) { + using namespace primitive_kind; + const auto &p = attr_.post_ops_; + + if (position == 0) { + /* eltwise before sum */ + return p.contain(eltwise, 0); + } else if (position == 1) { + /* eltwise after sum */ + return p.contain(sum, 0) && p.contain(eltwise, 1); + } + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) { + int nb_oc_block + = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; + eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w); +} + +bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + using namespace primitive_kind; + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + + switch (p.len_) { + case 0: return true; + case 1: return is_eltwise(0) || p.contain(sum, 0); + case 2: + return (p.contain(sum, 0) && is_eltwise(1)) + || (p.contain(sum, 1) && is_eltwise(0)); + default: return false; + } + + return false; +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, + const primitive_attr_t &attr) { + if (jcp.signed_input && jcp.ver != ver_vnni) { + dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16); + scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, + bool h_padded) { + + const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w; + + auto src_offset = [=](int oj, int icb, int ki) { + return jcp.typesize_in + * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) + * jcp.ngroups * jcp.ic_without_padding + + icb * 4); + }; + + auto kernel_offset = [=](int ocb, int icb, int ki) { + return jcp.typesize_in + * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + + icb * jcp.oc_block * jcp.ic_block / 4 + + ki * ch_block_all); + }; + + auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) { + if (jcp.ver == ver_vnni) { + vpdpbusd(vreg_acc, vreg_src, vreg_wei); + } else if (jcp.is_depthwise) { + vpmulld(zmm_tmp, vreg_src, vreg_wei); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } else { + vpmaddubsw(zmm_tmp, vreg_src, vreg_wei); + vpmaddwd(zmm_tmp, zmm_tmp, zmm_one); + vpaddd(vreg_acc, vreg_acc, zmm_tmp); + } + }; + + for (int ki = 0; ki < jcp.kw; ki++) { + + int jj_start = get_ow_start(ki, l_overflow); + int jj_end = get_ow_end(ur_w, ki, r_overflow); + + int _start = (jcp.signed_input) ? 0 : jj_start; + int _end = (jcp.signed_input) ? ur_w : jj_end; + + int tail_size = jcp.ic_without_padding % 4; + int n_ic_blocks = jcp.is_depthwise ? + 1 : + (last_ic_block_flag & ~no_last_block ? + div_up(jcp.ic_without_padding % jcp.ic_block, + 4) : + jcp.ic_block / 4); + + for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { + if (h_padded == true) { + /* fill padded area with shifted values */ + Zmm inp = zmm_inp(0, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } else { + + for (int jj = _start; jj < _end; jj += ur_w_stride) { + + int aux_src_off = src_offset(jj, icb1, ki); + + if (jj >= jj_start && jj < jj_end + && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { + if (jcp.is_depthwise) { + vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } else if ((last_ic_block_flag & last_sp_block) + && tail_size != 0 && icb1 == n_ic_blocks - 1) { + xmm_t xmm_tmp = xmm_t( + zmm_inp(jj, jcp.nb_oc_blocking).getIdx()); + for (int r = 0; r < tail_size; ++r) + vpinsrb(xmm_tmp, xmm_tmp, + ptr[aux_reg_src + aux_src_off + r], r); + vpbroadcastd( + zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); + } else { + vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), + EVEX_compress_addr( + aux_reg_src, aux_src_off)); + } + if (jcp.signed_input) + vpsubb(zmm_inp(jj, jcp.nb_oc_blocking), + zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift); + } else { + /* fill padded area with shifted values */ + if (jcp.signed_input) { + Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking); + vpxord(inp, inp, inp); + vpsubb(inp, inp, zmm_shift); + } + } + } + } + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + int aux_filt_off = kernel_offset(ocb, icb1, ki); + + if (_end - _start > 0) { + if (jcp.is_depthwise) + vpmovsxbd(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + else + vmovups(zmm_wei, + EVEX_compress_addr(aux_reg_filt, aux_filt_off)); + } + for (int jj = _start; jj < _end; jj += ur_w_stride) { + Zmm inp = (h_padded == true) ? + zmm_inp(0, jcp.nb_oc_blocking) : + zmm_inp(jj, jcp.nb_oc_blocking); + compute(zmm_out(jj, ocb), zmm_wei, inp); + } + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w, + int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { + + int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; + int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw + * jcp.ngroups * jcp.ic_without_padding; + const int stride_h = jcp.signed_input ? 1 : jcp.stride_h; + int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; + + Label kh_loop_label, skip_kh_loop; + Label t_overflow_label, no_t_overflow_label, b_overflow_label, + no_b_overflow_label; + + mov(aux_reg_src, reg_src); + mov(aux_reg_filt, reg_filt); + + if (jcp.signed_input && jcp.ndims > 3) { + /* Weights are transposed, so first compute 'bottom' padding. */ + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + cmp(reg_overflow, 0); + je(no_b_overflow_label, T_NEAR); + L(b_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(b_overflow_label, T_NEAR); + } + L(no_b_overflow_label); + } + + mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); + + if (jcp.signed_input || ((!jcp.signed_input) + && ((min(jcp.t_pad, jcp.b_pad) < 0) + || ((jcp.kh - 1) * (jcp.dilate_h + 1) + < nstl::max(jcp.t_pad, jcp.b_pad))))) { + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + } + + L(kh_loop_label); { + compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); + sub(aux_reg_src, shift_src_ih); + add(aux_reg_filt, shift_filt_kh); + dec(reg_kh); + + /* Insert weight compensation in stride 'holes' */ + if (jcp.signed_input && jcp.stride_h > 1) { + Label kh_comp_loop; + + cmp(reg_kh, 0); + je(skip_kh_loop, T_NEAR); + mov(reg_comp_strides, jcp.stride_h - 1); + L(kh_comp_loop); + { + compute_ker( + ur_w, 0, 0, last_ic_block_flag, true); + add(aux_reg_filt, shift_filt_kh); + dec(reg_comp_strides); + cmp(reg_comp_strides, 0); + jg(kh_comp_loop, T_NEAR); + } + } + cmp(reg_kh, 0); + jg(kh_loop_label, T_NEAR); + } + L(skip_kh_loop); + if (jcp.signed_input && jcp.ndims > 3) { + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + cmp(reg_overflow, 0); + je(no_t_overflow_label, T_NEAR); + L(t_overflow_label); { + compute_ker(ur_w, 0, 0, last_ic_block_flag, true); + + add(aux_reg_filt, shift_filt_kh); + dec(reg_overflow); + cmp(reg_overflow, 0); + jg(t_overflow_label, T_NEAR); + } + L(no_t_overflow_label); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) { + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vpxord(zmm, zmm, zmm); + } + } + if (jcp.signed_input) { + xor_(reg_scratch, reg_scratch); + Reg8 _t8 = reg_scratch.cvt8(); + mov(_t8, (int8_t)-128); + vpbroadcastb(zmm_shift, _t8); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps( + data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) { + zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in; + switch (type_in) { + case data_type::f32: + case data_type::s32: vmovups(zmm, op); break; + case data_type::s8: vpmovsxbd(zmm, op); break; + case data_type::u8: vpmovzxbd(zmm, op); break; + default: assert(!"unsupported data type"); + } + if (type_in != data_type::f32) + vcvtdq2ps(zmm_in, zmm_in); +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output( + int ur_w, bool last_oc_block) { + mov(reg_bias, ptr[param1 + GET_OFF(bias)]); + mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); + + if (jcp.signed_input) + mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); + + const auto &p = attr_.post_ops_; + const int sum_idx = p.find(primitive_kind::sum); + const float *p_sum_scale + = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; + if (p_sum_scale && *p_sum_scale != 1.f) + mov(reg_ptr_sum_scale, (size_t)p_sum_scale); + + if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) { + mov(reg_bias_alpha, float2int(jcp.wei_adj_scale)); + vmovq(xmm_bias_alpha(), reg_bias_alpha); + vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha()); + } + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + int scale_offset + = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); + + auto zmm_bias = zmm_tmp; + if (jcp.with_bias) { + int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; + auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); + cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); + if (jcp.signed_input && jcp.ver != ver_vnni) + vmulps(zmm_bias, zmm_bias, zmm_bias_alpha()); + } + if (jcp.signed_input) { + int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; + auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); + cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag); + } + + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + vcvtdq2ps(zmm, zmm); + if (jcp.signed_input) + vaddps(zmm, zmm, zmm_comp); + if (jcp.with_bias) + vaddps(zmm, zmm, zmm_bias); + zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm; + vmulps(mask_zmm, zmm, + EVEX_compress_addr(reg_ptr_scales, scale_offset)); + } + } + if (maybe_eltwise(0)) + compute_eltwise(ur_w); + if (p_sum_scale) { // post_op: sum + for (int k = 0; k < jcp.nb_oc_blocking; k++) { + const bool mask_flag + = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; + for (int j = 0; j < ur_w; j++) { + int aux_output_offset + = jcp.typesize_out + * (k * jcp.oc_block + + j * jcp.oc_without_padding * jcp.ngroups); + auto addr = EVEX_compress_addr(reg_dst, aux_output_offset); + Zmm zmm = zmm_out(j, k); + cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); + if (*p_sum_scale == 1.f) + vaddps(zmm, zmm_prev_dst); + else + vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); + } + } + } + if (maybe_eltwise(1)) + compute_eltwise(ur_w); + + for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { + const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; + for (int ur = 0; ur < ur_w; ur++) { + zmm_t zmm = zmm_out(ur, ocb); + if (jcp.dst_dt == data_type::u8) { + vpxord(zmm_zero, zmm_zero, zmm_zero); + vmaxps(zmm, zmm_zero, zmm); + } + if (jcp.dst_dt != data_type::f32) + vcvtps2dq(zmm, zmm); + } + for (int ur = 0; ur < ur_w; ur++) { + int aux_dst_off = jcp.typesize_out + * (ur * jcp.ngroups * jcp.oc_without_padding + + ocb * jcp.oc_block); + auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); + + zmm_t zmm = zmm_out(ur, ocb); + zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm; + switch (jcp.dst_dt) { + case data_type::f32: + case data_type::s32: vmovups(addr, r_zmm); break; + case data_type::s8: vpmovsdb(addr, r_zmm); break; + case data_type::u8: vpmovusdb(addr, r_zmm); break; + default: assert(!"unknown dst_dt"); + } + } + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop( + int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { + + int shift_src_icb = jcp.typesize_in * jcp.ic_block; + int shift_filt_icb + = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block; + + prepare_output(ur_w); + + Label skip_icb_loop, icb_loop_label; + + mov(reg_icb, jcp.nb_ic); + L(icb_loop_label); { + + if (jcp.ic_without_padding != jcp.ic) { + Label common_ker, end_ker; + cmp(reg_icb, 1); + jg(common_ker, T_NEAR); + + kh_loop(ur_w, l_overflow, r_overflow, + is_last_sp_block ? last_sp_block : last_ic_block); + jmp(end_ker, T_NEAR); + + L(common_ker); + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + + L(end_ker); + } else { + kh_loop(ur_w, l_overflow, r_overflow, no_last_block); + } + + add(reg_src, shift_src_icb); + add(reg_filt, shift_filt_icb); + dec(reg_icb); + cmp(reg_icb, 0); + jg(icb_loop_label, T_NEAR); + } + + /* come-back pointers */ + sub(reg_src, jcp.nb_ic * shift_src_icb); + sub(reg_filt, jcp.nb_ic * shift_filt_icb); + L(skip_icb_loop); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + Label common_store, end_store; + mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); + if (jcp.is_depthwise) + cmp(reg_oc_blocks, jcp.nb_ch - 1); + else + cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); + jne(common_store, T_NEAR); + + store_output(ur_w, true); + jmp(end_store, T_NEAR); + + L(common_store); + store_output(ur_w, false); + + L(end_store); + + } else { + store_output(ur_w, false); + } +} + +void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() { + preamble(); + + xor_(reg_scratch, reg_scratch); + Reg16 _t = reg_scratch.cvt16(); + mov(_t, 0x1); + vpbroadcastw(zmm_one, _t); + + if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { + int tail_size = jcp.is_depthwise ? + jcp.ngroups % jcp.ch_block : + jcp.oc_without_padding % jcp.oc_block; + int mask = (1 << tail_size) - 1; + Reg32 regw_tmp = reg_nur_w.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + mov(reg_src, ptr[param1 + GET_OFF(src)]); + mov(reg_filt, ptr[param1 + GET_OFF(filt)]); + mov(reg_dst, ptr[param1 + GET_OFF(dst)]); + + int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups + * jcp.oc_without_padding; + int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups + * jcp.ic_without_padding; + + int l_overflow = max( + 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); + int r_overflow + = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) + / jcp.stride_w); + + int r_overflow1 + = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) + - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) + / jcp.stride_w); + int nur_w = jcp.ow / jcp.ur_w; + if (r_overflow1 > 0) + nur_w--; + + if (jcp.ur_w == jcp.ow) { + icb_loop(jcp.ur_w, l_overflow, r_overflow, true); + } else if (nur_w == 0) { + icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + if (jcp.ur_w_tail != 0) + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } else { + xor_(reg_nur_w, reg_nur_w); + if (l_overflow > 0) { + icb_loop(jcp.ur_w, l_overflow, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + } + if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + icb_loop(jcp.ur_w, 0, 0, false); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + inc(reg_nur_w); + cmp(reg_nur_w, nur_w); + jl(ow_loop_label, T_NEAR); + } + } + if (r_overflow1 > 0) { + icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0); + add(reg_src, src_shift); + add(reg_dst, dst_shift); + } + if (jcp.ur_w_tail != 0) { + icb_loop(jcp.ur_w_tail, 0, r_overflow, true); + } + } + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template <data_type_t src_type, data_type_t dst_type> +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type, + dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get<float>(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast<wei_data_t *>(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + int n{ 0 }, g{ 0 }, occ{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + + p.dst = dst + dst_d.blk_off(n, g_oc); + p.src = src + src_d.blk_off(n, g_ic); + p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); + p.bias = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + p.compensation = (jcp.signed_input) ? compensation + g_oc : 0; + p.scales = &oscales[jcp.is_oc_scale * g_oc]; + p.t_overflow = 0; + p.b_overflow = 0; + p.kh_padding = jcp.kh; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + + kernel_->jit_ker(&p); + + ++start; + if (jcp.loop_order == loop_ngc) + nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); + else if (jcp.loop_order == loop_cgn) + nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); + else + assert(!"unsupported loop order"); + } + }); +} + +template <data_type_t src_type, data_type_t dst_type> +void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type, + dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + auto &jcp = kernel_->jcp; + + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int nb_groups = jcp.nb_ch; + + size_t src_h_stride = src_d.blk_off(0, 0, 1); + size_t dst_h_stride = dst_d.blk_off(0, 0, 1); + size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); + + const float *oscales = pd()->attr()->output_scales_.scales_; + if (jcp.signed_input && jcp.ver != ver_vnni) { + auto local_scales + = scratchpad(ctx).template get<float>(key_conv_adjusted_scales); + size_t count = pd()->attr()->output_scales_.count_; + float factor = 1.f / pd()->jcp_.wei_adj_scale; + if (count == 1) { + utils::array_set(local_scales, oscales[0] * factor, 16); + } else { + for (size_t c = 0; c < count; c++) + local_scales[c] = oscales[c] * factor; + } + oscales = local_scales; + } + size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw; + auto w = const_cast<wei_data_t *>(weights); + int32_t *compensation + = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0; + + parallel(0, [&](const int ithr, const int nthr) { + int start{ 0 }, end{ 0 }; + int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; + balance211(work_amount, nthr, ithr, start, end); + + auto p = jit_deconv_call_s(); + + /*loop order = cgn*/ + int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }; + if (jcp.loop_order == loop_ngc) + nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, + oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, + oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + while (start < end) { + + int ocb = occ * jcp.nb_oc_blocking; + int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; + int g_ic = g * jcp.ch_block * jcp.ic; + int work_rem = end - start; + int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; + + auto dst_w = dst + dst_d.blk_off(n, g_oc); + auto src_w = src + src_d.blk_off(n, g_ic); + auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); + auto bias_w = jcp.with_bias ? + bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) : + 0; + int32_t *compensation_w + = (jcp.signed_input) ? compensation + g_oc : 0; + + auto scales = &oscales[jcp.is_oc_scale * g_oc]; + for (int oj = oh_s; oj < oh_e; oj++) { + int ih_max = 0, kh_lo = 0, kh_len = 0; + if (jcp.dilate_h != 0 && jcp.stride_h == 1) { + /* dilation */ + int dilate_h = jcp.dilate_h + 1; + // Note: use div_up to account for "holes" in filter + int o_t_overflow = div_up( + max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), + dilate_h); + int o_b_overflow + = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh + + oj - jcp.b_pad), + dilate_h); + kh_len = jcp.kh - o_t_overflow - o_b_overflow; + kh_lo = o_b_overflow; + ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; + } else { + int o_t_overflow = max( + 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); + int o_b_overflow + = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) + / jcp.stride_h); + int overflow_kh_hi = jcp.kh - 1 + - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h; + int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; + + kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + + 1 - o_t_overflow - o_b_overflow; + kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; + ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; + } + + int wei_stride + = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0; + p.src = src_w + ih_max * src_h_stride; + p.dst = dst_w + oj * dst_h_stride; + p.filt = wht_w + wei_stride; + p.bias = bias_w; + p.compensation = compensation_w; + p.t_overflow = max( + 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h + + 1)); + p.b_overflow = kh_lo; + p.kh_padding = kh_len; + p.scales = scales; + p.oc_blocks = jcp.is_depthwise ? g : ocb; + kernel_->jit_ker(&p); + } + if (jcp.loop_order == loop_ngc) + nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, + oc_chunks, oh_s, jcp.oh); + else if (jcp.loop_order == loop_cgn) + nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, + jcp.mb, oh_s, jcp.oh); + else + assert(!"unsupported loop order"); + } + }); +} + +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8, + data_type::u8>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8, + data_type::s8>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8, + data_type::f32>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8, + data_type::s32>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8, + data_type::u8>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8, + data_type::s8>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8, + data_type::f32>; +template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8, + data_type::s32>; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp new file mode 100644 index 0000000000..901038fa48 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp @@ -0,0 +1,237 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP +#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "cpu_primitive.hpp" +#include "cpu_memory.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "nstl.hpp" + +#include "cpu_deconvolution_pd.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +typedef enum { + no_last_block = 0x1U, + last_ic_block = 0x2U, + last_sp_block = 0x4U, +} ker_block_t; + +struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); + + jit_avx512_core_x8s8s32x_deconv_fwd_kernel( + const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>( + this, jcp.eltwise); + generate(); + jit_ker = (void (*)(jit_deconv_call_s *))getCode(); + } + + ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const deconvolution_desc_t &cd, + memory_desc_t &src_md, + memory_desc_t &weights_md, + memory_desc_t &dst_md, + const bool with_bias, + memory_desc_t &bias_md, + const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp, const primitive_attr_t &attr); + + const jit_conv_conf_t &jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_deconv_call_s *); +private: + jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_; + using reg64_t = const Xbyak::Reg64; + using zmm_t = const Xbyak::Zmm; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_src = r8; + reg64_t reg_filt = r9; + reg64_t reg_dst = r10; + reg64_t param1 = abi_param1; + reg64_t reg_kh = abi_not_param1; + reg64_t reg_nur_w = rbx; + reg64_t reg_bias = rdx; + reg64_t reg_icb = reg_bias; + reg64_t reg_ptr_scales = rax; + reg64_t reg_oc_blocks = rsi; + + reg64_t aux_reg_src = r11; + reg64_t aux_reg_filt = r12; + + reg64_t reg_compensation = r14; + reg64_t reg_scratch = r14; + reg64_t reg_ptr_sum_scale = r11; + reg64_t reg_bias_alpha = abi_not_param1; + reg64_t reg_overflow = rax; + reg64_t reg_comp_strides = reg_overflow; + + Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); + zmm_t zmm_tmp = zmm_t(28); + zmm_t zmm_one = zmm_t(29); + /* used during write-out section of store_output */ + zmm_t zmm_zero = zmm_t(31); + zmm_t zmm_wei = zmm_t(31); + + /* signed input */ + zmm_t zmm_shift = zmm_t(30); + zmm_t zmm_comp = zmm_t(30); + zmm_t zmm_bias = zmm_t(31); + zmm_t zmm_prev_dst = zmm_t(31); + + zmm_t zmm_out(int i_ur, int i_oc) { + int idx = i_ur * jcp.nb_oc_blocking + i_oc; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_inp(int i_ic, int nb_x_blocking) { + int idx = i_ic + nb_x_blocking * jcp.ur_w; + assert(idx < 31); + return zmm_t(idx); + } + zmm_t zmm_bias_alpha() { + return zmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + xmm_t xmm_bias_alpha() { + return xmm_t(jcp.nb_oc_blocking * jcp.ur_w); + } + + int get_ow_start(int ki, int l_overflow) { + int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w + + l_overflow * jcp.stride_w + - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return res; + } + + int get_ow_end(int ur_w, int ki, int r_overflow) { + if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) + ur_w += nstl::min(0, jcp.r_pad); // remove negative padding + int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w + + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); + while (res < 0) + res += jcp.stride_w; + return ur_w - res; + } + bool maybe_eltwise(int position); + void compute_eltwise(int ur_w); + void prepare_output(int ur_w); + void store_output(int ur_w, bool last_oc_block); + void compute_ker(int ur_w, int l_overflow, int r_overflow, + ker_block_t last_ic_block_flag, bool h_padded = false); + void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); + void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); + void generate(); + void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op, + bool mask_flag); +}; + +template <impl::data_type_t src_type, impl::data_type_t dst_type> +struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_deconvolution_fwd_pd_t { + using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""), + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type, dst_type>); + + status_t init() { + bool ok = true + && is_fwd() + && (desc()->alg_kind & alg_kind::deconvolution_direct) + && desc()->src_desc.data_type == src_type + && desc()->dst_desc.data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + desc()->bias_desc.data_type, data_type::f32, + data_type::s32, data_type::s8, data_type::u8)) + && desc()->accum_data_type == data_type::s32; + if (!ok) return status::unimplemented; + + status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel:: + init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_, + with_bias(), bias_md_, *attr()); + + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad, + jcp_, *attr()); + + return status::success; + } + + jit_conv_conf_t jcp_; + }; + + _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + { + kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_, + *pd()->attr()); + } + + ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<data_type::s8>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if(pd()->ndims() == 3) + execute_forward_1d(ctx); + else + execute_forward_2d(ctx); + return status::success; + } + +private: + void execute_forward_1d(const exec_ctx_t &ctx) const; + void execute_forward_2d(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp new file mode 100644 index 0000000000..c09592d5c9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp @@ -0,0 +1,773 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_AVX2_GENERATOR_HPP +#define CPU_JIT_AVX2_GENERATOR_HPP + +#include <limits.h> + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_utils/jit_utils.hpp" + +#if defined(_WIN32) && !defined(__GNUC__) +# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ +#else +# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) +#endif + +#if defined(_WIN32) +# define OFFSET_SHADOWSPACE 0x28 +#endif + +#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \ + const char *name() const override { return STRINGIFY(jit_name); } \ + const char *source_file() const override { return __FILE__; } + +namespace mkldnn { +namespace impl { +namespace cpu { + +// TODO: move this to jit_generator class? +namespace { + +typedef enum { + PAGE_4K = 4096, + PAGE_2M = 2097152, +} cpu_page_size_t; + +// TODO: move this somewhere else? Although this is only used by jit kernels +// (Roma) +static inline int float2int(float x) { + union { + float vfloat; + int vint; + } cvt; + cvt.vfloat = x; + return cvt.vint; +} + +// TODO: A GPR class that hides ABI details from the JIT kernels and allows +// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and +// stack register (sr). +// +// This will allow using syntax like this: +// +// param = gpr0; +// reg_input = gpr0; +// reg_output = gpr1; +// ... +// +// #ifndef XBYAK64 +// mov(param, ptr[sr]) +// #endif +// +// (Roma) + +#ifdef XBYAK64 +constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { + Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, + Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15, +#ifdef _WIN32 + Xbyak::Operand::RDI, Xbyak::Operand::RSI, +#endif +}; + +#ifdef _WIN32 +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX), + abi_param2(Xbyak::Operand::RDX), + abi_param3(Xbyak::Operand::R8), + abi_param4(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RDI); +#else +static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI), + abi_param2(Xbyak::Operand::RSI), + abi_param3(Xbyak::Operand::RDX), + abi_param4(Xbyak::Operand::RCX), + abi_param5(Xbyak::Operand::R8), + abi_param6(Xbyak::Operand::R9), + abi_not_param1(Xbyak::Operand::RCX); +#endif +#endif + +inline unsigned int get_cache_size(int level, bool per_core = true){ + unsigned int l = level - 1; + // Currently, if XByak is not able to fetch the cache topology + // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core. + if (cpu.getDataCacheLevels() == 0){ + const int L1_cache_per_core = 32000; + const int L2_cache_per_core = 512000; + const int L3_cache_per_core = 1024000; + int num_cores = per_core ? 1 : mkldnn_get_max_threads(); + switch(l){ + case(0): return L1_cache_per_core * num_cores; + case(1): return L2_cache_per_core * num_cores; + case(2): return L3_cache_per_core * num_cores; + default: return 0; + } + } + if (l < cpu.getDataCacheLevels()) { + return cpu.getDataCacheSize(l) + / (per_core ? cpu.getCoresSharingDataCache(l) : 1); + } else + return 0; +} + +} + +class jit_generator : public Xbyak::CodeGenerator +{ +private: + const size_t xmm_len = 16; +#ifdef _WIN32 + const size_t xmm_to_preserve_start = 6; + const size_t xmm_to_preserve = 10; +#else + const size_t xmm_to_preserve_start = 0; + const size_t xmm_to_preserve = 0; +#endif + + const size_t num_abi_save_gpr_regs + = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); + + const size_t size_of_abi_save_regs + = num_abi_save_gpr_regs * rax.getBit() / 8 + + xmm_to_preserve * xmm_len; + +public: + enum { + _cmp_eq_oq = 0u, + _cmp_lt_os = 1u, + _cmp_le_os = 2u, + _cmp_neq_uq = 4u, + _cmp_nlt_us = 5u, + _cmp_nle_us = 6u, + + _op_floor = 1u, + _op_mxcsr = 4u, + }; + + Xbyak::Reg64 param1 = abi_param1; + const int EVEX_max_8b_offt = 0x200; + const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; + + inline size_t get_size_of_abi_save_regs() { + return size_of_abi_save_regs; + } + + void preamble() { + if (xmm_to_preserve) { + sub(rsp, xmm_to_preserve * xmm_len); + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i)); + } + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + push(Xbyak::Reg64(abi_save_gpr_regs[i])); + if (mayiuse(avx512_common)) { + mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); + } + } + + void mic_prefetcht0(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht0(a); + } + + void mic_prefetcht1(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht1(a); + } + + void mic_prefetcht2(Xbyak::Address a) { + if (mayiuse(avx512_mic)) + prefetcht2(a); + } + + void uni_vzeroupper() { + if (mayiuse(avx) && !mayiuse(avx512_mic)) + vzeroupper(); + } + + void postamble() { + for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) + pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); + if (xmm_to_preserve) { + for (size_t i = 0; i < xmm_to_preserve; ++i) + movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]); + add(rsp, xmm_to_preserve * xmm_len); + } + uni_vzeroupper(); + ret(); + } + + template<typename T> + Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, + T raw_offt, bool bcast = false) + { + using Xbyak::Zmm; + using Xbyak::Reg64; + using Xbyak::Address; + using Xbyak::RegExp; + + assert(raw_offt <= INT_MAX); + auto offt = static_cast<int>(raw_offt); + + int scale = 0; + + if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { + offt = offt - 2 * EVEX_max_8b_offt; + scale = 1; + } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) { + offt = offt - 4 * EVEX_max_8b_offt; + scale = 2; + } + + auto re = RegExp() + base + offt; + if (scale) + re = re + reg_EVEX_max_8b_offt * scale; + + if (bcast) + return zword_b [re]; + else + return zword [re]; + } + + Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, + const Xbyak::Reg64 &tmp_reg, bool bcast = false) { + if (offt > INT_MAX) { + mov(tmp_reg, offt); + return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; + } else { + return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; + } + } + + Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, + size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { + if (raw_offt > INT_MAX) { + return make_safe_addr(base, raw_offt, reg_offt, bcast); + } else { + return EVEX_compress_addr(base, raw_offt, bcast); + } + } + + void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + add(base, reg_offt); + } else { + add(base, raw_offt); + } + } + + void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, + const Xbyak::Reg64 ®_offt) { + if (raw_offt > INT_MAX) { + mov(reg_offt, raw_offt); + sub(base, reg_offt); + } else { + sub(base, raw_offt); + } + } + + // Disallow char-based labels completely + void L(const char *label) = delete; + void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } + + void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + pxor(x2, op); + } + void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpxor(x1, x2, op); + } else { + vxorps(x1, x2, op); + } + } + void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, + const Xbyak::Operand &op) { + vpxord(x1, x2, op); + } + + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movss(addr, x); + } + void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovss(addr, x); + } + void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movss(x, addr); + } + void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovss(x, addr); + } + + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) { + movsd(addr, x); + } + void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) { + vmovsd(addr, x); + } + void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) { + movsd(x, addr); + } + void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) { + vmovsd(x, addr); + } + + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovdqu(addr, x); + } + void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { + vmovdqu32(addr, x); + } + + void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + movdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { + vmovdqu(x, addr); + } + void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { + vmovdqu32(x, addr); + } + + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movups(addr, x); + } + void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovups(addr, x); + } + + void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movups(x, op); + } + void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vmovups(x, op); + } + + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { + movntps(addr, x); + } + void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { + vmovntps(addr, x); + } + + void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movss(x, op); + shufps(x, x, 0x0); + } + void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (op.isMEM() || mayiuse(avx2)) { + vbroadcastss(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movss(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + movsd(x, op); + pshufd(x, x, 0x0); + } + void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (mayiuse(avx2)) { + vpbroadcastd(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) movsd(t, op); + vinsertf128(x, x, t, 1); + vshufps(x, x, x, 0); + } + } + + void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpss(x, op); + } + void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { + Xbyak::Xmm x1_(x1.getIdx()); + Xbyak::Xmm x2_(x2.getIdx()); + vrcpss(x1_, x1_, x2_); + } + void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { + Xbyak::Xmm x_(x.getIdx()); + vrcpss(x_, x_, op); + } + + void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + rcpps(x, op); + } + void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vrcpps(x, op); + } + void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { + vrcp14ps(x, op); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + divps(x, op2); + } + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vdivps(x, op1, op2); + } + + void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + divps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vdivps(x, op1, op2); + } + + void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + addps(x, op2); + } + void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vaddps(x, op1, op2); + } + + void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, + const Xbyak::Operand& op) { + assert(x1.getIdx() == x2.getIdx()); + psignd(x1, op); + } + void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2, + const Xbyak::Operand& op) { + vpsignd(x1, x2, op); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + subps(x, op2); + } + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vsubps(x, op1, op2); + } + + void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { + movups(buf, op1); + subps(buf, op2); + if (x.getIdx() != buf.getIdx()) { + movups(x, buf); + } + } + + void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { + vsubps(x, op1, op2); + } + + void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + mulps(x, op2); + } + void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmulps(x, op1, op2); + } + + void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x1, x2); + addps(x1, op); + } + void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd213ps(x1, x2, op); + } + + void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + addps(x1, x2); + } + void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfmadd231ps(x1, x2, op); + } + + void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + mulps(x2, op); + subps(x1, x2); + } + + void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vfnmadd231ps(x1, x2, op); + } + + void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + sqrtps(x, op); + } + void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vsqrtps(x, op); + } + + void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + paddd(x2, op); + } + void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + vpaddd(x1, x2, op); + } + + void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + andps(x1, op); + } + void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vandps(x1, x2, op); + else + vpandd(x1, x2, op); + } + + void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + assert(x1.getIdx() == x2.getIdx()); + orps(x1, op); + } + void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op = Xbyak::Operand()) { + if (!mayiuse(avx512_common) || x1.getBit() < 512) + vorps(x1, x2, op); + else + vpord(x1, x2, op); + } + + void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + pslld(x, imm); + } + void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpslld(x, op, imm); + } + + void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + assert(x.getIdx() == op.getIdx()); + psrld(x, imm); + } + void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vpsrld(x, op, imm); + } + + void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + maxps(x, op2); + } + void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vmaxps(x, op1, op2); + } + + void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + assert(x.getIdx() == op1.getIdx()); + minps(x, op2); + } + void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, + const Xbyak::Operand &op2 = Xbyak::Operand()) { + vminps(x, op1, op2); + } + + void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nle_us); + } + + void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpgtps(x1, x2, op); + } + + void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + assert(x1.getIdx() == x2.getIdx()); + cmpps(x1, op, _cmp_nlt_us); + } + + void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vcmpps(x1, x2, op, _cmp_nlt_us); + } + + void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { + ptest(x1, op); + } + + void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { + assert(!(x1.isZMM() || op.isZMM())); + vtestps(x1, op); + } + + void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op, const Xbyak::Xmm &msk) { + assert(x1.getIdx() == x2.getIdx()); + assert(msk.getIdx() == 0); + blendvps(x1, op); + } + void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op, const Xbyak::Ymm &msk) { + vblendvps(x1, x2, op, msk); + } + + void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op, + const int imm) { + roundps(x, op, imm); + } + void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op, + const int imm) { + vroundps(x, op, imm); + } + + void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtps2dq(x, op); + } + void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtps2dq(x, op); + } + + void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + cvtdq2ps(x, op); + } + void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + vcvtdq2ps(x, op); + } + + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { + movmskps(x1.cvt64(), x2); + } + void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { + vmovmskps(x1, x2); + } + + void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packssdw(x1, op); + } + void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackssdw(x1, x2, op); + } + + void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){ + assert(x1.getIdx() == x1.getIdx()); + packuswb(x1, op); + } + void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){ + vpackuswb(x1, x2, op); + } + + + void mul_by_const(const Xbyak::Reg &out, + const Xbyak::Reg64 &tmp, int value) { + // Generates a shift + add sequence for multiplicating contents of the + // out register by a known JIT-time value. Clobbers the tmp register. + // + // Pros compared to mul/imul: + // - does not require using known registers + // - not microcoded on Intel(R) Xeon Phi(TM) processors + // Still, there are probably a lot of cases when mul/imul is faster on + // Intel(R) Core(TM) processors. Not intended for critical path. + + // TODO: detect when overflow is emminent (Roma) + // TODO: detect when using mul/imul is a better option (Roma) + + int p = 0; // the current power of 2 + int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 + + xor_(tmp, tmp); + while (value) { + if (value & 1) { + int shift = p - old_p; + if (shift) { + shl(out, shift); + old_p = p; + } + add(tmp, out); + } + value >>= 1; + p++; + } + mov(out, tmp); + } + +public: + jit_generator( + void *code_ptr = nullptr, + size_t code_size = 256 * 1024 + ) : Xbyak::CodeGenerator(code_size, code_ptr) + { + } + virtual ~jit_generator() {} + + virtual const char *name() const = 0; + virtual const char *source_file() const = 0; + + const Xbyak::uint8 *getCode() { + const Xbyak::uint8 *code = CodeGenerator::getCode(); + size_t code_size = getSize(); + jit_utils::register_jit_code(code, code_size, name(), source_file()); + return code; + } + + template<typename F> const F getCode() { + return (const F)getCode(); + } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp new file mode 100644 index 0000000000..56d7f592e2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp @@ -0,0 +1,481 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_PRIMITIVE_CONF_HPP +#define JIT_PRIMITIVE_CONF_HPP + +#include <stdint.h> + +#include "common/primitive_attr.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +/* convolution */ +enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni}; +enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, + loop_ngcw, loop_nhwcg, loop_nwcg}; +enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, + loop_brl}; +enum conv_kernel_kind_t {embd_bcast, expl_bcast}; + +enum { + FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1, + FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3, + FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5, + FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7, + FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9, + FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips + loading weights-data from memory; this + needs to happen on the first Group/16 + iteration. */ + FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip + loading bias data from memory */ + FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution + pass */ +}; + +struct jit_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + conv_loop_order_t loop_order; + + int simd_w; + int ndims; + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int id, ih, iw, od, oh, ow; + int f_pad, l_pad, t_pad; + int back_pad, r_pad, b_pad; + int kd, kh, kw; + int stride_d, stride_h, stride_w; + int dilate_d, dilate_h, dilate_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + + int idp, ihp, iwp, ohp, owp; + int nb_ic, ic_block; + int nb_oc, oc_block; + int nb_ow, ow_block; + int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking + into account vector registers distribution */ + int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work + within threads */ + int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work + int nb_ic_L2; + int h_blocking; + int nb_oc_L2; + int ur_h, ur_w; + int ur_w_tail; + bool is_1stconv; + int nonblk_group_off; + /* fma avx512_core */ + conv_kernel_kind_t kernel_kind; + /* 4fma */ + int tr_iw; + int tr_src_num_guard_elems; + /* 1st conv: 4fma */ + int tr_ld; + int kh_step; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* avx512_u8s8u8 */ + int ic_nb1, ic_nb2; + int oc_nb1; + int ur_ow_max, ur_ow, ur_ow_tail; + int ur_ow_nsteps; + data_type_t bia_dt; + data_type_t dst_dt; + /* avx512: max possible value is nregs(32) - aux_regs(4) */ + int src_offsets[28]; + int src_count; + bool expl_bcast; + bool large_spatial; + int is_oc_scale; + int max_regs_ur; // maximum accumulation registers + // dw conv + int nb_ch, ch_block, nb_ch_blocking; + bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; + int aligned_threads; + // large spatial + int oh_blk_size; + // s8s8 convolution + bool signed_input; + float wei_adj_scale; +}; + +struct jit_conv_conf_2x3_wino_t { + conv_version_t ver; + + int m; + int r; + int alpha; + int tile_h, tile_w; + + int mb; + int ngroups, ic, oc, oc_without_padding; + int ih, iw, oh, ow; + int l_pad, t_pad; + int r_pad, b_pad; + int kh, kw; + int stride_h, stride_w; + int dilate_h, dilate_w; + + int nb_ic, ic_block; + int nb_oc, oc_block; + + int w_block_size, h_block_size; + + data_type_t bia_dt; + data_type_t dst_dt; + + int is_oc_scale; + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + + format_tag_t src_tag, dst_tag; // temporary workaround + bool with_bias; + bool small_mb; + + int xb, yb; + int inp_stride; + int out_stride; + int wei_stride; + int bia_stride; + + int M, N, K; + int m_block, n_block, k_block; + int n2_block, n_chunks; + int k2_block, k_chunks; + + int mb_block, nb_mb; + + size_t size_wino_src, size_wino_wei, size_wino_dst; + + int nthr; +}; + +/* + Winograd sched policy: + + Computation Unit: + W: weights transform + S: src transform + D: dst transform + G: gemm + + Thread grouping by: + i: nb_ic + o: nb_oc + t: tile_block + e: element in tile + + Note: 'i' and 'o' are omited if + i. not comblined with t or + ii. with discrete transforms + + Current policies supported: +*/ +enum winograd_sched_t { + WSCHED_INVALID = 0, + + /* Forward & backward-data */ + /* W_S_G_D implements discrete transforms */ + WSCHED_DATA_W_S_G_D, + /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ + WSCHED_DATA_W_SGD, + + /* Backward-weights */ + WSCHED_WEI_S_D_G_W, + WSCHED_WEI_SDGtWo, + WSCHED_WEI_S_D_Giot_W, + WSCHED_WEI_SDGt_W, +}; + +struct jit_conv_winograd_conf_t : public jit_conv_conf_t { + int itiles; + int jtiles; + int ntiles; + int ic_simd_block=16; + int tile_4fma_padding; + int tile_4fma; + int oc_simd_block=16; + int oc_reg_block; + int ic_reg_block; + int tile_block; + int tile_block_ur; + int nb_tile_block_ur; + + bool double_buffering; + bool with_relu_postsum; + int zmm_start; + int nb_reg; + + int dimK; + int dimK_4fma; + int dimK_reg_block; + int dimK_block; + int dimK_nb_block; + + int dimM; + int dimM_reg_block; + int dimM_simd_block; + int dimM_block; + int dimM_nb_block; + + int dimN; + int dimN_reg_block; + int dimN_bcast_ur; + int dimN_block; + int dimN_nb_block; + + winograd_sched_t sched_policy; +}; + +struct jit_conv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *src_prf; + const void *dst_prf; + const void *filt_prf; + const void *bias_prf; + const void *scales; + const void *acc_s32; + const void *compensation; + size_t kd_offset; + size_t kd_offset_prf; + size_t d_index; + size_t d_index_prf; + size_t d_worksize; + size_t d_worksize_prf; + size_t kd_padding; + size_t kd_padding_prf; + size_t kh_padding; + size_t kh_padding_prf; + size_t owb; + size_t owb_prf; + size_t kw_padding; + size_t channel; + size_t channel_prf; + size_t oc_blocks; + size_t ur_w; + size_t ur_str_w; + size_t ch_blocks; + size_t t_overflow; + size_t b_overflow; + int flags; +}; + +struct jit_deconv_call_s { + const void *src; /* hack, non-const for backward_data */ + const void *dst; /* hack, non-const for forward */ + const void *filt; /* hack, non-const for backward_weights */ + const void *bias; /* hack, non-const for backward_bias */ + const void *scales; + const void *compensation; + size_t t_overflow; + size_t b_overflow; + size_t kh_padding; + size_t oc_blocks; +}; + +struct jit_dw_conv_call_s { + const void *input; + const void *output; + const void *filter; + const void *bias; + size_t kh_count; + size_t oh_count; + size_t oh_index; + size_t filter_pad_off; + unsigned char + exec_flags; /* Flags passed by driver execution to inner kernel */ +}; + +struct jit_wino_transform_call_s { + size_t tile_block; + size_t tile_block_ur; + size_t nb_tile_block_ur; + size_t tile_count; + size_t tj; + size_t ti; + void *src; + void *dst; + void *Mw; + void *M; + void *T; + void *G; + void *bias; +}; + +struct jit_1x1_conv_conf_t { + prop_kind_t prop_kind; + conv_version_t ver; + + int mb; + int ngroups, ic, oc, oc_without_padding, ic_without_padding; + int iw, ih, ow, oh; + int l_pad, t_pad; + int kh, kw; + int stride_h, stride_w; + format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround + bool with_bias; + bool with_sum; + bool with_eltwise; + + post_ops_t::entry_t::eltwise_t eltwise; + + int is, os; + int ic_block, oc_block; + + int ur, ur_tail; + + int reduce_dim, reduce_block, nb_reduce, + nb_reduce_blocking, nb_reduce_blocking_max; + int load_dim, load_block, nb_load, + nb_load_blocking, nb_load_blocking_max, nb_load_chunk; + int bcast_dim, bcast_block, nb_bcast, + nb_bcast_blocking, nb_bcast_blocking_max; + + int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step; + int load_loop_load_step, load_loop_iter_step; + int bcast_loop_output_step, bcast_loop_output_substep; + int bcast_loop_bcast_step, bcast_loop_bcast_substep; + int fma_step; + int load_grp_count; + conv_1x1_loop_order_t loop_order; + bool use_vmovntps; + /* avx512 core */ + bool expl_bcast; + /* 4vnni */ + int typesize_in; + int typesize_out; + int typesize_bia; + int typesize_acc; + /* 4fma */ + bool transpose_src; + int tr_is; + int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; + int is_oc_scale; + data_type_t bia_dt; + data_type_t dst_dt; + bool signed_input; + float wei_adj_scale; +}; + +struct jit_gemm_conv_conf_t { + prop_kind_t prop_kind; + + int mb; + int ngroups, ic, oc; + int iw, ih, id, ow, oh, od; + int l_pad, t_pad, f_pad; + int kh, kw, kd; + int stride_h, stride_w, stride_d; + int dilate_h, dilate_w, dilate_d; + bool with_bias; + + int is, os, ks; + int ic_block, oc_block; + + int nthr; + ptrdiff_t im2col_sz; + bool need_wei_reduction; + bool signed_input; + int oh_block; + int ow_block; + bool outer_threading; +}; + +struct jit_1x1_conv_call_s { + const void *bcast_data; + const void *load_data; + const void *output_data; + const void *bias_data; // used in forward and backward_weights only + const void *acc_s32; + const void *scales; + const void *compensation; + + size_t load_dim; + size_t bcast_dim; + size_t reduce_dim; + + size_t output_stride; // used in backward_weights only + + size_t first_last_flag; +}; + +/* pooling */ +struct jit_pool_conf_t { + int ndims; + int mb, c; + int id, ih, iw, od, oh, ow; + int stride_d, stride_h, stride_w; + int kd, kh, kw; + int f_pad, t_pad, l_pad; + alg_kind_t alg; + bool is_training; + bool pad_w_is_null; + bool is_backward; + bool simple_alg; + data_type_t ind_dt; + + int c_block, c_tail, nb_c; + int ur_c, ur_c_tail; + int ur_w; + int ur_w_tail; + size_t tail[4]; + data_type_t src_dt; + data_type_t dst_dt; +}; + +struct jit_pool_call_s { + const float *src; + const float *dst; + const void *indices; + const float *src_prf; + const float *dst_prf; + const void *indices_prf; + size_t oh; + size_t kd_padding; + size_t kh_padding; + size_t kh_padding_shift; + size_t kd_padding_shift; + size_t kw_padding; + const float* init_value; + float ker_area_h; +}; + + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp new file mode 100644 index 0000000000..94d2101d6e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp @@ -0,0 +1,677 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk) +{ + mov(aux1_reg_bcast_data, reg_bcast_data); + mov(aux_reg_output_data, reg_output_data); + mov(bcast_loop_iter, reg_bcast_loop_work); + + Label bcast_loop; + Label bcast_loop_tail; + + cmp(bcast_loop_iter, jcp.ur); + jl(bcast_loop_tail, T_NEAR); + + L(bcast_loop); { + assert(jcp.bcast_block % jcp.ur == 0); + int num_substeps = jcp.bcast_block / jcp.ur; + assert(num_substeps > 0 && num_substeps < 10); + for (int i = 0; i < num_substeps; i++) { + generate_reduce_loop(load_loop_blk, jcp.ur); + if (i < num_substeps - 1) { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_substep); + } else { + add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step + - (num_substeps - 1) * jcp.bcast_loop_bcast_substep); + add(aux_reg_output_data, jcp.bcast_loop_output_step + - (num_substeps - 1) * jcp.bcast_loop_output_substep); + } + } + sub(bcast_loop_iter, jcp.bcast_block); + cmp(bcast_loop_iter, jcp.bcast_block); + jge(bcast_loop, T_NEAR); + } + + L(bcast_loop_tail); + if (jcp.ur_tail) { + Label bcast_loop_tail_out; + cmp(bcast_loop_iter, 0); + jz(bcast_loop_tail_out, T_NEAR); + generate_reduce_loop(load_loop_blk, jcp.ur_tail); + L(bcast_loop_tail_out); + } +} + +void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop( + int load_loop_blk, int ur) +{ + auto reg_load = [=](int i, int n) { + return Xmm(2*ur * load_loop_blk + 2*i + n + 1); + }; + + auto reg_accum = [=](int i, int j, int n) { + return Xmm(2*j * load_loop_blk + 2*i + n + 1); + }; + + auto bias_ptr = [=](int i, int n) { + return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)]; + }; + + auto bcast_ptr = [=](int u, int j) { + assert(j < jcp.ur); + assert(u <= jcp.reduce_loop_unroll); + size_t offt; + if (one_of(jcp.prop_kind, + forward_training, forward_inference, backward_data)) { + assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data) + ? jcp.oc_block : jcp.ic_block); + auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is; + offt = (u == jcp.reduce_loop_unroll) + ? (height + j) * jcp.reduce_loop_unroll + : j * jcp.reduce_loop_unroll + u; + } else + offt = u * jcp.ic_block + j; + return ptr[aux_reg_bcast_data + sizeof(float) * offt]; + }; + + auto load_ptr = [=](int u, int i, int n) { + size_t offt; + size_t u0 = u % jcp.reduce_loop_unroll; + size_t u1 = u / jcp.reduce_loop_unroll; + switch (jcp.prop_kind) { + case backward_data: + offt = (i * jcp.oc_block + u0) * jcp.ic_block; + break; + case backward_weights: + offt = (i * jcp.os + u0) * jcp.oc_block; + break; + default: + offt = (i * jcp.ic + u0) * jcp.oc_block; + } + return ptr[aux_reg_load_data + + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)]; + }; + + auto output_ptr = [=](int i, int j, int n) { + switch (jcp.prop_kind) { + case backward_data: + return ptr[aux_reg_output_data + + (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)]; + case backward_weights: + return ptr[aux_reg_output_data + + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale + + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)]; + default: + return ptr[aux_reg_output_data + + (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)]; + } + }; + + auto init = [=]() { + Label init_done; + Label init_zero; + + if (jcp.with_bias && one_of(jcp.prop_kind, forward_training, + forward_inference)) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(init_zero); + + for (int i = 0; i < load_loop_blk; i++) + for (int j = 0; j < ur; ++j) { + movups(reg_accum(i, j, 0), bias_ptr(i, 0)); + movups(reg_accum(i, j, 1), bias_ptr(i, 1)); + } + jmp(init_done); + } + + L(init_zero); + for (int i = 0; i < load_loop_blk; ++i) + for (int j = 0; j < ur; ++j) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + xorps(r0, r0); + xorps(r1, r1); + } + + L(init_done); + + // load weights + for (int i = 0; i < load_loop_blk; ++i) { + movups(reg_load(i, 0), load_ptr(0, i, 0)); + movups(reg_load(i, 1), load_ptr(0, i, 1)); + } + + movss(reg_bcast, bcast_ptr(0, 0)); + shufps(reg_bcast, reg_bcast, 0); + }; // init() + + auto store = [=]() { + Label store_noadd; + + if (!jcp.with_sum) { + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jnz(store_noadd, T_NEAR); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = reg_accum(i, j, 0); + auto r1 = reg_accum(i, j, 1); + addps(r0, output_ptr(i, j, 0)); + addps(r1, output_ptr(i, j, 1)); + } + + L(store_noadd); + + if (jcp.with_eltwise) { + assert(ur * load_loop_blk < 14); + + Label store_norelu; + test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); + jz(store_norelu, T_NEAR); + + eltwise_injector_->compute_vector_range(1, + 2 * ur * load_loop_blk + 1); + + L(store_norelu); + } + + for (int j = 0; j < ur; ++j) + for (int i = 0; i < load_loop_blk; ++i) { + movups(output_ptr(i, j, 0), reg_accum(i, j, 0)); + movups(output_ptr(i, j, 1), reg_accum(i, j, 1)); + } + }; + + auto fma_block = [=](bool last_block) { + for (int u = 0; u < jcp.reduce_loop_unroll; ++u) { + for (int j = 0; j < ur; ++j) { + for (int i = 0; i < load_loop_blk; ++i) { + mulps(reg_load(i, 0), reg_bcast); + mulps(reg_load(i, 1), reg_bcast); + addps(reg_accum(i, j, 0), reg_load(i, 0)); + addps(reg_accum(i, j, 1), reg_load(i, 1)); + + if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) { + movups(reg_load(i, 0), load_ptr(u + 1, i, 0)); + movups(reg_load(i, 1), load_ptr(u + 1, i, 1)); + } + } + if (j < ur - 1) { + movss(reg_bcast, bcast_ptr(u, j + 1)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for ur + if (!last_block || u < jcp.reduce_loop_unroll - 1) { + movss(reg_bcast, bcast_ptr(u + 1, 0)); + shufps(reg_bcast, reg_bcast, 0); + } + } // for reduce_loop_unroll + }; + + Label reduce_loop; + Label reduce_loop_tail; + + mov(aux_reg_load_data, reg_load_data); + mov(aux_reg_bcast_data, aux1_reg_bcast_data); + + init(); + + mov(reduce_loop_iter, reg_reduce_loop_work); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jle(reduce_loop_tail, T_NEAR); + + L(reduce_loop); { + fma_block(false); + add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jg(reduce_loop, T_NEAR); + } + + L(reduce_loop_tail); + fma_block(true); + + store(); +} // reduce_loop() + +void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk) +{ + if (!jcp.with_bias || jcp.prop_kind != backward_weights) + return; + + Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out; + Label diff_bias_load; + + auto diff_bias_ptr = [=](int i, int n) { + return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)]; + }; + + auto load_ptr = [=](int u, int i, int n) { + return ptr[aux_reg_load_data + + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)]; + }; + + auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); }; + + mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]); + cmp(reg_diff_bias_data, 0); + je(diff_bias_loop_out, T_NEAR); + + test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); + jz(diff_bias_load, T_NEAR); + + for (int i = 0; i < load_loop_blk; ++i) { + auto r0 = diff_bias_reg(i, 0); + auto r1 = diff_bias_reg(i, 1); + xorps(r0, r0); + xorps(r1, r1); + } + jmp(diff_bias_init_out, T_NEAR); + + L(diff_bias_load); + for (int i = 0; i < load_loop_blk; ++i) { + movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0)); + movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1)); + } + + L(diff_bias_init_out); + mov(aux_reg_load_data, reg_load_data); + mov(reduce_loop_iter, reg_reduce_loop_work); + L(diff_bias_loop); { + for(int u = 0; u < jcp.reduce_loop_unroll; ++u) + for (int i = 0; i < load_loop_blk; ++i) { + addps(diff_bias_reg(i, 0), load_ptr(u, i, 0)); + addps(diff_bias_reg(i, 1), load_ptr(u, i, 1)); + } + assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); + add(aux_reg_load_data, jcp.reduce_loop_load_step); + sub(reduce_loop_iter, jcp.reduce_loop_unroll); + jnz(diff_bias_loop, T_NEAR); + } + + for (int i = 0; i < load_loop_blk; i++) { + movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0)); + movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1)); + } + + add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + + L(diff_bias_loop_out); +} + +void jit_sse42_1x1_conv_kernel_f32::generate() +{ + preamble(); + + mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); + mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); + mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); + if (jcp.with_bias) { + if (jcp.prop_kind == backward_weights) { + sub(rsp, stack_space_needed); + mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]); + mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data); + } else + mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); + } + + mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); + mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); + mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); + mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); + if (jcp.prop_kind == backward_weights) + mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); + + auto generate_load_loop_body = [=] (int load_loop_blk) { + generate_bcast_loop(load_loop_blk); + add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); + switch (jcp.prop_kind) { + case forward_training: + case forward_inference: + add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float)); + add(reg_output_data, + load_loop_blk * jcp.os * jcp.oc_block * sizeof(float)); + break; + case backward_data: + add(reg_output_data, + load_loop_blk * jcp.is * jcp.ic_block * sizeof(float)); + break; + case backward_weights: + for (int i = 0; i < load_loop_blk; i++) + add(reg_output_data, reg_output_stride); + break; + default: + assert(!"invalid prop_kind"); + } + sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); + }; + + Label load_loop_blk_8; + Label load_loop_blk_16; + Label load_loop_blk_24; + Label load_loop_blk_end; + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16, T_NEAR); + + cmp(reg_load_loop_work, 16); + jle(load_loop_blk_16, T_NEAR); + + L(load_loop_blk_24); { + generate_diff_bias_loop(3); + generate_load_loop_body(3); + cmp(reg_load_loop_work, 32); + je(load_loop_blk_16); + cmp(reg_load_loop_work, 24); + jge(load_loop_blk_24); + } + + cmp(reg_load_loop_work, 8); + jle(load_loop_blk_8, T_NEAR); + + L(load_loop_blk_16); { + generate_diff_bias_loop(2); + generate_load_loop_body(2); + cmp(reg_load_loop_work, 16); + jge(load_loop_blk_16); + } + + L(load_loop_blk_8); { + cmp(reg_load_loop_work, 0); + je(load_loop_blk_end, T_NEAR); + generate_diff_bias_loop(1); + generate_load_loop_body(1); + } + + L(load_loop_blk_end); + + if (jcp.with_bias && jcp.prop_kind == backward_weights) + add(rsp, stack_space_needed); + + postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok( + jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) + return status::unimplemented; + + // TODO (Roma): this code is duplicated from the generic kernel; maybe the + // configuration struct could do some stuff below + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + + jcp.prop_kind = cd.prop_kind; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + jcp.os = jcp.oh * jcp.ow; + jcp.is = jcp.ih * jcp.iw; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const int is_bwd_d = jcp.prop_kind == backward_data; + + format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c; + format_tag_t wei_tag = with_groups + ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o, + gOIhw8o8i) + : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o, + OIhw8o8i); + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.ngroups == 1 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag; + if (!args_ok) return status::unimplemented; + + const int simd_w = 4; + jcp.ic_block = jcp.oc_block = simd_w*2; + + args_ok = true + && jcp.oc % jcp.oc_block == 0 + && jcp.ic % jcp.ic_block == 0 + && jcp.t_pad == 0 && jcp.l_pad == 0 + && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides + && jcp.kh == 1 && jcp.kw == 1; + if (!args_ok) return status::unimplemented; + + jcp.ur = 1; + + int load_blocking{ 0 }; + int load_blocking_max{ 0 }; + int bcast_blocking{ 0 }; + int bcast_blocking_max{ 0 }; + int reduce_blocking{ 0 }; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.reduce_dim = jcp.ic; + jcp.reduce_block = jcp.ic_block; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.is; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.is * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + load_blocking = 120; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 192; + reduce_blocking = 128; // affects L1$ utilization + } else if (jcp.prop_kind == backward_data) { + jcp.reduce_dim = jcp.oc; + jcp.reduce_block = jcp.oc_block; + + jcp.load_dim = jcp.ic; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.os; + jcp.bcast_block = jcp.ur; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.os * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.ic * sizeof(float); + + jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = -1; // unused + jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float); + jcp.bcast_loop_bcast_substep = -1; // unused + + jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.load_loop_iter_step = jcp.ic_block; + + load_blocking = 96; // assumes the kernel is jcp.ur x 3 + load_blocking_max = 144; + bcast_blocking = 128; // affects load balancing across threads + bcast_blocking_max = 196; + reduce_blocking = 64; // affects L1$ utilization + } else if (jcp.prop_kind == backward_weights) { + jcp.reduce_dim = jcp.os; + jcp.reduce_block = 1; + + jcp.load_dim = jcp.oc; + jcp.load_block = jcp.oc_block; + + jcp.bcast_dim = jcp.ic; + jcp.bcast_block = jcp.ic_block; + + jcp.reduce_loop_unroll = jcp.reduce_block; + jcp.reduce_loop_bcast_step + = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float); + jcp.reduce_loop_load_step + = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float); + + jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float); + jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float); + jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float); + jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float); + + jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float); + jcp.load_loop_iter_step = jcp.oc_block; + + /* --- */ + + load_blocking = div_up(jcp.load_dim, jcp.load_block); + while (true) { + if (load_blocking <= 32) break; + else if (load_blocking % 2 == 0) load_blocking /= 2; + else if (load_blocking % 3 == 0) load_blocking /= 3; + else break; + } + load_blocking *= jcp.load_block; + load_blocking_max = load_blocking; + assert(jcp.load_dim % load_blocking == 0); + + bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); + while (true) { + if (bcast_blocking <= 9) break; + else if (bcast_blocking % 2 == 0) bcast_blocking /= 2; + else if (bcast_blocking % 3 == 0) bcast_blocking /= 3; + else break; + } + bcast_blocking *= jcp.bcast_block; + bcast_blocking_max = bcast_blocking; + assert(jcp.bcast_dim % bcast_blocking == 0); + + reduce_blocking = 128; // affects L1$ utilization + } else + return status::unimplemented; + + assert(load_blocking); + assert(load_blocking_max); + assert(bcast_blocking); + assert(bcast_blocking_max); + assert(reduce_blocking); + + assert(jcp.bcast_block % jcp.ur == 0); + jcp.ur_tail = jcp.bcast_dim % jcp.ur; + + jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; + jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; + jcp.nb_load_blocking = load_blocking / jcp.load_block; + jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; + jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; + + jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); + jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); + jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp new file mode 100644 index 0000000000..b314a5098c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP +#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_conv_kernel_f32: public jit_generator { + jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode(); + } + + ~jit_sse42_1x1_conv_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_1x1_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_1x1_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32) + + jit_1x1_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_1x1_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + using xmm_t = const Xbyak::Xmm; + + reg64_t reg_bcast_data = rax; + reg64_t reg_load_data = rsi; + reg64_t reg_output_data = rbx; + reg64_t aux_reg_bcast_data = rdx; + reg64_t aux1_reg_bcast_data = abi_not_param1; + reg64_t aux_reg_load_data = abi_param1; + reg64_t aux_reg_output_data = rbp; + reg64_t reg_load_loop_work = r9; + reg64_t reg_bcast_loop_work = r10; + reg64_t reg_reduce_loop_work = r11; + reg64_t load_loop_iter = r13; + reg64_t imm_addr64 = load_loop_iter; + reg64_t bcast_loop_iter = r14; + reg64_t reduce_loop_iter = r15; + reg64_t reg_reduce_pos_flag = r8; + reg64_t reg_output_stride = r12; + reg64_t reg_bias_data = r12; + reg64_t reg_diff_bias_data = bcast_loop_iter; + + int reg_diff_bias_data_stack_offt = 0; + int stack_space_needed = 8; + + xmm_t reg_bcast = xmm_t(15); + + jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_; + + void generate_bcast_loop(int load_loop_blk); + void generate_reduce_loop(int load_loop_blk, int ur); + void generate_diff_bias_loop(int load_loop_blk); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp new file mode 100644 index 0000000000..30c137641e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp @@ -0,0 +1,134 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_1x1_convolution.hpp" +#include "utils.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define data_blk_off(f, n, c, h, w) \ + ((ndims == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w)) + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +void jit_sse42_1x1_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + const int ndims = src_d.ndims(); + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + + parallel(0, [&](const int ithr, const int nthr) { + // TODO (Roma): remove this restriction + assert(jcp.stride_w == 1 && jcp.stride_h == 1); + + auto par_conv = jit_1x1_conv_call_s(); + + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + const int os_block = jcp.bcast_block; + + int start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + + int iwork = start; + while (iwork < end) { + int n{0}, g{0}, osb{0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, + jcp.nb_bcast); + + const int bcast_step_rem = jcp.nb_bcast - osb; + int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max + ? bcast_step_rem : jcp.nb_bcast_blocking; + bcast_step = nstl::min<int>(bcast_step, end - iwork); + + const int os = osb * os_block; + const int ow = os % jcp.ow; + const int oh = os / jcp.ow; + const int iw = nstl::max<int>(ow * jcp.stride_w - jcp.l_pad, 0); + const int ih = nstl::max<int>(oh * jcp.stride_h - jcp.t_pad, 0); + + par_conv.bcast_dim = this_block_size(os, jcp.os, + bcast_step * os_block); + + int ocb = 0; + while (ocb < jcp.nb_load) { + const int load_step_rem = jcp.nb_load - ocb; + const int load_step = load_step_rem < jcp.nb_load_blocking_max + ? load_step_rem : jcp.nb_load_blocking; + + const size_t _ocb = g * nb_oc + ocb; + par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, + load_step * jcp.oc_block); + + const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow); + par_conv.output_data = &dst[dst_off]; + + par_conv.bias_data = &bias[_ocb * jcp.oc_block]; + + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + par_conv.first_last_flag = 0 + | (icb == 0) * FLAG_REDUCE_FIRST + | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST; + + par_conv.reduce_dim = this_block_size(icb * jcp.ic_block, + jcp.ic, nb_ic_blocking * jcp.ic_block); + + const size_t _icb = g * nb_ic + icb; + const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw); + par_conv.bcast_data = &src[src_off]; + + par_conv.load_data = &weights[pd()->with_groups() + ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + kernel_->jit_ker(&par_conv); + } + + ocb += load_step; + } + + iwork += bcast_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp new file mode 100644 index 0000000000..b32b1e4784 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp @@ -0,0 +1,96 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP +#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "jit_sse42_1x1_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""), + jit_sse42_1x1_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_1x1_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o) + : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr()); + } + ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_1x1_conv_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp new file mode 100644 index 0000000000..17cabc1186 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp @@ -0,0 +1,497 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "cpu_memory.hpp" + +#include "jit_sse42_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); + int jj_end = ur_w + - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w)); + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l); + else + inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ki * ic_blk * oc_blk + ifm2 * oc_blk; + + for (int jj = jj_start; jj < jj_end; jj++) + { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * ker_off]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + } +} + +void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + Label kw_loop; + + int iw = jcp.iw; + int ih = jcp.ih; + int kw = jcp.kw; + int kh = jcp.kh; + int nb_ic = jcp.nb_ic; + int stride_w = jcp.stride_w; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + + xor_(ki_iter, ki_iter); + L(kw_loop); + { + int jj_start = 0; + int jj_end = ur_w; + for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) { + for (int jj = jj_start; jj < jj_end; jj++) { + int inp_off; + if (one_of(jcp.src_tag, ncw, nchw)) + inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l); + else + inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2; + + movss(Xmm(oc_blocks * ur_w + jj + 1), + ptr[aux_reg_input + sizeof(float) * inp_off]); + shufps(Xmm(oc_blocks * ur_w + jj + 1), + Xmm(oc_blocks * ur_w + jj + 1), 0x0); + } + for (int ii = 0; ii < oc_blocks; ii++) { + int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk + + ifm2 * oc_blk; + for (int jj = jj_start; jj < jj_end; jj++) { + movups(xmm0, + ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]); + mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1)); + addps(Xmm(ur_w * ii + jj + 1), xmm0); + } + } + } + add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ? + dilate_w : ic_blk * dilate_w)); + + inc(ki_iter); + cmp(ki_iter, kw); + jl(kw_loop, T_NEAR); + } +} + +void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w, + int pad_l, int pad_r, int oc_blocks) +{ + int iw = jcp.iw; + int kw = jcp.kw; + int ow = jcp.ow; + int oh = jcp.oh; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) + ? dilate_h : ic_blk * dilate_h; + const int inp_off = one_of(jcp.src_tag, ncw, nchw) + ? dilate_w : ic_blk * dilate_w; + + xor_(simd_iter, simd_iter); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + Label init_simd_iter_loop; + Label init_done; + Label init_first; + + L(init_simd_iter_loop); + + if (!jcp.with_sum) { + test(reg_ci_flag, FLAG_IC_FIRST); + jne(init_first, T_NEAR); + } + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), xword[reg_output + + sizeof(float) * (ii * oh * ow + jj) * oc_blk]); + + if (jcp.with_sum && jcp.with_bias) { + test(reg_ci_flag, FLAG_IC_FIRST); + je(init_done, T_NEAR); + + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + addps(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } + + jmp(init_done); + + L(init_first); + if (this->jcp.with_bias) { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + movups(Xmm(ur_w * ii + jj + 1), + xword[reg_bias + sizeof(float) * ii * oc_blk]); + } else { + for (int ii = 0; ii < oc_blocks; ii++) + for (int jj = 0; jj < ur_w; jj++) + pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1)); + } + + L(init_done); + + Label skip_kh_loop; + mov(kj, reg_kh); + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) { + cmp(kj, 0); + je(skip_kh_loop, T_NEAR); + } + Label kh_loop; + L(kh_loop); + { + if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) { + oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks); + sub(aux_reg_input, sizeof(float) * kw * inp_off); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } else { + oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks); + add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk); + add(aux_reg_input, sizeof(float) * iw * inp_mult); + } + + dec(kj); + cmp(kj, 0); + jg(kh_loop, T_NEAR); + } + + L(skip_kh_loop); + + if (jcp.with_eltwise) { + Label regular_store; + test(reg_ci_flag, FLAG_IC_LAST); + je(regular_store, T_NEAR); + + eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1); + + L(regular_store); + } + + for (int ii = 0; ii < oc_blocks; ii++) { + for (int jj = 0; jj < ur_w; jj++) { + const size_t o_off = (ii * oh * ow + jj) * oc_blk; + + Xmm reg_out = Xmm(ur_w * ii + jj + 1); + movups(xword[reg_output + sizeof(float) * o_off], reg_out); + } + } + + mov(aux_reg_kernel, reg_kernel); + mov(aux_reg_input, reg_input); + add(aux_reg_kernel, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + add(reg_bias, sizeof(float) * 4); + + inc(simd_iter); + cmp(simd_iter, 2); + jl(init_simd_iter_loop, T_NEAR); + + sub(reg_output, sizeof(float) * 8); + sub(reg_bias, sizeof(float) * 8); +} + +inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks) +{ + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int n_oi = jcp.ow / ur_w; + int iw = jcp.iw; + int kw = jcp.kw; + int ic_blk = jcp.ic_block; + int oc_blk = jcp.oc_block; + int dilate_w = jcp.dilate_w + 1; + int str_w = jcp.stride_w; + const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk; + + int l_pad = jcp.l_pad; + int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w + - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) + width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad" + else + width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad" + add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + Label ow_loop; + xor_(oi_iter, oi_iter); + + if (n_oi > 0) { + L(ow_loop); + + width_blk_step(ur_w, 0, 0, oc_blocks); // "middle" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + + if (r_pad1 > 0 && n_oi >=0) { + width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad" + add(reg_input, sizeof(float) * ur_w * str_w * inp_mult); + add(reg_output, sizeof(float) * ur_w * oc_blk); + } + + if (ur_w_tail != 0) + width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail" +} + +void jit_sse42_conv_fwd_kernel_f32::generate() +{ + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]); + mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]); + + int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking; + Label tail, exit; + + cmp(reg_oc_blocks, jcp.nb_oc_blocking); + jne(nb_oc_tail ? tail : exit, T_NEAR); + + solve_common(jcp.nb_oc_blocking); + jmp(exit, T_NEAR); + + if (nb_oc_tail) { + L(tail); + cmp(reg_oc_blocks, nb_oc_tail); + jne(exit, T_NEAR); + solve_common(nb_oc_tail); + } + + L(exit); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(sse42)) return status::unimplemented; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); + jcp.ndims = ndims; + + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2]; + jcp.ow = dst_d.dims()[ndims - 1]; + + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0]; + jcp.l_pad = cd.padding[0][ndims - 3]; + + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0]; + jcp.dilate_w = cd.dilates[ndims - 3]; + jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1) + - (jcp.ih + jcp.t_pad - 1); + + if (ndims == 3) { + jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Owi8o, gOwi8o, OIw8i8o, gOIw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c); + } else if (ndims == 4) { + jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c); + jcp.wei_tag = weights_d.matches_one_of_tag( + Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o); + jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c); + } + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + const bool flat = jcp.ic == 3; + const bool mimo = !flat; + + bool args_ok = true + && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc) + && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o)) + && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c) + && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o)) + && one_of(jcp.dst_tag, nCw8c, nChw8c); + if (!args_ok) return status::unimplemented; + + const int simd_w = 8; // 2 SSE vectors processing at once + + jcp.ur_h = 1; /* no code-unrolling by h so far */ + jcp.ur_w = 3; + if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */ + + args_ok = true + && jcp.oc % simd_w == 0 + && jcp.l_pad <= jcp.ur_w + && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) + || (jcp.stride_w == 1 && jcp.stride_h == 1)) + && IMPLICATION(mimo, jcp.ic % simd_w == 0); + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + + // kernel needs 1 temporary YMM register + const int num_avail_regs = 15; + if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) { + /* recalculate ur_w, nb_oc_blocking and ur_w_tail */ + jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail, + nstl::min(jcp.ow, num_avail_regs / 2)); + jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w; + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + /* check again ... */ + r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail)) + return status::unimplemented; + } + assert(jcp.nb_oc_blocking > 0); + assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs); + + jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w; + jcp.nb_ic = jcp.ic / jcp.ic_block; + + jcp.oc_block = simd_w; + jcp.nb_oc = jcp.oc / jcp.oc_block; + + if (one_of(jcp.prop_kind, forward_training, forward_inference)) { + jcp.nb_ic_blocking = 12; + jcp.nb_ic_blocking_max = 16; + } else { + jcp.nb_ic_blocking = 1; + jcp.nb_ic_blocking_max = jcp.nb_ic_blocking; + } + + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp new file mode 100644 index 0000000000..33c26ef081 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp @@ -0,0 +1,93 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP +#define JIT_SSE42_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "cpu_memory.hpp" +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_conv_fwd_kernel_f32: public jit_generator { + jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, + const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_sse42_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32) + jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + void (*jit_ker)(jit_conv_call_s *); + +private: + using reg64_t = const Xbyak::Reg64; + reg64_t reg_input = rax; + reg64_t aux_reg_input = r8; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r9; + reg64_t reg_output = rsi; + reg64_t reg_bias = rbx; + + reg64_t kj = r10; + reg64_t oi_iter = r11; + reg64_t ki_iter = r12; + reg64_t reg_kh = abi_not_param1; + reg64_t simd_iter = r15; + reg64_t reg_oc_blocks = r14; + reg64_t imm_addr64 = reg_oc_blocks; + Xbyak::Reg32 reg_ci_flag = r13d; + + jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_; + + inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, + int oc_blocks); + inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); + inline void solve_common(int oc_blocks); + + void generate(); +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp new file mode 100644 index 0000000000..5f77d692f5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "jit_sse42_convolution.hpp" +#include "mkldnn_thread.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +#define src_blk_off(f, n, c, h, w) \ + (pd()->ndims() == 3) \ + ? (f).blk_off(n, c, w) \ + : (f).blk_off(n, c, h, w) + +#define wht_blk_off_(f, g, ...) \ + pd()->with_groups() \ + ? (f).blk_off(g, __VA_ARGS__) \ + : (f).blk_off(__VA_ARGS__) +#define wht_blk_off(f, g, oc, ic, kh, kw) \ + pd()->ndims() == 3 \ + ? wht_blk_off_(f, g, oc, ic, kw) \ + : wht_blk_off_(f, g, oc, ic, kh, kw) + +void jit_sse42_convolution_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); + const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{ 0 }, end{ 0 }; + balance211(work_amount, nthr, ithr, start, end); + + int icbb = 0; + while (icbb < jcp.nb_ic) { + int icb_step = jcp.nb_ic_blocking; + int icb_step_rem = jcp.nb_ic - icbb; + if (icb_step_rem < jcp.nb_ic_blocking_max) + icb_step = icb_step_rem; + + size_t n{0}, g{0}, ocbb{0}, oh{0}; + nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + for (size_t iwork = start; iwork < end; ++iwork) { + int ocb = ocbb * jcp.nb_oc_blocking; + int ocb_num = jcp.nb_oc_blocking; + + for (int icb = icbb; icb < icbb + icb_step; ++icb) { + auto par_conv = jit_conv_call_s(); + + const int ij = oh * jcp.stride_h; + const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); + const int i_b_overflow = nstl::max(jcp.ih, ij + + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih; + + const size_t _oc = g * jcp.nb_oc + ocb; + const size_t _ic = g * jcp.nb_ic + icb; + + const int ih = nstl::max(ij - jcp.t_pad + + div_up(i_t_overflow, + (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0); + par_conv.src = &src[src_blk_off(src_d, n, + jcp.ic == 3 ? 0 : _ic, ih, 0)]; + + par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)]; + + const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); + par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb, + jcp.ic == 3 ? 0 : icb, wh, 0)]; + + if (icb == 0) { + if (bias) + par_conv.bias = + &bias[bias_d.blk_off(_oc * jcp.oc_block)]; + par_conv.flags |= FLAG_IC_FIRST; + } + + if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) { + par_conv.flags |= FLAG_IC_LAST; + } + + par_conv.oc_blocks = + nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; + + par_conv.kw_padding = 0; + const int kh_padding = jcp.kh + - div_up(i_t_overflow, (jcp.dilate_h + 1)) + - div_up(i_b_overflow, (jcp.dilate_h + 1)); + par_conv.kh_padding = nstl::max(0, kh_padding); + kernel_->jit_ker(&par_conv); + } + nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, + oh, jcp.oh); + } + icbb += icb_step; + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp new file mode 100644 index 0000000000..d2f0a38c5c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP +#define CPU_JIT_SSE42_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_primitive_conf.hpp" +#include "jit_sse42_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_sse42_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", sse42, ""), + jit_sse42_convolution_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(), + *src_md(), *weights_md(), *dst_md(), *attr()); + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const bool flat = IC() == 3; + auto src_tag = flat + ? utils::pick(ndims() - 3, ncw, nchw, ncdhw) + : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto dst_tag = + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); + auto wei_tag = with_groups() + ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o, + gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o) + : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o, + OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o); + + return set_default_formats_common(src_tag, wei_tag, dst_tag); + } + }; + + jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); } + ~jit_sse42_convolution_fwd_t() { delete kernel_; }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_sse42_conv_fwd_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp new file mode 100644 index 0000000000..0e734f7265 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp @@ -0,0 +1,1192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "jit_generator.hpp" +#include "cpu_barrier.hpp" + +#include "jit_transpose_src_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +#define GET_OFF(x) offsetof(ctx_t, x) + +struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t) + + jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t k3333 = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kCCCC = k4; + opmask_t k0F0F = k5; + opmask_t kF0F0 = k6; + opmask_t kTail = k7; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src, + (transpose_size + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + int offset = (transpose_size) * typesize + i * tr_src_stride; + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset)); + if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, + offset + 64)); + }; + + auto pf_src_t1 = [=](int i) { + if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf, + i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, + i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 16); + return Zmm(16 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = r; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) + add(reg_tr_src_tmp, l_pad * typesize); + + if (tail != transpose_size) + kmovw(kTail, (1 << tail) - 1); + + // Xbyak does not allow k0 to be specified explicitly via the '|' + // operator, so we have to do this via a method call (implicitly + // EVEX encoding uses k0 to mean 'no mask') + bool partial_store = nrows < 16; + auto k = partial_store ? kTail : k0; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + if (nontemporal_stores && !partial_store) + vmovntps(addr, r); + else + vmovups(addr, r); + + if (r_pad > 0) { + add(reg_tr_src_tmp, tail * typesize); + padding(reg_tr_src_tmp, r_pad); + } + + if (l_pad > 0) { + padding(reg_tr_src, l_pad); + } + }; + + auto transpose16x8 = [=](int base_idx) { + assert(base_idx == 0 || base_idx == 8); + + // swap 1 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i * 2; + int src_idx1 = src_idx0 + 1; + + int next_src_idx0 = src_idx0 + 2; + int next_src_idx1 = src_idx1 + 2; + bool load_next = base_idx == 0 || i < 3; + + if (base_idx == 0 && i == 0) { + load(src_idx0); + load(src_idx1); + } + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx1); + auto src0 = src_zmm(src_idx0); + auto src1 = src_zmm(src_idx1); + + if (next_src_idx0 < nrows && load_next) + load(next_src_idx0); + valignd(tmp0, src0, src0, 0x1); + pf_src_t1(base_idx + i); + + if (next_src_idx1 < nrows && load_next) + load(next_src_idx1); + valignd(tmp1, src1, src1, 0xf); + pf_src_t0(base_idx + i); + + vmovaps(src0 | kAAAA, tmp1); + vmovaps(src1 | k5555, tmp0); + } + // swap 2 + for (int i = 0; i < 4; i++) { + int select_half = (i < 2) ? 0 : 2; + int src_idx0 = base_idx + i + select_half + 0; + int src_idx2 = src_idx0 + 2; + + auto tmp0 = tmp_zmm(src_idx0); + auto tmp1 = tmp_zmm(src_idx2); + auto src0 = src_zmm(src_idx0); + auto src2 = src_zmm(src_idx2); + + valignd(tmp0, src0, src0, 0x2); + pf_src_t1(base_idx + 4 + i); + valignd(tmp1, src2, src2, 0xe); + pf_src_t0(base_idx + 4 + i); + vmovaps(src2 | k3333, tmp0); + vmovaps(src0 | kCCCC, tmp1); + } + + // swap 4 + for (int i = 0; i < 4; i++) { + int src_idx0 = base_idx + i; + int src_idx4 = src_idx0 + 4; + + auto tmp0 = tmp_zmm(src_idx0); + auto src0 = src_zmm(src_idx0); + auto src4 = src_zmm(src_idx4); + + vmovaps(tmp0, src0); + vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); + pf_tr_src_t1(base_idx / 2 + i); + vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); + pf_tr_src_t0(base_idx / 2 + i); + } + }; + + auto fixup16x16 = [=]() { + // swap 8 + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0x44); + store(tmp, i); + if (i % 2 == 0) { + pf_tr_src_t1(8 + i / 2); + pf_tr_src_t0(8 + i / 2); + } + } + + for (int i = 0; i < 8; i++) { + auto tmp = tmp_zmm(8 + i); + auto src0 = src_zmm(i); + auto src8 = src_zmm(8 + i); + vshuff64x2(tmp, src0, src8, 0xee); + store(tmp, 8 + i); + if (i % 2 == 0) { + pf_tr_src_t1(12 + i / 2); + pf_tr_src_t0(12 + i / 2); + } + } + }; + + transpose16x8(0); + transpose16x8(8); + fixup16x16(); +} + +void jit_trans_iw_ic_t::generate() { + preamble(); + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(k3333, 0x3333); // 0011001100110011 + kmovw(k5555, 0x5555); // 0101010101010101 + kmovw(kAAAA, 0xaaaa); // 1010101010101010 + kmovw(kCCCC, 0xcccc); // 1100110011001100 + kmovw(k0F0F, 0x0f0f); // 0000111100001111 + kmovw(kF0F0, 0xf0f0); // 1111000011110000 + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) + jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf): + jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFFFF = k1; + opmask_t k5555 = k2; + opmask_t kAAAA = k3; + opmask_t kAA = k4; + opmask_t k55 = k5; + opmask_t kCC = k6; + opmask_t k33 = k7; + opmask_t kTail = k1; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + Xbyak::Zmm vidx1 = zmm31; + Xbyak::Zmm vidx2 = zmm30; + Xbyak::Zmm vidx3 = zmm29; + Xbyak::Zmm vidx4 = zmm28; + Xbyak::Zmm vidx5 = zmm27; + Xbyak::Zmm zmm_tmp = zmm26; + + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto store = [=](Zmm r, int i) { + + auto padding = [=] (Reg64 reg, int pad) { + kmovw(kTail, (1 << pad) - 1); + auto k = kTail; + auto base = reg; + base.setOpmaskIdx(k.getIdx(), true); + + auto zmm_zero = zmm_tmp; + vpxord(zmm_zero, zmm_zero, zmm_zero); + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, zmm_zero); + }; + + int store_tail = (nrows%2) ? nrows+1 : nrows; + + int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2; + mov(reg_tr_src_tmp, reg_tr_src); + if (l_pad > 0) { + padding(reg_tr_src, store_pad); + add(reg_tr_src_tmp, l_pad * typesize); + } + if (r_pad > 0) { + store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2; + int addr_shift = (r_pad%2) ? 1 : 0; + add(reg_tr_src_tmp, (nrows - addr_shift) * typesize); + padding(reg_tr_src_tmp, store_pad); + } + + mov(reg_tr_src_tmp, reg_tr_src); + add(reg_tr_src_tmp, l_pad * typesize); + + kmovw(kTail, (1 << store_tail/2) - 1); + auto k = kTail; + auto base = reg_tr_src_tmp; + base.setOpmaskIdx(k.getIdx(), true); + + auto addr = EVEX_compress_addr(base, i * tr_src_stride); + vmovups(addr, r); + + }; + + kmovw(kFFFF, 0xffff); + //all loads + for (int i=0; i<16; i++){ + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // for odd numbers we need to mix row with zeroes + if (nrows%2) { + int i = nrows-1; + auto src0 = src_ymm(i); + auto src1 = src_ymm(i+1); //zero + + auto zmm_src0 = src_zmm(i); + vpxor(src1, src1, src1); + + load_ymm(i); + vpunpckhwd(src0, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); + vpxor(src0, src0, src0); + load_ymm(i); + vpunpcklwd(src1, src0, src1); + vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); + vpxord(zmm_src0, zmm_src0, zmm_src0); + vmovups(zmm_src0, zmm_tmp); + vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); + } + + // swap 1 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(4*i); + auto zmm1 = src_zmm(4*i+2); + auto tmp0 = src_zmm(4*i+1); + auto tmp1 = src_zmm(4*i+3); + + vmovups(tmp0, zmm0); + vmovups(tmp1, zmm1); + + vpermps(tmp0 | kAAAA, vidx3, zmm1); + vpermps(tmp1 | k5555, vidx3, zmm0); + } + // swap 2 + int base_idx; + base_idx=0; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + base_idx=8; + for (int i=0; i<2; i++) { + auto zmm0 = src_zmm(base_idx+2*i+1); + auto zmm1 = src_zmm(base_idx+2*i+5); + + auto tmp0 = src_zmm(base_idx+2*i); + auto tmp1 = src_zmm(base_idx+2*i+4); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kAA, vidx2, zmm1); + vpermpd(tmp1 | k55, vidx2, zmm0); + } + + // swap 3 + for (int i=0; i<4; i++) { + auto zmm0 = src_zmm(2*i); + auto zmm1 = src_zmm(2*i+8); + + auto tmp0 = src_zmm(2*i+1); + auto tmp1 = src_zmm(2*i+9); + + vmovupd(tmp0, zmm0); + vmovupd(tmp1, zmm1); + + vpermpd(tmp0 | kCC, vidx1, zmm1); + vpermpd(tmp1 | k33, vidx1, zmm0); + } + + // all stores + for (int i=0; i<8; i++) + vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1); + + store(src_zmm(1), 0); + store(src_zmm(0), 1); + store(src_zmm(3), 2); + store(src_zmm(2), 3); + store(src_zmm(9), 4); + store(src_zmm(8), 5); + store(src_zmm(11), 6); + store(src_zmm(10), 7); + store(src_zmm(5), 8); + store(src_zmm(4), 9); + store(src_zmm(7), 10); + store(src_zmm(6), 11); + store(src_zmm(13), 12); + store(src_zmm(12), 13); + store(src_zmm(15), 14); + store(src_zmm(14), 15); + +} + +void jit_trans_iw_ic_int16_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 2, 3, 0, 1, 6, 7, 4, 5 }; + alignas(64) static constexpr const int64_t idx2[8] + = { 1, 0, 3, 2, 5, 4, 7, 6 }; + alignas(64) static constexpr const int32_t idx3[16] + = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 }; + alignas(64) static constexpr const int32_t idx4[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + alignas(64) static constexpr const int32_t idx5[16] + = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 }; + + const int ic_block = conf_->ic_block; + const int iw = conf_->iw; + const int tr_iw = conf_->tr_iw; + const int transposes = utils::div_up(iw, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = iw - loop_iters * transpose_size; + + src_stride = ic_block * typesize; + tr_src_stride = tr_iw * typesize; + + bool nontemporal_stores = false; + enable_prefetch = iw > small_spatial ? 1 : 0; + + assert(transpose_size == ic_block); + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * typesize; + + const int left_pad = conf_->l_pad; + const int right_pad = tr_iw - iw - left_pad; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFFFF, 0xffff); + kmovw(k5555, 0x5555); + kmovw(kAAAA, 0xaaaa); + kmovw(kAA, 0xaa); + kmovw(k55, 0x55); + kmovw(kCC, 0xcc); + kmovw(k33, 0x33); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast<size_t>(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast<size_t>(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + vmovdqa64(vidx2, idx2); + vmovdqa32(vidx3, idx3); + vmovdqa32(vidx4, idx4); + vmovdqa32(vidx5, idx5); + + if (left_pad > 0 && loop_iters > 0) { + loop_iters--; + transpose(transpose_size, left_pad, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step + left_pad * typesize); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step + left_pad * typesize); + } + + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + if (transposes > 1) + transpose(tail, 0, right_pad, nontemporal_stores); + else + transpose(tail, left_pad, right_pad, nontemporal_stores); + + postamble(); + +} + +struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) + jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using reg64_t = const Xbyak::Reg64; + using reg32_t = const Xbyak::Reg32; + using opmask_t = const Xbyak::Opmask; + using zmm = const Xbyak::Zmm; + + enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 }; + int src_stride, tr_src_stride; + int tail; + bool enable_prefetch; + + opmask_t kFF = k1; + + zmm vidx1 = zmm31; + + reg64_t reg_src = r8; + reg64_t reg_tr_src = r9; + reg64_t reg_src_prf = r10; + reg64_t reg_tr_src_prf = r11; + reg64_t reg_loop = r12; + reg64_t reg_tr_src_tmp = r13; + reg32_t regw_tmp = r14d; + reg64_t imm_addr64 = rbx; + + void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); + void generate(); +}; + +void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, + bool nontemporal_stores) { + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 16, "Unsupported transpose size"); + if (!nrows) + return; + + auto src_zmm = [=](int i) { + return Zmm(i); + }; + + auto src_ymm = [=](int i) { + assert(i >= 0 && i < 16); + return Ymm(i); + }; + + auto load_ymm = [=](int i) { + vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + + auto store = [=](Zmm r, int i) { + auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); + if (nontemporal_stores) + vmovntps(addr, r); + else + vmovups(addr, r); + }; + + for (int i = 0; i < nrows/2; i++) { + auto src0 = src_ymm(2*i); + auto src1 = src_ymm(2*i+1); + auto zmm_src0 = src_zmm(2*i); + load_ymm(2*i); + vpunpcklwd(src1, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vpunpckhwd(src0, src0, + EVEX_compress_addr(reg_src, (2*i+1) * src_stride)); + vinserti64x4(zmm_src0, zmm_src0, src1, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, 2*i); + } + if (r_pad > 0) { + auto src0 = src_ymm(nrows-1); + auto src1 = src_ymm(nrows); + auto zmm_src0 = src_zmm(30); + load_ymm(nrows-1); + + vpxor(src1, src1, src1); + vpunpckhwd(src1, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src1, 0); + vpxor(src1, src1, src1); + vpunpcklwd(src0, src0, src1); + vinserti64x4(zmm_src0, zmm_src0, src0, 1); + vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); + store(zmm_src0, nrows-1); + } +} + +void jit_trans_ow_oc_t::generate() { + preamble(); + + alignas(64) static constexpr const int64_t idx1[8] + = { 4, 5, 0, 1, 6, 7, 2, 3 }; + + const int oc_block = conf_->oc_block; + const int ow = conf_->ow; + const int transposes = utils::div_up(ow, transpose_size); + int loop_iters = nstl::max(0, transposes - 1); + tail = ow - loop_iters * transpose_size; + + src_stride = oc_block * typesize; + tr_src_stride = oc_block * typesize; + + bool nontemporal_stores = false; + enable_prefetch = ow > small_spatial ? 1 : 0; + + const int src_step = oc_block * transpose_size * typesize; + const int tr_src_step = oc_block * transpose_size * typesize; + const int right_pad = ow % 2; + + mov(reg_src, ptr [param1 + GET_OFF(src)]); + mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]); + mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]); + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + kmovw(kFF, 0xFF); + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast<size_t>(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + vmovdqa64(vidx1, idx1); + if (loop_iters) { + mov(reg_loop, loop_iters); + Label loop; + L(loop); { + transpose(transpose_size, 0, 0, nontemporal_stores); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, 1); + jnz(loop); + } + } + transpose(tail, 0, right_pad, nontemporal_stores); + + postamble(); +} + +struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t) + + jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) { + generate(); + ker_ = (decltype(ker_))this->getCode(); + } + + void generate(); + enum { typesize = (int)sizeof(float) }; +}; + +/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4] + * required for 1st 4fma backward by weights convolution */ +void jit_trans_iw_x4_4x_t::generate() { + using namespace utils; + + /* TODO: put into code */ + static int mask[16] = { + 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, }; + + const auto &c = *conf_; + const int simd_w = cpu_isa_traits<avx512_common>::vlen / typesize; + const int niters = c.tr_ld / simd_w; + + assert(niters <= 4); /* [bwd_w:tr_src:r1] */ + + Reg64 reg_ptr_src = r8; + Reg64 reg_ptr_tr_src = r9; + + Reg64 reg_ih = rax; + Reg64 reg_ih_end = rbx; + + Reg64 reg_nthr_oc_b = rsi; + Reg64 reg_ptr_tr_src_bctx = abi_not_param1; + + Reg64 reg_tmp = rdx; + + Zmm vmsk = Zmm(31); + Opmask kmsk = k7; + + auto emit_tr_sync = [&]() { + simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b); + }; + + auto emit_tr_iw = [&]() { + auto vreg = [](int iter, int i) { + assert(4 * iter + i < 24); + return Zmm(4 * iter + i); + }; + auto vtmp = [](int i) { return Zmm(24 + i); }; + + auto emit_load = [&](int iter) { + for (int i = 0; i < 4; ++i) { + auto v = vreg(iter, i); + const int off = (iter * 4 + i) * simd_w; + + if (off + simd_w <= c.iw) + vmovups(v, ptr[reg_ptr_src + off * typesize]); + else if (off < c.iw) + vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]); + else + vpxord(v, v, v); + } + }; + + auto emit_tr = [&](int iter) { + for (int i = 0; i < 4; ++i) + vpermps(vreg(iter, i), vmsk, vreg(iter, i)); + + vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88); + vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd); + vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88); + vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd); + + vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88); + vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd); + vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88); + vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd); + }; + + auto emit_store = [&]() { + for (int i = 0; i < 4; ++i) { + for (int iter = 0; iter < niters; ++iter) { + const size_t off = i * c.tr_ld + iter * simd_w; + vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i)); + } + } + }; + + for (int iter = 0; iter < niters; ++iter) + emit_load(iter); + + for (int iter = 0; iter < niters; ++iter) + emit_tr(iter); + + emit_store(); + }; + + preamble(); + + mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]); + mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]); + + mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]); + mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]); + mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]); + mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]); + + emit_tr_sync(); + + Label l_ih_loop, l_tr_done; + cmp(reg_ih, reg_ih_end); + je(l_tr_done, T_NEAR); + + mov(reg_tmp, (size_t)&mask[0]); + vmovups(vmsk, ptr[reg_tmp]); + + if (c.iw % simd_w) { + const char load_mask = (1 << (c.iw % simd_w)) - 1; + mov(reg_tmp, load_mask); + kmovw(kmsk, reg_tmp.cvt32()); + } + + /* src += ih_start * c.iw; */ + imul(reg_tmp, reg_ih, c.iw * typesize); + add(reg_ptr_src, reg_tmp); + /* tr_src += ih_start * c.stride_w * c.tr_ld; */ + imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize); + add(reg_ptr_tr_src, reg_tmp); + + L(l_ih_loop); { + emit_tr_iw(); + + add(reg_ptr_src, c.iw * typesize); + add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize); + + inc(reg_ih); + cmp(reg_ih, reg_ih_end); + jl(l_ih_loop, T_NEAR); + } + + L(l_tr_done); + + emit_tr_sync(); + + postamble(); +} + +/* +// ------------------------------------------------- +// jit_transpose4x16_src +// ------------------------------------------------- +*/ + +void jit_transpose4x16_src::transpose(int nrows) +{ + assert(nrows >= 0 && nrows <= transpose_size); + static_assert(transpose_size == 4, "Unsupported transpose size"); + if (!nrows) + return; + + auto pf_src_t0 = [=](int i) { + if (tparams->src_pf0_distance) + prefetcht0(EVEX_compress_addr( + reg_src, (tparams->src_pf0_distance + i) * src_stride)); + }; + + auto pf_tr_src_t0 = [=](int i) { + if (tparams->tr_src_pf0_distance) + prefetcht0(EVEX_compress_addr(reg_tr_src, + (tparams->tr_src_pf0_distance + i) * src_stride)); + }; + + auto pf_src_t1 = [=](int i) { + if (tparams->src_pf1) + prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); + }; + + auto pf_tr_src_t1 = [=](int i) { + if (tparams->tr_src_pf1) + prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); + }; + + auto src_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(i); + }; + + auto tmp_zmm = [=](int i) { + assert(i >= 0 && i < 4); + return Zmm(4 + i); + }; + + auto load = [=](int i) { + vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); + }; + + auto store = [=](Zmm r, int i) { + vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); + }; + + auto tmp0 = tmp_zmm(0); + auto tmp1 = tmp_zmm(1); + auto tmp2 = tmp_zmm(2); + auto tmp3 = tmp_zmm(3); + + auto src0 = src_zmm(0); + auto src1 = src_zmm(1); + auto src2 = src_zmm(2); + auto src3 = src_zmm(3); + for (int i = 0; i < nrows; i++) { + load(i); + } + + for (size_t i = nrows; i < 4; i++) { + vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); + } + + vmovupd(tmp0, src0); + vmovupd(tmp1, src1); + pf_src_t0(0); + vpermpd(tmp0 | kF0, vidx01, src2); + vpermpd(tmp1 | kF0, vidx01, src3); + + valignd(src0, src0, src0, 8); + valignd(src1, src1, src1, 8); + pf_src_t0(1); + vmovupd(tmp2, src0); + vmovupd(tmp3, src1); + pf_src_t0(2); + vpermpd(tmp2 | kF0, vidx10, src2); + vpermpd(tmp3 | kF0, vidx10, src3); + pf_src_t0(3); + + vmovupd(src0, tmp0); + pf_src_t1(0); + vmovupd(src1, tmp2); + pf_src_t1(1); + vmovupd(src2, tmp1); + pf_src_t1(2); + vmovupd(src3, tmp3); + pf_src_t1(3); + vpermpd(src0 | kCC, vidx1, tmp1); + vpermpd(src1 | kCC, vidx1, tmp3); + pf_tr_src_t0(0); + vpermpd(src2 | k33, vidx1, tmp0); + vpermpd(src3 | k33, vidx1, tmp2); + pf_tr_src_t0(1); + + vmovupd(tmp0, src0); + vmovupd(tmp1, src2); + pf_tr_src_t0(2); + vmovupd(tmp2, src1); + vmovupd(tmp3, src3); + pf_tr_src_t0(3); + vpermps(tmp0 | kFFFF, vidxP, src0); + pf_tr_src_t1(0); + vpermps(tmp1 | kFFFF, vidxP, src2); + pf_tr_src_t1(1); + vpermps(tmp2 | kFFFF, vidxP, src1); + pf_tr_src_t1(3); + vpermps(tmp3 | kFFFF, vidxP, src3); + pf_tr_src_t1(4); + + store(tmp0, 0); + store(tmp1, 1); + store(tmp2, 2); + store(tmp3, 3); +} + +alignas(64) static constexpr const int64_t idx01[8] + = { 0, 0, 0, 0, 0, 1, 2, 3 }; +alignas(64) static constexpr const int64_t idx10[8] + = { 0, 0, 0, 0, 4, 5, 6, 7 }; +alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 }; +alignas(64) static constexpr const int32_t idxP[16] + = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + +void jit_transpose4x16_src::generate() +{ + preamble(); + + const int ic_block = params->ic_block; + const int is = params->is; + int tail = is % transpose_size; + + src_stride = ic_block * typesize; + assert(src_stride == 64); + tr_src_stride = ic_block * typesize; + + const int src_step = ic_block * transpose_size * typesize; + const int tr_src_step = ic_block * transpose_size * typesize; + +#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) + mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); + mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); + mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); + mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); + mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); +#undef GET_TR_OFF + + auto kmovw = [=](Opmask k, unsigned w) { + mov(regw_tmp, w); + jit_generator::kmovw(k, regw_tmp); + }; + + auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { + mov(imm_addr64, reinterpret_cast<size_t>(addr)); + jit_generator::vmovdqa64(z, ptr[imm_addr64]); + }; + + auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { + mov(imm_addr64, reinterpret_cast<size_t>(addr)); + jit_generator::vmovdqa32(z, ptr[imm_addr64]); + }; + + kmovw(kF0, 0xf0); // 11110000 + kmovw(kCC, 0xcc); // 11001100 + kmovw(k33, 0x33); // 00110011 + kmovw(kFFFF, 0xffff); // 1111111111111111 + + vmovdqa64(vidx01, idx01); + vmovdqa64(vidx10, idx10); + vmovdqa64(vidx1, idx1); + vmovdqa32(vidxP, idxP); + + Label loop_label; + Label tail_label; + + cmp(reg_loop, transpose_size); + jl(tail_label, T_NEAR); + + L(loop_label); + { + transpose(transpose_size); + add(reg_src, src_step); + add(reg_tr_src, tr_src_step); + add(reg_src_prf, src_step); + add(reg_tr_src_prf, tr_src_step); + sub(reg_loop, transpose_size); + cmp(reg_loop, transpose_size); + jge(loop_label, T_NEAR); + } + L(tail_label); + transpose(tail); + + postamble(); +} + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { + if (conf->ver == ver_4fma && !conf->is_1stconv) + return new jit_trans_iw_ic_t(conf); + if (conf->ver == ver_4fma && conf->is_1stconv) + return new jit_trans_iw_x4_4x_t(conf); + assert(!"unsupported configuration"); + return nullptr; +} + +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { + assert(!"unsupported configuration"); + return nullptr; +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp new file mode 100644 index 0000000000..565e97e4fc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp @@ -0,0 +1,145 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_TRANSPOSE_SRC_HPP +#define CPU_JIT_TRANSPOSE_SRC_HPP + +#include "cpu_barrier.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_trans_src_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_src_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_src_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_src_transpose_s { + size_t size; + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; +}; + +struct jit_trans_dst_t { + struct ctx_t { + const void *src; + const void *tr_src; + const void *src_prf; + const void *tr_src_prf; + + /* 1st conv 4fma: backward by weights */ + int nthr_oc_b; /* number of threads process given src image */ + int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */ + simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */ + }; + + jit_trans_dst_t(const jit_conv_conf_t *conf) + : conf_(conf), ker_(nullptr) {} + virtual ~jit_trans_dst_t() {} + + void operator()(const ctx_t *ctx) + { assert(ker_); ker_(ctx); } + + const jit_conv_conf_t *conf_; + void (*ker_)(const ctx_t *); +}; + +struct jit_transpose4x16_src_t { + int src_pf0_distance; + int tr_src_pf0_distance; + bool src_pf1; + bool tr_src_pf1; +}; + +struct jit_transpose4x16_src : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) + + jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, + jit_transpose4x16_src_t *tparams_) + : params(aparams), tparams(tparams_) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + const jit_1x1_conv_conf_t *params; + const jit_transpose4x16_src_t *tparams; + void (*jit_ker)(jit_src_transpose_s *); + + void operator()(jit_src_transpose_s *arg) { jit_ker(arg); } + + static const int transpose_size = 4; +private: + static const int typesize = sizeof(float); + + int src_stride, tr_src_stride; + + Xbyak::Reg64 imm_addr64 = rbx; + + Xbyak::Opmask kF0 = k1; + Xbyak::Opmask kCC = k2; + Xbyak::Opmask k33 = k3; + Xbyak::Opmask kFFFF = k4; + + Xbyak::Zmm vidx01 = zmm31; + Xbyak::Zmm vidx10 = zmm30; + Xbyak::Zmm vidx1 = zmm29; + Xbyak::Zmm vidxP = zmm28; + + Xbyak::Reg64 reg_src = r8; + Xbyak::Reg64 reg_tr_src = r9; + Xbyak::Reg64 reg_src_prf = r10; + Xbyak::Reg64 reg_tr_src_prf = r11; + Xbyak::Reg64 reg_loop = r12; + Xbyak::Reg64 reg_tr_src_tmp = r13; + Xbyak::Reg32 regw_tmp = r14d; + + void transpose_block(int ur, int nrows); + void transpose(int nrows); + void generate(); +}; + +jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); +jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp new file mode 100644 index 0000000000..53313f9f01 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp @@ -0,0 +1,327 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_1x1_CONV_UTILS_HPP +#define JIT_UNI_1x1_CONV_UTILS_HPP + +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; + +struct reduce_to_unit_stride_t { + convolution_desc_t conv_d_; + bool reduce_src_; + size_t space_per_thread_; +}; + +/* 1x1-kernel does not support non-unit strides so far, so the idea is: + * - for fwd or bwd_weights: to copy src to a scratch memory (with strides + * equal to 1) and then call the kernel + * - for bwd_data: reduce the problem to the one with unit stride by + * performing computations in a scratch memory (with strides equal to 1) + * and then copy the result to diff_src */ +template <typename conv_pd_t> +inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d, + const memory_desc_t *&src_d, const memory_desc_t *dst_d) { + const bool is_bwd_data = self->desc()->prop_kind + == prop_kind::backward_data; + + const int ndims = src_d->ndims; + const auto dat_tag = ndims == 3 + ? memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nCw8c, format_tag::nCw16c) + : memory_desc_wrapper(dst_d).matches_one_of_tag( + format_tag::nChw8c, format_tag::nChw16c); + + bool rtus_applicable = true + && utils::pick(ndims - 3, + (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type, + data_type::s32)), + (conv_d->strides[0] != 1 || conv_d->strides[1] != 1)) + && dat_tag != format_tag::undef; + for (int d = 2; d < ndims; ++d) { + /* TODO: relax these conditions (by improving reducer) */ + rtus_applicable = rtus_applicable + && conv_d->padding[0][d - 2] == 0 + && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d]; + } + + if (rtus_applicable) { + self->rtus_.reduce_src_ = true; + conv_d = &(self->rtus_.conv_d_ = *conv_d); + self->rtus_.conv_d_.strides[0] = 1; + if (ndims == 4) + self->rtus_.conv_d_.strides[1] = 1; + utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2); + if (ndims == 4) + utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2); + const int ic = src_d->dims[1]; + if (is_bwd_data) { + src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d); + self->rtus_.conv_d_.diff_src_desc.dims[1] = ic; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.diff_src_desc, dat_tag); + } else { + data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type; + src_d = &(self->rtus_.conv_d_.src_desc = *dst_d); + self->rtus_.conv_d_.src_desc.dims[1] = ic; + self->rtus_.conv_d_.src_desc.data_type = data_type; + memory_desc_wrapper::compute_blocking( + self->rtus_.conv_d_.src_desc, dat_tag); + } + } +} + +template <typename conv_pd_t> +inline void rtus_prepare_space_info(conv_pd_t *self, + memory_tracking::registrar_t &scratchpad) { + const auto &jcp = self->jcp_; + + const int max_threads = mkldnn_get_max_threads(); + const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind, + jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking); + size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->desc())->data_type); + + self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block; + scratchpad.book(memory_tracking::names::key_conv_rtus_space, + typesize * max_threads * self->rtus_.space_per_thread_); +} + +template <cpu_isa_t isa> +struct rtus_driver_t: public jit_generator { + + struct call_params_t { + const void *ws; /* reduced image (w/ strides = 1) */ + const void *src; /* source image (w/ non-unit strides) */ + size_t icb; + size_t os; + size_t iw_start; + }; + + void (*ker_)(const call_params_t *p); + + DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm, + Xbyak::Zmm>::type; + + Xbyak::Reg64 reg_ws = abi_param1; + Xbyak::Reg64 reg_src = abi_not_param1; + Xbyak::Reg64 reg_icb = rdx; + Xbyak::Reg64 reg_os = r11; + Xbyak::Reg64 reg_iw_start = r8; + + Xbyak::Reg64 reg_cur_os = rax; + Xbyak::Reg64 reg_cur_iw = r9; + Xbyak::Reg64 reg_cur_src = r10; + + int iw_, stride_w_; + int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_; + bool src_to_ws_; + size_t typesize_; + Vmm reg_zero; + Vmm reg_v; + + rtus_driver_t(int iw, int stride_w, int src_step_h, + int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize) + : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h) + , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb) + , src_to_ws_(src_to_ws), typesize_(typesize) + { + using namespace Xbyak; + vlen_ = cpu_isa_traits<isa>::vlen; + vlen_shift_ = cpu_isa_traits<isa>::vlen_shift; + if (typesize_ == 2) { + vlen_ /= 2; + vlen_shift_--; + } + + reg_zero = Vmm(0); + reg_v = Vmm(1); + + generate(); + } + + void loop_is() { + using namespace Xbyak; + + mov(reg_cur_src, reg_src); + mov(reg_cur_iw, reg_iw_start); + mov(reg_cur_os, reg_os); + + Label is_loop, skip_h_step; + L(is_loop); + + if (src_to_ws_) { + vmovups(reg_v, ptr[reg_cur_src]); + vmovups(ptr[reg_ws], reg_v); + } else { + vmovups(reg_v, ptr[reg_ws]); + vmovups(ptr[reg_cur_src], reg_v); + for (int w = 1; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + } + + add(reg_ws, vlen_); + + add(reg_cur_iw, stride_w_); + add(reg_cur_src, stride_w_ * vlen_); + + cmp(reg_cur_iw, iw_); + jl(skip_h_step); + /* for 1d convolution the loop over h should be skipped */ + if (src_step_icb_ == iw_) jmp(skip_h_step); + + if (src_to_ws_) { + add(reg_cur_src, (src_step_h_ - iw_) * vlen_); + } else { + Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */ + mov(reg_cur_src_fin, reg_cur_src); + add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_); + Label ih_loop; + L(ih_loop); + + for (int w = 0; w < stride_w_; ++w) + vmovups(ptr[reg_cur_src + w * vlen_], reg_zero); + + add(reg_cur_src, stride_w_ * vlen_); + cmp(reg_cur_src, reg_cur_src_fin); + jl(ih_loop); + } + xor_(reg_cur_iw, reg_cur_iw); + + L(skip_h_step); + + sub(reg_cur_os, vlen_); + jnz(is_loop); + + /* restore dst */ + sub(reg_ws, reg_os); + } + + void generate() { + using namespace Xbyak; + assert(isa == avx2 || isa == avx512_common + || isa == avx512_core || isa == avx512_mic); + +#if defined(_WIN32) + assert(reg_src == abi_not_param1 && abi_not_param1 == rdi); + push(rdi); +#endif + +#define READ_PARAM(what) \ + mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)]) + READ_PARAM(src); + READ_PARAM(icb); + READ_PARAM(os); + READ_PARAM(iw_start); + + assert(reg_ws == abi_param1); + READ_PARAM(ws); /* reg_ws should always be read the last */ +#undef READ_PARAM + + shl(reg_os, vlen_shift_); + + if (!src_to_ws_) + uni_vpxor(reg_zero, reg_zero, reg_zero); + + Label icb_loop; + L(icb_loop); + + loop_is(); + + add(reg_ws, ws_step_icb_ * vlen_); + add(reg_src, src_step_icb_ * vlen_); + + dec(reg_icb); + jnz(icb_loop, T_NEAR); + +#if defined(_WIN32) + pop(rdi); +#endif + + uni_vzeroupper(); + ret(); + this->ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>( + this->getCode())); + } +}; + +template <cpu_isa_t isa, typename conv_t> +inline void init_rtus_driver(conv_t *self) { + const auto &conf = *self->pd(); + if (!conf.rtus_.reduce_src_) return; + + const auto &cd = *conf.desc(); + const int ndims = conf.ndims(); + const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0]; + const int stride_w = cd.strides[ndims - 3]; + + const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data; + const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md(); + + const int ih = ndims == 3 ? 1 : src_d.dims[2]; + const int iw = src_d.dims[ndims - 1]; + + const int src_step_h = stride_h * iw; + const int src_step_icb = ih * iw; + const int ws_step_icb = conf.jcp_.is; + const bool src_to_ws = !is_bwd_data; + const size_t typesize = types::data_type_size( + conv_prop_invariant_src_d(self->pd()->desc())->data_type); + + self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h, + src_step_icb, ws_step_icb, src_to_ws, typesize); +} + +inline int best_divider(int value, int min_divider, int max_divider, + bool find_max, int step = 1) +{ + max_divider = nstl::max(1, nstl::min(max_divider, value)); + min_divider = nstl::max(1, nstl::min(min_divider, max_divider)); + + auto loss_ratio = [](int total, int chunk) + { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); }; + + float min_loss = FLT_MAX; + int x_divider = max_divider; + for (int divider = max_divider; divider >= min_divider; divider -= step) { + const float loss = loss_ratio(value, divider); + if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) { + min_loss = loss; + x_divider = divider; + } + } + return x_divider; +} + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp new file mode 100644 index 0000000000..72fe3a8109 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp @@ -0,0 +1,1407 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "jit_uni_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { + +using namespace memory_tracking::names; + +using namespace Xbyak; +namespace barrier = simple_barrier; + +typedef float data_t; + +template <cpu_isa_t isa> +struct jit_bnorm_t: public jit_generator { + struct call_params_t { + // keep all sizes at 8 bytes -- jit code expects this + size_t N_ithr, N_nthr; + size_t coff_max, soff_max; + size_t mb_stride_Bc, spat_size, spat_size_loc; + size_t S_s, S_tail; + size_t is_cblk_tail; + data_t chan_size, eps, one; + const data_t *scale_shift; + const data_t *mean, *var; + const data_t *diff_scale_shift; + const data_t *src, *dst; + const data_t *diff_src, *diff_dst; + const data_t *rbuf1, *rbuf2; + const uint8_t *ws; + barrier::ctx_t *barrier; + }; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t) + + /* cpu specific part */ + using Vmm = typename utils::conditional3<isa == sse42, Xmm, + isa == avx2, Ymm, Zmm>::type; + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx2) ? yword : zword; + + const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen; + + const batch_normalization_pd_t *bdesc_; + bool is_spatial_thr_; + + void (*ker)(const call_params_t *); + void operator()(const call_params_t *p) { (*ker)(p); } + + Reg64 reg_param = abi_param1; + + Reg64 reg_scale_shift = rbx; + Reg64 reg_rbuf1 = abi_not_param1; + Reg64 reg_rbuf2 = rdx; + + Reg64 reg_mean = rbp; + Reg64 reg_var = reg_param; + Reg64 reg_diff_scale_shift = rax; + + Reg64 reg_coff = r8; + Reg64 reg_coff_max = r9; + Reg64 reg_soff = r10; + Reg64 reg_soff_max = r11; + Reg64 reg_ctr = r12; + Reg64 reg_roff = r13; + + Reg64 reg_mb_stride_Bc = r14; + + Reg64 reg_src = r15; + Reg64 reg_diff_src = reg_rbuf1; + Reg64 reg_dst = rsi; + Reg64 reg_diff_dst = reg_dst; + + Reg64 reg_tmp_off = reg_roff; + + // Reuse loop counters + Reg64 reg_bar = reg_coff; + Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff + Reg64 reg_tmp = reg_ctr; + + // Relu section + bool with_relu, with_relu_inf_only; + Vmm vzero; // is_fwd() ? vdiff_beta : vbeta + Reg64 reg_ws = reg_roff; + Label l_relu_mask_avx2; + Opmask kstore_mask = Opmask(1); + + // channel tail processing + Opmask ktail_mask = Opmask(2); + + size_t unroll_blocks; + size_t unroll_regs; + Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5); + Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6); + Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7); + Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8); + Vmm vone = Vmm(isa == avx512_common ? 24 : 9); + Vmm vmean = Vmm(isa == avx512_common ? 25 : 10); + Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11); + Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12); + Vmm veps = Vmm(isa == avx512_common ? 28 : 13); + Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14); + Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15); + + size_t t0_pf_offt; + size_t t1_pf_offt; + size_t spat_size; + size_t chan_data_offt; + + enum { + stack_off_N_nthr = 0, + stack_off_N_ithr = 8, + stack_off_src = 16, + stack_off_dst = 24, + stack_off_diff_src = 32, + stack_off_diff_dst = 40, + stack_off_diff_scale_shift = 48, + stack_off_ws = 56, + stack_off_barrier = 64, + stack_off_spat_size_loc = 72, + stack_off_s_s = 80, + stack_off_s_tail = 88, + stack_off_is_cblk_tail = 96, + stack_size_required = 104, + }; + + bool is_c_padded() const { + const memory_desc_wrapper data_d(bdesc_->src_md()); + return bdesc_->C() != data_d.padded_dims()[1]; + } + + void compute_static_strides() { + spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H(); + chan_data_offt = bdesc_->C() * sizeof(data_t); + + if (isa == avx512_mic) { + t0_pf_offt = 4096; + t1_pf_offt = 0; + } else { + t0_pf_offt = 0; + t1_pf_offt = 0; + } + } + + void load_common_params() { +# define PARAM_OFF(x) offsetof(call_params_t, x) + mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]); + if (bdesc_->is_bwd()) + mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]); + mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]); + mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]); + mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]); + shl(reg_coff_max, 2); + shl(reg_soff_max, 2); + shl(reg_mb_stride_Bc, 2); + + mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]); + mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]); + + uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]); + uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]); + uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]); + + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]); + mov(ptr[rsp + stack_off_N_nthr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]); + mov(ptr[rsp + stack_off_N_ithr], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]); + mov(ptr[rsp + stack_off_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]); + mov(ptr[rsp + stack_off_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]); + mov(ptr[rsp + stack_off_diff_src], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]); + mov(ptr[rsp + stack_off_diff_dst], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]); + mov(ptr[rsp + stack_off_ws], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]); + mov(ptr[rsp + stack_off_barrier], reg_tmp); + if (is_spatial_thr_) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]); + mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]); + mov(ptr[rsp + stack_off_s_s], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]); + mov(ptr[rsp + stack_off_s_tail], reg_tmp); + } + if (is_c_padded()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]); + mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp); + } + + if (bdesc_->is_fwd()) { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } else { + mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]); + mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp); + mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]); + mov(reg_var, reg_tmp); + } +# undef PARAM_OFF + } + + void prepare_tail_mask_avx512_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + const int mask = (1 << tail) - 1; + + Reg32 regw_tmp = reg_tmp.cvt32(); + mov(regw_tmp, mask); + kmovw(ktail_mask, regw_tmp); + } + + void prepare_tail_mask_avx2_common() { + if (!is_c_padded()) return; + + const int tail = bdesc_->C() % (int)(vlen / sizeof(float)); + static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0, 0, 0, 0, 0, 0, 0, 0}; + + mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail])); + vmovups(vtail_mask, ptr[reg_tmp]); + } + + void prepare_relu() { + with_relu = bdesc_->is_fwd() + ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu() + : bdesc_->fuse_bn_relu(); + with_relu_inf_only = with_relu && bdesc_->is_fwd() + && !(bdesc_->fuse_bn_relu() && bdesc_->is_training()); + + vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta; + if (with_relu) { + uni_vpxor(vzero, vzero, vzero); + if (!bdesc_->is_fwd() && isa == avx2) + prepare_l_relu_mask_avx2(); + } + } + + void prepare_l_relu_mask_avx2() { + Label l_mask_after; + jmp(l_mask_after); + align(32); + L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */ + for (int i = 0; i < 8; ++i) dd(1<<i); + L(l_mask_after); + } + + void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) { + Reg64 reg_store_mask = reg_diff_scale_shift; + shr(reg_soff, 5); + vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os); + vmovmskps(reg_store_mask, vstore_mask); + mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8()); + vblendvps(vdst, vzero, vdst, vstore_mask); + shl(reg_soff, 5); + } + + void fwd_process_relu_avx512_common(Vmm vdst, int offt) { + shr(reg_soff, 5); + vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os); + kmovw(ptr[reg_ws + reg_soff + offt / (1 << 5)], kstore_mask); + vblendmps(vdst | kstore_mask, vzero, vdst); + shl(reg_soff, 5); + } + + void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) { + shr(reg_soff, 5); + vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]); + vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]); + vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]); + vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask); + shl(reg_soff, 5); + } + + void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) { + shr(reg_soff, 5); + kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]); + vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst); + shl(reg_soff, 5); + } + + void uni_vmovups_tail_avx2_common(const Operand &dst, + const Operand &src, Label &l_ret) { + if (dst.isMEM()) { + vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx())); + } else { + vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress()); + } + jmp(l_ret); + } + + void uni_vmovups_tail_avx512_common(const Operand &dst, + const Operand &src, Label &l_ret) { + if (dst.isMEM()) + uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx())); + else + uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress()); + + jmp(l_ret); + } + + void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) { + Label l_no_mask, l_ret; + + if (is_c_padded()) { + mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]); + cmp(reg_tmp, 0); + jz(l_no_mask); + + lea(reg_tmp, ptr[reg_coff + vlen]); + cmp(reg_tmp, reg_coff_max); + jl(l_no_mask); + assert(isa == avx512_common || isa == avx2); + if (isa == avx512_common) + uni_vmovups_tail_avx512_common(dst, src, l_ret); + else if (isa == avx2) + uni_vmovups_tail_avx2_common(dst, src, l_ret); + } + L(l_no_mask); + if (dst.isMEM()) + uni_vmovups(dst.getAddress(), Vmm(src.getIdx())); + else + uni_vmovups(Vmm(dst.getIdx()), src.getAddress()); + + L(l_ret); + } + + void barrier() { + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + mov(reg_bar, ptr[rsp + stack_off_barrier]); + simple_barrier::generate(*this, reg_bar, reg_nnthr); + } + + Address mean_ptr(size_t offt = 0) { + return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt]; + } + + Address var_ptr(size_t offt = 0) { + return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt]; + } + + Address diff_gamma_ptr(size_t offt = 0) { + return vmmword[reg_diff_scale_shift + reg_coff + offt + + 0 * chan_data_offt]; + } + + Address diff_beta_ptr(size_t offt = 0) { + return vmmword[reg_diff_scale_shift + reg_coff + offt + + 1 * chan_data_offt]; + } + + Address gamma_ptr(size_t offt = 0) { + return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt]; + } + + Address beta_ptr(size_t offt = 0) { + return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt]; + } + + template <typename init_t, typename body_t, typename fini_t> + void spat_loop(size_t len, size_t blocks, size_t regs, + init_t init, body_t body, fini_t fini) { + size_t factor = regs * blocks; + size_t loop_unroll = len / factor * factor; + size_t loop_tail = len - loop_unroll; + size_t num_active_regs = (len < regs) ? len : regs; + for (size_t i = 0; i < num_active_regs; i++) + init(i); + if (loop_unroll) { + if (is_spatial_thr_) { + mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]); + add(reg_soff, ptr[rsp + stack_off_s_s]); + } else { + mov(reg_ctr, loop_unroll); + } + Label label; + L(label); { + for (size_t i = 0; i < factor; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + add(reg_soff, factor * vlen); + sub(reg_ctr, factor); + jnz(label); + } + if (is_spatial_thr_) { + add(reg_soff, ptr[rsp + stack_off_s_tail]); + } + } + + for (size_t i = 0; i < loop_tail; i++) { + size_t base_reg = i % regs; + body(base_reg, i); + } + if (loop_tail) + add(reg_soff, loop_tail * vlen); + + for (size_t i = 0; i < num_active_regs; i++) + fini(i); + } + + void mean_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, + unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v0 = Vmm(base_reg * 2 + 0); + Vmm v1 = Vmm(base_reg * 2 + 1); + size_t offt = i * vlen; + uni_vmovups(v1, + vmmword[reg_src + reg_soff + offt]); + uni_vaddps(v0, v0, v1); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 2); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void var_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) { + Vmm v = Vmm(base_reg * 3); + if (base_reg > 0) + uni_vpxor(v, v, v); + }, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(3 * base_reg); + Vmm vtmp0 = Vmm(3 * base_reg + 1); + Vmm vtmp1 = Vmm(3 * base_reg + 2); + size_t offt = i * vlen; + uni_vmovups(vtmp0, + vmmword[reg_src + reg_soff + offt]); + if (isa == sse42) { + movups(vtmp1, vmean); + subps(vtmp1, vtmp0); + } else { + vsubps(vtmp1, vmean, vtmp0); + } + uni_vfmadd231ps(v, vtmp1, vtmp1); + + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b = Vmm(0); + Vmm v = Vmm(base_reg * 3); + if (base_reg) + uni_vaddps(b, b, v); + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void compute_mean_variance() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf; + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + + xor_(reg_soff, reg_soff); + Label mean_spatial; + L(mean_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + mean_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + mean_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(mean_spatial); + } + + Label no_mean_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_mean_reduction); + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label mean_reduction_channels; + L(mean_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label mean_reduction_thrs; + L(mean_reduction_thrs); { + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0)); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(mean_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(mean_ptr(), Vmm(1)); + + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(mean_reduction_channels); + } + } + L(no_mean_reduction); + barrier(); + + xor_(reg_soff, reg_soff); + Label var_spatial; + L(var_spatial); { + xor_(reg_coff, reg_coff); + + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + var_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + + var_channels(); + + sub(reg_src, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(var_spatial); + } + + Label no_var_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + jne(no_var_reduction); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + Label var_reduction_channels; + L(var_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + mov(reg_ctr, reg_nnthr); + Label var_reduction_thrs; + L(var_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(var_reduction_thrs); + } + uni_vdivps(Vmm(1), Vmm(1), vchan_size); + uni_vmovups_maybe_tail(var_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + + cmp(reg_coff, reg_coff_max); + jne(var_reduction_channels); + } + } + L(no_var_reduction); + barrier(); + } + + void forward_channels() { + Label ch_label; + L(ch_label); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + + if (bdesc_->use_scaleshift()) { + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vbeta, beta_ptr()); + } + + Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone; + Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar; + + if (isa == sse42) { + movups(vbuf, vscale); + divps(vbuf, vsqrtvar); + movups(vdiv, vbuf); + } else { + vdivps(vdiv, vscale, vsqrtvar); + } + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v = Vmm(base_reg); + size_t offt = i * vlen; + uni_vmovups(v, + vmmword[reg_src + reg_soff + offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + uni_vsubps(v, v, vmean); + if (bdesc_->use_scaleshift()) { + uni_vfmadd213ps(v, vgamma, vbeta); + } else { + uni_vmulps(v, v, vsqrtvar); + } + if (with_relu_inf_only) { + uni_vmaxps(v, v, vzero); + } else if (with_relu) { + if (isa == avx512_common) + fwd_process_relu_avx512_common(v, offt); + else + fwd_process_relu_avx2(v, offt, Vmm(3)); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_dst + reg_soff + offt], v); + } else { + uni_vmovups( + vmmword[reg_dst + reg_soff + offt], v); + } + }, + [](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_dst, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(ch_label); + } + } + + void forward() { + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_dst, ptr[rsp + stack_off_dst]); + mov(reg_ws, ptr[rsp + stack_off_ws]); + + xor_(reg_soff, reg_soff); + Label dst_spatial; + L(dst_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) + mov(reg_tmp_off, reg_soff); + + forward_channels(); + + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_src, vlen / 2); + add(reg_dst, vlen / 2); + mov(reg_coff, vlen / 2); + + forward_channels(); + + sub(reg_src, vlen / 2); + sub(reg_dst, vlen / 2); + } + + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jnz(dst_spatial); + } + } + + void backward_sh_channels() { + Label sh_channels; + L(sh_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]); + uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]); + spat_loop(spat_size, 1, 1, + [=](size_t base_reg) { + if (base_reg > 0) { + for (int i = 0; i < 2; i++) { + Vmm v(base_reg * 5 + i); + uni_vpxor(v, v, v); + } + } + }, + [=](size_t base_reg, size_t i) { + Vmm o0 = Vmm(base_reg * 5 + 0); + Vmm o1 = Vmm(base_reg * 5 + 1); + Vmm t1 = Vmm(base_reg * 5 + 2); + Vmm t2 = Vmm(base_reg * 5 + 3); + Vmm t3 = Vmm(base_reg * 5 + 4); + size_t offt = i * vlen; + uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]); + uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(t2, offt); + else if (isa == avx2) + bwd_process_relu_avx2(t2, offt, t3); + else + assert(false); + } + uni_vsubps(t3, vmean, t1, t3); + if (isa == sse42) { + mulps(t3, t2); + subps(o0, t3); + } else { + vfnmadd231ps(o0, t3, t2); + } + uni_vaddps(o1, o1, t2); + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt + + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) { + Vmm b0 = Vmm(0); + Vmm b1 = Vmm(1); + if (base_reg) { + uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0)); + uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1)); + } + }); + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1)); + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(sh_channels); + } + } + + void backward_diff_channels() { + Label diff_channels; + L(diff_channels); { + uni_vmovups_maybe_tail(vmean, mean_ptr()); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + if (bdesc_->use_scaleshift()) + uni_vmovups_maybe_tail(vgamma, gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr()); + uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr()); + uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar); + uni_vdivps(vdiff_beta, vdiff_beta, vchan_size); + uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size); + + auto compute = [=](bool output_is_aligned) { + spat_loop(spat_size, unroll_blocks, unroll_regs, + [=](size_t base_reg) {UNUSED(base_reg);}, + [=](size_t base_reg, size_t i) { + Vmm v(base_reg * 2 + 0); + Vmm t(base_reg * 2 + 1); + Vmm t1(base_reg * 2 + 2); + size_t offt = i * vlen; + uni_vmovups(v, vmmword[reg_diff_dst + reg_soff + + offt]); + if (with_relu) { + if (isa == avx512_common) + bwd_process_relu_avx512_common(v, offt); + else if (isa == avx2) + bwd_process_relu_avx2(v, offt, t); + else + assert(false); + } + if (!bdesc_->use_global_stats()) { + uni_vsubps(v, v, vdiff_beta); + uni_vmovups(t, vmmword[reg_src + reg_soff + + offt]); + uni_vsubps(t, vmean, t, t1); + uni_vmulps(t, t, vdiff_gamma); + uni_vaddps(v, v, t); + } + uni_vmulps(v, v, vsqrtvar); + if (bdesc_->use_scaleshift()) { + uni_vmulps(v, v, vgamma); + } + if (output_is_aligned) { + uni_vmovntps( + vmmword[reg_diff_src + reg_soff + offt], + v); + } else { + uni_vmovups( + vmmword[reg_diff_src + reg_soff + offt], + v); + } + mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht0(ptr[reg_src + reg_soff + offt + + t0_pf_offt]); + mic_prefetcht1(ptr[reg_diff_dst + reg_soff + + offt + t1_pf_offt]); + mic_prefetcht1(ptr[reg_src + reg_soff + offt + + t1_pf_offt]); + }, + [=](size_t base_reg) {UNUSED(base_reg);}); + }; + + Label unaligned_store, end_store; + test(reg_diff_src, vlen - 1); + jnz(unaligned_store, T_NEAR); + compute(true); + jmp(end_store, T_NEAR); + L(unaligned_store); { + compute(false); + } + L(end_store); + + add(reg_coff, vlen); + cmp(reg_coff, reg_coff_max); + jl(diff_channels); + } + } + + void backward() { + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + xor_(reg_coff, reg_coff); + Label zero_rbuf, sh_spatial; + + L(zero_rbuf); { + uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0)); + uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(zero_rbuf); + } + + mov(reg_src, ptr[rsp + stack_off_src]); + mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + L(sh_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_sh_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_sh_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(sh_spatial); + } + + mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]); + + Label no_sh_reduction; + barrier(); { + mov(reg_tmp, ptr[rsp + stack_off_N_ithr]); + cmp(reg_tmp, 0); + Label sh_reduction_channels; + jne(no_sh_reduction, T_NEAR); + + mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]); + xor_(reg_coff, reg_coff); + L(sh_reduction_channels); { + mov(reg_roff, reg_coff); + uni_vpxor(Vmm(0), Vmm(0), Vmm(0)); + uni_vpxor(Vmm(1), Vmm(1), Vmm(1)); + uni_vmovups_maybe_tail(vsqrtvar, var_ptr()); + uni_vaddps(vsqrtvar, vsqrtvar, veps); + uni_vsqrtps(vsqrtvar, vsqrtvar); + uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf); + mov(reg_ctr, reg_nnthr); + Label sh_reduction_thrs; + L(sh_reduction_thrs); { // TODO: unroll (?) + uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]); + uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]); + add(reg_roff, reg_coff_max); + sub(reg_ctr, 1); + jnz(sh_reduction_thrs); + } + uni_vmulps(Vmm(0), Vmm(0), vsqrtvar); + uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0)); + uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1)); + add(reg_coff, isa == sse42 ? vlen / 2 : vlen); + cmp(reg_coff, reg_coff_max); + jne(sh_reduction_channels); + } + } + L(no_sh_reduction); + barrier(); + + mov(reg_diff_src, ptr[rsp + stack_off_diff_src]); + if (with_relu) { + assert(isa == avx2 || isa == avx512_common); + mov(reg_ws, ptr[rsp + stack_off_ws]); + } + + xor_(reg_soff, reg_soff); + Label diff_spatial; + L(diff_spatial); { + xor_(reg_coff, reg_coff); + if (isa == sse42) { + mov(reg_tmp_off, reg_soff); + } + backward_diff_channels(); + if (isa == sse42) { + mov(reg_soff, reg_tmp_off); + add(reg_diff_dst, vlen / 2); + add(reg_diff_src, vlen / 2); + add(reg_src, vlen / 2); + mov(reg_coff, vlen / 2); + backward_diff_channels(); + sub(reg_diff_dst, vlen / 2); + sub(reg_diff_src, vlen / 2); + sub(reg_src, vlen / 2); + } + add(reg_soff, reg_mb_stride_Bc); + cmp(reg_soff, reg_soff_max); + jne(diff_spatial); + } + } + + jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) { + static_assert(isa == sse42 || isa == avx2 || isa == avx512_common + || isa == avx512_mic, "unsupported isa"); + + const int simd_w = isa == sse42 ? 8 : + cpu_isa_traits<isa>::vlen / sizeof(data_t); + is_spatial_thr_ = + bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t)); + + unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1; + + preamble(); + + if (isa == avx512_common) + prepare_tail_mask_avx512_common(); + else if (isa == avx2) + prepare_tail_mask_avx2_common(); + + compute_static_strides(); + sub(rsp, stack_size_required); + load_common_params(); + prepare_relu(); + + if (bdesc_->is_fwd()) { + if (!bdesc_->stats_is_src()) { + compute_mean_variance(); + } + forward(); + } else { + backward(); + } + add(rsp, stack_size_required); + postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); + } +}; + +template <cpu_isa_t isa> +struct uni_bnorm_driver_t: public c_compatible { + uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc) + : bdesc_(bdesc), ker_(bdesc_) + { + const int nthrs = mkldnn_get_max_threads(); + const dim_t C_PADDED = get_c_padded(bdesc_); + + size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED + * bdesc_->D() * bdesc_->H() * bdesc_->W(); + l3_size_ = get_cache_size(3, true) * nthrs / 2; + do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0); + } + + ~uni_bnorm_driver_t() {} + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const batch_normalization_pd_t *bdesc) { + int nthrs = mkldnn_get_max_threads(); + dim_t C_PADDED = get_c_padded(bdesc); + + int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED; + int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED; + int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs; + + scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz); + + if (mkldnn_thr_syncable()) { + int n_barriers = C_PADDED / simd_w; + scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers); + } + } + + void exec(int ithr, int nthr, const data_t *src, data_t *diff_src, + data_t *dst, const data_t *diff_dst, const data_t *scale_shift, + data_t *diff_scale_shift, const data_t *mean, const data_t *var, + const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) { + auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats); + auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss); + auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction); + auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier); + + dim_t N = bdesc_->MB(); + dim_t C = bdesc_->C(); + dim_t C_PADDED = get_c_padded(bdesc_); + dim_t D = bdesc_->D(); + dim_t H = bdesc_->H(); + dim_t W = bdesc_->W(); + dim_t SP = D * H * W; + dim_t img_size = C_PADDED * D * H * W; + const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen; + + typename jit_bnorm_t<isa>::call_params_t p; + + p.eps = bdesc_->desc()->batch_norm_epsilon; + p.one = 1.0f; + p.spat_size = D * H * W; + p.chan_size = 1.0f * N * p.spat_size; + + dim_t C_blks = C_PADDED / simd_w; + + int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0}; + dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0}; + + dim_t C_blks_per_iter{ 1 }; + int64_t iters{ 1 }; + if (do_blocking_) { + int num_tensors = bdesc_->is_fwd() ? 1 : 2; + size_t working_set_size + = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors; + bnorm_utils::cache_balance(working_set_size, C_blks, + C_blks_per_iter, iters); + } + + bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks, + SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e, + S_ithr, S_nthr, S_s, S_e); + + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1)); + + p.N_ithr = SP_N_ithr; + p.N_nthr = SP_N_nthr; + + int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter; + int global_C_blk_s; + int global_barriers_per_iter = C_nthr; + + for (int64_t it = 0; it < iters; it++) { + if (it == iters - 1 && iters > 1) { + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + + // Update call parameters for JIT, last iteration + p.N_ithr = N_ithr * S_nthr + S_ithr; + p.N_nthr = N_nthr * S_nthr; + } + + global_C_blk_s = do_blocking_ ? + (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s : + C_blk_s; + + int C_blks_thr = C_blk_e - C_blk_s; + int N_thr = N_e - N_s; + + size_t coff_base = global_C_blk_s * simd_w; + size_t soff_base + = global_C_blk_s * p.spat_size * simd_w + N_s * img_size; + + p.spat_size_loc = S_e - S_s; + p.S_s = S_s * vlen; + p.S_tail = (p.spat_size - S_e) * vlen; + p.coff_max = C_blks_thr * simd_w; + p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base; + p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base; + p.scale_shift = scale_shift + coff_base; + p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_) + ? pbuf : diff_scale_shift) + coff_base; + + p.soff_max = N_thr * img_size; + p.src = src + soff_base; + p.dst = dst + soff_base; + p.diff_src = diff_src + soff_base; + p.diff_dst = diff_dst + soff_base; + p.ws = ws + soff_base / 8; + + p.mb_stride_Bc = img_size - p.coff_max * p.spat_size; + + // use SP_N_nthr which is the same as p.N_nthr except maybe for + // the last iteration. + p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr + + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w; + // rbuf1 and rbuf2 have to be disjoint + p.rbuf2 = p.rbuf1 + C_PADDED * nthr; + p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C; + + size_t iter_bariers + = do_blocking_ ? it * global_barriers_per_iter : 0; + p.barrier = barriers + C_ithr + iter_bariers; + if (p.soff_max != 0 && p.coff_max != 0) + ker_(&p); + } + } + + void init_barriers(const memory_tracking::grantor_t &scratchpad) { + auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier); + if (barriers) { + const int n_barriers = get_c_padded(bdesc_) / simd_w; + for (int i = 0; i < n_barriers; ++i) + barrier::ctx_init(&barriers[i]); + } + } + +private: + enum { + simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t) + }; + + static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) { + return true + && !bdesc->stats_is_src() + && bdesc->desc()->prop_kind == prop_kind::forward_inference; + } + + static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc) + { + return false + || (bdesc->is_bwd() && !bdesc->use_scaleshift()) + || bdesc->desc()->prop_kind == prop_kind::backward_data; + } + + static dim_t get_c_padded(const batch_normalization_pd_t *bdesc) + { return bdesc->src_md()->padded_dims[1]; } + + const batch_normalization_pd_t *bdesc_; + bool do_blocking_; + size_t l3_size_; + + jit_bnorm_t<isa> ker_; +}; + +} + +using namespace data_type; +using namespace format_tag; +using namespace utils; + +/* fwd */ + +template <cpu_isa_t isa> +status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + } + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this); + + return status::success; +} + +template <cpu_isa_t isa> +jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); } + +template <cpu_isa_t isa> +status_t jit_uni_batch_normalization_fwd_t<isa>::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + auto var = pd()->stats_is_src() + ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr, + scale_shift, nullptr, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template <cpu_isa_t isa> +jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t() +{ delete bnorm_driver_; } + +/* bwd */ + +template <cpu_isa_t isa> +status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() { + auto desired_fmt_tag = (ndims() == 4) + ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c + : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c; + + bool ok = true + && mayiuse(isa) + && is_bwd() + && !has_zero_dim_memory() + && one_of(ndims(), 4, 5) + && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() + && isa < avx2) + return status::unimplemented; + + if (fuse_bn_relu()) { + if (isa < avx2) return status::unimplemented; + init_default_ws(1); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + /* TODO: extra checks required */ + + auto scratchpad = scratchpad_registry().registrar(); + uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this); + + return status::success; +} + +template <cpu_isa_t isa> +jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t( + const pd_t *apd): cpu_primitive_t(apd) +{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); } + +template <cpu_isa_t isa> +status_t jit_uni_batch_normalization_bwd_t<isa>::execute( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + + bnorm_driver_->init_barriers(scratchpad); + + parallel(0, [&](const int ithr, const int nthr) { + bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst, + scale_shift, diff_scale_shift, mean, var, ws, scratchpad); + }); + + return status::success; +} + +template <cpu_isa_t isa> +jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() +{ delete bnorm_driver_; } + +/* struct instantiation */ +template struct jit_uni_batch_normalization_fwd_t<sse42>; +template struct jit_uni_batch_normalization_bwd_t<sse42>; +template struct jit_uni_batch_normalization_fwd_t<avx2>; +template struct jit_uni_batch_normalization_bwd_t<avx2>; +template struct jit_uni_batch_normalization_fwd_t<avx512_common>; +template struct jit_uni_batch_normalization_bwd_t<avx512_common>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp new file mode 100644 index 0000000000..96410ec84e --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp @@ -0,0 +1,100 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP +#define JIT_UNI_BATCH_NORMALIZATION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_isa_traits.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace { template <cpu_isa_t isa> struct uni_bnorm_driver_t; } + +template <cpu_isa_t isa> +struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_fwd_t<isa>); + + status_t init(); + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + jit_uni_batch_normalization_fwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t<isa> *bnorm_driver_; +}; + +template <cpu_isa_t isa> +struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_batch_normalization_bwd_t<isa>); + + status_t init(); + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + jit_uni_batch_normalization_bwd_t(const pd_t *apd); + ~jit_uni_batch_normalization_bwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + uni_bnorm_driver_t<isa> *bnorm_driver_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp new file mode 100644 index 0000000000..b7dc5f85c5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp @@ -0,0 +1,1302 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "cpu_memory.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +#define GET_OFF(field) offsetof(jit_conv_call_s, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::prop_kind; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +using namespace Xbyak; + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + int b_off = ch*jcp.ch_block + i*4; + if (this->jcp.with_bias) + uni_vmovups(vmm_acc, + vmmword[reg_bias + b_off*sizeof(float)]); + else + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + + int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block + + ow*jcp.ch_block + i*4; + if (this->jcp.with_sum) + uni_vaddps(vmm_acc, vmm_acc, + vmmword[reg_output + o_off*sizeof(float)]); + } + } + } +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(iter_kw, reg_kw); + mov(aux1_reg_input, aux_reg_input); + mov(aux1_reg_kernel, aux_reg_kernel); + + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + i*4; + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + add(aux1_reg_kernel, ch_blk*sizeof(float)); + add(aux1_reg_input, ch_blk*dilate_w*sizeof(float)); + + dec(iter_kw); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + int dilate_h = jcp.dilate_h + 1; + int dilate_w = jcp.dilate_w + 1; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int kw = 0; kw < jcp.kw; kw++) { + int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4; + + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux_reg_kernel + + ker_off*sizeof(float)]); + + for (int ow = 0; ow < ur_w; ow++) { + int inp_off = ch*jcp.ih*jcp.iw*ch_blk + + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux_reg_input + + inp_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + + ch*ur_w + ow); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + } + + add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float)); + add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float)); + + dec(iter_kh); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation( + int ur_ch_blocks, int ur_w) { + if (this->jcp.with_eltwise) { + int repeats = isa == sse42 ? 2 : 1; + eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4); + } +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst( + int ur_ch_blocks, int ur_w) { + int ch_blk = jcp.ch_block; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int ow = 0; ow < ur_w; ow++) { + int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4; + Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow); + + uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst); + } + } + } +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter_unrolled(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_input, reg_input); + mov(aux_reg_kernel, reg_kernel); + + load_src(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + apply_activation(ur_ch_blocks, ur_w); + store_dst(ur_ch_blocks, ur_w); + + add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_output, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() { + this->preamble(); + + mov(reg_input, ptr[this->param1 + GET_OFF(src)]); + mov(reg_output, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + if (jcp.with_bias) + mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); + + if (jcp.with_eltwise) + eltwise_injector_->prepare_table(); +} + +template <cpu_isa_t isa> +bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok( + jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len_) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +template <cpu_isa_t isa> +status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, + const primitive_attr_t &attr) +{ + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + jcp.prop_kind = cd.prop_kind; + + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = src_d.dims()[0]; + + jcp.oc = dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = dst_d.dims()[2]; + jcp.ow = dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + + if (!post_ops_ok(jcp, attr)) + return status::unimplemented; + + const auto &p = attr.post_ops_; + jcp.with_sum = p.find(primitive_kind::sum) != -1; + const int eltwise_ind = p.find(primitive_kind::eltwise); + jcp.with_eltwise = eltwise_ind != -1; + if (jcp.with_eltwise) + jcp.eltwise = p.entry_[eltwise_ind].eltwise; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.oc / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + if (jcp.with_bias && jcp.oc_without_padding != jcp.oc) + scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc); +} + +template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>; +template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>; +template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>; + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst( + int ur_ch_blocks, int ur_str_w) { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter( + int ur_ch_blocks, int ur_str_w) { + int kw = jcp.kw; + int kh = jcp.kh; + int ow = jcp.ow; + int oh = jcp.oh; + + int ch_blk = jcp.ch_block; + int stride_h = jcp.stride_h; + int stride_w = jcp.stride_w; + + Label iter_exit_label; + + cmp(reg_kh, 0); + je(iter_exit_label, T_NEAR); + + cmp(reg_kw, 0); + je(iter_exit_label, T_NEAR); + + mov(iter_kh, reg_kh); + Label kh_label; + L(kh_label); { + mov(aux1_reg_ddst, aux_reg_ddst); + mov(aux1_reg_kernel, aux_reg_kernel); + + mov(iter_kw, reg_kw); + Label kw_label; + L(kw_label); { + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + int ker_off = ch*kh*kw*ch_blk + i*4; + Vmm vmm_ker = get_ker_reg(0); + uni_vmovups(vmm_ker, ptr[aux1_reg_kernel + + ker_off*sizeof(float)]); + + for (int w = 0; w < ur_str_w; w++) { + int ddst_off = (ch*oh*ow + w)*ch_blk + i*4; + + Vmm vmm_src = get_src_reg(0); + uni_vmovups(vmm_src, ptr[aux1_reg_ddst + + ddst_off*sizeof(float)]); + + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker); + } + } + } + + add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float)); + sub(aux1_reg_ddst, ch_blk*sizeof(float)); + + sub(iter_kw, stride_w); + cmp(iter_kw, 0); + jg(kw_label, T_NEAR); + } + + add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float)); + sub(aux_reg_ddst, ow*ch_blk*sizeof(float)); + + sub(iter_kh, stride_h); + cmp(iter_kh, 0); + jg(kh_label, T_NEAR); + } + + L(iter_exit_label); +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc( + int ur_ch_blocks, int ur_str_w) { + int ch_blk = jcp.ch_block; + int iw = jcp.iw; + int ih = jcp.ih; + int stride_w = jcp.stride_w; + + int repeats = isa == sse42 ? 2 : 1; + for (int i = 0; i < repeats; i++) { + for (int ch = 0; ch < ur_ch_blocks; ch++) { + for (int w = 0; w < ur_str_w; w++) { + int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4; + Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w + + ch*ur_str_w + w); + + uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc); + } + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body( + int ur_ch_blocks) { + Label unrolled_w_label; + Label tail_w_label; + Label exit_label; + + L(unrolled_w_label); { + int ur_w = jcp.ur_w; + + cmp(reg_ur_str_w, ur_w); + jl(tail_w_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(unrolled_w_label); + } + + L(tail_w_label); { + int ur_w = 1; + + cmp(reg_ur_str_w, ur_w); + jl(exit_label, T_NEAR); + + mov(aux_reg_ddst, reg_ddst); + mov(aux_reg_kernel, reg_kernel); + + load_ddst(ur_ch_blocks, ur_w); + apply_filter(ur_ch_blocks, ur_w); + store_dsrc(ur_ch_blocks, ur_w); + + add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w); + add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block); + + sub(reg_ur_str_w, ur_w); + jmp(tail_w_label); + } + + L(exit_label); +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() { + preamble(); + + mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]); + mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]); + mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]); + mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); + mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]); + mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]); + mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]); + + Label ch_blocks_tail_label; + Label exit_label; + + int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking; + + cmp(reg_ch_blocks, jcp.nb_ch_blocking); + jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR); + + loop_body(jcp.nb_ch_blocking); // channel main loop + + if (ch_blocks_tail) { + L(ch_blocks_tail_label); + + cmp(reg_ch_blocks, ch_blocks_tail); + jne(exit_label, T_NEAR); + + loop_body(ch_blocks_tail); // channel tail loop + } + + L(exit_label); + + this->postamble(); +} + +template <cpu_isa_t isa> +status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d) { + if (!mayiuse(isa)) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + + const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; + if (!with_groups) return status::unimplemented; + + jcp.ngroups = weights_d.dims()[0]; + jcp.mb = diff_src_d.dims()[0]; + + jcp.oc = diff_dst_d.dims()[1]; + jcp.oc_without_padding = jcp.oc; + jcp.ic = diff_src_d.dims()[1]; + + jcp.ih = diff_src_d.dims()[2]; + jcp.iw = diff_src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = weights_d.dims()[3]; + jcp.kw = weights_d.dims()[4]; + + jcp.t_pad = cd.padding[0][0]; + jcp.l_pad = cd.padding[0][1]; + jcp.b_pad = cd.padding[1][0]; + jcp.r_pad = cd.padding[1][1]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + bool ok_to_pad_channels = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && one_of(isa, avx512_common, avx2); + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, simd_w); + jcp.ic = rnd_up(jcp.oc, simd_w); + jcp.ngroups = rnd_up(jcp.ngroups, simd_w); + } + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.oc == jcp.ngroups + && jcp.ic == jcp.ngroups + && jcp.ngroups % simd_w == 0 + && jcp.dilate_h == 0 + && jcp.dilate_w == 0 + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1 + && jcp.ic <= diff_src_d.padded_dims()[1] + && jcp.oc <= diff_dst_d.padded_dims()[1] + && jcp.ngroups <= weights_d.padded_dims()[0]; + if (!args_ok) return status::unimplemented; + + jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3; + + jcp.ch_block = simd_w; + jcp.nb_ch = jcp.ic / jcp.ch_block; + jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2; + if (jcp.nb_ch < jcp.nb_ch_blocking) + jcp.nb_ch_blocking = jcp.nb_ch; + + return status::success; +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + UNUSED(scratchpad); + UNUSED(jcp); +} + +template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>; +template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>; +template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>; + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < jcp.kw; ++i) { + Vmm vmm_acc = get_acc_reg(r * jcp.kw + i); + uni_vpxor(vmm_acc, vmm_acc, vmm_acc); + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (reg_set + i) * simd_w; + Vmm vmm_acc = get_acc_reg(reg_set + i); + uni_vmovups(vmm_acc, + vmmword[reg_tmp_filter + off_filter * sizeof(float)]); + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vpxor(vmm_bias, vmm_bias, vmm_bias); + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]); + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int iw_block = ow_block * jcp.stride_w; + const int right_border = jcp.iw - iw_block; + + const int cascade_input = nstl::min(jcp.stride_w, jcp.kw); + + /* preamble count for number of cascaded LOAD + FMA operation */ + const int input_overlap = nstl::max(jcp.kw - l_pad, 0); + + /* LOAD initial input registers, then cascade LOADs and FMAs*/ + for (int r = 0; r < reg_repeats; ++r) { + for (int i_ur = 0; i_ur < unroll_w; ++i_ur) { + int off_output = (i_ur * reg_repeats + r) * simd_w; + Vmm vmm_output = get_output_reg(r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + if (i_ur == 0) { + for (int c = 0; c < input_overlap; ++c) { + int off_input + = ((c - pad_offset) * reg_repeats + r) * simd_w; + Vmm vmm_input + = get_input_reg((c % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } else { + for (int c = 0; c < cascade_input; ++c) { + int overlap = (i_ur - 1) * jcp.stride_w + input_overlap; + int off_input + = ((overlap + c - pad_offset) * reg_repeats + r) + * simd_w; + Vmm vmm_input = get_input_reg( + ((overlap + c) % jcp.kw) * reg_repeats + r); + uni_vmovups(vmm_input, + ptr[reg_tmp_input + off_input * sizeof(float)]); + } + } + + for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) { + int io_overlap = i_kw + (i_ur * jcp.stride_w); + + /* Don't apply FMAs that fall into the padded region */ + if (io_overlap - l_pad < 0 + || io_overlap - jcp.l_pad >= right_border) + continue; + + Vmm vmm_input = get_input_reg( + ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r); + Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r); + Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input; + if (isa == sse42) + uni_vmovups(vmm_aux, vmm_input); + uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output); + } + } + } +} + +template <cpu_isa_t isa> +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll( + const int unroll_w) { + for (int r = 0; r < reg_repeats; ++r) { + for (int i = 0; i < unroll_w; ++i) { + Vmm vmm_bias = get_bias_reg(r); + int off_output = (i * reg_repeats + r) * simd_w; + if (isa == sse42) { + /* Need to support unaligned address loads for SSE42*/ + Vmm vmm_output = get_output_reg(1 + r); + uni_vmovups(vmm_output, + ptr[reg_tmp_output + off_output * sizeof(float)]); + uni_vaddps(vmm_bias, vmm_bias, vmm_output); + } else { + uni_vaddps(vmm_bias, vmm_bias, + vmmword[reg_tmp_output + off_output * sizeof(float)]); + } + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() { + for (int r = 0; r < reg_repeats; ++r) { + const int reg_set = r * jcp.kw; + for (int i = 0; i < jcp.kw; ++i) { + int off_filter = (i + reg_set) * simd_w; + Vmm vmm_acc = get_acc_reg(i + reg_set); + uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)], + vmm_acc); + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() { + for (int r = 0; r < reg_repeats; ++r) { + Vmm vmm_bias = get_bias_reg(r); + uni_vmovups( + vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias); + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop( + const int block_size) { + Label oh_label; + Label ow_blk_label; + + const int unroll_w = nstl::min(block_size, jcp.ow); + const int unroll_w_trips = jcp.ow / unroll_w; + const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0; + + const int ch_offset = jcp.ch_block; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + L(oh_label); + { + + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + { + + compute_bias_step_unroll(unroll_w); + add(reg_tmp_output, unroll_w * ch_offset * sizeof(float)); + + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + if (tail_w > 0) { + compute_bias_step_unroll(tail_w); + add(reg_tmp_output, tail_w * ch_offset * sizeof(float)); + } + + inc(reg_oh); + cmp(reg_oh, reg_oh_worksize); + jl(oh_label, T_NEAR); + } +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_zeroing_label; + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_FILTER); + test(reg_exec_flags, reg_exec_flags); + je(skip_zeroing_label); + + zero_filter(); + + mov(reg_tmp_filter, reg_filter_baddr); + mov(reg_kh, jcp.kh); + L(kh_loop_label); + { + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float)); + + L(skip_zeroing_label); +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const int ch_offset = jcp.ch_block; + + Label kh_loop_label, skip_loop_label; + + cmp(reg_kh_count, 0); + je(skip_loop_label, T_NEAR); + + mov(reg_kh, reg_kh_count); + L(kh_loop_label); + { + load_filter(); + compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block); + store_filter(); + + add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_loop_label); + } + + /* Comeback pointers */ + Label kh_comeback_label; + mov(reg_kh, reg_kh_count); + L(kh_comeback_label); + { + sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float)); + sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float)); + dec(reg_kh); + cmp(reg_kh, 0); + jg(kh_comeback_label, T_NEAR); + } + + L(skip_loop_label); +} + +template <cpu_isa_t isa> +inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block) { + + const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ? + jcp.ih / jcp.stride_h - 1 : + jcp.oh - jcp.b_pad - 1; + const int ch_offset = jcp.ch_block; + const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1; + + Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label, + end_h_loop_label; + + mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]); + mov(reg_oh_worksize, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]); + mov(reg_kh_count, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]); + + mov(reg_tmp_output, reg_output_baddr); + mov(reg_tmp_input, reg_input_baddr); + mov(reg_tmp_filter, reg_filter_baddr); + + L(h_loop_label); + { + + compute_h_step(unroll_w, l_pad, pad_offset, ow_block); + + add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float)); + + /* If within the top_pad region */ + if (jcp.t_pad > 0) { + /* Skip t_pad area if no longer in initial h_block */ + cmp(reg_oh, jcp.t_pad); + jg(skip_tpad_label, T_NEAR); + + cmp(reg_kh_count, jcp.kh); + jge(skip_tpad_label, T_NEAR); + + add(reg_kh_count, t_overlap_off); + sub(reg_tmp_filter, + t_overlap_off * jcp.kw * ch_offset * sizeof(float)); + + /* kernel has moved beyond padding (adjust for stride effects) */ + if (jcp.t_pad % jcp.stride_h != 0) { + int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h; + add(reg_tmp_input, + inp_corr * jcp.iw * ch_offset * sizeof(float)); + } + jmp(tpad_loop_label, T_NEAR); + } + + L(skip_tpad_label); + + cmp(reg_oh, io_overlap); + jl(skip_bpad_label, T_NEAR); + sub(reg_kh_count, b_overlap_off); + + L(skip_bpad_label); + add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float)); + + L(tpad_loop_label); + + cmp(reg_oh, jcp.ih / jcp.stride_h); + jge(end_h_loop_label, T_NEAR); + + inc(reg_oh); + + cmp(reg_oh, reg_oh_worksize); + jl(h_loop_label, T_NEAR); + } + L(end_h_loop_label); +} + +template <cpu_isa_t isa> +inline void +jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() { + + const int ch_offset = jcp.ch_block; + int ow = jcp.ow; + int pad_offset = 0; + int l_pad = jcp.l_pad; + + /* Calculate effective padding */ + int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w + + (jcp.kw - 1) * (jcp.dilate_w + 1) + - (jcp.iw + jcp.l_pad - 1)); + + /* Is this strictly defined by: + * -code-size (?) + * -address size (?) */ + const int max_unroll_w = 30; + const int block_size = 15; + + int unroll_w_tail = 0; + int unroll_w = 0; + int unroll_w_trips = 0; + + if (jcp.ow > max_unroll_w) { + unroll_w = nstl::min(block_size, jcp.ow); + unroll_w_trips = ow / unroll_w; + /* calculate tail */ + unroll_w_tail = ow % unroll_w; + /* Perform some rebalancing if tail too small*/ + if ((unroll_w_tail == 0 && r_pad != 0) + || (r_pad > 0 && r_pad >= unroll_w_tail)) { + if (unroll_w_trips > 1) { + unroll_w_tail += unroll_w; + unroll_w_trips--; + } else { + /* Idealy, this case shouldn't happen */ + unroll_w_tail += (unroll_w - unroll_w / 2); + unroll_w = unroll_w / 2; + } + } + } else { + unroll_w = jcp.ow; + unroll_w_trips = nstl::max(1, ow / unroll_w); + } + if (jcp.with_bias) { + Label skip_load_bias; + mov(reg_bias_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]); + + zero_bias(); + + mov(reg_exec_flags, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]); + and_(reg_exec_flags, FLAG_ZERO_BIAS); + test(reg_exec_flags, reg_exec_flags); + jne(skip_load_bias); + + load_bias(); + + L(skip_load_bias); + compute_bias_loop(block_size); + + store_bias(); + } + + /* Pass filter address, then offset for h_padding. */ + compute_zero_filter(); + mov(reg_kh_offset, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]); + add(reg_filter_baddr, reg_kh_offset); + + /* compute left padded block */ + if (l_pad) { + compute_h_loop(unroll_w, l_pad, 0, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + unroll_w_trips--; + pad_offset = l_pad; + l_pad = 0; + } + + /* compute middle block */ + Label ow_blk_label; + + /* Insert loop for 'ow' block when middle block needs to execute more + * than once */ + bool do_ow_blk_loop = unroll_w_trips > 1; + if (do_ow_blk_loop) { + mov(iter_ow_blk, unroll_w_trips); + L(ow_blk_label); + } + if (unroll_w_trips > 0) { + compute_h_loop(unroll_w, l_pad, pad_offset, 0); + add(reg_output_baddr, unroll_w * ch_offset * sizeof(float)); + add(reg_input_baddr, + unroll_w * jcp.stride_w * ch_offset * sizeof(float)); + } + if (do_ow_blk_loop) { + dec(iter_ow_blk); + cmp(iter_ow_blk, 0); + jg(ow_blk_label, T_NEAR); + } + + /* compute right padded block */ + if (unroll_w_tail) { + compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail); + } +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() { + preamble(); + + mov(reg_input_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]); + mov(reg_output_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]); + mov(reg_filter_baddr, + ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]); + + compute_ow_block_unroll(); + + this->postamble(); +} + +template <cpu_isa_t isa> +status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf( + jit_conv_conf_t &jcp, const convolution_desc_t &cd, + const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads) { + if (!mayiuse(isa)) + return status::unimplemented; + + jcp.ngroups = diff_weights_d.dims()[0]; + jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + + const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; + + jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic); + + if (!jcp.is_depthwise) + return status::unimplemented; + + jcp.ch_block = isa == avx512_common ? 16 : 8; + + jcp.mb = src_d.dims()[0]; + + jcp.ih = src_d.dims()[2]; + jcp.iw = src_d.dims()[3]; + jcp.oh = diff_dst_d.dims()[2]; + jcp.ow = diff_dst_d.dims()[3]; + + jcp.kh = diff_weights_d.dims()[3]; + jcp.kw = diff_weights_d.dims()[4]; + + jcp.stride_h = cd.strides[0]; + jcp.stride_w = cd.strides[1]; + + jcp.t_pad = cd.padding[0][0]; + jcp.b_pad = cd.padding[1][0]; + + jcp.l_pad = cd.padding[0][1]; + jcp.r_pad = cd.padding[1][1]; + + jcp.dilate_h = cd.dilates[0]; + jcp.dilate_w = cd.dilates[1]; + + jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; + jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; + + jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + jcp.src_tag = src_d.matches_one_of_tag(dat_tag); + jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); + jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); + + bool args_ok = true + && jcp.src_tag == dat_tag + && jcp.wei_tag == wei_tag + && jcp.dst_tag == dat_tag + && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && jcp.kw <= 3 + && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1 + && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1; + if (!args_ok) + return status::unimplemented; + + jcp.nb_ch = jcp.ngroups / jcp.ch_block; + + /* kernel applicability check wrt boundaries + * the conditions are quite general across the kernels we have, + * but ideally the check should belong to a specific kernel... */ + const int max_hpad = (jcp.kh - 1 + 1) / 2; + const int max_wpad = (jcp.kw - 1 + 1) / 2; + const bool boundaries_ok = true && jcp.t_pad <= max_hpad + && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad + && jcp.r_pad <= max_wpad; + if (!boundaries_ok) + return status::unimplemented; + + balance(jcp, nthreads); + + return status::success; +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) { + /* Notes: if splitting thread work on 'mb', then a reduction has to take + * place. Hence, book a per-thread, local weights-buffer for the + * reduction */ + if (jcp.nthr_mb > 1) { + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + scratchpad.book(key_conv_wei_reduction, + sizeof(float) * wei_size * (jcp.nthr_mb - 1)); + + if (jcp.with_bias) + scratchpad.book(key_conv_bia_reduction, + sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1)); + } +} + +template <cpu_isa_t isa> +void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp, + int nthreads) { + jcp.nthr = nthreads; + jcp.nthr_g = jcp.nthr_mb = 1; + + /* Basic-Heuristics for parallel strategy: + * 1) Tries to parallel on the number of Groups (g) where tasks are + * independent. Otherwise, + * 2) Tries to split the work across g and MiniBatch (mb). + * Parallelizing on mb requires computing a reduction for weights. + * + * NOTE: because of 'task partitioning' scheme, there will be unbalanced + * per-thread load when the number of threads is high (e.g. > 16). + */ + jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr); + jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb); + + jcp.nthr = jcp.nthr_g * jcp.nthr_mb; +} + +template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>; +template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp new file mode 100644 index 0000000000..9c08fc4a09 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp @@ -0,0 +1,253 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP +#define JIT_UNI_DW_CONV_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" +#include "jit_uni_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) + + jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) + : jcp(ajcp), eltwise_injector_(nullptr) + { + if (jcp.with_eltwise) + eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this, + jcp.eltwise); + + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + ~jit_uni_dw_conv_fwd_kernel_f32() { + delete eltwise_injector_; + } + + static bool post_ops_ok(jit_conv_conf_t &jcp, + const primitive_attr_t &attr); + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm, + isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; + using reg64_t = const Xbyak::Reg64; + const Xbyak::AddressFrame &vmmword = (isa == sse42) + ? xword : (isa == avx2) ? yword : zword; + const int vlen = cpu_isa_traits<isa>::vlen; + + // dw convolution + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t aux1_reg_input = r10; + reg64_t reg_kernel = r11; + reg64_t aux_reg_kernel = r12; + reg64_t aux1_reg_kernel = r13; + reg64_t reg_output = r14; + reg64_t reg_bias = r15; + reg64_t reg_kh = rax; + reg64_t reg_kw = rbx; + reg64_t iter_kh = rdx; + reg64_t iter_kw = rsi; + reg64_t reg_ur_w = rbp; + reg64_t reg_ch_blocks = aux1_reg_input; + reg64_t imm_addr64 = aux1_reg_input; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + inline void load_src(int ur_ch_blocks, int ur_w); + inline void apply_filter(int ur_ch_blocks, int ur_w); + inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w); + inline void apply_activation(int ur_ch_blocks, int ur_w); + inline void store_dst(int ur_ch_blocks, int ur_w); + inline void loop_body(int ur_ch_blocks); + + jit_uni_eltwise_injector_f32<isa> *eltwise_injector_; + + void generate(); +}; + +template <cpu_isa_t isa> +struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) + + jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_conv_call_s *))this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, + const memory_desc_wrapper &diff_src_d, + const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &diff_dst_d); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_conv_call_s *); + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm, + isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; + using reg64_t = const Xbyak::Reg64; + + inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } + inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } + + reg64_t reg_ddst = rax; + reg64_t aux_reg_ddst = r8; + reg64_t aux1_reg_ddst = abi_not_param1; + reg64_t reg_kernel = rdx; + reg64_t aux_reg_kernel = r10; + reg64_t aux1_reg_kernel = rbp; + reg64_t reg_dsrc = rsi; + + reg64_t reg_ur_str_w = r9; + reg64_t reg_ch_blocks = rbx; + + reg64_t iter_kh = r11; + reg64_t iter_kw = r12; + reg64_t reg_kh = r13; + reg64_t reg_kw = r14; + + inline void loop_body(int ur_ch_blocks); + inline void load_ddst(int ur_ch_blocks, int ur_str_w); + inline void apply_filter(int ur_ch_blocks, int ur_str_w); + inline void store_dsrc(int ur_ch_blocks, int ur_str_w); + + void generate(); +}; + +template <cpu_isa_t isa> +struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) + + jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) { + this->generate(); + jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode(); + } + + static status_t init_conf(jit_conv_conf_t &jcp, + const convolution_desc_t &cd, const memory_desc_wrapper &src_d, + const memory_desc_wrapper &diff_weights_d, + const memory_desc_wrapper &diff_dst_d, int nthreads); + + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const jit_conv_conf_t &jcp); + + static void balance(jit_conv_conf_t &jcp, int nthreads); + + jit_conv_conf_t jcp; + void (*jit_ker)(jit_dw_conv_call_s *); + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm, + isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; + using reg64_t = const Xbyak::Reg64; + const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); + const int reg_repeats = (isa == sse42) ? 2 : 1; + + const Xbyak::AddressFrame &vmmword + = (isa == sse42) ? xword : (isa == avx2) ? yword : zword; + + /* XXX: offset between input and accummulators is 3, therefore, assume 'kw' + * is no larger than 3*/ + inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } + inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); } + inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); } + inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); } + inline Vmm get_aux_reg() { return Vmm(0); } + + reg64_t reg_tmp_input = r9; + reg64_t reg_tmp_output = r10; + reg64_t reg_tmp_filter = r13; + reg64_t reg_kh_offset = rax; + + /* parameter passed by driver into kernel */ + Xbyak::Reg8 reg_exec_flags = bl; + + reg64_t reg_oh_worksize = r14; + reg64_t reg_oh = rax; + + reg64_t iter_ow_blk = r11; + + reg64_t reg_kh = rsi; + reg64_t reg_kh_count = rdx; + + /* Base addresses for convolution parameters. */ + reg64_t reg_input_baddr = r15; + reg64_t reg_output_baddr = r12; + reg64_t reg_filter_baddr = abi_not_param1; + reg64_t reg_bias_baddr = r13; + + /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs + */ + inline void compute_ow_step_unroll( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ + inline void compute_h_step( + int unroll_w, int l_pad, int pad_offset, int ow_block); + inline void compute_h_loop( + int unroll_w, int l_pad, int pad_offset, int ow_block); + + /* Write 'width' micro-kernel JITs; depending on the padding and convolution + * size, write a micro-kernel for the left ow-block, middle ow-block(s), and + * right ow-block.*/ + inline void compute_ow_block_unroll(); + + inline void compute_zero_filter(); + inline void load_filter(); + inline void zero_filter(); + inline void load_bias(); + inline void zero_bias(); + inline void compute_bias_step_unroll(const int unroll_w); + inline void compute_bias_loop(const int block_size); + inline void store_filter(); + inline void store_bias(); + + void generate(); +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp new file mode 100644 index 0000000000..58449601a3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp @@ -0,0 +1,427 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "mkldnn_thread.hpp" + +#include "jit_uni_dw_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::memory_tracking::names; +using namespace mkldnn::impl::utils; + +template <cpu_isa_t isa> +void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const auto &jcp = kernel_->jcp; + + if (pd()->wants_padded_bias()) { + auto padded_bias = this->scratchpad(ctx).template get<data_t>( + key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + int dil_h = jcp.dilate_h + 1; + int dil_w = jcp.dilate_w + 1; + int str_h = jcp.stride_h; + int str_w = jcp.stride_w; + + auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh, + int kh_padding, int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w)); + const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w + + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw; + + const int iw = nstl::max((ow*str_w - jcp.l_pad + + div_up(i_l_overflow, dil_w)*dil_w), 0); + const int kw = div_up(i_l_overflow, dil_w); + + const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w) + - div_up(i_r_overflow, dil_w); + + par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)]; + + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; + if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)]; + + par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); + par_conv.kw_padding = (size_t)nstl::max(0, kw_padding); + + par_conv.ur_w = (size_t)ur_w_step; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.oh, + [&](int n, int chb, int oh) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h)); + const int i_b_overflow = nstl::max(jcp.ih, + (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih; + + const int ih = nstl::max((int)(oh*str_h - jcp.t_pad + + div_up(i_t_overflow, dil_h)*dil_h), 0); + const int kh = div_up(i_t_overflow, dil_h); + const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) + - div_up(i_b_overflow, dil_h); + + // left border + int ow = 0; + int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow); + int ur_w_step = 1; + for (; ow < l_border; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1) + / jcp.stride_w - ow + 1; + if (ur_w_step > 0) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + ow += ur_w_step; + } + + // right border + ur_w_step = 1; + for (; ow < jcp.ow; ow++) { + jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih, + kh, kh_padding, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + }); + + if (pd()->wants_zero_pad_dst()) + ctx.memory(MKLDNN_ARG_DST)->zero_pad(); +} + +template struct _jit_uni_dw_convolution_fwd_t<avx512_common>; +template struct _jit_uni_dw_convolution_fwd_t<avx2>; +template struct _jit_uni_dw_convolution_fwd_t<sse42>; + +template <cpu_isa_t isa> +void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + + const auto &jcp = kernel_->jcp; + + auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, + int i_t_overflow, int i_b_overflow, int stride_off_h, + int ch, int ch_num, int n) { + auto par_conv = jit_conv_call_s(); + + const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); + const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) + - jcp.r_pad)); + + int ow = iw + jcp.l_pad - i_r_overflow; + int stride_off_w = ow % jcp.stride_w; + ow /= jcp.stride_w; + + par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)]; + par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)]; + par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow + + stride_off_h, i_r_overflow + stride_off_w)]; + + par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow + - stride_off_h); + par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow + - stride_off_w); + + par_conv.ur_str_w = ur_str_w; + + par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch; + + return par_conv; + }; + + const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); + parallel_nd(jcp.mb, chb_work, jcp.ih, + [&](int n, int chb, int ih) { + int ch = chb * jcp.nb_ch_blocking; + int ch_num = jcp.nb_ch_blocking; + + const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih + - jcp.t_pad)); + const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1 + - (jcp.ih - 1 - ih) - jcp.b_pad)); + + int oh = ih + jcp.t_pad - i_b_overflow; + int stride_off_h = oh % jcp.stride_h; + oh /= jcp.stride_h; + + for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { + // left border + int iw = i_str_w; + int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); + int ur_str_w = 1; + for (; iw < l_border; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + + // main loop + ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw) + / jcp.stride_w, jcp.iw); + if (ur_str_w > 0) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + + iw += ur_str_w * jcp.stride_w; + } + + // right border + ur_str_w = 1; + for (; iw < jcp.iw; iw += jcp.stride_w) { + jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, + ih, i_t_overflow, i_b_overflow, + stride_off_h, ch, ch_num, n); + + kernel_->jit_ker(&par_conv); + } + } + }); +} + +template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>; +template struct _jit_uni_dw_convolution_bwd_data_t<avx2>; +template struct _jit_uni_dw_convolution_bwd_data_t<sse42>; + +template <cpu_isa_t isa> +_jit_uni_dw_convolution_bwd_weights_t<isa>:: +_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) + : cpu_primitive_t(apd) + , kernel_(nullptr), acc_ker_(nullptr) +{ + kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_); + if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction()) + acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>(); +} + +template <cpu_isa_t isa> +void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + auto diff_wei_reduction_buf = + scratchpad(ctx).template get<data_t>(key_conv_wei_reduction); + auto diff_bia_reduction_buf = + scratchpad(ctx).template get<data_t>(key_conv_bia_reduction); + + const auto &jcp = kernel_->jcp; + + /* Used when executing a parallel reduction */ + simple_barrier::ctx_t reduction_bctx; + simple_barrier::ctx_init(&reduction_bctx); + + const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; + const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; + + const int ch_block = jcp.ch_block; + + auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params, + const int batch, const int group, const int oh_start, + const int work_size, const unsigned char exec_flag, + const size_t kh_padding, const size_t filter_off) { + const int tpad_underflow_off = jcp.t_pad - filter_off; + + conv_params->exec_flags = exec_flag; + conv_params->kh_count = jcp.kh - kh_padding; + + const int oh_s = oh_start; + const int oh_e = oh_start + work_size; + const int ih_s = oh_s * jcp.stride_h; + + conv_params->filter_pad_off + = filter_off * jcp.kw * ch_block * sizeof(float); + conv_params->oh_index = oh_s; + conv_params->oh_count = oh_e; + + size_t diff_dst_off + = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + + oh_start) + * jcp.ow; + + size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih + + ih_s - tpad_underflow_off) * jcp.iw; + + conv_params->output = &diff_dst[diff_dst_off * ch_block]; + conv_params->input = &src[src_off * ch_block]; + }; + + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + assert(nthr == jcp.nthr); + + auto conv_params = jit_dw_conv_call_s(); + const int h_block_size = 15; + + /* assign iteration space to thread */ + const int ithr_g = ithr % jcp.nthr_g; + const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; + + /* split dimensions */ + int g_start{ 0 }, g_end{ 0 }; + balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); + + int mb_start{ 0 }, mb_end{ 0 }; + balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); + + auto diff_wei = ithr_mb == 0 + ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size; + auto diff_bia = ithr_mb == 0 + ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; + + for (int g = g_start; g < g_end; ++g) { + unsigned char zero_filter_flag = FLAG_ZERO_FILTER; + unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; + + size_t diff_wei_off = g * jcp.kh * jcp.kw; + conv_params.filter = &diff_wei[diff_wei_off * ch_block]; + + if (jcp.with_bias) + conv_params.bias = &diff_bia[g * ch_block]; + + for (int mb = mb_start; mb < mb_end; ++mb) { + int oh = 0; + while (oh < jcp.oh) { + const int h_work = nstl::min(h_block_size, jcp.oh - oh); + auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); + auto kh_b_padding + = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ? + jcp.b_pad - (h_work - 1) : + 0; + + set_kernel_params(&conv_params, mb, g, oh, h_work, + zero_filter_flag | zero_bias_flag, + kh_t_padding + kh_b_padding, kh_t_padding); + kernel_->jit_ker(&conv_params); + + zero_bias_flag &= ~FLAG_ZERO_BIAS; + zero_filter_flag &= ~FLAG_ZERO_FILTER; + oh += h_work; + } + } + } + + if (do_parallel_reduction() && jcp.nthr_mb > 1) { + size_t reduct_start{ 0 }, reduct_end{ 0 }; + balance211(wei_size, nthr, ithr, reduct_start, reduct_end); + + const int acc_size = reduct_end - reduct_start; + const size_t reduct_off = reduct_start; + auto *acc_data = diff_weights + reduct_off; + + simple_barrier::barrier(&reduction_bctx, nthr); + + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + auto *src_data = diff_wei_reduction_buf + + (thr_mb - 1) * wei_size + reduct_off; + acc_ker_->accumulate(acc_data, src_data, acc_size); + } + } + }); + + if (jcp.nthr_mb <= 1) return; + + /* Apply single-threaded 'mb' reduction */ + for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { + size_t mb_accum_offset = (thr_mb - 1) * wei_size; + size_t b_accum_offset = (thr_mb - 1) * bias_size; + + for (int g = 0; g < jcp.nb_ch; ++g) { + /* Reduction on Bias */ + if (jcp.with_bias) { + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + size_t bias_offset = g * ch_block + g_block; + diff_bias[bias_offset] += diff_bia_reduction_buf[ + b_accum_offset + bias_offset]; + } + } + + if (do_parallel_reduction()) continue; + + for (int kh = 0; kh < jcp.kh; ++kh) + for (int kw = 0; kw < jcp.kw; ++kw) + { + size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw; + PRAGMA_OMP_SIMD() + for (int g_block = 0; g_block < ch_block; ++g_block) { + const size_t off = wei_offset * ch_block + g_block; + diff_weights[off] += + diff_wei_reduction_buf[mb_accum_offset + off]; + } + } + } + } +} + +template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>; +template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>; +template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp new file mode 100644 index 0000000000..ca53749ec2 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp @@ -0,0 +1,266 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP +#define CPU_JIT_UNI_DW_CONVOLUTION_HPP + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" + +#include "cpu_barrier.hpp" +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" +#include "cpu_reducer.hpp" + +#include "jit_uni_dw_conv_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + pd_t(engine_t *engine, const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_fwd_t<isa>); + + status_t init() { + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf( + jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(scratchpad, + jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32<isa>(pd()->jcp_); } + + ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_fwd_kernel_f32<isa> *kernel_; +}; + +using jit_avx512_common_dw_convolution_fwd_t = + _jit_uni_dw_convolution_fwd_t<avx512_common>; +using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<avx2>; +using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<sse42>; + +template <cpu_isa_t isa> +struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() + {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::undef, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + + if (!ok) return status::unimplemented; + + status_t status = jit_uni_dw_conv_bwd_data_kernel_f32<isa>:: + init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(), + *diff_dst_md()); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32<isa>(pd()->jcp_); } + ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_data_kernel_f32<isa> *kernel_; +}; + +using jit_avx512_common_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t<avx512_common>; +using jit_avx2_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t<avx2>; +using jit_sse42_dw_convolution_bwd_data_t = + _jit_uni_dw_convolution_bwd_data_t<sse42>; + +template <cpu_isa_t isa> +struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const convolution_desc_t *adesc, + const primitive_attr_t *attr, + const convolution_fwd_pd_t *hint_fwd_pd) + : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , jcp_() {} + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""), + _jit_uni_dw_convolution_bwd_weights_t<isa>); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(data_type::f32, data_type::f32, + data_type::f32, data_type::f32, data_type::f32) + && !has_zero_dim_memory() + && set_default_formats(); + if (!ok) return status::unimplemented; + + const int max_threads = mkldnn_in_parallel() + ? 1 : mkldnn_get_max_threads(); + + status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32<isa>:: + init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(), + *diff_dst_md(), max_threads); + if (status != status::success) return status; + + auto scratchpad = scratchpad_registry().registrar(); + jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad( + scratchpad, jcp_); + + return status::success; + } + + jit_conv_conf_t jcp_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + auto dat_tag = isa == avx512_common ? nChw16c : nChw8c; + auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g; + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); + ~_jit_uni_dw_convolution_bwd_weights_t() { + delete kernel_; + delete acc_ker_; + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + bool do_parallel_reduction() const { return false; } + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_dw_conv_bwd_weights_kernel_f32<isa> *kernel_; + cpu_accumulator_1d_t<data_type::f32> *acc_ker_; +}; + +using jit_avx512_common_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t<avx512_common>; +using jit_avx2_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t<avx2>; +using jit_sse42_dw_convolution_bwd_weights_t = + _jit_uni_dw_convolution_bwd_weights_t<sse42>; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp new file mode 100644 index 0000000000..2af6435871 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp @@ -0,0 +1,1142 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_eltwise.hpp" + +#define GET_OFF(field) offsetof(jit_args, field) + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx, + size_t end_idx) { + preserved_vecs_count = 0; + vecs_to_preserve = (size_t)aux_vecs_count(alg_); + start_idx_tail = start_idx; + + // For sse42 mask register has to be Xmm(0) + if (isa == sse42 && vecs_to_preserve > 0) { + size_t idx = 0; + assert(idx < start_idx); + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { + if (preserved_vecs_count >= vecs_to_preserve) break; + if (start_idx <= idx && idx < end_idx) continue; + + preserved_vec_idxs[preserved_vecs_count++] = idx; + } + + size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; + for (size_t i = 0; i < preserved_vecs_count_tail; i++) { + preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++; + } + + assert(preserved_vecs_count == vecs_to_preserve); + + if (save_state_) { + h->push(p_table); + + if (preserved_vecs_count) + h->sub(h->rsp, preserved_vecs_count * vlen); + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[i])); + + load_table_addr(); + } + + assign_regs(); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx) +{ + size_t tail_vecs_to_preserve = start_idx_tail - start_idx; + if (tail_vecs_to_preserve == 0) return; + + const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; + + if (save_state_) { + if (idx_off) + h->add(h->rsp, idx_off * vlen); + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), + h->ptr[h->rsp + i * vlen]); + } + + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; + + if (save_state_) { + for (size_t i = 0; i < tail_vecs_to_preserve; ++i) + h->uni_vmovups(h->ptr[h->rsp + i * vlen], + Vmm(preserved_vec_idxs[idx_off + i])); + + if (idx_off) + h->sub(h->rsp, idx_off * vlen); + } + + assign_regs(); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::injector_postamble() { + if (!save_state_) return; + + for (size_t i = 0; i < preserved_vecs_count; ++i) + h->uni_vmovups(Vmm(preserved_vec_idxs[i]), + h->ptr[h->rsp + i * vlen]); + + if (preserved_vecs_count) + h->add(h->rsp, preserved_vecs_count * vlen); + + h->pop(p_table); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::assign_regs() { + vmm_mask = Vmm(preserved_vec_idxs[0]); + vmm_aux0 = Vmm(preserved_vec_idxs[0]); + vmm_aux1 = Vmm(preserved_vec_idxs[1]); + vmm_aux2 = Vmm(preserved_vec_idxs[2]); + vmm_aux3 = Vmm(preserved_vec_idxs[3]); + vmm_aux4 = Vmm(preserved_vec_idxs[4]); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) { + h->uni_vminps(vmm_src, vmm_src, table_val(10)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(11)); + h->uni_vmovups(vmm_aux0, vmm_src); + //calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux1, vmm_aux1); + + h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux1, vmm_src, _op_floor); + } + + //keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx + + //x = x - fx * ln2 + h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3)); + + // compute 2^n + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx + + // y = p5 + h->uni_vmovups(vmm_src, table_val(9)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q) + // y = y * 2^n + h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src) +{ + const int alpha_off = 0, zero_off = 1; + + h->uni_vmovups(vmm_aux1, vmm_src); + if (isa == sse42) { + h->movups(vmm_mask, vmm_src); + h->mulps(vmm_src, table_val(alpha_off)); + h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us); + h->blendvps(vmm_src, vmm_aux1); + } else if (isa == avx2) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off)); + h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } else if (isa == avx512_common) { + h->vmulps(vmm_src, vmm_src, table_val(alpha_off)); + h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector( + const Vmm &vmm_src) { + const int zero_off = 1; + h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off)); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) { + const int alpha_off = 23, zero_off = 24; + + // compute exponent + h->uni_vmovups(vmm_aux2, vmm_src); + exp_compute_vector(vmm_src); + + // alpha * (exp(x) - 1) + h->uni_vsubps(vmm_src, vmm_src, table_val(0)); + h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off)); + + // combine with mask + if (isa == sse42) { + h->pxor(vmm_mask, vmm_mask); + h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os); + h->blendvps(vmm_src, vmm_aux2); + } else if (isa == avx2) { + h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask); + } else if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2); + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src) +{ + // # comes from Taylor expansion error bound + // > linear_sat_point = single(sqrt(3) * 1b-12); + // # comes from the exp formula cancellation + // > exp_bound_point = (single(log(3)/2)); + // # comes from rounding accuracy in float + // > one_sat_point = round(atanh(1 - 1b-25), single, RU); + // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |], + // [linear_sat_point, exp_bound_point], relative, floating); + // > err_bound = D(sup(supnorm(P, tanh(x), + // [linear_sat_point, exp_bound_point], relative, theta))); + // 0x1.fffd6f00b9539p-25 + // > P; + // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 * + // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5 + // + x^0x1p1 * 0x1.09fa1p-6)))) + + // register mapping + // vmm_src contains input + // vmm_aux0 contains mask of currently valid results. + // 1 is need computation, 0 is already computed + // vmm_aux1 contains current output + // vmm_aux2, vmm_aux3 contains auxiliary values + // vmm_aux4 contains the original sign of inputs + + Label end_tanh_label; + + auto test_exit =[&](Xbyak::Address threshold){ + // is not necessary for >AVX, but should not matter on perf + h->uni_vmovups(vmm_aux0, vmm_src); + if (isa == avx512_common){ + h->vcmpps(k_mask, vmm_aux0, threshold, 0x5); + h->kortestw(k_mask, k_mask); + } else { + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold); + h->uni_vtestps(vmm_aux0, vmm_aux0); + } + h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR); + }; + + auto blend_results=[&](Vmm vmm_partial_res){ + if (isa == avx512_common) + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res); + else + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0); + }; + + // because tanh(x) = -tanh(-x), we extract sign to make x postive + // and reapply sign at the end + // mov is not necessary for >AVX, but should not matter for performance + h->uni_vmovups(vmm_aux4, vmm_src); + h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12)); + h->uni_vandps(vmm_src, vmm_src, table_val(17)); + + // if x < linear_sat_point for all inputs, we just return the input + h->uni_vmovups(vmm_aux1, vmm_src); + test_exit(table_val(13)); + + // if one of the mask is one, we have to compute an better approx + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2); + h->uni_vmovups(vmm_aux3, table_val(22)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19)); + h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18)); + h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src); + + // we blend only the result that need update + blend_results(vmm_aux3); + + // if x < exp_bound_point, we go to return point + test_exit(table_val(14)); + + // if not we use a better approx 1 - 2 / (1 + exp(2x)) + // compute 2x + h->uni_vmovups(vmm_aux3, vmm_src); + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3); + + // Compute exp(2x) + // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them + // vmm_src is not more read afterwards, so we do not have to save it + auto stack_size = 3 * vlen + (isa == avx512_common) * 4; + h->sub(h->rsp, stack_size); + h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0); + h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1); + h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src); + if (isa == avx512_common) + h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask); + + exp_compute_vector(vmm_aux3); + + h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]); + h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]); + h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]); + if (isa == avx512_common) + h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]); + h->add(h->rsp, stack_size); + + // 1 + exp(2x) + h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0)); + + // 1 - 2 / (1 + exp(2x)) + h->uni_vmovups(vmm_aux2, table_val(16)); + h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3); + h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0)); + + // we blend only the result that need update + blend_results(vmm_aux2); + + // finally, we saturate to 1 if needed + // TODO: maybe move that up if most inputs saturate in practice + if (isa == avx512_common) + h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5); + else { + h->uni_vmovups(vmm_aux0, vmm_src); + h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15)); + } + h->uni_vmovups(vmm_aux2, table_val(0)); + blend_results(vmm_aux2); + + h->L(end_tanh_label); + { + // we apply the sign of x to the result and we are done + h->uni_vmovups(vmm_src, vmm_aux1); + h->uni_vpxor(vmm_src, vmm_src, vmm_aux4); + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::square_compute_vector( + const Vmm &vmm_src) { + h->uni_vmulps(vmm_src, vmm_src, vmm_src); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) { + // compute abs(x) = _mm_and_ps(x, 01111..111)); + h->uni_vandps(vmm_src, vmm_src, table_val(0)); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src) +{ + if (isa == avx512_common) { + h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1); + } else { + h->uni_vmovups(vmm_mask, vmm_src); + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0)); + h->uni_vsqrtps(vmm_aux1, vmm_src); + h->uni_vmovups(vmm_src, table_val(0)); + h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask); + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector( + const Vmm &vmm_src) { + // compute x = alpha * x + beta; + h->uni_vmovups(vmm_aux0, table_val(0)); + h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1)); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector( + const Vmm &vmm_src) { + // compute bounded relu */ + h->uni_vmaxps(vmm_src, vmm_src, table_val(1)); + h->uni_vminps(vmm_src, vmm_src, table_val(0)); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector( + const Vmm &vmm_src) { + // duplicate src + h->uni_vmovups(vmm_aux2, vmm_src); + + h->uni_vminps(vmm_src, vmm_src, table_val(24)); + h->uni_vmaxps(vmm_src, vmm_src, table_val(25)); + h->uni_vmovups(vmm_aux1, vmm_src); + // calculate exp(x) + // fx = x * log2ef + 0.5 + h->uni_vmulps(vmm_src, vmm_src, table_val(2)); + h->uni_vaddps(vmm_src, vmm_src, table_val(1)); + + // tmp = floorf(fx) + if (isa == avx512_common) { + h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src); + h->vcvtdq2ps(vmm_aux0, vmm_aux0); + + h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us); + h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0)); + + h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3); + } else { + h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); + } + + // keep fx for further computations + h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx + // calculation fx * ln2 + h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3)); + // x = x - fx * ln2 + h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); + // y = p5 + h->uni_vmovups(vmm_aux3, table_val(22)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0)); + // y = y * x + p0 + h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17)); + + // compute 2^(-n) + if (isa == avx512_common) { + h->vmulps(vmm_aux1, vmm_src, table_val(23)); + h->vcvtps2dq(vmm_aux1, vmm_aux1); + } else { + h->uni_vcvtps2dq(vmm_aux1, vmm_src); + h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23)); + } + + h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4)); + h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx + // calculate ln(1 + y) + h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); + // x = y; y is free; keep x for further computations + h->uni_vmovups(vmm_src, vmm_aux3); + // frexp() + h->uni_vpsrld(vmm_src, vmm_src, 23); + h->uni_vcvtdq2ps(vmm_src, vmm_src); + // got n. where n is x = 2^n * y. y = 0.5 .. 1 + h->uni_vsubps(vmm_src, vmm_src, table_val(5)); + + h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6)); + // got y. (mantisa) 0.5 < y < 1 + h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7)); + // y = y - 1 + h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0)); + // y = p8 + h->uni_vmovups(vmm_aux1, table_val(16)); + // y = y * x + p7 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15)); + // y = y * x + p6 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14)); + // y = y * x + p5 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13)); + // y = y * x + p4 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12)); + // y = y * x + p3 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11)); + // y = y * x + p2 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10)); + // y = y * x + p1 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9)); + // y = y * x + p0 ; p0 = 0 + h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8)); + //calculate ln(2) * n + h->uni_vmulps(vmm_src, vmm_src, table_val(3)); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src); + h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0); + + // get vmm_mask = src > max logf + h->uni_vmovups(vmm_mask, vmm_aux2); + if (isa == avx512_common) { + // y = (x < max log f) ? soft_relu(x) : x + h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us); + h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2); + } else { + // y = (x < max log f) ? soft_relu(x) : x + h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24)); + h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask); + } + + h->uni_vmovups(vmm_src, vmm_aux1); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector( + const Vmm &vmm_src) { + // we store the original sign and make x negative + // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required + // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it. + h->uni_vmovups(vmm_aux2, vmm_src); + h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12)); + h->uni_vorps(vmm_src, vmm_src, table_val(12)); + + exp_compute_vector(vmm_src); + // dup exp(x) + h->uni_vmovups(vmm_aux1, vmm_src); + // (exp(x) + 1) + h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0)); + // y = exp(x) / (exp(x) + 1) + h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); + + // Now we have to apply the "symmetry" based on original sign + h->uni_vmovups(vmm_aux3, table_val(0)); + h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); + if (isa == avx512_common) { + h->vptestmd(k_mask, vmm_aux2, vmm_aux2); + h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src); + } else { + h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2 + h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0); + } + h->uni_vmovups(vmm_src, vmm_aux3); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + // exp(x) polynom + 0x3f800001, // [5] p0 = 1.0000001f + 0x3efffe85, // [6] p2 = 0.4999887f + 0x3e2aaa3e, // [7] p3 = 0.16666505f + 0x3d2bb1b1, // [8] p4 = 0.041917507f + 0x3c091ec1, // [9] p5 = 0.008369149f + 0x42b0c0a5, //[10] max logf = 88.3762589f + 0xc1766666, //[11] min logf = -14.5f + // tanh(x) constants, + 0x80000000, //[12] mask to extract sign + 0x39ddb3d7, //[13] arg below which tanh(x) = x + 0x3f0c9f54, //[14] arg below which pol approx is valid + 0x41102cb4, //[15] arg after which tanh(x) = 1 + 0xc0000000, //[16] -2.0f + 0x7fffffff, //[17] mask to make positive + // tanh pol approx + 0x3f7fffff, //[18] p0 + 0xbeaaa9cf, //[19] p1 + 0x3e085f1f, //[20] p2 + 0xbd572bda, //[21] p3 + 0x3c84fd08, //[22] p4 + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]); + } + + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() { + const unsigned int cvals[] = { + 0x3f800000, // [0] 1.0f + 0x3f000000, // [1] 0.5f + 0x3fb8aa3b, // [2] log2ef = 1.44269502f + 0x3f317218, // [3] ln2f = 0.69314718f + 0x0000007f, // [4] 0x7f + 0x42fc0000, // [5] 126 + 0x807fffff, // [6] and with (to get 0.5 * mantissa) + 0x3f000000, // [7] or with (to get 0.5 * mantissa) + // ln(1 + x) polynomial + 0xb2b4637d, // [8] p0 = 0.0000000244f + 0x3f7fff8e, // [9] p1 = 0.9999976971f + 0xbf001759, //[10] p2 = -0.5002478215f + 0x3ea70608, //[11] p3 = 0.3272714505f + 0xbea3d7bf, //[12] p4 = -0.3153830071f + 0xbe361d04, //[13] p5 = -0.1701777461f + 0xbfa8f1e6, //[14] p6 = -1.3254635147f + 0xbfe1e812, //[15] p7 = -1.7971917960f + 0xbfc4d30e, //[16] p8 = -1.5652673123f + // exp(x) polynomial + 0x3f800001, //[17] p0 = 1.0000001f + 0x3f800000, //[18] p1 = 1.0f + 0x3efffe85, //[19] p2 = 0.4999887f + 0x3e2aaa3e, //[20] p3 = 0.16666505f + 0x3d2bb1b1, //[21] p4 = 0.041917507f + 0x3c091ec1, //[22] p5 = 0.008369149f + 0xbf800000, //[23] is required for sign changing + 0x42b0c0a5, //[24] max logf = 88.3762589f + 0xc1766666 //[25] min logf = -14.5f + }; + + for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) { + for (size_t d = 0; d < vlen / sizeof(float); ++d) { + h->dd(cvals[i]); + } + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_)); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() { + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_)); + for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0); +} + +template <cpu_isa_t isa> +int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) { + switch (alg_) { + case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; + case alg_kind::eltwise_elu: return 4; + case alg_kind::eltwise_tanh: return 5; + case alg_kind::eltwise_square: return 0; + case alg_kind::eltwise_abs: return 0; + case alg_kind::eltwise_sqrt: return 2; + case alg_kind::eltwise_linear: return 1; + case alg_kind::eltwise_bounded_relu: return 0; + case alg_kind::eltwise_soft_relu: return 4; + case alg_kind::eltwise_logistic: return 4; + default: assert(!"unsupported eltwise algorithm"); + } + + return 0; +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx, + size_t end_idx) { + using namespace alg_kind; + for (size_t idx = start_idx; idx < end_idx; idx++) { + switch (alg_) { + case eltwise_relu: + if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx)); + else relu_compute_vector(Vmm(idx)); + break; + case eltwise_elu: elu_compute_vector(Vmm(idx)); break; + case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break; + case eltwise_square: square_compute_vector(Vmm(idx)); break; + case eltwise_abs: abs_compute_vector(Vmm(idx)); break; + case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break; + case eltwise_linear: linear_compute_vector(Vmm(idx)); break; + case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break; + case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break; + case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx, + size_t end_idx) { + assert(start_idx < end_idx && end_idx <= vecs_count); + + injector_preamble(start_idx, end_idx); + compute_body(start_idx_tail, end_idx); + injector_preamble_tail(start_idx); + compute_body(start_idx, start_idx_tail); + injector_postamble(); +} + +template <cpu_isa_t isa> +void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) { + using namespace alg_kind; + + h->align(64); + h->L(l_table); + + if (gen_table) { + switch (alg_) { + case eltwise_relu: relu_prepare_table(); break; + case eltwise_elu: + case eltwise_tanh: + case eltwise_logistic: + elu_prepare_table(); break; + case eltwise_soft_relu: soft_relu_prepare_table(); break; + case eltwise_abs: abs_prepare_table(); break; + case eltwise_sqrt: sqrt_prepare_table(); break; + case eltwise_linear: linear_prepare_table(); break; + case eltwise_bounded_relu: bounded_relu_prepare_table(); break; + case eltwise_square: break; + default: assert(!"unsupported eltwise algorithm"); + } + } +} + +template struct jit_uni_eltwise_injector_f32<avx512_common>; +template struct jit_uni_eltwise_injector_f32<avx2>; +template struct jit_uni_eltwise_injector_f32<sse42>; + + +struct jit_args { + const float *from; + const float *for_comparison; + const float *to; + size_t work_amount; +}; + +struct jit_uni_eltwise_kernel_f32 : public c_compatible { + const eltwise_desc_t &desc_; + + void (*ker_)(const jit_args *); + void operator()(const jit_args *args) { assert(ker_); ker_(args); } + + jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc) + : desc_(desc), ker_(nullptr) {} + virtual ~jit_uni_eltwise_kernel_f32() {} + +protected: + bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; } +}; + +/* jit kernels */ +namespace { + +template <cpu_isa_t isa> +struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32, + public jit_generator +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32) + + void compute_step(bool vectorize, const int uf, const int shift) { + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + uni_vmovups(Vmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } else { + movss(Xmm(i + 1), ptr[reg_from + i * shift]); + if (is_bwd()) + movss(Xmm(uf + i + 1), + ptr[reg_for_comparison + i * shift]); + } + } + + if (isa == sse42) { + for (int i = 0; i < uf; i++) { + movups(Vmm(2 * uf + i + 1), Vmm(i + 1)); + mulps(Vmm(2 * uf + i + 1), vmm_ns); + + Vmm mask = Vmm(0); + if (is_bwd()) { + movups(mask, Vmm(uf + i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } else { + movups(mask, Vmm(i + 1)); + cmpps(mask, vmm_zero, _cmp_nle_us); + } + blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1)); + } + } else { + for (int i = 0; i < uf; i++) { + vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns); + if (isa == avx2) { + if (is_bwd()) + vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero); + else + vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero); + + vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1), + Vmm(i + 1), vmm_mask); + + } else { + if (is_bwd()) + vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us); + else + vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us); + vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1), + Vmm(i + 1)); + } + } + } + + for (int i = 0; i < uf; i++) { + if (vectorize) { + uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1)); + } else { + movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1)); + } + } + } + + jit_uni_relu_kernel_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + assert(desc.alg_kind == alg_kind::eltwise_relu); + assert(isa == sse42 || isa == avx2 || isa == avx512_common); + + Reg64 param = abi_param1; + + const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); + const int loop_dec[] = {simd_w, 1}; + const int uf[] = {1, 1}; + const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)}; + const bool loop_vectorize[] = {true, false}; + + this->preamble(); + + mov(reg_from, ptr[param + GET_OFF(from)]); + if (is_bwd()) + mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + + mov(imm_addr64, float2int(desc.alpha)); + movq(xmm_ns, imm_addr64); + uni_vbroadcastss(vmm_ns, xmm_ns); + + uni_vpxor(vmm_zero, vmm_zero, vmm_zero); + + Label loop_label[3]; + + for (int id = 0; id < 2; id++) { + L(loop_label[id]); + cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); + jle(loop_label[id + 1], T_NEAR); + + compute_step(loop_vectorize[id], uf[id], shift[id]); + + add(reg_from, uf[id] * shift[id]); + add(reg_to, uf[id] * shift[id]); + if (is_bwd()) + add(reg_for_comparison, uf[id] * shift[id]); + + sub(reg_work_amount, uf[id] * loop_dec[id]); + jmp(loop_label[id]); + } + + L(loop_label[2]); + this->postamble(); + + ker_ = (decltype(ker_))this->getCode(); + } + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xmm, + isa == avx2, Ymm, Zmm>::type; + + Reg64 reg_from = rax; + Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_ns = Xmm(14); + + Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14); + Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15); + + Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12); + Opmask k_mask = Opmask(1); +}; + +template <cpu_isa_t isa> +struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32, + public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32) + + jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc) + : jit_uni_eltwise_kernel_f32(desc), jit_generator() { + + eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this, + desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1)); + + using namespace alg_kind; + + assert(is_bwd() == false); + assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + + preamble(); + + Reg64 param = abi_param1; + mov(reg_from, ptr[param + GET_OFF(from)]); + mov(reg_to, ptr[param + GET_OFF(to)]); + mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); + eltwise_injector_->load_table_addr(); + + Label reminder_loop_start, reminder_loop_end; + Label vectorized_loop_start, vectorized_loop_end; + + cmp(reg_work_amount, simd_w); + jl(reminder_loop_start, T_NEAR); + + L(vectorized_loop_start); + + uni_vmovups(vmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(vmm_src.getIdx()); + uni_vmovups(ptr[reg_to], vmm_src); + + add(reg_from, vlen); + add(reg_to, vlen); + + sub(reg_work_amount, simd_w); + cmp(reg_work_amount, simd_w); + jge(vectorized_loop_start, T_NEAR); + + L(vectorized_loop_end); + + L(reminder_loop_start); + + cmp(reg_work_amount, 0); + jle(reminder_loop_end, T_NEAR); + + movss(xmm_src, ptr[reg_from]); + eltwise_injector_->compute_vector(xmm_src.getIdx()); + movss(ptr[reg_to], xmm_src); + + add(reg_from, sizeof(float)); + add(reg_to, sizeof(float)); + + dec(reg_work_amount); + jmp(reminder_loop_start, T_NEAR); + + L(reminder_loop_end); + + postamble(); + + eltwise_injector_->prepare_table(); + + ker_ = (decltype(ker_))this->getCode(); + } + + ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; } + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xmm, + isa == avx2, Ymm, Zmm>::type; + + const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); + const int vlen = cpu_isa_traits<isa>::vlen; + + Reg64 reg_from = rax; + Reg64 reg_to = r8; + Reg64 reg_work_amount = rsi; + Reg64 imm_addr64 = rbx; + + Xmm xmm_src = Xmm(1); + Vmm vmm_src = Vmm(1); + + jit_uni_eltwise_injector_f32<isa> *eltwise_injector_; +}; + +} /* namespace */ + +template <cpu_isa_t isa> +status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() { + using namespace alg_kind; + + bool ok = true + && mayiuse(isa) + && is_fwd() + && utils::everyone_is(data_type::f32, desc()->data_desc.data_type) + && !has_zero_dim_memory() + && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh, + eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, + eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, + eltwise_logistic) + && memory_desc_wrapper(src_md()).is_dense(true) + && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false), + math::eltwise_fwd_preserves_zero(desc()->alg_kind, true)) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template <cpu_isa_t isa> +jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break; + default: + kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc); + } +} + +template <cpu_isa_t isa> +jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t() +{ delete kernel_; } + +template <cpu_isa_t isa> +void jit_uni_eltwise_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const size_t nelems = data_d.nelems(true); + + src += data_d.offset0(); + dst += data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &src[start]; + arg.for_comparison = &src[start]; + arg.to = &dst[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template <cpu_isa_t isa> +status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() { + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu) + && src_md()->data_type == data_type::f32 + && !has_zero_dim_memory() + && mayiuse(isa) + && memory_desc_wrapper(src_md()).is_dense() + && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md()) + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; +} + +template <cpu_isa_t isa> +jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd), kernel_(nullptr) { + const auto &desc = *pd()->desc(); + switch (desc.alg_kind) { + case alg_kind::eltwise_relu: + kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break; + default: assert(!"unknown eltwise alg_kind"); + } +} + +template <cpu_isa_t isa> +jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t() +{ delete kernel_; } + +template <cpu_isa_t isa> +void jit_uni_eltwise_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const size_t nelems = data_d.nelems(); + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + + const int cache_line = 16; + + balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); + start = nstl::min(nelems, start * cache_line); + end = nstl::min(nelems, end * cache_line); + + auto arg = jit_args(); + arg.from = &diff_dst[start]; + arg.to = &diff_src[start]; + arg.for_comparison = &src[start]; + arg.work_amount = end - start; + if (arg.work_amount) + (*kernel_)(&arg); + }); +} + +template struct jit_uni_eltwise_fwd_t<sse42>; +template struct jit_uni_eltwise_bwd_t<sse42>; +template struct jit_uni_eltwise_fwd_t<avx2>; +template struct jit_uni_eltwise_bwd_t<avx2>; +template struct jit_uni_eltwise_fwd_t<avx512_common>; +template struct jit_uni_eltwise_bwd_t<avx512_common>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp new file mode 100644 index 0000000000..45436b9f46 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp @@ -0,0 +1,193 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_ELTWISE_HPP +#define CPU_JIT_UNI_ELTWISE_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +struct jit_uni_eltwise_injector_f32 { + using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm, + isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; + + jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, + float alpha, float beta, bool save_state = true, + Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : alg_(alg), alpha_(alpha), beta_(beta), h(host) + , save_state_(save_state), p_table(p_table), k_mask(k_mask) + { + using namespace alg_kind; + assert(utils::one_of(isa, sse42, avx2, avx512_common)); + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); + } + + // note that eltwise.scale is ignored + jit_uni_eltwise_injector_f32(jit_generator *host, + const post_ops_t::entry_t::eltwise_t &eltwise, + bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, + Xbyak::Opmask k_mask = Xbyak::Opmask(1)) + : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, + eltwise.beta, save_state, p_table, k_mask) {} + + void compute_vector_range(size_t start_idx, size_t end_idx); + void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); } + void prepare_table(bool gen_table=true); + void load_table_addr() { h->mov(p_table, l_table); } + + const alg_kind_t alg_; + const float alpha_; + const float beta_; + + jit_generator * const h; + + const bool save_state_; + const Xbyak::Reg64 p_table; + const Xbyak::Opmask k_mask; + Xbyak::Label l_table; + +private: + // if only the injector was inherited from jit_generator... + enum { + _cmp_le_os = jit_generator::_cmp_le_os, + _cmp_nle_us = jit_generator::_cmp_nle_us, + _op_floor = jit_generator::_op_floor, + }; + + size_t vlen = cpu_isa_traits<isa>::vlen; + + const static size_t preserved_vecs_max = 5; + + size_t vecs_to_preserve = 0; + size_t vecs_count = isa == avx512_common ? 32 : 16; + size_t preserved_vecs_count = 0; + size_t preserved_vec_idxs[preserved_vecs_max] = {0}; + size_t start_idx_tail = 0; + + Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4; + + Xbyak::Address table_val(int index) + { return h->ptr[p_table + index * vlen]; } + + int aux_vecs_count(alg_kind_t alg); + + void compute_body(size_t start_idx, size_t end_idx); + void injector_preamble(size_t start_idx, size_t end_idx); + void injector_preamble_tail(size_t start_idx); + void injector_postamble(); + void assign_regs(); + + void exp_compute_vector(const Vmm &vmm_src); + void relu_compute_vector(const Vmm &vmm_src); + void relu_zero_ns_compute_vector(const Vmm &vmm_src); + void elu_compute_vector(const Vmm &vmm_src); + void tanh_compute_vector(const Vmm &vmm_src); + void square_compute_vector(const Vmm &vmm_src); + void abs_compute_vector(const Vmm &vmm_src); + void sqrt_compute_vector(const Vmm &vmm_src); + void linear_compute_vector(const Vmm &vmm_src); + void bounded_relu_compute_vector(const Vmm &vmm_src); + void soft_relu_compute_vector(const Vmm &vmm_src); + void logistic_compute_vector(const Vmm &vmm_src); + + void relu_prepare_table(); + void elu_prepare_table(); + void soft_relu_prepare_table(); + void abs_prepare_table(); + void sqrt_prepare_table(); + void linear_prepare_table(); + void bounded_relu_prepare_table(); +}; + +struct jit_uni_eltwise_kernel_f32; + +template <cpu_isa_t isa> +struct jit_uni_eltwise_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_fwd_t<isa>); + + status_t init(); + }; + + jit_uni_eltwise_fwd_t(const pd_t *apd); + ~jit_uni_eltwise_fwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +template <cpu_isa_t isa> +struct jit_uni_eltwise_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_eltwise_bwd_t<isa>); + + status_t init(); + }; + + jit_uni_eltwise_bwd_t(const pd_t *apd); + ~jit_uni_eltwise_bwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_eltwise_kernel_f32 *kernel_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp new file mode 100644 index 0000000000..a3ca6273a0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp @@ -0,0 +1,949 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "jit_uni_i8i8_pooling.hpp" + +#include <math.h> + +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::types; +using namespace alg_kind; + +template <cpu_isa_t isa> +struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t) + + struct call_params_t { + const char *src_i8; + const char *dst_i8; + size_t kw_range; + size_t kh_range; + float idivider; + }; + + using Vmm = typename cpu_isa_traits<isa>::Vmm; + Xmm xreg(int idx) const { return Xmm(idx); } + Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); } + + // In case of avx2 with data type i8 we need to use + // maskmovdqu instruction which has its destination hardcoded in rdi. + // Windows ABI: abi_param1 is rcx - nothing to do else + // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1 + Reg64 reg_param = rcx; // Our "unified abi_param1" + Reg64 reg_ptr_src_i8 = r8; + Reg64 reg_ptr_dst_i8 = r9; + Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi + + Reg64 ki = r10; + Reg64 kj = r11; + Reg64 reg_kw = r12; + Reg64 reg_kh = r13; + Reg64 c_iter = r14; + + Reg64 aux_reg_src_h = rax; + Reg64 aux_reg_src_w = rbx; + + Reg64 reg_tmp = rdx; + + Reg64 reg_mask = r15; + + Opmask k_cmp_mask = Opmask(7); + + Opmask mask(int idx) { + return Opmask(6 - idx); + } + + // ref to any of XYZ-regs via xreg/yreg/vreg functions + Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp + Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type + Vmm vreg_zeros = vreg(1); + + // only in case of <isa> == avx2 + Vmm vreg_mask = vreg(2); // full byte-mask + Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask) + Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately) + Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations + Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails + + enum:int {vidx_base = isa == avx2 ? 4 : 2}; + Vmm base_vr(int idx) const { return vreg(vidx_base + idx); } + + size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); } + size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); } + + /* max pooling */ + Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1] + Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1] + + /* avg pooling */ + // s32 used for processing of s8/u8 data + // thus we need to take into account ratio of sizes s32/i8 = 4 + static constexpr data_type_t avg_proc_dt = data_type::s32; + enum:int { + s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type) + / sizeof(typename prec_traits<data_type::u8>::type), + max_num_ll = s32_to_i8_ratio + }; + Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3] + Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7] + Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11] + + void (*ker_)(const call_params_t *); + jit_pool_conf_t jpp; + + void init_tmp_reg(); + void init_mask(); + + void load_vreg_mask_q(int ll) {}; + + void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void load_src(int jj, int ll, int c_tail); + + void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk); + void store_dst(int jj, int ll, int c_tail); + + void compute_avg_step(int ur_c, int c_tail); + void compute_max_op(const int jj); + void compute_max_step(int ur_c, int c_tail); + void compute_step(int ur_c, int c_tail); + + void compute_c_block(); + void generate(); + + static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd); + + jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_) + : jpp(jpp_) { + generate(); + ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>( + getCode())); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) { + + // extract ll-th part of mask (ll-th QWORD) + vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD + + // Move mask from ll-th pos to 0-th pos + if (ll>0) + vpermq(vreg_mask_q, vreg_mask_q, ll); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) { + vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk)); + } else { + vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask); + } + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + if (jpp.src_dt == s32) + vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + else + vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]); + } else + vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]); +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto load_i8 = [&](bool is_signed, const Vmm& vr_src) { + + // Need to use mask of tail? + if (masked) { + + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + + // Load by mask from mem into register vr_src + vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q); + + // Conversion s8/u8 -> s32 + if (is_signed) + vpmovsxbd(vr_src, vr_src); + else + vpmovzxbd(vr_src, vr_src); + } else { + + // Load from mem into vr_src with conversion + if (is_signed) + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + else + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + } + }; + + switch (jpp.src_dt) { + case s32: + if (masked) + vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset], + static_cast<uint8_t>(msk)); + else + vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]); + break; + case s8: + load_i8(true, vreg_src_s32(jj, ll)); + break; + case u8: + load_i8(false, vreg_src_s32(jj, ll)); + break; + default: assert(!"unsupported src data type"); + } +}; + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_src = masked ? + vreg_src_s32(jj, ll) | mask(ll) : + vreg_src_s32(jj, ll); + + switch (jpp.src_dt) { + case s32: + vmovups(vr_src, ptr[aux_reg_src_w + offset]); + break; + case s8: + vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + case u8: + vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); + break; + default: assert(!"unsupported src data type"); + } +}; + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch (jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_max_op(jj, ll, offset, masked, jpp.tail[0]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt(); + bool masked = jj == ur_c - 1 && c_tail; + load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + int c_block = jpp.c_block; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj)); + break; + case s8: + case u8: { + // Store low half by mask (bytes 0...15) + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vreg_dst(jj), xreg_mask_lo); + + // Do we need to store high half (bytes 16...31) ? + const uint64_t low_mask = (1ULL << (c_block/2))-1; + if (msk & ~low_mask) { + vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1); + add(reg_ptr_maskmovdqu_dst, c_block / 2); + maskmovdqu(vreg_dst(jj), xreg_mask_hi); + } + } break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + if (masked) { + switch (jpp.src_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + case s8: + case u8: + vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0)); + break; + default: assert(!"unsupported src data type"); + } + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj)); +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk){ + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) { + + // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16} + // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0} + if (is_signed) + vpackssdw(vr_dst, vr_dst, vreg_zeros); + else + vpackusdw(vr_dst, vr_dst, vreg_zeros); + + // Permute qwords to restore original order + // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} + vpermq(vr_dst, vr_dst, 0x58); + + // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8} + // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx} + if (is_signed) + vpacksswb(vr_dst, vr_dst, vreg_zeros); + else + vpackuswb(vr_dst, vr_dst, vreg_zeros); + + }; + + auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) { + + // Conversion s32 -> s8/u8 + s32_to_i8(is_signed, vr_dst); + + // Need to use mask of tail? + if (is_masked) { + // load ll-th part of mask into vreg_mask_q + load_vreg_mask_q(ll); + } + + // store 8 bytes + lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]); + maskmovdqu(vr_dst, xreg_mask_q); + }; + + switch (jpp.dst_dt) { + case s32: + if (masked) { + vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll)); + } else + vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll)); + break; + case s8: + store_i8(true, masked, vreg_dst_s32(jj, ll)); + break; + case u8: + store_i8(false, masked, vreg_dst_s32(jj, ll)); + break; + default: assert(!"unsuppotred dst data_type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll, + size_t offset, bool masked, uint64_t msk) { + using namespace data_type; + + // Don't generate useless code + if (masked && !msk) + return; + + const Vmm& vr_dst = masked ? + vreg_dst_s32(jj, ll) | mask(ll) : + vreg_dst_s32(jj, ll); + + switch (jpp.dst_dt) { + case s32: + vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case s8: + vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + case u8: + vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); + break; + default: assert(!"unsupported dst data_type"); + } +} + + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll, + int c_tail) { + using namespace data_type; + + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + + switch(jpp.alg) { + case pooling_max: { + auto offset = jj*c_block*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt(); + bool masked = jj == ur_c - 1 && c_tail; + store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]); + break; + } + default: assert(!"unsupported pooling algorithm"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj) +{ + using namespace data_type; + switch (jpp.src_dt) { + case s32: + vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case s8: + vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + case u8: + vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); + break; + default: assert(!"unsupported src data type"); + } +} + +template <> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj) +{ + using namespace data_type; + + // Compare + switch (jpp.src_dt) { + case s32: + vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case s8: + vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + case u8: + vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os); + break; + default: assert(!"unsupported src data type"); + } + + // move max values into vreg_dst + if (jpp.src_dt == s32) + vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); + else + vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj)); +} + + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail) +{ + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + for (int jj = 0; jj < ur_c; jj++) + vmovups(vreg_dst(jj), vreg_tmp); + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + load_src(jj, 0, c_tail); + compute_max_op(jj); + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) + store_dst(jj, 0, c_tail); +} + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail) +{ + using namespace data_type; + + Label l_kw, l_kh; + + int iw = jpp.iw; + int c = jpp.c; + + const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt); + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll)); + uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll)); + } + } + } + + mov(aux_reg_src_h, reg_ptr_src_i8); + + xor_(kj, kj); + L(l_kh); + { + mov(aux_reg_src_w, aux_reg_src_h); + xor_(ki, ki); + L(l_kw); + { + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + load_src(jj, ll, c_tail); + vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), + vreg_src_s32(jj, ll)); + } + } + } + add(aux_reg_src_w, c * sizeof_src_dt()); + inc(ki); + cmp(ki, reg_kw); + jl(l_kw, T_NEAR); + } + add(aux_reg_src_h, iw * c * sizeof_src_dt()); + inc(kj); + cmp(kj, reg_kh); + jl(l_kh, T_NEAR); + } + + for (int jj = 0; jj < ur_c; jj++) { + for (int ll = 0; ll < num_ll; ll++) { + bool masked = jj == ur_c - 1 && c_tail; + size_t msk = jpp.tail[ll]; + if (!(masked && !msk)) { + vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll)); + vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp); + vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll)); + store_dst(jj, ll, c_tail); + } + } + } +} + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) { + switch (jpp.alg) { + case pooling_max: + compute_max_step(ur_c, c_tail); break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + compute_avg_step(ur_c, c_tail); break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){ + Label l_main_loop; + + int nb_c = jpp.nb_c; + int c_block = jpp.c_block; + int ur_c = jpp.ur_c; + int ur_c_tail = jpp.ur_c_tail; + int c_steps = nb_c / ur_c; + int c_tail = jpp.c_tail; + + xor_(c_iter, c_iter); + if (c_steps > 0) { + L(l_main_loop); { + compute_step(ur_c, 0); + add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt()); + add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt()); + inc(c_iter); + cmp(c_iter, c_steps); + jl(l_main_loop, T_NEAR); + } + } + + if (ur_c_tail != 0) { + compute_step(ur_c_tail, c_tail); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() { + using namespace data_type; + using cpu_isa = cpu_isa_traits<avx2>; + + // AVX2 mask initialization: mask stored in Ymm-regs + auto init = [&](uint64_t bit_mask, bool init_mask_q) { + const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t); + + uint64_t vmask[QW_PER_VREG]; + for (size_t i = 0; i < QW_PER_VREG; i++){ + + uint64_t qw_vmask=0ULL; + const size_t DBITS = 8*sizeof_src_dt(); + const uint64_t VMSK = 1ULL << (DBITS-1); + const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS; + for (size_t j = 0; j < D_PER_QW; j++) { + if (bit_mask & 1) + qw_vmask |= VMSK << DBITS * j; + bit_mask >>= 1; + } + vmask[i] = qw_vmask; + } + + // Put QWORDS with target mask into xmm regs + const int xdst_i[QW_PER_VREG] = { + xreg_mask_lo.getIdx(), + xreg_mask_lo.getIdx(), + xreg_mask_hi.getIdx(), + xreg_mask_hi.getIdx() + }; + const int xsrc_i[QW_PER_VREG] = { + vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0} + xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1} + vreg_zeros.getIdx(), + xreg_mask_hi.getIdx() + }; + const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg + + for (size_t i = 0; i < QW_PER_VREG; i++) { + mov(reg_mask, vmask[i]); + vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]); + } + + // Merge Low (xreg_mask_lo alias for vreg_mask.xreg) + // and High (xreg_mask_hi) into full vreg_mask + // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg} + vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1); + + // Keep only low qword of mask in xreg_mask_q + if (init_mask_q) { + mov(reg_mask, vmask[0]); + vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0); + } + }; + + uint64_t tail_mask = (1ULL << jpp.c_tail) - 1; + switch (jpp.alg) { + case pooling_max: + // For "max" we need mask only in case of non-zero tail + if (tail_mask) + init(tail_mask, false); + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + // For "avg" we need mask: + // - s32 - in case of the non-zero tail + // - s8/u8 - irrespective of the tail + switch (jpp.src_dt) { + case s32: + if (tail_mask) + init(tail_mask, false); + break; + case s8: + case u8: + init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0); + break; + default: assert(!"unsupported src data type"); + } + break; + default: assert(!"unsupported pooling algorithm"); + } +} + +template<> +void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() { + + for (int ll = 0; ll < max_num_ll; ll++) { + mov(reg_mask, jpp.tail[ll]); + kmovq(mask(ll), reg_mask); + } +} + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() { + using namespace data_type; + + switch (jpp.alg) { + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: + mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]); + movq(xmm_tmp, reg_tmp); + vpbroadcastd(vreg_tmp, xmm_tmp); + break; + case pooling_max: + switch (jpp.src_dt) { + case s32: + mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest()); + break; + case s8: + mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest()); + break; + case u8: + mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest()); + break; + default: assert(!"unsupported src data_type"); + } + + movq(xmm_tmp, reg_tmp); + if (jpp.src_dt == s32) + vpbroadcastd(vreg_tmp, xmm_tmp); + else + vpbroadcastb(vreg_tmp, xmm_tmp); + break; + default: assert(!"unsupported pooling algorithm"); + } + +} + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() { + preamble(); + +#if !defined(_WIN32) + // Always use rcx as abi_param1 - + // see the note about maskmovdqu near reg_param. + mov(rcx, rdi); +#endif + +# define READ_PARAM(reg, field) \ + mov(reg, ptr[reg_param + offsetof(call_params_t, field)]) + READ_PARAM(reg_ptr_src_i8, src_i8); + READ_PARAM(reg_ptr_dst_i8, dst_i8); + READ_PARAM(reg_kw, kw_range); + READ_PARAM(reg_kh, kh_range); + +# undef READ_PARAM + + uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros); + + init_mask(); + + init_tmp_reg(); + + compute_c_block(); + + postamble(); +} + +template <cpu_isa_t isa> +status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + if (!mayiuse(isa)) + return status::unimplemented; + + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d(ppd->src_md()); + const memory_desc_wrapper dst_d(ppd->dst_md()); + + jpp.mb = src_d.dims()[0]; + jpp.c = src_d.dims()[1]; + jpp.ih = src_d.dims()[2]; + jpp.iw = src_d.dims()[3]; + jpp.oh = dst_d.dims()[2]; + jpp.ow = dst_d.dims()[3]; + + jpp.stride_h = pd.strides[0]; + jpp.stride_w = pd.strides[1]; + jpp.kh = pd.kernel[0]; + jpp.kw = pd.kernel[1]; + + jpp.t_pad = pd.padding[0][0]; + jpp.l_pad = pd.padding[0][1]; + + jpp.alg = pd.alg_kind; + + jpp.src_dt = pd.src_desc.data_type; + jpp.dst_dt = pd.dst_desc.data_type; + + // data_type items per one vreg on the <isa> + // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32 + // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32 + int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt); + + jpp.c_block = simd_w; + jpp.c_tail = jpp.c % jpp.c_block; + jpp.nb_c = jpp.c / jpp.c_block; + jpp.ur_c = 1; + jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + + (jpp.c_tail != 0); + + size_t tail_mask = (1ULL << jpp.c_tail) - 1; + + switch (jpp.alg) { + case pooling_max: + jpp.tail[0] = tail_mask; + jpp.tail[1] = 0; + jpp.tail[2] = 0; + jpp.tail[3] = 0; + break; + case pooling_avg_include_padding: + case pooling_avg_exclude_padding: { + // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32) + // avx2 : 8, avx512 : 16 + const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt); + const size_t msk_msk = (1ULL << msk_gran) - 1; + size_t m = tail_mask; + for (size_t ll = 0; ll < max_num_ll; ll++) { + jpp.tail[ll] = m & msk_msk; + m = m >> msk_gran; + } + break; + } + default: return status::unimplemented; + } + + return status::success; +} + +template <cpu_isa_t isa> +status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() { + return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this); +} + +template <cpu_isa_t isa> +jit_uni_i8i8_pooling_fwd_t<isa>:: +jit_uni_i8i8_pooling_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) +{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); } + +template <cpu_isa_t isa> +jit_uni_i8i8_pooling_fwd_t<isa>:: +~jit_uni_i8i8_pooling_fwd_t() { delete ker_; } + +template <cpu_isa_t isa> +void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward( + const exec_ctx_t &ctx) const { + auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC); + auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const auto &jpp = pd()->jpp_; + + parallel_nd(jpp.mb, jpp.oh, jpp.ow, + [&](int n, int oh, int ow) { + const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0); + const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0); + + const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h); + const int kh_end = nstl::min(jpp.kh, + jpp.ih + jpp.t_pad - oh * jpp.stride_h); + const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w); + const int kw_end = nstl::min(jpp.kw, + jpp.iw + jpp.l_pad - ow * jpp.stride_w); + + auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t(); + p.src_i8 = &src_i8[ + src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()]; + p.dst_i8 = &dst_i8[ + dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()]; + p.kw_range = (size_t)(kw_end - kw_start); + p.kh_range = (size_t)(kh_end - kh_start); + p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ? + p.kh_range*p.kw_range : jpp.kw*jpp.kh); + + ker_->ker_(&p); + }); +} + +// Explicit instantiation only for supported <isa> values. +// +template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>; +template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>; + +template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>; +template struct jit_uni_i8i8_pooling_fwd_t<avx2>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp new file mode 100644 index 0000000000..d757679df5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP +#define CPU_JIT_UNI_I8I8_POOLING_HPP + +#include "c_types_map.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "cpu_isa_traits.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +struct jit_uni_i8i8_pooling_fwd_ker_t; + +template <cpu_isa_t isa> +struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_i8i8_pooling_fwd_t<isa>); + + status_t init() { + bool ok = true + && mayiuse(isa) + && ndims() == 4 + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::forward_inference + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::one_of(src_md()->data_type, data_type::s32, + data_type::s8, data_type::u8) + && src_md()->data_type == dst_md()->data_type + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*dst_md(), format_tag::nhwc); + if (!ok) return status::unimplemented; + + return jit_conf(); + } + + jit_pool_conf_t jpp_; + + protected: + status_t jit_conf(); + }; + + jit_uni_i8i8_pooling_fwd_t(const pd_t *apd); + ~jit_uni_i8i8_pooling_fwd_t(); + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_i8i8_pooling_fwd_ker_t<isa> *ker_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp new file mode 100644 index 0000000000..2c5a8e8973 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp @@ -0,0 +1,305 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" +#include "jit_uni_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::utils; + +template <cpu_isa_t isa> +jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(const pd_t *apd) + : cpu_primitive_t(apd), ker_(nullptr) + , ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float K = pd()->desc()->lrn_k; + + auto pk = pd()->desc()->prop_kind; + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw8c_across(H, W, 0), A, K, pk); + ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw8c_across(H, W, -1), A, K, pk); + ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw8c_across(H, W, +1), A, K, pk); + } else if (dat_tag == nChw8c && ak == lrn_within_channel) { + /* within channel, local_size (x) local_size */ + A /= ls; /* XXX: why? */ + ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw8c_within(H, W, ls), A, K, pk); + } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw_across(C, H*W, 0), A, K, pk); + int remind = (H*W) % VECTOR_LENGTH; + if (remind != 0) { + ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>( + nchw_across(C, H*W, remind), A, K, pk); + } + } else if (true /* XXX: why */) { + ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk); + } +} + +template <cpu_isa_t isa> +jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t() +{ delete ker_; delete ker_first_; delete ker_last_; } + +template <cpu_isa_t isa> +void jit_uni_lrn_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int HW = pd()->H() * pd()->W(); + const int ls = pd()->desc()->local_size; + + auto ak = pd()->desc()->alg_kind; + auto dat_tag = pd()->dat_tag_; + + if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else if (dat_tag == nChw8c && ak == lrn_within_channel) { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH]; + (*ker_)(&args); + }); + } + else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) { + parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH, + [&](int n, int hw8) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH]; + args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH]; + args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH]; + if ((hw8 + 1)*VECTOR_LENGTH > HW) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { // nhwc + parallel_nd(N, HW, [&](int n, int hw) { + jit_args_fwd_t args; + args.src = &src[n*HW*C + hw * C]; + args.dst = &dst[n*HW*C + hw * C]; + args.scratch = &ws[n*HW*C + hw * C]; + (*ker_)(&args); + }); + } +} + +template <cpu_isa_t isa> +status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && is_fwd() + && everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && data_d.dims()[1] >= 2 * VECTOR_LENGTH + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + if (desc_.prop_kind == forward_training) ws_md_ = *src_md(); + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && one_of(dat_tag_, nChw8c, nchw, nhwc); + + const int jit_max_local_size = 5; // bigger size triggers too big code size + bool args_ok_within = true + && desc()->alg_kind == lrn_within_channel + && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE + ? jit_max_local_size : MAX_LOCAL_SIZE) + && data_d.dims()[2] >= desc()->local_size + && data_d.dims()[3] >= desc()->local_size + && one_of(dat_tag_, nChw8c); + + return args_ok_across || args_ok_within ? success : unimplemented; +} + +template <cpu_isa_t isa> +jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd) + : cpu_primitive_t(apd) + , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr) +{ + using namespace alg_kind; + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const int ls = pd()->desc()->local_size; + float A = pd()->desc()->lrn_alpha / ls; + float B = pd()->desc()->lrn_beta; + + int use_h_parallelizm = 0;// XXX + if (C / VECTOR_LENGTH == 1) { + ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>( + nchw8c_across(H, W, 3), A, B, use_h_parallelizm); + } + else { + ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>( + nchw8c_across(H, W, 0), A, B, use_h_parallelizm); + ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>( + nchw8c_across(H, W, -1), A, B, use_h_parallelizm); + ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>( + nchw8c_across(H, W, +1), A, B, use_h_parallelizm); + } +} + +template <cpu_isa_t isa> +jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t() +{ + delete ker_; delete ker_first_; delete ker_last_; +} + +template <cpu_isa_t isa> +void jit_uni_lrn_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const int N = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + + int use_h_parallelizm = 0; // XXX + if (use_h_parallelizm) { + parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH + + h*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } + else { + parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) { + auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH; + jit_args_bwd_t args; + args.src = &src[offset]; + args.diff_dst = &diff_dst[offset]; + args.scratch = &ws[offset]; + args.diff_src = &diff_src[offset]; + if (C / VECTOR_LENGTH == 1) + (*ker_)(&args); + else if (c8 == 0) + (*ker_first_)(&args); + else if (c8 == C / VECTOR_LENGTH - 1) + (*ker_last_)(&args); + else + (*ker_)(&args); + }); + } +} + +template <cpu_isa_t isa> +status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() { + using namespace prop_kind; + using namespace alg_kind; + + const memory_desc_wrapper data_d(src_md()); + bool ok = true + && mayiuse(isa) + && !is_fwd() + && utils::everyone_is(data_type::f32, data_d.data_type()) + && !has_zero_dim_memory() + && data_d.ndims() == 4 + && data_d.dims()[1] % VECTOR_LENGTH == 0 + && desc()->lrn_beta == 0.75 + && attr()->has_default_values(); + if (!ok) return unimplemented; + + ws_md_ = *src_md(); + if (!compare_ws(hint_fwd_pd_)) return unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c); + + bool args_ok_across = true + && desc()->alg_kind == lrn_across_channels + && desc()->local_size == 5 + && utils::one_of(dat_tag_, nChw8c); + + return args_ok_across ? success : unimplemented; +} + +template struct jit_uni_lrn_fwd_t<sse42>; +template struct jit_uni_lrn_fwd_t<avx2>; +template struct jit_uni_lrn_bwd_t<avx2>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp new file mode 100644 index 0000000000..333cd3396d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_HPP +#define CPU_JIT_UNI_LRN_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_isa_traits.hpp" +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> struct jit_uni_lrn_fwd_kernel_f32; +template <cpu_isa_t isa> struct jit_uni_lrn_bwd_kernel_f32; + +template <cpu_isa_t isa> +struct jit_uni_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_fwd_t<isa>); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_fwd_t(const pd_t *apd); + ~jit_uni_lrn_fwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_fwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_; +}; + +template <cpu_isa_t isa> +struct jit_uni_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_lrn_bwd_t<isa>); + + status_t init(); + + format_tag_t dat_tag_; + }; + + jit_uni_lrn_bwd_t(const pd_t *apd); + ~jit_uni_lrn_bwd_t(); + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + jit_uni_lrn_bwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp new file mode 100644 index 0000000000..89af47272c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp @@ -0,0 +1,1487 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" + +#include "jit_uni_lrn_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +////////////////////////////////////////////////////////////////////////////// +// forward kernel +template<cpu_isa_t isa> +void jit_uni_lrn_fwd_kernel_f32<isa>::within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk) +{ + vxorps(ysum, ysum, ysum); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + vmovups(ydst, ptr[src]); + vfmadd231ps(ysum, ydst, ydst); + } + else + { + vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]); + vfmadd231ps(ysum, ytmp, ytmp); + } + } + } + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + vmovaps(ytmp, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ytmp); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75 + vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum + vmovups(ptr[dst], ydst); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template<cpu_isa_t isa> +void jit_uni_lrn_fwd_kernel_f32<isa>::within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk) +{ + Xbyak::Xmm xtmp_lo = xmm12; + Xbyak::Xmm xtmp_hi = xmm13; + Xbyak::Xmm xsum_lo = xmm8; + Xbyak::Xmm xsum_hi = xmm9; + Xbyak::Xmm xdst_lo = xmm10; + Xbyak::Xmm xdst_hi = xmm11; + Xbyak::Xmm xsum2_lo = xmm14; + Xbyak::Xmm xsum2_hi = xmm15; + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + for (int i = hoff; i <= Hoff; ++i) + { + for (int j = woff; j <= Woff; ++j) + { + if (i == 0 && j == 0) + { + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + addps(xsum_lo, xdst_lo); + addps(xsum_hi, xdst_hi); + } + else + { + movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]); + movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]); + mulps(xtmp_lo, xtmp_lo); + mulps(xtmp_hi, xtmp_hi); + addps(xsum_lo, xtmp_lo); + addps(xsum_hi, xtmp_hi); + } + } + } + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + movaps(xtmp_lo, xsum_lo); + movaps(xtmp_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xtmp_lo); + movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi); + } + movaps(xsum2_lo, xsum_lo); + movaps(xsum2_hi, xsum_hi); + mulps(xsum2_lo, xsum_lo); + mulps(xsum2_hi, xsum_hi); + mulps(xsum_lo, xsum2_lo); + mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3; + + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75 + + movups(xdst_lo, ptr[src]); + movups(xdst_hi, ptr[src + 4 * sizeof(float)]); + divps(xdst_lo, xsum_lo); + divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum + + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); +} + +template <cpu_isa_t isa> +jit_uni_lrn_fwd_kernel_f32<isa>::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 h = r9; + Xbyak::Reg64 w = r10; + Vmm ysum = Vmm(isa == avx2 ? 9 : 9); + Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10); + Vmm ydst = Vmm(isa == avx2 ? 11 : 11); + Vmm ytmp = Vmm(isa == avx2 ? 12 : 12); + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + if (isa == avx2) { + vbroadcastss(yalpha, xalpha); + } else { + shufps(xalpha, xalpha, 0); + } + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + if (isa == avx2) { + vbroadcastss(yk, xk); + } else { + shufps(xk, xk, 0); + } + + int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1; + + for (int i = 0; i < s2; ++i) + { + Label label_t; + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } + else { + within_body_sse42(-i, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + L(label_t); + if (isa == avx2) { + within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_t, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-i, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + mov(h, J.H - J.size + 1); + Label lrn_loop_h; + L(lrn_loop_h); + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -j, S2, J.W, pk); + } + } + mov(w, J.W - J.size + 1); + Label lrn_loop_w; + L(lrn_loop_w); + if (isa == avx2) { + within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(lrn_loop_w, T_NEAR); + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, S2, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk); + } + } + dec(h); + cmp(h, 0); + jne(lrn_loop_h, T_NEAR); + + for (int i = J.H - S2; i < J.H; ++i) + { + for (int j = 0; j < s2; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -j, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk); + } + } + + mov(w, J.W - J.size + 1); + Label label_b; + L(label_b); + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, S2, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk); + } + dec(w); + cmp(w, 0); + jne(label_b, T_NEAR); + + for (int j = J.W - S2; j < J.W; ++j) { + if (isa == avx2) { + within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, + ysum, ydst, ytmp, ysum2, pk); + } else { + within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk); + } + } + } + + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + Xbyak::Xmm xsrc_prev = xmm2; + Xbyak::Ymm ysrc = ymm3; + Xbyak::Ymm yc = ymm3; + Xbyak::Xmm xsrc_next = xmm4; + Xbyak::Ymm ya = ymm5; + Xbyak::Ymm yb = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ye = ymm8; + Xbyak::Ymm ysum = ymm9; + Xbyak::Ymm ysum2 = ymm10; + Xbyak::Ymm ydst = ymm11; + Xbyak::Ymm ybase = ymm12; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + if (J.version == -1) + { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(ysrc, ptr[src]); + if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev); + vmovups(ptr[t + 16], ysrc); + if (J.version != +1) vmovups(ptr[t + 48], xsrc_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vmulps(ysum, yc, yc); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya + vfmadd231ps(ysum, yb, yb); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk + + vmovaps(ybase, ysum); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ysum2, ysum, ysum); + vmulps(ysum, ysum, ysum2); // ysum = ybase^3; + vsqrtps(ysum, ysum); + vsqrtps(ysum, ysum); // ysum = ybase^0.75 + vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum + vmovups(ptr[dst], ydst); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r9; + + Xbyak::Xmm xsrc_lo = xmm2; + Xbyak::Xmm xsrc_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xsum_lo = xc_lo; + Xbyak::Xmm xsum_hi = xc_hi; + Xbyak::Xmm xsrc_prev = xmm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xa_lo = xmm8; + Xbyak::Xmm xa_hi = xmm9; + Xbyak::Xmm xb_lo = xmm10; + Xbyak::Xmm xb_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + sub(t, 64); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + if (J.version == -1) + { + xorps(xsrc_prev, xsrc_prev); + movups(ptr[t + 0], xsrc_prev); + } + if (J.version == +1) + { + xorps(xsrc_next, xsrc_next); + movups(ptr[t + 48], xsrc_next); + } + + mov(hw, J.H*J.W); + Label lrn_loop; + L(lrn_loop); + + if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + movups(xsrc_lo, ptr[src]); + movups(xsrc_hi, ptr[src + 4 * sizeof(float)]); + if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]); + + if (J.version != -1) movups(ptr[t + 0], xsrc_prev); + movups(ptr[t + 16], xsrc_lo); + movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi); + if (J.version != +1) movups(ptr[t + 48], xsrc_next); + + movups(xa_lo, ptr[t + 16 - 8]); + movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]); + movups(xb_lo, ptr[t + 16 - 4]); + movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]); + movups(xd_lo, ptr[t + 16 + 4]); + movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[t + 16 + 8]); + movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]); + movaps(xc_lo, xsrc_lo); + movaps(xc_hi, xsrc_hi); + mulps(xsum_lo, xc_lo); + mulps(xsum_hi, xc_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + mulps(xsum_lo, xalpha); + mulps(xsum_hi, xalpha); + addps(xsum_lo, xk); + addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk + + movaps(xbase_lo, xsum_lo); + movaps(xbase_hi, xsum_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xsum_lo, xsum_lo); + mulps(xsum_hi, xsum_hi); + mulps(xsum_lo, xbase_lo); + mulps(xsum_hi, xbase_hi); // xsum = xbase^3; + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); + sqrtps(xsum_lo, xsum_lo); + sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75 + divps(xsrc_lo, xsum_lo); + divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum + movups(ptr[dst], xsrc_lo); + movups(ptr[dst + 4 * sizeof(float)], xsrc_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0x80000000, 0, 0 + }; + + Xbyak::Reg64 c = r9; + Xbyak::Ymm ya = ymm2; + Xbyak::Ymm yb = ymm3; + Xbyak::Ymm yc = ymm4; + Xbyak::Ymm yd = ymm5; + Xbyak::Ymm ye = ymm6; + Xbyak::Ymm ysum = ymm7; + Xbyak::Ymm ydst = ymm8; + Xbyak::Ymm ybase = ymm9; + Xbyak::Ymm ymask = ymm10; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + vxorps(ysum, ysum, ysum); + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[0])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ya, ymask, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[1])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yb, ymask, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + vmovups(yc, ptr[src]); + vmovups(yd, ptr[src + 4]); + vmovups(ye, ptr[src + 8]); + vfmadd231ps(ysum, yc, yc); + vfmadd231ps(ysum, yd, yd); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + vmovups(ptr[dst], ydst); + + vxorps(ysum, ysum, ysum); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + vmovups(ya, ptr[src - 8]); + vfmadd231ps(ysum, ya, ya); + vmovups(yb, ptr[src - 4]); + vfmadd231ps(ysum, yb, yb); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vmovups(yc, ptr[src]); + vfmadd231ps(ysum, yc, yc); + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[2])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(yd, ymask, ptr[src + 4]); + vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[3])); + vmovups(ymask, ptr[imm_addr64]); + vmaskmovps(ye, ymask, ptr[src + 8]); + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + vmovups(ptr[scratch], ybase); + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + vmovups(ptr[dst], ydst); + + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0xffffffff, 0, 0 + }; + + static uint32_t store[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r9; + + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xa_lo = xmm2; + Xbyak::Xmm xa_hi = xmm3; + Xbyak::Xmm xb_lo = xmm2; + Xbyak::Xmm xb_hi = xmm3; + Xbyak::Xmm xc_lo = xmm4; + Xbyak::Xmm xc_hi = xmm5; + Xbyak::Xmm xd_lo = xmm6; + Xbyak::Xmm xd_hi = xmm7; + Xbyak::Xmm xe_lo = xmm8; + Xbyak::Xmm xe_hi = xmm9; + Xbyak::Xmm xsum_lo = xmm10; + Xbyak::Xmm xsum_hi = xmm11; + Xbyak::Xmm xmask_lo = xmm12; + Xbyak::Xmm xmask_hi = xmm13; + Xbyak::Xmm xbase_lo = xmm14; + Xbyak::Xmm xbase_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + mov(store_addr, reinterpret_cast<size_t>(&store[0])); + and_(store_addr, -15); + movups(ptr[store_addr], xalpha); + movups(ptr[store_addr + 4 * sizeof(float)], xk); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[0])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + andps(xa_lo, xmask_lo); + andps(xa_hi, xmask_hi); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[1])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + andps(xb_lo, xmask_lo); + andps(xb_hi, xmask_hi); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + mov(c, J.C / 8 - 1); + Label lrn_loop; + L(lrn_loop); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + + add(src, 32); + add(dst, 32); + if (pk != prop_kind::forward_inference) + add(scratch, 32); + + movups(xa_lo, ptr[src - 8]); + movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + addps(xsum_lo, xa_lo); + addps(xsum_hi, xa_hi); + movups(xb_lo, ptr[src - 4]); + movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]); + mulps(xb_lo, xb_lo); + mulps(xb_hi, xb_hi); + addps(xsum_lo, xb_lo); + addps(xsum_hi, xb_hi); + + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[2])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xd_lo, ptr[src + 4]); + movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(imm_addr64, reinterpret_cast<size_t>(&mask[3])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + movups(xe_lo, ptr[src + 8]); + movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + movups(xdst_lo, xsum_lo); + movups(xdst_hi, xsum_hi); + // xdst <- xsum*xalpha+xk + mulps(xdst_lo, ptr[store_addr]); + mulps(xdst_hi, ptr[store_addr]); + addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movups(xc_lo, ptr[src]); + movups(xc_hi, ptr[src + 4 * sizeof(float)]); + divps(xc_lo, xdst_lo); + divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + + movups(ptr[dst], xc_lo); + movups(ptr[dst + 4 * sizeof(float)], xc_hi); + + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) {} + +template<> +void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body( + int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum) +{ + Xbyak::Ymm ydst = ymm14; + Xbyak::Ymm ybase = ymm15; + + vfmadd231ps(ysum, ye, ye); + + vmovups(ydst, ysum); + vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk + + vmovaps(ybase, ydst); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) + vmaskmovps(ptr[scratch], ymask, ybase); + else + vmovups(ptr[scratch], ybase); + } + vmulps(ydst, ydst, ydst); + vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3; + vsqrtps(ydst, ydst); + vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75 + vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75 + + if (tail != 0) + vmaskmovps(ptr[dst], ymask, ydst); + else + vmovups(ptr[dst], ydst); + + + vfnmadd231ps(ysum, ya, ya); + vmovups(ya, yb); + vmovups(yb, yc); + vmovups(yc, yd); + vmovups(yd, ye); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{} + +template<> +void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42( + int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi) +{ + Xbyak::Xmm xmm_tmp = xmm10; + movaps(xmm_tmp, xtail_lo); + size_t offset = 0; + + if (tail > 4) { + movups(ptr[reg_dst], xtail_lo); + movaps(xmm_tmp, xtail_hi); + offset += 4 * sizeof(float); + tail -= 4; + } + movss(ptr[reg_dst + offset], xmm_tmp); + for (int i = 1; i < tail; i++) + { + psrldq(xmm_tmp, 4); + movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp); + } +} + +template<> +void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) +{ + Xbyak::Xmm xdst_lo = xmm0; + Xbyak::Xmm xdst_hi = xmm1; + Xbyak::Xmm xbase_lo = xmm6; + Xbyak::Xmm xbase_hi = xmm7; + Xbyak::Xmm xtmp_lo = xmm8; + Xbyak::Xmm xtmp_hi = xmm9; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + + // store xe + movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi); + + mulps(xe_lo, xe_lo); + mulps(xe_hi, xe_hi); + addps(xsum_lo, xe_lo); + addps(xsum_hi, xe_hi); + + // xdst <- xsum*xalpha+xk + movaps(xdst_lo, xsum_lo); + movaps(xdst_hi, xsum_hi); + mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]); + mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]); + addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]); + addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]); + + movaps(xbase_lo, xdst_lo); + movaps(xbase_hi, xdst_hi); + if (pk != prop_kind::forward_inference) + { + if (tail != 0) { + nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi); + } + else { + movups(ptr[scratch], xbase_lo); + movups(ptr[scratch + 4 * sizeof(float)], xbase_hi); + } + } + mulps(xdst_lo, xdst_lo); + mulps(xdst_hi, xdst_hi); + mulps(xdst_lo, xbase_lo); + mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3; + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); + sqrtps(xdst_lo, xdst_lo); + sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75 + movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + divps(xtmp_lo, xdst_lo); + divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75 + movaps(xdst_lo, xtmp_lo); + movaps(xdst_hi, xtmp_hi); + + if (tail != 0) { + nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi); + } + else { + movups(ptr[dst], xdst_lo); + movups(ptr[dst + 4 * sizeof(float)], xdst_hi); + } + + movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]); + movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]); + mulps(xa_lo, xa_lo); + mulps(xa_hi, xa_hi); + subps(xsum_lo, xa_lo); + subps(xsum_hi, xa_hi); + + // xa <- xb + movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]); + movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi); + + // xb <- xc + movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]); + movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi); + + // xc <- xd + movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]); + movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi); + + // xd <- xe + movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]); + movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi); +} + +template<> +void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42( + int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {} + +template<> +jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, + 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0 + }; + Xbyak::Reg64 c = r10; + Xbyak::Ymm ymask = ymm2; + Xbyak::Ymm ye = ymm3; + Xbyak::Ymm ya = ymm4; + Xbyak::Ymm yb = ymm5; + Xbyak::Ymm yc = ymm6; + Xbyak::Ymm yd = ymm7; + Xbyak::Ymm ysum = ymm8; + + this->preamble(); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail])); + vmovups(ymask, ptr[imm_addr64]); + } + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + vbroadcastss(yalpha, xalpha); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + vbroadcastss(yk, xk); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + vxorps(ya, ya, ya); + vxorps(yb, yb, yb); + if (J.tail != 0) + vmaskmovps(yc, ymask, ptr[src + J.HW * 0]); + else + vmovups(yc, ptr[src + J.HW * 0]); + if (J.tail != 0) + vmaskmovps(yd, ymask, ptr[src + J.HW * 4]); + else + vmovups(yd, ptr[src + J.HW * 4]); + + vxorps(ysum, ysum, ysum); + vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2 + vfmadd231ps(ysum, yd, yd); + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) + vmaskmovps(ye, ymask, ptr[src + J.HW * 8]); + else + vmovups(ye, ptr[src + J.HW * 8]); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + vxorps(ye, ye, ye); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum); + + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template<> +jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , alpha(A), k(K) +{ + static const uint32_t mask[] = { + 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, + 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0 + }; + + Xbyak::Reg64 c = r10; + + Xbyak::Xmm xmask_lo = xmm2; + Xbyak::Xmm xmask_hi = xmm3; + Xbyak::Xmm xsum_lo = xmm4; + Xbyak::Xmm xsum_hi = xmm5; + Xbyak::Xmm xa_lo = xmm6; + Xbyak::Xmm xa_hi = xmm7; + Xbyak::Xmm xb_lo = xmm8; + Xbyak::Xmm xb_hi = xmm9; + Xbyak::Xmm xc_lo = xmm10; + Xbyak::Xmm xc_hi = xmm11; + Xbyak::Xmm xd_lo = xmm12; + Xbyak::Xmm xd_hi = xmm13; + Xbyak::Xmm xe_lo = xmm14; + Xbyak::Xmm xe_hi = xmm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(dst, ptr[this->param1 + 8]); + if (pk != prop_kind::forward_inference) + mov(scratch, ptr[this->param1 + 16]); + + sub(rsp, stack_space_needed); + mov(store_addr, rsp); + and_(store_addr, -15); + + mov(imm_addr64, float2int(this->alpha)); + movq(xalpha, imm_addr64); + shufps(xalpha, xalpha, 0); + + mov(imm_addr64, float2int(this->k)); + movq(xk, imm_addr64); + shufps(xk, xk, 0); + + // put alpha and k into store (free up regs) + movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha); + movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk); + + if (J.tail != 0) + { + mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail])); + movups(xmask_lo, ptr[imm_addr64]); + movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]); + } + // init xa, xb + xorps(xa_lo, xa_lo); + xorps(xa_hi, xa_hi); + xorps(xb_lo, xb_lo); + xorps(xb_hi, xb_hi); + + // read xc, xd + if (J.tail != 0) { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + andps(xc_lo, xmask_lo); + andps(xc_hi, xmask_hi); + } + else { + movups(xc_lo, ptr[src + J.HW * 0]); + movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]); + } + if (J.tail != 0) { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + andps(xd_lo, xmask_lo); + andps(xd_hi, xmask_hi); + } + else { + movups(xd_lo, ptr[src + J.HW * 4]); + movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]); + } + + // put xa, xb, xc, xd into store to free-up regs + movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo); + movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi); + movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo); + movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi); + movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo); + movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi); + movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo); + movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi); + + xorps(xsum_lo, xsum_lo); + xorps(xsum_hi, xsum_hi); + mulps(xc_lo, xc_lo); + mulps(xc_hi, xc_hi); + addps(xsum_lo, xc_lo); + addps(xsum_hi, xc_hi); + mulps(xd_lo, xd_lo); + mulps(xd_hi, xd_hi); + addps(xsum_lo, xd_lo); + addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2 + + mov(c, J.C - 2); + Label lrn_loop; + L(lrn_loop); + + if (J.tail != 0) { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + andps(xe_lo, xmask_lo); + andps(xe_hi, xmask_hi); + } + else { + movups(xe_lo, ptr[src + J.HW * 8]); + movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]); + } + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + dec(c); + cmp(c, 0); + jne(lrn_loop, T_NEAR); + + xorps(xe_lo, xe_lo); + xorps(xe_hi, xe_hi); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + add(src, J.HW * 4); + add(dst, J.HW * 4); + if (pk != prop_kind::forward_inference) + add(scratch, J.HW * 4); + + nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi, + xe_lo, xe_hi, + xsum_lo, xsum_hi); + + add(rsp, stack_space_needed); + + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +////////////////////////////////////////////////////////////////////////////// +// backward kernel +template <cpu_isa_t isa> +jit_uni_lrn_bwd_kernel_f32<isa>::jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr, + size_t code_size) + : jit_generator(code_ptr, code_size) + , nalphabeta(-2 * A*B) + , use_h_parallelizm(use_h_parallel) +{ + Xbyak::Reg64 t = rsp; + Xbyak::Reg64 hw = r10; + + Xbyak::Xmm xsrc_prev = xmm1; + Xbyak::Xmm xws_prev = xmm2; + Xbyak::Xmm xdiffdst_prev = xmm3; + Xbyak::Ymm ysrc = ymm4; + Xbyak::Ymm yws = ymm5; + Xbyak::Ymm ydiffdst = ymm6; + Xbyak::Xmm xsrc_next = xmm7; + Xbyak::Xmm xws_next = xmm8; + Xbyak::Xmm xdiffdst_next = xmm9; + Xbyak::Ymm ya = ymm10; + Xbyak::Xmm xa = xmm10; + Xbyak::Ymm yb = ymm11; + Xbyak::Ymm yd = ymm12; + Xbyak::Ymm ye = ymm13; + Xbyak::Ymm ysum = ymm14; + Xbyak::Ymm ydiffsrc = ymm15; + + this->preamble(); + + mov(src, ptr[this->param1 + 0]); + mov(diffdst, ptr[this->param1 + 8]); + mov(workspace, ptr[this->param1 + 16]); + mov(diffsrc, ptr[this->param1 + 24]); + + sub(t, 64); + mov(imm_addr64, float2int(this->nalphabeta)); + movq(xnalphabeta, imm_addr64); + vbroadcastss(ynalphabeta, xnalphabeta); + + bool is_single = J.version == 3; + bool is_first = J.version == -1 || J.version == -2; + bool is_last = J.version == +1 || J.version == -2; + + if (is_first || is_single) { + vxorps(xsrc_prev, xsrc_prev, xsrc_prev); + vmovups(ptr[t + 0], xsrc_prev); + } + if (is_last || is_single) { + vxorps(xsrc_next, xsrc_next, xsrc_next); + vmovups(ptr[t + 48], xsrc_next); + } + mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W); + Label lrn_loop; + L(lrn_loop); + { + if (!is_first && !is_single) { + vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]); + vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]); + vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]); + vmulps(xa, xws_prev, xws_prev); + vmulps(xa, xa, xws_prev); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_prev); + vdivps(xsrc_prev, xsrc_prev, xa); + vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev); + } + + vmovups(ysrc, ptr[src]); + vmovups(yws, ptr[workspace]); + vmovups(ydiffdst, ptr[diffdst]); + vmulps(ya, yws, yws); + vmulps(ya, ya, yws); + vsqrtps(ya, ya); + vsqrtps(ya, ya); + vdivps(ydiffsrc, ydiffdst, ya); + vdivps(ysum, ydiffsrc, yws); + vmulps(ysum, ysum, ysrc); + + if (!is_last && !is_single) { + vmovups(xws_next, ptr[workspace + J.H*J.W * 32]); + vmovups(xsrc_next, ptr[src + J.H*J.W * 32]); + vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]); + vmulps(xa, xws_next, xws_next); + vmulps(xa, xa, xws_next); + vsqrtps(xa, xa); + vsqrtps(xa, xa); + vmulps(xa, xa, xws_next); + vdivps(xsrc_next, xsrc_next, xa); + vdivps(xsrc_next, xsrc_next, xws_next); + vmulps(xdiffdst_next, xdiffdst_next, xsrc_next); + } + + if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev); + vmovups(ptr[t + 16], ysum); + if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next); + + vmovups(ya, ptr[t + 16 - 8]); + vmovups(yb, ptr[t + 16 - 4]); + vaddps(ysum, ysum, ya); + vmulps(ysrc, ysrc, ynalphabeta); + vaddps(ysum, ysum, yb); + + vmovups(yd, ptr[t + 16 + 4]); + vmovups(ye, ptr[t + 16 + 8]); + vaddps(ysum, ysum, yd); + vaddps(ysum, ysum, ye); + + vfmadd231ps(ydiffsrc, ysum, ysrc); + + vmovups(ptr[diffsrc], ydiffsrc); + + add(src, 32); + add(diffsrc, 32); + add(diffdst, 32); + add(workspace, 32); + + dec(hw); + cmp(hw, 0); + jne(lrn_loop, T_NEAR); + } + + add(t, 64); + this->postamble(); + + ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>( + this->getCode())); +} + +template struct jit_uni_lrn_fwd_kernel_f32<sse42>; +template struct jit_uni_lrn_fwd_kernel_f32<avx2>; +template struct jit_uni_lrn_bwd_kernel_f32<avx2>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp new file mode 100644 index 0000000000..2b3ed43cd4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP +#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 }; + +typedef struct { + const float *src; + float *dst, *scratch; +} jit_args_fwd_t; + +typedef struct { + const float *src, *diff_dst, *scratch; + float *diff_src; +} jit_args_bwd_t; + +struct nchw8c_across { + /* version: + * -1: channels 0..7, + * 1: channels C-8 .. C-1, + * 0: other channels + * 3: channels only for this kernel(without prev and next) + */ + int H, W, version; + nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {} +}; + +struct nchw8c_within { + int H, W, size; + nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {} +}; + +struct nchw_across { + int C, HW, tail; + nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {} +}; + +struct nhwc_across { + int C; + nhwc_across(int c) : C(c) {} +}; + +template <cpu_isa_t isa> +struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 dst = r8; + Xbyak::Reg64 scratch = rdx; + Xbyak::Reg64 imm_addr64 = rbx; + Xbyak::Reg64 store_addr = rbp; + + Xbyak::Xmm xalpha = xmm0; + Xbyak::Ymm yalpha = ymm0; + Xbyak::Xmm xk = xmm1; + Xbyak::Ymm yk = ymm1; + + float alpha; + float k; + + int stack_space_needed = 11 * 4 * sizeof(float) + 16; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32) + + /* cpu specific part */ + using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type; + + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_within &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + const struct nhwc_across &J, + float A, + float K, + prop_kind_t pk, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + jit_uni_lrn_fwd_kernel_f32( + struct nchw_across J, + float A, + float K, + prop_kind_t pk, + void* code_ptr = nullptr, + size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void within_body( + int hoff, int Hoff, int woff, int Woff, int stride, + Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2, + prop_kind_t pk); + void within_body_sse42( + int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk); + + + void nchw_body(int tail, int HW, prop_kind_t pk, + Xbyak::Ymm ymask, + Xbyak::Ymm ya, + Xbyak::Ymm yb, + Xbyak::Ymm yc, + Xbyak::Ymm yd, + Xbyak::Ymm ye, + Xbyak::Ymm ysum); + void nchw_body_sse42(int tail, int HW, prop_kind_t pk, + Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi, + Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi, + Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); + void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst, + Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi); + + void operator()(jit_args_fwd_t *arg) { ker(arg); } + void(*ker)(jit_args_fwd_t *); +}; + +template <cpu_isa_t isa> +struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator { + Xbyak::Reg64 src = rax; + Xbyak::Reg64 diffsrc = r8; + Xbyak::Reg64 diffdst = r9; + Xbyak::Reg64 workspace = rdx; + Xbyak::Reg64 imm_addr64 = rsi; + + Xbyak::Xmm xnalphabeta = xmm0; + Xbyak::Ymm ynalphabeta = ymm0; + + float nalphabeta; + + int use_h_parallelizm; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32) + + jit_uni_lrn_bwd_kernel_f32( + const struct nchw8c_across &J, + float A, + float B, + int use_h_parallel, + void *code_ptr = nullptr, + size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); + + void operator()(jit_args_bwd_t *arg) { ker(arg); } + void(*ker)(jit_args_bwd_t *); +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp new file mode 100644 index 0000000000..bf8e609d23 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp @@ -0,0 +1,699 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "utils.hpp" +#include "cpu_pooling_pd.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; +using namespace alg_kind; + +#define GET_OFF(field) offsetof(jit_pool_call_s, field) + +template <cpu_isa_t isa> +status_t jit_uni_pool_kernel_f32<isa>::init_conf(jit_pool_conf_t &jpp, + const pooling_pd_t *ppd) { + const auto &pd = *ppd->desc(); + const memory_desc_wrapper src_d( + ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md()); + const memory_desc_wrapper dst_d( + ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md()); + + bool args_ok = true + && mayiuse(isa) + && utils::one_of(pd.alg_kind, pooling_max, + pooling_avg_include_padding, + pooling_avg_exclude_padding); + if (!args_ok) return status::unimplemented; + + const int simd_w = isa == avx512_common ? 16 : 8; + const int ndims = src_d.ndims(); + + jpp.ndims = ndims; + jpp.mb = src_d.dims()[0]; + + jpp.c = utils::rnd_up(src_d.dims()[1], simd_w); + if (jpp.c > src_d.padded_dims()[1]) + return status::unimplemented; + + jpp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jpp.ih = src_d.dims()[ndims-2]; + jpp.iw = src_d.dims()[ndims-1]; + jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jpp.oh = dst_d.dims()[ndims-2]; + jpp.ow = dst_d.dims()[ndims-1]; + + jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1; + jpp.stride_h = pd.strides[ndims-4]; + jpp.stride_w = pd.strides[ndims-3]; + jpp.kd = (ndims == 5) ? pd.kernel[0] : 1; + jpp.kh = pd.kernel[ndims-4]; + jpp.kw = pd.kernel[ndims-3]; + + jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0; + jpp.t_pad = pd.padding[0][ndims-4]; + jpp.l_pad = pd.padding[0][ndims-3]; + + jpp.alg = pd.alg_kind; + + jpp.is_training = pd.prop_kind == prop_kind::forward_training; + jpp.is_backward = pd.prop_kind == prop_kind::backward_data; + jpp.ind_dt = ppd->workspace_md() + ? ppd->workspace_md()->data_type : data_type::undef; + + jpp.simple_alg = jpp.is_training + || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d); + + jpp.c_block = simd_w; + + jpp.nb_c = jpp.c / jpp.c_block; + if (jpp.alg == pooling_max) { + jpp.ur_w = isa == avx512_common ? 16 : 4; + if (jpp.is_training) + jpp.ur_w = isa == avx512_common ? 9 : 3; + else if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 6 : 3; + } else { + if (jpp.is_backward) + jpp.ur_w = isa == avx512_common ? 12 : 6; + else + jpp.ur_w = isa == avx512_common ? 24 : 12; + } + if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow; + if (jpp.l_pad > jpp.ur_w) return status::unimplemented; + + jpp.ur_w_tail = jpp.ow % jpp.ur_w; + + return status::success; +} + +template <cpu_isa_t isa> +inline void jit_uni_pool_kernel_f32<isa>::maybe_recalculate_divisor(int jj, + int ur_w, int pad_l, int pad_r) { + if (jpp.alg == pooling_avg_exclude_padding) { + int kw = jpp.kw; + int stride_w = jpp.stride_w; + + int non_zero_kw = kw; + non_zero_kw -= nstl::max(0, pad_l - jj*stride_w); + non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w); + + if (non_zero_kw != prev_kw) { + mov(tmp_gpr, float2int((float)non_zero_kw)); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h); + prev_kw = non_zero_kw; + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_pool_kernel_f32<isa>::avg_step(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + if (jpp.is_backward) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + } else { + uni_vpxor(vreg(jj), vreg(jj), vreg(jj)); + } + } + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + if (jpp.is_backward) { + uni_vmovups(vreg(ur_w+jj), + ptr[aux_reg_input + input_offset]); + uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj)); + uni_vmovups(vmmword[aux_reg_input + input_offset], + vreg(ur_w+jj)); + } else { + uni_vaddps(vreg(jj), vreg(jj), + ptr[aux_reg_input + input_offset]); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + if (!jpp.is_backward) { + for (int jj = 0; jj < ur_w; jj++) { + maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r); + uni_vdivps(vreg(jj), vreg(jj), vmm_tmp); + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], + vreg(jj)); + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_pool_kernel_f32<isa>::max_step_fwd(int ur_w, int pad_l, + int pad_r) { + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest())); + movq(xmm_tmp, tmp_gpr); + uni_vbroadcastss(vmm_tmp, xmm_tmp); + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), vmm_tmp); + if (jpp.is_training) + uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj)); + } + if (jpp.is_training) + { + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + } + + if (jpp.ndims == 5) { + push(reg_input); + push(reg_output); + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + movups(vmm_mask, vreg(jj)); + cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os); + blendvps(vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + blendvps(vreg(2*ur_w+jj), vmm_k_offset); + } else if (isa == avx) { + vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj), + _cmp_lt_os); + vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj), + vreg(3*ur_w+jj)); + if (jpp.is_training) + vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), + vmm_k_offset, vreg(3*ur_w+jj)); + } else { + vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os); + vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj)); + if (jpp.is_training) + vblendmps(vreg(2*ur_w+jj) | k_store_mask, + vreg(2*ur_w+jj), vmm_k_offset); + } + } + if (jpp.is_training) { + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + + if (jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + if (jpp.is_training) { + mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, xmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + pop(reg_output); + pop(reg_input); + } + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj)); + if (jpp.is_training) { + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + + auto x = xreg(2 * ur_w + jj); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + for (int i = 0; i < 4; ++i) + pextrb(ptr[reg_index + step_index + i], x, 4*i); + } else if (isa == avx) { + auto y = yreg(2 * ur_w + jj); + if (jj == 0) { + movd(xmm_tmp, reg_shuf_mask); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (mayiuse(avx2)) { + vpshufb(y, y, vmm_tmp); + movd(ptr[reg_index + step_index], x); + vperm2i128(y, y, y, 0x1u); + movd(ptr[reg_index + step_index + 4], x); + } else { + Xmm t(vmm_mask.getIdx()); + vextractf128(t, y, 0); + vpshufb(t, t, xmm_tmp); + movd(ptr[reg_index + step_index], t); + vextractf128(t, y, 1); + vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0] + movd(ptr[reg_index + step_index + 4], t); + } + } else { + auto v = vreg(2 * ur_w + jj); + vpmovusdb(x, v); + vmovups(ptr[reg_index + step_index], v | k_index_mask); + } + } else { + uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj)); + } + } + } +} + +template <cpu_isa_t isa> +inline void jit_uni_pool_kernel_f32<isa>::max_step_bwd(int ur_w, int pad_l, + int pad_r) { + + int iw = jpp.iw; + int kw = jpp.kw; + int stride_w = jpp.stride_w; + int c_block = jpp.c_block; + Label kd_label, kh_label; + + for (int jj = 0; jj < ur_w; jj++) { + uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]); + + const size_t step_index + = jj * c_block * types::data_type_size(jpp.ind_dt); + if (jpp.ind_dt == data_type::u8) { + if (isa == sse42) { + movd(xreg(ur_w+jj), ptr[reg_index + step_index]); + pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } else if (isa == avx) { + movq(xreg(ur_w+jj), ptr[reg_index + step_index]); + if (!mayiuse(avx2)) { + avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp); + } else { + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + vmovups(vreg(ur_w+jj) | k_index_mask, + ptr[reg_index + step_index]); + vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj)); + } + } else { + uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]); + } + } + movq(xmm_tmp, reg_k_shift); + uni_vpbroadcastd(vmm_k_offset, xmm_tmp); + + if (jpp.simple_alg && jpp.ndims == 5) { + push(reg_input); + push(reg_output); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + push(dst_ptr); + } + mov(aux_reg_input_d, reg_input); + mov(ki, ptr[reg_param + GET_OFF(kd_padding)]); + mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]); + L(kd_label); + mov(aux_reg_input, aux_reg_input_d); + } else { + mov(aux_reg_input, reg_input); + } + + xor_(kj, kj); + L(kh_label); + { + for (int ki = 0; ki < kw; ki++) { + int jj_start = nstl::max(0, pad_l - ki); + int jj_end = ur_w + - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w); + for (int jj = jj_start; jj < jj_end; jj++) { + int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block; + if (aux_input_offset > iw * c_block) + continue; + int input_offset = sizeof(float)*aux_input_offset; + uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]); + if (isa == sse42) { + mov(dst_ptr, aux_reg_input); + add(dst_ptr, input_offset); + + movups(vreg(3*ur_w+jj), vreg(ur_w+jj)); + pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset); + addps(vreg(2*ur_w+jj), vreg(jj)); + maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj)); + } else if (isa == avx) { + if (mayiuse(avx2)) { + vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset); + } else { + avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp); + } + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj)); + vmaskmovps(vmmword[aux_reg_input + input_offset], + vreg(3*ur_w+jj), vreg(2*ur_w+jj)); + } else { + vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset); + vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj)); + vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp); + vmovups(vmmword[aux_reg_input + + sizeof(float)*aux_input_offset], vreg(2*ur_w+jj)); + } + } + if (isa == avx && !mayiuse(avx2)) { + avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one); + } + } + add(aux_reg_input, sizeof(float) * iw * c_block); + inc(kj); + cmp(kj, reg_kh); + jl(kh_label, T_NEAR); + } + if (jpp.simple_alg && jpp.ndims == 5) + { + add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block); + + mov(tmp_gpr, reg_kd_pad_shift); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + if (isa == avx && !mayiuse(avx2)) { + Xmm t(vmm_mask.getIdx()); + avx_vpadd1(vmm_k_offset, vmm_tmp, t); + } else { + uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp); + } + + dec(ki); + cmp(ki, 0); + jg(kd_label, T_NEAR); + if (isa == sse42) { + // Save rdi since it is used in maskmovdqu + assert(dst_ptr == rdi); + pop(dst_ptr); + } + pop(reg_output); + pop(reg_input); + } +} + +template <cpu_isa_t isa> +void jit_uni_pool_kernel_f32<isa>::maybe_zero_diff_src() { + assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0); + Label l_skip, l_zero; + + auto reg_oh = tmp_gpr; + mov(reg_oh, ptr[reg_param + GET_OFF(oh)]); + cmp(reg_oh, 0); + jz(l_skip, T_NEAR); + + if (jpp.ndims == 5) { + mov(zero_size, ptr[reg_param + GET_OFF(oh)]); + mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float)); + imul(zero_size, tmp_gpr); + } + + auto vzero = vmm_tmp; + uni_vpxor(vzero, vzero, vzero); + + auto reg_off = tmp_gpr; + xor_(reg_off, reg_off); + + L(l_zero); + { + const int dim = jpp.iw * jpp.c_block * sizeof(float); + for (int i = 0; i < dim; i += cpu_isa_traits<isa>::vlen) + uni_vmovups(ptr[reg_input + reg_off + i], vzero); + add(reg_off, dim); + if (jpp.ndims == 5) cmp(reg_off, zero_size); + else cmp(reg_off, jpp.ih * dim); + jl(l_zero, T_NEAR); + } + + L(l_skip); +} + +template <cpu_isa_t isa> +void jit_uni_pool_kernel_f32<isa>::generate() { + + this->preamble(); + + int ow = jpp.ow; + int iw = jpp.iw; + int kw = jpp.kw; + int kh = jpp.kh; + int ur_w = jpp.ur_w; + int c_block = jpp.c_block; + int stride_w = jpp.stride_w; + int l_pad = jpp.l_pad; + int ur_w_tail = jpp.ur_w_tail; + + int n_oi = ow / ur_w; + + prev_kw = 0; + + int vlen = cpu_isa_traits<isa>::vlen; + +#if defined(_WIN32) + // Always mimic the Unix ABI (see the note about maskmovdqu in the header + // file). + xor_(rdi, rcx); + xor_(rcx, rdi); + xor_(rdi, rcx); +#endif + + mov(reg_input, ptr[reg_param + GET_OFF(src)]); + mov(reg_output, ptr[reg_param + GET_OFF(dst)]); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + mov(reg_index, ptr[reg_param + GET_OFF(indices)]); + mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]); + mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]); + mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]); + + if (jpp.is_backward) + maybe_zero_diff_src(); + + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { + mov(tmp_gpr, 1); + movq(xmm_one, tmp_gpr); + uni_vpbroadcastd(vmm_one, xmm_one); + + if (isa == avx) { + mov(reg_shuf_mask, 0x0c080400); + } else if (isa >= avx512_common) { + mov(tmp_gpr.cvt32(), 0x000f); + kmovw(k_index_mask, tmp_gpr.cvt32()); + } + } + + int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1)); + int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1); + if (r_pad1 > 0) n_oi--; + + if (jpp.alg == pooling_avg_exclude_padding) { + movq(xmm_ker_area_h, reg_ker_area_h); + uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h); + } + + if (jpp.alg == pooling_avg_include_padding) { + mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd))); + movq(xmm_tmp, tmp_gpr); + uni_vpbroadcastd(vmm_tmp, xmm_tmp); + } + if (l_pad > 0) { + n_oi--; + if (n_oi < 0 && r_pad1 > 0) { + step(ur_w, l_pad, r_pad1); + } else { + step(ur_w, l_pad, 0); + } + + if (isa == sse42) { + if (n_oi < 0 && r_pad1 > 0) { + step_high_half(ur_w, l_pad, r_pad1); + } else { + step_high_half(ur_w, l_pad, 0); + } + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + xor_(oi_iter, oi_iter); + if (n_oi > 0) { + Label ow_loop; + L(ow_loop); { + step(ur_w, 0, 0); + + if (isa == sse42) { + step_high_half(ur_w, 0, 0); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + + inc(oi_iter); + cmp(oi_iter, n_oi); + jl(ow_loop, T_NEAR); + } + } + + if (r_pad1 > 0 && n_oi >= 0) { + step(ur_w, 0, r_pad1); + + if (isa == sse42) { + step_high_half(ur_w, 0, r_pad1); + } + + if (isa == sse42) { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen); + add(reg_output, sizeof(float)*ur_w*c_block - vlen); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, (2 * ur_w - 1) * c_block / 2 + * types::data_type_size(jpp.ind_dt)); + } else { + add(reg_input, sizeof(float)*ur_w*stride_w*c_block); + add(reg_output, sizeof(float)*ur_w*c_block); + if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) + add(reg_index, ur_w * c_block + * types::data_type_size(jpp.ind_dt)); + } + } + + if (ur_w_tail != 0) { + step(ur_w_tail, 0, r_pad); + + if (isa == sse42) { + step_high_half(ur_w_tail, 0, r_pad); + } + } + + this->postamble(); +} + +template struct jit_uni_pool_kernel_f32<sse42>; +template struct jit_uni_pool_kernel_f32<avx>; // implements both <avx> and <avx2> +template struct jit_uni_pool_kernel_f32<avx512_common>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp new file mode 100644 index 0000000000..992b526587 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp @@ -0,0 +1,192 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* Copyright 2018 YANDEX LLC +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_UNI_POOL_KERNEL_F32_HPP +#define JIT_UNI_POOL_KERNEL_F32_HPP + +#include <cfloat> + +#include "c_types_map.hpp" +#include "pooling_pd.hpp" +#include "type_helpers.hpp" + +#include "jit_generator.hpp" +#include "jit_primitive_conf.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace Xbyak; + +template <cpu_isa_t isa> +struct jit_uni_pool_kernel_f32: public jit_generator { + jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp) + { + this->generate(); + jit_ker = (decltype(jit_ker))this->getCode(); + } + + jit_pool_conf_t jpp; + + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32) + + void operator()(jit_pool_call_s *arg) { jit_ker(arg); } + static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd); + +private: + using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx, + Ymm, Zmm>::type; + Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); } + Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); } + Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); } + + const AddressFrame &vmmword = (isa == sse42) ? xword : + (isa == avx) ? yword : zword; + + Xmm vmm_mask = Xmm(0); + Xmm xmm_ker_area_h = Xmm(2); + Xmm xmm_one = Xmm(2); + Xmm xmm_tmp = Xmm(3); + + Vmm vmm_ker_area_h = Vmm(2); + Vmm vmm_one = Vmm(2); + Vmm vmm_tmp = Vmm(3); + + Vmm vmm_k_offset = Vmm(1); + + Opmask k_index_mask = Opmask(6); + Opmask k_store_mask = Opmask(7); + + // Here be some (tame) dragons. This kernel does not follow the regular + // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu + // instruction which has its destination hardcoded in rdi. Therefore: + // - all registers are hardcoded + // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI + // + // While this is only required by the backward pass, the quirk above + // is applied to the forward pass as well to keep things simpler. + + using reg64_t = const Xbyak::Reg64; + reg64_t reg_param = rdi; // Always mimic the Unix ABI + reg64_t reg_input = r8; + reg64_t aux_reg_input = r9; + reg64_t reg_index = r10; + reg64_t reg_output = r12; + reg64_t reg_kd_pad_shift = r13; + reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu + + reg64_t kj = r14; + reg64_t oi_iter = r15; + reg64_t reg_kh = rax; + reg64_t reg_k_shift = rbx; + reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above + reg64_t reg_ker_area_h = rdx; + + reg64_t zero_size = r15; + reg64_t ki = r12; + reg64_t aux_reg_input_d = r8; + + Xbyak::Reg32 reg_shuf_mask = esi; + + int prev_kw; + void (*jit_ker)(jit_pool_call_s *); + + void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r); + void avg_step(int ur_w, int pad_l, int pad_r); + void max_step_fwd(int ur_w, int pad_l, int pad_r); + void max_step_bwd(int ur_w, int pad_l, int pad_r); + + void maybe_zero_diff_src(); + + void step(int ur_w, int pad_l, int pad_r) { + if (jpp.alg == alg_kind::pooling_max) { + if(jpp.is_backward) + max_step_bwd(ur_w, pad_l, pad_r); + else + max_step_fwd(ur_w, pad_l, pad_r); + } + else + avg_step(ur_w, pad_l, pad_r); + } + + void step_high_half(int ur_w, int pad_l, int pad_r) { + add(reg_input, sizeof(float) * 4); + add(reg_output, sizeof(float) * 4); + if (jpp.alg == alg_kind::pooling_max && + (jpp.is_training || jpp.is_backward)) + add(reg_index, types::data_type_size(jpp.ind_dt) * 4); + + step(ur_w, pad_l, pad_r); + } + + void generate(); + + void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + assert(y0.getIdx() != x1.getIdx()); + vextractf128(xtmp, y0, 0); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 0); + vextractf128(xtmp, y0, 1); + vpaddd(xtmp, xtmp, x1); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + paddd(x0, x1); + } + + void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) { + Xmm x0(y0.getIdx()); + pshufd(xmm_tmp, x1, 1); + pmovzxbd(x0, x1); + pmovzxbd(xmm_tmp, xmm_tmp); + vinsertf128(y0, y0, xmm_tmp, 1); + } + + void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) { + assert(false /*function should not be used*/); + pmovzxbd(x0, x1); + } + + void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) { + assert(y0.getIdx() != y1.getIdx()); + assert(y0.getIdx() != y2.getIdx()); + Xmm x0(y0.getIdx()); + Xmm x2(y2.getIdx()); + vextractf128(x0, y1, 1); + vextractf128(xtmp, y2, 1); + pcmpeqd(xtmp, x0); + vextractf128(x0, y1, 0); + pcmpeqd(x0, x2); + vinsertf128(y0, y0, xtmp, 1); + } + + void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) { + assert(false /*function should not be used*/); + pcmpeqd(x0, x1); + } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp new file mode 100644 index 0000000000..afbcf996d8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_types.h" + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "nstl.hpp" + +#include "jit_uni_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +void jit_uni_pooling_fwd_t<isa>::execute_forward(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = oh == 0; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.oh, + [&](int n, int b_c, int oh) { + ker(n, b_c, oh); + }); +} + +template <cpu_isa_t isa> +void jit_uni_pooling_fwd_t<isa>::execute_forward_3d(const data_t *src, + data_t *dst, char *indices) const { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &src[src_d.blk_off(n, b_c, id, ih)]; + arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh + od == 0); + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad) + -jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow); + } + }); +} + +template <cpu_isa_t isa> +void jit_uni_pooling_bwd_t<isa>::execute_backward(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int oh) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = (oh == 0); + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)); + + (*kernel_)(&arg); + }; + + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, oh); + } + }); +} + +template <cpu_isa_t isa> +void jit_uni_pooling_bwd_t<isa>::execute_backward_3d(const data_t *diff_dst, + const char *indices, data_t *diff_src) const { + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper indices_d(pd()->workspace_md()); + const size_t ind_dt_size = indices + ? types::data_type_size(indices_d.data_type()) : 0; + + const auto &jpp = pd()->jpp_; + + auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow, + int d_b_overflow, int zero_size, int kd) { + auto arg = jit_pool_call_s(); + + const int ij = oh * jpp.stride_h; + const int i_t_overflow = nstl::max(0, jpp.t_pad-ij); + const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih; + const int ih = nstl::max(ij - jpp.t_pad, 0); + + arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)]; + arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)]; + if (indices) { + const size_t ind_off = indices_d.blk_off(n, b_c, od, oh); + arg.indices = &indices[ind_off * ind_dt_size]; + } + arg.oh = zero_size; + arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow; + arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow; + arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh + + kd * jpp.kw * jpp.kh; + arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw; + arg.kw_padding = 0; + arg.ker_area_h = (float)(jpp.kh - + nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) - + nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd - + nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) - + nstl::max(0, jpp.f_pad - od*jpp.stride_d)); + + (*kernel_)(&arg); + }; + + if (jpp.simple_alg) { + + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, + [&](int n, int b_c, int od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad - ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + const int id = nstl::max(ik - jpp.f_pad, 0); + int zero_s = jpp.stride_d - d_t_overflow - (nstl::max( + jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + (oh == 0) ? zero_s : 0, 0); + } + }); + } else { + ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c + * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw; + + parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; }); + + for (int kd = 0; kd < jpp.kd; ++kd) { + parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) { + for (int od = 0; od < jpp.od; ++od) { + const int ik = od * jpp.stride_d; + const int d_t_overflow = nstl::max(0, jpp.f_pad-ik); + const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd + - jpp.f_pad) - jpp.id; + if (kd >= jpp.kd - d_t_overflow - d_b_overflow) + continue; + const int id = nstl::max(ik - jpp.f_pad, 0); + for (int oh = 0; oh < jpp.oh; ++oh) { + ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow, + 0, kd); + } + } + }); + } + } +} + + +template struct jit_uni_pooling_fwd_t<sse42>; +template struct jit_uni_pooling_bwd_t<sse42>; +template struct jit_uni_pooling_fwd_t<avx>; +template struct jit_uni_pooling_bwd_t<avx>; +template struct jit_uni_pooling_fwd_t<avx512_common>; +template struct jit_uni_pooling_bwd_t<avx512_common>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp new file mode 100644 index 0000000000..57bebacdee --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp @@ -0,0 +1,182 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_JIT_UNI_POOLING_HPP +#define CPU_JIT_UNI_POOLING_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +#include "jit_uni_pool_kernel_f32.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <cpu_isa_t isa> +struct jit_uni_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_fwd_t<isa>); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() == 4 + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); } + + ~jit_uni_pooling_fwd_t() { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE); + + if (pd()->ndims() == 5) + execute_forward_3d(src, dst, ws); + else + execute_forward(src, dst, ws); + + return status::success; + } + +private: + void execute_forward(const data_t *src, data_t *dst, char *indices) const; + void execute_forward_3d(const data_t *src, data_t *dst, + char *indices) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32<isa> *kernel_; +}; + +template <cpu_isa_t isa> +struct jit_uni_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T( + JIT_IMPL_NAME_HELPER("jit:", isa, ""), + jit_uni_pooling_bwd_t<isa>); + + status_t init() { + using namespace utils; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && !has_zero_dim_memory() + && everyone_is(data_type::f32, + diff_src_md()->data_type, + diff_dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag()) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag()); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this); + } + + format_tag_t desired_fmt_tag() { + using namespace format_tag; + return ndims() + ? isa == avx512_common ? nChw16c : nChw8c + : isa == avx512_common ? nCdhw16c : nCdhw8c; + } + + jit_pool_conf_t jpp_; + }; + + jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) + { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); } + + ~jit_uni_pooling_bwd_t() { delete kernel_; } + + typedef typename prec_traits<data_type::f32>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + if (pd()->ndims() == 5) + execute_backward_3d(diff_dst, ws, diff_src); + else + execute_backward(diff_dst, ws, diff_src); + + return status::success; + } + +private: + void execute_backward(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + void execute_backward_3d(const data_t *diff_dst, const char *indices, + data_t *diff_src) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + jit_uni_pool_kernel_f32<isa> *kernel_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp new file mode 100644 index 0000000000..98796503b7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp @@ -0,0 +1,1006 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" +#include "jit_uni_reorder.hpp" + +#include "jit_generator.hpp" + +// #define TR_DEBUG +#if defined(TR_DEBUG) +#define DEBUg(...) do { __VA_ARGS__ } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +#ifdef _WIN32 +/* seems like s_addr is a reserved macro on Windows */ +#undef s_addr +#endif + +using namespace Xbyak; +using namespace mkldnn::impl::types; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ +const size_t ker_prb_size_min = 64; + +/* kernel */ +struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) + + enum { + len_unroll_max = 256, + ndims_jit_loop_max = 3, + }; + + struct simple_impl_desc_t { + int ndims_full_unroll; + int len_last_dim_unroll; + int len_unroll; + }; + + static bool simple_impl_desc_init(const prb_t &prb, + simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; + int len_unroll = 1; + + for (int d = 0; d < ndims; ++d) { + auto &node = prb.nodes[d]; + if (len_unroll * node.n <= len_unroll_max) { + ndims_full_unroll++; + len_unroll *= node.n; + } else { + len_last_dim_unroll = len_unroll_max / len_unroll; + while (node.n % len_last_dim_unroll) + --len_last_dim_unroll; + len_unroll *= len_last_dim_unroll; + break; + } + } + + if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) + return false; + + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; + desc->len_unroll = len_unroll; + } + + return true; + } + + static bool applicable(const prb_t &p) { + using namespace data_type; + + bool ok = true + && p.ndims > 0 + && utils::one_of(p.itype, f32, s32, s8, u8) + && utils::one_of(p.otype, f32, s32, s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ + && simple_impl_desc_init(p, nullptr) + && mayiuse(sse42) + && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype), + mayiuse(avx)); + if (!ok) return false; + + const ptrdiff_t max_stride = (1LL<<31) - 1; + for (int d = 0; d < p.ndims; ++d) { + const ptrdiff_t cms = max_stride / p.nodes[d].n; + bool strides_ok = true + && p.nodes[d].is < cms / (int)data_type_size(p.itype) + && p.nodes[d].os < cms / (int)data_type_size(p.otype); + if (!strides_ok) return false; + } + + return true; + } + + int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; } + int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; } + int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; } + int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; } + + Address i_addr(int i_off) + { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; } + + Address o_addr(int o_off) + { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; } + + Address s_addr(int s_off) + { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; } + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int &i_off, int &o_off, int &s_off, int step_size = 1) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) + dims_prod *= n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + + for (int d = start_dim; d < prb_.ndims; ++d) { + i_off += is(d); + o_off += os(d); + s_off += ss(d); + + if (off % n(d)) break; + + i_off += - n(d) * is(d); + o_off += - n(d) * os(d); + s_off += - n(d) * ss(d); + off /= n(d); + + if (off == 0) break; /* FIXME: is it really required? */ + } + } + + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1) { + int dummy = 0; + step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy, + step_size); + } + + void tr8x8_avx2(int i_off, int o_off) { + for (int i = 0; i < 8; i++) + vmovups(Ymm(i), i_addr(i_off + i * 8)); + + for (int i = 0; i < 8 / 2; i++) { + vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1)); + vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); + } + + const unsigned int lfloat = 0x44; + const unsigned int ufloat = 0xee; + for (int i = 0; i < 8 / 2; i++) { + int j = i % 2 == 0 ? 8 + i : i - 1; + vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); + vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); + } + + const unsigned int lquad = 0x20; + for (int i = 0; i < 8 / 2; i++) + vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad); + + const unsigned int uquad = 0x31; + for (int i = 8 / 2; i < 8; i++) + vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad); + + for (int i = 0; i < 8; i++) + vmovups(o_addr(o_off + i * 8), Ymm(i)); + } + + bool process_unroll_tr8x8(int len) { + bool can_do = true + && mayiuse(avx2) + && prb_.ndims >= 2 + && utils::everyone_is(4, itype_sz, otype_sz) + && utils::everyone_is(8, n(0), n(1)) + && utils::everyone_is(1, os(0), is(1)) + && utils::everyone_is(8, os(1), is(0)) + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + const int step_size = n(0) * n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr8x8_avx2(i_off, o_off); + } + + return true; + } + + template <cpu_isa_t isa> + bool process_direct_copy(int len) { + using namespace data_type; + + using Vmm = typename cpu_isa_traits<isa>::Vmm; + const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz; + + bool can_do = true + && mayiuse(isa) + && utils::everyone_is(1, os(0), is(0)) + && (false + || prb_.itype == prb_.otype + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32) + ) + && len % simd_w == 0 + && n(0) % len == 0 + && prb_.scale_type == scale_type_t::NONE + && prb_.beta == 0.f; + if (!can_do) return false; + + for (int off = 0; off < len;) { + const int unroll = nstl::min(16, (len - off) / simd_w); + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w)); + + if (prb_.itype != prb_.otype) { + for (int ur = 0; ur < unroll; ++ur) { + if (prb_.itype == s32 && prb_.otype == f32) + uni_vcvtdq2ps(Vmm(ur), Vmm(ur)); + else if (prb_.itype == f32 && prb_.otype == s32) + uni_vcvtps2dq(Vmm(ur), Vmm(ur)); + else assert(!"unreachable"); + } + } + + for (int ur = 0; ur < unroll; ++ur) + uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur)); + + off += unroll * simd_w; + } + + return true; + } + + void process_unroll_generic_step(int reg_unroll, const int *i_off, + const int *o_off, const int *s_off) { + using namespace data_type; + + auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) { + Xmm dst_pure = Xmm(dst.getIdx()); + switch (idt) { + case f32: + if (src.isMEM() || src.getIdx() != dst.getIdx()) + vmovups(dst, src); + break; + case s32: vcvtdq2ps(dst, src); break; + case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break; + default: assert(!"unreachable"); + } + }; + + auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) { + switch (odt) { + case s32: + if (idt == f32) vcvtps2dq(xmm, xmm); + else if (idt == s8) vpmovsxbd(xmm, xmm); + else if (idt == u8) vpmovzxbd(xmm, xmm); + break; + case s8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmovsdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpacksswb(xmm, xmm, xmm_zero); + } + } + if (idt == u8) vpminub(xmm, xmm, xmm_4x127b); + break; + case u8: + if (idt == f32) vcvtps2dq(xmm, xmm); + if (idt == f32 || idt == s32) { + if (mayiuse(avx512_core)) { + vpmaxsd(xmm, xmm, xmm_zero); + vpmovusdb(xmm, xmm); + } else { + vpackssdw(xmm, xmm, xmm_zero); + vpackuswb(xmm, xmm, xmm_zero); + } + } + if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero); + break; + default: assert(!"unreachable"); + } + }; + + auto load = [=](const Xmm &xmm, const Address &addr, int size) { + switch (size) { + case 16: movups(xmm, addr); break; + case 4: movss(xmm, addr); break; + case 1: pinsrb(xmm, addr, 0x0); break; + default: assert(!"unreachable"); + } + }; + + auto store = [=](const Address &addr, const Xmm &xmm, int size) { + switch (size) { + case 16: movups(addr, xmm); break; + case 4: movss(addr, xmm); break; + case 1: pextrb(addr, xmm, 0x0); break; + default: assert(!"unreachable"); + } + }; + + /* check whether loading 4 values at once is possible */ + bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (i_off[ur] != i_off[ur - 1] + 1) + can_load_xmm = false; + const int load_step = can_load_xmm ? 4 : 1; + + /* check whether storing 4 values at once is possible */ + bool can_store_xmm = reg_unroll % 4 == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) + can_store_xmm = false; + const int ur_step = can_store_xmm ? 4 : 1; + + const bool interim_f32 = false + || utils::one_of(f32, prb_.itype, prb_.otype) + || prb_.scale_type != scale_type_t::NONE + || prb_.beta != 0.f; + + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == 4); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + if (itype_sz == 4) + pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r); + else + pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r); + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz); + } + + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); + for (int ur = 0; ur < reg_unroll; ur += cvt_step) + cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); + } + + if (can_load_xmm && !can_store_xmm) { + const bool fast_return = true // transposition on the fly + && prb_.scale_type != scale_type_t::MANY + && prb_.beta == 0.f; + if (fast_return) { + for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (prb_.scale_type == scale_type_t::COMMON) + mulps(Xmm(ur), xmm_scale); + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, + interim_f32 ? f32 : prb_.itype); + for (int r = 0; r < load_step; ++r) { + if (otype_sz == 4) + pextrd(o_addr(o_off[ur + r]), Xmm(ur), r); + else + pextrb(o_addr(o_off[ur + r]), Xmm(ur), r); + } + } + return; + } + + /* scatter elements of xmm into 4 xmms */ + if (itype_sz == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r); + } + } + + /* scale and beta processing */ + if (can_store_xmm) { + /* xmm <-- scale * xmm[:] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulps(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + scale_load_type_t scale_load_type = + scale_load_type_t::bcast; // the best case + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast) { + movss(xmm_scale, s_addr(s_off[ur])); + shufps(xmm_scale, xmm_scale, 0x0); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // bcast doesn't work, the next try -- load + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load) { + movups(xmm_scale, s_addr(s_off[ur])); + mulps(Xmm(ur), xmm_scale); + continue; + } + + // load doesn't work as well + // so gather the scale factors one by one + for (int r = ur; r < ur + ur_step; ++r) + pinsrd(xmm_scale, s_addr(s_off[r]), r - ur); + mulps(Xmm(ur), xmm_scale); + } + } + + /* dst <-- beta * dst + xmm[:] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + /* non VEX instructions do not support unaligned + * memory for instructions other than movups. */ + if (mayiuse(avx)) { + vaddps(Xmm(ur), o_addr(o_off[ur])); + } else { + /* register xmm(1) is unused */ + movups(Xmm(1), o_addr(o_off[ur])); + addps(Xmm(ur), Xmm(1)); + } + } else { + cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); + vaddps(Xmm(ur), Xmm(1)); + } + } + } + } else { + /* xmm[0] <-- scale * xmm[0] */ + if (prb_.scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + mulss(Xmm(ur), xmm_scale); + } else if (prb_.scale_type == scale_type_t::MANY) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + mulss(Xmm(ur), s_addr(s_off[ur])); + } + } + + /* dst <-- beta * dst + xmm[0] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype == f32) { + addss(Xmm(ur), o_addr(o_off[ur])); + } else { + if (prb_.otype == s32) { + vmovss(xmm_tmp, o_addr(o_off[ur])); + } else if (utils::one_of(prb_.otype, s8, u8)) { + pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0); + } else { + assert(!"unsupported o_type"); + } + cvt2ps(xmm_tmp, xmm_tmp, prb_.otype); + addps(Xmm(ur), xmm_tmp); + } + } + } + } + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.otype != f32) + cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); + store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz); + } + } + + void process_unroll_generic(int len) { + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; + + /* compute offsets */ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr * blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + step(off + ur, + i_off[ur_p], o_off[ur_p], s_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c]); + } + + process_unroll_generic_step(reg_unroll, i_off + curr * blk, + o_off + curr * blk, s_off + curr * blk); + + curr = 1 - curr; + } + } + + void loop_begin(Label &l, Reg64 reg_cnt, int len) { + mov(reg_cnt, len); + L(l); + } + + void loop_end(Label &l, Reg64 reg_cnt, int len, + int i_step, int o_step, int s_step) { + add(reg_off_in, i_step * itype_sz); + add(reg_off_out, o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + add(reg_off_scale, s_step * stype_sz); + dec(reg_cnt); + jnz(l); + + sub(reg_off_in, len * i_step * itype_sz); + sub(reg_off_out, len * o_step * otype_sz); + if (prb_.scale_type == scale_type_t::MANY) + sub(reg_off_scale, len * s_step * stype_sz); + } + + bool simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + + const int nfu = d.ndims_full_unroll; + const int ldu = d.len_last_dim_unroll; + const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; + assert(n_jit_loops <= ndims_jit_loop_max); + + xor_(reg_off_in, reg_off_in); + xor_(reg_off_out, reg_off_out); + if (prb_.scale_type == scale_type_t::MANY) + xor_(reg_off_scale, reg_off_scale); + + Label l_loop[3]; + Reg64 reg_cnt[3] = {r15, r14, r13}; + + if (n_jit_loops > 2) + loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2)); + + if (n_jit_loops > 1) + loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1)); + + if (n_jit_loops > 0) + loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu); + + const bool optimized = false + || process_direct_copy<avx>(d.len_unroll) + || process_direct_copy<sse42>(d.len_unroll) + || process_unroll_tr8x8(d.len_unroll); + if (!optimized) + process_unroll_generic(d.len_unroll); + + if (n_jit_loops > 0) + loop_end(l_loop[0], reg_cnt[0], + n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu, + ss(nfu + 0) * ldu); + + if (n_jit_loops > 1) + loop_end(l_loop[1], reg_cnt[1], + n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1)); + + if (n_jit_loops > 2) + loop_end(l_loop[2], reg_cnt[2], + n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2)); + + return true; + } + + void impl() { + if (simple_impl()) return; + assert(!"no implementation available"); + } + + jit_uni_reorder_kernel_f32(const desc_t &desc) + : kernel_t(desc), jit_generator() { + itype_sz = data_type_size(prb_.itype); + otype_sz = data_type_size(prb_.otype); + stype_sz = sizeof(float); + + preamble(); +# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)] + if (prb_.scale_type == scale_type_t::COMMON) { + auto reg_ptr_scale_tmp = reg_ptr_in; + mov(reg_ptr_scale_tmp, PARAM(scale)); + movups(xmm_scale, ptr[reg_ptr_scale_tmp]); + } else if (prb_.scale_type == scale_type_t::MANY) { + mov(reg_ptr_scale, PARAM(scale)); + } + mov(reg_ptr_in, PARAM(in)); + mov(reg_ptr_out, PARAM(out)); +# undef PARAM + + if (mayiuse(avx)) { + vxorps(xmm_zero, xmm_zero, xmm_zero); + + if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { + mov(reg_tmp.cvt32(), 0x7f7f7f7f); + movd(xmm_4x127b, reg_tmp.cvt32()); + } + } + + impl(); + postamble(); + ker_ = (void (*)(const call_param_t *))getCode(); + } + +private: + int itype_sz; + int otype_sz; + int stype_sz; + + Reg64 reg_ptr_in = rsi; + Reg64 reg_ptr_out = rdx; + Reg64 reg_ptr_scale = abi_not_param1; + + Reg64 reg_off_in = r8; + Reg64 reg_off_out = r9; + Reg64 reg_off_scale = r10; + + Reg64 reg_tmp = rax; + + Xmm xmm_scale = xmm15; + Xmm xmm_zero = xmm14; + Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero + Xmm xmm_tmp = xmm12; +}; + +status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb, + int ndims_ker_max) { + desc.prb = prb; + desc.prb.ioff = desc.prb.ooff = 0; + + if (ndims_ker_max > prb.ndims) + return status::invalid_arguments; + + auto ndims_ker_max_f = [&]() { + size_t cur_size = 1; + for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) + if (cur_size >= ker_prb_size_min) return d; + return prb.ndims; + }; + + if (ndims_ker_max <= 0) + ndims_ker_max = ndims_ker_max_f(); + + /* traverse through kernel implementations */ + /* TODO: find a better way to do that... */ + desc.id = 0; + for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { + desc.prb.ndims = ndims_ker; + if (jit_uni_reorder_kernel_f32::applicable(desc.prb)) + return status::success; + } + + return status::unimplemented; +} + +kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + switch (desc.id) { + case 0: return new jit_uni_reorder_kernel_f32(desc); + default: assert(!"unknown kernel id"); return nullptr; + } + + return nullptr; +} + +} + +static void prb_block_for_cache(tr::prb_t &prb) { + if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) { + /** an attempt to use caches more efficient and + * address the 4K-aliasing issue */ + /* TODO: improve the logic around here */ + int j = 1; + for (; j < prb.ndims && prb.nodes[j].is != 1; ++j); + if (j == prb.ndims) return; + + /* it makes sense to re-prioritize sequential read over + * sequential write if the former would not trash the + * cache, i.e. is == 1 and os % 2^smth != 0. Smth is + * set to 2 at the moment */ + const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1; + if (j == move_to) return; + + if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0) + prb_node_split(prb, j, 16); + + prb_node_move(prb, j, move_to); + DEBUG({ printf("cache: "); prb_dump(prb); }); + } +} + +/** finds the maximum number of dimension the kernel should process and + * optionally splits one of the dimension to achieve better balance between + * parallel driver and the kernel. */ +static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) { + size_t sz_total = 1; + for (int d = 0; d < prb.ndims; ++d) + sz_total *= prb.nodes[d].n; + + /* sz_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t sz_drv_min = nstl::min<size_t>( + 16 * mkldnn_get_max_threads(), + utils::div_up(sz_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * sz_ker_cur -- product of the dimension processed by a kernel + * sz_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; + size_t sz_drv_cur = 1; + for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims) + sz_drv_cur *= prb.nodes[kdims - 1].n; + + size_t sz_ker_cur = 1; + for (int d = 0; d < kdims; ++d) + sz_ker_cur *= prb.nodes[d].n; + + /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min. + * + * It might happen that for chosen kdims the sz_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase sz_ker_cur. */ + bool want_borrow_ker_from_drv = true + && kdims < prb.ndims + && sz_ker_cur < tr::ker_prb_size_min + && sz_drv_cur > sz_drv_min; + if (want_borrow_ker_from_drv) { + /* sz_want_borrow is the minimal sz, so that: + * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by + * sz_want_borrow (so that we can evenly split that + * dimension into two) + * + * In the worst case the minimal sz_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ + size_t sz_want_borrow + = utils::div_up(tr::ker_prb_size_min, sz_ker_cur); + for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims].n) + prb_node_split(prb, kdims, sz_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims + * the sz_drv_cur is too small (less than sz_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase + * sz_drv_cur. */ + bool want_borrow_drv_from_ker = true + && sz_ker_cur > tr::ker_prb_size_min + && sz_drv_cur < sz_drv_min; + if (want_borrow_drv_from_ker) { + size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur); + for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow); + if (sz_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split(prb, kdims - 1, + prb.nodes[kdims - 1].n / sz_want_borrow); + } + + ndims_ker_max = kdims; + + if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { + DEBUG({ printf("split: "); prb_dump(prb); + printf("ndims_ker_max = %d\n", ndims_ker_max); }); + } +} + +struct jit_uni_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + auto prb = tr::prb_t(); + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + + DEBUG({ printf("init : "); prb_dump(prb); }); + prb_normalize(prb); + DEBUG({ printf("norm : "); prb_dump(prb); }); + prb_simplify(prb); + DEBUG({ printf("smpl : "); prb_dump(prb); }); + + prb_block_for_cache(prb); + + int ndims_ker_max; + prb_thread_kernel_balance(prb, ndims_ker_max); + + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); + if (ker_init_status != status::success) return ker_init_status; + + const int ndims_driver = prb.ndims - ker_desc.prb.ndims; + if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) + return status::unimplemented; + + DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); }); + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + _pd->prb_ = prb; + _pd->ker_desc_ = ker_desc; + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + }; + + jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + kernel_ = tr::kernel_t::create(pd()->ker_desc_); + assert(kernel_); + } + ~jit_uni_reorder_t() { delete kernel_; } + + void omp_driver_0d(int off, const char *in, char *out, + const float *scale) const { + tr::call_param_t c{in, out, scale}; + (*kernel_)(&c); + } + + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype); + c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) + * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + (*kernel_)(&c); + }); + } + + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, + const float *scale) const { + const tr::node_t *ns = pd()->prb_.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + auto c = tr::call_param_t(); + c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) * data_type_size(pd()->prb_.itype); + c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) * data_type_size(pd()->prb_.otype); + c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss + + d3 * ns[3].ss; + (*kernel_)(&c); + }); + } + + void omp_driver(const char *in, char *out, const float *scale) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + + DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); }); + DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); }); + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; + assert(ndims - ndims_ker <= ndims_driver_max); + + if (ndims - ndims_ker == 0) { + omp_driver_0d(ndims_ker, in, out, scale); + } else { + parallel(0, [&](const int ithr, const int nthr) { + switch (ndims - ndims_ker) { + case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break; + case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break; + case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break; + case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break; + default: assert(!"unimplemented"); + } + }); + } + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM); + auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + + omp_driver(in, out, pd()->attr()->output_scales_.scales_); + + return status::success; + } + + enum { ndims_driver_max = 4 }; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + tr::kernel_t *kernel_; +}; + +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr, + src_engine, src_md, dst_engine, dst_md); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp new file mode 100644 index 0000000000..0746ea61d3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef _JIT_UNI_REORDER_HPP +#define _JIT_UNI_REORDER_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_primitive.hpp" +#include "cpu_reorder_pd.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +constexpr int max_ndims = MKLDNN_MAX_NDIMS; + +struct node_t { + size_t n; + ptrdiff_t is; // input stride + ptrdiff_t os; // output stride + ptrdiff_t ss; // scale stride +}; + +enum class scale_type_t { NONE, COMMON, MANY }; + +struct prb_t { + data_type_t itype; + data_type_t otype; + int ndims; + node_t nodes[max_ndims]; + ptrdiff_t ioff; + ptrdiff_t ooff; + scale_type_t scale_type; + float beta; +}; + +status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +/** sorts the problem nodes so that output strides come in ascending order */ +void prb_normalize(prb_t &p); + +/** folds nodes together if possible */ +void prb_simplify(prb_t &p); + +/** splits the node dim into two of sizes n1 and n / n1 + * @warning n must be multiple of n1 */ +void prb_node_split(prb_t &p, int dim, size_t n1); + +/** swaps d0 and d1 nodes */ +void prb_node_swap(prb_t &p, int d0, int d1); + +/** moves node d0 to the d1 position. + * nodes (d0, d1] are shifted to the left if d0 < d1 or + * to the right if d0 > d1 */ +void prb_node_move(prb_t &p, int d0, int d1); + +/** dumps the problem to stdout */ +void prb_dump(const prb_t &p); + +struct call_param_t { + const void *in; + void *out; + const float *scale; +}; + +struct kernel_t { + struct desc_t { + int id; + prb_t prb; + }; + + kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {} + void operator()(const call_param_t *c) const { assert(ker_); ker_(c); } + virtual ~kernel_t() {} + + /** inits kernel descriptor: + * desc -- kernel descriptor (output) + * prb -- transposition problem (input) + * ndims_ker_max -- limit the maximum number of dimensions kernel + * will process (optional, 0 -- no limitation) */ + static status_t desc_init(desc_t &desc, const prb_t &prb, + int ndims_ker_max = 0); + + /** creates kernel for the problem described in desc */ + static kernel_t *create(const desc_t &desc); + +protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; + void (*ker_)(const call_param_t *); +}; + +/* TODO: add trans_t class */ + +} + +/* for cpu reorder list */ +status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md); + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp new file mode 100644 index 0000000000..69b7a33604 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp @@ -0,0 +1,313 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_desc_wrapper.hpp" +#include "mkldnn_debug.h" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "jit_uni_reorder.hpp" + +using namespace mkldnn::impl::types; +using namespace mkldnn::impl::status; + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace tr { + +/** ad-hoc structure to describe blocked memory layout */ +struct layout_desc_t { + data_type_t dt; + int ndims; + dims_t id; + dims_t dims; + strides_t strides; +}; + +status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_, + layout_desc_t &ld) { + const auto md = memory_desc_wrapper(md_); + + bool ok = true + && md.is_blocking_desc() + && md.extra().flags == 0; + if (!ok) return invalid_arguments; + + const auto &bd = md.blocking_desc(); + + ld.ndims = 0; + ld.dt = md.data_type(); + + auto P = [&ld](int id, int dim, ptrdiff_t stride) { + assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0])); + ld.id[ld.ndims] = id; + ld.dims[ld.ndims] = dim; + ld.strides[ld.ndims] = stride; + ++ld.ndims; + }; + + dims_t blocks; + md.compute_blocks(blocks); + + for (int d = 0; d < md.ndims(); ++d) { + const int ld_ndims_start = ld.ndims; + if (blocks[d] != 1) { + stride_t stride = 1; + for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) { + if (bd.inner_idxs[iblk] == d) + P(d, bd.inner_blks[iblk], stride); + stride *= bd.inner_blks[iblk]; + } + } + P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]); + + // TODO: NOW: revisit, do we need a reverse? + // TODO: NOW: consider using strides instead of block sizes in md + // reverse the order of dims + for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) { + const int idx0 = ld_ndims_start + ld_d; + const int idx1 = ld.ndims - 1 - ld_d; + nstl::swap(ld.dims[idx0], ld.dims[idx1]); + nstl::swap(ld.strides[idx0], ld.strides[idx1]); + } + } + + return success; +} + +status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, + const primitive_attr_t *attr) { + auto im_d = memory_desc_wrapper(imd); + auto om_d = memory_desc_wrapper(omd); + + bool ok = true + && im_d.is_blocking_desc() + && om_d.is_blocking_desc() + && !im_d.has_zero_dim() + && !om_d.has_zero_dim(); + if (!ok) + return unimplemented; + + dims_t iblocks, oblocks; + im_d.compute_blocks(iblocks); + om_d.compute_blocks(oblocks); + + /* padding_dim consistency check */ + for (int d = 0; d < im_d.ndims(); ++d) { + const auto pdim = im_d.padded_dims()[d]; + bool ok = true + && pdim == om_d.padded_dims()[d] + && pdim % iblocks[d] == 0 + && pdim % oblocks[d] == 0; + if (!ok) return unimplemented; + } + + layout_desc_t ild, old; + status_t status = cvt_mem_desc_to_layout_desc(imd, ild); + if (status != success) return status; + status = cvt_mem_desc_to_layout_desc(omd, old); + if (status != success) return status; + + p.itype = ild.dt; + p.otype = old.dt; + + p.scale_type = attr->output_scales_.has_default_values() + ? scale_type_t::NONE + : (attr->output_scales_.mask_ == 0 + ? scale_type_t::COMMON + : scale_type_t::MANY); + + ptrdiff_t ss[max_ndims] = {0}; + if (p.scale_type == scale_type_t::MANY) { + ptrdiff_t last_ss = 1; + for (int d = old.ndims - 1; d >=0; --d) { + assert((d == 0 || old.id[d - 1] <= old.id[d]) + && "logical dimensions should be in ascending order"); + if (attr->output_scales_.mask_ & (1 << old.id[d])) { + ss[d] = last_ss; + last_ss *= old.dims[d]; + } + } + } + + int ndims = 0; + + int i_pos = 0; /* state for input -- current dimension */ + int o_pos = 0; /* state for output -- current dimension */ + + while (i_pos < ild.ndims && o_pos < old.ndims) { + assert(ild.id[i_pos] == old.id[o_pos]); + if (ild.id[i_pos] != old.id[o_pos]) + return runtime_error; + + assert(ndims < max_ndims); + if (ndims == max_ndims) + return runtime_error; + + if (ild.dims[i_pos] == old.dims[o_pos]) { + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++i_pos; + ++o_pos; + } else if (ild.dims[i_pos] < old.dims[o_pos]) { + assert(old.dims[o_pos] % ild.dims[i_pos] == 0); + int factor = old.dims[o_pos] / ild.dims[i_pos]; + p.nodes[ndims].n = ild.dims[i_pos]; + p.nodes[ndims].is = ild.strides[i_pos]; + p.nodes[ndims].os = old.strides[o_pos] * factor; + p.nodes[ndims].ss = ss[o_pos] * factor; + ++ndims; + ++i_pos; + old.dims[o_pos] = factor; + } else if (ild.dims[i_pos] > old.dims[o_pos]) { + assert(ild.dims[i_pos] % old.dims[o_pos] == 0); + int factor = ild.dims[i_pos] / old.dims[o_pos]; + p.nodes[ndims].n = old.dims[o_pos]; + p.nodes[ndims].is = ild.strides[i_pos] * factor; + p.nodes[ndims].os = old.strides[o_pos]; + p.nodes[ndims].ss = ss[o_pos]; + ++ndims; + ++o_pos; + ild.dims[i_pos] = factor; + } + } + p.ndims = ndims; + + dims_t zero_pos = {0}; + p.ioff = memory_desc_wrapper(imd).off_v(zero_pos); + p.ooff = memory_desc_wrapper(omd).off_v(zero_pos); + + const int sum_idx = attr->post_ops_.find(primitive_kind::sum); + p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale; + + return success; +} + +void prb_normalize(prb_t &p) { + for (int d = 0; d < p.ndims; ++d) { + int min_pos = d; + for (int j = d + 1; j < p.ndims; ++j) { + bool new_min = false + || p.nodes[j].os < p.nodes[min_pos].os + || (true + && p.nodes[j].os == p.nodes[min_pos].os + && p.nodes[j].n < p.nodes[min_pos].n); + if (new_min) min_pos = j; + } + if (min_pos != d) + nstl::swap(p.nodes[d], p.nodes[min_pos]); + } +} + +void prb_simplify(prb_t &p) { +#if defined(__GNUC__) && __GNUC__ >= 4 +/* GCC produces bogus array subscript is above array bounds warning for + * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" +#endif + for (int d = 0; d < p.ndims - 1; ++d) { + auto &this_node = p.nodes[d + 0]; + auto &next_node = p.nodes[d + 1]; + const bool fold = false + || next_node.n == (size_t)1 // trivial case, just drop next node + || (true // or real folding if possible + && next_node.is == (ptrdiff_t)this_node.n * this_node.is + && next_node.os == (ptrdiff_t)this_node.n * this_node.os + && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss); + if (fold) { + this_node.n *= next_node.n; + for (int j = d + 2; j < p.ndims; ++j) + p.nodes[j - 1] = p.nodes[j]; + --p.ndims; + --d; // make another try + } + } +#if defined(__GNUC__) && __GNUC__ >= 4 +#pragma GCC diagnostic pop +#endif +} + +void prb_node_split(prb_t &p, int dim, size_t n1) { + assert(dim < p.ndims); + assert(p.ndims < max_ndims); + assert(p.nodes[dim].n % n1 == 0); + + p.ndims += 1; + + for (int d = p.ndims; d > dim + 1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[dim + 1].n = p.nodes[dim].n / n1; + p.nodes[dim + 1].is = p.nodes[dim].is * n1; + p.nodes[dim + 1].os = p.nodes[dim].os * n1; + p.nodes[dim + 1].ss = p.nodes[dim].ss * n1; + + p.nodes[dim].n = n1; +} + +void prb_node_swap(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + nstl::swap(p.nodes[d0], p.nodes[d1]); +} + +void prb_node_move(prb_t &p, int d0, int d1) { + assert(d0 < p.ndims); + assert(d1 < p.ndims); + assert(p.ndims < max_ndims); + + if (d0 == d1) return; + + node_t node = p.nodes[d0]; + + if (d0 < d1) + for (int d = d0; d < d1; ++d) + p.nodes[d] = p.nodes[d + 1]; + else + for (int d = d0; d > d1; --d) + p.nodes[d] = p.nodes[d - 1]; + + p.nodes[d1] = node; +} + +void prb_dump(const prb_t &p) { + printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype), + mkldnn_dt2str(p.otype), p.ndims); + for (int d = 0; d < p.ndims; ++d) + printf("[%zu:%td:%td:%td]", + p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss); + printf(" off:%zu:%zu\n", p.ioff, p.ooff); +} + +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp new file mode 100644 index 0000000000..08747aa89c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp @@ -0,0 +1,115 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <mutex> + +#include "utils.hpp" + +#ifndef MKLDNN_ENABLE_JIT_PROFILING +#define MKLDNN_ENABLE_JIT_PROFILING 1 +#endif + +#ifndef MKLDNN_ENABLE_JIT_DUMP +#define MKLDNN_ENABLE_JIT_DUMP 1 +#endif + +#if MKLDNN_ENABLE_JIT_PROFILING +#include "jitprofiling/jitprofiling.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +// WARNING: These functions are not thread safe and must be protected by a +// mutex + +void dump_jit_code(const void *code, size_t code_size, const char *code_name) +{ +#if MKLDNN_ENABLE_JIT_DUMP + if (code && jit_dump_enabled()) { + static int counter = 0; +#define MAX_FNAME_LEN 256 + char fname[MAX_FNAME_LEN + 1]; + // TODO (Roma): support prefix for code / linux perf dumps + snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name, + counter); + counter++; + + FILE *fp = fopen(fname, "w+"); + // Failure to dump code is not fatal + if (fp) { + size_t unused = fwrite(code, code_size, 1, fp); + UNUSED(unused); + fclose(fp); + } + } +#undef MAX_FNAME_LEN +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); +#endif +} + +void register_jit_code_vtune(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ +#if MKLDNN_ENABLE_JIT_PROFILING + if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) { + auto jmethod = iJIT_Method_Load(); + jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe + jmethod.method_name = (char *)code_name; // XXX: dropping const + jmethod.class_file_name = NULL; + jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const + jmethod.method_load_address = (void *)code; + jmethod.method_size = (unsigned int)code_size; + + iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, + (void*)&jmethod); + } +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name) +{ + // The #ifdef guards are required to avoid generating a function that only + // consists of lock and unlock code +#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP + static std::mutex m; + std::lock_guard<std::mutex> guard(m); + + dump_jit_code(code, code_size, code_name); + register_jit_code_vtune(code, code_size, code_name, source_file_name); +#else + UNUSED(code); + UNUSED(code_size); + UNUSED(code_name); + UNUSED(source_file_name); +#endif +} + +} +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp new file mode 100644 index 0000000000..2f52dba4ac --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef JIT_SUPPORT_HPP +#define JIT_SUPPORT_HPP + +namespace mkldnn { +namespace impl { +namespace cpu { +namespace jit_utils { + +void register_jit_code(const void *code, size_t code_size, + const char *code_name, const char *source_file_name); + +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD new file mode 100644 index 0000000000..4fd21cea57 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD @@ -0,0 +1,27 @@ +Copyright (c) 2011, Intel Corporation +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md new file mode 100644 index 0000000000..fc67c4f134 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md @@ -0,0 +1 @@ +This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI) diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h new file mode 100644 index 0000000000..edbf4a15f0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h @@ -0,0 +1,595 @@ +/* <copyright> + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +</copyright> */ +#ifndef _ITTNOTIFY_CONFIG_H_ +#define _ITTNOTIFY_CONFIG_H_ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include <stddef.h> +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include <tchar.h> +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include <stdint.h> +#if defined(UNICODE) || defined(_UNICODE) +#include <wchar.h> +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#define ITT_INLINE __forceinline +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +#ifndef ITT_ARCH_IA32 +# define ITT_ARCH_IA32 1 +#endif /* ITT_ARCH_IA32 */ + +#ifndef ITT_ARCH_IA32E +# define ITT_ARCH_IA32E 2 +#endif /* ITT_ARCH_IA32E */ + +#ifndef ITT_ARCH_ARM +# define ITT_ARCH_ARM 4 +#endif /* ITT_ARCH_ARM */ + +#ifndef ITT_ARCH_PPC64 +# define ITT_ARCH_PPC64 5 +#endif /* ITT_ARCH_PPC64 */ + +#ifndef ITT_ARCH +# if defined _M_IX86 || defined __i386__ +# define ITT_ARCH ITT_ARCH_IA32 +# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__ +# define ITT_ARCH ITT_ARCH_IA32E +# elif defined _M_IA64 || defined __ia64__ +# define ITT_ARCH ITT_ARCH_IA64 +# elif defined _M_ARM || defined __arm__ +# define ITT_ARCH ITT_ARCH_ARM +# elif defined __powerpc64__ +# define ITT_ARCH ITT_ARCH_PPC64 +# endif +#endif + +#ifdef __cplusplus +# define ITT_EXTERN_C extern "C" +# define ITT_EXTERN_C_BEGIN extern "C" { +# define ITT_EXTERN_C_END } +#else +# define ITT_EXTERN_C /* nothing */ +# define ITT_EXTERN_C_BEGIN /* nothing */ +# define ITT_EXTERN_C_END /* nothing */ +#endif /* __cplusplus */ + +#define ITT_TO_STR_AUX(x) #x +#define ITT_TO_STR(x) ITT_TO_STR_AUX(x) + +#define __ITT_BUILD_ASSERT(expr, suffix) do { \ + static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \ + __itt_build_check_##suffix[0] = 0; \ +} while(0) +#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix) +#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__) + +#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 } + +/* Replace with snapshot date YYYYMMDD for promotion build. */ +#define API_VERSION_BUILD 20151119 + +#ifndef API_VERSION_NUM +#define API_VERSION_NUM 0.0.0 +#endif /* API_VERSION_NUM */ + +#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \ + " (" ITT_TO_STR(API_VERSION_BUILD) ")" + +/* OS communication functions */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include <windows.h> +typedef HMODULE lib_t; +typedef DWORD TIDT; +typedef CRITICAL_SECTION mutex_t; +#define MUTEX_INITIALIZER { 0 } +#define strong_alias(name, aliasname) /* empty for Windows */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include <dlfcn.h> +#if defined(UNICODE) || defined(_UNICODE) +#include <wchar.h> +#endif /* UNICODE */ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */ +#endif /* _GNU_SOURCE */ +#ifndef __USE_UNIX98 +#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */ +#endif /*__USE_UNIX98*/ +#include <pthread.h> +typedef void* lib_t; +typedef pthread_t TIDT; +typedef pthread_mutex_t mutex_t; +#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER +#define _strong_alias(name, aliasname) \ + extern __typeof (name) aliasname __attribute__ ((alias (#name))); +#define strong_alias(name, aliasname) _strong_alias(name, aliasname) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_get_proc(lib, name) GetProcAddress(lib, name) +#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex) +#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex) +#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex) +#define __itt_load_lib(name) LoadLibraryA(name) +#define __itt_unload_lib(handle) FreeLibrary(handle) +#define __itt_system_error() (int)GetLastError() +#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2) +#define __itt_fstrnlen(s, l) strnlen_s(s, l) +#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l) +#define __itt_fstrdup(s) _strdup(s) +#define __itt_thread_id() GetCurrentThreadId() +#define __itt_thread_yield() SwitchToThread() +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return InterlockedIncrement(ptr); +} +#endif /* ITT_SIMPLE_INIT */ + +#define DL_SYMBOLS (1) +#define PTHREAD_SYMBOLS (1) + +#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */ +#define __itt_get_proc(lib, name) dlsym(lib, name) +#define __itt_mutex_init(mutex) {\ + pthread_mutexattr_t mutex_attr; \ + int error_code = pthread_mutexattr_init(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \ + error_code); \ + error_code = pthread_mutexattr_settype(&mutex_attr, \ + PTHREAD_MUTEX_RECURSIVE); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \ + error_code); \ + error_code = pthread_mutex_init(mutex, &mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutex_init", \ + error_code); \ + error_code = pthread_mutexattr_destroy(&mutex_attr); \ + if (error_code) \ + __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \ + error_code); \ +} +#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex) +#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex) +#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) +#define __itt_unload_lib(handle) dlclose(handle) +#define __itt_system_error() errno +#define __itt_fstrcmp(s1, s2) strcmp(s1, s2) + +/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */ +#ifdef SDL_STRNLEN_S +#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l) +#else +#define __itt_fstrnlen(s, l) strlen(s) +#endif /* SDL_STRNLEN_S */ +#ifdef SDL_STRNCPY_S +#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l) +#else +#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l) +#endif /* SDL_STRNCPY_S */ + +#define __itt_fstrdup(s) strdup(s) +#define __itt_thread_id() pthread_self() +#define __itt_thread_yield() sched_yield() +#if ITT_ARCH==ITT_ARCH_IA64 +#ifdef __INTEL_COMPILER +#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val) +#else /* __INTEL_COMPILER */ +/* TODO: Add Support for not Intel compilers for IA-64 architecture */ +#endif /* __INTEL_COMPILER */ +#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */ +ITT_INLINE long +__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend) +{ + long result; + __asm__ __volatile__("lock\nxadd %0,%1" + : "=r"(result),"=m"(*(int*)ptr) + : "0"(addend), "m"(*(int*)ptr) + : "memory"); + return result; +} +#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64 +#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val) +#endif /* ITT_ARCH==ITT_ARCH_IA64 */ +#ifndef ITT_SIMPLE_INIT +ITT_INLINE long +__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE; +ITT_INLINE long __itt_interlocked_increment(volatile long* ptr) +{ + return __TBB_machine_fetchadd4(ptr, 1) + 1L; +} +#endif /* ITT_SIMPLE_INIT */ + +void* dlopen(const char*, int) __attribute__((weak)); +void* dlsym(void*, const char*) __attribute__((weak)); +int dlclose(void*) __attribute__((weak)); +#define DL_SYMBOLS (dlopen && dlsym && dlclose) + +int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak)); +int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak)); +int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak)); +int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak)); +pthread_t pthread_self(void) __attribute__((weak)); +#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self) + +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +typedef enum { + __itt_collection_normal = 0, + __itt_collection_paused = 1 +} __itt_collection_state; + +typedef enum { + __itt_thread_normal = 0, + __itt_thread_ignored = 1 +} __itt_thread_state; + +#pragma pack(push, 8) + +typedef struct ___itt_thread_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + TIDT tid; + __itt_thread_state state; /*!< Thread state (paused or normal) */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_thread_info* next; +} __itt_thread_info; + +#include "ittnotify_types.h" /* For __itt_group_id definition */ + +typedef struct ___itt_api_info_20101001 +{ + const char* name; + void** func_ptr; + void* init_func; + __itt_group_id group; +} __itt_api_info_20101001; + +typedef struct ___itt_api_info +{ + const char* name; + void** func_ptr; + void* init_func; + void* null_func; + __itt_group_id group; +} __itt_api_info; + +typedef struct __itt_counter_info +{ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + const char* domainA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* domainW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* domainW; +#endif /* UNICODE || _UNICODE */ + int type; + long index; + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct __itt_counter_info* next; +} __itt_counter_info_t; + +struct ___itt_domain; +struct ___itt_string_handle; + +typedef struct ___itt_global +{ + unsigned char magic[8]; + unsigned long version_major; + unsigned long version_minor; + unsigned long version_build; + volatile long api_initialized; + volatile long mutex_initialized; + volatile long atomic_counter; + mutex_t mutex; + lib_t lib; + void* error_handler; + const char** dll_path_ptr; + __itt_api_info* api_list_ptr; + struct ___itt_global* next; + /* Joinable structures below */ + __itt_thread_info* thread_list; + struct ___itt_domain* domain_list; + struct ___itt_string_handle* string_list; + __itt_collection_state state; + __itt_counter_info_t* counter_list; +} __itt_global; + +#pragma pack(pop) + +#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = NULL; \ + h->nameW = n ? _wcsdup(n) : NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \ + h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \ + if (h != NULL) { \ + h->tid = t; \ + h->nameA = n ? __itt_fstrdup(n) : NULL; \ + h->nameW = NULL; \ + h->state = s; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->thread_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \ + h = (__itt_domain*)malloc(sizeof(__itt_domain)); \ + if (h != NULL) { \ + h->flags = 1; /* domain is enabled by default */ \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->domain_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = NULL; \ + h->strW = name ? _wcsdup(name) : NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \ + h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \ + if (h != NULL) { \ + h->strA = name ? __itt_fstrdup(name) : NULL; \ + h->strW = NULL; \ + h->extra1 = 0; /* reserved */ \ + h->extra2 = NULL; /* reserved */ \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->string_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = NULL; \ + h->nameW = name ? _wcsdup(name) : NULL; \ + h->domainA = NULL; \ + h->domainW = name ? _wcsdup(domain) : NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \ + h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \ + if (h != NULL) { \ + h->nameA = name ? __itt_fstrdup(name) : NULL; \ + h->nameW = NULL; \ + h->domainA = domain ? __itt_fstrdup(domain) : NULL; \ + h->domainW = NULL; \ + h->type = type; \ + h->index = 0; \ + h->next = NULL; \ + if (h_tail == NULL) \ + (gptr)->counter_list = h; \ + else \ + h_tail->next = h; \ + } \ +} + +#endif /* _ITTNOTIFY_CONFIG_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h new file mode 100644 index 0000000000..99fbc24054 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h @@ -0,0 +1,94 @@ +/* <copyright> + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +</copyright> */ + +#ifndef _ITTNOTIFY_TYPES_H_ +#define _ITTNOTIFY_TYPES_H_ + +typedef enum ___itt_group_id +{ + __itt_group_none = 0, + __itt_group_legacy = 1<<0, + __itt_group_control = 1<<1, + __itt_group_thread = 1<<2, + __itt_group_mark = 1<<3, + __itt_group_sync = 1<<4, + __itt_group_fsync = 1<<5, + __itt_group_jit = 1<<6, + __itt_group_model = 1<<7, + __itt_group_splitter_min = 1<<7, + __itt_group_counter = 1<<8, + __itt_group_frame = 1<<9, + __itt_group_stitch = 1<<10, + __itt_group_heap = 1<<11, + __itt_group_splitter_max = 1<<12, + __itt_group_structure = 1<<12, + __itt_group_suppress = 1<<13, + __itt_group_arrays = 1<<14, + __itt_group_all = -1 +} __itt_group_id; + +#pragma pack(push, 8) + +typedef struct ___itt_group_list +{ + __itt_group_id id; + const char* name; +} __itt_group_list; + +#pragma pack(pop) + +#define ITT_GROUP_LIST(varname) \ + static __itt_group_list varname[] = { \ + { __itt_group_all, "all" }, \ + { __itt_group_control, "control" }, \ + { __itt_group_thread, "thread" }, \ + { __itt_group_mark, "mark" }, \ + { __itt_group_sync, "sync" }, \ + { __itt_group_fsync, "fsync" }, \ + { __itt_group_jit, "jit" }, \ + { __itt_group_model, "model" }, \ + { __itt_group_counter, "counter" }, \ + { __itt_group_frame, "frame" }, \ + { __itt_group_stitch, "stitch" }, \ + { __itt_group_heap, "heap" }, \ + { __itt_group_structure, "structure" }, \ + { __itt_group_suppress, "suppress" }, \ + { __itt_group_arrays, "arrays" }, \ + { __itt_group_none, NULL } \ + } + +#endif /* _ITTNOTIFY_TYPES_H_ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c new file mode 100644 index 0000000000..15f4b9929b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c @@ -0,0 +1,293 @@ +/* <copyright> + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +</copyright> */ + +#include "ittnotify_config.h" + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include <windows.h> +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD +#include <malloc.h> +#endif +#include <stdlib.h> + +#include "jitprofiling.h" + +static const char rcsid[] = "\n@(#) $Revision: 471937 $\n"; + +#define DLL_ENVIRONMENT_VAR "VS_PROFILER" + +#ifndef NEW_DLL_ENVIRONMENT_VAR +#if ITT_ARCH==ITT_ARCH_IA32 +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32" +#else +#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64" +#endif +#endif /* NEW_DLL_ENVIRONMENT_VAR */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define DEFAULT_DLLNAME "JitPI.dll" +HINSTANCE m_libHandle = NULL; +#elif ITT_PLATFORM==ITT_PLATFORM_MAC +#define DEFAULT_DLLNAME "libJitPI.dylib" +void* m_libHandle = NULL; +#else +#define DEFAULT_DLLNAME "libJitPI.so" +void* m_libHandle = NULL; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/* default location of JIT profiling agent on Android */ +#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so" + +/* the function pointers */ +typedef unsigned int(JITAPI *TPInitialize)(void); +static TPInitialize FUNC_Initialize=NULL; + +typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*); +static TPNotify FUNC_NotifyEvent=NULL; + +static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING; + +/* end collector dll part. */ + +/* loadiJIT_Funcs() : this function is called just in the beginning + * and is responsible to load the functions from BistroJavaCollector.dll + * result: + * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1 + * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0 + */ +static int loadiJIT_Funcs(void); + +/* global representing whether the collector can't be loaded */ +static int iJIT_DLL_is_missing = 0; + +ITT_EXTERN_C int JITAPI +iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData) +{ + int ReturnValue = 0; + + /* initialization part - the collector has not been loaded yet. */ + if (!FUNC_NotifyEvent) + { + if (iJIT_DLL_is_missing) + return 0; + + if (!loadiJIT_Funcs()) + return 0; + } + + if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED || + event_type == iJVM_EVENT_TYPE_METHOD_UPDATE) + { + if (((piJIT_Method_Load)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2) + { + if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3) + { + if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0) + return 0; + } + else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED) + { + if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 || + ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0) + return 0; + } + + ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData); + + return ReturnValue; +} + +ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive() +{ + if (!iJIT_DLL_is_missing) + { + loadiJIT_Funcs(); + } + + return executionMode; +} + +/* This function loads the collector dll and the relevant functions. + * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1 + * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0 + */ +static int loadiJIT_Funcs() +{ + static int bDllWasLoaded = 0; + char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + DWORD dNameLength = 0; +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if(bDllWasLoaded) + { + /* dll was already loaded, no need to do it for the second time */ + return 1; + } + + /* Assumes that the DLL will not be found */ + iJIT_DLL_is_missing = 1; + FUNC_NotifyEvent = NULL; + + if (m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FreeLibrary(m_libHandle); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dlclose(m_libHandle); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = NULL; + } + + /* Try to get the dll name from the environment */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN + dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryExA(dllName, + NULL, LOAD_WITH_ALTERED_SEARCH_PATH); + } + free(dllName); + } + } else { + /* Try to use old VS_PROFILER variable */ + dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0); + if (dNameLength) + { + DWORD envret = 0; + dllName = (char*)malloc(sizeof(char) * (dNameLength + 1)); + if(dllName != NULL) + { + envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, + dllName, dNameLength); + if (envret) + { + /* Try to load the dll from the PATH... */ + m_libHandle = LoadLibraryA(dllName); + } + free(dllName); + } + } + } +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + dllName = getenv(NEW_DLL_ENVIRONMENT_VAR); + if (!dllName) + dllName = getenv(DLL_ENVIRONMENT_VAR); +#if defined(__ANDROID__) || defined(ANDROID) + if (!dllName) + dllName = ANDROID_JIT_AGENT_PATH; +#endif + if (dllName) + { + /* Try to load the dll from the PATH... */ + m_libHandle = dlopen(dllName, RTLD_LAZY); + } +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + + if (!m_libHandle) + { +#if ITT_PLATFORM==ITT_PLATFORM_WIN + m_libHandle = LoadLibraryA(DEFAULT_DLLNAME); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + } + + /* if the dll wasn't loaded - exit. */ + if (!m_libHandle) + { + iJIT_DLL_is_missing = 1; /* don't try to initialize + * JIT agent the second time + */ + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_NotifyEvent) + { + FUNC_Initialize = NULL; + return 0; + } + +#if ITT_PLATFORM==ITT_PLATFORM_WIN + FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize"); +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize"); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + if (!FUNC_Initialize) + { + FUNC_NotifyEvent = NULL; + return 0; + } + + executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize(); + + bDllWasLoaded = 1; + iJIT_DLL_is_missing = 0; /* DLL is ok. */ + + return 1; +} + +ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID() +{ + static unsigned int methodID = 1; + + if (methodID == 0) + return 0; /* ERROR : this is not a valid value */ + + return methodID++; +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h new file mode 100644 index 0000000000..bf0489b1a1 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h @@ -0,0 +1,673 @@ +/* <copyright> + + Contact Information: + http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/ + + BSD LICENSE + + Copyright (c) 2005-2014 Intel Corporation. All rights reserved. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + * Neither the name of Intel Corporation nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +</copyright> */ + +#ifndef __JITPROFILING_H__ +#define __JITPROFILING_H__ + +/** + * @brief JIT Profiling APIs + * + * The JIT Profiling API is used to report information about just-in-time + * generated code that can be used by performance tools. The user inserts + * calls in the code generator to report information before JIT-compiled + * code goes to execution. This information is collected at runtime and used + * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics + * associated with JIT-compiled code. + * + * These APIs can be used to\n + * - **Profile trace-based and method-based JIT-compiled + * code**. Some examples of environments that you can profile with these APIs: + * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) + * software technology, Java/.NET managed execution environments, and custom + * ISV JIT engines. + * @code + * #include <jitprofiling.h> + * + * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { + * return; + * } + * + * iJIT_Method_Load jmethod = {0}; + * jmethod.method_id = iJIT_GetNewMethodID(); + * jmethod.method_name = "method_name"; + * jmethod.class_file_name = "class_name"; + * jmethod.source_file_name = "source_file_name"; + * jmethod.method_load_address = code_addr; + * jmethod.method_size = code_size; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); + * @endcode + * + * * Expected behavior: + * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and its + * memory region is treated as unloaded. VTune Amplifier displays the metrics + * collected by the method until it is overwritten. + * * If supplied line number information contains multiple source lines for + * the same assembly instruction (code location), then VTune Amplifier picks up + * the first line number. + * * Dynamically generated code can be associated with a module name. + * Use the iJIT_Method_Load_V2 structure.\n + * Clarification of some cases: + * * If you register a function with the same method ID multiple times, + * specifying different module names, then the VTune Amplifier picks up + * the module name registered first. If you want to distinguish the same + * function between different JIT engines, supply different method IDs for + * each function. Other symbolic information (for example, source file) + * can be identical. + * + * - **Analyze split functions** (multiple joint or disjoint code regions + * belonging to the same function) **including re-JIT** + * with potential overlapping of code regions in time, which is common in + * resource-limited environments. + * @code + * #include <jitprofiling.h> + * + * unsigned int method_id = iJIT_GetNewMethodID(); + * + * iJIT_Method_Load a = {0}; + * a.method_id = method_id; + * a.method_load_address = 0x100; + * a.method_size = 0x20; + * + * iJIT_Method_Load b = {0}; + * b.method_id = method_id; + * b.method_load_address = 0x200; + * b.method_size = 0x30; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); + * @endcode + * + * * Expected behaviour: + * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and + * its memory region is treated as unloaded. + * * All code regions reported with the same method ID are considered as + * belonging to the same method. Symbolic information (method name, + * source file name) will be taken from the first notification, and all + * subsequent notifications with the same method ID will be processed + * only for line number table information. So, the VTune Amplifier will map + * samples to a source line using the line number table from the current + * notification while taking the source file name from the very first one.\n + * Clarification of some cases:\n + * * If you register a second code region with a different source file + * name and the same method ID, then this information will be saved and + * will not be considered as an extension of the first code region, but + * VTune Amplifier will use the source file of the first code region and map + * performance metrics incorrectly. + * * If you register a second code region with the same source file as + * for the first region and the same method ID, then the source file will be + * discarded but VTune Amplifier will map metrics to the source file correctly. + * * If you register a second code region with a null source file and + * the same method ID, then provided line number info will be associated + * with the source file of the first code region. + * + * - **Explore inline functions** including multi-level hierarchy of + * nested inline methods which shows how performance metrics are distributed through them. + * @code + * #include <jitprofiling.h> + * + * // method_id parent_id + * // [-- c --] 3000 2000 + * // [---- d -----] 2001 1000 + * // [---- b ----] 2000 1000 + * // [------------ a ----------------] 1000 n/a + * + * iJIT_Method_Load a = {0}; + * a.method_id = 1000; + * + * iJIT_Method_Inline_Load b = {0}; + * b.method_id = 2000; + * b.parent_method_id = 1000; + * + * iJIT_Method_Inline_Load c = {0}; + * c.method_id = 3000; + * c.parent_method_id = 2000; + * + * iJIT_Method_Inline_Load d = {0}; + * d.method_id = 2001; + * d.parent_method_id = 1000; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); + * @endcode + * + * * Requirements: + * * Each inline (iJIT_Method_Inline_Load) method should be associated + * with two method IDs: one for itself; one for its immediate parent. + * * Address regions of inline methods of the same parent method cannot + * overlap each other. + * * Execution of the parent method must not be started until it and all + * its inline methods are reported. + * * Expected behaviour: + * * In case of nested inline methods an order of + * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. + * * If any event overwrites either inline method or top parent method, + * then the parent, including inline methods, becomes invalid and its memory + * region is treated as unloaded. + * + * **Life time of allocated data**\n + * The client sends an event notification to the agent with event-specific + * data, which is a structure. The pointers in the structure refer to memory + * allocated by the client, which responsible for releasing it. The pointers are + * used by the iJIT_NotifyEvent method to copy client's data in a trace file, + * and they are not used after the iJIT_NotifyEvent method returns. + */ + +/** + * @defgroup jitapi JIT Profiling + * @ingroup internal + * @{ + */ + +/** + * @brief Enumerator for the types of notifications + */ +typedef enum iJIT_jvm_event +{ + iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. + * Use NULL for event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load as event + * data. */ +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic + * code is being unloaded from memory. + * Use iJIT_Method_Load as event data.*/ +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for + * a previously reported dynamic code. + * The previous content will be invalidated + * starting from the time of the notification. + * Use iJIT_Method_Load as event data but + * required fields are following: + * - method_id identify the code to update. + * - method_load_address specify start address + * within identified code range + * where update should be started. + * - method_size specify length of updated code + * range. */ + + + iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic + * code is JIT compiled and loaded + * into memory by the JIT engine, + * but before the parent code region + * starts executing. + * Use iJIT_Method_Inline_Load as event data.*/ + +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UPDATE_V2, +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V2 as event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V3 as event data. */ +} iJIT_JVM_EVENT; + +/** + * @brief Enumerator for the agent's mode + */ +typedef enum _iJIT_IsProfilingActiveFlags +{ + iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; + * iJIT_NotifyEvent calls will + * not be processed. */ + iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and + * ready to process notifications. */ +} iJIT_IsProfilingActiveFlags; + +/** + * @brief Description of a single entry in the line number information of a code region. + * @details A table of line number entries gives information about how the reported code region + * is mapped to source file. + * Intel(R) VTune(TM) Amplifier uses line number information to attribute + * the samples (virtual address) to a line number. \n + * It is acceptable to report different code addresses for the same source line: + * @code + * Offset LineNumber + * 1 2 + * 12 4 + * 15 2 + * 18 1 + * 21 30 + * + * VTune Amplifier constructs the following table using the client data + * + * Code subrange Line number + * 0-1 2 + * 1-12 4 + * 12-15 2 + * 15-18 1 + * 18-21 30 + * @endcode + */ +typedef struct _LineNumberInfo +{ + unsigned int Offset; /**<\brief Offset from the begining of the code region. */ + unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ + +} *pLineNumberInfo, LineNumberInfo; + +/** + * @brief Enumerator for the code architecture. + */ +typedef enum _iJIT_CodeArchitecture +{ + iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ + + iJIT_CA_32, /**<\brief 32-bit machine code. */ + + iJIT_CA_64 /**<\brief 64-bit machine code. */ + +} iJIT_CodeArchitecture; + +#pragma pack(push, 8) + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, data provided with + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table.0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + unsigned int class_id; /**<\brief This field is obsolete. */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Load, iJIT_Method_Load; + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load_V2 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V2 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + The module name can be useful for distinguishing among + different JIT engines. VTune Amplifier will display + reported methods grouped by specific module. */ + +} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; + +/** + * @brief Description of a JIT-compiled method + * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 + * with a newly introduced 'arch' field that specifies architecture of the code region. + * When you use the iJIT_Method_Load_V3 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V3 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise they are + * treated as regions of different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Cannot be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + * The module name can be useful for distinguishing among + * different JIT engines. VTune Amplifier will display + * reported methods grouped by specific module. */ + + iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. + * By default, it is the same as the process + * architecture that is calling it. + * For example, you can use it if your 32-bit JIT + * engine generates 64-bit code. + * + * If JIT engine reports both 32-bit and 64-bit types + * of methods then VTune Amplifier splits the methods + * with the same module name but with different + * architectures in two different modules. VTune Amplifier + * modifies the original name provided with a 64-bit method + * version by ending it with '(64)' */ + +} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; + +/** + * @brief Description of an inline JIT-compiled method + * @details When you use the_iJIT_Method_Inline_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Inline_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. + * Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /** <\brief The virtual address on which the method + * is inlined. If NULL, then data provided with + * the event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; + +/** @cond exclude_from_documentation */ +/** + * @brief Description of a segment type + * @details Use the segment type to specify a type of data supplied + * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to + * a certain code trace. + */ +typedef enum _iJIT_SegmentType +{ + iJIT_CT_UNKNOWN = 0, + + iJIT_CT_CODE, /**<\brief Executable code. */ + + iJIT_CT_DATA, /**<\brief Data (not executable code). + * VTune Amplifier uses the format string + * (see iJIT_Method_Update) to represent + * this data in the VTune Amplifier GUI */ + + iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. + * Can be used for the following + * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, + * if the type of the previously reported segment + * type is the same. */ + iJIT_CT_EOF +} iJIT_SegmentType; + +/** + * @brief Description of a dynamic update of the content within JIT-compiled method + * @details The JIT engine may generate the methods that are updated at runtime + * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update + * structure to describe the update of the content within a JIT-compiled method, + * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. + * + * On the first Update event, VTune Amplifier copies the original code range reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and + * adds the modified range to the original method. For next update events, VTune Amplifier + * does the same but it uses the latest modified version of a code region for update. + * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event. + * Notes: + * - Multiple update events with different types for the same trace are allowed + * but they must be reported for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [code] Ignored + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + * - The types of previously reported events can be changed but they must be reported + * for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + */ + +typedef struct _iJIT_Method_Update +{ + void* load_address; /**<\brief Start address of the update within a method */ + + unsigned int size; /**<\brief The update size */ + + iJIT_SegmentType type; /**<\brief Type of the update */ + + const char* data_format; /**<\brief C string that contains a format string + * that follows the same specifications as format in printf. + * The format string is used for iJIT_CT_CODE only + * and cannot be NULL. + * Format can be changed on the fly. */ +} *piJIT_Method_Update, iJIT_Method_Update; + +/** @endcond */ + +#pragma pack(pop) + +/** @cond exclude_from_documentation */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef JITAPI_CDECL +# if defined WIN32 || defined _WIN32 +# define JITAPI_CDECL __cdecl +# else /* defined WIN32 || defined _WIN32 */ +# if defined _M_IX86 || defined __i386__ +# define JITAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define JITAPI_CDECL /* actual only on x86_64 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* defined WIN32 || defined _WIN32 */ +#endif /* JITAPI_CDECL */ + +#define JITAPI JITAPI_CDECL +/** @endcond */ + +/** + * @brief Generates a new unique method ID. + * + * You must use this API to obtain unique and valid method IDs for methods or + * traces reported to the agent if you don't have your own mechanism to generate + * unique method IDs. + * + * @return a new unique method ID. When out of unique method IDs, this API + * returns 0, which is not an accepted value. + */ +unsigned int JITAPI iJIT_GetNewMethodID(void); + +/** + * @brief Returns the current mode of the agent. + * + * @return iJIT_SAMPLING_ON, indicating that agent is running, or + * iJIT_NOTHING_RUNNING if no agent is running. + */ +iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); + +/** + * @brief Reports infomation about JIT-compiled code to the agent. + * + * The reported information is used to attribute samples obtained from any + * Intel(R) VTune(TM) Amplifier collector. This API needs to be called + * after JIT compilation and before the first entry into the JIT-compiled + * code. + * + * @param[in] event_type - type of the data sent to the agent + * @param[in] EventSpecificData - pointer to event-specific data + * + * @returns 1 on success, otherwise 0. + */ +int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +/** @endcond */ + +/** @} jitapi group */ + +#endif /* __JITPROFILING_H__ */ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp new file mode 100644 index 0000000000..ef4c42bacf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp @@ -0,0 +1,317 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nchw_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +void nchw_pooling_fwd_t<data_type>::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t ws_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[ws_offset] = value; + } else + reinterpret_cast<int *>(ws)[ws_offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + auto s = src[src_offset]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH + : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + auto src_offset + = (size_t)IW * IH * ID * C * mb + + (size_t)IW * IH * ID * c + + (size_t)IW * IH * id + + (size_t)IW * ih + + (size_t)iw; + d[0] += src[src_offset]; + } + } + } + + d[0] = math::out_round<data_t>((float)d[0] / num_summands); + }; + + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = nstl::numeric_limits<data_t>::lowest(); + set_ws(mb, c, od, oh, ow, 0); + ker_max(d, mb, c, od, oh, ow); + }); + } else { + parallel_nd(MB, C, OD, OH, OW, + [&](int mb, int c, int od, int oh, int ow) { + size_t dst_offset + = (size_t)OW * OH * OD * C * mb + + (size_t)OW * OH * OD * c + + (size_t)OW * OH * od + + (size_t)OW * oh + + (size_t)ow; + data_t *d = &dst[dst_offset]; + d[0] = 0; + ker_avg(d, mb, c, od, oh, ow); + }); + } +} + +template <impl::data_type_t data_type> +void nchw_pooling_bwd_t<data_type>::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int mb, int c) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW; + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_offset++] = 0; + } + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto b_c = ws_d.blocking_desc().inner_nblks == 0 + ? 1 : ws_d.blocking_desc().inner_blks[0]; + auto ws_offset = is_3d + ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c + : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c; + + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + size_t diff_src_offset = + (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + size_t num_summands = (alg == pooling_avg_include_padding) + ? (size_t)KW*KH*KD + : (size_t)(id_end - id_start)*(ih_end - ih_start) + *(iw_end - iw_start); + + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + + (size_t)c*ID*IH*IW + (size_t)id*IH*IW + + (size_t)ih*IW + (size_t)iw; + diff_src[diff_src_offset] += d[0] / num_summands; + } + } + } + }; + + if (pd()->desc()->alg_kind == pooling_max) { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_max(d, mb, c, od, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, C, [&](int mb, int c) { + size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW + + (size_t)c*OD*OH*OW; + ker_zero(mb, c); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = &diff_dst[diff_dst_offset++]; + ker_avg(d, mb, c, od, oh, ow); + } + } + } + }); + } +} + +template struct nchw_pooling_fwd_t<data_type::f32>; +template struct nchw_pooling_bwd_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp new file mode 100644 index 0000000000..bbdd04f6b9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp @@ -0,0 +1,147 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCHW_POOLING_HPP +#define CPU_NCHW_POOLING_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +struct nchw_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct nchw_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && !has_zero_dim_memory() + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + bool ws_ok = true + && hint_fwd_pd_ + && hint_fwd_pd_->workspace_md(); + if (!ws_ok) + return status::unimplemented; + + const auto &ws_blk = + hint_fwd_pd_->workspace_md()->format_desc.blocking; + ws_ok = ws_ok + && ws_blk.inner_nblks < 1 + && IMPLICATION(ws_blk.inner_nblks == 1, + ws_blk.inner_idxs[0] == 1); + if (!ws_ok) + return status::unimplemented; + + ws_md_ = *hint_fwd_pd_->workspace_md(); + } + + return status::success; + } + }; + + nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp new file mode 100644 index 0000000000..c0e93fefe4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp @@ -0,0 +1,382 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "ncsp_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void ncsp_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool calculate_stats = !pd()->stats_is_src(); + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = scratchpad.get<data_t>(key_bnorm_tmp_mean); + variance = scratchpad.get<data_t>(key_bnorm_tmp_var); + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t N = pd()->MB(); + dim_t C = pd()->C(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off; + + if (calculate_stats) { + data_t *mean_blk = mean + C_off; + data_t *variance_blk = variance + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = (c + C_off) * SP; + data_t sum = 0; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + sum += src[off + n * C * SP + sp]; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + mean_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + mean_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + mean_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sum = 0.; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : sum)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + data_t m = src[off * SP + n * C * SP + sp] + - mean[off]; + sum += m * m; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = sum; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + variance_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) + variance_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + variance_blk[c] /= (N * SP); + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + } + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t sqrt_variance + = static_cast<data_t>(sqrtf(variance[off] + eps)); + data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + off] : 0; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + size_t d_off = off * SP + n * C * SP + sp; + data_t bn_res + = sm * (src[d_off] - mean[off]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void ncsp_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + if (diff_scaleshift == nullptr) + diff_scaleshift = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss); + + const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5); + dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1; + dim_t C = pd()->C(), N = pd()->MB(); + const bool use_scaleshift = pd()->use_scaleshift(); + const float eps = pd()->desc()->batch_norm_epsilon; + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + int nthr = mkldnn_get_max_threads(); + size_t l3_size_ = get_cache_size(3, true) * nthr / 2; + size_t data_size = N * C * SP * sizeof(data_t); + bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0); + + parallel(0, [&](const int ithr, const int nthr) { + int C_ithr = 0, C_nthr = 0; + int N_ithr = 0, N_nthr = 0; + int S_ithr = 0, S_nthr = 0; + + dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0; + dim_t N_s = 0, N_e = 0; + dim_t S_s = 0, S_e = 0; + + dim_t C_blks_per_iter = 1; + int64_t iters = 1; + + if (do_blocking) { + size_t working_set_size = 2 * N * SP * sizeof(data_t); + bnorm_utils::cache_balance( + working_set_size, C, C_blks_per_iter, iters); + } else + C_blks_per_iter = C; + int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter; + bool spatial_thr_allowed + = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N, + C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e, + N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e); + balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + int SP_N_ithr = N_ithr * S_nthr + S_ithr; + int SP_N_nthr = N_nthr * S_nthr; + + for (int64_t it = 0; it < iters; ++it) { + if (it == iters - 1 && iters > 1) { + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). So sync the + // threads if they are not synced by the algorithm. + if (SP_N_nthr == 1 && mkldnn_thr_syncable()) + mkldnn_thr_barrier(); + + C_blk_s = C_blk_e = N_s = N_e = 0; + spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking, + spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP, + C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, + N_e, S_ithr, S_nthr, S_s, S_e); + balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e); + SP_N_ithr = N_ithr * S_nthr + S_ithr; + SP_N_nthr = N_nthr * S_nthr; + } + size_t C_off = it * C_blks_per_iter; + // On the last iteration the access pattern to ws_reduce + // might change (due to re-balance on C). Since sync is not always + // possible (in case of TBB) use different parts of ws for each + // iteration if threads are not synced by the algorithm. + size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off; + + data_t *diff_gamma_blk = diff_scaleshift + C_off; + data_t *diff_beta_blk = diff_scaleshift + C + C_off; + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t diff_gamma = 0.0, diff_beta = 0.0; + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) + PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta)) + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + diff_gamma += (src[d_off] - v_mean) * dd; + diff_beta += dd; + } + ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c] + = diff_gamma; + ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter + + SP_N_ithr * C_blks_per_iter + c] = diff_beta; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) { + data_t sqrt_variance = static_cast<data_t>( + 1.0f / sqrtf(variance[c + C_off] + eps)); + diff_gamma_blk[c] = 0.; + diff_beta_blk[c] = 0.; + for (dim_t n = 0; n < SP_N_nthr; n++) { + diff_gamma_blk[c] += ws_reduce[ws_iter_off + + n * C_blks_per_iter + c]; + diff_beta_blk[c] += ws_reduce[ws_iter_off + + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter + + c]; + } + diff_gamma_blk[c] *= sqrt_variance; + } + + if (SP_N_nthr > 1) mkldnn_thr_barrier(); + + for (dim_t c = C_blk_s; c < C_blk_e; c++) { + size_t off = c + C_off; + data_t gamma = use_scaleshift ? scaleshift[off] : 1; + data_t sqrt_variance + = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps)); + data_t v_mean = mean[off]; + for (dim_t n = N_s; n < N_e; ++n) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t sp = S_s; sp < S_e; ++sp) { + const size_t d_off = off * SP + n * C * SP + sp; + + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_blk[c] / (SP * N) + + (src[d_off] - v_mean) * diff_gamma_blk[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp new file mode 100644 index 0000000000..97ca3b003f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp @@ -0,0 +1,160 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP +#define CPU_NCSP_BATCH_NORMALIZATION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + using namespace format_tag; + + bool ok = true + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * C() * mkldnn_get_max_threads()); + + if (!is_training()) { + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C()); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C()); + } + } + } + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t; + + DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace format_tag; + + bool ok = true + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc) + && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward)) + scratchpad.book(key_bnorm_tmp_diff_ss, + sizeof(data_t) * 2 * C()); + } + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~ncsp_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp new file mode 100644 index 0000000000..38cfb28dce --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp @@ -0,0 +1,392 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" + +#include "nhwc_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +#define MEM_D(name) name##_d + +#define DECLARE_READ_STRIDES(name) \ + const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \ + const size_t name##_d_stride = (!is_3d) \ + ? 0 \ + : MEM_D(name).blocking_desc().strides[2]; \ + const size_t name##_h_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[2] \ + : MEM_D(name).blocking_desc().strides[3]; \ + const size_t name##_w_stride = (!is_3d) \ + ? MEM_D(name).blocking_desc().strides[3] \ + : MEM_D(name).blocking_desc().strides[4]; + +namespace nhwc_pooling { + size_t strided_offset(const int _n, const size_t _sn, + const int _d, const size_t _sd, + const int _h, const size_t _sh, + const int _w, const size_t _sw) + { + return _n * _sn + + _d * _sd + + _h * _sh + + _w * _sw; + } +} + +template <impl::data_type_t data_type> +void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n, + const data_t *src, const size_t num, data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + float ftmp = (float)src[i]; + ftmp = ftmp / num; + dst[i] = math::out_round<data_t>(ftmp); + } +} + +template <impl::data_type_t data_type> +void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src, + data_t *dst) const +{ + for (int i = 0; i < n; ++i) + { + dst[i] += src[i]; + } +} + +template <impl::data_type_t data_type> +void nhwc_pooling_fwd_t<data_type>::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + using namespace nhwc_pooling; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper MEM_D(src)(pd()->src_md()); + const memory_desc_wrapper MEM_D(dst)(pd()->dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + DECLARE_READ_STRIDES(src); + DECLARE_READ_STRIDES(dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + parallel_nd(MB, OD, OH, OW, + [&](int mb, int od, int oh, int ow) { + size_t dst_offset_init = strided_offset(mb, dst_n_stride, + od, dst_d_stride, + oh, dst_h_stride, + ow, dst_w_stride); + if (alg == pooling_max) { + size_t ws_offset_init = 0; + if (ws) + { + DECLARE_READ_STRIDES(ws); + ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + } + // Note: GCC 4.8.5 won't vectorize below + // simple loops unless they are singled out + // into separate helper routines: + // array_nhwc_initialize, array_nhwc_max + if (!ws) + array_nhwc_initialize<false>(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + else + array_nhwc_initialize<true>(OC, dst + dst_offset_init, + ws, ws_offset_init, ws_dt); + + + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) + continue; + if (ih < 0 || ih >= IH) + continue; + if (iw < 0 || iw >= IW) + continue; + + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + + if (!ws) + array_nhwc_max<false>(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + else + array_nhwc_max<true>(OC, + dst + dst_offset_init, + src + src_offset_init, + ws, ws_offset_init, + ws_dt, + kd * KH * KW + kh * KW + kw + ); + } + } else { + // pooling_avg + auto d = dst + dst_offset_init; + + utils::array_set(d, 0, OC); + + auto id_start = apply_offset(od * SD, padF); + auto ih_start = apply_offset(oh * SH, padT); + auto iw_start = apply_offset(ow * SW, padL); + auto id_end = nstl::min(od * SD - padF + KD, ID); + auto ih_end = nstl::min(oh * SH - padT + KH, IH); + auto iw_end = nstl::min(ow * SW - padL + KW, IW); + + // it is cheaper to actually count this in a loop + // as the typical kernel is small + size_t num_summands = 0; + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + size_t src_offset_init = strided_offset(mb, src_n_stride, + id, src_d_stride, + ih, src_h_stride, + iw, src_w_stride); + auto s = src + src_offset_init; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_add(OC, s, d); + + num_summands++; + } + + num_summands = (alg == pooling_avg_include_padding) ? + KW * KH * KD : num_summands; + + // need to move the loop to separate function + // for GCC 4.8.5 to vectorize + array_div_by_const(OC, d, num_summands, d); + } + }); +} + +template <impl::data_type_t data_type> +void nhwc_pooling_bwd_t<data_type>::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace nhwc_pooling; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md()); + const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md()); + const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int OC = pd()->C(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + auto alg = pd()->desc()->alg_kind; + + DECLARE_READ_STRIDES(diff_src); + DECLARE_READ_STRIDES(diff_dst); + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + const int MB = pd()->MB(); + + parallel_nd(MB, ID, IH, IW, + [&](int mb, int id, int ih, int iw) { + size_t src_offset_init = strided_offset(mb, diff_src_n_stride, + id, diff_src_d_stride, + ih, diff_src_h_stride, + iw, diff_src_w_stride); + + // check if kernel windows are disjoint, in this case there's no + // update needed and we just write there once, no initialization + // required. + if (!(KD == SD && KH == SH && KW == SW)) + for (int oc = 0; oc < OC; ++oc) + diff_src[src_offset_init + oc] = data_type_t(0); + + // Find out which output cells may correspond to current + // input position. Current input postition divided by + // stride, with integer divide rounding down, is the + // right-most output. + // Left-most output may be computed if we decrement input + // by (kernel_size - 1) and then do the same division by + // stride. + int od_left = nstl::max((id + padF - KD + 1) / SD, 0); + int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0); + int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0); + // Notice +1 here to preserve the C loop "less than" + // condition for continuing the for loop. + int od_right = nstl::min((id + padF) / SD + 1 , OD); + int oh_right = nstl::min((ih + padT) / SH + 1 , OH); + int ow_right = nstl::min((iw + padL) / SW + 1 , OW); + + for (int od = od_left; od < od_right; ++od) + for (int oh = oh_left; oh < oh_right; ++oh) + for (int ow = ow_left; ow < ow_right; ++ow) { + const int kd = id - od*SD + padF; + const int kh = ih - oh*SH + padT; + const int kw = iw - ow*SW + padL; + + if (kd < 0 || kd >= KD) + continue; + if (kh < 0 || kh >= KH) + continue; + if (kw < 0 || kw >= KW) + continue; + + size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride, + od, diff_dst_d_stride, + oh, diff_dst_h_stride, + ow, diff_dst_w_stride); + + if (alg == pooling_max) { + DECLARE_READ_STRIDES(ws); + size_t ws_offset_init = strided_offset(mb, ws_n_stride, + od, ws_d_stride, + oh, ws_h_stride, + ow, ws_w_stride); + const int index = kd * KH * KW + kh * KW + kw; + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const int index_from_ws = + (MEM_D(ws).data_type() == data_type::u8) + ? (int)ws[ws_offset_init + oc] + : ((int *)ws)[ws_offset_init + oc]; + + const data_t d = diff_dst[dst_offset_init + oc]; + + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += + (index_from_ws == index) + ? d + : data_type_t(0); + else + diff_src[src_offset_init + oc] = + (index_from_ws == index) + ? d + : data_type_t(0); + } + } else { + // pooling_avg + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) + ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < OC; ++oc) { + const data_t d = diff_dst[dst_offset_init + oc]; + // Check if kernel windows are disjoint, in this case + // there's no update needed and we just write there once + // otherwise we add value to the contents. + if (!(KD == SD && KH == SH && KW == SW)) + diff_src[src_offset_init + oc] += d / num_summands; + else + diff_src[src_offset_init + oc] = d / num_summands; + } + } + } + }); +} + +template struct nhwc_pooling_fwd_t<data_type::f32>; +template struct nhwc_pooling_bwd_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp new file mode 100644 index 0000000000..7e33b6869f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NHWC_POOLING_HPP +#define CPU_NHWC_POOLING_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace nhwc_pooling { +size_t strided_offset(const int _n, const size_t _sn, const int _d, + const size_t _sd, const int _h, const size_t _sh, const int _w, + const size_t _sw); +} + +template <impl::data_type_t data_type> +struct nhwc_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + src_md()->data_type, + dst_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*src_md(), desired_fmt_tag) + && memory_desc_matches_tag(*dst_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void array_div_by_const(const int n, const data_t *src, const size_t num, + data_t *dst) const; + void array_add(const int n, const data_t *src, data_t *dst) const; + + template <bool use_workspace> + void array_nhwc_max(const int n, data_t *dst, const data_t *src, + unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt, + const int index) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + PRAGMA_OMP_SIMD() + for (int oc = 0; oc < n; ++oc) { + auto s = src[oc]; + data_t mv = dst[oc]; + + // update index of maximum +#if defined __INTEL_COMPILER + if ((use_workspace) && (s > mv)) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + ws[ws_offset + oc] = index; + } else + reinterpret_cast<int *>(ws)[ws_offset + oc] = index; + } +#else + // Need to add explicit predicates for GCC to vectorize this. + // And although the resulting code is ugly, it is still 4 times + // faster than scalar + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + + if (ws_dt == data_type::u8) { + assert(0 <= index && index <= 255); + unsigned char predicate = (s > mv) ? 0xff : 0; + unsigned char current_value = ws[ws_offset + oc]; + current_value = (predicate & (unsigned char)index) + | ((~predicate) & current_value); + ws[ws_offset + oc] = current_value; + } else { + auto wint = reinterpret_cast<int *>(ws); + unsigned int predicate = (s > mv) ? 0xffffffff : 0; + unsigned int current_value = wint[ws_offset + oc]; + current_value = (predicate & (unsigned int)index) + | ((~predicate) & current_value); + wint[ws_offset + oc] = current_value; + } + } +#endif + // update maximum + dst[oc] = nstl::max(s, mv); + } + } + + template <bool use_workspace> + void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws, + const size_t ws_offset, const data_type_t ws_dt) const { + assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists + for (int oc = 0; oc < n; ++oc) { + if (use_workspace) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + if (ws_dt == data_type::u8) { + ws[ws_offset + oc] = 0; + } else + reinterpret_cast<int *>(ws)[ws_offset + oc] = 0; + } + dst[oc] = nstl::numeric_limits<data_t>::lowest(); + } + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct nhwc_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t); + + status_t init() { + const format_tag_t desired_fmt_tag = + ndims() == 4 ? format_tag::nchw : format_tag::ncdhw; + + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, + alg_kind::pooling_avg_include_padding, + alg_kind::pooling_avg_exclude_padding) + && utils::everyone_is(data_type, + diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values() + && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) + && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +}// namespace cpu +}// namespace impl +}// namespace mkldnn + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp new file mode 100644 index 0000000000..e20333e66f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "nspc_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void nspc_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + const bool with_relu = pd()->with_relu_post_op(); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean); + auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var); + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = tmp_mean; + variance = tmp_var; + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->H() * pd()->W() * pd()->D(); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; + data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; + + if (calculate_stats) { + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] = 0.; + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + + sp * C + c]; + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + mean[c] = 0; + for (dim_t n = 0; n < nthr; n++) + mean[c] += ws_reduce[C * n + c]; + mean[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + mean_loc[c] = mean[c]; + ws_reduce[C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) { + data_t m = src[(size_t)n * SP * C + sp * C + c] + - mean_loc[c]; + ws_reduce[C * ithr + c] += m * m; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + variance[c] = 0; + for (dim_t n = 0; n < nthr; n++) + variance[c] += ws_reduce[C * n + c]; + variance[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) + variance_loc[c] = variance[c]; + } else { + variance_loc = variance; + mean_loc = mean; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + data_t sqrt_variance = static_cast<data_t>( + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + c] : 0; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void nspc_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss); + + if (diff_scaleshift == nullptr) + diff_scaleshift = tmp_diff_ss; + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->D() * pd()->H() * pd()->W(); + data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + + data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; + data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); + + for (dim_t c = 0; c < C; c++) { + ws_reduce[C * ithr + c] = 0.; + ws_reduce[C * nthr + C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; + ws_reduce[C * nthr + C * ithr + c] += dd; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + data_t sqrt_variance + = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps)); + diff_gamma[c] = 0; + diff_beta[c] = 0; + for (dim_t n = 0; n < nthr; n++) { + diff_gamma[c] += ws_reduce[C * n + c]; + diff_beta[c] += ws_reduce[C * nthr + C * n + c]; + } + diff_gamma[c] *= sqrt_variance; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + diff_gamma_loc[c] = diff_gamma[c]; + diff_beta_loc[c] = diff_beta[c]; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t gamma = use_scaleshift ? scaleshift[c] : 1; + data_t sqrt_variance + = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps)); + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_loc[c] / (SP * N) + + (src[d_off] - mean[c]) * diff_gamma_loc[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp new file mode 100644 index 0000000000..aad86b05a7 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp @@ -0,0 +1,169 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP +#define CPU_NSPC_BATCH_NORMALIZATION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct nspc_batch_normalization_fwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_fwd() + && !has_zero_dim_memory() + && src_md()->data_type == f32 + && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && (attr()->has_default_values() || this->with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + if (!stats_is_src()) { + dim_t sz = nstl::max<dim_t>(C(), 16) * mkldnn_get_max_threads(); + scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz); + scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz); + } + } + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_fwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +struct nspc_batch_normalization_bwd_t : public cpu_primitive_t { + struct pd_t : public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t); + + status_t init() { + using namespace data_type; + using namespace prop_kind; + + bool ok = true + /* the algorithm requires barriers while switching + * between parallelization over N and C dimensions */ + && mkldnn_thr_syncable() + && is_bwd() + && !has_zero_dim_memory() + && utils::everyone_is(f32, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), + utils::everyone_is(f32, + weights_md()->data_type, + diff_weights_md()->data_type)) + && memory_desc_matches_tag(*src_md(), format_tag::nhwc) + && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_bnorm_reduction, + sizeof(data_t) * 2 * C() * mkldnn_get_max_threads()); + scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C() + * (mkldnn_get_max_threads() + 1)); + } + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + ~nspc_batch_normalization_bwd_t() {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp new file mode 100644 index 0000000000..d79b1a034b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp @@ -0,0 +1,265 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "simple_q10n.hpp" + +#include "ref_batch_normalization.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +void ref_batch_normalization_fwd_t<data_type>::execute_forward( + const exec_ctx_t &ctx) const { + /* fast return */ + if (this->pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT); + + auto mean = pd()->stats_is_src() + ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN); + auto variance = pd()->stats_is_src() + ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE)) + : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE); + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift();; + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + + const bool with_relu = pd()->with_relu_post_op(); + auto maybe_post_op = [&](float res) { + return (with_relu && res < 0.0f) ? 0.0f : res; + }; + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + float v_mean = calculate_stats ? 0 : mean[c]; + float v_variance = calculate_stats ? 0 : variance[c]; + + if (calculate_stats) { + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) + v_mean += src[data_offset(data_d, n, c, d, h, w)]; + v_mean /= W*N*H*D; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean; + v_variance += m*m; + } + v_variance /= W*H*N*D; + } + + float sqrt_variance = sqrtf(v_variance + eps); + float sm = (use_scaleshift + ? scaleshift[scaleshift_d.off(0, c)] + : 1.0f) / sqrt_variance; + float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + auto d_off = data_offset(data_d,n,c,d,h,w); + float bn_res = sm * ((float)src[d_off] - v_mean) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + if (data_type == data_type::s8) { + dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res)); + } else { + dst[d_off] = static_cast<data_t>(maybe_post_op(bn_res)); + } + } + + if (calculate_stats) { + if (save_stats) { + mean[c] = v_mean; + variance[c] = v_variance; + } + } + }); +} + +template struct ref_batch_normalization_fwd_t<data_type::f32>; +template struct ref_batch_normalization_fwd_t<data_type::s8>; + +template <impl::data_type_t data_type> +void ref_batch_normalization_bwd_t<data_type>::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + const memory_desc_wrapper scaleshift_d(pd()->weights_md()); + const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md()); + + const dim_t C = pd()->C(); + + /* fast return */ + if (this->pd()->has_zero_dim_memory()) { + if (diff_scaleshift) { + for (dim_t c = 0; c < C; ++c) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0; + } + } + return; + } + + const dim_t N = pd()->MB(); + dim_t H = 1, W = 1, D = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5); + if (has_spatial) { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + } + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + const bool is_3d = data_d.ndims() == 5; + + auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c, + dim_t d, dim_t h, dim_t w) { + if (has_spatial) { + if (is_3d) + return data_d.off(n, c, d, h, w); + else + return data_d.off(n, c, h, w); + } else + return data_d.off(n, c); + }; + + parallel_nd(C, [&](dim_t c) { + data_t v_mean = mean[c]; + data_t v_variance = variance[c]; + data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps)); + data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1; + data_t diff_gamma = data_t(0); + data_t diff_beta = data_t(0); + diff_gamma = 0.0; + diff_beta = 0.0; + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + diff_gamma += (src[s_off] - v_mean) * dd; + diff_beta += dd; + } + diff_gamma *= sqrt_variance; + + if (diff_scaleshift) { + diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma; + diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta; + } + + for (dim_t n = 0; n < N; ++n) + for (dim_t d = 0; d < D; ++d) + for (dim_t h = 0; h < H; ++h) + for (dim_t w = 0; w < W; ++w) { + const size_t s_off = data_offset(data_d, n, c, d, h, w); + const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w); + data_t dd = diff_dst[dd_off]; + if (fuse_bn_relu && !ws[s_off]) + dd = 0; + + data_t v_diff_src = dd; + if (calculate_diff_stats) { + v_diff_src -= diff_beta/(D*W*H*N) + + (src[s_off] - v_mean) * + diff_gamma*sqrt_variance/(D*W*H*N); + } + v_diff_src *= gamma*sqrt_variance; + diff_src[dd_off] = v_diff_src; + } + }); +} + +template struct ref_batch_normalization_bwd_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp new file mode 100644 index 0000000000..aa9f74125a --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_BATCH_NORMALIZATION_HPP +#define CPU_REF_BATCH_NORMALIZATION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_batch_normalization_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +struct ref_batch_normalization_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_fwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && IMPLICATION(use_scaleshift(), + weights_md()->data_type == data_type::f32) + && (attr()->has_default_values() || with_relu_post_op()); + if (!ok) return status::unimplemented; + + if (src_md()->data_type == data_type::s8 && !stats_is_src()) + return status::unimplemented; + + if (is_training() && fuse_bn_relu()) init_default_ws(8); + + return status::success; + } + }; + + ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct ref_batch_normalization_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_batch_normalization_bwd_pd_t { + pd_t(engine_t *engine, const batch_normalization_desc_t *adesc, + const primitive_attr_t *attr, + const batch_normalization_fwd_pd_t *hint_fwd_pd) + : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) + {} + + DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t); + + status_t init() { + bool ok = true + && is_bwd() + && utils::everyone_is(data_type, src_md()->data_type, + diff_src_md()->data_type) + && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type, + weights_md()->data_type, + diff_weights_md()->data_type)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (fuse_bn_relu()) { + init_default_ws(8); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp new file mode 100644 index 0000000000..4c534b5508 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_CONCAT_HPP +#define REF_CONCAT_HPP + +#include "reorder_pd.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_CONCAT_PD_T("ref:any", ref_concat_t); + + status_t init() { + bool ok = cpu_concat_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + const primitive_attr_t attr; /* alpha == 1. */ + reorder_pd_t *r_pd = nullptr; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, src_image_md(i)) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector<const reorder_pd_t *> reorder_pds_; + }; + + ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_concat_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector<primitive_t *> reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp new file mode 100644 index 0000000000..c0a979c4cf --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp @@ -0,0 +1,395 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "type_helpers.hpp" + +#include "ref_convolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template <data_type_t src_type, data_type_t wei_type, + data_type_t dst_type, data_type_t acc_type> +void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool with_relu = 0; // TODO: change if support post_ops + const float nslope = 0.f; + + const int ndims = pd()->desc()->src_desc.ndims; + + auto ker = [=](int g, int mb, int oc, int od, int oh, + int ow) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + const int id = od * KSD - padFront + kd * (1 + KDD); + const int ih = oh * KSH - padT + kh * (1 + KDH); + const int iw = ow * KSW - padL + kw * (1 + KDW); + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + if (ndims == 5) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)] + * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + + } + return d; + }; + + parallel_nd(G, MB, OC, OD, OH, OW, + [&](int g, int mb, int oc, int od, int oh, int ow) { + float a = bias + ? get_bias(bias, bias_d.off(g * OC + oc), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, oc, od, oh, ow); + if (with_relu && a < 0) + a = a * nslope; + if (ndims == 5) + dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate<dst_data_t>(a); + else if (ndims == 4) + dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate<dst_data_t>(a); + else if (ndims == 3) + dst[dst_d.off(mb, g*OC + oc, ow)] = saturate<dst_data_t>(a); + else + assert(false); + }); +} + +template <data_type_t diff_src_type, data_type_t wei_type, + data_type_t diff_dst_type, data_type_t acc_type> +void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type, + acc_type>::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->diff_src_desc.ndims; + + auto ker = [=](int g, int mb, int ic, int id, int ih, + int iw) { + acc_data_t d = 0; + for (int oc = 0; oc < OC; ++oc) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + if (iw + padL < kw * (1 + KDW) + || ih + padT < kh * (1 + KDH) + || id + padFront < kd * (1 + KDD)) + continue; + int ow = iw - kw * (1 + KDW) + padL; + int oh = ih - kh * (1 + KDH) + padT; + int od = id - kd * (1 + KDD) + padFront; + if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) + continue; + + ow /= KSW; + oh /= KSH; + od /= KSD; + + if (od < OD && oh < OH && ow < OW) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, od, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kd, kh, kw)] + : weights[weights_d.off(oc, ic, kd, kh, kw)]); + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, oh, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kh, kw)] + : weights[weights_d.off(oc, ic, kh, kw)]); + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + + oc, ow)] * (with_groups + ? weights[weights_d.off(g, oc, ic, kw)] + : weights[weights_d.off(oc, ic, kw)]); + else + assert(false); + } + } + return d; + }; + + parallel_nd(G, MB, IC, ID, IH, IW, + [&](int g, int mb, int ic, int id, int ih, int iw) { + auto ds_idx = (ndims == 5) + ? diff_src_d.off(mb, g*IC + ic, id, ih, iw) + : (ndims == 4) + ? diff_src_d.off(mb, g*IC + ic, ih, iw) + : diff_src_d.off(mb, g*IC + ic, iw); + float a = bias + ? get_bias(bias, bias_d.off(g * IC + ic), + pd()->desc()->bias_desc.data_type) + : 0; + a += ker(g, mb, ic, id, ih, iw); + diff_src[ds_idx] = saturate<diff_src_data_t>(a); + }); +} + +template <data_type_t src_type, data_type_t diff_wei_type, + data_type_t diff_dst_type, data_type_t acc_type> +void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type, + acc_type>::execute_backward_weights(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const bool with_groups = pd()->with_groups(); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + + const int OC = pd()->OC() / G; + const int IC = pd()->IC() / G; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + + const int KSD = pd()->KSD(); + const int KSH = pd()->KSH(); + const int KSW = pd()->KSW(); + + const int KDD = pd()->KDD(); + const int KDH = pd()->KDH(); + const int KDW = pd()->KDW(); + + const int padFront = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const int ndims = pd()->desc()->src_desc.ndims; + +auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ow*KSW + kw * (1 + KDW) < padL + || oh*KSH + kh * (1 + KDH) < padT + || od*KSD + kd * (1 + KDD) < padFront + || ow*KSW + kw * (1 + KDW) >= IW + padL + || oh*KSH + kh * (1 + KDH) >= IH + padT + || od*KSD + kd * (1 + KDD) >= ID + padFront) + continue; + + int id = od*KSD - padFront + kd * (1 + KDD); + int ih = oh*KSH - padT + kh * (1 + KDH); + int iw = ow*KSW - padL + kw * (1 + KDW); + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, + oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)] + * src[src_d.off(mb, g*IC + ic, ih, iw)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)] + * src[src_d.off(mb, g*IC + ic, iw)]; + else + assert(false); + } + }; + + auto ker_bias = [=](acc_data_t &d, int g, int oc) { + for (int mb = 0; mb < MB; ++mb) + for (int od = 0; od < OD; ++od) + for (int oh = 0; oh < OH; ++oh) + for (int ow = 0; ow < OW; ++ow) { + if (ndims == 5) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh, + ow)]; + else if (ndims == 4) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, + ow)]; + else if (ndims == 3) + d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]; + else + assert(false); + } + }; + + parallel_nd(G, OC, [&](int g, int oc) { + if (diff_bias) { + // XXX: loss of precision when bias is a float... + acc_data_t db = 0; + ker_bias(db, g, oc); + diff_bias[diff_bias_d.off(g*OC+oc)] + = saturate<diff_wei_data_t>(db); + } + + for (int ic = 0; ic < IC; ++ic) + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t dw = 0; + ker(dw, g, oc, ic, kd, kh, kw); + + if (ndims == 5) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kd, kh, kw) + : diff_weights_d.off(oc, ic, kd, kh, kw); + diff_weights[idx] = saturate<diff_wei_data_t>(dw); + } else if (ndims == 4) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kh, kw) + : diff_weights_d.off(oc, ic, kh, kw); + diff_weights[idx] = saturate<diff_wei_data_t>(dw); + } else if (ndims == 3) { + auto idx = with_groups + ? diff_weights_d.off(g, oc, ic, kw) + : diff_weights_d.off(oc, ic, kw); + diff_weights[idx] = saturate<diff_wei_data_t>(dw); + } else { + assert(false); + } + } + }); +} + +using namespace data_type; + +template struct ref_convolution_fwd_t<f32>; + +template struct ref_convolution_fwd_t<u8, s8, f32, s32>; +template struct ref_convolution_fwd_t<u8, s8, s32, s32>; +template struct ref_convolution_fwd_t<u8, s8, s8, s32>; +template struct ref_convolution_fwd_t<u8, s8, u8, s32>; + +template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>; + +template struct ref_convolution_bwd_data_t<f32, s8, u8, s32>; +template struct ref_convolution_bwd_data_t<s32, s8, u8, s32>; +template struct ref_convolution_bwd_data_t<s8, s8, u8, s32>; +template struct ref_convolution_bwd_data_t<u8, s8, u8, s32>; + +template struct ref_convolution_bwd_weights_t<f32, f32, f32, f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp new file mode 100644 index 0000000000..7c83d0c6d4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_CONVOLUTION_HPP +#define CPU_REF_CONVOLUTION_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, + impl::data_type_t wei_type = src_type, + impl::data_type_t dst_type = src_type, + impl::data_type_t acc_type = dst_type> +struct ref_convolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_fwd_pd_t { + using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, data_type::undef, + dst_type, acc_type) + && IMPLICATION(with_bias(), true + && IMPLICATION(src_type == u8, + utils::one_of(bias_md_.data_type, f32, s32, s8, u8)) + && IMPLICATION(src_type == f32, + bias_md_.data_type == f32)) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t diff_src_type, impl::data_type_t wei_type, + impl::data_type_t diff_dst_type, + impl::data_type_t acc_type = diff_src_type> +struct ref_convolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_data_pd_t { + using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(diff_src_type, wei_type, data_type::undef, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + + return ok ? status::success : status::unimplemented; + } + + virtual bool support_bias() const override { return true; } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<diff_src_type>::type diff_src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t src_type, impl::data_type_t diff_wei_type, + impl::data_type_t diff_dst_type, + impl::data_type_t acc_type = diff_wei_type> +struct ref_convolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_convolution_bwd_weights_pd_t { + using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t); + + status_t init() { + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, diff_wei_type, diff_wei_type, + diff_dst_type, acc_type) + && set_default_formats() + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + + protected: + bool set_default_formats() { + using namespace format_tag; + auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, goiw, goihw, goidhw) + : utils::pick(ndims() - 3, oiw, oihw, oidhw); + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + }; + + ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<diff_wei_type>::type diff_wei_data_t; + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp new file mode 100644 index 0000000000..541a303aab --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp @@ -0,0 +1,199 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_deconvolution.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OD = pd()->OD(); + const int OC = pd()->OC() / G; + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(MB, G, OC, OD, OH, OW, + [&](int mb, int g, int oc, int od, int oh, int ow) { + auto b = bias[g * OC + oc]; + switch (ndims) { + case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break; + case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break; + case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break; + default: assert(!"invalid dimension size"); + } + }); +} + +void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW()*pd()->OH()*pd()->OD(); + + parallel_nd(MB, OC, [&](int mb, int oc) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + dst[offset] += bias[oc]; + } + }); +} + +template <int blksize> +void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const { + const memory_desc_wrapper dst_d(pd()->dst_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int SP = pd()->OW() * pd()->OH() * pd()->OD(); + + const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0]; + + parallel_nd(MB, utils::div_up(OC, blksize), SP, + [&](int mb, int oc_blk, int sp) { + int oc = oc_blk * blksize; + auto offset = mb * stride_mb + oc * SP + sp * blksize; + const int blk = nstl::min(blksize, OC - oc); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + dst[offset + i] += bias[oc + i]; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst, + data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int G = pd()->G(); + const int MB = pd()->MB(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + const int OC = pd()->OC() / G; + const int OD = pd()->OD(); + const int ndims = pd()->desc()->src_desc.ndims; + + parallel_nd(G, OC, [&](int g, int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + switch (ndims) { + case 5: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, od, oh, ow)]; + break; + case 4: + db += diff_dst[diff_dst_d.off( + mb, g * OC + oc, oh, ow)]; + break; + case 3: + db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)]; + break; + default: assert(!"invalid dimension size"); + } + } + } + } + } + diff_bias[g * OC + oc] = db; + }); +} + +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH()*pd()->OW()*pd()->OD(); + + parallel_nd(OC, [&](int oc) { + data_t db = 0; + for (int mb = 0; mb < MB; ++mb) { + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + auto offset = (size_t)(mb * OC + oc) * SP + sp; + db += diff_dst[offset]; + } + } + diff_bias[oc] = db; + }); +} + +template <int blksize> +void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const { + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + + const int OC = pd()->OC(); + const int MB = pd()->MB(); + const int SP = pd()->OH() * pd()->OW() * pd()->OD(); + + const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0]; + + parallel_nd(utils::div_up(OC, blksize), [&](int ocb) { + data_t db[blksize] = {0}; + + for (int mb = 0; mb < MB; ++mb) { + for (int sp = 0; sp < SP; ++sp) { + auto offset = mb * stride_mb + (ocb * SP + sp) * blksize; + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; ++i) + db[i] += diff_dst[offset+i]; + } + } + + const int blk = nstl::min(blksize, OC - ocb * blksize); + + PRAGMA_OMP_SIMD() + for (int i = 0; i < blk; ++i) + diff_bias[ocb * blksize + i] = db[i]; + }); +} + +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>( + const data_t *diff_dst, data_t *diff_bias) const; +template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>( + const data_t *diff_dst, data_t *diff_bias) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp new file mode 100644 index 0000000000..d61903c32d --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp @@ -0,0 +1,502 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_DECONVOLUTION_HPP +#define CPU_REF_DECONVOLUTION_HPP + +#include <assert.h> +#include <string.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "primitive_iterator.hpp" + +#include "cpu_convolution_pd.hpp" +#include "cpu_deconvolution_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static status_t compute_blocked_format(bool with_groups, + const memory_desc_t *oi_md, memory_desc_t *io_md) +{ + /* Computes blocking for *i*o* format from *o*i* format */ + + bool sanity_check_ok = true + && oi_md->ndims == io_md->ndims + && oi_md->format_kind == format_kind::blocked; + if (!sanity_check_ok) return status::invalid_arguments; + + const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; + blocking_desc_t io_blk = io_md->format_desc.blocking; + + io_md->format_kind = format_kind::blocked; + io_blk = oi_blk; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); + for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { + if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { + io_blk.inner_idxs[i_blk] = + (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); + } + } + + return memory_desc_init_by_blocking_desc(*io_md, io_blk); +} + +static status_t conv_descr_create(const deconvolution_desc_t *dd, + convolution_desc_t *cd) +{ + using namespace prop_kind; + alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct + ? alg_kind::convolution_direct : alg_kind::convolution_winograd; + + const memory_desc_t *src_md, *dst_md, *d_weights_d; + prop_kind_t prop_kind; + memory_desc_t c_weights_d; + if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { + prop_kind = backward_data; + src_md = &dd->dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->weights_desc; + } else if (dd->prop_kind == backward_data) { + prop_kind = forward_training; + src_md = &dd->diff_dst_desc; + dst_md = &dd->diff_src_desc; + d_weights_d = &dd->weights_desc; + } else { + prop_kind = dd->prop_kind; + src_md = &dd->diff_dst_desc; + dst_md = &dd->src_desc; + d_weights_d = &dd->diff_weights_desc; + } + + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; + + /* create weights desc for convolution */ + c_weights_d = *d_weights_d; + + const int ID_OC = 0 + with_groups; + const int ID_IC = 1 + with_groups; + + nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); + nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); + nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); + + if (c_weights_d.format_kind != format_kind::any) + CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); + + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, + prop_kind != backward_weights ? &dd->bias_desc : nullptr, + dst_md, dd->strides, dd->dilates, + dd->padding[0], dd->padding[1], dd->padding_kind); +} + +struct ref_deconvolution_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_fwd_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_fwd_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , conv_supports_bias_(other.conv_supports_bias_) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + CHECK(conv_descr_create(desc(), &cd)); + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + conv_supports_bias_ = + static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_) + ->support_bias(); + bool output_f32 = utils::everyone_is(data_type::f32, + desc()->accum_data_type, desc()->dst_desc.data_type); + + bool ok = true + && conv_pd_->weights_md()->extra.flags == 0 + /* deconv reference code can process only f32 bias */ + && IMPLICATION(with_bias(), + conv_supports_bias_ || output_f32); + if (ok) return status::success; + + delete conv_pd_; + } + conv_pd_ = nullptr; + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->post_ops_.has_default_values(); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (dst_md_.format_kind == format_kind::any) + dst_md_ = *conv_pd_->diff_src_md(); + if (bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + bool conv_supports_bias_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_fwd_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + if (pd()->with_bias() && pd()->conv_supports_bias_) + conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); + conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + + if (pd()->with_bias() && !pd()->conv_supports_bias_) { + using namespace format_tag; + + auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_fwd_bias_ncdhw(bias, dst); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_fwd_bias_nCdhwXc<8>(bias, dst); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_fwd_bias_nCdhwXc<16>(bias, dst); + break; + default: + compute_fwd_bias(bias, dst); + break; + } + } + return status::success; + } + +private: + void compute_fwd_bias(const data_t *bias, data_t *dst) const; + void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; + template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias, + data_t *dst) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_data_pd_t { + pd_t(engine_t *engine, const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_data_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + + return status::unimplemented; + } + + status_t init() { + using namespace data_type; + bool ok = true + && desc()->prop_kind == prop_kind::backward_data + && utils::everyone_is(data_type::f32, + desc()->diff_src_desc.data_type, + desc()->weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd); + + if (ok) { + CHECK(init_convolution()); + if (weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->weights_md(), &desc_.weights_desc)); + weights_md_ = desc_.weights_desc; + } + if (diff_src_md_.format_kind == format_kind::any) + diff_src_md_ = *conv_pd_->dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_data_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); + conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + conv_p_->execute(conv_ctx); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + primitive_t *conv_p_; +}; + +struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { + pd_t(engine_t *engine, + const deconvolution_desc_t *adesc, + const primitive_attr_t *attr, + const deconvolution_fwd_pd_t *hint_fwd_pd) + : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) + , conv_pd_(nullptr) + {} + + pd_t(const pd_t &other) + : cpu_deconvolution_bwd_weights_pd_t(other) + , conv_pd_(other.conv_pd_->clone()) + , dst_tag_(other.dst_tag_) + {} + + ~pd_t() { delete conv_pd_; } + + DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); + + status_t init_convolution() { + using namespace types; + + convolution_desc_t cd; + status_t status = conv_descr_create(desc(), &cd); + if (status != status::success) return status; + + mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, + &attr_, nullptr); + while (++it != it.end()) { + conv_pd_ = *it; + if (conv_pd_->diff_weights_md()->extra.flags == 0) + return status::success; + delete conv_pd_; + } + return status::unimplemented; + } + + status_t init() { + using namespace format_tag; + bool ok = true + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type::f32, + desc()->src_desc.data_type, + desc()->diff_weights_desc.data_type, + desc()->diff_dst_desc.data_type) + && utils::one_of(desc()->alg_kind, + alg_kind::deconvolution_direct, + alg_kind::deconvolution_winograd) + && attr()->has_default_values(); + if (ok) { + CHECK(init_convolution()); + if (diff_weights_md_.format_kind == format_kind::any) { + CHECK(compute_blocked_format(with_groups(), + conv_pd_->diff_weights_md(), + &desc_.diff_weights_desc)); + diff_weights_md_ = desc_.diff_weights_desc; + } + if (src_md_.format_kind == format_kind::any) + src_md_ = *conv_pd_->diff_dst_md(); + if (diff_dst_md_.format_kind == format_kind::any) + diff_dst_md_ = *conv_pd_->src_md(); + if (diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); + + dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, + utils::pick(ndims() - 3, ncw, nchw, ncdhw), + utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), + utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); + + return status::success; + } + + return status::unimplemented; + } + + virtual void init_scratchpad_md() override { + scratchpad_md_ = *conv_pd_->scratchpad_md(); + } + + primitive_desc_t *conv_pd_; + format_tag_t dst_tag_; + }; + + typedef typename prec_traits<data_type::f32>::type data_t; + + ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) + { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } + ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto &args = ctx.args(); + exec_args_t conv_args; + conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); + conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); + conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); + if (!types::is_zero_md(pd()->scratchpad_md())) + conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); + const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); + + status_t status = conv_p_->execute(conv_ctx); + if (status != status::success) return status; + + if (pd()->with_bias()) { + using namespace format_tag; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + switch (pd()->dst_tag_) { + case ncdhw: case nchw: case ncw: + compute_bwd_bias_ncdhw(diff_dst, diff_bias); + break; + case nCdhw8c: case nChw8c: case nCw8c: + compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); + break; + case nCdhw16c: case nChw16c: case nCw16c: + compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); + break; + default: + compute_bwd_bias(diff_dst, diff_bias); + break; + } + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; + void compute_bwd_bias_ncdhw(const data_t *diff_dst, + data_t *diff_bias) const; + template <int blksize> void compute_bwd_bias_nCdhwXc( + const data_t *diff_dst, data_t *diff_bias) const; + + primitive_t *conv_p_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp new file mode 100644 index 0000000000..7beee8d323 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp @@ -0,0 +1,297 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_eltwise.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace alg_kind; +using namespace math; + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, + float beta): alg_(alg), alpha_(alpha), beta_(beta) { + assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu, + eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, + eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)); +} + +ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t( + const post_ops_t::entry_t::eltwise_t &eltwise) + : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {} + +float ref_eltwise_scalar_fwd_t::compute_scalar(float s) { + switch (alg_) { + case eltwise_relu: return relu_fwd(s, alpha_); + case eltwise_tanh: return tanh_fwd(s); + case eltwise_elu: return elu_fwd(s, alpha_); + case eltwise_square: return square_fwd(s); + case eltwise_abs: return abs_fwd(s); + case eltwise_sqrt: return sqrt_fwd(s); + case eltwise_linear: return linear_fwd(s, alpha_, beta_); + case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_); + case eltwise_soft_relu: return soft_relu_fwd(s); + case eltwise_logistic: return logistic_fwd(s); + default: assert(!"unknown eltwise alg_kind"); + } + + return 0.f; +} + +template <impl::data_type_t data_type> +void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + const blocking_desc_t &blk = data_d.blocking_desc(); + const int block = blk.inner_blks[0]; + + const int MB = pd()->MB(); + const int C = pd()->C() / block; + const int C_PADDED = data_d.padded_dims()[1] / block; + const int tail = pd()->C() % block; + const int SP = pd()->D() * pd()->H() * pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + auto ker = [=] (data_t &d, data_t s) { + switch (alg_kind) { + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }; + + // FIXME: integer overflow? + + parallel_nd(MB, C_PADDED, SP, + [&](int n, int c, int sp) { + auto d_off = (n*C_PADDED*SP + c*SP + sp) * block; + if (c < C) { + for (int v = 0; v < block; v++) + ker(dst[d_off + v], src[d_off + v]); + } else { + for (int v = 0; v < tail; v++) + ker(dst[d_off + v], src[d_off + v]); + } + }); +} + +template <impl::data_type_t data_type> +void ref_eltwise_fwd_t<data_type>::execute_forward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int id, int h, int w) { + auto d_off = is_3d + ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w); + data_t s = src[d_off]; + data_t &d = dst[d_off]; + switch (alg_kind) { + case eltwise_relu: d = relu_fwd(s, alpha); break; + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: + d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template <impl::data_type_t data_type> +void ref_eltwise_fwd_t<data_type>::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + dst += data_d.offset0(); + + if (alg_kind == eltwise_relu) { + // a fast path for relu as the most popular activation + parallel_nd(nelems, [&](ptrdiff_t e) { + dst[e] = relu_fwd(src[e], alpha); + }); + return; + } + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t s = src[e]; + data_t &d = dst[e]; + + switch (alg_kind) { + case eltwise_tanh: d = tanh_fwd(s); break; + case eltwise_elu: d = elu_fwd(s, alpha); break; + case eltwise_square: d = square_fwd(s); break; + case eltwise_abs: d = abs_fwd(s); break; + case eltwise_sqrt: d = sqrt_fwd(s); break; + case eltwise_linear: d = linear_fwd(s, alpha, beta); break; + case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break; + case eltwise_soft_relu: d = soft_relu_fwd(s); break; + case eltwise_logistic: d = logistic_fwd(s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template <impl::data_type_t data_type> +void ref_eltwise_bwd_t<data_type>::execute_backward_generic( + const exec_ctx_t &ctx) const { + /* fast return */ + if (pd()->has_zero_dim_memory()) return; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int D = pd()->D(); + const int H = pd()->H(); + const int W = pd()->W(); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + const bool is_3d = pd()->desc()->data_desc.ndims == 5; + + parallel_nd(MB, C, D, H, W, + [&](int n, int c, int d, int h, int w) { + auto data_off = is_3d + ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w); + auto diff_data_off = is_3d + ? diff_data_d.off(n, c, d, h, w) + : diff_data_d.off(n, c, h, w); + data_t s = src[data_off]; + data_t dd = diff_dst[diff_data_off]; + data_t &ds = diff_src[diff_data_off]; + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: + ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: + ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template <impl::data_type_t data_type> +void ref_eltwise_bwd_t<data_type>::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); + + const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true)); + const auto alg_kind = pd()->desc()->alg_kind; + const float alpha = pd()->desc()->alpha; + const float beta = pd()->desc()->beta; + + src += data_d.offset0(); + diff_dst += diff_data_d.offset0(); + diff_src += diff_data_d.offset0(); + + parallel_nd(nelems, [&](ptrdiff_t e) { + const data_t dd = diff_dst[e]; + const data_t s = src[e]; + data_t &ds = diff_src[e]; + + switch (alg_kind) { + case eltwise_relu: ds = relu_bwd(dd, s, alpha); break; + case eltwise_tanh: ds = tanh_bwd(dd, s); break; + case eltwise_elu: ds = elu_bwd(dd, s, alpha); break; + case eltwise_square: ds = square_bwd(dd, s); break; + case eltwise_abs: ds = abs_bwd(dd, s); break; + case eltwise_sqrt: ds = sqrt_bwd(dd, s); break; + case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break; + case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break; + case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break; + case eltwise_logistic: ds = logistic_bwd(dd, s); break; + default: assert(!"unknown eltwise alg_kind"); + } + }); +} + +template struct ref_eltwise_fwd_t<data_type::f32>; +template struct ref_eltwise_fwd_t<data_type::s32>; +template struct ref_eltwise_fwd_t<data_type::s8>; +template struct ref_eltwise_fwd_t<data_type::u8>; + +template struct ref_eltwise_bwd_t<data_type::f32>; +template struct ref_eltwise_bwd_t<data_type::s32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp new file mode 100644 index 0000000000..8f4ab35413 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp @@ -0,0 +1,168 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_ELTWISE_HPP +#define CPU_REF_ELTWISE_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_eltwise_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_eltwise_scalar_fwd_t { +public: + ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta); + + // note that eltwise.scale is ignored + ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise); + + float compute_scalar(float s); + + const alg_kind_t alg_; + const float alpha_; + const float beta_; +}; + +template <impl::data_type_t data_type> +struct ref_eltwise_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_fwd_pd_t { + using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t); + + status_t init() { + using namespace utils; + + auto src_d = memory_desc_wrapper(src_md()); + + use_dense_ = false + || src_d.is_dense() + || (src_d.is_dense(true) && is_zero_preserved()); + + use_nCspBc_padded_ = !use_dense_ + && src_d.blocking_desc().inner_nblks == 1 + && one_of(src_d.blocking_desc().inner_blks[0], 8, 16) + && src_d.blocking_desc().inner_idxs[0] == 1 + && src_d.only_padded_dim(1) + && src_d.is_dense(true); + + if (has_zero_dim_memory()) + use_dense_ = use_nCspBc_padded_ = false; + + const bool use_generic = !use_dense_ && !use_nCspBc_padded_; + + bool ok = true + && is_fwd() + && everyone_is(data_type, desc()->data_desc.data_type) + && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5)) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + + bool use_dense_, use_nCspBc_padded_; + }; + + ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_forward_dense(ctx); + else if (pd()->use_nCspBc_padded_) + execute_forward_nCspBc_padded(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const; + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct ref_eltwise_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_eltwise_bwd_pd_t { + using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t); + + status_t init() { + using namespace utils; + + bool ok = true + && !is_fwd() + && everyone_is(data_type, + desc()->data_desc.data_type, + desc()->diff_data_desc.data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + auto diff_dst_d = memory_desc_wrapper(diff_dst_md()); + const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md()); + + use_dense_ = true + && same_fmt_ + && diff_dst_d.is_dense(true) + && is_zero_preserved() + && !has_zero_dim_memory(); + const bool use_generic = !use_dense_; + + if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5)) + return status::unimplemented; + + return status::success; + } + + bool use_dense_; + }; + + ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (pd()->use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp new file mode 100644 index 0000000000..c807a9ffd0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp @@ -0,0 +1,285 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "mkldnn_traits.hpp" +#include "math_utils.hpp" + +#include "ref_inner_product.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using math::saturate; +using math::get_bias; + +template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type, + data_type_t acc_type> +void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>:: +execute_forward(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper bias_d(pd()->weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); + const int ndims = src_d.ndims() - 2; + + const auto &post_ops = pd()->attr()->post_ops_; + const bool do_relu = post_ops.len_ == 1; + const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f; + + auto ker_has_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int ic = 0; ic < IC; ++ic) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + switch (ndims) { + case 3: + d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)] + * weights[weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)] + * weights[weights_d.off(oc, ic, kh, kw)]; + break; + case 1: + d += (acc_data_t)src[src_d.off(mb, ic, kw)] + * weights[weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + return d; + }; + + auto ker_no_spatial = [=](int mb, int oc) { + acc_data_t d = 0; + for (int ic = 0; ic < IC; ++ic) { + d += (acc_data_t)src[src_d.off(mb, ic)] + * weights[weights_d.off(oc, ic)]; + } + return d; + }; + + parallel_nd(MB, OC, [&](int mb, int oc) { + float a = bias + ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type) + : 0; + if (src_has_spatial) + a += ker_has_spatial(mb, oc); + else + a += ker_no_spatial(mb, oc); + if (do_relu && a < (acc_data_t)0) + a *= nslope; + dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a); + }); +} + +using namespace data_type; +template struct ref_inner_product_fwd_t<f32>; +template struct ref_inner_product_fwd_t<u8, s8, f32, s32>; +template struct ref_inner_product_fwd_t<u8, s8, s32, s32>; +template struct ref_inner_product_fwd_t<u8, s8, s8, s32>; +template struct ref_inner_product_fwd_t<u8, s8, u8, s32>; + +template <data_type_t diff_src_type, data_type_t wei_type, + data_type_t diff_dst_type, data_type_t acc_type> +void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type, + acc_type>::execute_backward_data(const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST); + auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS); + auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool diff_src_has_spatial + = utils::one_of(diff_src_d.ndims(), 3, 4, 5); + const int ndims = diff_src_d.ndims() - 2; + + parallel_nd(MB, IC, [&](int mb, int ic) { + if (diff_src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) + for (int kh = 0; kh < KH; ++kh) + for (int kw = 0; kw < KW; ++kw) { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + switch (ndims) { + case 3: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kd, kh, kw)]); + break; + case 2: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kh, kw)]); + break; + case 1: + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] + * weights[weights_d.off(oc, ic, kw)]); + break; + default: assert(!"unsupported ndims size"); + } + } + switch (ndims) { + case 3: + diff_src[diff_src_d.off(mb, ic, kd, kh, kw)] + = (diff_src_data_t)ds; + break; + case 2: + diff_src[diff_src_d.off(mb, ic, kh, kw)] + = (diff_src_data_t)ds; + break; + case 1: + diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds; + break; + default: assert(!"unsupported ndims size"); + } + } + } else { + acc_data_t ds = acc_data_t(0); + for (int oc = 0; oc < OC; ++oc) { + ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] * + weights[weights_d.off(oc, ic)]); + } + diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds; + } + }); +} + +template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>; + +template <impl::data_type_t data_type> +void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights( + const exec_ctx_t &ctx) const { + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS); + auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); + const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); + + const int MB = pd()->MB(); + const int OC = pd()->OC(); + const int IC = pd()->IC(); + + const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5); + const int ndims = src_d.ndims() - 2; + + parallel_nd(OC, IC, [&](int oc, int ic) { + if (src_has_spatial) { + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + data_t *dw(nullptr); + switch (ndims) { + case 3: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kd, kh, kw)]; + break; + case 2: + dw = &diff_weights[diff_weights_d.off( + oc, ic, kh, kw)]; + break; + case 1: + dw = &diff_weights[diff_weights_d.off(oc, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + switch (ndims) { + case 3: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kd, kh, kw)]; + break; + case 2: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kh, kw)]; + break; + case 1: + *dw += diff_dst[diff_dst_d.off(mb, oc)] + * src[src_d.off(mb, ic, kw)]; + break; + default: assert(!"unsupported ndims size"); + } + } + } + } + } + } else { + data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)]; + *dw = data_t(0); + for (int mb = 0; mb < MB; ++mb) { + *dw += diff_dst[diff_dst_d.off(mb, oc)] * + src[src_d.off(mb, ic)]; + } + } + }); + + if (diff_bias) { + diff_bias += diff_bias_d.offset0(); + + parallel_nd(OC, [&](int oc) { + data_t *db = &diff_bias[oc]; + *db = data_t(0); + for (int mb = 0; mb < MB; ++mb) + *db += diff_dst[diff_dst_d.off(mb, oc)]; + }); + } +} + +template struct ref_inner_product_bwd_weights_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp new file mode 100644 index 0000000000..bf87dbd514 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp @@ -0,0 +1,159 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_INNER_PRODUCT_HPP +#define CPU_REF_INNER_PRODUCT_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_inner_product_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type, + impl::data_type_t dst_type = src_type, + impl::data_type_t acc_type = dst_type> +struct ref_inner_product_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_fwd_pd_t { + using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t); + + status_t init() { + using namespace data_type; + + bool ok = true + && set_default_params() == status::success + && is_fwd() + && src_md()->data_type == src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && dst_md()->data_type == dst_type + && IMPLICATION(with_bias(), utils::one_of( + weights_md(1)->data_type, f32, s32, s8, u8)) + && attr()->output_scales_.has_default_values() + && attr()->post_ops_.len_ <= 1 + && IMPLICATION(attr()->post_ops_.len_ == 1, + attr()->post_ops_.entry_[0].is_relu(true, false)); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<dst_type>::type dst_data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t diff_src_type, impl::data_type_t wei_type, + impl::data_type_t diff_dst_type, + impl::data_type_t acc_type = diff_src_type> +struct ref_inner_product_bwd_data_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_data_pd_t { + using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_data + && diff_src_md()->data_type == diff_src_type + && weights_md()->data_type == wei_type + && desc()->accum_data_type == acc_type + && diff_dst_md()->data_type == diff_dst_type + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<diff_src_type>::type diff_src_data_t; + typedef typename prec_traits<wei_type>::type wei_data_t; + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_data(ctx); + return status::success; + } + +private: + void execute_backward_data(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct ref_inner_product_bwd_weights_t: public cpu_primitive_t { + struct pd_t: public cpu_inner_product_bwd_weights_pd_t { + using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && desc()->prop_kind == prop_kind::backward_weights + && utils::everyone_is(data_type, + src_md()->data_type, + diff_dst_md()->data_type, + diff_weights_md()->data_type) + && IMPLICATION(with_bias(), + data_type == diff_weights_md(1)->data_type) + && attr()->has_default_values(); + return ok ? status::success : status::unimplemented; + } + }; + + ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward_weights(ctx); + return status::success; + } + +private: + void execute_backward_weights(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp new file mode 100644 index 0000000000..325e97963b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp @@ -0,0 +1,252 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_lrn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +static inline float fast_negative_powf(float omega, float beta) { + float Y; +/* + * Y = omega^(-3/4) = + * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega)) + * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega) + * = sqrtf(1.0f / sqrtf(omega) / omega) + * = sqrtf(1.0f / (sqrtf(omega) * omega)) + */ + if (beta == 0.75f) { + Y = sqrtf(1.0f / (sqrtf(omega) * omega)); + } else { + Y = 1.0f / powf(omega, beta); + } + return Y; +}; + +template <impl::data_type_t data_type> +template <impl::format_tag_t tag> +void ref_lrn_fwd_t<data_type>::execute_forward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize + + h * W * blksize + w * blksize + c % blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const float alpha = static_cast<float>(pd()->desc()->lrn_alpha); + const float beta = static_cast<float>(pd()->desc()->lrn_beta); + const float k = static_cast<float>(pd()->desc()->lrn_k); + + const int size = pd()->desc()->local_size; + const int half_size = (size - 1) / 2; + + float sum = 0; + if (across_channels) { + const int c_st = nstl::max(oc - half_size + 0, 0); + const int c_en = nstl::min(oc + half_size + 1, C); + + for (int c = c_st; c < c_en; ++c) { + const float s = src[data_off(mb, c, oh, ow)]; + sum += s * s; + } + } else { + int h_st = nstl::max(oh - half_size + 0, 0); + int h_en = nstl::min(oh + half_size + 1, H); + int w_st = nstl::max(ow - half_size + 0, 0); + int w_en = nstl::min(ow + half_size + 1, W); + for (int h = h_st; h < h_en; ++h) { + for (int w = w_st; w < w_en; ++w) { + const float s = src[data_off(mb, oc, h, w)]; + sum += s * s; + } + } + } + const int summands = across_channels ? size : size * size; + sum = k + alpha * sum / summands; + size_t off = data_off(mb, oc, oh, ow); + d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta)); + }; + + const int MB = pd()->MB(); + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&dst[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&dst[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&dst[off], mb, c, h, w); + }); + } +} + +template <impl::data_type_t data_type> +template <mkldnn_format_tag_t tag> +void ref_lrn_bwd_t<data_type>::execute_backward(const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace format_tag; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper data_d(pd()->src_md()); + + const int MB = pd()->MB(); + const int C = pd()->C(); + const int H = pd()->H(); + const int W = pd()->W(); + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = tag == nChw16c ? 16 : 8; + + const float alpha = static_cast<float>(pd()->desc()->lrn_alpha); + const float beta = static_cast<float>(pd()->desc()->lrn_beta); + const float k = static_cast<float>(pd()->desc()->lrn_k); + const int kernel_size = pd()->desc()->local_size; + const int half_ksize = (kernel_size - 1) / 2; + + auto data_off = [&](int mb, int c, int h, int w) -> size_t { + switch (tag) { + case nChw16c: + case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize + + h * W * blksize + w * blksize + c%blksize; + case nchw: return mb * stride_mb + c * H * W + h * W + w; + case nhwc: return mb * stride_mb + h * W * C + w * C + c; + default: return data_d.off(mb, c, h, w); + } + }; + + auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) { + const int c_st = nstl::max(oc - half_ksize + 0, 0); + const int c_en = nstl::min(oc + half_ksize + 1, C); + + float A = 0, B = 0, omega_mid = 0; + for (int c = c_st; c < c_en; c++) { + float sum = 0.0; + const int i_st = nstl::max(c - half_ksize, 0); + const int i_en = nstl::min(c + kernel_size - half_ksize, C); + + for (int i = i_st; i < i_en; ++i) { + const float value = src[data_off(mb, i, oh, ow)]; + sum += value * value; + } + const float omega = static_cast<float>(k + sum * alpha / kernel_size); + if (c == oc) omega_mid = omega; + float t = src[data_off(mb, c, oh, ow)] + * fast_negative_powf(omega, beta); + B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)]; + } + + const size_t off = data_off(mb, oc, oh, ow); + A = fast_negative_powf(omega_mid, beta) * diff_dst[off]; + B *= src[off]; + B *= (2.0f * alpha * beta) / kernel_size; + *d = static_cast<data_t>(A - B); // final cast down to data_t + }; + + if (tag == nChw16c || tag == nChw8c) { + parallel_nd(MB, utils::div_up(C, blksize), H, W, + [&](int mb, int c_blk, int h, int w) { + int c = c_blk * blksize; + const size_t off = mb * stride_mb + c * H * W + + (h * W + w) * blksize; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc) + ker(&diff_src[off + cc], mb, c + cc, h, w); + }); + } else if (tag == nhwc) { + parallel_nd(MB, H, W, C, + [&](int mb, int h, int w, int c) { + const size_t off = mb * stride_mb + h * W * C + w * C + c; + ker(&diff_src[off], mb, c, h, w); + }); + } else { + parallel_nd(MB, C, H, W, + [&](int mb, int c, int h, int w) { + const size_t off = data_off(mb, c, h, w); + ker(&diff_src[off], mb, c, h, w); + }); + } +} + +template void ref_lrn_fwd_t<data_type::f32>:: +execute_forward<format_tag::nChw16c>(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t<data_type::f32>:: +execute_forward<format_tag::nChw8c>(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t<data_type::f32>:: +execute_forward<format_tag::nchw>(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t<data_type::f32>:: +execute_forward<format_tag::nhwc>(const exec_ctx_t &ctx) const; +template void ref_lrn_fwd_t<data_type::f32>:: +execute_forward<format_tag::any>(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t<data_type::f32>:: +execute_backward<format_tag::nChw16c>(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t<data_type::f32>:: +execute_backward<format_tag::nChw8c>(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t<data_type::f32>:: +execute_backward<format_tag::nchw>(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t<data_type::f32>:: +execute_backward<format_tag::nhwc>(const exec_ctx_t &ctx) const; +template void ref_lrn_bwd_t<data_type::f32>:: +execute_backward<format_tag::any>(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp new file mode 100644 index 0000000000..f25cfb7fae --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp @@ -0,0 +1,136 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_LRN_HPP +#define CPU_REF_LRN_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_lrn_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +struct ref_lrn_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_fwd_pd_t { + using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); + + status_t init() { + using namespace format_tag; + + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_forward<nChw16c>(ctx); break; + case nChw8c: execute_forward<nChw8c>(ctx); break; + case nchw: execute_forward<nchw>(ctx); break; + case nhwc: execute_forward<nhwc>(ctx); break; + default: execute_forward<any>(ctx); + } + return status::success; + } + +private: + template<format_tag_t tag> + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type> +struct ref_lrn_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_lrn_bwd_pd_t { + using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); + + status_t init() { + using namespace format_tag; + using namespace alg_kind; + + bool ok = true + && !is_fwd() + && utils::one_of(desc()->alg_kind, lrn_across_channels + /*, lrn_within_channel */) // not supported yet + && utils::everyone_is(data_type, + src_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + dat_tag_ = memory_desc_matches_one_of_tag( + *src_md(), nChw16c, nChw8c, nchw, nhwc); + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nChw16c: execute_backward<nChw16c>(ctx); break; + case nChw8c: execute_backward<nChw8c>(ctx); break; + case nchw: execute_backward<nchw>(ctx); break; + case nhwc: execute_backward<nhwc>(ctx); break; + default: execute_backward<any>(ctx); + } + return status::success; + } + +private: + template<format_tag_t tag> + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp new file mode 100644 index 0000000000..65b934e123 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp @@ -0,0 +1,381 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" + +#include "ref_pooling.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t data_type, data_type_t acc_type> +void ref_pooling_fwd_t<data_type, acc_type>::execute_forward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + using namespace prop_kind; + + auto alg = pd()->desc()->alg_kind; + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef; + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->src_desc.ndims == 5; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) { + if (ws) { + assert(ws_dt == data_type::u8 || ws_dt == data_type::s32); + size_t offset = is_3d + ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);; + if (ws_dt == data_type::u8) { + assert(0 <= value && value <= 255); + ws[offset] = value; + } else + reinterpret_cast<int *>(ws)[offset] = value; + } + }; + + auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, 1, oh, ow, kh*KW + kw); + } + } + } + }; + + auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + acc_data_t dst = 0; + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, ih, iw)]; + } + } + + d[0] = math::out_round<data_t>((float)dst / num_summands); + }; + + auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + for (int kd = 0; kd < KD; ++kd) { + for (int kh = 0; kh < KH; ++kh) { + for (int kw = 0; kw < KW; ++kw) { + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + if (id < 0 || id >= ID) continue; + if (ih < 0 || ih >= IH) continue; + if (iw < 0 || iw >= IW) continue; + + auto s = src[src_d.off(mb, oc, id, ih, iw)]; + if (s > d[0]) { + d[0] = s; + set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw); + } + } + } + } + }; + + auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + acc_data_t dst = 0; + for (int id = id_start; id < id_end; ++id) { + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + dst += src[src_d.off(mb, oc, id, ih, iw)]; + } + } + } + + d[0] = math::out_round<data_t>((float)dst / num_summands); + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (alg == pooling_max) { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = nstl::numeric_limits<data_t>::lowest(); + set_ws(mb, oc, od, oh, ow, 0); + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + }); + } else { + parallel_nd(MB, OC, OD, OH, OW, + [&](int mb, int oc, int od, int oh, int ow) { + data_t *d = is_3d + ? &dst[dst_d.off(mb, oc, od, oh, ow)] + : &dst[dst_d.off(mb, oc, oh, ow)]; + d[0] = 0; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + }); + } +} + +template <data_type_t data_type, data_type_t acc_type> +void ref_pooling_bwd_t<data_type, acc_type>::execute_backward( + const exec_ctx_t &ctx) const { + using namespace alg_kind; + + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); + const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); + const memory_desc_wrapper ws_d(pd()->workspace_md()); + + const int ID = pd()->ID(); + const int IH = pd()->IH(); + const int IW = pd()->IW(); + const int KD = pd()->KD(); + const int KH = pd()->KH(); + const int KW = pd()->KW(); + const int SD = pd()->KSD(); + const int SH = pd()->KSH(); + const int SW = pd()->KSW(); + const int padF = pd()->padFront(); + const int padT = pd()->padT(); + const int padL = pd()->padL(); + + const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5; + + auto alg = pd()->desc()->alg_kind; + + auto apply_offset = [=](int index, int offset) { + return (index > offset) ? index - offset : 0; + }; + + auto ker_zero = [=](int _mb, int _oc) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0); + } + } + }; + + auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) { + const size_t ws_off = ws_d.off(mb, oc, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = index / KW; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0]; + }; + + auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) { + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH + : (ih_end - ih_start)*(iw_end - iw_start); + + for (int ih = ih_start; ih < ih_end; ++ih) { + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands; + } + } + }; + + auto ker_zero_3d = [=](int _mb, int _oc) { + for (int id = 0; id < ID; ++id) { + for (int ih = 0; ih < IH; ++ih) { + for (int iw = 0; iw < IW; ++iw) { + diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] = + data_type_t(0); + } + } + } + }; + + auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + const size_t ws_off = ws_d.off(mb, oc, od, oh, ow); + const int index = ws_d.data_type() == data_type::u8 + ? (int)ws[ws_off] : ((int *)ws)[ws_off]; + const int kw = index % KW; + const int kh = (index / KW) % KH; + const int kd = (index / KW) / KH; + const int id = od * SD - padF + kd; + const int ih = oh * SH - padT + kh; + const int iw = ow * SW - padL + kw; + + // If padding area could fit the kernel, + // then input displacement would be out of bounds. + // No need to back propagate there as padding is + // virtual in pooling_max case. + if (id < 0 || id >= ID) + return; + if (ih < 0 || ih >= IH) + return; + if (iw < 0 || iw >= IW) + return; + + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0]; + }; + + auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh, + int ow) { + auto id_start = apply_offset(od*SD, padF); + auto ih_start = apply_offset(oh*SH, padT); + auto iw_start = apply_offset(ow*SW, padL); + auto id_end = nstl::min(od*SD - padF + KD, ID); + auto ih_end = nstl::min(oh*SH - padT + KH, IH); + auto iw_end = nstl::min(ow*SW - padL + KW, IW); + + auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD + : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start); + + for (int id = id_start; id < id_end; ++id) + for (int ih = ih_start; ih < ih_end; ++ih) + for (int iw = iw_start; iw < iw_end; ++iw) { + diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands; + } + }; + + const int MB = pd()->MB(); + const int OC = pd()->C(); + const int OD = pd()->OD(); + const int OH = pd()->OH(); + const int OW = pd()->OW(); + + if (pd()->desc()->alg_kind == alg_kind::pooling_max) { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow); + else ker_max(d, mb, oc, oh, ow); + } + } + } + }); + } else { + parallel_nd(MB, OC, [&](int mb, int oc) { + if (is_3d) ker_zero_3d(mb, oc); + else ker_zero(mb, oc); + for (int od = 0; od < OD; ++od) { + for (int oh = 0; oh < OH; ++oh) { + for (int ow = 0; ow < OW; ++ow) { + const data_t *d = is_3d + ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)] + : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)]; + if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow); + else ker_avg(d, mb, oc, oh, ow); + } + } + } + }); + } +} + +template struct ref_pooling_fwd_t<data_type::f32>; +template struct ref_pooling_fwd_t<data_type::s32>; +template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>; +template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>; + +template struct ref_pooling_bwd_t<data_type::f32>; +template struct ref_pooling_bwd_t<data_type::s32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp new file mode 100644 index 0000000000..e43ceaa82b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_POOLING_HPP +#define CPU_REF_POOLING_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_pooling_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type> +struct ref_pooling_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_fwd_pd_t { + using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && is_fwd() + && utils::everyone_is(data_type, src_md()->data_type, + dst_md()->data_type) + && desc()->accum_data_type == acc_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + bool is_training = desc_.prop_kind == prop_kind::forward_training; + if (desc()->alg_kind == alg_kind::pooling_max && is_training) + init_default_ws(); + + return status::success; + } + }; + + ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + + typedef typename prec_traits<data_type>::type data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type> +struct ref_pooling_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_pooling_bwd_pd_t { + using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t); + + status_t init() { + bool ok = true + && set_default_params() == status::success + && !is_fwd() + && utils::everyone_is(data_type, diff_dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + if (desc()->alg_kind == alg_kind::pooling_max) { + init_default_ws(); + if (!compare_ws(hint_fwd_pd_)) + return status::unimplemented; + } + + return status::success; + } + }; + + ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {} + typedef typename prec_traits<data_type>::type data_t; + typedef typename prec_traits<acc_type>::type acc_data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_backward(ctx); + return status::success; + } + +private: + void execute_backward(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp new file mode 100644 index 0000000000..af27743110 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp @@ -0,0 +1,153 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_shuffle.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace format_tag; + +template <int data_type_size> +template <mkldnn_format_tag_t tag> +void ref_shuffle_t<data_type_size>::execute_(const exec_ctx_t &ctx) const { + using namespace prop_kind; + using namespace utils; + + const memory_desc_wrapper data_d(pd()->data_md()); + + auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST; + auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC; + auto input = CTX_IN_MEM(const data_t *, i_arg); + auto output = CTX_OUT_MEM(data_t *, o_arg); + + const int axis = pd()->axis(); + const int axis_size = pd()->axis_size(); + + const int MB = pd()->MB(); + const int C = pd()->C(); + int H = 1, W = 1, D = 1, HW = 1, SP = 1; + const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5); + if (has_spatial) + { + D = pd()->D(); + H = pd()->H(); + W = pd()->W(); + HW = H * W; + SP = D * HW; + } + const size_t stride_mb = data_d.blocking_desc().strides[0]; + constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8; + + if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) { +#if MKLDNN_THR == MKLDNN_THR_OMP +# pragma omp parallel for collapse(3) schedule(static) + for (int mb = 0; mb < MB; ++mb) + for (int cb = 0; cb < C; cb += blksize) + for (int sp = 0; sp < SP; ++sp) { + const size_t off = mb * stride_mb + sp * blksize; + const size_t output_off = off + cb * SP; + PRAGMA_OMP_SIMD() + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + } +#else + parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c, + int sp) { + const size_t off = mb * stride_mb + sp * blksize; + const int cb = c * blksize; + const size_t output_off = off + cb * SP; + for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc) + { + int input_c = rev_transposed_[cb + cc]; + const size_t input_off = off + input_c / blksize * SP * blksize + + input_c % blksize; + output[output_off + cc] = input[input_off]; + } + }); +#endif + } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { + parallel_nd(MB, SP, [&](int mb, int sp) { + const size_t off = mb * stride_mb + sp * C; + PRAGMA_OMP_SIMD() + for (int c = 0; c < C; ++c) + output[off + c] = input[off + rev_transposed_[c]]; + }); + } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { + parallel_nd(MB, C, [&](int mb, int c) { + const size_t output_off = mb * stride_mb + c * SP; + const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP; + PRAGMA_OMP_SIMD() + for (int sp = 0; sp < SP; ++sp) { + output[output_off + sp] = input[input_off + sp]; + } + }); + } else { + auto dims = pd()->desc()->data_desc.dims; + auto ndims = pd()->desc()->data_desc.ndims; + const size_t outer_size = utils::array_product(dims, axis); + const size_t inner_size = utils::array_product(dims + axis + 1, + ndims - axis - 1); + const size_t dim = axis_size * inner_size; + + parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a, + size_t in) + { + const size_t off = ou * dim + in; + auto &o = output[data_d.off_l(off + a * inner_size)]; + o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)]; + }); + } +} + +template void ref_shuffle_t<4>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<nChw16c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<nChw8c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<ncdhw>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<nchw>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<ndhwc>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<nhwc>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<4>::execute_<any>(const exec_ctx_t &ctx) const; + +template void ref_shuffle_t<1>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<nChw16c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<nChw8c>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<ncdhw>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<nchw>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<ndhwc>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<nhwc>(const exec_ctx_t &ctx) const; +template void ref_shuffle_t<1>::execute_<any>(const exec_ctx_t &ctx) const; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp new file mode 100644 index 0000000000..5e09a1a69b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp @@ -0,0 +1,111 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SHUFFLE_HPP +#define CPU_REF_SHUFFLE_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_shuffle_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template<int data_type_size> +struct ref_shuffle_t : public cpu_primitive_t { + using shuffle_class = ref_shuffle_t<data_type_size>; + + struct pd_t: public cpu_shuffle_pd_t { + using cpu_shuffle_pd_t::cpu_shuffle_pd_t; + + DECLARE_COMMON_PD_T("ref:any", shuffle_class); + + status_t init() { + using namespace format_tag; + + bool ok = true + && data_type_size + == types::data_type_size(data_md()->data_type); + if (!ok) return status::unimplemented; + + if (ndims() == 5) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc); + } else if (ndims() == 4) { + dat_tag_ = memory_desc_matches_one_of_tag( + *data_md(), nChw16c, nChw8c, nchw, nhwc); + } else + dat_tag_ = any; + + return status::success; + } + + format_tag_t dat_tag_; + }; + + ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) { + const int axis_size = pd()->axis_size(); + const int group_size = pd()->group_size(); + const int transpose_row = pd()->is_fwd() ? group_size + : axis_size / group_size; + const int transpose_col = pd()->is_fwd() ? axis_size / group_size + : group_size; + rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64); + parallel_nd(transpose_col, transpose_row, [&](int i, int j) { + rev_transposed_[j * transpose_col + i] = i * transpose_row + j; + }); + } + + ~ref_shuffle_t() { free(rev_transposed_); } + + typedef typename typesize_traits<data_type_size>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + using namespace format_tag; + switch (pd()->dat_tag_) { + case nCdhw16c: execute_<nCdhw16c>(ctx); break; + case nChw16c: execute_<nChw16c>(ctx); break; + case nCdhw8c: execute_<nCdhw8c>(ctx); break; + case nChw8c: execute_<nChw8c>(ctx); break; + case ncdhw: execute_<ncdhw>(ctx); break; + case nchw: execute_<nchw>(ctx); break; + case ndhwc: execute_<ndhwc>(ctx); break; + case nhwc: execute_<nhwc>(ctx); break; + default: execute_<any>(ctx); break; + } + return status::success; + } + +private: + template<format_tag_t tag> + void execute_(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int *rev_transposed_; +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp new file mode 100644 index 0000000000..36d5237f56 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <float.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "mkldnn_thread.hpp" +#include "type_helpers.hpp" + +#include "ref_softmax.hpp" +#include "gemm/os_blas.hpp" + +#ifdef USE_MKL +#include "mkl_vml_functions.h" +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::execute_forward_dense( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + parallel_nd(outer_size_, [&](int ou) { + const data_t *src_data = src + ou * channels_; + data_t *dst_data = dst + ou * channels_; + data_t scalar = 0; + + _max(channels_, src_data, &scalar); + _sub(channels_, scalar, src_data, dst_data); + _exp(channels_, dst_data, dst_data); + _sum(channels_, dst_data, &scalar); + _scal(channels_, data_t(1)/scalar, dst_data); + }); +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::execute_forward_generic( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + data_t space_max_val = 0, space_denom_val = 0; + data_t *space_max = &space_max_val, *space_denom = &space_denom_val; + if (inner_size_ > 1) { + using namespace memory_tracking::names; + space_max = scratchpad(ctx).template get<data_t>(key_softmax_reduction); + space_denom = space_max + inner_size_; + } + + const memory_desc_wrapper data_d(pd()->src_md()); + const size_t dim = channels_ * inner_size_; + + for (int ou = 0; ou < outer_size_; ou++) { + utils::array_set(space_max, -FLT_MAX, inner_size_); + utils::array_set(space_denom, 0, inner_size_); + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_max[in] = nstl::max(space_max[in], src[off]); + } + } + + for (int c = 0; c < channels_; c++) { + for(int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + space_denom[in] += dst[off] = exp(src[off] - space_max[in]); + } + } + + for (int c = 0; c < channels_; c++) { + for (int in = 0; in < inner_size_; in++) { + size_t off = data_d.off_l(ou * dim + c * inner_size_ + in); + dst[off] /= space_denom[in]; + } + } + } +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::_max(int n, const data_t *x, + data_t *max_data) const { +// Intel(R) C++ Compiler generates the maxps + shuffle pattern +// for the max search which works faster +#if !defined(__INTEL_COMPILER) + // The code below makes a compiler to generate maxps instruction + // rather than maxss, which is generated for the 'else' code path + auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); }; + auto min_wrapper = [](int a, int b) { return nstl::min(a, b); }; + + constexpr int unroll_factor = 32; + data_t max_values[unroll_factor]; + + if (n < unroll_factor) { + data_t max_val = x[0]; + for (int i = 1; i < n; i++) { + max_val = max_wrapper(max_val, x[i]); + } + max_data[0] = max_val; + return; + } + for (int i = 0; i < unroll_factor; i++) { + max_values[i] = x[i]; + } + for (int i = unroll_factor; i < n; i += unroll_factor) { + int offset = min_wrapper(i, n - unroll_factor); + for (int j = 0; j < unroll_factor; j++) { + max_values[j] = max_wrapper(max_values[j], x[offset + j]); + } + } + data_t max_val = max_values[0]; + for (int i = 1; i < unroll_factor; i++) { + max_val = max_wrapper(max_val, max_values[i]); + } + max_data[0] = max_val; +#else + max_data[0] = x[0]; + for (int c = 1; c < n; ++c) + max_data[0] = nstl::max(max_data[0], x[c]); +#endif +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::_sub(int n, data_t alpha, const data_t *x, + data_t *y) const { + constexpr int unroll_factor = 32; + int tail = n % unroll_factor; + for (int i = 0; i < n - tail; i += unroll_factor) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < unroll_factor; j++) { + y[i + j] = x[i + j] - alpha; + } + } + PRAGMA_OMP_SIMD() + for (int i = n - tail; i < n; i++) { + y[i] = x[i] - alpha; + } +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::_exp(int n, const data_t *a, + data_t *r) const { +#ifdef USE_MKL + if (data_type == data_type::f32) { + vsExp(n, a, r); + return; + } +#endif + parallel_nd(n, [&](int c) { r[c] = expf(a[c]); }); +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::_sum(int n, const data_t *x, + data_t *sum_data) const { +#ifdef USE_CBLAS + // Here we are summing x's eg. e^z , which are positives + // so we can use BLAS ASUM + if (data_type == data_type::f32) { + sum_data[0] = cblas_sasum(n, x, 1); + return; + } +#endif + data_t tsum = static_cast<data_t>(0); + PRAGMA_OMP_SIMD(reduction(+ : tsum)) + for (int c = 0; c < n; ++c) + tsum += x[c]; + sum_data[0] = tsum; +} + +template <impl::data_type_t data_type> +void ref_softmax_fwd_t<data_type>::_scal(int n, data_t alpha, data_t *x) const { +#ifdef USE_CBLAS + if (data_type == data_type::f32) { + cblas_sscal(n, alpha, x, 1); + return; + } +#endif + parallel_nd(n, [&](int c) { x[c] *= alpha; }); +} + +template struct ref_softmax_fwd_t<data_type::f32>; + + +// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW) +template <impl::data_type_t data_type> +void ref_softmax_bwd_t<data_type>::execute_backward_dense( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + parallel_nd(outer_size_, [&](int ou) { + data_t sbr = 0; + size_t off = channels_*ou; + for (int c = 0; c < channels_; c++) { + size_t loff = off + c; + data_t ldata = dst[loff]; + sbr += diff_dst[loff]*ldata; + diff_src[loff] = ldata; + } + + for(int c=0; c < channels_ ; ++c) { + size_t loff = off + c; + diff_src[loff] *= (diff_dst[loff] - sbr); + } + }); +} + +template <impl::data_type_t data_type> +void ref_softmax_bwd_t<data_type>::execute_backward_generic( + const exec_ctx_t &ctx) const { + auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + + const memory_desc_wrapper diff_d(pd()->diff_src_md()); + const memory_desc_wrapper data_d(pd()->dst_md()); + + const size_t dim = channels_ * inner_size_; + + parallel_nd(outer_size_, [&](int ou) { + for (int in = 0; in < inner_size_; in++) { + data_t sbr = 0; + for (int c = 0; c < channels_; c++) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in); + sbr += diff_dst[off_diff] * dst[off_data]; + } + + for(int c=0; c < channels_ ; ++c) { + size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in); + size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in); + diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr); + } + } + }); +} + +template struct ref_softmax_bwd_t<data_type::f32>; + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp new file mode 100644 index 0000000000..5cb74d8007 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp @@ -0,0 +1,186 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_SOFTMAX_HPP +#define CPU_REF_SOFTMAX_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "cpu_softmax_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <impl::data_type_t data_type> +struct ref_softmax_fwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_fwd_pd_t { + using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t); + + status_t init() { + bool ok = true + && is_fwd() + && src_md()->data_type == data_type + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + const int inner_size = utils::array_product( + desc()->data_desc.dims + desc()->softmax_axis + 1, + desc()->data_desc.ndims - desc()->softmax_axis - 1); + + if (inner_size > 1) { + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_softmax_reduction, + sizeof(data_t) * 2 * inner_size); + } + } + }; + + ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd) + { + auto ndims = pd()->desc()->data_desc.ndims; + auto dims = pd()->desc()->data_desc.dims; + auto axis = pd()->desc()->softmax_axis; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->src_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk) + if (data_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = inner_size_ == 1 && data_d.is_dense() + && no_axis_blocking + && data_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_forward_dense(ctx); + else + execute_forward_generic(ctx); + return status::success; + } + +private: + void execute_forward_dense(const exec_ctx_t &ctx) const; + void execute_forward_generic(const exec_ctx_t &ctx) const; + + void _max(int n, const data_t *x, data_t *max_data) const; + void _sub(int n, data_t alpha, const data_t *x, data_t *y) const; + void _exp(int n, const data_t *a, data_t *r) const; + void _sum(int n, const data_t *x, data_t *sum_data) const; + void _scal(int n, data_t alpha, data_t *x) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + +template <impl::data_type_t data_type> +struct ref_softmax_bwd_t: public cpu_primitive_t { + struct pd_t: public cpu_softmax_bwd_pd_t { + using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t; + + DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t); + + status_t init() { + bool ok = true + && !is_fwd() + && utils::everyone_is(data_type, + dst_md()->data_type, + diff_src_md()->data_type) + && attr()->has_default_values(); + if (!ok) return status::unimplemented; + + return status::success; + } + }; + + ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) { + auto dims = pd()->desc()->diff_desc.dims; + auto axis = pd()->desc()->softmax_axis; + auto ndims = pd()->desc()->diff_desc.ndims; + + outer_size_ = utils::array_product(dims, axis); + channels_ = dims[axis]; + inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1); + + const memory_desc_wrapper data_d(pd()->dst_md()); + const memory_desc_wrapper diff_d(pd()->diff_dst_md()); + + bool no_axis_blocking = true; + for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk) + if (diff_d.blocking_desc().inner_idxs[iblk] == axis) + no_axis_blocking = false; + + use_dense_ = true + && inner_size_ == 1 + && diff_d == data_d + && diff_d.is_dense() + && no_axis_blocking + && diff_d.blocking_desc().strides[axis] == 1; + } + + typedef typename prec_traits<data_type>::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + if (use_dense_) + execute_backward_dense(ctx); + else + execute_backward_generic(ctx); + return status::success; + } + +private: + void execute_backward_dense(const exec_ctx_t &ctx) const; + void execute_backward_generic(const exec_ctx_t &ctx) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + bool use_dense_; + int outer_size_, channels_, inner_size_; +}; + + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp new file mode 100644 index 0000000000..3b2a75d99b --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp @@ -0,0 +1,101 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef REF_SUM_HPP +#define REF_SUM_HPP + +#include "reorder_pd.hpp" + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct ref_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) { + for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i) + reorder_pds_.push_back( + (const reorder_pd_t *)rhs.reorder_pds_[i]->clone()); + } + + ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; } + + DECLARE_SUM_PD_T("ref:any", ref_sum_t); + + status_t init() { + bool ok = cpu_sum_pd_t::init() == status::success; + if (!ok) return status::unimplemented; + + for (int i = 0; i < n_; ++i) { + auto r_impls = engine_->get_reorder_implementation_list(); + for (auto r = r_impls; *r; ++r) { + primitive_attr_t attr; + attr.output_scales_.set(scales_[i]); + if (i != 0) attr.post_ops_.append_sum(1.0); + + reorder_pd_t *r_pd; + if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i), + engine_, dst_md()) == status::success) { + r_pd->init_info(); + reorder_pds_.push_back(r_pd); + break; + } + } + } + + ok = reorder_pds_.size() == (size_t)n_; + return ok ? status::success : status::unimplemented; + } + + nstl::vector<const reorder_pd_t *> reorder_pds_; + }; + + ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) { + const int n = pd()->n_inputs(); + reorders_.resize(n); + for (int i = 0; i < n; ++i) + pd()->reorder_pds_[i]->create_primitive(&reorders_[i]); + } + + ~ref_sum_t() { for (auto &r: reorders_) delete r; } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + const auto n = pd()->n_inputs(); + for (int i = 0; i < n; ++i) { + exec_args_t r_args; + r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i); + r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST); + exec_ctx_t r_ctx(ctx.stream(), std::move(r_args)); + reorders_[i]->execute(r_ctx); + } + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + nstl::vector<primitive_t *> reorders_; +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp new file mode 100644 index 0000000000..537084db91 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Common for RNN and LSTM cell execution + */ +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { +using namespace rnn_utils; + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_cell_execution_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + if (rnn_postgemm_ != nullptr) + rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + else + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} +template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) { + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + /// bwd by data on the cell + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, + 0.0, diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + + /// bwd by weights on the cell + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + + if (!rnn.merge_gemm_iter) + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, + diff_w_iter_, rnn.diff_weights_iter_ld); + + /// bwd by bias we just accumulate diffs from the gates + gates_reduction(rnn, ws_gates_, diff_bias_); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp new file mode 100644 index 0000000000..e1a61d4c62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp @@ -0,0 +1,180 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +#define AOC array_offset_calculator +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_[0]); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + // 1. gemm Wx[0-2],x + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + + // 2. gemm Wh[0-1],h + (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb, + rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld); + + // 3. activation zt and rt + elemwise multiplication rt,ht-1 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j); + } + }); + + // 4. gemm Wh[2],h~t + (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1], + rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0, + &(ws_gates(0, 2, 0)), rnn.gates_ws_ld); + + // 5. activation h~t + calculate ht + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) { + assert(!"GRU int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + // use state memory for intermediate computations + // TODO: use cell ws for that + float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0)); + float *hG1_ = dhG1_; + AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld); + AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld); + + // 1. calculate dG2, dG1, and part of dht-1 + // dG2^ = dh * (1 - G0) * (1 - G2^2) + // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) + // dht-1 (part) = dh * G0 + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt + * one_m_square(ws_gates(i, 2, j)); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 0, j) = dG0; + ws_gates(i, 2, j) = dG2; + } + }); + + // 2. calculate intermediate d(hG1) + // d(hG1) = dG2 * W2h^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1], + rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0, + dhG1_, rnn.states_ws_ld); + + // 3. calculate dG1^ and part of dht-1 + // dG1^ = d(hG1) * h * G1 * (1 - G1) + // dht-1 (part) += d(hG1) * G1 + // h * G1 (required for dWh) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float G1 = ws_gates(i, 1, j); + diff_states_t_l(0, i, j) += dhG1(i, j) * G1; + ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1); + hG1(i, j) = G1 * h; + } + }); + + // 4. calculate diff weights + // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) + gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_iter_ld); + gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)), + rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0, + &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld); + + // 5. calculate diff states + // dht-1 += dG1 * W1h + dG0 * W0h + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, + (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0], + rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0, + diff_states_t_l_, rnn.states_ws_ld); + + if (!rnn.merge_gemm_layer) { + // dWx += [dG0 dG1 dG2] * [x] + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld); + } + + // 6. calculate diff bias + gates_reduction(rnn, ws_gates_, diff_bias_); +} +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp new file mode 100644 index 0000000000..8dea8c90a4 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution GRU with linear before reset + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_); + AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j); + ws_gates(i, 0, j) = logistic_fwd( + ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd( + ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd( + ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j)); + states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j) + + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j); + if (rnn.is_training) + ws_Wh_b(i, j) = Wh_b; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) { + if (!rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, + rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, + states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_, + rnn.gates_ws_ld); + } + (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic, + 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_, + rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld); + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); +} + +template <> +rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { + assert(!"GRU LBR int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic); + + // 1. calculate dG1 dG2 dG3 + // dG0 = (dht - G2) * dht * (1 - G0) * G0 + // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 + // dG2 = (1 - G0) * dht * (1 - G2*G2) + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float h = states_tm1_l(i, j); + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dG0 = (h - ws_gates(i, 2, j)) * dHt + * x_m_square(ws_gates(i, 0, j)); + float dG2 = (1.0f - ws_gates(i, 0, j)) + * one_m_square(ws_gates(i, 2, j)) * dHt; + float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); + + diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j); + ws_gates(i, 2, j) = dG2; + ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j); + ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0; + ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1; + } + }); +} + +template <> +rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { + ws_gates_aoc_t ws_gates_r(rnn, ws_cell_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + + (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_, + states_tm1_l_, c_states_tm1_l_, diff_states_t_l_, + diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_, + ws_cell_); + + if (!rnn.merge_gemm_layer) { + // dx = dG * Wx^t + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, + rnn.n_gates * rnn.dic, 1.0, w_layer_[0], + rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0, + &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld); + // dWx += dG^t * x + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_, + rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0, + diff_w_layer_, rnn.diff_weights_layer_ld); + } + // dh += dGr * Wh^t + (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic, + 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld, + 1.0, diff_states_t_l_, rnn.states_ws_ld); + + // dWh += dGr^t * h + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_, + rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_, + rnn.diff_weights_layer_ld); + + // db1-3 += e * dG + // db4 += e * (r * dG2) + gates_reduction(rnn, ws_gates_, diff_bias_); + + parallel_nd(rnn.dic, [&](int j) { + for (int i = 0; i < rnn.mb; i++) { + diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j); + } + }); +} + +#undef AOC + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp new file mode 100644 index 0000000000..a15ba00d4c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "../simple_q10n.hpp" +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j)); + ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j)); + ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j)); + ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j)); + + float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j) + + ws_gates(i, 0, j) * ws_gates(i, 2, j); + states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) { + ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_u8_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + auto q_d = [&](float f) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, src_data_t>()(qf); + }; + + auto deq_w = [&](acc_data_t s, int gate, int j) { + return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ? + saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) : + saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j] + * data_scale)); + }; + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float G0 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j)); + float G1 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j)); + float G2 = tanh_fwd<float>( + deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j)); + float G3 = logistic_fwd<float>( + deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j)); + float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2; + states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp)); + c_states_t_l(i, j) = tmp; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < rnn.dic; j++) { + float Ct = c_states_t_l(i, j); + /// @todo save it in the workspace in fwd pass or recompute it to + /// save bw + float tanhCt = tanh_fwd(Ct); + // we have 2 incoming diffs on Ht + float dHt = diff_states_tp1_l(0, i, j) + + diff_states_t_lp1(rnn.n_states, i, j); + float dCt = diff_states_tp1_l(1, i, j) + + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; + + float dG1 = c_states_tm1_l(i, j) * dCt + * x_m_square(ws_gates(i, 1, j)); + float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); + float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); + float dG2 + = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); + + diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j); + + ws_gates(i, 0, j) = dG0; + ws_gates(i, 1, j) = dG1; + ws_gates(i, 2, j) = dG2; + ws_gates(i, 3, j) = dG3; + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp new file mode 100644 index 0000000000..4536e8dfad --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution of Vanilla RNN + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::math; +using namespace rnn_utils; + +template <> +float activation<alg_kind::eltwise_relu, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return relu_fwd<float>(s, alpha); +} + +template <> +float activation<alg_kind::eltwise_relu, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return relu_bwd<float>(dd, s, alpha); +} + +template <> +float activation<alg_kind::eltwise_tanh, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return tanh_fwd<float>(s); +} + +template <> +float activation<alg_kind::eltwise_tanh, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return dd * one_m_square<float>(s); +} + +template <> +float activation<alg_kind::eltwise_logistic, prop_kind::forward>( + float dd, float s, float alpha, float cliping) { + return logistic_fwd<float>(s); +} + +template <> +float activation<alg_kind::eltwise_logistic, prop_kind::backward>( + float dd, float s, float alpha, float cliping) { + return dd * x_m_square<float>(s); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; j++) { + const float h + = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0); + ws_gates(i, 0, j) = states_t_l(i, j) = h; + } + }); +} + +template <> +rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) { + assert(!"VANILLA RNN int8 is not supported"); +} + +template <> +rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) { + ws_gates_aoc_t ws_gates(rnn, ws_gates_); + bias_aoc_t bias(rnn, bias_); + ws_states_aoc_t states_t_l(rnn, states_t_l_); + ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_); + ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_); + ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_); + ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_); + + parallel_nd(rnn.mb, [&](int i) { + for (int j = 0; j < rnn.dic; ++j) { + const float dH = diff_states_t_lp1(rnn.n_states, i, j) + + diff_states_tp1_l(0, i, j); + auto g = ws_gates(i, 0, j); + ws_gates(i, 0, j) = activation_func(dH, g, 0, 0); + } + }); +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp new file mode 100644 index 0000000000..b39427caf9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp @@ -0,0 +1,191 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RNN_PD_HPP +#define CPU_RNN_PD_HPP + +#include "c_types_map.hpp" +#include "nstl.hpp" +#include "rnn_pd.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" +#include "rnn_utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { + using rnn_fwd_pd_t::rnn_fwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace data_type; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldigo_p); + else + ok = ok && rnn_utils::is_ldigo(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + /* Int8 is supported only for packed weights */ + data_type_t weights_iter_dt = weights_iter_md_.data_type; + data_type_t weights_layer_dt = weights_layer_md_.data_type; + ok = ok && IMPLICATION( + weights_iter_dt == s8, weights_iter_md_.format_kind + == format_kind::rnn_packed); + ok = ok && IMPLICATION( + weights_layer_dt == s8, weights_layer_md_.format_kind + == format_kind::rnn_packed); + + return ok ? status::success : status::unimplemented; + } +}; + +struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { + using rnn_bwd_pd_t::rnn_bwd_pd_t; + +protected: + status_t set_default_params() { + using namespace format_tag; + if (src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); + if (dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); + + if (diff_src_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); + if (diff_weights_layer_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); + } + if (diff_weights_iter_md_.format_kind == format_kind::any) { + CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); + CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); + } + if (diff_dst_layer_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); + + // Optional parameters + if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc)); + if (with_bias() && bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); + if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc)); + + if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc)); + if (with_bias() && diff_bias_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); + if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc)); + + return status::success; + } + + status_t check_layout_consistency() { + using namespace format_tag; + using namespace types; + + auto is_blocked = [&](memory_desc_t md, int ndims) { + return md.format_kind == format_kind::blocked && md.ndims == ndims; + }; + + bool ok = true; + ok = ok && is_blocked(src_layer_md_, 3) + && is_blocked(dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_), + is_blocked(src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&dst_iter_md_), + is_blocked(dst_iter_md_, 5)); + + if (weights_layer_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_); + + if (weights_iter_md_.format_kind == format_kind::rnn_packed) + ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format + == mkldnn_ldgoi_p); + else + ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_); + + ok = ok && IMPLICATION(!is_zero_md(&bias_md_), + memory_desc_matches_tag(bias_md_, ldgo)); + + ok = ok && is_blocked(diff_src_layer_md_, 3) + && is_blocked(diff_dst_layer_md_, 3); + ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_), + is_blocked(diff_src_iter_md_, 5)) + && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), + is_blocked(diff_dst_iter_md_, 5)); + + ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) + && rnn_utils::is_ldigo(&diff_weights_iter_md_); + ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_), + memory_desc_matches_tag(diff_bias_md_, ldgo)); + + return ok ? status::success : status::unimplemented; + } +}; +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp new file mode 100644 index 0000000000..09445648aa --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp @@ -0,0 +1,401 @@ +/******************************************************************************* +* Copyright 2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + * Cell execution LSTM + */ + +#include "rnn_utils.hpp" +#include "../jit_generator.hpp" +#include "../jit_uni_eltwise.hpp" +#include "c_types_map.hpp" +#include "utils.hpp" + +#include "mkldnn_thread.hpp" + + +namespace mkldnn { +namespace impl { +namespace cpu { + +struct jit_uni_rnn_postgemm_kernel : public jit_generator { + + typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_, + void *c_states_t_l_, void *c_states_tm1_l_); + + jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){} + + virtual void init() = 0; + +template <typename src_data_t, typename acc_data_t> + rnn_elemwise_sig(execute) { + rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_); + rnn_utils::bias_aoc_t bias(rnn, bias_); + rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_); + rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_); + rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_); + + // Todo: add parallelization on dic for the batch 1 case + // Assumption: the kernel runs a loop on dic elements + parallel_nd(rnn.mb, [&](int i) { + auto b_ = &bias(0, 0); + auto g_ = &ws_gates(i, 0, 0); + auto s_tl_ = &states_t_l(i, 0); + auto c_tl_ = &c_states_t_l(i, 0); + auto c_tm1l_ = &c_states_tm1_l(i, 0); + kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_); + }); + } + +protected: + kernel_t kernel_; + const rnn_utils::rnn_conf_t &rnn_; + const primitive_attr_t *attr_; +}; + +template <cpu_isa_t isa, impl::data_type_t src_data_t> +struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel +{ + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd) + + typedef typename utils::conditional<src_data_t == data_type::u8, int32_t, + float>::type acc_data_t; + typedef typename utils::conditional<isa == avx512_core, + jit_uni_eltwise_injector_f32<avx512_common>, + jit_uni_eltwise_injector_f32<isa>>::type injector_t; + + jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr) + : jit_uni_rnn_postgemm_kernel(rnn, attr){} + + void init() override { + // we use rax for both constant tables as they use the same table + sigmoid_injector_ = new injector_t(this, + alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax); + tanh_injector_ = new injector_t(this, + alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax); + generate(); + kernel_ = (kernel_t) this->getCode(); + } + +protected: + injector_t *sigmoid_injector_; + injector_t *tanh_injector_; + + // register size in bytes + using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; + size_t vlen = cpu_isa_traits<isa>::vlen; + size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen; + size_t cstate_dt_size = sizeof(float); + size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float); + size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float); + size_t qscale_dt_size = sizeof(float); + size_t bias_dt_size = sizeof(float); + + void generate() { + using namespace Xbyak; + + int mask = attr_->rnn_weights_qparams_.mask_; + float *weights_scales = attr_->rnn_weights_qparams_.scales_; + float data_scale = attr_->rnn_data_qparams_.scale_; + float data_shift = attr_->rnn_data_qparams_.shift_; + + // Labels declaration + Label vector_loop_start_label, vector_loop_end_label; + Label rem_loop_start_label, rem_loop_end_label; + Label table_label; + + // Register map + Reg64 loop_cnt(r11); // loop counter + Reg64 table_reg(rbx); // table is used for data scale and shifts + Reg64 weights_scales_reg(r13); + // We skip vmm0 as it can be used by the injector for masks on sse4.2 + Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7); + + // constant table map + Address dscale_off_addr = ptr[table_reg]; + Address dshift_off_addr = ptr[table_reg + vlen]; + Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen]; + Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen]; + + // quantize from float to u8 + auto q_d = [&](Vmm f, Vmm tmp_vmm) { + uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm); + uni_vmulps(f, f, dscale_off_addr); // apply scale + uni_vaddps(f, f, dshift_off_addr); // apply shift + uni_vcvtps2dq(f, f); // convert to int32 + uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16 + uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation + // Note that the results are interleaved by 128 bit chunks, so we need to merge them together + switch (vlen) { + case 64: { //avx512 + Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx()); + uni_vmovups(tmpz, zmm_perm_mask_addr); + vpermd(fz, tmpz, fz); + break; } + case 32: { //avx + Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx()); + uni_vmovups(tmpy, ymm_perm_mask_addr); + vpermd(fy, tmpy, fy); + break; } + case 16: // sse: nothing to do + break; + default: assert(!"Unsupported case"); + }; + }; + + auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) { + if (packed) + uni_vrcpps(tmp, s); + else + uni_vrcpss(tmp, s); // prevent divide by zero + // we add one Newton iteration + uni_vmulps(s, s, tmp); + uni_vmulps(s, s, tmp); // s <- s * tmp^2 + uni_vaddps(tmp, tmp, tmp); + uni_vsubps(tmp, tmp, s); + uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2 + }; + + // dequantize from s32 to float + auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) { + // TODO: if mask is 0 precompute mul and inverse + if (mask == 0) + uni_vbroadcastss(tmp1, ptr[weights_scales_reg]); + else + uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]); + uni_vcvtdq2ps(s, s); + uni_vmulps(tmp1, tmp1, dscale_off_addr); + fast_recip(tmp1, tmp2, packed); + uni_vmulps(s, s, tmp1); + }; + + // We start code generations here + preamble(); + + // extract addresses passed as parameter +#ifdef _WIN32 + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = r10; + // Here we cannot use rbp to have initial stack pointer so we + // use rsp and offset it with the size of pushed registers in + // preamble + mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]); +#else + auto addr_ws_gates_reg = abi_param1; + auto addr_bias_reg = abi_param2; + auto addr_states_t_l_reg = abi_param3; + auto addr_c_states_tm1_l_reg = abi_param4; + auto addr_c_states_t_l_reg = abi_param5; +#endif + + // initialize registers with addresses and constants + mov(table_reg, table_label); + mov(weights_scales_reg, size_t(weights_scales)); + // both sigmoid and tanh use the same table so load address just once in rax + sigmoid_injector_->load_table_addr(); + + mov(loop_cnt, rnn_.dic * gate_dt_size); + cmp(loop_cnt, vlen); + jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + + L(vector_loop_start_label); + { + // load G0 G1 G2 G3 + uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true); + } + + // add biases + uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0.getIdx()); + sigmoid_injector_->compute_vector(G1.getIdx()); + tanh_injector_->compute_vector(G2.getIdx()); + sigmoid_injector_->compute_vector(G3.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0 * G2 + uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1_vmm, tmp1_vmm, G1); + uni_vfmadd231ps(tmp1_vmm, G0, G2); + uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1_vmm.getIdx()); + uni_vmulps(tmp1_vmm, tmp1_vmm, G3); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, vlen); + add(addr_bias_reg, vlen); + add(addr_states_t_l_reg, vlen_dst); + add(addr_c_states_tm1_l_reg, vlen); + add(addr_c_states_t_l_reg, vlen); + if (mask != 0) + add(weights_scales_reg, vlen); + + // increment loop counter + sub(loop_cnt, vlen); + cmp(loop_cnt, vlen); + jge(vector_loop_start_label); + } + L(vector_loop_end_label); + + cmp(loop_cnt, 0); + je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); + // Same code as above, we just use movuss for accessing inputs + // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar + L(rem_loop_start_label); + { + // remaping registers to Xmms + Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx()); + Xmm tmp1s_vmm(tmp1_vmm.getIdx()); + + // load G0 G1 G2 G3 + uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]); + uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]); + uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]); + uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]); + + // dequantize the gates from s32 to f32 if needed + if (src_data_t == data_type::u8){ + deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false); + deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false); + deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false); + deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false); + } + + // add biases + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]); + uni_vaddps(G0s, G0s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]); + uni_vaddps(G1s, G1s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]); + uni_vaddps(G2s, G2s, tmp1s_vmm); + uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]); + uni_vaddps(G3s, G3s, tmp1s_vmm); + + // inject eltwise code + sigmoid_injector_->compute_vector(G0s.getIdx()); + sigmoid_injector_->compute_vector(G1s.getIdx()); + tanh_injector_->compute_vector(G2s.getIdx()); + sigmoid_injector_->compute_vector(G3s.getIdx()); + + // compute c_states_t_l = G1 * c_tm1_l + G0s * G2 + uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s); + uni_vfmadd231ps(tmp1s_vmm, G0s, G2s); + uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm); + + // states_t_l = G3 * tanh(c_states_t_l) + tanh_injector_->compute_vector(tmp1s_vmm.getIdx()); + uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s); + + // if int8, we quantize the resulting state + if (src_data_t == data_type::u8) + q_d(tmp1_vmm, tmp2_vmm); + + // write back the result + if(vlen_dst == vlen) + uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm); + else + // we write only 1/4 of the register + switch(vlen_dst){ + case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break; + default: + assert(!"Unsuported vector length for quantization"); + } + + // increment address pointers + add(addr_ws_gates_reg, gate_dt_size); + add(addr_bias_reg, bias_dt_size); + add(addr_states_t_l_reg, hstate_dt_size); + add(addr_c_states_tm1_l_reg, cstate_dt_size); + add(addr_c_states_t_l_reg, cstate_dt_size); + if (mask != 0) + add(weights_scales_reg, qscale_dt_size); + + // increment loop counter + sub(loop_cnt, gate_dt_size); + cmp(loop_cnt, 0); + jg(rem_loop_start_label); + + } + L(rem_loop_end_label); + + postamble(); + + // Again, only one table is needed and shared between sigmoid and tanh + sigmoid_injector_->prepare_table(false); + tanh_injector_->prepare_table(true); + + L(table_label); + { + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale)); + for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift)); + // perm mask for ymm + dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7); + // perm mask for zmm + dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7); + dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14); + } + } + +}; + +template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>; + +template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>; +template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>; +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp new file mode 100644 index 0000000000..ead536816c --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp @@ -0,0 +1,788 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* + General architecture + + for diff states, we have n_states + 1 as we have n_states diff + to propagate to the previous iteration and 1 states to propagate + to the previous layer + index 0 is dh for cell(t-1, l) to consume + index 1 is dc for cell(t-1, l) to consume + index 2 is dh for cell(t, l-1) to consume + this indexing enables to have the same indexing for states in elemwise + function + only the cell execution function should be impacted + + */ + +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "../gemm/gemm.hpp" +#include "../simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace mkldnn::impl::memory_tracking::names; +using namespace rnn_utils; +#define AOC array_offset_calculator + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction( + const rnn_conf_t &rnn, const acc_data_t *ws_gates_, + float *diff_bias_) const { + auto body = [&](int i, int k) { + for (int j = 0; j < rnn.mb; j++) + diff_bias_[i * rnn.dic + k] + += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; + }; + + // @todo block k on simd-width +#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ + /* icc 17.0 has a problem with simd collapse */ \ + && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) +#pragma omp parallel for simd collapse(2) + for (int i = 0; i < rnn.n_gates; i++) + for (int k = 0; k < rnn.dic; k++) + body(i, k); +#else + parallel_nd(rnn.n_gates, rnn.dic, body); +#endif +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) { + assert(ldA * ldB * ldC != 0); + extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, + &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { + assert(!"non packed gemm is disabled for int8"); +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + assert(transA == 'N'); + cblas_sgemm_compute(CblasColMajor, CblasPacked, + (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, + ldB, beta, c_, ldC); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +template <> +rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { +#if (USE_MKL_PACKED_GEMM) + int8_t offseta = 0, offsetb = 0; + int32_t offsetc = 0; + cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, + CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, + ldB, offsetb, beta, c_, ldC, &offsetc); +#else + UNUSED(transA); + UNUSED(transB); + UNUSED(m); + UNUSED(n); + UNUSED(k); + UNUSED(alpha); + UNUSED(ldA); + UNUSED(b_); + UNUSED(ldB); + UNUSED(beta); + UNUSED(c_); + UNUSED(ldC); + assert(!"packed gemm is disabled"); +#endif +} + +//*************** Grid computations strategy: linear ***************// +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_grid_execution_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) { + AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); + AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, + rnn.states_nld * rnn.states_ws_ld); + AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, + rnn.gates_nld * rnn.gates_ws_ld); + AOC<weights_data_t *, 3> weights_input( + weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); + AOC<weights_data_t *, 3> weights_states( + weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); + AOC<float*, 3> bias( + bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer, + rnn.n_dir, + rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); + AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, + rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); + AOC<float, 3> diff_bias( + diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + AOC<float, 4> ws_grid( + ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); + + // We run the grid of computation + for (int dir = 0; dir < rnn.n_dir; dir++) { + for (int j = 0; j < rnn.n_layer; j++) { + int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; + + if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, + rnn.mb * rnn.n_iter, rnn.slc, 1.0, + weights_input(lay, dir, 0), rnn.weights_iter_ld, + &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, + &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); + } + + for (int i = 0; i < rnn.n_iter; i++) { + int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; + (this->*cell_func)(rnn, + &(ws_states(lay + 1, dir, iter + 1, 0)), + &(ws_c_states(lay + 1, dir, iter + 1, 0)), + &(ws_diff_states(lay, dir, 0, iter, 0)), + &(weights_input(lay, dir, 0)), + &(weights_states(lay, dir, 0)), + &(bias(lay, dir, 0)), + &(ws_states(lay, dir, iter + 1, 0)), + &(ws_states(lay + 1, dir, iter, 0)), + &(ws_c_states(lay + 1, dir, iter, 0)), + &(ws_diff_states(lay + 1, dir, 0, iter, 0)), + &(ws_diff_states(lay, dir, 0, iter + 1, 0)), + &(diff_weights_layer(lay, dir, 0)), + &(diff_weights_iter(lay, dir, 0)), + &(diff_bias(lay, dir, 0)), + &(ws_gates(lay, dir, iter, 0)), + &(ws_grid(lay, dir, iter, 0)), + ws_cell_); + } + + if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { + (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, + rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), + rnn.weights_layer_ld, + (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, 0.0, + (acc_data_t *)(&(ws_diff_states( + lay, dir, rnn.n_states, 0, 0))), + rnn.states_ws_ld); + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay, dir, 1, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), + rnn.diff_weights_layer_ld); + } + if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { + gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, + rnn.mb * rnn.n_iter, 1.0, + (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), + rnn.gates_ws_ld, + (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), + rnn.states_ws_ld, 1.0, + (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), + rnn.diff_weights_iter_ld); + } + } + } +} + +//********* GRID computations strategy: utility functions **********// + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, + const float *__restrict diff_dst_layer_) const { + + AOC<src_data_t, 4> ws_states( + ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto xt_d = memory_desc_wrapper(pd()->src_md(0)); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto xxt = xt_ + xt_d.blk_off(it, b); + src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); + src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); + if (rnn.exec_dir != r2l) + for (int c = 0; c < rnn.slc; c++) + ws_l2r_ptr[c] = xxt[c]; + if (rnn.exec_dir != l2r) + for (int c = 0; c < rnn.slc; c++) + ws_r2l_ptr[c] = xxt[c]; + }); +} + +template <> +void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, + const float *diff_dst_layer_) const { + AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); + + switch (rnn.exec_dir) { + case bi_concat: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[rnn.dic + s]; + } + }); + break; + case bi_sum: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + ws_diff_states( + rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case l2r: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x + = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + case r2l: + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + auto diff_dst_layer_x = diff_dst_layer_ + + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); + for (int s = 0; s < rnn.dic; s++) { + ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) + = diff_dst_layer_x[s]; + } + }); + break; + default: assert(!"Unsupported direction"); break; + } +} + +/* For int8 configuration, input iteration states may be of types f32 or u8 + * Internally h_state is always stored in u8 and c_state is always stored in f32 + * If input states are of type u8 then h state is copied and c state is dequantized + * If input states are of type f32 then h state is quantized and c_state is copied + * */ +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename input_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter( + const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, + float *__restrict ws_c_states_, float *__restrict ws_diff_states_, + const input_data_t *__restrict firstit_states_, + const float *__restrict diff_dst_iter_) const { + AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](input_data_t f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, src_data_t>()(qf); + } else + return (src_data_t)f; + }; + + const bool dequantize = pd()->with_src_iter() + && pd()->src_md(1)->data_type == data_type::u8; + auto maybe_deq = [&](input_data_t s) { + if (dequantize) + return (((float)s - data_shift) / data_scale); + else + return (float)s; + }; + auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); + if (firstit_states_) { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.sic; s++) + ws_states(lay + 1, dir, 0, b, s) = maybe_q( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 0, b, s)]); + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.sic; s++) + ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( + firstit_states_[firstit_states_d.blk_off( + lay, dir, 1, b, s)]); + }); + } else { + parallel_nd( + rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { + for (int j = 0; j < rnn.sic; j++) { + ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; + ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; + } + }); + } +} + +template <> +template <typename input_data_t> +void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter_) const { + AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); + if (diff_dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + array_copy(&(ws_diff_states( + lay, dir, state, rnn.n_iter, b, 0)), + diff_dst_iter_ + + diff_dst_iter_d.blk_off( + lay, dir, state, b), + rnn.dic); + }); + } else { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int i) { + for (int j = 0; j < rnn.dic; j++) + ws_diff_states(lay, dir, state, rnn.n_iter, i, j) + = 0.0f; + }); + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename dst_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + + auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); + AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float shift = (pd()->attr()->rnn_data_qparams_.shift_); + float scale = (pd()->attr()->rnn_data_qparams_.scale_); + + const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (dst_data_t)(((float)s - shift) / scale); + else + return (dst_data_t)s; + }; + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + if (rnn.exec_dir != r2l) { + for (int s = 0; s < rnn.dic; s++) { + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); + } + dir = 1; + } + if (rnn.exec_dir != l2r) { + for (int s = 0; s < rnn.dic; s++) + switch (rnn.exec_dir) { + case bi_sum: + dst_layer_[dst_layer_d.blk_off(it, b, s)] + += maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + break; + default: + dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] + = maybe_deq(ws_states( + rnn.n_layer, dir, rnn.n_iter - it, b, s)); + } + } + }); +} + +template <> +template <typename dst_data_t> +void ref_rnn_bwd_f32_t::copy_res_layer( + const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, + const src_data_t *ws_states_, const float *ws_diff_states_) const { + auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); + AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + + parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { + int dir = 0; + for (int s = 0; s < rnn.slc; s++) { + float *dst_addr = diff_src_layer_ + + diff_src_layer_d.blk_off( + (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, + b, dir * rnn.slc + s); + float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); + if (rnn.n_dir - 1) + res += ws_diff_states( + 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); + dst_addr[0] = res; + } + }); +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +template <typename output_data_t> +void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); + AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, + rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + + const bool quantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::u8 + && rnn.dt_conf != all_f32; + auto maybe_q = [&](float f) { + if (quantize) { + float qf = f * data_scale + data_shift; + return qz_a1b0<float, output_data_t>()(qf); + } else + return (output_data_t)f; + }; + + const bool dequantize = pd()->with_dst_iter() + && pd()->dst_md(1)->data_type == data_type::f32 + && rnn.dt_conf != all_f32; + auto maybe_deq = [&](src_data_t s) { + if (dequantize) + return (output_data_t)(((float)s - data_shift) / data_scale); + else + return (output_data_t)s; + }; + if (dst_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, + [&](int lay, int dir, int b) { + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] + = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); + } + if (pd()->cell_kind() == alg_kind::vanilla_lstm) + for (int s = 0; s < rnn.dic; s++) { + dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] + = maybe_q(ws_c_states( + lay + 1, dir, rnn.n_iter, b, s)); + } + }); + } +} + +template <> +template <typename output_data_t> +void ref_rnn_bwd_f32_t::copy_res_iter( + const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, + const src_data_t *ws_states_, float *ws_c_states_, + const float *ws_diff_states_) const { + auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); + AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, + rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, + rnn.states_ws_ld); + if (diff_src_iter_) { + parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, + [&](int lay, int dir, int state, int b) { + for (int s = 0; s < rnn.sic; s++) { + diff_src_iter_[diff_src_iter_d.blk_off( + lay, dir, state, b, s)] + = ws_diff_states(lay, dir, state, 0, b, s); + } + }); + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) { + /* Original set of bias provided by the user */ + AOC<const float, 5> b( + b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + /* Array of pointers initialized in packing */ + AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); + AOC<float, 3> scratch_bias( + scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); + + if (rnn.copy_bias) { + parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, + [&](size_t i) { scratch_bias_[i] = b_[i]; }); + } + + for (int i = 0; i < rnn.n_layer; i++) { + for (int d = 0; d < rnn.n_dir; d++) { + int offset_bias = 0; + for (int p = 0; p < rnn.n_parts_bias; p++) { + bias(i, d, p) = rnn.copy_bias + ? (float *) &scratch_bias(i, d, offset_bias) + : (float *) &b(i, d, offset_bias); + offset_bias += rnn.parts_bias[p] * rnn.dic; + } + } + } + +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_bias_finalize_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) { + if (rnn.dt_conf != all_f32) { + float data_shift = pd()->attr()->rnn_data_qparams_.shift_; + float data_scale = pd()->attr()->rnn_data_qparams_.scale_; + float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; + bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; + for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) + for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { + size_t off = i * rnn.n_bias * rnn.dic + j; + float weights_scale + = scale_per_oc ? weights_scales[j] : weights_scales[0]; + scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) + * data_shift / (weights_scale * data_scale); + } + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, + weights_type>::assign_packed_weights)) { + assert(md->format_kind == format_kind::rnn_packed); + const auto packed_desc = md->format_desc.rnn_packed_desc; + AOC<weights_data_t *, 3> weights(weights_, + rnn.n_layer, rnn.n_dir, packed_desc.n_parts); + + size_t offset_packed = 0; + for (int l = 0; l < rnn.n_layer; l++) + for (int d = 0; d < rnn.n_dir; d++) { + for (int p = 0; p < packed_desc.n_parts; p++) { + weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; + offset_packed + += packed_desc.part_pack_size[p] / sizeof(weights_data_t); + } + } +} + +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +rnn_weights_assign_sig( + (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) { + assert(md->format_kind == format_kind::blocked); + const auto &blk = md->format_desc.blocking; + /* Original set of weights provided by the user */ + AOC<const weights_data_t, 3> w(w_, + rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); + /* Array of pointers for each part of weights */ + AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); + + for (int i = 0; i < rnn.n_layer; i++) + for (int d = 0; d < rnn.n_dir; d++) { + size_t offset_weights = 0; + for (int p = 0; p < n_parts; p++) { + weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); + offset_weights += gates_per_part[p] * blk.strides[3]; + } + } +} + +//********************* Execution function *********************// +template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> +void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_( + const exec_ctx_t &ctx) const { + const rnn_conf_t &rnn = this->pd()->rnn_; + auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); + auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); + auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); + auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); + auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); + + auto dst_last_layer = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); + auto dst_last_iter = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); + + auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); + auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); + + auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp); + auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp); + auto w_iter_comp = reinterpret_cast<const float *>( + iter_weights_n_comp + rnn.weights_iter_comp_offset); + auto w_layer_comp = reinterpret_cast<const float *>( + layer_weights_n_comp + rnn.weights_layer_comp_offset); + + auto scratchpad = this->scratchpad(ctx); + + auto ptr_wei_layer + = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer); + auto ptr_wei_iter + = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter); + auto ptr_bias = + scratchpad.template get<float *>(key_rnn_ptrs_bia); + + // fetchihg buffers from the workspace + // if no workspace was provided we use the scratchpad + char *scratch_ptr = scratchpad.template get<char>(key_rnn_space); + char *ws_ptr = nullptr; + if (rnn.use_workspace) + ws_ptr = rnn.is_fwd + ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) + : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); + + char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; + acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); + src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); + float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); + float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); + float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); + float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); + + auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); + auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); + + auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); + auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); + auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); + + // Fetching extra buffers from scratchpad + float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); + + // initialize diff_states to 0 + if (aprop == prop_kind::backward) + array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); + + /* Pack(if using packed gemm API) or copy(if input arrays have bad leading + * dimension */ + (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); + + (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), + rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, + rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, + rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, + ptr_bias, bias, ws_bias); + (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), + rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, + rnn.n_parts_weights_layer, rnn.parts_weights_layer, + rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, + bias, ws_bias); + + (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); + + // we first need to copy the initial states and input into ws + copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const float *)states, diff_dst_iter); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, + (const uint8_t *)states, diff_dst_iter); + else + assert(!"unimplemented"); + + // run the execution on the grid + (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, + ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, + diff_weights_layer, diff_weights_iter, diff_bias); + + // Finally we copy the results to the result buffers + if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, + ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) + copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, + ws_states, ws_diff_states); + else + assert(!"unimplemented"); + + if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 + || rnn.dt_conf == all_f32) + copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) + copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, + ws_c_states, ws_diff_states); + else + assert(!"unimplemented"); +}; + +/* Fix for MSVS warning C4661 */ +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); +template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); +template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); +template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); + +template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>; +template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>; +template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>; + +#undef AOC +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp new file mode 100644 index 0000000000..6f449a9016 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp @@ -0,0 +1,328 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_REF_RNN_HPP +#define CPU_REF_RNN_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "memory_tracking.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +#include "../cpu_isa_traits.hpp" +#include "../gemm/os_blas.hpp" + +#include "cpu_rnn_pd.hpp" +#include "../cpu_primitive.hpp" +#include "rnn_utils.hpp" +#include "jit_uni_rnn_postgemm.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <alg_kind_t alg_kind, prop_kind_t prop_kind> +float activation(float s, float alpha, float cliping, float dd); + +template <prop_kind_t aprop, impl::data_type_t src_type, + impl::data_type_t weights_type> +struct _ref_rnn_common_t : public cpu_primitive_t { + typedef typename prec_traits<src_type>::type src_data_t; + typedef typename prec_traits<weights_type>::type weights_data_t; + typedef typename utils::conditional<src_type == data_type::u8, int32_t, + float>::type acc_data_t; + + using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>; + + typedef rnn_elemwise_sig((class_name::*elemwise_f)); + typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); + typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); + + typedef rnn_gemm_sig((class_name::*gemm_t)); + typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); + typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); + typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); + + using base_pd_t = + typename utils::conditional<false || aprop == prop_kind::forward, + cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type; + + struct pd_t : public base_pd_t { + using base_pd_t::base_pd_t; + + DECLARE_COMMON_PD_T("ref:any", class_name); + + status_t init() { + using namespace prop_kind; + using namespace utils; + using namespace format_tag; + using namespace rnn_utils; + const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind; + + data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type; + data_type_t weights_iter_dt + = this->desc()->weights_iter_desc.data_type; + data_type_t weights_layer_dt + = this->desc()->weights_layer_desc.data_type; + + bool ok = true + && one_of(cell_kind, alg_kind::vanilla_rnn, + alg_kind::vanilla_lstm, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset) + && IMPLICATION(aprop == prop_kind::forward, + one_of(this->desc()->prop_kind, forward_training, + forward_inference)) + && IMPLICATION(aprop == backward, + one_of(this->desc()->prop_kind, backward)) + && src_layer_dt == src_type + && everyone_is( + weights_type, weights_iter_dt, weights_layer_dt) + && this->set_default_params() == status::success + && this->with_bias(); + if (!ok) + return status::unimplemented; + + init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1), + this->weights_md(0), this->weights_md(1), this->dst_md(0)); + + if (rnn_.dt_conf == all_f32) + ok = ok && this->attr()->has_default_values(); + + // Set weights descriptors to desired format + memory_desc_t new_weights_layer_md = *this->weights_md(0); + CHECK(set_expected_desc(rnn_, new_weights_layer_md, false)); + if (this->weights_layer_md_.format_kind == format_kind::any) { + this->weights_layer_md_ = new_weights_layer_md; + } else if (this->weights_layer_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_layer_md_ != new_weights_layer_md) + return status::unimplemented; + } + + memory_desc_t new_weights_iter_md = *this->weights_md(1); + CHECK(set_expected_desc(rnn_, new_weights_iter_md, true)); + if (this->weights_iter_md_.format_kind == format_kind::any) { + this->weights_iter_md_ = new_weights_iter_md; + } else if (this->weights_iter_md_.format_kind + == format_kind::rnn_packed) { + if (this->weights_iter_md_ != new_weights_iter_md) + return status::unimplemented; + } + + CHECK(this->check_layout_consistency()); + + set_conf(rnn_, *this->desc(), this->weights_md(0), + this->weights_md(1), this->diff_weights_md(0), + this->diff_weights_md(1)); + + size_t scratchpad_sz{0}, ws_sz{0}; + get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); + + // initialize the workspace if needed + if (rnn_.is_training) { + dims_t ws_dims = { (int)ws_sz }; + mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims, + data_type::u8, format_tag::x); + } + + init_scratchpad(scratchpad_sz); + + return status::success; + } + + rnn_utils::rnn_conf_t rnn_; + + private: + void init_scratchpad(size_t scratchpad_sz) { + using namespace memory_tracking::names; + auto scratchpad = this->scratchpad_registry().registrar(); + scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096); + + int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1; + int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; + scratchpad.book(key_rnn_ptrs_wei_layer, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_wei_iter, + sizeof(float *) * ptr_wei_sz); + scratchpad.book(key_rnn_ptrs_bia, + sizeof(float *) * ptr_wei_sz); + } + }; + + _ref_rnn_common_t(const pd_t *apd) + : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) { + /// @todo set max_feature_size assuming that we limit the number of + /// iterations and layer to one if slc != dic and sic != dic + /// respectively + + bias_preparation_func = &class_name::bias_prepare; + bias_finalization_func = &class_name::bias_finalize; + + auto set_gemm_funcs + = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) { + if (packed_gemm) { + g = &class_name::packed_gemm; + a = &class_name::assign_packed_weights; + } else { + g = &class_name::gemm; + a = &class_name::assign_weights; + } + }; + set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, + weights_iter_assign_func); + + set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, + weights_layer_assign_func); + + switch (pd()->cell_kind()) { + case alg_kind::vanilla_lstm: + cell_func = &class_name::cell_execution; + if (aprop == prop_kind::forward) { + if (mayiuse(avx512_core)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>( + pd()->rnn_, pd()->attr()); + else if (mayiuse(avx2)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>( + pd()->rnn_, pd()->attr()); + else if (mayiuse(sse42)) + rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>( + pd()->rnn_, pd()->attr()); + assert(rnn_postgemm_ != nullptr); + rnn_postgemm_->init(); + } + elemwise_func = &class_name::lstm_elemwise; + break; + case alg_kind::vanilla_rnn: // @todo switch on cell kind + cell_func = &class_name::cell_execution; + elemwise_func = &class_name::rnn_elemwise; + switch (pd()->activation_kind()) { + case alg_kind::eltwise_relu: + activation_func = &activation<alg_kind::eltwise_relu, aprop>; + break; + case alg_kind::eltwise_tanh: + activation_func = &activation<alg_kind::eltwise_tanh, aprop>; + break; + case alg_kind::eltwise_logistic: + activation_func = &activation<alg_kind::eltwise_logistic, aprop>; + break; + default: break; + } + break; + case alg_kind::vanilla_gru: + cell_func = &class_name::cell_execution_gru; + break; + case alg_kind::gru_linear_before_reset: + cell_func = &class_name::cell_execution_gru_lbr; + elemwise_func = &class_name::gru_lbr_elemwise; + break; + default: break; + } + + grid_computation = &class_name::linear_execution; + + size_t scratchpad_size, workspace_size; + rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_, + ws_c_states_offset_, ws_diff_states_offset_, + ws_grid_comp_offset_, ws_cell_comp_offset_, + ws_bias_offset_, scratchpad_size, workspace_size); + } + + ~_ref_rnn_common_t() {} + + // typedef typename prec_traits::type data_t; + + virtual status_t execute(const exec_ctx_t &ctx) const override { + execute_(ctx); + return status::success; + } + +private: + void execute_(const exec_ctx_t &ctx) const; + rnn_grid_execution_sig(linear_execution); + rnn_cell_execution_sig(cell_execution); + rnn_cell_execution_sig(cell_execution_gru); + rnn_cell_execution_sig(cell_execution_gru_lbr); + rnn_elemwise_sig(rnn_elemwise); + rnn_elemwise_sig(lstm_elemwise); + rnn_elemwise_sig(gru_lbr_elemwise); + rnn_gemm_sig(gemm); + rnn_gemm_sig(packed_gemm); + rnn_bias_prepare_sig(bias_prepare); + rnn_bias_finalize_sig(bias_finalize); + rnn_weights_assign_sig(assign_weights); + rnn_weights_assign_sig(assign_packed_weights); + + float (*activation_func)(float dd, float s, float alpha, float cliping); + + void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_diff_states_, + const src_data_t *xt_, const float *diff_dst_layer) const; + + template <typename input_data_t> + void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, + src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_, + const input_data_t *firstit_states_, + const float *diff_dst_iter) const; + + template <typename dst_data_t> + void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, + dst_data_t *dst_layer_, float *diff_src_layer, + const src_data_t *ws_states_, const float *ws_diff_states_) const; + + template <typename output_data_t> + void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, + output_data_t *dst_iter_, float *diff_src_iter, + const src_data_t *ws_states_, float *ws_c_states, + const float *ws_diff_states_) const; + + void gates_reduction(const rnn_utils::rnn_conf_t &rnn, + const acc_data_t *ws_gates_, float *diff_bias_) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + + size_t ws_gates_offset_; + size_t ws_states_offset_; + size_t ws_c_states_offset_; + size_t ws_bias_offset_; + size_t ws_diff_states_offset_; + size_t ws_grid_comp_offset_; + size_t ws_cell_comp_offset_; + jit_uni_rnn_postgemm_kernel *rnn_postgemm_; + + grid_execution_f grid_computation; + cell_execution_f cell_func; + + bias_prepare_t bias_preparation_func; + bias_finalize_t bias_finalization_func; + weights_assign_t weights_layer_assign_func; + weights_assign_t weights_iter_assign_func; + + gemm_t gemm_layer_func; + gemm_t gemm_iter_func; + elemwise_f elemwise_func; +}; + +using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>; +using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>; +using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>; +} +} +} +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp new file mode 100644 index 0000000000..597c63e3f8 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_RNN_REORDERS_HPP +#define CPU_RNN_REORDERS_HPP + +#include <assert.h> + +#include "type_helpers.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" +#include "simple_q10n.hpp" +#include "cpu_reorder_pd.hpp" +#include "../gemm/os_blas.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t type_i, data_type_t type_o> +struct rnn_data_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc) + && od == id; + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + + rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const size_t nelems = input_d.nelems(); + const float scale = pd()->attr()->rnn_data_qparams_.scale_; + const float shift = pd()->attr()->rnn_data_qparams_.shift_; + + parallel_nd(nelems, [&](size_t i) { + float in = (float)input[input_d.off_l(i)] * scale + shift; + output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in); + }); + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <data_type_t type_i, data_type_t type_o> +struct rnn_weights_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && od.format_kind() == format_kind::rnn_packed + && od.rnn_packed_desc().format == mkldnn_ldigo_p + && od.rnn_packed_desc().n_parts == 1 + && attr != nullptr; + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + _pd->itag_ = itag; + if (_pd->init() != success) { delete _pd; return unimplemented; } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + format_tag_t itag_; + + private: + void init_scratchpad() { + const memory_desc_wrapper id(src_md()); + const size_t nelems = id.nelems(); + const auto &dims = id.dims(); + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + size_t quantization_size = sizeof(int8_t) * nelems; + size_t reduction_size = itag_ == ldigo + ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0] + * dims[1] * dims[3] * dims[4] + : 0; + scratchpad.book( + key_reorder_rnn_weights_quantization, quantization_size); + scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + + /* Quantize input & compute compensation */ + auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_rnn_weights_quantization); + auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_rnn_weights_reduction); + float *comp = reinterpret_cast<float *>( + output + output_d.rnn_packed_desc().offset_compensation); + const float *scales = pd()->attr()->rnn_weights_qparams_.scales_; + const int mask = pd()->attr()->rnn_weights_qparams_.mask_; + + if (is_igo) { + int nthr = mkldnn_get_max_threads(); + int LD_nthr = nstl::min(L * D, nthr); + int I_nthr = nstl::min(I, nthr / LD_nthr); + parallel(nthr, [&](const int ithr, const int nthr) { + int LD_ithr = -1, LD_s = -1, LD_e = -1; + int I_ithr = -1, I_s = -1, I_e = -1; + if (ithr < LD_nthr * I_nthr) { + LD_ithr = ithr % LD_nthr; + I_ithr = ithr / LD_nthr; + balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e); + balance211(I, I_nthr, I_ithr, I_s, I_e); + } + int32_t *comp_ithr = reduction + I_ithr * L * D * G * O; + for (int ld = LD_s; ld < LD_e; ld++) { + for (int go = 0; go < G * O; go++) + comp_ithr[ld * G * O + go] = 0; + for (int i = I_s; i < I_e; i++) { + PRAGMA_OMP_SIMD() + for (int go = 0; go < G * O; go++) { + const float s = scales[(mask == 0) ? 0 : go]; + int8_t q = qz_b0<in_data_t, out_data_t>()( + input[ld * I * G * O + i * G * O + go], s); + quantized[ld * I * G * O + i * G * O + go] + = (int32_t)q; + comp_ithr[ld * G * O + go] += (int32_t)q; + } + } + } + }); + parallel_nd(L * D * G * O, + [&](int s) { comp[s] = saturate<float>(reduction[s]); }); + for (int i = 1; i < I_nthr; i++) { + parallel_nd(L * D * G * O, [&](int s) { + comp[s] += saturate<float>( + reduction[i * L * D * G * O + s]); + }); + } + } else { + parallel_nd(L * D, G * O, [&](int ld, int go) { + int32_t compensation = 0; + const float s = scales[(mask == 0) ? 0 : go]; + PRAGMA_OMP_SIMD() + for (int i = 0; i < I; i++) { + int8_t q = qz_b0<in_data_t, out_data_t>()( + input[ld * G * O * I + go * I + i], s); + compensation += (int32_t)q; + quantized[ld * G * O * I + go * I + i] = q; + } + comp[ld * G * O + go] = saturate<float>(compensation); + }); + } + + /* Pack */ + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + int n_parts = output_d.rnn_packed_desc().n_parts; + const size_t *size_packed_cell + = output_d.rnn_packed_desc().part_pack_size; + const int *parts = output_d.rnn_packed_desc().parts; + const int n = output_d.rnn_packed_desc().n; + char *to_pack = output; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = parts[p] * O; + int k_p = I; + cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix, + is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p, + &quantized[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, g, 0, 0)], + is_igo ? G * O : I, to_pack); + to_pack += size_packed_cell[p]; + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +template <> +struct rnn_weights_reorder_t<data_type::f32, data_type::f32> + : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { +#if !USE_MKL_PACKED_GEMM + return status::unimplemented; +#endif + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == data_type::f32 + && od.data_type() == data_type::f32 + && od.format_kind() == format_kind::rnn_packed + && utils::one_of(od.rnn_packed_desc().format, + mkldnn_ldigo_p, mkldnn_ldgoi_p) + && attr->has_default_values(); + if (!args_ok) return status::invalid_arguments; + + format_tag_t itag = id.matches_one_of_tag( + format_tag::ldigo, format_tag::ldgoi); + if (itag == format_tag::undef) return status::invalid_arguments; + + const int mask = attr->rnn_weights_qparams_.mask_; + if (!utils::one_of(mask, 0, 3)) return status::unimplemented; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return out_of_memory; + if (_pd->init() != success) { delete _pd; return unimplemented; } + _pd->itag_ = itag; + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + format_tag_t itag_; + }; + +private: + rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { +#if USE_MKL_PACKED_GEMM + auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO); + const memory_desc_wrapper &input_d = pd()->src_md(); + const memory_desc_wrapper &output_d = pd()->dst_md(); + const auto &dims = input_d.dims(); + const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc(); + const int L = dims[0]; + const int D = dims[1]; + const int I = dims[2]; + const int G = dims[3]; + const int O = dims[4]; + + /* Pack */ + bool cross_case = false + || (pd()->itag_ == format_tag::ldigo + && rnn_pdata.format == mkldnn_ldgoi_p) + || (pd()->itag_ == format_tag::ldgoi + && rnn_pdata.format == mkldnn_ldigo_p); + auto trans = cross_case ? CblasTrans : CblasNoTrans; + int n_parts = rnn_pdata.n_parts; + const size_t *size_packed_cell = rnn_pdata.part_pack_size; + const int *parts = rnn_pdata.parts; + const int n = rnn_pdata.n; + + const bool is_igo = pd()->itag_ == format_tag::ldigo; + auto off_igo = [&](int l, int d, int i, int g, int o) { + return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o; + }; + auto off_goi = [&](int l, int d, int i, int g, int o) { + return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i; + }; + for (int l = 0; l < L; l++) { + for (int d = 0; d < D; d++) { + for (int p = 0; p < n_parts; p++) { + int g = (p > 0) ? parts[p - 1] : 0; + int m_p = is_igo ? parts[p] * O : I; + int k_p = is_igo ? I : parts[p] * O; + int ld = is_igo ? G * O : I; + cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n, + k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) : + off_goi(l, d, 0, g, 0)], + ld, output); + output += size_packed_cell[p] / sizeof(float); + } + } + } +#endif + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp new file mode 100644 index 0000000000..1d60415cbc --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp @@ -0,0 +1,426 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" + +#include "ref_rnn.hpp" +#include "rnn_utils.hpp" +#include "type_helpers.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::utils; +using namespace rnn_utils; +using namespace format_tag; +using namespace rnn_packed_format; +using namespace data_type; + +bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1 + && str[3] == dims[4] && str[1] == str[2] * dims[2] + && str[0] == str[1] * dims[1]; +}; + +bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) { + if (md.format_kind() != format_kind::blocked) + return false; + + auto blk = md.blocking_desc(); + auto str = blk.strides; + auto dims = md.dims(); + return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1 + && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3] + && str[0] == str[1] * dims[1]; +}; + +void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d) { + rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); + rnn.is_training = utils::one_of( + rd.prop_kind, prop_kind::forward_training, prop_kind::backward); + rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset; + + switch (rd.direction) { + case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break; + case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break; + case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break; + case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break; + default: break; + } + + if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(), + weights_layer_d.data_type())) + rnn.dt_conf = all_f32; + else if (dst_layer_d.data_type() == u8) { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8u8; + else + rnn.dt_conf = f32u8f32u8; + } else { + if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8)) + rnn.dt_conf = u8u8u8f32; + else + rnn.dt_conf = f32u8f32f32; + } + + rnn.n_layer = weights_layer_d.dims()[0]; + rnn.n_iter = src_layer_d.dims()[0]; + rnn.n_dir = weights_layer_d.dims()[1]; + rnn.n_gates = weights_layer_d.dims()[3]; + rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc); + rnn.n_bias = rnn.n_gates + rnn.is_lbr; + rnn.mb = src_layer_d.dims()[1]; + rnn.sic = weights_iter_d.dims()[2]; + rnn.slc = weights_layer_d.dims()[2]; + rnn.dic = weights_layer_d.dims()[4]; + rnn.dlc = dst_layer_d.dims()[2]; + + rnn.gates_ld = rnn.dic * rnn.n_gates; + rnn.gates_nld = rnn.mb; + rnn.states_nld = rnn.mb; + + /* Set the correct number of weights parts */ + bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru; + rnn.n_parts_weights_layer = 1; + rnn.parts_weights_layer[0] = rnn.n_gates; + rnn.parts_weights_layer[1] = 0; + + rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1; + rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates; + rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0; + + rnn.n_parts_bias = 1; + rnn.parts_bias[0] = rnn.n_bias; + rnn.parts_bias[1] = 0; + + /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas + * and if to mergre gemm across iterations */ + bool is_int8 = rnn.dt_conf != all_f32; + rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) + || is_int8; + bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru, + alg_kind::gru_linear_before_reset); + rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8; + bool is_inference = !rnn.is_training; + + rnn.use_jit_gemm = !mayiuse(avx512_mic) + && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) + || (rnn.is_training && rnn.dic < 500)); + + /* Decide to copy bias */ + rnn.copy_bias = rnn.dt_conf != all_f32; + +#if USE_MKL_PACKED_GEMM + rnn.use_layer_packed_gemm + = (weights_layer_d.format_kind() == format_kind::any + && rnn.slc > 760 && rnn.dic > 760 && is_inference) + || is_int8; // packed gemm is the only supported option for int8 + rnn.use_iter_packed_gemm + = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760 + && rnn.dic > 760 && is_inference) + || is_int8; +#else + rnn.use_layer_packed_gemm = false; + rnn.use_iter_packed_gemm = false; +#endif + + /* Set packed gemm sizes */ + if (rnn.use_layer_packed_gemm) { + rnn.weights_layer_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_layer; p++) { + int m_p = rnn.is_fwd + ? (rnn.parts_weights_layer[p] * rnn.dic) + : rnn.slc; + int k_p = rnn.is_fwd + ? rnn.slc + : (rnn.parts_weights_layer[p] * rnn.dic); + int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_layer_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_layer_pack_size[p] = 0; +#endif + rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_layer_pack_size[p]; + } + rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size; + rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float); + } + + if (rnn.use_iter_packed_gemm) { + rnn.weights_iter_pack_size = 0; + for (int p = 0; p < rnn.n_parts_weights_iter; p++) { + int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) : + rnn.sic; + int k_p = rnn.is_fwd ? rnn.sic : + (rnn.parts_weights_iter[p] * rnn.dic); + int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb; + +#if USE_MKL_PACKED_GEMM + if (rnn.dt_conf == all_f32) + rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); + else + rnn.part_weights_iter_pack_size[p] + = cblas_gemm_s8u8s32_pack_get_size( + CblasAMatrix, m_p, n_p, k_p); +#else + UNUSED(m_p); + UNUSED(k_p); + UNUSED(n_p); + rnn.part_weights_iter_pack_size[p] = 0; +#endif + rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir + * rnn.part_weights_iter_pack_size[p]; + } + rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size; + rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer + * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float); + } + +} + +void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d) { + + /* Set leading dimensions for input weights arrays depending on input format + */ + rnn.weights_layer_is_packed + = weights_layer_d.format_kind() == format_kind::rnn_packed; + rnn.weights_iter_is_packed + = weights_iter_d.format_kind() == format_kind::rnn_packed; + + auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) { + ld = 0; nld = 0; + if (md.is_blocking_desc()) { + if (is_ldigo(md)) { + ld = (int)md.blocking_desc().strides[2]; + nld = md.dims()[2]; + } else if (is_ldgoi(md)) { + ld = (int)md.blocking_desc().strides[4]; + nld = md.dims()[3] * md.dims()[4]; + } else + assert(!"unsupported weights format"); + } + }; + set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); + set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); + if (!rnn.is_fwd) { + set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, + rnn.diff_weights_layer_nld); + set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, + rnn.diff_weights_iter_nld); + } + + int sizeof_states_dt + = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t); + rnn.states_ws_ld + = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)), + sizeof_states_dt); + rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float)); + + /* Set workspace sizes to store: + * states to copmute a pass + * diff states to copmute bwd pass (training only) + * intermediate results from the gates + */ + rnn.use_workspace = rnn.is_training; + rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir + * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt; + bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm; + rnn.ws_c_states_size = is_lstm + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb + * rnn.states_ws_ld * sizeof(float) + : 0; + rnn.ws_diff_states_size = rnn.is_training + ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) + * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld + * sizeof(float) + : (size_t)0; + rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb + * rnn.gates_ws_ld * sizeof(float); + + /* set other sizes */ + rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float); + rnn.ws_cell_comp_size + = rnn.is_lbr || rnn.dt_conf != all_f32 + ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float) + : 0; + rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer + * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); + rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic + * sizeof(float); +} + +int rnn_utils::get_good_ld(int dim, int sizeof_dt) { + // we want matrices leading dimentions to be 64-byte aligned, + // and not divisible by 256 to avoid 4K aliasing effects + int ld = rnd_up(dim, 64 / sizeof_dt); + return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; +} + +void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_states_offset, size_t &ws_c_states_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size) { + + const size_t page_size = 4096; // 2097152; + size_t current_offset; + /* Mandatory workspaces: go to workspace if use_workspace, scratchpad + * otherwise */ + current_offset = 0; // assumes the workspace base pointer is page aligned + ws_gates_offset = current_offset; + current_offset += rnn.ws_gates_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_states_offset = current_offset; + current_offset += rnn.ws_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_c_states_offset = current_offset; + current_offset += rnn.ws_c_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_diff_states_offset = current_offset; + current_offset += rnn.ws_diff_states_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_grid_comp_offset = current_offset; + current_offset += rnn.ws_grid_comp_size; + + current_offset = utils::rnd_up(current_offset, page_size); + ws_cell_comp_offset = current_offset; + current_offset += rnn.ws_cell_comp_size; + + workspace_size = rnn.use_workspace ? current_offset : 0; + + /* Optional scratchpads */ + // Assumes the scratchpad base pointer is page aligned. + // If use_workspace, the following goes to scratchpad alone, + // otherwise, all goes to scratchpad and continue incrementing offset + current_offset = rnn.use_workspace ? 0 : current_offset; + + if (rnn.copy_bias) { + current_offset = utils::rnd_up(current_offset, page_size); + ws_bias_offset = current_offset; + current_offset += rnn.ws_bias_size; + } + + scratchpad_size = current_offset; +} + +void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size) { + size_t ws_gates_offset, ws_states_offset, ws_c_states_offset, + ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset; + set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset, + ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset, + ws_bias_offset, scratchpad_size, workspace_size); +} + +status_t rnn_utils::set_good_strides( + memory_desc_t &weights_md, format_tag_t tag) { + auto &strides = weights_md.format_desc.blocking.strides; + auto dims = weights_md.dims; + + if (tag == ldigo) { + strides[2] = rnn_utils::get_good_ld((int)strides[2], + (int)types::data_type_size(weights_md.data_type)); + strides[1] = dims[2] * strides[2]; + strides[0] = dims[1] * strides[1]; + } else if (tag == ldgoi) { + strides[4] = rnn_utils::get_good_ld((int)strides[4], + (int)types::data_type_size(weights_md.data_type)); + strides[3] = dims[4] * strides[4]; + strides[1] = dims[3] * strides[3]; + strides[0] = dims[1] * strides[1]; + } else + return status::unimplemented; + + return status::success; +} + +status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, + memory_desc_t &weights_md, bool is_iter) { + using namespace format_tag; + bool use_packed_gemm = is_iter + ? rnn.use_iter_packed_gemm + : rnn.use_layer_packed_gemm; + if (use_packed_gemm) { + weights_md.format_kind = format_kind::rnn_packed; + rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; + rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p; + if (is_iter) { + rnn_pdata.n = rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_iter; + array_copy(rnn_pdata.parts, rnn.parts_weights_iter, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; + rnn_pdata.size = rnn.weights_iter_pack_size; + } else { + rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; + rnn_pdata.n_parts = rnn.n_parts_weights_layer; + array_copy(rnn_pdata.parts, rnn.parts_weights_layer, + MKLDNN_RNN_MAX_N_PARTS); + array_copy(rnn_pdata.part_pack_size, + rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS); + rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; + rnn_pdata.size = rnn.weights_layer_pack_size; + } + } else { + CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + // Adjust strides for good leading dimension in GEMM + CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi)); + } + return status::success; +} + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp new file mode 100644 index 0000000000..99eb787a64 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp @@ -0,0 +1,225 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef RNN_UTILS_HPP +#define RNN_UTILS_HPP + +#include "mkldnn.h" + +#include "cpu_rnn_pd.hpp" + + +#define rnn_elemwise_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \ + src_data_t *states_t_l_, float *c_states_t_l_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_l_, float *diff_states_t_lp1_, \ + float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \ + float *ws_cell_) const + +#define rnn_cell_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \ + float *c_states_t_l_, float *diff_states_t_l_, \ + weights_data_t **w_layer_, weights_data_t **w_iter_, \ + float **bias_, src_data_t *states_t_lm1_, \ + src_data_t *states_tm1_l_, float *c_states_tm1_l_, \ + float *diff_states_t_lp1_, float *diff_states_tp1_l_, \ + float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \ + acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const + +#define rnn_grid_execution_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \ + weights_data_t **weights_states_, float **bias_, \ + src_data_t *ws_states_, float *ws_c_states_, \ + float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \ + float *ws_grid_, float *diff_weights_layer_, \ + float *diff_weights_iter_, float *diff_bias_) const + +#define rnn_gemm_sig(f) \ + void f(const char transA, const char transB, int m, int n, int k, \ + const float alpha, const weights_data_t *a_, const int ldA, \ + const src_data_t *b_, const int ldB, const float beta, \ + acc_data_t *c_, const int ldC) const + +#define rnn_bias_prepare_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \ + float *scratch_bias_) const + +#define rnn_bias_finalize_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \ + const float *w_iter_comp, const float *w_layer_comp) const + +#define rnn_weights_assign_sig(f) \ + void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \ + int ld, int OC_size, int IC_size, const int n_parts, \ + const int *gates_per_part, const size_t *part_weights_pack_size, \ + weights_data_t **weights_, const weights_data_t *w_, \ + float **bias_, const float *b_, float *scratch_bias_) const + + +namespace mkldnn { +namespace impl { +namespace cpu { + +namespace rnn_utils { + +using namespace mkldnn::impl::utils; + +enum execution_direction_t { + l2r, + r2l, + bi_concat, + bi_sum, +}; + +enum data_type_conf_t { + all_f32, + u8u8u8f32, + f32u8f32f32, + u8u8u8u8, + f32u8f32u8 +}; + +struct rnn_conf_t { + execution_direction_t exec_dir; + data_type_conf_t dt_conf; + int n_layer, n_iter, n_dir, n_gates, n_states; + int mb; + int slc, sic, dic, dlc; + int gates_ld, gates_nld, gates_ws_ld; + int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS]; + int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS]; + int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS]; + size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS], + part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS]; + bool weights_layer_is_packed, weights_iter_is_packed; + /* Size of packed data in bytes */ + size_t weights_layer_comp_offset, weights_layer_pack_size, + weights_iter_comp_offset, weights_iter_pack_size; + + bool copy_bias; + int weights_layer_ld, weights_layer_nld; + int diff_weights_layer_ld, diff_weights_layer_nld; + int weights_iter_ld, weights_iter_nld; + int diff_weights_iter_ld, diff_weights_iter_nld; + int states_nld, states_ws_ld; + int weights_iter_compensation_size, weights_layer_compensation_size; + bool is_fwd, is_training, is_lbr; + bool use_workspace; + + /* Size of workspace for each tensor in bytes */ + size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size, + ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size; + bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm, + use_iter_packed_gemm; +}; + +bool is_ldigo(const memory_desc_wrapper &md); +bool is_ldgoi(const memory_desc_wrapper &md); + +int get_good_ld(int dim, int sizeof_dt); + +void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &src_layer_d, + const memory_desc_wrapper &src_iter_d, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &dst_layer_d); + +void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, + const memory_desc_wrapper &weights_layer_d, + const memory_desc_wrapper &weights_iter_d, + const memory_desc_wrapper &diff_weights_layer_d, + const memory_desc_wrapper &diff_weights_iter_d); + +void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, + size_t &ws_h_state_offset, size_t &ws_c_state_offset, + size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset, + size_t &ws_cell_comp_offset, size_t &ws_bias_offset, + size_t &scratchpad_size, size_t &workspace_size); + +void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, + size_t &scratchpad_size, size_t &workspace_size); +status_t set_expected_desc( + rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter); +status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); + +template <typename T> +struct ws_gates_aoc { + ws_gates_aoc(const rnn_conf_t &rnn, T *data) + : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {} + T &operator()(int batch, int gate, int dic) { + return gates_(batch, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<T, 2> gates_; + int DIC_; +}; +using ws_gates_aoc_t = ws_gates_aoc<float>; +using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>; + +struct bias_aoc_t { + bias_aoc_t(const rnn_conf_t &rnn, const float *data) + : bias_(data, rnn.n_bias, rnn.dic) {} + const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_; +}; + +template <typename T> +struct ws_states_aoc { + ws_states_aoc(const rnn_conf_t &rnn, T *data) + : state_(data, rnn.states_nld, rnn.states_ws_ld) {} + T &operator()(int batch, int dic) { return state_(batch, dic); } + +private: + mkldnn::impl::utils::array_offset_calculator<T, 2> state_; +}; +using ws_states_aoc_t = ws_states_aoc<float>; +using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>; + +struct ws_diff_states_aoc_t { + ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld, + rnn.states_ws_ld) {} + float &operator()(int state_n, int batch, int dic) { + return diff_states_(state_n, 0, batch, dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_; +}; + +struct ws_diff_w_iter_aoc_t { + ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) + : diff_weights_iter_( + data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) + , DIC_(rnn.dic) {} + float &operator()(int sic, int gate, int dic) { + return diff_weights_iter_(sic, gate * DIC_ + dic); + } + +private: + mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_; + int DIC_; +}; +} +} +} +} +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp new file mode 100644 index 0000000000..0420f87aa5 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_concat.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +template <data_type_t data_type> +status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const { + auto scratchpad = this->scratchpad(ctx); + auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs); + auto optrs = scratchpad.template get<data_t *>(key_concat_optrs); + auto nelems_to_copy = scratchpad.template get<dim_t>(key_concat_nelems); + auto is = scratchpad.template get<strides_t>(key_concat_istrides); + + const int num_arrs = pd()->n_inputs(); + const int *perm = pd()->perm_, *iperm = pd()->iperm_; + const int concat_dim = pd()->concat_dim(); + auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + const memory_desc_wrapper o_d(pd()->src_image_md(a)); + + iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + optrs[a] = o_base_ptr + o_d.blk_off(0); + nelems_to_copy[a] = pd()->nelems_to_concat(i_d); + for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) { + if (i < perm[concat_dim]) + is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); + else + is[a][i] = 0; + } + } + + const memory_desc_wrapper o_d(pd()->src_image_md(0)); + + strides_t os = { 0 }; + for (int i = 0; i < perm[concat_dim]; i++) + os[i] = o_d.blocking_desc().strides[iperm[i]]; + + dims_t phys_dims; + for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++) + phys_dims[i] = (i < (size_t)perm[concat_dim]) + ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1; + + if (perm[concat_dim] == 0) { + for (int a = 0; a < num_arrs; ++a) { + const data_t *i = &iptrs[a][0]; + data_t *o = &optrs[a][0]; + parallel_nd((ptrdiff_t)nelems_to_copy[a], + [&](ptrdiff_t e) { o[e] = i[e]; }); + } + } else { + parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], + phys_dims[4], num_arrs, + [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) { + // XXX: this code may access uninitialized values in is[*][0-4] -- + // that's why we have to set them to zero although this is + // probably benign + size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 + + is[a][3] * n3 + is[a][4] * n4; + size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 + + os[3] * n3 + os[4] * n4; + const data_t *i = &iptrs[a][in_off]; + data_t *o = &optrs[a][out_off]; +#if defined(__GNUC__) && !defined(__INTEL_COMPILER) + // The code below performs data copying: o[e] = i[e] + // and uses a workaround to make GNU compilers optimize it + uint8_t *ptro = reinterpret_cast<uint8_t *>(o); + const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i); + const dim_t main_part = + nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t); + const dim_t tail_part = + nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t); + + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < main_part; ++e) { + *(reinterpret_cast<uint32_t *>(ptro)) + = *(reinterpret_cast<const uint32_t *>(ptri)); + ptro += sizeof(uint32_t); + ptri += sizeof(uint32_t); + } + for (dim_t e = 0; e < tail_part; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } +#else + PRAGMA_OMP_SIMD() + for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; +#endif + }); + } + + return status::success; +} + +template struct simple_concat_t<data_type::f32>; +template struct simple_concat_t<data_type::u8>; +template struct simple_concat_t<data_type::s8>; +template struct simple_concat_t<data_type::s32>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp new file mode 100644 index 0000000000..5177275452 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_CONCAT_HPP +#define SIMPLE_CONCAT_HPP + +#include "memory_tracking.hpp" + +#include "cpu_concat_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t data_type> +struct simple_concat_t: public cpu_primitive_t { + struct pd_t: public cpu_concat_pd_t { + using cpu_concat_pd_t::cpu_concat_pd_t; + + pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { + int ndims = rhs.dst_md_.ndims; + utils::array_copy(perm_, rhs.perm_, ndims); + utils::array_copy(iperm_, rhs.iperm_, ndims); + utils::array_copy(blocks_, rhs.blocks_, ndims); + } + + DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); + + status_t init() { + const memory_desc_wrapper dst_d(dst_md()); + bool ok = true + && cpu_concat_pd_t::init() == status::success + && dst_d.ndims() <= 6; + if (!ok) return status::unimplemented; + + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + const memory_desc_wrapper o_d(&src_image_mds_[i]); + + const int ignore_strides = 0; + + ok = ok + && utils::everyone_is(data_type, i_d.data_type(), + o_d.data_type()) + && utils::everyone_is(format_kind::blocked, + i_d.format_kind(), o_d.format_kind()) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + o_d.blocking_desc(), ignore_strides) + && types::blocking_desc_is_equal(i_d.blocking_desc(), + dst_d.blocking_desc(), ignore_strides) + && !i_d.is_additional_buffer(); + if (!ok) return status::unimplemented; + } + + dst_d.compute_blocks(blocks_); + format_perm(); + + // start dim is the first dimension after which the concatenation + // would happen contiguously + const int start_dim = perm_[concat_dim()]; + + // check that contiguous part is indeed contiguous (i.e. dense) + if (nelems_to_concat(dst_d) != + dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] + * dst_d.blocking_desc().strides[concat_dim()]) + return status::unimplemented; + + // check that all inputs have the same strides for the + // contiguous part [concat_dim .. ndims] for the *major* dims. + // the block part is already checked above + for (size_t i = 0; i < src_mds_.size(); ++i) { + const memory_desc_wrapper i_d(&src_mds_[i]); + for (int d = start_dim; d < dst_d.ndims(); ++d) { + if (dst_d.blocking_desc().strides[iperm_[d]] + != i_d.blocking_desc().strides[iperm_[d]]) + return status::unimplemented; + } + } + + init_scratchpad(); + + return status::success; + } + + int perm_[MKLDNN_MAX_NDIMS]; + int iperm_[MKLDNN_MAX_NDIMS]; + dims_t blocks_; + + dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { + const int ndims = data_d.ndims(); + + dim_t nelems = 1; + for (int i = perm_[concat_dim()]; i < ndims; i++) + nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; + for (int i = 0; i < ndims; i++) + nelems *= blocks_[i]; + + return nelems; + } + + private: + void format_perm() { + const memory_desc_wrapper dst_d(dst_md()); + const int ndims = dst_d.ndims(); + + strides_t strides; + utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); + for (int i = 0; i < ndims; i++) iperm_[i] = i; + + utils::simultaneous_sort(strides, iperm_, ndims, + [](stride_t a, stride_t b) { return b - a; }); + + for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; + } + + void init_scratchpad() { + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); + scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); + scratchpad.book(key_concat_istrides, + sizeof(strides_t) * n_inputs()); + } + }; + + simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + typedef typename prec_traits<data_type>::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp new file mode 100644 index 0000000000..e6c3b8d7af --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_Q10N_HPP +#define CPU_SIMPLE_Q10N_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "math_utils.hpp" +#include "nstl.hpp" +#include "type_helpers.hpp" +#include "utils.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::math; + +template <typename out_t> +inline out_t round_and_saturate(float f) +{ return math::saturate<out_t>(out_round<int>(f)); } + +/* Quantization with alpha == 1 and beta == 0 */ +template <typename in_t, typename out_t, typename enabled = void> +struct qz_a1b0 { + out_t operator()(in_t in) + { return round_and_saturate<out_t>((float)in); } +}; + +template <typename in_t, typename out_t> +struct qz_a1b0<in_t, out_t, + typename utils::enable_if<true + && nstl::is_integral<in_t>::value + && !is_subset<in_t, out_t>::value + >::type> { + out_t operator()(in_t in) { return math::saturate<out_t>(in); } +}; + +template <typename in_t, typename out_t> +struct qz_a1b0<in_t, out_t, + typename utils::enable_if<is_subset<in_t, out_t>::value>::type> { + out_t operator()(in_t in) { return (out_t)in; } +}; + +/* Quantization with alpha == 1 */ +template <typename in_t, typename out_t> struct qz_a1 { + out_t operator()(in_t in, out_t out, float beta) + { return round_and_saturate<out_t>((float)in + beta * out); } +}; + +template <typename in_t> struct qz_a1<in_t, float> { + float operator()(in_t in, float out, float beta) + { return (float)in + beta * out; } +}; + +/* Quantization with beta == 0 */ +template <typename in_t, typename out_t> struct qz_b0 { + out_t operator()(in_t in, float alpha) + { return round_and_saturate<out_t>(alpha * in); } +}; + +template <typename in_t> struct qz_b0<in_t, float> { + float operator()(in_t in, float alpha) { return alpha * in; } +}; + +/* Quantization */ +template <typename in_t, typename out_t> struct qz { + out_t operator()(in_t in, out_t out, float alpha, float beta) { + return round_and_saturate<out_t>( + alpha * in + (beta ? beta * out : 0)); + } +}; + +template <typename in_t> struct qz<in_t, float> { + float operator()(in_t in, float out, float alpha, float beta) + { return alpha * in + (beta ? beta * out : 0); } +}; + +} +} +} + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp new file mode 100644 index 0000000000..ff845f5bd3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp @@ -0,0 +1,1022 @@ +/******************************************************************************* +* Copyright 2016-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_SIMPLE_REORDER_HPP +#define CPU_SIMPLE_REORDER_HPP + +#include <assert.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" +#include "math_utils.hpp" +#include "mkldnn_thread.hpp" +#include "utils.hpp" + +#include "tag_traits.hpp" +#include "cpu_reorder_pd.hpp" +#include "cpu_primitive.hpp" + +#include "simple_q10n.hpp" +#include "cpu_isa_traits.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace mkldnn::impl::status; +using namespace mkldnn::impl::format_tag; +using namespace mkldnn::impl::data_type; + +using bd = block_dim_t; +using ib = inner_blk_t; + +using namespace mkldnn::impl::utils; +using math::saturate; + +template<impl::data_type_t type> +using data_t = typename prec_traits<type>::type; + +template<impl::data_type_t type_i, impl::data_type_t type_o> +using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>; + +template<impl::data_type_t type_i, impl::data_type_t type_o> +using _qz = qz<data_t<type_i>, data_t<type_o>>; + +namespace fmt_order { + const bool keep = true; + const bool reverse = false; + const bool any = keep; +} + +namespace spec { +struct direct_copy {}; +struct direct_copy_except_dim_0 {}; +struct reference {}; +struct conv_s8s8 {}; +} + +#define SIMPLE_REORDER_TEMPL_DECL \ + impl::data_type_t type_i, impl::format_tag_t tag_i, \ + impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep +#define SIMPLE_REORDER_TEMPL_CALL \ + type_i, tag_i, type_o, tag_o, order_keep + +#define DECLARE_COMMON_PARAMS() \ + const memory_desc_wrapper &input_d = pd->src_md(); \ + const memory_desc_wrapper &output_d = pd->dst_md(); \ + const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \ + const float beta = pd->beta(); MAYBE_UNUSED(beta); + +/* specific reorders: common template */ +template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void> +struct simple_reorder_impl {}; + +namespace { +inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i, + impl::format_tag_t tag_o, const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d) { + return input_d.matches_tag(order_keep ? tag_i : tag_o) + && output_d.matches_tag(order_keep ? tag_o : tag_i); +} +inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) { + if (many_scales_support) + return true; + return IMPLICATION(attr, attr->output_scales_.mask_ == 0); +} +} + +/* specific reorders: implementation */ +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, +typename utils::enable_if<tag_i == any && (false + || tag_o == hwio + || tag_o == hwigo) + , spec::conv_s8s8>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = (input_d.dims()[tag_o == hwigo + 0]); + const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1; + + return output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = tag_o == hwigo; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int IC = dims[w_groups + 1]; + const int H = dims[w_groups + 2]; + const int W = dims[w_groups + 3]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W; + int32_t *cp = reinterpret_cast<int32_t *>(output + offset); + + parallel_nd(G, OC, [&](int g, int oc) { + cp[g * OC + oc] = 0; + for (int ic = 0; ic < IC; ic++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)]; + auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)]; + const float s = scales[(D_mask == 1) ? 0 : g * OC + oc]; + + o = qz_b0<data_t<type_i>, data_t<type_o>>()( + i, s * adj_scale); + cp[g * OC + oc] -= (int32_t)o; + } + cp [g * OC + oc] *= 128; + }); + return success; + } +}; + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, + typename utils::enable_if< + (tag_i == goiw && tag_o == gOIw4i16o4i) + || (tag_i == oiw && tag_o == OIw4i16o4i) + || (tag_i == goihw && tag_o == gOIhw4i16o4i) + || (tag_i == oihw && tag_o == OIhw4i16o4i) + || (tag_i == goihw && tag_o == gOIhw2i8o4i) + || (tag_i == goihw && tag_o == gOIhw4o4i) + , spec::conv_s8s8>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + const int oc = (input_d.dims()[w_groups ? 1 : 0]); + const int g = w_groups ? input_d.dims()[0] : 1; + + return input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + static constexpr bool w_groups = + !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i); + constexpr int is_1d = + utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i); + constexpr int blksize = tag_traits<tag_o>::inner_blks == ib::_4b4c + ? 4 + : tag_traits<tag_o>::inner_blks == ib::_2c8b4c + ? 8 + : 16; + + const auto &_g_oihw_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + const int G = w_groups ? dims[0] : 1; + const int OC = dims[w_groups + 0]; + const int NB_OC = pdims[w_groups + 0] / blksize; + const int IC = dims[w_groups + 1]; + const int NB_IC = pdims[w_groups + 1] / blksize; + const int H = is_1d ? 1 : dims[w_groups + 2]; + const int W = dims[w_groups + 3 - is_1d]; + + const float *scales = pd->attr()->output_scales_.scales_; + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, + int32_t *c, const float *s, const int oc_block, const int ic_block) { +# define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks> + + for (int ic = 0; ic < ic_block; ++ic) { + for (int oc = 0; oc < oc_block; ++oc) { + const auto _g_oihw_off = + oc * _g_oihw_d.blocking_desc().strides[w_groups + 0] + + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1]; + out[index(oc, ic)] + = qz_b0<data_t<type_i>, data_t<type_o>>()( + inp[_g_oihw_off], s[oc] * adj_scale); + c[oc] -= (128 * (int32_t)(out[index(oc, ic)])); + } + } +# undef index + }; + + constexpr int i_mult = blksize; + constexpr int o_mult = 1; + + size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W; + int32_t *cp = reinterpret_cast<int32_t *>(output + offset); + parallel_nd(G * NB_OC * blksize, [&](int i) { + cp[i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \ + : (md).blk_off<!w_groups>(g, o, i, h, w)) + + parallel_nd(G, NB_OC, [&](int g, int O) { + for (int I = 0; I < NB_IC; I++) + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) { + auto i = &input[wei_blk_off( + input_d, g, i_mult * O, i_mult * I, h, w)]; + auto o = &output[wei_blk_off( + output_d, g, o_mult * O, o_mult * I, h, w)]; + const int oc_block = nstl::min(blksize, OC - O * blksize); + const int ic_block = nstl::min(blksize, IC - I * blksize); + + int _offset = (g * NB_OC + O) * blksize; + ker(i, o, (order_keep) ? &cp[_offset] : nullptr, + &scales[(D_mask == 1) ? 0 : _offset], + oc_block, ic_block); + } + }); + +# undef wei_blk_off + + return success; + } +}; + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, + typename utils::enable_if<false + ||(tag_i == goiw && tag_o == Goiw16g) + ||(tag_i == goihw && tag_o == Goihw16g) + , spec::conv_s8s8>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(attr->output_scales_.mask_ + 1)); + const int oc = input_d.dims()[1]; + const int g = input_d.dims()[0]; + + return true + && order_keep + && input_d.matches_tag(tag_i) + && output_d.matches_tag(tag_o) + && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8) + && (input_d.data_type() == f32 || input_d.data_type() == s8) + && output_d.data_type() == s8 + && (D_mask == 1 || D_mask == (size_t)g * oc); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + constexpr bool is_1d = tag_i == goiw; + constexpr int blksize = 16; + + const auto &dims = input_d.dims(); + const auto &pdims = output_d.padded_dims(); + const int G = dims[0]; + const int Gp = pdims[0]; + const int OC = dims[1]; + const int IC = dims[2]; + const int H = is_1d ? 1 : dims[3]; + const int W = dims[4 - is_1d]; + + const size_t D_mask = utils::array_product(input_d.dims(), + math::ilog2q(pd->attr()->output_scales_.mask_ + 1)); + const float *scales = pd->attr()->output_scales_.scales_; + + assert(output_d.extra().flags + & memory_extra_flags::compensation_conv_s8s8); + float adj_scale = + (output_d.extra().flags & memory_extra_flags::scale_adjust) + ? output_d.extra().scale_adjust : 1.f; + + auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out, + int32_t *cp, const float *s, const int g_block) { + PRAGMA_OMP_SIMD() + for (int g = 0; g < g_block; g++) { + const auto i_off = g * input_d.blocking_desc().strides[0]; + out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()( + inp[i_off], s[g * OC] * adj_scale); + cp[g * OC] -= 128 * (int32_t)(out[g]); + } + }; + + size_t cp_offset = output_d.size() - output_d.additional_buffer_size(); + int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset); + parallel_nd((Gp/blksize) * OC, [&](int ib) { + PRAGMA_OMP_SIMD() + for (int i = 0; i < blksize; i++) + cp[ib * blksize + i] = 0; + }); + +# define wei_blk_off(md, g, o, i, h, w) \ + (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w)) + + parallel_nd(Gp/blksize, OC, [&](int gb, int O) { + for (int I = 0; I < IC; I++) { + for (int h = 0; h < H; h++) + for (int w = 0; w < W; w++) + { + const int g_block = nstl::min(G - gb * blksize, blksize); + const auto inp = &input[wei_blk_off( + input_d, gb * blksize, O, I, h, w)]; + const auto out = &output[wei_blk_off( + output_d, gb, O, I, h, w)]; + int offset = gb * blksize + O; + ker(inp, out, &cp[offset], + &scales[(D_mask == 1) ? 0 : offset], g_block); + } + } + }); + +# undef wei_blk_off + + return success; + } +}; + +/* reorders with tail support */ + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, +typename utils::enable_if<false + || (tag_i == nCdhw8c && tag_o == nCdhw16c) + || (tag_i == nChw8c && tag_o == nChw16c) + || (tag_i == nCw8c && tag_o == nCw16c) + >::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) + { + return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + constexpr int is_1d = tag_i == nCw8c; + constexpr int is_3d = tag_i == nCdhw8c; + constexpr int blksize_16 = 16; + constexpr int blksize_8 = 8; + constexpr int ic_mult = order_keep ? 2 : 1; + constexpr int oc_mult = order_keep ? 1 : 2; + + const auto &dims = input_d.dims(); + const auto &pdims = order_keep ? output_d.padded_dims() + : input_d.padded_dims(); + + const int C = dims[1]; + const int D = is_3d ? dims[2] : 1; + const int H = is_1d ? 1 : dims[2 + is_3d]; + const int W = dims[3 + is_3d - is_1d]; + + auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, + const int block_16) { + const int nb = (block_16 - 1) / blksize_8 + 1; + if (alpha == 1.0 && beta == 0.0) { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz_a1b0<type_i, type_o>()( + i[i_off + c]); + } + } + } else { + for (int b = 0; b < nb; ++b) { + const ptrdiff_t i_off = order_keep ? b : b * blksize_8; + const ptrdiff_t o_off = order_keep ? b * blksize_8 : b; + const int block_8 = nstl::min(blksize_8, + block_16 - b * blksize_8); + for (int c = 0; c < block_8; ++c) { + o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c], + o[o_off + c], alpha, beta); + } + } + } + }; + +# define data_blk_off(md, n, c, d, h, w) \ + ( is_1d ? (md).blk_off(n, c, w) \ + : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w)) + + parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W, + [&](int n, int nb_c, int d, int h, int w) { + auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)]; + auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)]; + const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16); + ker(i, o, block_16); + }); + +# undef data_blk_off + + return success; + } +}; + +#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \ + static bool is_applicable(const memory_desc_wrapper &input_d, \ + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \ + return simple_attr_check(attr, false) && (order_keep \ + ? output_d.matches_tag(tag_o) && input_d.is_plain() \ + : input_d.matches_tag(tag_o) && output_d.is_plain()); \ + } + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, +typename utils::enable_if<tag_i == any + && (tag_traits<tag_o>::block_dims == bd::_A + || tag_traits<tag_o>::block_dims == bd::_B) + && tag_traits<tag_o>::ndims >= 3 + && tag_traits<tag_o>::ndims <= 6 + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &block_d = order_keep ? output_d : input_d; + const auto &dims = input_d.dims(); + const auto &pdims = block_d.padded_dims(); + + constexpr int ndims = tag_traits<tag_o>::ndims; + constexpr int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1; + + const dim_t H0 = dims[0]; + const dim_t H1 = dims[1]; + const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1; + const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1; + const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1; + const dim_t L = dims[ndims - 1]; + const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1]; + + constexpr int blksize = false ? 0 + : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_4a, ib::_4b) ? 4 + : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_8a, ib::_8b) ? 8 + : 16; + + auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) { + if (alpha == 1.0 && beta == 0.0) { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz_a1b0<type_i, type_o>()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0<type_i, type_o>()( + i[l * l_blk_stride + blk]); + } + } + } else { + for (int l = 0; l < L; ++l) + for (int blk = 0; blk < block; ++blk) { + const dim_t flat_off = 0 + + blk * flat_d.blocking_desc().strides[blk_idx] + + l * flat_d.blocking_desc().strides[ndims - 1]; + if (order_keep) { + o[l * l_blk_stride + blk] = _qz<type_i, type_o>()( + i[flat_off], o[l * blksize + blk], + alpha, beta); + } else { + o[flat_off] = _qz<type_i, type_o>()( + i[l * l_blk_stride + blk], o[flat_off], + alpha, beta); + } + } + } + }; + +# define off(md, h0, h1, m0, m1, m2) \ + (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \ + : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \ + : ndims >= 4 ? (md).blk_off(h0, h1, m2) \ + : /* ndims >= 3 ? */ (md).blk_off(h0, h1)) + + constexpr int i_mult = order_keep ? blksize : 1; + constexpr int o_mult = order_keep ? 1 : blksize; + + if (blk_idx == 0) { + const dim_t BH0 = pdims[0] / blksize; + parallel_nd(BH0, H1, M0, M1, M2, + [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)]; + auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)]; + const int block = nstl::min<int>(blksize, H0 - bh0 * blksize); + ker(i, o, block); + }); + } else if (blk_idx == 1) { + const dim_t BH1 = pdims[1] / blksize; + parallel_nd(H0, BH1, M0, M1, M2, + [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)]; + auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)]; + const int block = nstl::min<int>(blksize, H1 - bh1 * blksize); + ker(i, o, block); + }); + } else { + assert(!"unimplemented"); + } + +# undef off + + return success; + } +}; + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, +typename utils::enable_if<tag_i == any + && (tag_traits<tag_o>::block_dims == bd::_AB + || tag_traits<tag_o>::block_dims == bd::_BC) + && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB, + tag_traits<tag_o>::ndims >= 3 && tag_traits<tag_o>::ndims <= 5) + && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC, + tag_traits<tag_o>::ndims >= 4 && tag_traits<tag_o>::ndims <= 6) + >::type> +{ + PLAIN_TO_BLOCKED_IS_APPLICABLE(); + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + const auto &flat_d = order_keep ? input_d : output_d; + const auto &dims = input_d.dims(); + const auto &pdims = order_keep + ? output_d.padded_dims() + : input_d.padded_dims(); + + constexpr int ndims = tag_traits<tag_o>::ndims; + + static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC; + const dim_t G = with_g ? dims[0] : 1; + + const dim_t H0 = dims[0 + with_g]; + const dim_t H1 = dims[1 + with_g]; + + const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1; + const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1; + const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1; + + constexpr int blksize_0 = false ? 0 + : utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b) + ? 4 + : utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c, + ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b, + ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c) + ? 16 : INT_MIN; + + constexpr int blksize_1 = utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c) + ? 8 + : utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b, + ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b, + ib::_4c16b4c, ib::_8c16b2c) + ? 16 + : utils::one_of(tag_traits<tag_o>::inner_blks, + ib::_4b4a, ib::_4b4c, ib::_4c4b, + ib::_16a4b, ib::_16b4c) + ? 4 + : INT_MIN; + + const dim_t NB_H0 = pdims[0 + with_g] / blksize_0; + const dim_t NB_H1 = pdims[1 + with_g] / blksize_1; + + auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, + const int block_h0, const int block_h1) { +# define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks> + + if (alpha == 1.0 && beta == 0.0) { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz_a1b0<type_i, type_o>()( + i[flat_off]); + } else { + o[flat_off] = _qz_a1b0<type_i, type_o>()( + i[blk_off(h0, h1)]); + } + } + } else { + for (int h0 = 0; h0 < block_h0; ++h0) + for (int h1 = 0; h1 < block_h1; ++h1) { + const dim_t flat_off = 0 + + h0 * flat_d.blocking_desc().strides[with_g + 0] + + h1 * flat_d.blocking_desc().strides[with_g + 1]; + if (order_keep) { + o[blk_off(h0, h1)] = _qz<type_i, type_o>()(i[flat_off], + o[blk_off(h0, h1)], alpha, beta); + } else { + o[flat_off] = _qz<type_i, type_o>()(i[blk_off(h0, h1)], + o[flat_off], alpha, beta); + } + } + } + +# undef blk_off + }; + + constexpr int i_mult_0 = order_keep ? blksize_0 : 1; + constexpr int o_mult_0 = order_keep ? 1 : blksize_0; + + constexpr int i_mult_1 = order_keep ? blksize_1 : 1; + constexpr int o_mult_1 = order_keep ? 1 : blksize_1; + +# define off(md, g, h0, h1, m0, m1, m2) \ + (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \ + : ndims >= 4 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \ + : /* ndims >= 3 + with_g ? */ (md).blk_off<!with_g>(g, h0, h1, m2)) + + parallel_nd(G, NB_H0, NB_H1, M0, M1, M2, + [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) { + auto i = &input[off(input_d, + g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)]; + auto o = &output[off(output_d, + g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)]; + const int block_h0 = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0); + const int block_h1 = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1); + ker(i, o, block_h0, block_h1); + }); + +# undef off + + return success; + } +}; + +/* generic and direct-copy reorders */ + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, + typename utils::enable_if< + tag_i == any && tag_o == any && order_keep == fmt_order::any, + spec::direct_copy>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 0) + && input_d.is_dense() && output_d.is_dense() + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + assert(input_d.is_dense()); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const size_t nelems = input_d.nelems(); + + constexpr int block_size = 16; + const auto num_blocks = nelems / block_size; + const auto rem_elems = nelems % block_size; + + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(num_blocks, nthr, ithr, start, end); + start = start * block_size; + end = end * block_size; + + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>() + (input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_a1<data_t<type_i>, data_t<type_o>>() + (input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz_b0<data_t<type_i>, data_t<type_o>>() + (input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = start; e < end; ++e) { + output[e] = qz<data_t<type_i>, data_t<type_o>>() + (input[e], output[e], alpha, beta); + } + } + + if (rem_elems != 0 && ithr == nthr - 1){ + if (alpha == 1.0 && beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1b0<data_t<type_i>, + data_t<type_o>>()(input[e]); + } + } else if (alpha == 1.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_a1<data_t<type_i>, + data_t<type_o>>()(input[e], output[e], beta); + } + } else if (beta == 0.0) { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz_b0<data_t<type_i>, + data_t<type_o>>()(input[e], alpha); + } + } else { + PRAGMA_OMP_SIMD() + for (size_t e = nelems - rem_elems; e < nelems; ++e) { + output[e] = qz<data_t<type_i>, data_t<type_o>>() + (input[e], output[e], alpha, beta); + } + } + } + }); + return success; + } +}; + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, + typename utils::enable_if< + tag_i == any && tag_o == any && order_keep == fmt_order::any, + spec::direct_copy_except_dim_0>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) { + return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d); + }; + /* FIXME: is the formula correct? */ + return input_d.similar_to(output_d, true, false, 1) + && is_dense_no_0(input_d) && is_dense_no_0(output_d) + && simple_attr_check(attr, false); + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + input += input_d.blk_off(0); + output += output_d.blk_off(0); + + const int N = input_d.dims()[0]; + const dim_t is = input_d.blocking_desc().strides[0]; + const dim_t os = output_d.blocking_desc().strides[0]; + const dim_t nelems_no_d0 = nelems_no_dim_0(input_d); + const dim_t work_amount = N * nelems_no_d0; + + if (alpha == 1.0 && beta == 0.0) { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = dim1_s + work_rem > nelems_no_d0 + ? nelems_no_d0 : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e) { + output[os * n + e] = _qz_a1b0<type_i, type_o>()( + input[is * n + e]); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } else { + parallel(0, [&](const int ithr, const int nthr) { + dim_t n{0}, dim1_s{0}; + dim_t start{0}, end{0}; + balance211(work_amount, nthr, ithr, start, end); + nd_iterator_init(start, n, N, dim1_s, nelems_no_d0); + while(start < end) { + dim_t work_rem = end - start; + dim_t dim1_e = + dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0 + : dim1_s + work_rem; + PRAGMA_OMP_SIMD() + for (dim_t e = dim1_s; e < dim1_e; ++e){ + output[os * n + e] = _qz<type_i, type_o>()( + input[is * n + e], output[os * n + e], alpha, + beta); + } + nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0); + } + }); + } + + return success; + } + +private: + static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) { + const int ndims = data_d.ndims(); + if (ndims <= 1) return 1; + return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1); + } + + static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) { + dims_t blocks; + data_d.compute_blocks(blocks); + + const auto &blk = data_d.blocking_desc(); + + dim_t blk_size = 1; + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) + blk_size *= blk.inner_blks[iblk]; + + dim_t max_size = blk_size; + for (int d = 1; d < data_d.ndims(); ++d) { + max_size = nstl::max(max_size, + data_d.padded_dims()[d] / blocks[d] * blk.strides[d]); + } + + return max_size; + } +}; + +template <SIMPLE_REORDER_TEMPL_DECL> +struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, + typename utils::enable_if< + tag_i == any && tag_o == any && order_keep == fmt_order::any, + spec::reference>::type> +{ + static bool is_applicable(const memory_desc_wrapper &input_d, + const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { + /* supported smask: 0x0...011..10...0, + * i.e. 1 should be contiguous */ + int smask = attr ? attr->output_scales_.mask_ : 0; + for (; smask > 0 && !(smask & 0x1); smask >>= 1); + for (; smask > 0 && smask & 0x1; smask >>= 1); + return true + && input_d.is_blocking_desc() + && output_d.is_blocking_desc() + && !output_d.is_additional_buffer() + && !input_d.is_additional_buffer() + && smask == 0; + } + + static status_t execute(const cpu_reorder_pd_t *pd, + const data_t<type_i> *input, data_t<type_o> *output) { + DECLARE_COMMON_PARAMS(); + + const size_t nelems = input_d.nelems(); + + int ndims_start = 0, ndims_mask = 0; + int smask = pd->attr()->output_scales_.mask_; + for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start; + for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask; + assert(smask == 0); + + const ptrdiff_t D_start + = utils::array_product(input_d.dims(), ndims_start); + const ptrdiff_t D_mask + = utils::array_product(input_d.dims() + ndims_start, ndims_mask); + const ptrdiff_t D_rest = nelems / D_start / D_mask; + + const float *scales = pd->attr()->output_scales_.scales_; + + parallel_nd(D_start, D_mask, D_rest, + [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) { + const float scale = scales[dm]; + + const size_t e = (ds * D_mask + dm) * D_rest + dr; + const auto &i = input[input_d.off_l(e)]; + auto &o = output[output_d.off_l(e)]; + + o = _qz<type_i, type_o>()(i, o, scale, beta); + }); + + return success; + } +}; + + +/* high level class declaration */ + +template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void> +struct simple_reorder_t: public cpu_primitive_t { + struct pd_t: public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("simple:any", simple_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + bool args_ok = true + && src_md->data_type == type_i + && dst_md->data_type == type_o + && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>:: + is_applicable(src_md, dst_md, attr); + if (!args_ok) + return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + }; + + simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const data_t<type_i> *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(data_t<type_o> *, MKLDNN_ARG_TO); + simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute( + pd(), input, output); + return status::success; + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +#undef SIMPLE_REORDER_TEMPL_DECL +#undef SIMPLE_REORDER_TEMPL_CALL + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp new file mode 100644 index 0000000000..f0947573a9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "mkldnn_thread.hpp" + +#include "simple_sum.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t data_type> +status_t simple_sum_t<data_type>::execute(const exec_ctx_t &ctx) const { + auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + + const memory_desc_wrapper o_d(pd()->dst_md()); + output += o_d.blk_off(0); + + const int num_arrs = pd()->n_inputs(); + const data_t *input_ptrs[max_num_arrs]; + const size_t nelems = o_d.nelems(); + + for (int a = 0; a < num_arrs; ++a) { + const memory_desc_wrapper i_d(pd()->src_md(a)); + input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a) + + i_d.blk_off(0); + } + + const size_t block_size = 16 * 1024 / sizeof(data_type); + const size_t blocks_number = nelems / block_size; + const size_t tail = nelems % block_size; + + const auto scales = pd()->scales(); + parallel(0, [&](const int ithr, const int nthr) { + size_t start{0}, end{0}; + balance211(blocks_number, nthr, ithr, start, end); + + for (size_t nb = start; nb < end; ++nb) { + size_t start_e = nb * block_size; + size_t end_e = start_e + block_size; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + + if (tail != 0 && ithr == nthr - 1) { + size_t start_e = nelems - tail; + size_t end_e = nelems; + + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] = data_t(scales[0] * input_ptrs[0][e]); + } + for (int a = 1; a < num_arrs; a++) { + PRAGMA_OMP_SIMD() + for (size_t e = start_e; e < end_e; e++) { + output[e] += data_t(scales[a] * input_ptrs[a][e]); + } + } + } + }); + + return status::success; +} + +template struct simple_sum_t<data_type::f32>; + +} +} +} diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp new file mode 100644 index 0000000000..2a0187a184 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* +* Copyright 2017-2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef SIMPLE_SUM_HPP +#define SIMPLE_SUM_HPP + +#include "cpu_sum_pd.hpp" +#include "cpu_primitive.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t data_type> +struct simple_sum_t: public cpu_primitive_t { + struct pd_t: public cpu_sum_pd_t { + using cpu_sum_pd_t::cpu_sum_pd_t; + + DECLARE_SUM_PD_T("simple:any", simple_sum_t); + + status_t init() { + const int n = n_inputs(); + + bool ok = true + && cpu_sum_pd_t::init() == status::success + && n <= max_num_arrs; + if (!ok) return status::unimplemented; + + const memory_desc_wrapper o_d(dst_md()); + ok = ok + && o_d.data_type() == data_type + && o_d.is_dense(); + if (!ok) return status::unimplemented; + + for (int i = 0; i < n; ++i) { + const memory_desc_wrapper i_d(src_md(i)); + if (i_d != o_d) return status::unimplemented; + } + + return status::success; + } + }; + + simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {} + + virtual status_t execute(const exec_ctx_t &ctx) const override; + + enum {max_num_arrs = 16 }; + typedef typename prec_traits<data_type>::type data_t; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } +}; + +} +} +} + +#endif + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp new file mode 100644 index 0000000000..c2082d7d62 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp @@ -0,0 +1,376 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef CPU_WINO_REORDER_HPP +#define CPU_WINO_REORDER_HPP + +#include "mkldnn_thread.hpp" + +#include "simple_q10n.hpp" + +namespace mkldnn { +namespace impl { +namespace cpu { + +template <data_type_t type_i, data_type_t type_o> +struct wino_reorder_t : public cpu_primitive_t { + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t); + + static status_t create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, + engine_t *src_engine, const memory_desc_t *src_md, + engine_t *dst_engine, const memory_desc_t *dst_md) { + const memory_desc_wrapper id(src_md), od(dst_md); + bool args_ok = true + && id.data_type() == type_i + && od.data_type() == type_o + && id.matches_tag(utils::pick(id.ndims() - 4, + format_tag::oihw, format_tag::goihw)) + && od.format_kind() == format_kind::wino + && utils::one_of(od.wino_desc().wino_format, + mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio, + mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio); + if (!args_ok) return status::invalid_arguments; + + auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine, + dst_md); + if (_pd == nullptr) return status::out_of_memory; + if (_pd->init() != status::success) { + delete _pd; + return status::unimplemented; + } + return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); + } + + status_t init() { + status_t status = cpu_reorder_pd_t::init(); + if (status != status::success) return status; + + init_scratchpad(); + + return status::success; + } + + private: + void init_scratchpad() { + auto &o = memory_desc_wrapper(dst_md()).wino_desc(); + size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block; + size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic; + + using namespace memory_tracking::names; + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(key_reorder_wino_transform_space, + sizeof(in_data_t) * transform_space_size); + scratchpad.book(key_reorder_wino_plain, + sizeof(out_data_t) * plain_size); + } + }; + +private: + typedef typename prec_traits<type_i>::type in_data_t; + typedef typename prec_traits<type_o>::type out_data_t; + const int unsign_val_in_wino_domain_ = 5; + + wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) { + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + + r_ = dst_d.wino_desc().r; + w_alpha_ = dst_d.wino_desc().alpha; + wino_format_ = dst_d.wino_desc().wino_format; + + const auto &in_dims = src_d.dims(); + int groups; + int groups_offset; + if (src_d.ndims() == 5) { + groups = in_dims[0]; + groups_offset = 1; + } else { + groups = 1; + groups_offset = 0; + } + assert(groups == 1); // groups are not supported now + MAYBE_UNUSED(groups); + + or_oc_ = in_dims[0 + groups_offset]; + or_ic_ = in_dims[1 + groups_offset]; + kh_ = in_dims[2 + groups_offset]; + kw_ = in_dims[3 + groups_offset]; + + oc_ = dst_d.wino_desc().oc; + ic_ = dst_d.wino_desc().ic; + oc_block_ = dst_d.wino_desc().oc_block; + ic_block_ = dst_d.wino_desc().ic_block; + assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0); + nb_oc_ = oc_ / oc_block_; + nb_ic_ = ic_ / ic_block_; + ic2_block_ = 1; + if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + ic2_block_ = dst_d.wino_desc().ic2_block; + oc2_block_ = dst_d.wino_desc().oc2_block; + assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0); + + adj_scale_ = dst_d.wino_desc().adj_scale; + + size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_; + size_wspace_ = r_ * w_alpha_ * oc_block_; + } + + void transform(out_data_t *__restrict tmp_wei, + const in_data_t *__restrict input, + in_data_t *__restrict wspace) const { + const memory_desc_wrapper src_d(pd()->src_md()); + + const int smask = pd()->attr()->output_scales_.mask_; + const int ndims_mask = math::ilog2q(smask + 1); + const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask); + const float *__restrict scales = pd()->attr()->output_scales_.scales_; + assert(D_mask == 1 || D_mask == (size_t)oc_); + + /* transform weights to winograd domain */ + const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 }, + { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } }; + + const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f }, + { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f }, + { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f }, + { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f }, + { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f }, + { 0.f, 0.f, 1.f } }; + + float *__restrict g; + if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi, + mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo)) + g = (float *)G_2x2_3x3; + else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio) + g = (float *)G_4x4_3x3; + else { + assert("Unknown winograd weights target layout"); + return; + } + + int Z = oc_ * ic_; + assert(r_ == kh_ && r_ == kw_); + + for (int iic = 0; iic < ic_; iic++) { + for (int ob = 0; ob < nb_oc_; ob++) { + const in_data_t *__restrict _inp + = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_; + out_data_t *__restrict _out + = tmp_wei + (iic * nb_oc_ + ob) * oc_block_; + + for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; }); + + for_nd(0, 1, r_, w_alpha_, oc_block_, + [&](int ih, int j, int ioc) { + for (int iw = 0; iw < r_; ++iw) { + int inp_oc = ob * oc_block_ + ioc; + int inp_ic = iic; + in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_) + ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw] + : 0.f; + wspace[(ih * w_alpha_ + j) * oc_block_ + ioc] + += inp_v * g[j * r_ + iw]; + } + }); + + for_nd(0, 1, w_alpha_, w_alpha_, oc_block_, + [&](int i, int j, int ioc) { + float t = 0; + for (int k = 0; k < r_; ++k) + t += g[i * r_ + k] + * wspace[(k * w_alpha_ + j) * oc_block_ + ioc]; + if (type_o == data_type::s8) { + const float scale = (D_mask == 1) + ? scales[0] + : scales[ob * oc_block_ + ioc]; + _out[(i * w_alpha_ + j) * Z + ioc] + = qz_b0<in_data_t, out_data_t>()( + (in_data_t)t, scale * adj_scale_); + } else { + _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t; + } + }); + }} + } + + void reorder_to_aaOIoi(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int32_t *__restrict dst_bias = nullptr; + if (type_o == data_type::s8) { + const auto bias_shift = sizeof(out_data_t) * size_wino_wei_; + const size_t bias_size = w_alpha_ * w_alpha_ * oc_; + + dst_bias = (int32_t *)(output + bias_shift); + utils::array_set((int32_t *)dst_bias, 0, bias_size); + } + int index = 0; + for (int u_h = 0; u_h < w_alpha_; u_h++) { + for (int u_w = 0; u_w < w_alpha_; u_w++) { + for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) { + int u_h_shift = u_h * w_alpha_ * ic_ * oc_; + int u_w_shift = u_w * ic_ * oc_; + int u_h_shift_b = u_h * w_alpha_ * oc_; + int u_w_shift_b = u_w * oc_; + int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_; + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int _i = ib * ic_block_; + int _o = ob * oc_block_; + int ic_shift = (_i + i) * oc_; + int oc_shift = (_o + o); + int ic_block_shift = ib * oc_block_ * ic_block_ + i; + int src_offset = + u_h_shift + u_w_shift + ic_shift + oc_shift; + int dst_offset = u_h_shift + u_w_shift + oc_block_shift + + ic_block_shift; + + output[dst_offset] = tmp_wei[src_offset]; + if (type_o == data_type::s8) { + int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift; + if (index != unsign_val_in_wino_domain_) + dst_bias[bias_offset] + -= (128 * (int32_t)output[dst_offset]); + else + dst_bias[bias_offset] = 0; + } + }} + }); + index++; + }} + } + + void reorder_to_aaOio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_, + [&](int u_h, int u_w, int ob) { + for (int ib = 0; ib < nb_ic_; ib++) { + for (int i = 0; i < ic_block_; i++) { + for (int o = 0; o < oc_block_; o++) { + int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_ + + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o); + + int dst_offset + = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_ + + ob * nb_ic_ * ic_block_ * oc_block_ + + ib * ic_block_ * oc_block_ + i * oc_block_ + o; + output[dst_offset] = tmp_wei[src_offset]; + }}} + }); + } + + void reorder_to_aaOBiOo(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks, + [&](int u_h, int u_w, int occ) { + for (int ib = 0; ib < nb_ic_; ib++) { + out_data_t *__restrict wei_ptr = output + + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib) + * oc2_block_ * ic_block_ * oc_block_; + int wei_offset = 0; + for (int i = 0; i < ic_block_; i++) { + for (int ob2 = 0; ob2 < oc2_block_; ob2++) { + for (int o = 0; o < oc_block_; o++) { + int icp = ib * ic_block_ + i; + int ocp = + occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + wei_ptr[wei_offset + o] = tmp_wei[src_offset]; + } + wei_offset += oc_block_; + }} + } + }); + } + + void reorder_to_OBaaIBOIio(out_data_t *__restrict output, + const out_data_t *__restrict tmp_wei) const { + int ic_chunks = nb_ic_ / ic2_block_; + int oc_chunks = nb_oc_ / oc2_block_; + + for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_, + [&](int occ, int u_h, int u_w) { + for (int icc = 0; icc < ic_chunks; icc++) { + for (int ob = 0; ob < oc2_block_; ob++) { + int ocp = (occ * oc2_block_ + ob) * oc_block_; + for (int ib = 0; ib < ic2_block_; ib++) { + for (int i = 0; i < ic_block_; i++) { + int icp = (icc * ic2_block_ + ib) * ic_block_ + i; + + int src_offset = u_h * w_alpha_ * ic_ * oc_ + + u_w * ic_ * oc_ + icp * oc_ + ocp; + int wei_offset + = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w) + * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_ + + ib) * ic_block_ + i) * oc_block_; + for (int o = 0; o < oc_block_; o++) + output[wei_offset + o] = tmp_wei[src_offset + o]; + }} + }} + }); + } + + virtual status_t execute(const exec_ctx_t &ctx) const override { + auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM); + auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO); + + auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_wino_transform_space); + auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>( + memory_tracking::names::key_reorder_wino_plain); + + transform(tmp_wei, input, wspace); + + /* reorder to winograd domain */ + switch (wino_format_) { + case mkldnn_wino_wei_aaOIoi: + reorder_to_aaOIoi(output, tmp_wei); break; + case mkldnn_wino_wei_aaOio: + reorder_to_aaOio(output, tmp_wei); break; + case mkldnn_wino_wei_aaOBiOo: + reorder_to_aaOBiOo(output, tmp_wei); break; + case mkldnn_wino_wei_OBaaIBOIio: + reorder_to_OBaaIBOIio(output, tmp_wei); break; + default: assert("Unknown wino format"); break; + } + + return status::success; + } + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } + int r_, w_alpha_; + int ic_, oc_, or_ic_, or_oc_, kh_, kw_; + int oc_block_, ic_block_, oc2_block_, ic2_block_; + float adj_scale_; + int nb_oc_, nb_ic_; + mkldnn_wino_memory_format_t wino_format_; + int size_wino_wei_; + int size_wspace_; +}; + +} // namespace cpu +} // namespace impl +} // namespace mkldnn + +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT new file mode 100644 index 0000000000..66b6ea55d0 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT @@ -0,0 +1,47 @@ + +Copyright (c) 2007 MITSUNARI Shigeo +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. +Neither the name of the copyright owner nor the names of its contributors may +be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGE. +----------------------------------------------------------------------------- +ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た +す場合に限り、再頒布および使用が許可されます。 + +ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 +を含めること。 +バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 +権表示、本条件一覧、および下記免責条項を含めること。 +書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 +に、著作権者の名前またはコントリビューターの名前を使用してはならない。 +本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ +れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 +に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 +著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを +問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で +あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 +本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の +喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 +損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 +一切責任を負わないものとします。 diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h new file mode 100644 index 0000000000..cf5771332f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h @@ -0,0 +1,2658 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#pragma once +#ifndef XBYAK_XBYAK_H_ +#define XBYAK_XBYAK_H_ +/*! + @file xbyak.h + @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++ + @author herumi + @url https://github.com/herumi/xbyak + @note modified new BSD license + http://opensource.org/licenses/BSD-3-Clause +*/ +#ifndef XBYAK_NO_OP_NAMES + #if not +0 // trick to detect whether 'not' is operator or not + #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()." + #endif +#endif + +#include <stdio.h> // for debug print +#include <assert.h> +#include <list> +#include <string> +#include <algorithm> +#ifndef NDEBUG +#include <iostream> +#endif + +// #define XBYAK_DISABLE_AVX512 + +//#define XBYAK_USE_MMAP_ALLOCATOR +#if !defined(__GNUC__) || defined(__MINGW32__) + #undef XBYAK_USE_MMAP_ALLOCATOR +#endif + +#ifdef __GNUC__ + #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor)) +#else + #define XBYAK_GNUC_PREREQ(major, minor) 0 +#endif + +// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft. +#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\ + ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__))) + #include <unordered_set> + #define XBYAK_STD_UNORDERED_SET std::unordered_set + #include <unordered_map> + #define XBYAK_STD_UNORDERED_MAP std::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap + +/* + Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using + libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version). +*/ +#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__) + #include <tr1/unordered_set> + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include <tr1/unordered_map> + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600) + #include <unordered_set> + #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set + #include <unordered_map> + #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map + #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap + +#else + #include <set> + #define XBYAK_STD_UNORDERED_SET std::set + #include <map> + #define XBYAK_STD_UNORDERED_MAP std::map + #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap +#endif +#ifdef _WIN32 + #include <winsock2.h> + #include <windows.h> + #include <malloc.h> +#elif defined(__GNUC__) + #include <unistd.h> + #include <sys/mman.h> + #include <stdlib.h> +#endif +#if !defined(_MSC_VER) || (_MSC_VER >= 1600) + #include <stdint.h> +#endif + +#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__)) + #define XBYAK64_WIN +#elif defined(__x86_64__) + #define XBYAK64_GCC +#endif +#if !defined(XBYAK64) && !defined(XBYAK32) + #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN) + #define XBYAK64 + #else + #define XBYAK32 + #endif +#endif + +#if (__cplusplus >= 201103) || (_MSC_VER >= 1800) + #define XBYAK_VARIADIC_TEMPLATE +#endif + +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) /* remove inline function */ + #pragma warning(disable : 4786) /* identifier is too long */ + #pragma warning(disable : 4503) /* name is too long */ + #pragma warning(disable : 4127) /* constant expresison */ +#endif + +namespace Xbyak { + +enum { + DEFAULT_MAX_CODE_SIZE = 4096, + VERSION = 0x5760 /* 0xABCD = A.BC(D) */ +}; + +#ifndef MIE_INTEGER_TYPE_DEFINED +#define MIE_INTEGER_TYPE_DEFINED +#ifdef _MSC_VER + typedef unsigned __int64 uint64; + typedef __int64 sint64; +#else + typedef uint64_t uint64; + typedef int64_t sint64; +#endif +typedef unsigned int uint32; +typedef unsigned short uint16; +typedef unsigned char uint8; +#endif + +#ifndef MIE_ALIGN + #ifdef _MSC_VER + #define MIE_ALIGN(x) __declspec(align(x)) + #else + #define MIE_ALIGN(x) __attribute__((aligned(x))) + #endif +#endif +#ifndef MIE_PACK // for shufps + #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w)) +#endif + +enum { + ERR_NONE = 0, + ERR_BAD_ADDRESSING, + ERR_CODE_IS_TOO_BIG, + ERR_BAD_SCALE, + ERR_ESP_CANT_BE_INDEX, + ERR_BAD_COMBINATION, + ERR_BAD_SIZE_OF_REGISTER, + ERR_IMM_IS_TOO_BIG, + ERR_BAD_ALIGN, + ERR_LABEL_IS_REDEFINED, + ERR_LABEL_IS_TOO_FAR, + ERR_LABEL_IS_NOT_FOUND, + ERR_CODE_ISNOT_COPYABLE, + ERR_BAD_PARAMETER, + ERR_CANT_PROTECT, + ERR_CANT_USE_64BIT_DISP, + ERR_OFFSET_IS_TOO_BIG, + ERR_MEM_SIZE_IS_NOT_SPECIFIED, + ERR_BAD_MEM_SIZE, + ERR_BAD_ST_COMBINATION, + ERR_OVER_LOCAL_LABEL, // not used + ERR_UNDER_LOCAL_LABEL, + ERR_CANT_ALLOC, + ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW, + ERR_BAD_PROTECT_MODE, + ERR_BAD_PNUM, + ERR_BAD_TNUM, + ERR_BAD_VSIB_ADDRESSING, + ERR_CANT_CONVERT, + ERR_LABEL_ISNOT_SET_BY_L, + ERR_LABEL_IS_ALREADY_SET_BY_L, + ERR_BAD_LABEL_STR, + ERR_MUNMAP, + ERR_OPMASK_IS_ALREADY_SET, + ERR_ROUNDING_IS_ALREADY_SET, + ERR_K0_IS_INVALID, + ERR_EVEX_IS_INVALID, + ERR_SAE_IS_INVALID, + ERR_ER_IS_INVALID, + ERR_INVALID_BROADCAST, + ERR_INVALID_OPMASK_WITH_MEMORY, + ERR_INVALID_ZERO, + ERR_INVALID_RIP_IN_AUTO_GROW, + ERR_INVALID_MIB_ADDRESS, + ERR_INTERNAL, + ERR_X2APIC_IS_NOT_SUPPORTED +}; + +class Error : public std::exception { + int err_; +public: + explicit Error(int err) : err_(err) + { + if (err_ < 0 || err_ > ERR_INTERNAL) { + fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_); + //exit(1); + } + } + operator int() const { return err_; } + const char *what() const throw() + { + static const char *errTbl[] = { + "none", + "bad addressing", + "code is too big", + "bad scale", + "esp can't be index", + "bad combination", + "bad size of register", + "imm is too big", + "bad align", + "label is redefined", + "label is too far", + "label is not found", + "code is not copyable", + "bad parameter", + "can't protect", + "can't use 64bit disp(use (void*))", + "offset is too big", + "MEM size is not specified", + "bad mem size", + "bad st combination", + "over local label", + "under local label", + "can't alloc", + "T_SHORT is not supported in AutoGrow", + "bad protect mode", + "bad pNum", + "bad tNum", + "bad vsib addressing", + "can't convert", + "label is not set by L()", + "label is already set by L()", + "bad label string", + "err munmap", + "opmask is already set", + "rounding is already set", + "k0 is invalid", + "evex is invalid", + "sae(suppress all exceptions) is invalid", + "er(embedded rounding) is invalid", + "invalid broadcast", + "invalid opmask with memory", + "invalid zero", + "invalid rip in AutoGrow", + "invalid mib address", + "internal error", + "x2APIC is not supported" + }; + assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl)); + return errTbl[err_]; + } +}; + +inline const char *ConvertErrorToString(const Error& err) +{ + return err.what(); +} + +inline void *AlignedMalloc(size_t size, size_t alignment) +{ +#ifdef __MINGW32__ + return __mingw_aligned_malloc(size, alignment); +#elif defined(_WIN32) + return _aligned_malloc(size, alignment); +#else + void *p; + int ret = posix_memalign(&p, alignment, size); + return (ret == 0) ? p : 0; +#endif +} + +inline void AlignedFree(void *p) +{ +#ifdef __MINGW32__ + __mingw_aligned_free(p); +#elif defined(_MSC_VER) + _aligned_free(p); +#else + free(p); +#endif +} + +template<class To, class From> +inline const To CastTo(From p) throw() +{ + return (const To)(size_t)(p); +} +namespace inner { + +static const size_t ALIGN_PAGE_SIZE = 4096; + +inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; } +inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; } + +inline uint32 VerifyInInt32(uint64 x) +{ +#ifdef XBYAK64 + if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + return static_cast<uint32>(x); +} + +enum LabelMode { + LasIs, // as is + Labs, // absolute + LaddTop // (addr + top) for mov(reg, label) with AutoGrow +}; + +} // inner + +/* + custom allocator +*/ +struct Allocator { + virtual uint8 *alloc(size_t size) { return reinterpret_cast<uint8*>(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); } + virtual void free(uint8 *p) { AlignedFree(p); } + virtual ~Allocator() {} + /* override to return false if you call protect() manually */ + virtual bool useProtect() const { return true; } +}; + +#ifdef XBYAK_USE_MMAP_ALLOCATOR +class MmapAllocator : Allocator { + typedef XBYAK_STD_UNORDERED_MAP<uintptr_t, size_t> SizeList; + SizeList sizeList_; +public: + uint8 *alloc(size_t size) + { + const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1; + size = (size + alignedSizeM1) & ~alignedSizeM1; +#ifdef MAP_ANONYMOUS + const int mode = MAP_PRIVATE | MAP_ANONYMOUS; +#elif defined(MAP_ANON) + const int mode = MAP_PRIVATE | MAP_ANON; +#else + #error "not supported" +#endif + void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0); + if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC); + assert(p); + sizeList_[(uintptr_t)p] = size; + return (uint8*)p; + } + void free(uint8 *p) + { + if (p == 0) return; + SizeList::iterator i = sizeList_.find((uintptr_t)p); + if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER); + if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP); + sizeList_.erase(i); + } +}; +#endif + +class Address; +class Reg; + +class Operand { + static const uint8 EXT8BIT = 0x20; + unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil + unsigned int kind_:9; + unsigned int bit_:10; +protected: + unsigned int zero_:1; + unsigned int mask_:3; + unsigned int rounding_:3; + void setIdx(int idx) { idx_ = idx; } +public: + enum Kind { + NONE = 0, + MEM = 1 << 0, + REG = 1 << 1, + MMX = 1 << 2, + FPU = 1 << 3, + XMM = 1 << 4, + YMM = 1 << 5, + ZMM = 1 << 6, + OPMASK = 1 << 7, + BNDREG = 1 << 8 + }; + enum Code { +#ifdef XBYAK64 + RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15, + R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D, + R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W, + R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B, + SPL = 4, BPL, SIL, DIL, +#endif + EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI, + AX = 0, CX, DX, BX, SP, BP, SI, DI, + AL = 0, CL, DL, BL, AH, CH, DH, BH + }; + Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { } + Operand(int idx, Kind kind, int bit, bool ext8bit = 0) + : idx_(static_cast<uint8>(idx | (ext8bit ? EXT8BIT : 0))) + , kind_(kind) + , bit_(bit) + , zero_(0), mask_(0), rounding_(0) + { + assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two + } + Kind getKind() const { return static_cast<Kind>(kind_); } + int getIdx() const { return idx_ & (EXT8BIT - 1); } + bool isNone() const { return kind_ == 0; } + bool isMMX() const { return is(MMX); } + bool isXMM() const { return is(XMM); } + bool isYMM() const { return is(YMM); } + bool isZMM() const { return is(ZMM); } + bool isXMEM() const { return is(XMM | MEM); } + bool isYMEM() const { return is(YMM | MEM); } + bool isZMEM() const { return is(ZMM | MEM); } + bool isOPMASK() const { return is(OPMASK); } + bool isBNDREG() const { return is(BNDREG); } + bool isREG(int bit = 0) const { return is(REG, bit); } + bool isMEM(int bit = 0) const { return is(MEM, bit); } + bool isFPU() const { return is(FPU); } + bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; } + bool isExtIdx() const { return (getIdx() & 8) != 0; } + bool isExtIdx2() const { return (getIdx() & 16) != 0; } + bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); } + bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); } + bool hasZero() const { return zero_; } + int getOpmaskIdx() const { return mask_; } + int getRounding() const { return rounding_; } + void setKind(Kind kind) + { + if ((kind & (XMM|YMM|ZMM)) == 0) return; + kind_ = kind; + bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512; + } + void setBit(int bit) { bit_ = bit; } + void setOpmaskIdx(int idx, bool ignore_idx0 = false) + { + if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID); + if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET); + mask_ = idx; + } + void setRounding(int idx) + { + if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET); + rounding_ = idx; + } + void setZero() { zero_ = true; } + // ah, ch, dh, bh? + bool isHigh8bit() const + { + if (!isBit(8)) return false; + if (isExt8bit()) return false; + const int idx = getIdx(); + return AH <= idx && idx <= BH; + } + // any bit is accetable if bit == 0 + bool is(int kind, uint32 bit = 0) const + { + return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16) + } + bool isBit(uint32 bit) const { return (bit_ & bit) != 0; } + uint32 getBit() const { return bit_; } + const char *toString() const + { + const int idx = getIdx(); + if (kind_ == REG) { + if (isExt8bit()) { + static const char *tbl[4] = { "spl", "bpl", "sil", "dil" }; + return tbl[idx - 4]; + } + static const char *tbl[4][16] = { + { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" }, + { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" }, + { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" }, + { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" }, + }; + return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx]; + } else if (isOPMASK()) { + static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" }; + return tbl[idx]; + } else if (isZMM()) { + static const char *tbl[32] = { + "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", + "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31" + }; + return tbl[idx]; + } else if (isYMM()) { + static const char *tbl[32] = { + "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15", + "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31" + }; + return tbl[idx]; + } else if (isXMM()) { + static const char *tbl[32] = { + "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15", + "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31" + }; + return tbl[idx]; + } else if (isMMX()) { + static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" }; + return tbl[idx]; + } else if (isFPU()) { + static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" }; + return tbl[idx]; + } else if (isBNDREG()) { + static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" }; + return tbl[idx]; + } + throw Error(ERR_INTERNAL); + } + bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; } + bool operator==(const Operand& rhs) const; + bool operator!=(const Operand& rhs) const { return !operator==(rhs); } + const Address& getAddress() const; + const Reg& getReg() const; +}; + +class Label; + +struct Reg8; +struct Reg16; +struct Reg32; +#ifdef XBYAK64 +struct Reg64; +#endif +class Reg : public Operand { +public: + Reg() { } + Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { } + Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); } + uint8 getRexW() const { return isREG(64) ? 8 : 0; } + uint8 getRexR() const { return isExtIdx() ? 4 : 0; } + uint8 getRexX() const { return isExtIdx() ? 2 : 0; } + uint8 getRexB() const { return isExtIdx() ? 1 : 0; } + uint8 getRex(const Reg& base = Reg()) const + { + uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB(); + if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40; + return rex; + } + Reg8 cvt8() const; + Reg16 cvt16() const; + Reg32 cvt32() const; +#ifdef XBYAK64 + Reg64 cvt64() const; +#endif +}; + +inline const Reg& Operand::getReg() const +{ + assert(!isMEM()); + return static_cast<const Reg&>(*this); +} + +struct Reg8 : public Reg { + explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { } +}; + +struct Reg16 : public Reg { + explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { } +}; + +struct Mmx : public Reg { + explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { } +}; + +struct EvexModifierRounding { + enum { + T_RN_SAE = 1, + T_RD_SAE = 2, + T_RU_SAE = 3, + T_RZ_SAE = 4, + T_SAE = 5 + }; + explicit EvexModifierRounding(int rounding) : rounding(rounding) {} + int rounding; +}; +struct EvexModifierZero{EvexModifierZero() {}}; + +struct Xmm : public Mmx { + explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { } + Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { } + Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; } + Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; } + Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; } +}; + +struct Ymm : public Xmm { + explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { } + Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Zmm : public Ymm { + explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { } + Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; } +}; + +struct Opmask : public Reg { + explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {} +}; + +struct BoundsReg : public Reg { + explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {} +}; + +template<class T>T operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; } +template<class T>T operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; } +template<class T>T operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; } + +struct Fpu : public Reg { + explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { } +}; + +struct Reg32e : public Reg { + explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {} +}; +struct Reg32 : public Reg32e { + explicit Reg32(int idx = 0) : Reg32e(idx, 32) {} +}; +#ifdef XBYAK64 +struct Reg64 : public Reg32e { + explicit Reg64(int idx = 0) : Reg32e(idx, 64) {} +}; +struct RegRip { + sint64 disp_; + const Label* label_; + bool isAddr_; + explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {} + friend const RegRip operator+(const RegRip& r, int disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, int disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ + disp, r.label_, r.isAddr_); + } + friend const RegRip operator-(const RegRip& r, sint64 disp) { + return RegRip(r.disp_ - disp, r.label_, r.isAddr_); + } + friend const RegRip operator+(const RegRip& r, const Label& label) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_, &label); + } + friend const RegRip operator+(const RegRip& r, const void *addr) { + if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING); + return RegRip(r.disp_ + (sint64)addr, 0, true); + } +}; +#endif + +inline Reg8 Reg::cvt8() const +{ + const int idx = getIdx(); + if (isBit(8)) return Reg8(idx, isExt8bit()); +#ifdef XBYAK32 + if (idx >= 4) throw Error(ERR_CANT_CONVERT); +#endif + return Reg8(idx, 4 <= idx && idx < 8); +} + +inline Reg16 Reg::cvt16() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg16(idx); +} + +inline Reg32 Reg::cvt32() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg32(idx); +} + +#ifdef XBYAK64 +inline Reg64 Reg::cvt64() const +{ + const int idx = getIdx(); + if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT); + return Reg64(idx); +} +#endif + +#ifndef XBYAK_DISABLE_SEGMENT +// not derived from Reg +class Segment { + int idx_; +public: + enum { + es, cs, ss, ds, fs, gs + }; + explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); } + int getIdx() const { return idx_; } + const char *toString() const + { + static const char tbl[][3] = { + "es", "cs", "ss", "ds", "fs", "gs" + }; + return tbl[idx_]; + } +}; +#endif + +class RegExp { +public: +#ifdef XBYAK64 + enum { i32e = 32 | 64 }; +#else + enum { i32e = 32 }; +#endif + RegExp(size_t disp = 0) : scale_(0), disp_(disp) { } + RegExp(const Reg& r, int scale = 1) + : scale_(scale) + , disp_(0) + { + if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (scale == 0) return; + if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE); + if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index + index_ = r; + } else { + base_ = r; + } + } + bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); } + RegExp optimize() const + { + RegExp exp = *this; + // [reg * 2] => [reg + reg] + if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) { + exp.base_ = index_; + exp.scale_ = 1; + } + return exp; + } + bool operator==(const RegExp& rhs) const + { + return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_; + } + const Reg& getBase() const { return base_; } + const Reg& getIndex() const { return index_; } + int getScale() const { return scale_; } + size_t getDisp() const { return disp_; } + void verify() const + { + if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER); + if (index_.getBit() && index_.getBit() <= 64) { + if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX); + if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER); + } + } + friend RegExp operator+(const RegExp& a, const RegExp& b); + friend RegExp operator-(const RegExp& e, size_t disp); + uint8 getRex() const + { + uint8 rex = index_.getRexX() | base_.getRexB(); + return rex ? uint8(rex | 0x40) : 0; + } +private: + /* + [base_ + index_ * scale_ + disp_] + base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm + */ + Reg base_; + Reg index_; + int scale_; + size_t disp_; +}; + +inline RegExp operator+(const RegExp& a, const RegExp& b) +{ + if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + RegExp ret = a; + if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; } + if (b.base_.getBit()) { + if (ret.base_.getBit()) { + if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING); + // base + base => base + index * 1 + ret.index_ = b.base_; + // [reg + esp] => [esp + reg] + if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_); + ret.scale_ = 1; + } else { + ret.base_ = b.base_; + } + } + ret.disp_ += b.disp_; + return ret; +} +inline RegExp operator*(const Reg& r, int scale) +{ + return RegExp(r, scale); +} +inline RegExp operator-(const RegExp& e, size_t disp) +{ + RegExp ret = e; + ret.disp_ -= disp; + return ret; +} + +// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc) +void *const AutoGrow = (void*)1; //-V566 +void *const DontSetProtectRWE = (void*)2; //-V566 + +class CodeArray { + enum Type { + USER_BUF = 1, // use userPtr(non alignment, non protect) + ALLOC_BUF, // use new(alignment, protect) + AUTO_GROW // automatically move and grow memory if necessary + }; + CodeArray(const CodeArray& rhs); + void operator=(const CodeArray&); + bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; } + struct AddrInfo { + size_t codeOffset; // position to write + size_t jmpAddr; // value to write + int jmpSize; // size of jmpAddr + inner::LabelMode mode; + AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode) + : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {} + uint64 getVal(const uint8 *top) const + { + uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top); + if (jmpSize == 4) disp = inner::VerifyInInt32(disp); + return disp; + } + }; + typedef std::list<AddrInfo> AddrInfoList; + AddrInfoList addrInfoList_; + const Type type_; +#ifdef XBYAK_USE_MMAP_ALLOCATOR + MmapAllocator defaultAllocator_; +#else + Allocator defaultAllocator_; +#endif + Allocator *alloc_; +protected: + size_t maxSize_; + uint8 *top_; + size_t size_; + bool isCalledCalcJmpAddress_; + + bool useProtect() const { return alloc_->useProtect(); } + /* + allocate new memory and copy old data to the new area + */ + void growMemory() + { + const size_t newSize = (std::max<size_t>)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2); + uint8 *newTop = alloc_->alloc(newSize); + if (newTop == 0) throw Error(ERR_CANT_ALLOC); + for (size_t i = 0; i < size_; i++) newTop[i] = top_[i]; + alloc_->free(top_); + top_ = newTop; + maxSize_ = newSize; + } + /* + calc jmp address for AutoGrow mode + */ + void calcJmpAddress() + { + if (isCalledCalcJmpAddress_) return; + for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) { + uint64 disp = i->getVal(top_); + rewrite(i->codeOffset, disp, i->jmpSize); + } + isCalledCalcJmpAddress_ = true; + } +public: + enum ProtectMode { + PROTECT_RW = 0, // read/write + PROTECT_RWE = 1, // read/write/exec + PROTECT_RE = 2 // read/exec + }; + explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0) + : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF) + , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_) + , maxSize_(maxSize) + , top_(type_ == USER_BUF ? reinterpret_cast<uint8*>(userPtr) : alloc_->alloc((std::max<size_t>)(maxSize, 1))) + , size_(0) + , isCalledCalcJmpAddress_(false) + { + if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC); + if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) { + alloc_->free(top_); + throw Error(ERR_CANT_PROTECT); + } + } + virtual ~CodeArray() + { + if (isAllocType()) { + if (useProtect()) setProtectModeRW(false); + alloc_->free(top_); + } + } + bool setProtectMode(ProtectMode mode, bool throwException = true) + { + bool isOK = protect(top_, maxSize_, mode); + if (isOK) return true; + if (throwException) throw Error(ERR_CANT_PROTECT); + return false; + } + bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); } + bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); } + void resetSize() + { + size_ = 0; + addrInfoList_.clear(); + isCalledCalcJmpAddress_ = false; + } + void db(int code) + { + if (size_ >= maxSize_) { + if (type_ == AUTO_GROW) { + growMemory(); + } else { + throw Error(ERR_CODE_IS_TOO_BIG); + } + } + top_[size_++] = static_cast<uint8>(code); + } + void db(const uint8 *code, size_t codeSize) + { + for (size_t i = 0; i < codeSize; i++) db(code[i]); + } + void db(uint64 code, size_t codeSize) + { + if (codeSize > 8) throw Error(ERR_BAD_PARAMETER); + for (size_t i = 0; i < codeSize; i++) db(static_cast<uint8>(code >> (i * 8))); + } + void dw(uint32 code) { db(code, 2); } + void dd(uint32 code) { db(code, 4); } + void dq(uint64 code) { db(code, 8); } + const uint8 *getCode() const { return top_; } + template<class F> + const F getCode() const { return reinterpret_cast<F>(top_); } + const uint8 *getCurr() const { return &top_[size_]; } + template<class F> + const F getCurr() const { return reinterpret_cast<F>(&top_[size_]); } + size_t getSize() const { return size_; } + void setSize(size_t size) + { + if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG); + size_ = size; + } + void dump() const + { + const uint8 *p = getCode(); + size_t bufSize = getSize(); + size_t remain = bufSize; + for (int i = 0; i < 4; i++) { + size_t disp = 16; + if (remain < 16) { + disp = remain; + } + for (size_t j = 0; j < 16; j++) { + if (j < disp) { + printf("%02X", p[i * 16 + j]); + } + } + putchar('\n'); + remain -= disp; + if (remain == 0) { + break; + } + } + } + /* + @param offset [in] offset from top + @param disp [in] offset from the next of jmp + @param size [in] write size(1, 2, 4, 8) + */ + void rewrite(size_t offset, uint64 disp, size_t size) + { + assert(offset < maxSize_); + if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER); + uint8 *const data = top_ + offset; + for (size_t i = 0; i < size; i++) { + data[i] = static_cast<uint8>(disp >> (i * 8)); + } + } + void save(size_t offset, size_t val, int size, inner::LabelMode mode) + { + addrInfoList_.push_back(AddrInfo(offset, val, size, mode)); + } + bool isAutoGrow() const { return type_ == AUTO_GROW; } + bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; } + /** + change exec permission of memory + @param addr [in] buffer address + @param size [in] buffer size + @param protectMode [in] mode(RW/RWE/RE) + @return true(success), false(failure) + */ + static inline bool protect(const void *addr, size_t size, int protectMode) + { +#if defined(_WIN32) + const DWORD c_rw = PAGE_READWRITE; + const DWORD c_rwe = PAGE_EXECUTE_READWRITE; + const DWORD c_re = PAGE_EXECUTE_READ; + DWORD mode; +#else + const int c_rw = PROT_READ | PROT_WRITE; + const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC; + const int c_re = PROT_READ | PROT_EXEC; + int mode; +#endif + switch (protectMode) { + case PROTECT_RW: mode = c_rw; break; + case PROTECT_RWE: mode = c_rwe; break; + case PROTECT_RE: mode = c_re; break; + default: + return false; + } +#if defined(_WIN32) + DWORD oldProtect; + return VirtualProtect(const_cast<void*>(addr), size, mode, &oldProtect) != 0; +#elif defined(__GNUC__) + size_t pageSize = sysconf(_SC_PAGESIZE); + size_t iaddr = reinterpret_cast<size_t>(addr); + size_t roundAddr = iaddr & ~(pageSize - static_cast<size_t>(1)); +#ifndef NDEBUG + if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize); +#endif + return mprotect(reinterpret_cast<void*>(roundAddr), size + (iaddr - roundAddr), mode) == 0; +#else + return true; +#endif + } + /** + get aligned memory pointer + @param addr [in] address + @param alignedSize [in] power of two + @return aligned addr by alingedSize + */ + static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16) + { + return reinterpret_cast<uint8*>((reinterpret_cast<size_t>(addr) + alignedSize - 1) & ~(alignedSize - static_cast<size_t>(1))); + } +}; + +class Address : public Operand { +public: + enum Mode { + M_ModRM, + M_64bitDisp, + M_rip, + M_ripAddr + }; + Address(uint32 sizeBit, bool broadcast, const RegExp& e) + : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast) + { + e_.verify(); + } +#ifdef XBYAK64 + explicit Address(size_t disp) + : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ } + Address(uint32 sizeBit, bool broadcast, const RegRip& addr) + : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { } +#endif + RegExp getRegExp(bool optimize = true) const + { + return optimize ? e_.optimize() : e_; + } + Mode getMode() const { return mode_; } + bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; } + bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax + size_t getDisp() const { return e_.getDisp(); } + uint8 getRex() const + { + if (mode_ != M_ModRM) return 0; + return getRegExp().getRex(); + } + bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset + bool isBroadcast() const { return broadcast_; } + const Label* getLabel() const { return label_; } + bool operator==(const Address& rhs) const + { + return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_; + } + bool operator!=(const Address& rhs) const { return !operator==(rhs); } + bool isVsib() const { return e_.isVsib(); } +private: + RegExp e_; + const Label* label_; + Mode mode_; + bool broadcast_; +}; + +inline const Address& Operand::getAddress() const +{ + assert(isMEM()); + return static_cast<const Address&>(*this); +} + +inline bool Operand::operator==(const Operand& rhs) const +{ + if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress(); + return isEqualIfNotInherited(rhs); +} + +class AddressFrame { + void operator=(const AddressFrame&); + AddressFrame(const AddressFrame&); +public: + const uint32 bit_; + const bool broadcast_; + explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { } + Address operator[](const RegExp& e) const + { + return Address(bit_, broadcast_, e); + } + Address operator[](const void *disp) const + { + return Address(bit_, broadcast_, RegExp(reinterpret_cast<size_t>(disp))); + } +#ifdef XBYAK64 + Address operator[](uint64 disp) const { return Address(disp); } + Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); } +#endif +}; + +struct JmpLabel { + size_t endOfJmp; /* offset from top to the end address of jmp */ + int jmpSize; + inner::LabelMode mode; + size_t disp; // disp for [rip + disp] + explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0) + : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp) + { + } +}; + +class LabelManager; + +class Label { + mutable LabelManager *mgr; + mutable int id; + friend class LabelManager; +public: + Label() : mgr(0), id(0) {} + Label(const Label& rhs); + Label& operator=(const Label& rhs); + ~Label(); + void clear() { mgr = 0; id = 0; } + int getId() const { return id; } + const uint8 *getAddress() const; + + // backward compatibility + static inline std::string toStr(int num) + { + char buf[16]; +#if defined(_MSC_VER) && (_MSC_VER < 1900) + _snprintf_s +#else + snprintf +#endif + (buf, sizeof(buf), ".%08x", num); + return buf; + } +}; + +class LabelManager { + // for string label + struct SlabelVal { + size_t offset; + SlabelVal(size_t offset) : offset(offset) {} + }; + typedef XBYAK_STD_UNORDERED_MAP<std::string, SlabelVal> SlabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP<std::string, const JmpLabel> SlabelUndefList; + struct SlabelState { + SlabelDefList defList; + SlabelUndefList undefList; + }; + typedef std::list<SlabelState> StateList; + // for Label class + struct ClabelVal { + ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {} + size_t offset; + int refCount; + }; + typedef XBYAK_STD_UNORDERED_MAP<int, ClabelVal> ClabelDefList; + typedef XBYAK_STD_UNORDERED_MULTIMAP<int, const JmpLabel> ClabelUndefList; + typedef XBYAK_STD_UNORDERED_SET<Label*> LabelPtrList; + + CodeArray *base_; + // global : stateList_.front(), local : stateList_.back() + StateList stateList_; + mutable int labelId_; + ClabelDefList clabelDefList_; + ClabelUndefList clabelUndefList_; + LabelPtrList labelPtrList_; + + int getId(const Label& label) const + { + if (label.id == 0) label.id = labelId_++; + return label.id; + } + template<class DefList, class UndefList, class T> + void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset) + { + // add label + typename DefList::value_type item(labelId, addrOffset); + std::pair<typename DefList::iterator, bool> ret = defList.insert(item); + if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED); + // search undefined label + for (;;) { + typename UndefList::iterator itr = undefList.find(labelId); + if (itr == undefList.end()) break; + const JmpLabel *jmp = &itr->second; + const size_t offset = jmp->endOfJmp - jmp->jmpSize; + size_t disp; + if (jmp->mode == inner::LaddTop) { + disp = addrOffset; + } else if (jmp->mode == inner::Labs) { + disp = size_t(base_->getCurr()); + } else { + disp = addrOffset - jmp->endOfJmp + jmp->disp; +#ifdef XBYAK64 + if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR); + } + if (base_->isAutoGrow()) { + base_->save(offset, disp, jmp->jmpSize, jmp->mode); + } else { + base_->rewrite(offset, disp, jmp->jmpSize); + } + undefList.erase(itr); + } + } + template<class DefList, class T> + bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const + { + typename DefList::const_iterator i = defList.find(label); + if (i == defList.end()) return false; + *offset = i->second.offset; + return true; + } + friend class Label; + void incRefCount(int id, Label *label) + { + clabelDefList_[id].refCount++; + labelPtrList_.insert(label); + } + void decRefCount(int id, Label *label) + { + labelPtrList_.erase(label); + ClabelDefList::iterator i = clabelDefList_.find(id); + if (i == clabelDefList_.end()) return; + if (i->second.refCount == 1) { + clabelDefList_.erase(id); + } else { + --i->second.refCount; + } + } + template<class T> + bool hasUndefinedLabel_inner(const T& list) const + { +#ifndef NDEBUG + for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) { + std::cerr << "undefined label:" << i->first << std::endl; + } +#endif + return !list.empty(); + } + // detach all labels linked to LabelManager + void resetLabelPtrList() + { + for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) { + (*i)->clear(); + } + labelPtrList_.clear(); + } +public: + LabelManager() + { + reset(); + } + ~LabelManager() + { + resetLabelPtrList(); + } + void reset() + { + base_ = 0; + labelId_ = 1; + stateList_.clear(); + stateList_.push_back(SlabelState()); + stateList_.push_back(SlabelState()); + clabelDefList_.clear(); + clabelUndefList_.clear(); + resetLabelPtrList(); + } + void enterLocal() + { + stateList_.push_back(SlabelState()); + } + void leaveLocal() + { + if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL); + if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND); + stateList_.pop_back(); + } + void set(CodeArray *base) { base_ = base; } + void defineSlabel(std::string label) + { + if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR); + if (label == "@@") { + SlabelDefList& defList = stateList_.front().defList; + SlabelDefList::iterator i = defList.find("@f"); + if (i != defList.end()) { + defList.erase(i); + label = "@b"; + } else { + i = defList.find("@b"); + if (i != defList.end()) { + defList.erase(i); + } + label = "@f"; + } + } + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + define_inner(st.defList, st.undefList, label, base_->getSize()); + } + void defineClabel(Label& label) + { + define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize()); + label.mgr = this; + labelPtrList_.insert(&label); + } + void assign(Label& dst, const Label& src) + { + ClabelDefList::const_iterator i = clabelDefList_.find(src.id); + if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L); + define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset); + dst.mgr = this; + labelPtrList_.insert(&dst); + } + bool getOffset(size_t *offset, std::string& label) const + { + const SlabelDefList& defList = stateList_.front().defList; + if (label == "@b") { + if (defList.find("@f") != defList.end()) { + label = "@f"; + } else if (defList.find("@b") == defList.end()) { + throw Error(ERR_LABEL_IS_NOT_FOUND); + } + } else if (label == "@f") { + if (defList.find("@f") != defList.end()) { + label = "@b"; + } + } + const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + return getOffset_inner(st.defList, offset, label); + } + bool getOffset(size_t *offset, const Label& label) const + { + return getOffset_inner(clabelDefList_, offset, getId(label)); + } + void addUndefinedLabel(const std::string& label, const JmpLabel& jmp) + { + SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front(); + st.undefList.insert(SlabelUndefList::value_type(label, jmp)); + } + void addUndefinedLabel(const Label& label, const JmpLabel& jmp) + { + clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp)); + } + bool hasUndefSlabel() const + { + for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) { + if (hasUndefinedLabel_inner(i->undefList)) return true; + } + return false; + } + bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); } + const uint8 *getCode() const { return base_->getCode(); } + bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); } +}; + +inline Label::Label(const Label& rhs) +{ + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); +} +inline Label& Label::operator=(const Label& rhs) +{ + if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L); + id = rhs.id; + mgr = rhs.mgr; + if (mgr) mgr->incRefCount(id, this); + return *this; +} +inline Label::~Label() +{ + if (id && mgr) mgr->decRefCount(id, this); +} +inline const uint8* Label::getAddress() const +{ + if (mgr == 0 || !mgr->isReady()) return 0; + size_t offset; + if (!mgr->getOffset(&offset, *this)) return 0; + return mgr->getCode() + offset; +} + +class CodeGenerator : public CodeArray { +public: + enum LabelType { + T_SHORT, + T_NEAR, + T_AUTO // T_SHORT if possible + }; +private: + CodeGenerator operator=(const CodeGenerator&); // don't call +#ifdef XBYAK64 + enum { i32e = 32 | 64, BIT = 64 }; + static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788; + typedef Reg64 NativeReg; +#else + enum { i32e = 32, BIT = 32 }; + static const size_t dummyAddr = 0x12345678; + typedef Reg32 NativeReg; +#endif + // (XMM, XMM|MEM) + static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isXMM() || op2.isMEM()); + } + // (MMX, MMX|MEM) or (XMM, XMM|MEM) + static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2) + { + return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2); + } + // (XMM, MMX|MEM) + static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isMMX() || op2.isMEM()); + } + // (MMX, XMM|MEM) + static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isMMX() && (op2.isXMM() || op2.isMEM()); + } + // (XMM, REG32|MEM) + static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM()); + } + // (REG32, XMM|MEM) + static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM()); + } + // (REG32, REG32|MEM) + static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2) + { + return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM()); + } + void rex(const Operand& op1, const Operand& op2 = Operand()) + { + uint8 rex = 0; + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM()) std::swap(p1, p2); + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isMEM()) { + const Address& addr = p2->getAddress(); + if (BIT == 64 && addr.is32bit()) db(0x67); + rex = addr.getRex() | p1->getReg().getRex(); + } else { + // ModRM(reg, base); + rex = op2.getReg().getRex(op1.getReg()); + } + // except movsx(16bit, 32/64bit) + if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66); + if (rex) db(rex); + } + enum AVXtype { + // low 3 bit + T_N1 = 1, + T_N2 = 2, + T_N4 = 3, + T_N8 = 4, + T_N16 = 5, + T_N32 = 6, + T_NX_MASK = 7, + // + T_N_VL = 1 << 3, // N * (1, 2, 4) for VL + T_DUP = 1 << 4, // N = (8, 32, 64) + T_66 = 1 << 5, + T_F3 = 1 << 6, + T_F2 = 1 << 7, + T_0F = 1 << 8, + T_0F38 = 1 << 9, + T_0F3A = 1 << 10, + T_L0 = 1 << 11, + T_L1 = 1 << 12, + T_W0 = 1 << 13, + T_W1 = 1 << 14, + T_EW0 = 1 << 15, + T_EW1 = 1 << 16, + T_YMM = 1 << 17, // support YMM, ZMM + T_EVEX = 1 << 18, + T_ER_X = 1 << 19, // xmm{er} + T_ER_Y = 1 << 20, // ymm{er} + T_ER_Z = 1 << 21, // zmm{er} + T_SAE_X = 1 << 22, // xmm{sae} + T_SAE_Y = 1 << 23, // ymm{sae} + T_SAE_Z = 1 << 24, // zmm{sae} + T_MUST_EVEX = 1 << 25, // contains T_EVEX + T_B32 = 1 << 26, // m32bcst + T_B64 = 1 << 27, // m64bcst + T_M_K = 1 << 28, // mem{k} + T_VSIB = 1 << 29, + T_MEM_EVEX = 1 << 30, // use evex if mem + T_XXX + }; + void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false) + { + int w = (type & T_W1) ? 1 : 0; + bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM(); + bool r = reg.isExtIdx(); + bool b = base.isExtIdx(); + int idx = v ? v->getIdx() : 0; + if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION); + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp; + if (!b && !x && !w && (type & T_0F)) { + db(0xC5); db((r ? 0 : 0x80) | vvvv); + } else { + uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv); + } + db(code); + } + void verifySAE(const Reg& r, int type) const + { + if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return; + throw Error(ERR_SAE_IS_INVALID); + } + void verifyER(const Reg& r, int type) const + { + if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return; + throw Error(ERR_ER_IS_INVALID); + } + // (a, b, c) contains non zero two or three values then err + int verifyDuplicate(int a, int b, int c, int err) + { + int v = a | b | c; + if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err); + return v; + } + int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false) + { + if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID); + int w = (type & T_EW1) ? 1 : 0; + uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0; + uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0; + + int idx = v ? v->getIdx() : 0; + uint32 vvvv = ~idx; + + bool R = !reg.isExtIdx(); + bool X = x ? false : !base.isExtIdx2(); + bool B = !base.isExtIdx(); + bool Rp = !reg.isExtIdx2(); + int LL; + int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET); + int disp8N = 1; + if (rounding) { + if (rounding == EvexModifierRounding::T_SAE) { + verifySAE(base, type); LL = 0; + } else { + verifyER(base, type); LL = rounding - 1; + } + b = true; + } else { + if (v) VL = (std::max)(VL, v->getBit()); + VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL); + LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0; + if (b) { + disp8N = (type & T_B32) ? 4 : 8; + } else if (type & T_DUP) { + disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64; + } else { + if ((type & (T_NX_MASK | T_N_VL)) == 0) { + type |= T_N16 | T_N_VL; // default + } + int low = type & T_NX_MASK; + if (low > 0) { + disp8N = 1 << (low - 1); + if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1); + } + } + } + bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx); + bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false); + if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET); + db(0x62); + db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3)); + db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3)); + db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7)); + db(code); + return disp8N; + } + void setModRM(int mod, int r1, int r2) + { + db(static_cast<uint8>((mod << 6) | ((r1 & 7) << 3) | (r2 & 7))); + } + void setSIB(const RegExp& e, int reg, int disp8N = 0) + { + size_t disp64 = e.getDisp(); +#ifdef XBYAK64 + size_t high = disp64 >> 32; + if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG); +#endif + uint32 disp = static_cast<uint32>(disp64); + const Reg& base = e.getBase(); + const Reg& index = e.getIndex(); + const int baseIdx = base.getIdx(); + const int baseBit = base.getBit(); + const int indexBit = index.getBit(); + enum { + mod00 = 0, mod01 = 1, mod10 = 2 + }; + int mod = mod10; // disp32 + if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) { + mod = mod00; + } else { + if (disp8N == 0) { + if (inner::IsInDisp8(disp)) { + mod = mod01; + } + } else { + // disp must be casted to signed + uint32 t = static_cast<uint32>(static_cast<int>(disp) / disp8N); + if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) { + disp = t; + mod = mod01; + } + } + } + const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP; + /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */ + bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP; +#ifdef XBYAK64 + if (!baseBit && !indexBit) hasSIB = true; +#endif + if (hasSIB) { + setModRM(mod, reg, Operand::ESP); + /* SIB = [2:3:3] = [SS:index:base(=rm)] */ + const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP; + const int scale = e.getScale(); + const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0; + setModRM(SS, idx, newBaseIdx); + } else { + setModRM(mod, reg, newBaseIdx); + } + if (mod == mod01) { + db(disp); + } else if (mod == mod10 || (mod == mod00 && !baseBit)) { + dd(disp); + } + } + LabelManager labelMgr_; + bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; } + void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE) + { + rex(reg2, reg1); + db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + setModRM(3, reg1.getIdx(), reg2.getIdx()); + } + void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + rex(addr, reg); + db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2); + opAddr(addr, reg.getIdx(), immSize); + } + void opMIB(const Address& addr, const Reg& reg, int code0, int code1) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS); + if (BIT == 64 && addr.is32bit()) db(0x67); + const RegExp& regExp = addr.getRegExp(false); + uint8 rex = regExp.getRex(); + if (rex) db(rex); + db(code0); db(code1); + setSIB(regExp, reg.getIdx()); + } + void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + const int shortJmpSize = 2; + const int longHeaderSize = longPref ? 2 : 1; + const int longJmpSize = longHeaderSize + 4; + if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) { + db(shortCode); db(disp - shortJmpSize); + } else { + if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR); + if (longPref) db(longPref); + db(longCode); dd(disp - longJmpSize); + } + } + template<class T> + void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref) + { + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */ + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { /* label exists */ + makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref); + } else { + int jmpSize = 0; + if (type == T_NEAR) { + jmpSize = 4; + if (longPref) db(longPref); + db(longCode); dd(0); + } else { + jmpSize = 1; + db(shortCode); db(0); + } + JmpLabel jmp(size_, jmpSize, inner::LasIs); + labelMgr_.addUndefinedLabel(label, jmp); + } + } + void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0) + { + if (isAutoGrow()) { + if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW); + if (size_ + 16 >= maxSize_) growMemory(); + if (longPref) db(longPref); + db(longCode); + dd(0); + save(size_ - 4, size_t(addr) - size_, 4, inner::Labs); + } else { + makeJmp(inner::VerifyInInt32(reinterpret_cast<const uint8*>(addr) - getCurr()), type, shortCode, longCode, longPref); + } + + } + // reg is reg field of ModRM + // immSize is the size for immediate value + // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement + void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false) + { + if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING); + if (addr.getMode() == Address::M_ModRM) { + setSIB(addr.getRegExp(), reg, disp8N); + } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) { + setModRM(0, reg, 5); + if (addr.getLabel()) { // [rip + Label] + putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize); + } else { + size_t disp = addr.getDisp(); + if (addr.getMode() == Address::M_ripAddr) { + if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW); + disp -= (size_t)getCurr() + 4 + immSize; + } + dd(inner::VerifyInInt32(disp)); + } + } + } + /* preCode is for SSSE3/SSE4 */ + void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE) + { + if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION); + if (pref != NONE) db(pref); + if (op.isMEM()) { + opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0); + } else { + opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code); + } + if (imm8 != NONE) db(imm8); + } + void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext) + { + if (mmx.isXMM()) db(0x66); + opModR(Reg32(ext), mmx, 0x0F, code); + db(imm8); + } + void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE) + { + opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode); + } + void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref) + { + if (pref != NONE) db(pref); + if (op1.isXMM() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), 0x0F, code); + } else if (op1.isMEM() && op2.isXMM()) { + opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false) + { + if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */ + if (mmx.isXMM()) db(0x66); + opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm); + } else { + opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A); + } + } + void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0) + { + int opBit = op.getBit(); + if (disableRex && opBit == 64) opBit = 32; + if (op.isREG(bit)) { + opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2); + } else if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShift(const Operand& op, int imm, int ext) + { + verifyMemHasSize(op); + opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0); + if (imm != 1) db(imm); + } + void opShift(const Operand& op, const Reg8& _cl, int ext) + { + if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opR_ModM(op, 0, ext, 0xD2); + } + void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0) + { + if (condR) { + opModR(op1.getReg(), op2.getReg(), code0, code1, code2); + } else if (condM) { + opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0) + { + if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION); + opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1); + if (!_cl) db(imm); + } + // (REG, REG|MEM), (MEM, REG) + void opRM_RM(const Operand& op1, const Operand& op2, int code) + { + if (op1.isREG() && op2.isMEM()) { + opModM(op2.getAddress(), op1.getReg(), code | 2); + } else { + opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code); + } + } + // (REG|MEM, IMM) + void opRM_I(const Operand& op, uint32 imm, int code, int ext) + { + verifyMemHasSize(op); + uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32; + if (op.isBit(8)) immBit = 8; + if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG); + if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */ + if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al + rex(op); + db(code | 4 | (immBit == 8 ? 0 : 1)); + } else { + int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0; + opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8); + } + db(imm, immBit / 8); + } + void opIncDec(const Operand& op, int code, int ext) + { + verifyMemHasSize(op); +#ifndef XBYAK64 + if (op.isREG() && !op.isBit(8)) { + rex(op); db(code | op.getIdx()); + return; + } +#endif + code = 0xFE; + if (op.isREG()) { + opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code); + } else { + opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code); + } + } + void opPushPop(const Operand& op, int code, int ext, int alt) + { + int bit = op.getBit(); + if (bit == 16 || bit == BIT) { + if (bit == 16) db(0x66); + if (op.isREG()) { + if (op.getReg().getIdx() >= 8) db(0x41); + db(alt | (op.getIdx() & 7)); + return; + } + if (op.isMEM()) { + opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code); + return; + } + } + throw Error(ERR_BAD_COMBINATION); + } + void verifyMemHasSize(const Operand& op) const + { + if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED); + } + /* + mov(r, imm) = db(imm, mov_imm(r, imm)) + */ + int mov_imm(const Reg& reg, size_t imm) + { + int bit = reg.getBit(); + const int idx = reg.getIdx(); + int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3); + if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) { + rex(Reg32(idx)); + bit = 32; + } else { + rex(reg); + if (bit == 64 && inner::IsInInt32(imm)) { + db(0xC7); + code = 0xC0; + bit = 32; + } + } + db(code | (idx & 7)); + return bit / 8; + } + template<class T> + void putL_inner(T& label, bool relative = false, size_t disp = 0) + { + const int jmpSize = relative ? 4 : (int)sizeof(size_t); + if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); + size_t offset = 0; + if (labelMgr_.getOffset(&offset, label)) { + if (relative) { + db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize); + } else if (isAutoGrow()) { + db(uint64(0), jmpSize); + save(size_ - jmpSize, offset, jmpSize, inner::LaddTop); + } else { + db(size_t(top_) + offset, jmpSize); + } + return; + } + db(uint64(0), jmpSize); + JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp); + labelMgr_.addUndefinedLabel(label, jmp); + } + void opMovxx(const Reg& reg, const Operand& op, uint8 code) + { + if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION); + int w = op.isBit(16); +#ifdef XBYAK64 + if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION); +#endif + bool cond = reg.isREG() && (reg.getBit() > op.getBit()); + opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w); + } + void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext) + { + if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP); + uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0; + if (!code) throw Error(ERR_BAD_MEM_SIZE); + if (m64ext && addr.isBit(64)) ext = m64ext; + + rex(addr, st0); + db(code); + opAddr(addr, ext); + } + // use code1 if reg1 == st0 + // use code2 if reg1 != st0 && reg2 == st0 + void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2) + { + uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0; + if (!code) throw Error(ERR_BAD_ST_COMBINATION); + db(uint8(code >> 8)); + db(uint8(code | (reg1.getIdx() | reg2.getIdx()))); + } + void opFpu(const Fpu& reg, uint8 code1, uint8 code2) + { + db(code1); db(code2 | reg.getIdx()); + } + void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE) + { + if (op2.isMEM()) { + const Address& addr = op2.getAddress(); + const RegExp& regExp = addr.getRegExp(); + const Reg& base = regExp.getBase(); + const Reg& index = regExp.getIndex(); + if (BIT == 64 && addr.is32bit()) db(0x67); + int disp8N = 0; + bool x = index.isExtIdx(); + if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) { + int aaa = addr.getOpmaskIdx(); + if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY); + bool b = false; + if (addr.isBroadcast()) { + if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST); + b = true; + } + int VL = regExp.isVsib() ? index.getBit() : 0; + disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2()); + } else { + vex(r, base, p1, type, code, x); + } + opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0); + } else { + const Reg& base = op2.getReg(); + if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) { + evex(r, base, p1, type, code); + } else { + vex(r, base, p1, type, code); + } + setModRM(3, r.getIdx(), base.getIdx()); + } + if (imm8 != NONE) db(imm8); + } + // (r, r, r/m) if isR_R_RM + // (r, r/m, r) + void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE) + { + const Operand *p1 = &op1; + const Operand *p2 = &op2; + if (!isR_R_RM) std::swap(p1, p2); + const unsigned int bit = r.getBit(); + if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION); + type |= (bit == 64) ? T_W1 : T_W0; + opVex(r, p1, *p2, type, code, imm8); + } + void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE) + { + const Xmm *x2 = static_cast<const Xmm*>(&op1); + const Operand *op = &op2; + if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1) + x2 = &x1; + op = &op1; + } + // (x1, x2, op) + if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION); + opVex(x1, x2, *op, type, code0, imm8); + } + void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE) + { + if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION); + opVex(k, &x2, op3, type, code0, imm8); + } + // (x, x/m), (y, x/m256), (z, y/m) + void checkCvt1(const Operand& x, const Operand& op) const + { + if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION); + } + // (x, x/m), (x, y/m256), (y, z/m) + void checkCvt2(const Xmm& x, const Operand& op) const + { + if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION); + } + void opCvt2(const Xmm& x, const Operand& op, int type, int code) + { + checkCvt2(x, op); + Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM; + opVex(x.copyAndSetKind(kind), &xm0, op, type, code); + } + void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code) + { + if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER); + Xmm x(op.getIdx()); + const Operand *p = op.isREG() ? &x : &op; + opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code); + } + const Xmm& cvtIdx0(const Operand& x) const + { + return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0; + } + // support (x, x/m, imm), (y, y/m, imm) + void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE) + { + opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8); + } + // QQQ:need to refactor + void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1) + { + if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); + bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM()); + if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); + if (is16bit) db(0x66); + db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1); + } + void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode) + { + const RegExp& regExp = addr.getRegExp(); + if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING); + const int y_vx_y = 0; + const int y_vy_y = 1; +// const int x_vy_x = 2; + const bool isAddrYMM = regExp.getIndex().getBit() == 256; + if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) { + bool isOK = false; + if (mode == y_vx_y) { + isOK = x1.isYMM() && !isAddrYMM && x2.isYMM(); + } else if (mode == y_vy_y) { + isOK = x1.isYMM() && isAddrYMM && x2.isYMM(); + } else { // x_vy_x + isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM(); + } + if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING); + } + opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code); + } + enum { + xx_yy_zz = 0, + xx_yx_zy = 1, + xx_xy_yz = 2 + }; + void checkGather2(const Xmm& x1, const Reg& x2, int mode) const + { + if (x1.isXMM() && x2.isXMM()) return; + switch (mode) { + case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return; + break; + case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return; + break; + case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return; + break; + } + throw Error(ERR_BAD_VSIB_ADDRESSING); + } + void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode) + { + if (x.hasZero()) throw Error(ERR_INVALID_ZERO); + checkGather2(x, addr.getRegExp().getIndex(), mode); + opVex(x, 0, addr, type, code); + } + /* + xx_xy_yz ; mode = true + xx_xy_xz ; mode = false + */ + void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode) + { + if (mode) { + if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION); + } else { + if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); + } + opVex(x, 0, op, type, code); + } + void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind) + { + if (addr.hasZero()) throw Error(ERR_INVALID_ZERO); + if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING); + opVex(x, 0, addr, type, code); + } +public: + unsigned int getVersion() const { return VERSION; } + using CodeArray::db; + const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7; + const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7; + const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7; + const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7; + const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7; + const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7; + const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi; + const Reg16 ax, cx, dx, bx, sp, bp, si, di; + const Reg8 al, cl, dl, bl, ah, ch, dh, bh; + const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM + const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b} + const Fpu st0, st1, st2, st3, st4, st5, st6, st7; + const Opmask k0, k1, k2, k3, k4, k5, k6, k7; + const BoundsReg bnd0, bnd1, bnd2, bnd3; + const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae} + const EvexModifierZero T_z; // {z} +#ifdef XBYAK64 + const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15; + const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d; + const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w; + const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b; + const Reg8 spl, bpl, sil, dil; + const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15; + const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23; + const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31; + const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15; + const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23; + const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15; + const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23; + const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience + const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23; + const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31; + const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15; + const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23; + const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31; + const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15; + const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23; + const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31; + const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT + const Segment es, cs, ss, ds, fs, gs; +#endif + void L(const std::string& label) { labelMgr_.defineSlabel(label); } + void L(Label& label) { labelMgr_.defineClabel(label); } + Label L() { Label label; L(label); return label; } + void inLocalLabel() { labelMgr_.enterLocal(); } + void outLocalLabel() { labelMgr_.leaveLocal(); } + /* + assign src to dst + require + dst : does not used by L() + src : used by L() + */ + void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); } + /* + put address of label to buffer + @note the put size is 4(32-bit), 8(64-bit) + */ + void putL(std::string label) { putL_inner(label); } + void putL(const Label& label) { putL_inner(label); } + + void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); } + void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); } + void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); } + void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); } + + void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); } + // call(string label), not const std::string& + void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + void call(const char *label) { call(std::string(label)); } + void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); } + // call(function pointer) +#ifdef XBYAK_VARIADIC_TEMPLATE + template<class Ret, class... Params> + void call(Ret(*func)(Params...)) { call(reinterpret_cast<const void*>(func)); } +#endif + void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); } + + void test(const Operand& op, const Reg& reg) + { + opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84); + } + void test(const Operand& op, uint32 imm) + { + verifyMemHasSize(op); + int immSize = (std::min)(op.getBit() / 8, 4U); + if (op.isREG() && op.getIdx() == 0) { // al, ax, eax + rex(op); + db(0xA8 | (op.isBit(8) ? 0 : 1)); + } else { + opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize); + } + db(imm, immSize); + } + void imul(const Reg& reg, const Operand& op) + { + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF); + } + void imul(const Reg& reg, const Operand& op, int imm) + { + int s = inner::IsInDisp8(imm) ? 1 : 0; + int immSize = s ? 1 : reg.isREG(16) ? 2 : 4; + opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize); + db(imm, immSize); + } + void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); } + void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); } + void push(const AddressFrame& af, uint32 imm) + { + if (af.bit_ == 8 && inner::IsInDisp8(imm)) { + db(0x6A); db(imm); + } else if (af.bit_ == 16 && isInDisp16(imm)) { + db(0x66); db(0x68); dw(imm); + } else { + db(0x68); dd(imm); + } + } + /* use "push(word, 4)" if you want "push word 4" */ + void push(uint32 imm) + { + if (inner::IsInDisp8(imm)) { + push(byte, imm); + } else { + push(dword, imm); + } + } + void mov(const Operand& reg1, const Operand& reg2) + { + const Reg *reg = 0; + const Address *addr = 0; + uint8 code = 0; + if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp] + reg = ®1.getReg(); + addr= ®2.getAddress(); + code = 0xA0; + } else + if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al + reg = ®2.getReg(); + addr= ®1.getAddress(); + code = 0xA2; + } +#ifdef XBYAK64 + if (addr && addr->is64bitDisp()) { + if (code) { + rex(*reg); + db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3); + db(addr->getDisp(), 8); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } else +#else + if (code && addr->isOnlyDisp()) { + rex(*reg, *addr); + db(code | (reg->isBit(8) ? 0 : 1)); + dd(static_cast<uint32>(addr->getDisp())); + } else +#endif + { + opRM_RM(reg1, reg2, 0x88); + } + } + void mov(const Operand& op, size_t imm) + { + if (op.isREG()) { + const int size = mov_imm(op.getReg(), imm); + db(imm, size); + } else if (op.isMEM()) { + verifyMemHasSize(op); + int immSize = op.getBit() / 8; + if (immSize <= 4) { + sint64 s = sint64(imm) >> (immSize * 8); + if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG); + } else { + if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG); + immSize = 4; + } + opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize); + db(static_cast<uint32>(imm), immSize); + } else { + throw Error(ERR_BAD_COMBINATION); + } + } + void mov(const NativeReg& reg, const char *label) // can't use std::string + { + if (label == 0) { + mov(static_cast<const Operand&>(reg), 0); // call imm + return; + } + mov_imm(reg, dummyAddr); + putL(label); + } + void mov(const NativeReg& reg, const Label& label) + { + mov_imm(reg, dummyAddr); + putL(label); + } + void xchg(const Operand& op1, const Operand& op2) + { + const Operand *p1 = &op1, *p2 = &op2; + if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) { + p1 = &op2; p2 = &op1; + } + if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION); + if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0) +#ifdef XBYAK64 + && (p2->getIdx() != 0 || !p1->isREG(32)) +#endif + ) { + rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7)); + return; + } + opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1)); + } + +#ifndef XBYAK_DISABLE_SEGMENT + void push(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x06); break; + case Segment::cs: db(0x0E); break; + case Segment::ss: db(0x16); break; + case Segment::ds: db(0x1E); break; + case Segment::fs: db(0x0F); db(0xA0); break; + case Segment::gs: db(0x0F); db(0xA8); break; + default: + assert(0); + } + } + void pop(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x07); break; + case Segment::cs: throw Error(ERR_BAD_COMBINATION); + case Segment::ss: db(0x17); break; + case Segment::ds: db(0x1F); break; + case Segment::fs: db(0x0F); db(0xA1); break; + case Segment::gs: db(0x0F); db(0xA9); break; + default: + assert(0); + } + } + void putSeg(const Segment& seg) + { + switch (seg.getIdx()) { + case Segment::es: db(0x2E); break; + case Segment::cs: db(0x36); break; + case Segment::ss: db(0x3E); break; + case Segment::ds: db(0x26); break; + case Segment::fs: db(0x64); break; + case Segment::gs: db(0x65); break; + default: + assert(0); + } + } + void mov(const Operand& op, const Segment& seg) + { + opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C); + } + void mov(const Segment& seg, const Operand& op) + { + opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast<const Operand&>(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E); + } +#endif + + enum { NONE = 256 }; + // constructor + CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0) + : CodeArray(maxSize, userPtr, allocator) + , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7) + , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7) + , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7) + , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7) + // for my convenience + , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7) + , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7) + , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7) + + , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI) + , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI) + , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH) + , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512) + , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true) + , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7) + , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7) + , bnd0(0), bnd1(1), bnd2(2), bnd3(3) + , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE) + , T_z() +#ifdef XBYAK64 + , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15) + , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15) + , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15) + , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15) + , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true) + , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15) + , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23) + , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31) + , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15) + , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23) + , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31) + , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15) + , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23) + , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31) + // for my convenience + , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15) + , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23) + , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31) + , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15) + , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23) + , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31) + , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15) + , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23) + , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31) + , rip() +#endif +#ifndef XBYAK_DISABLE_SEGMENT + , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs) +#endif + { + labelMgr_.set(this); + } + void reset() + { + resetSize(); + labelMgr_.reset(); + labelMgr_.set(this); + } + bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); } + /* + MUST call ready() to complete generating code if you use AutoGrow mode. + It is not necessary for the other mode if hasUndefinedLabel() is true. + */ + void ready(ProtectMode mode = PROTECT_RWE) + { + if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND); + if (isAutoGrow()) { + calcJmpAddress(); + if (useProtect()) setProtectMode(mode); + } + } + // set read/exec + void readyRE() { return ready(PROTECT_RE); } +#ifdef XBYAK_TEST + void dump(bool doClear = true) + { + CodeArray::dump(); + if (doClear) size_ = 0; + } +#endif + +#ifdef XBYAK_UNDEF_JNL + #undef jnl +#endif + + /* + use single byte nop if useMultiByteNop = false + */ + void nop(size_t size = 1, bool useMultiByteNop = true) + { + if (!useMultiByteNop) { + for (size_t i = 0; i < size; i++) { + db(0x90); + } + return; + } + /* + Intel Architectures Software Developer's Manual Volume 2 + recommended multi-byte sequence of NOP instruction + AMD and Intel seem to agree on the same sequences for up to 9 bytes: + https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf + */ + static const uint8 nopTbl[9][9] = { + {0x90}, + {0x66, 0x90}, + {0x0F, 0x1F, 0x00}, + {0x0F, 0x1F, 0x40, 0x00}, + {0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00}, + {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00}, + {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00}, + }; + const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]); + while (size > 0) { + size_t len = (std::min)(n, size); + const uint8 *seq = nopTbl[len - 1]; + db(seq, len); + size -= len; + } + } + +#ifndef XBYAK_DONT_READ_LIST +#include "xbyak_mnemonic.h" + /* + use single byte nop if useMultiByteNop = false + */ + void align(size_t x = 16, bool useMultiByteNop = true) + { + if (x == 1) return; + if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN); + if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x); + size_t remain = size_t(getCurr()) % x; + if (remain) { + nop(x - remain, useMultiByteNop); + } + } +#endif +}; + +namespace util { +static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7); +static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7); +static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7); +static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7); +static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI); +static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI); +static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH); +static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512); +static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true); +static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7); +static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7); +static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3); +static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE); +static const EvexModifierZero T_z; +#ifdef XBYAK64 +static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15); +static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15); +static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15); +static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true); +static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15); +static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23); +static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31); +static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15); +static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23); +static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31); +static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15); +static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23); +static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31); +static const RegRip rip; +#endif +#ifndef XBYAK_DISABLE_SEGMENT +static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs); +#endif +} // util + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +} // end of namespace + +#endif // XBYAK_XBYAK_H_ diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h new file mode 100644 index 0000000000..a22e5224c3 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h @@ -0,0 +1,303 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +enum { + B00000000= 0, + B00000001= 1, + B00000010= 2, + B00000011= 3, + B00000100= 4, + B00000101= 5, + B00000110= 6, + B00000111= 7, + B00001000= 8, + B00001001= 9, + B00001010= 10, + B00001011= 11, + B00001100= 12, + B00001101= 13, + B00001110= 14, + B00001111= 15, + B00010000= 16, + B00010001= 17, + B00010010= 18, + B00010011= 19, + B00010100= 20, + B00010101= 21, + B00010110= 22, + B00010111= 23, + B00011000= 24, + B00011001= 25, + B00011010= 26, + B00011011= 27, + B00011100= 28, + B00011101= 29, + B00011110= 30, + B00011111= 31, + B00100000= 32, + B00100001= 33, + B00100010= 34, + B00100011= 35, + B00100100= 36, + B00100101= 37, + B00100110= 38, + B00100111= 39, + B00101000= 40, + B00101001= 41, + B00101010= 42, + B00101011= 43, + B00101100= 44, + B00101101= 45, + B00101110= 46, + B00101111= 47, + B00110000= 48, + B00110001= 49, + B00110010= 50, + B00110011= 51, + B00110100= 52, + B00110101= 53, + B00110110= 54, + B00110111= 55, + B00111000= 56, + B00111001= 57, + B00111010= 58, + B00111011= 59, + B00111100= 60, + B00111101= 61, + B00111110= 62, + B00111111= 63, + B01000000= 64, + B01000001= 65, + B01000010= 66, + B01000011= 67, + B01000100= 68, + B01000101= 69, + B01000110= 70, + B01000111= 71, + B01001000= 72, + B01001001= 73, + B01001010= 74, + B01001011= 75, + B01001100= 76, + B01001101= 77, + B01001110= 78, + B01001111= 79, + B01010000= 80, + B01010001= 81, + B01010010= 82, + B01010011= 83, + B01010100= 84, + B01010101= 85, + B01010110= 86, + B01010111= 87, + B01011000= 88, + B01011001= 89, + B01011010= 90, + B01011011= 91, + B01011100= 92, + B01011101= 93, + B01011110= 94, + B01011111= 95, + B01100000= 96, + B01100001= 97, + B01100010= 98, + B01100011= 99, + B01100100= 100, + B01100101= 101, + B01100110= 102, + B01100111= 103, + B01101000= 104, + B01101001= 105, + B01101010= 106, + B01101011= 107, + B01101100= 108, + B01101101= 109, + B01101110= 110, + B01101111= 111, + B01110000= 112, + B01110001= 113, + B01110010= 114, + B01110011= 115, + B01110100= 116, + B01110101= 117, + B01110110= 118, + B01110111= 119, + B01111000= 120, + B01111001= 121, + B01111010= 122, + B01111011= 123, + B01111100= 124, + B01111101= 125, + B01111110= 126, + B01111111= 127, + B10000000= 128, + B10000001= 129, + B10000010= 130, + B10000011= 131, + B10000100= 132, + B10000101= 133, + B10000110= 134, + B10000111= 135, + B10001000= 136, + B10001001= 137, + B10001010= 138, + B10001011= 139, + B10001100= 140, + B10001101= 141, + B10001110= 142, + B10001111= 143, + B10010000= 144, + B10010001= 145, + B10010010= 146, + B10010011= 147, + B10010100= 148, + B10010101= 149, + B10010110= 150, + B10010111= 151, + B10011000= 152, + B10011001= 153, + B10011010= 154, + B10011011= 155, + B10011100= 156, + B10011101= 157, + B10011110= 158, + B10011111= 159, + B10100000= 160, + B10100001= 161, + B10100010= 162, + B10100011= 163, + B10100100= 164, + B10100101= 165, + B10100110= 166, + B10100111= 167, + B10101000= 168, + B10101001= 169, + B10101010= 170, + B10101011= 171, + B10101100= 172, + B10101101= 173, + B10101110= 174, + B10101111= 175, + B10110000= 176, + B10110001= 177, + B10110010= 178, + B10110011= 179, + B10110100= 180, + B10110101= 181, + B10110110= 182, + B10110111= 183, + B10111000= 184, + B10111001= 185, + B10111010= 186, + B10111011= 187, + B10111100= 188, + B10111101= 189, + B10111110= 190, + B10111111= 191, + B11000000= 192, + B11000001= 193, + B11000010= 194, + B11000011= 195, + B11000100= 196, + B11000101= 197, + B11000110= 198, + B11000111= 199, + B11001000= 200, + B11001001= 201, + B11001010= 202, + B11001011= 203, + B11001100= 204, + B11001101= 205, + B11001110= 206, + B11001111= 207, + B11010000= 208, + B11010001= 209, + B11010010= 210, + B11010011= 211, + B11010100= 212, + B11010101= 213, + B11010110= 214, + B11010111= 215, + B11011000= 216, + B11011001= 217, + B11011010= 218, + B11011011= 219, + B11011100= 220, + B11011101= 221, + B11011110= 222, + B11011111= 223, + B11100000= 224, + B11100001= 225, + B11100010= 226, + B11100011= 227, + B11100100= 228, + B11100101= 229, + B11100110= 230, + B11100111= 231, + B11101000= 232, + B11101001= 233, + B11101010= 234, + B11101011= 235, + B11101100= 236, + B11101101= 237, + B11101110= 238, + B11101111= 239, + B11110000= 240, + B11110001= 241, + B11110010= 242, + B11110011= 243, + B11110100= 244, + B11110101= 245, + B11110110= 246, + B11110111= 247, + B11111000= 248, + B11111001= 249, + B11111010= 250, + B11111011= 251, + B11111100= 252, + B11111101= 253, + B11111110= 254, + B11111111= 255 +}; diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h new file mode 100644 index 0000000000..28d2d222f9 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h @@ -0,0 +1,2017 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +const char *getVersionString() const { return "5.76"; } +void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); } +void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); } +void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); } +void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); } +void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); } +void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); } +void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); } +void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); } +void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); } +void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); } +void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); } +void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); } +void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); } +void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); } +void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); } +void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); } +void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); } +void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); } +void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); } +void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); } +void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); } +void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); } +void bnd() { db(0xF2); } +void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); } +void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); } +void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); } +void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); } +void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); } +void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); } +void bsf(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); } +void bsr(const Reg®, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); } +void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); } +void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); } +void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); } +void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); } +void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); } +void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); } +void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); } +void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); } +void cbw() { db(0x66); db(0x98); } +void cdq() { db(0x99); } +void clc() { db(0xF8); } +void cld() { db(0xFC); } +void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); } +void cli() { db(0xFA); } +void cmc() { db(0xF5); } +void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524 +void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524 +void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524 +void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524 +void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524 +void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524 +void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524 +void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524 +void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524 +void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524 +void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524 +void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524 +void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524 +void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524 +void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524 +void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524 +void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); } +void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); } +void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); } +void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); } +void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); } +void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); } +void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); } +void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); } +void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); } +void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); } +void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); } +void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); } +void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); } +void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); } +void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); } +void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); } +void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); } +void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); } +void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); } +void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); } +void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); } +void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); } +void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); } +void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); } +void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); } +void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); } +void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); } +void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); } +void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); } +void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); } +void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); } +void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); } +void cmpsb() { db(0xA6); } +void cmpsd() { db(0xA7); } +void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); } +void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); } +void cmpsw() { db(0x66); db(0xA7); } +void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); } +void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); } +void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); } +void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); } +void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); } +void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); } +void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); } +void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); } +void cpuid() { db(0x0F); db(0xA2); } +void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); } +void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); } +void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); } +void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); } +void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); } +void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); } +void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); } +void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); } +void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); } +void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); } +void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); } +void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); } +void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); } +void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); } +void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); } +void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); } +void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); } +void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); } +void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); } +void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); } +void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); } +void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); } +void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); } +void cwd() { db(0x66); db(0x99); } +void cwde() { db(0x98); } +void dec(const Operand& op) { opIncDec(op, 0x48, 1); } +void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); } +void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); } +void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); } +void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); } +void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); } +void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void emms() { db(0x0F); db(0x77); } +void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); } +void f2xm1() { db(0xD9); db(0xF0); } +void fabs() { db(0xD9); db(0xE1); } +void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); } +void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); } +void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); } +void faddp() { db(0xDE); db(0xC1); } +void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); } +void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); } +void fchs() { db(0xD9); db(0xE0); } +void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); } +void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); } +void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); } +void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); } +void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); } +void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); } +void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); } +void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); } +void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); } +void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); } +void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); } +void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); } +void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); } +void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); } +void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); } +void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); } +void fcom() { db(0xD8); db(0xD1); } +void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); } +void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); } +void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); } +void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); } +void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); } +void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); } +void fcomp() { db(0xD8); db(0xD9); } +void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); } +void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); } +void fcompp() { db(0xDE); db(0xD9); } +void fcos() { db(0xD9); db(0xFF); } +void fdecstp() { db(0xD9); db(0xF6); } +void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); } +void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); } +void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); } +void fdivp() { db(0xDE); db(0xF9); } +void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); } +void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); } +void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); } +void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); } +void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); } +void fdivrp() { db(0xDE); db(0xF1); } +void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); } +void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); } +void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); } +void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); } +void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); } +void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); } +void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); } +void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); } +void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); } +void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); } +void fincstp() { db(0xD9); db(0xF7); } +void finit() { db(0x9B); db(0xDB); db(0xE3); } +void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); } +void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); } +void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); } +void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); } +void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); } +void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); } +void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); } +void fld1() { db(0xD9); db(0xE8); } +void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); } +void fldl2e() { db(0xD9); db(0xEA); } +void fldl2t() { db(0xD9); db(0xE9); } +void fldlg2() { db(0xD9); db(0xEC); } +void fldln2() { db(0xD9); db(0xED); } +void fldpi() { db(0xD9); db(0xEB); } +void fldz() { db(0xD9); db(0xEE); } +void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); } +void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); } +void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); } +void fmulp() { db(0xDE); db(0xC9); } +void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); } +void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); } +void fninit() { db(0xDB); db(0xE3); } +void fnop() { db(0xD9); db(0xD0); } +void fpatan() { db(0xD9); db(0xF3); } +void fprem() { db(0xD9); db(0xF8); } +void fprem1() { db(0xD9); db(0xF5); } +void fptan() { db(0xD9); db(0xF2); } +void frndint() { db(0xD9); db(0xFC); } +void fscale() { db(0xD9); db(0xFD); } +void fsin() { db(0xD9); db(0xFE); } +void fsincos() { db(0xD9); db(0xFB); } +void fsqrt() { db(0xD9); db(0xFA); } +void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); } +void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); } +void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); } +void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); } +void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); } +void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); } +void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); } +void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); } +void fsubp() { db(0xDE); db(0xE9); } +void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); } +void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); } +void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); } +void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); } +void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); } +void fsubrp() { db(0xDE); db(0xE1); } +void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); } +void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); } +void ftst() { db(0xD9); db(0xE4); } +void fucom() { db(0xDD); db(0xE1); } +void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); } +void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); } +void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); } +void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); } +void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); } +void fucomp() { db(0xDD); db(0xE9); } +void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); } +void fucompp() { db(0xDA); db(0xE9); } +void fwait() { db(0x9B); } +void fxam() { db(0xD9); db(0xE5); } +void fxch() { db(0xD9); db(0xC9); } +void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); } +void fxtract() { db(0xD9); db(0xF4); } +void fyl2x() { db(0xD9); db(0xF1); } +void fyl2xp1() { db(0xD9); db(0xF9); } +void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); } +void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); } +void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); } +void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); } +void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); } +void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); } +void inc(const Operand& op) { opIncDec(op, 0x40, 0); } +void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524 +void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524 +void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524 +void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524 +void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524 +void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524 +void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524 +void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524 +void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524 +void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524 +void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524 +void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524 +void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524 +void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524 +void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524 +void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524 +void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524 +void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524 +void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524 +void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524 +void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524 +void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524 +void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524 +void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524 +void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524 +void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524 +void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524 +void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524 +void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524 +void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524 +void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524 +void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524 +void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524 +void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524 +void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524 +void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524 +void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524 +void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524 +void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524 +void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524 +void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524 +void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524 +void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524 +void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524 +void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524 +void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524 +void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524 +void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524 +void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524 +void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524 +void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524 +void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524 +void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524 +void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524 +void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524 +void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524 +void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524 +void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524 +void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524 +void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524 +void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524 +void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524 +void lahf() { db(0x9F); } +void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); } +void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); } +void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); } +void lfence() { db(0x0F); db(0xAE); db(0xE8); } +void lock() { db(0xF0); } +void lzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); } +void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); } +void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); } +void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); } +void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); } +void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); } +void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); } +void mfence() { db(0x0F); db(0xAE); db(0xF0); } +void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); } +void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); } +void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); } +void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); } +void monitor() { db(0x0F); db(0x01); db(0xC8); } +void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); } +void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); } +void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); } +void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); } +void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); } +void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); } +void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); } +void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); } +void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); } +void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); } +void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); } +void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); } +void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); } +void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); } +void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); } +void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); } +void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); } +void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); } +void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); } +void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); } +void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); } +void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); } +void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); } +void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); } +void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); } +void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); } +void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); } +void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); } +void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); } +void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); } +void movsb() { db(0xA4); } +void movsd() { db(0xA5); } +void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); } +void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); } +void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); } +void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); } +void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); } +void movsw() { db(0x66); db(0xA5); } +void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); } +void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); } +void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); } +void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); } +void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); } +void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); } +void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); } +void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); } +void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); } +void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); } +void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); } +void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); } +void mwait() { db(0x0F); db(0x01); db(0xC9); } +void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); } +void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); } +void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); } +void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); } +void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); } +void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); } +void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); } +void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); } +void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); } +void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); } +void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); } +void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); } +void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); } +void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); } +void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); } +void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); } +void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); } +void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); } +void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); } +void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); } +void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast<uint8>(imm), 0x3a); } +void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); } +void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); } +void pause() { db(0xF3); db(0x90); } +void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); } +void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); } +void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); } +void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); } +void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); } +void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); } +void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); } +void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); } +void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); } +void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); } +void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); } +void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); } +void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); } +void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); } +void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); } +void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); } +void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); } +void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); } +void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); } +void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); } +void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); } +void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); } +void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); } +void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); } +void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); } +void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); } +void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); } +void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); } +void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); } +void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); } +void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); } +void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); } +void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); } +void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); } +void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); } +void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); } +void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); } +void popcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); } +void popf() { db(0x9D); } +void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); } +void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); } +void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); } +void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); } +void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); } +void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); } +void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); } +void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); } +void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); } +void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); } +void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); } +void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); } +void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); } +void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); } +void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); } +void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); } +void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); } +void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); } +void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); } +void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); } +void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); } +void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); } +void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); } +void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); } +void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); } +void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); } +void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); } +void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); } +void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); } +void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); } +void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); } +void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); } +void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); } +void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); } +void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); } +void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); } +void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); } +void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); } +void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); } +void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); } +void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); } +void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); } +void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); } +void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); } +void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); } +void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); } +void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); } +void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); } +void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); } +void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); } +void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); } +void pushf() { db(0x9C); } +void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); } +void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); } +void rcl(const Operand& op, int imm) { opShift(op, imm, 2); } +void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); } +void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); } +void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); } +void rcr(const Operand& op, int imm) { opShift(op, imm, 3); } +void rdmsr() { db(0x0F); db(0x32); } +void rdpmc() { db(0x0F); db(0x33); } +void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); } +void rdtsc() { db(0x0F); db(0x31); } +void rdtscp() { db(0x0F); db(0x01); db(0xF9); } +void rep() { db(0xF3); } +void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } } +void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); } +void rol(const Operand& op, int imm) { opShift(op, imm, 0); } +void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); } +void ror(const Operand& op, int imm) { opShift(op, imm, 1); } +void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); } +void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); } +void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); } +void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); } +void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); } +void sahf() { db(0x9E); } +void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void sal(const Operand& op, int imm) { opShift(op, imm, 4); } +void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); } +void sar(const Operand& op, int imm) { opShift(op, imm, 7); } +void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); } +void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); } +void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); } +void scasb() { db(0xAE); } +void scasd() { db(0xAF); } +void scasw() { db(0x66); db(0xAF); } +void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524 +void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524 +void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524 +void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524 +void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524 +void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524 +void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524 +void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524 +void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524 +void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524 +void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524 +void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524 +void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524 +void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524 +void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524 +void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524 +void sfence() { db(0x0F); db(0xAE); db(0xF8); } +void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); } +void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); } +void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); } +void shl(const Operand& op, int imm) { opShift(op, imm, 4); } +void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); } +void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); } +void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); } +void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); } +void shr(const Operand& op, int imm) { opShift(op, imm, 5); } +void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); } +void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); } +void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); } +void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); } +void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); } +void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); } +void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); } +void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); } +void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); } +void stac() { db(0x0F); db(0x01); db(0xCB); } +void stc() { db(0xF9); } +void std() { db(0xFD); } +void sti() { db(0xFB); } +void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); } +void stosb() { db(0xAA); } +void stosd() { db(0xAB); } +void stosw() { db(0x66); db(0xAB); } +void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); } +void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); } +void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); } +void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); } +void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); } +void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); } +void tzcnt(const Reg®, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); } +void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); } +void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); } +void ud2() { db(0x0F); db(0x0B); } +void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); } +void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); } +void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); } +void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); } +void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); } +void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); } +void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); } +void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); } +void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); } +void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); } +void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); } +void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); } +void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); } +void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); } +void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); } +void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); } +void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); } +void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); } +void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); } +void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); } +void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); } +void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); } +void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); } +void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); } +void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); } +void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); } +void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); } +void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); } +void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); } +void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); } +void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); } +void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); } +void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); } +void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); } +void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); } +void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); } +void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); } +void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); } +void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); } +void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); } +void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); } +void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); } +void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); } +void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); } +void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); } +void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); } +void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); } +void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); } +void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); } +void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); } +void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); } +void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); } +void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); } +void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); } +void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); } +void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); } +void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); } +void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); } +void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); } +void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); } +void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); } +void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); } +void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); } +void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); } +void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); } +void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); } +void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); } +void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); } +void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); } +void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); } +void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); } +void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); } +void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); } +void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); } +void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); } +void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); } +void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); } +void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); } +void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); } +void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); } +void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); } +void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); } +void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); } +void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); } +void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); } +void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); } +void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); } +void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); } +void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); } +void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); } +void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); } +void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); } +void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); } +void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); } +void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); } +void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); } +void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); } +void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); } +void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); } +void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); } +void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); } +void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); } +void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); } +void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); } +void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); } +void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); } +void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); } +void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); } +void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); } +void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); } +void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); } +void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); } +void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); } +void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); } +void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); } +void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); } +void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); } +void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); } +void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); } +void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); } +void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); } +void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); } +void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); } +void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); } +void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); } +void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); } +void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); } +void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); } +void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); } +void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); } +void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); } +void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); } +void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); } +void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); } +void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); } +void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); } +void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); } +void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); } +void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); } +void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); } +void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); } +void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); } +void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); } +void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); } +void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); } +void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); } +void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); } +void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); } +void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); } +void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); } +void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); } +void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); } +void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); } +void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); } +void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); } +void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); } +void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); } +void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); } +void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); } +void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); } +void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); } +void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); } +void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); } +void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); } +void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); } +void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); } +void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); } +void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); } +void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); } +void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); } +void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); } +void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); } +void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); } +void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); } +void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); } +void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); } +void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); } +void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); } +void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); } +void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); } +void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); } +void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); } +void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); } +void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); } +void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); } +void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); } +void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); } +void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); } +void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); } +void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); } +void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); } +void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); } +void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); } +void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); } +void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); } +void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); } +void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); } +void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); } +void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); } +void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); } +void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); } +void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); } +void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); } +void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); } +void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); } +void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); } +void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); } +void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); } +void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); } +void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); } +void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); } +void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); } +void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); } +void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); } +void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); } +void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); } +void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); } +void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); } +void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); } +void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); } +void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); } +void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); } +void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); } +void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); } +void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); } +void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); } +void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); } +void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); } +void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); } +void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); } +void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); } +void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); } +void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); } +void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); } +void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); } +void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); } +void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); } +void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); } +void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); } +void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); } +void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); } +void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); } +void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); } +void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); } +void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); } +void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); } +void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); } +void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); } +void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); } +void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); } +void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); } +void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); } +void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); } +void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); } +void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); } +void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); } +void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); } +void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); } +void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); } +void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); } +void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); } +void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); } +void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); } +void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); } +void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); } +void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); } +void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); } +void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); } +void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); } +void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); } +void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); } +void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); } +void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); } +void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); } +void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); } +void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); } +void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); } +void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); } +void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); } +void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); } +void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); } +void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); } +void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); } +void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); } +void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); } +void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); } +void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); } +void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); } +void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); } +void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); } +void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); } +void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); } +void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); } +void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); } +void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); } +void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); } +void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); } +void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); } +void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); } +void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); } +void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); } +void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); } +void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); } +void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); } +void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); } +void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); } +void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); } +void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); } +void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); } +void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); } +void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); } +void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); } +void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); } +void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); } +void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); } +void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); } +void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); } +void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); } +void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); } +void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); } +void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); } +void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); } +void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); } +void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); } +void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); } +void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); } +void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); } +void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); } +void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); } +void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); } +void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); } +void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); } +void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); } +void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); } +void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); } +void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); } +void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); } +void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); } +void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); } +void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); } +void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); } +void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); } +void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); } +void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); } +void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); } +void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); } +void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); } +void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); } +void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); } +void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); } +void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); } +void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); } +void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); } +void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); } +void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); } +void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); } +void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); } +void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); } +void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); } +void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); } +void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); } +void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); } +void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); } +void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); } +void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); } +void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); } +void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); } +void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); } +void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } } +void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); } +void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); } +void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); } +void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); } +void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); } +void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); } +void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); } +void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); } +void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); } +void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); } +void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); } +void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); } +void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); } +void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); } +void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); } +void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); } +void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); } +void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); } +void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); } +void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); } +void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); } +void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); } +void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); } +void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); } +void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); } +void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); } +void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); } +void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); } +void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); } +void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); } +void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); } +void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); } +void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); } +void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); } +void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); } +void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); } +void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); } +void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); } +void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); } +void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); } +void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); } +void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); } +void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); } +void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); } +void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); } +void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); } +void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); } +void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); } +void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); } +void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); } +void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); } +void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); } +void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); } +void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); } +void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); } +void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); } +void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); } +void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); } +void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); } +void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); } +void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); } +void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); } +void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); } +void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); } +void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); } +void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); } +void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); } +void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); } +void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); } +void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); } +void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); } +void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); } +void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); } +void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); } +void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); } +void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); } +void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); } +void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); } +void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); } +void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); } +void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); } +void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); } +void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); } +void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); } +void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); } +void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); } +void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); } +void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); } +void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); } +void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); } +void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); } +void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); } +void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); } +void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); } +void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); } +void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); } +void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); } +void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); } +void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); } +void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); } +void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); } +void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); } +void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); } +void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); } +void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); } +void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); } +void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); } +void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); } +void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); } +void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); } +void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); } +void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); } +void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); } +void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); } +void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); } +void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); } +void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); } +void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); } +void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); } +void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); } +void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); } +void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); } +void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); } +void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); } +void vzeroall() { db(0xC5); db(0xFC); db(0x77); } +void vzeroupper() { db(0xC5); db(0xF8); db(0x77); } +void wait() { db(0x9B); } +void wbinvd() { db(0x0F); db(0x09); } +void wrmsr() { db(0x0F); db(0x30); } +void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); } +void xgetbv() { db(0x0F); db(0x01); db(0xD0); } +void xlatb() { db(0xD7); } +void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); } +void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); } +void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); } +void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); } +#ifdef XBYAK_ENABLE_OMITTED_OPERAND +void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); } +void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); } +void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); } +void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); } +void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); } +void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); } +void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); } +void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); } +void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); } +void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); } +void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); } +void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); } +void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); } +void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); } +void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); } +void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); } +void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); } +void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); } +void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); } +void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); } +void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); } +void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); } +void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); } +void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); } +void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); } +void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); } +void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); } +void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); } +void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); } +void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); } +void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); } +void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); } +void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); } +void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); } +void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); } +void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); } +void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); } +void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); } +void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); } +void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); } +void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); } +void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); } +void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); } +void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); } +void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); } +void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); } +void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); } +void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); } +void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); } +void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); } +void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); } +void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); } +void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); } +void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); } +void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); } +void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); } +void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); } +void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); } +void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); } +void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); } +void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); } +void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); } +void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); } +void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); } +void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); } +void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); } +void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); } +void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); } +void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); } +void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); } +void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); } +void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); } +void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); } +void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); } +void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); } +void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); } +void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); } +void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); } +void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); } +void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); } +void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); } +void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); } +void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); } +void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); } +void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); } +void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); } +void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); } +void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); } +void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); } +void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); } +void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); } +void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); } +void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); } +void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); } +void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); } +void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); } +void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); } +void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); } +void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); } +void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); } +void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); } +void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); } +void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); } +void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); } +void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); } +void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); } +void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); } +void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); } +void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); } +void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); } +void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); } +void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); } +void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); } +void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); } +void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); } +void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); } +void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); } +void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); } +void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); } +void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); } +void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); } +void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); } +void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); } +void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); } +void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); } +void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); } +void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); } +void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); } +void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); } +void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); } +void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); } +void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); } +void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); } +void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); } +void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); } +void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); } +void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); } +void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); } +void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); } +void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); } +void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); } +void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); } +void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); } +void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); } +void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); } +void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); } +void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); } +void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); } +void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); } +void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); } +void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); } +void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); } +void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); } +void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); } +void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); } +void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); } +void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); } +void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); } +void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); } +void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); } +void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); } +void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); } +void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); } +void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); } +void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); } +void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); } +void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); } +void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); } +void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); } +void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); } +void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); } +void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); } +void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); } +void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); } +void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); } +void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); } +void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); } +void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); } +void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); } +void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); } +void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); } +void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); } +void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); } +void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); } +void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); } +void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); } +void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); } +void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); } +void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); } +void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); } +void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); } +void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); } +void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); } +void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); } +void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); } +void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); } +void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); } +void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); } +void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); } +void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); } +void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); } +void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); } +void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); } +void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); } +void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); } +void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); } +void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); } +void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); } +void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); } +void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); } +void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); } +void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); } +void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); } +void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); } +void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); } +void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); } +void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); } +void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); } +void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); } +void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); } +void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); } +void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); } +void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); } +void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); } +void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); } +void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); } +void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); } +void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); } +void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); } +void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); } +void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); } +void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); } +void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); } +void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); } +void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); } +void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); } +void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); } +void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); } +void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); } +void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); } +void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); } +void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); } +void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); } +void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); } +void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); } +void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); } +void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); } +void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); } +void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); } +void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); } +void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); } +void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); } +void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); } +void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); } +void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); } +void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); } +#endif +#ifdef XBYAK64 +void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void cdqe() { db(0x48); db(0x98); } +void cqo() { db(0x48); db(0x99); } +void cmpsq() { db(0x48); db(0xA7); } +void movsq() { db(0x48); db(0xA5); } +void scasq() { db(0x48); db(0xAF); } +void stosq() { db(0x48); db(0xAB); } +void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); } +void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); } +void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); } +void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); } +void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); } +void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); } +void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); } +void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); } +void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); } +void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); } +void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); } +void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); } +#else +void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); } +void aaa() { db(0x37); } +void aad() { db(0xD5); db(0x0A); } +void aam() { db(0xD4); db(0x0A); } +void aas() { db(0x3F); } +void daa() { db(0x27); } +void das() { db(0x2F); } +void popad() { db(0x61); } +void popfd() { db(0x9D); } +void pusha() { db(0x60); } +void pushad() { db(0x60); } +void pushfd() { db(0x9C); } +void popa() { db(0x61); } +#endif +#ifndef XBYAK_NO_OP_NAMES +void and(const Operand& op1, const Operand& op2) { and_(op1, op2); } +void and(const Operand& op, uint32 imm) { and_(op, imm); } +void or(const Operand& op1, const Operand& op2) { or_(op1, op2); } +void or(const Operand& op, uint32 imm) { or_(op, imm); } +void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); } +void xor(const Operand& op, uint32 imm) { xor_(op, imm); } +void not(const Operand& op) { not_(op); } +#endif +#ifndef XBYAK_DISABLE_AVX512 +void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); } +void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); } +void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); } +void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); } +void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); } +void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); } +void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); } +void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); } +void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); } +void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); } +void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); } +void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); } +void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); } +void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); } +void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); } +void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); } +void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); } +void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); } +void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); } +void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); } +void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); } +void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); } +void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); } +void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); } +void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); } +void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); } +void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); } +void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); } +void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); } +void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); } +void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); } +void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); } +void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); } +void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); } +void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); } +void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); } +void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); } +void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); } +void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); } +void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); } +void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); } +void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); } +void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); } +void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); } +void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); } +void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); } +void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); } +void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); } +void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); } +void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); } +void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); } +void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); } +void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); } +void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); } +void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); } +void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); } +void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); } +void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); } +void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); } +void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); } +void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); } +void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); } +void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); } +void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); } +void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); } +void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); } +void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); } +void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); } +void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); } +void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); } +void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); } +void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); } +void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); } +void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); } +void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); } +void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); } +void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); } +void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); } +void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); } +void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); } +void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); } +void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); } +void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); } +void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); } +void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); } +void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); } +void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); } +void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); } +void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); } +void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); } +void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); } +void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); } +void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); } +void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); } +void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); } +void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); } +void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); } +void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); } +void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); } +void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); } +void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); } +void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); } +void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); } +void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); } +void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); } +void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); } +void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); } +void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); } +void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); } +void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); } +void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); } +void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); } +void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); } +void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); } +void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); } +void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); } +void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); } +void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); } +void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); } +void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); } +void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); } +void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); } +void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); } +void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); } +void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); } +void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); } +void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); } +void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); } +void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); } +void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); } +void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); } +void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); } +void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); } +void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); } +void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); } +void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); } +void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); } +void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); } +void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); } +void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); } +void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); } +void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); } +void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); } +void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); } +void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); } +void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); } +void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); } +void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); } +void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); } +void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); } +void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); } +void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); } +void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); } +void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); } +void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); } +void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); } +void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); } +void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); } +void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); } +void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); } +void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); } +void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); } +void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); } +void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); } +void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); } +void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); } +void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); } +void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); } +void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); } +void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); } +void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); } +void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); } +void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); } +void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); } +void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); } +void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); } +void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); } +void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); } +void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); } +void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); } +void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); } +void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); } +void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); } +void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); } +void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); } +void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); } +void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); } +void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); } +void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); } +void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); } +void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); } +void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); } +void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); } +void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); } +void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); } +void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); } +void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); } +void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); } +void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); } +void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); } +void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); } +void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); } +void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); } +void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); } +void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); } +void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); } +void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); } +void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); } +void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); } +void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); } +void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); } +void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); } +void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); } +void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); } +void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); } +void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); } +void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); } +void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); } +void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); } +void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); } +void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); } +void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); } +void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); } +void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); } +void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); } +void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); } +void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); } +void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); } +void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); } +void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); } +void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); } +void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); } +void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); } +void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); } +void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); } +void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); } +void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); } +void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); } +void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); } +void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); } +void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); } +void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); } +void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); } +void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); } +void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); } +void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); } +void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); } +void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); } +void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); } +void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); } +void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); } +void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); } +void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); } +void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); } +void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); } +void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); } +void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); } +void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); } +void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); } +void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); } +void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); } +void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); } +void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); } +void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); } +void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); } +void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); } +void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); } +void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); } +void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); } +void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); } +void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); } +void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); } +void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); } +void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); } +void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); } +void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); } +void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); } +void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); } +void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); } +void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); } +void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); } +void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); } +void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); } +void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); } +void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); } +void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); } +void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); } +void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); } +void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); } +void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); } +void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); } +void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); } +void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); } +void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); } +void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); } +void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); } +void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); } +void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); } +void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); } +void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); } +void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); } +#ifdef XBYAK64 +void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); } +void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); } +void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); } +#endif +#endif diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h new file mode 100644 index 0000000000..8ef076e680 --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h @@ -0,0 +1,772 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/******************************************************************************* +* Copyright (c) 2007 MITSUNARI Shigeo +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* Neither the name of the copyright owner nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +*******************************************************************************/ + +#ifndef XBYAK_XBYAK_UTIL_H_ +#define XBYAK_XBYAK_UTIL_H_ + +/** + utility class and functions for Xbyak + Xbyak::util::Clock ; rdtsc timer + Xbyak::util::Cpu ; detect CPU + @note this header is UNDER CONSTRUCTION! +*/ +#include "xbyak.h" + +#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64) + #define XBYAK_INTEL_CPU_SPECIFIC +#endif + +#ifdef XBYAK_INTEL_CPU_SPECIFIC +#ifdef _MSC_VER + #if (_MSC_VER < 1400) && defined(XBYAK32) + static inline __declspec(naked) void __cpuid(int[4], int) + { + __asm { + push ebx + push esi + mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn + cpuid + mov esi, dword ptr [esp + 4 * 2 + 4] // data + mov dword ptr [esi], eax + mov dword ptr [esi + 4], ebx + mov dword ptr [esi + 8], ecx + mov dword ptr [esi + 12], edx + pop esi + pop ebx + ret + } + } + #else + #include <intrin.h> // for __cpuid + #endif +#else + #ifndef __GNUC_PREREQ + #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor))) + #endif + #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__) + #include <cpuid.h> + #else + #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm' + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #else + #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn)) + #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn)) + #endif + #endif +#endif +#endif + +namespace Xbyak { namespace util { + +typedef enum { + SmtLevel = 1, + CoreLevel = 2 +} IntelCpuTopologyLevel; + +/** + CPU detection class +*/ +class Cpu { + uint64 type_; + //system topology + bool x2APIC_supported_; + static const size_t maxTopologyLevels = 2; + unsigned int numCores_[maxTopologyLevels]; + + static const unsigned int maxNumberCacheLevels = 10; + unsigned int dataCacheSize_[maxNumberCacheLevels]; + unsigned int coresSharignDataCache_[maxNumberCacheLevels]; + unsigned int dataCacheLevels_; + + unsigned int get32bitAsBE(const char *x) const + { + return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); + } + unsigned int mask(int n) const + { + return (1U << n) - 1; + } + void setFamily() + { + unsigned int data[4] = {}; + getCpuid(1, data); + stepping = data[0] & mask(4); + model = (data[0] >> 4) & mask(4); + family = (data[0] >> 8) & mask(4); + // type = (data[0] >> 12) & mask(2); + extModel = (data[0] >> 16) & mask(4); + extFamily = (data[0] >> 20) & mask(8); + if (family == 0x0f) { + displayFamily = family + extFamily; + } else { + displayFamily = family; + } + if (family == 6 || family == 0x0f) { + displayModel = (extModel << 4) + model; + } else { + displayModel = model; + } + } + unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end) + { + return (val >> base) & ((1u << (end - base)) - 1); + } + void setNumCores() + { + if ((type_ & tINTEL) == 0) return; + + unsigned int data[4] = {}; + + /* CAUTION: These numbers are configuration as shipped by Intel. */ + getCpuidEx(0x0, 0, data); + if (data[0] >= 0xB) { + /* + if leaf 11 exists(x2APIC is supported), + we use it to get the number of smt cores and cores on socket + + leaf 0xB can be zeroed-out by a hypervisor + */ + x2APIC_supported_ = true; + for (unsigned int i = 0; i < maxTopologyLevels; i++) { + getCpuidEx(0xB, i, data); + IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15); + if (level == SmtLevel || level == CoreLevel) { + numCores_[level - 1] = extractBit(data[1], 0, 15); + } + } + } else { + /* + Failed to deremine num of cores without x2APIC support. + TODO: USE initial APIC ID to determine ncores. + */ + numCores_[SmtLevel - 1] = 0; + numCores_[CoreLevel - 1] = 0; + } + + } + void setCacheHierarchy() + { + if ((type_ & tINTEL) == 0) return; + const unsigned int NO_CACHE = 0; + const unsigned int DATA_CACHE = 1; +// const unsigned int INSTRUCTION_CACHE = 2; + const unsigned int UNIFIED_CACHE = 3; + unsigned int smt_width = 0; + unsigned int logical_cores = 0; + unsigned int data[4] = {}; + + if (x2APIC_supported_) { + smt_width = numCores_[0]; + logical_cores = numCores_[1]; + } + + /* + Assumptions: + the first level of data cache is not shared (which is the + case for every existing architecture) and use this to + determine the SMT width for arch not supporting leaf 11. + when leaf 4 reports a number of core less than numCores_ + on socket reported by leaf 11, then it is a correct number + of cores not an upperbound. + */ + for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) { + getCpuidEx(0x4, i, data); + unsigned int cacheType = extractBit(data[0], 0, 4); + if (cacheType == NO_CACHE) break; + if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) { + unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1; + if (logical_cores != 0) { // true only if leaf 0xB is supported and valid + actual_logical_cores = (std::min)(actual_logical_cores, logical_cores); + } + assert(actual_logical_cores != 0); + dataCacheSize_[dataCacheLevels_] = + (extractBit(data[1], 22, 31) + 1) + * (extractBit(data[1], 12, 21) + 1) + * (extractBit(data[1], 0, 11) + 1) + * (data[2] + 1); + if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores; + assert(smt_width != 0); + // FIXME: check and fix number of cores sharing L3 cache for different configurations + // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket) + coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u); + dataCacheLevels_++; + } + } + } + +public: + int model; + int family; + int stepping; + int extModel; + int extFamily; + int displayFamily; // family + extFamily + int displayModel; // model + extModel + + unsigned int getNumCores(IntelCpuTopologyLevel level) { + if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER); + if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED); + return (level == CoreLevel) + ? numCores_[level - 1] / numCores_[SmtLevel - 1] + : numCores_[level - 1]; + } + + unsigned int getDataCacheLevels() const { return dataCacheLevels_; } + unsigned int getCoresSharingDataCache(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return coresSharignDataCache_[i]; + } + unsigned int getDataCacheSize(unsigned int i) const + { + if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER); + return dataCacheSize_[i]; + } + + /* + data[] = { eax, ebx, ecx, edx } + */ + static inline void getCpuid(unsigned int eaxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuid(reinterpret_cast<int*>(data), eaxIn); + #else + __cpuid(eaxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)data; +#endif + } + static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4]) + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + __cpuidex(reinterpret_cast<int*>(data), eaxIn, ecxIn); + #else + __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]); + #endif +#else + (void)eaxIn; + (void)ecxIn; + (void)data; +#endif + } + static inline uint64 getXfeature() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return _xgetbv(0); + #else + unsigned int eax, edx; + // xgetvb is not support on gcc 4.2 +// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0)); + __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0)); + return ((uint64)edx << 32) | eax; + #endif +#else + return 0; +#endif + } + typedef uint64 Type; + + static const Type NONE = 0; + static const Type tMMX = 1 << 0; + static const Type tMMX2 = 1 << 1; + static const Type tCMOV = 1 << 2; + static const Type tSSE = 1 << 3; + static const Type tSSE2 = 1 << 4; + static const Type tSSE3 = 1 << 5; + static const Type tSSSE3 = 1 << 6; + static const Type tSSE41 = 1 << 7; + static const Type tSSE42 = 1 << 8; + static const Type tPOPCNT = 1 << 9; + static const Type tAESNI = 1 << 10; + static const Type tSSE5 = 1 << 11; + static const Type tOSXSAVE = 1 << 12; + static const Type tPCLMULQDQ = 1 << 13; + static const Type tAVX = 1 << 14; + static const Type tFMA = 1 << 15; + + static const Type t3DN = 1 << 16; + static const Type tE3DN = 1 << 17; + static const Type tSSE4a = 1 << 18; + static const Type tRDTSCP = 1 << 19; + static const Type tAVX2 = 1 << 20; + static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt + static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx + static const Type tLZCNT = 1 << 23; + + static const Type tINTEL = 1 << 24; + static const Type tAMD = 1 << 25; + + static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb + static const Type tRDRAND = 1 << 27; + static const Type tADX = 1 << 28; // adcx, adox + static const Type tRDSEED = 1 << 29; // rdseed + static const Type tSMAP = 1 << 30; // stac + static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest + static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort + static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph + static const Type tMOVBE = uint64(1) << 34; // mobve + static const Type tAVX512F = uint64(1) << 35; + static const Type tAVX512DQ = uint64(1) << 36; + static const Type tAVX512_IFMA = uint64(1) << 37; + static const Type tAVX512IFMA = tAVX512_IFMA; + static const Type tAVX512PF = uint64(1) << 38; + static const Type tAVX512ER = uint64(1) << 39; + static const Type tAVX512CD = uint64(1) << 40; + static const Type tAVX512BW = uint64(1) << 41; + static const Type tAVX512VL = uint64(1) << 42; + static const Type tAVX512_VBMI = uint64(1) << 43; + static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual + static const Type tAVX512_4VNNIW = uint64(1) << 44; + static const Type tAVX512_4FMAPS = uint64(1) << 45; + static const Type tPREFETCHWT1 = uint64(1) << 46; + static const Type tPREFETCHW = uint64(1) << 47; + static const Type tSHA = uint64(1) << 48; + static const Type tMPX = uint64(1) << 49; + static const Type tAVX512_VBMI2 = uint64(1) << 50; + static const Type tGFNI = uint64(1) << 51; + static const Type tVAES = uint64(1) << 52; + static const Type tVPCLMULQDQ = uint64(1) << 53; + static const Type tAVX512_VNNI = uint64(1) << 54; + static const Type tAVX512_BITALG = uint64(1) << 55; + static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56; + + Cpu() + : type_(NONE) + , x2APIC_supported_(false) + , numCores_() + , dataCacheSize_() + , coresSharignDataCache_() + , dataCacheLevels_(0) + { + unsigned int data[4] = {}; + const unsigned int& EAX = data[0]; + const unsigned int& EBX = data[1]; + const unsigned int& ECX = data[2]; + const unsigned int& EDX = data[3]; + getCpuid(0, data); + const unsigned int maxNum = EAX; + static const char intel[] = "ntel"; + static const char amd[] = "cAMD"; + if (ECX == get32bitAsBE(amd)) { + type_ |= tAMD; + getCpuid(0x80000001, data); + if (EDX & (1U << 31)) type_ |= t3DN; + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 30)) type_ |= tE3DN; + if (EDX & (1U << 22)) type_ |= tMMX2; + if (EDX & (1U << 27)) type_ |= tRDTSCP; + } + if (ECX == get32bitAsBE(intel)) { + type_ |= tINTEL; + getCpuid(0x80000001, data); + if (EDX & (1U << 27)) type_ |= tRDTSCP; + if (ECX & (1U << 5)) type_ |= tLZCNT; + if (ECX & (1U << 8)) type_ |= tPREFETCHW; + } + getCpuid(1, data); + if (ECX & (1U << 0)) type_ |= tSSE3; + if (ECX & (1U << 9)) type_ |= tSSSE3; + if (ECX & (1U << 19)) type_ |= tSSE41; + if (ECX & (1U << 20)) type_ |= tSSE42; + if (ECX & (1U << 22)) type_ |= tMOVBE; + if (ECX & (1U << 23)) type_ |= tPOPCNT; + if (ECX & (1U << 25)) type_ |= tAESNI; + if (ECX & (1U << 1)) type_ |= tPCLMULQDQ; + if (ECX & (1U << 27)) type_ |= tOSXSAVE; + if (ECX & (1U << 30)) type_ |= tRDRAND; + if (ECX & (1U << 29)) type_ |= tF16C; + + if (EDX & (1U << 15)) type_ |= tCMOV; + if (EDX & (1U << 23)) type_ |= tMMX; + if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE; + if (EDX & (1U << 26)) type_ |= tSSE2; + + if (type_ & tOSXSAVE) { + // check XFEATURE_ENABLED_MASK[2:1] = '11b' + uint64 bv = getXfeature(); + if ((bv & 6) == 6) { + if (ECX & (1U << 28)) type_ |= tAVX; + if (ECX & (1U << 12)) type_ |= tFMA; + if (((bv >> 5) & 7) == 7) { + getCpuidEx(7, 0, data); + if (EBX & (1U << 16)) type_ |= tAVX512F; + if (type_ & tAVX512F) { + if (EBX & (1U << 17)) type_ |= tAVX512DQ; + if (EBX & (1U << 21)) type_ |= tAVX512_IFMA; + if (EBX & (1U << 26)) type_ |= tAVX512PF; + if (EBX & (1U << 27)) type_ |= tAVX512ER; + if (EBX & (1U << 28)) type_ |= tAVX512CD; + if (EBX & (1U << 30)) type_ |= tAVX512BW; + if (EBX & (1U << 31)) type_ |= tAVX512VL; + if (ECX & (1U << 1)) type_ |= tAVX512_VBMI; + if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2; + if (ECX & (1U << 8)) type_ |= tGFNI; + if (ECX & (1U << 9)) type_ |= tVAES; + if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ; + if (ECX & (1U << 11)) type_ |= tAVX512_VNNI; + if (ECX & (1U << 12)) type_ |= tAVX512_BITALG; + if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ; + if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW; + if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS; + } + } + } + } + if (maxNum >= 7) { + getCpuidEx(7, 0, data); + if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2; + if (EBX & (1U << 3)) type_ |= tBMI1; + if (EBX & (1U << 8)) type_ |= tBMI2; + if (EBX & (1U << 9)) type_ |= tENHANCED_REP; + if (EBX & (1U << 18)) type_ |= tRDSEED; + if (EBX & (1U << 19)) type_ |= tADX; + if (EBX & (1U << 20)) type_ |= tSMAP; + if (EBX & (1U << 4)) type_ |= tHLE; + if (EBX & (1U << 11)) type_ |= tRTM; + if (EBX & (1U << 14)) type_ |= tMPX; + if (EBX & (1U << 29)) type_ |= tSHA; + if (ECX & (1U << 0)) type_ |= tPREFETCHWT1; + } + setFamily(); + setNumCores(); + setCacheHierarchy(); + } + void putFamily() const + { + printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n", + family, model, stepping, extFamily, extModel); + printf("display:family=%X, model=%X\n", displayFamily, displayModel); + } + bool has(Type type) const + { + return (type & type_) != 0; + } +}; + +class Clock { +public: + static inline uint64 getRdtsc() + { +#ifdef XBYAK_INTEL_CPU_SPECIFIC + #ifdef _MSC_VER + return __rdtsc(); + #else + unsigned int eax, edx; + __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx)); + return ((uint64)edx << 32) | eax; + #endif +#else + // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu + return 0; +#endif + } + Clock() + : clock_(0) + , count_(0) + { + } + void begin() + { + clock_ -= getRdtsc(); + } + void end() + { + clock_ += getRdtsc(); + count_++; + } + int getCount() const { return count_; } + uint64 getClock() const { return clock_; } + void clear() { count_ = 0; clock_ = 0; } +private: + uint64 clock_; + int count_; +}; + +#ifdef XBYAK64 +const int UseRCX = 1 << 6; +const int UseRDX = 1 << 7; + +class Pack { + static const size_t maxTblNum = 15; + const Xbyak::Reg64 *tbl_[maxTblNum]; + size_t n_; +public: + Pack() : tbl_(), n_(0) {} + Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); } + Pack(const Pack& rhs) + : n_(rhs.n_) + { + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + } + Pack& operator=(const Pack& rhs) + { + n_ = rhs.n_; + for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i]; + return *this; + } + Pack(const Xbyak::Reg64& t0) + { n_ = 1; tbl_[0] = &t0; } + Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; } + Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; } + Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; } + Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; } + Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; } + Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; } + Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; } + Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; } + Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0) + { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; } + Pack& append(const Xbyak::Reg64& t) + { + if (n_ == maxTblNum) { + fprintf(stderr, "ERR Pack::can't append\n"); + throw Error(ERR_BAD_PARAMETER); + } + tbl_[n_++] = &t; + return *this; + } + void init(const Xbyak::Reg64 *tbl, size_t n) + { + if (n > maxTblNum) { + fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n); + throw Error(ERR_BAD_PARAMETER); + } + n_ = n; + for (size_t i = 0; i < n; i++) { + tbl_[i] = &tbl[i]; + } + } + const Xbyak::Reg64& operator[](size_t n) const + { + if (n >= n_) { + fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_); + throw Error(ERR_BAD_PARAMETER); + } + return *tbl_[n]; + } + size_t size() const { return n_; } + /* + get tbl[pos, pos + num) + */ + Pack sub(size_t pos, size_t num = size_t(-1)) const + { + if (num == size_t(-1)) num = n_ - pos; + if (pos + num > n_) { + fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num); + throw Error(ERR_BAD_PARAMETER); + } + Pack pack; + pack.n_ = num; + for (size_t i = 0; i < num; i++) { + pack.tbl_[i] = tbl_[pos + i]; + } + return pack; + } + void put() const + { + for (size_t i = 0; i < n_; i++) { + printf("%s ", tbl_[i]->toString()); + } + printf("\n"); + } +}; + +class StackFrame { +#ifdef XBYAK64_WIN + static const int noSaveNum = 6; + static const int rcxPos = 0; + static const int rdxPos = 1; +#else + static const int noSaveNum = 8; + static const int rcxPos = 3; + static const int rdxPos = 2; +#endif + static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax + Xbyak::CodeGenerator *code_; + int pNum_; + int tNum_; + bool useRcx_; + bool useRdx_; + int saveNum_; + int P_; + bool makeEpilog_; + Xbyak::Reg64 pTbl_[4]; + Xbyak::Reg64 tTbl_[maxRegNum]; + Pack p_; + Pack t_; + StackFrame(const StackFrame&); + void operator=(const StackFrame&); +public: + const Pack& p; + const Pack& t; + /* + make stack frame + @param sf [in] this + @param pNum [in] num of function parameter(0 <= pNum <= 4) + @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14 + @param stackSizeByte [in] local stack size + @param makeEpilog [in] automatically call close() if true + + you can use + rax + gp0, ..., gp(pNum - 1) + gt0, ..., gt(tNum-1) + rcx if tNum & UseRCX + rdx if tNum & UseRDX + rsp[0..stackSizeByte - 1] + */ + StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true) + : code_(code) + , pNum_(pNum) + , tNum_(tNum & ~(UseRCX | UseRDX)) + , useRcx_((tNum & UseRCX) != 0) + , useRdx_((tNum & UseRDX) != 0) + , saveNum_(0) + , P_(0) + , makeEpilog_(makeEpilog) + , p(p_) + , t(t_) + { + using namespace Xbyak; + if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM); + const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0); + if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM); + const Reg64& _rsp = code->rsp; + saveNum_ = (std::max)(0, allRegNum - noSaveNum); + const int *tbl = getOrderTbl() + noSaveNum; + for (int i = 0; i < saveNum_; i++) { + code->push(Reg64(tbl[i])); + } + P_ = (stackSizeByte + 7) / 8; + if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment + P_ *= 8; + if (P_ > 0) code->sub(_rsp, P_); + int pos = 0; + for (int i = 0; i < pNum; i++) { + pTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + for (int i = 0; i < tNum_; i++) { + tTbl_[i] = Xbyak::Reg64(getRegIdx(pos)); + } + if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx); + if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx); + p_.init(pTbl_, pNum); + t_.init(tTbl_, tNum_); + } + /* + make epilog manually + @param callRet [in] call ret() if true + */ + void close(bool callRet = true) + { + using namespace Xbyak; + const Reg64& _rsp = code_->rsp; + const int *tbl = getOrderTbl() + noSaveNum; + if (P_ > 0) code_->add(_rsp, P_); + for (int i = 0; i < saveNum_; i++) { + code_->pop(Reg64(tbl[saveNum_ - 1 - i])); + } + + if (callRet) code_->ret(); + } + ~StackFrame() + { + if (!makeEpilog_) return; + try { + close(); + } catch (std::exception& e) { + printf("ERR:StackFrame %s\n", e.what()); + //exit(1); + } + } +private: + const int *getOrderTbl() const + { + using namespace Xbyak; + static const int tbl[] = { +#ifdef XBYAK64_WIN + Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI, +#else + Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, +#endif + Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15 + }; + return &tbl[0]; + } + int getRegIdx(int& pos) const + { + assert(pos < maxRegNum); + using namespace Xbyak; + const int *tbl = getOrderTbl(); + int r = tbl[pos++]; + if (useRcx_) { + if (r == Operand::RCX) { return Operand::R10; } + if (r == Operand::R10) { r = tbl[pos++]; } + } + if (useRdx_) { + if (r == Operand::RDX) { return Operand::R11; } + if (r == Operand::R11) { return tbl[pos++]; } + } + return r; + } +}; +#endif + +} } // end of util +#endif diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza Binary files differnew file mode 100644 index 0000000000..12459a33bc --- /dev/null +++ b/thirdparty/oidn/weights/rtlightmap_hdr.tza diff --git a/thirdparty/xatlas/xatlas.cpp b/thirdparty/xatlas/xatlas.cpp index 80cacf9746..b1cbeb980f 100644 --- a/thirdparty/xatlas/xatlas.cpp +++ b/thirdparty/xatlas/xatlas.cpp @@ -33,19 +33,19 @@ https://github.com/brandonpelfrey/Fast-BVH MIT License Copyright (c) 2012 Brandon Pelfrey */ -#include <atomic> -#include <condition_variable> -#include <mutex> -#include <thread> #include <assert.h> #include <float.h> // FLT_MAX #include <limits.h> #include <math.h> +#include <atomic> +#include <condition_variable> +#include <mutex> +#include <thread> #define __STDC_LIMIT_MACROS +#include "xatlas.h" #include <stdint.h> #include <stdio.h> #include <string.h> -#include "xatlas.h" #ifndef XA_DEBUG #ifdef NDEBUG @@ -70,7 +70,10 @@ Copyright (c) 2012 Brandon Pelfrey #define XA_XSTR(x) XA_STR(x) #ifndef XA_ASSERT -#define XA_ASSERT(exp) if (!(exp)) { XA_PRINT_WARNING("\rASSERT: %s %s %d\n", XA_XSTR(exp), __FILE__, __LINE__); } +#define XA_ASSERT(exp) \ + if (!(exp)) { \ + XA_PRINT_WARNING("\rASSERT: %s %s %d\n", XA_XSTR(exp), __FILE__, __LINE__); \ + } #endif #ifndef XA_DEBUG_ASSERT @@ -78,13 +81,13 @@ Copyright (c) 2012 Brandon Pelfrey #endif #ifndef XA_PRINT -#define XA_PRINT(...) \ +#define XA_PRINT(...) \ if (xatlas::internal::s_print && xatlas::internal::s_printVerbose) \ xatlas::internal::s_print(__VA_ARGS__); #endif #ifndef XA_PRINT_WARNING -#define XA_PRINT_WARNING(...) \ +#define XA_PRINT_WARNING(...) \ if (xatlas::internal::s_print) \ xatlas::internal::s_print(__VA_ARGS__); #endif @@ -136,19 +139,13 @@ Copyright (c) 2012 Brandon Pelfrey #define XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION 0 #define XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS 0 -#define XA_DEBUG_EXPORT_OBJ (0 \ - || XA_DEBUG_EXPORT_OBJ_SOURCE_MESHES \ - || XA_DEBUG_EXPORT_OBJ_CHART_GROUPS \ - || XA_DEBUG_EXPORT_OBJ_PLANAR_REGIONS \ - || XA_DEBUG_EXPORT_OBJ_CHARTS \ - || XA_DEBUG_EXPORT_OBJ_BEFORE_FIX_TJUNCTION \ - || XA_DEBUG_EXPORT_OBJ_CLOSE_HOLES_ERROR \ - || XA_DEBUG_EXPORT_OBJ_CHARTS_AFTER_PARAMETERIZATION \ - || XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION \ - || XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS) +#define XA_DEBUG_EXPORT_OBJ (0 || XA_DEBUG_EXPORT_OBJ_SOURCE_MESHES || XA_DEBUG_EXPORT_OBJ_CHART_GROUPS || XA_DEBUG_EXPORT_OBJ_PLANAR_REGIONS || XA_DEBUG_EXPORT_OBJ_CHARTS || XA_DEBUG_EXPORT_OBJ_BEFORE_FIX_TJUNCTION || XA_DEBUG_EXPORT_OBJ_CLOSE_HOLES_ERROR || XA_DEBUG_EXPORT_OBJ_CHARTS_AFTER_PARAMETERIZATION || XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION || XA_DEBUG_EXPORT_OBJ_RECOMPUTED_CHARTS) #ifdef _MSC_VER -#define XA_FOPEN(_file, _filename, _mode) { if (fopen_s(&_file, _filename, _mode) != 0) _file = NULL; } +#define XA_FOPEN(_file, _filename, _mode) \ + { \ + if (fopen_s(&_file, _filename, _mode) != 0) _file = NULL; \ + } #define XA_SPRINTF(_buffer, _size, _format, ...) sprintf_s(_buffer, _size, _format, __VA_ARGS__) #else #define XA_FOPEN(_file, _filename, _mode) _file = fopen(_filename, _mode) @@ -163,10 +160,8 @@ static FreeFunc s_free = free; static PrintFunc s_print = printf; static bool s_printVerbose = false; -struct MemTag -{ - enum - { +struct MemTag { + enum { Default, BitImage, BVH, @@ -189,8 +184,7 @@ struct MemTag }; #if XA_DEBUG_HEAP -struct AllocHeader -{ +struct AllocHeader { size_t size; const char *file; int line; @@ -203,11 +197,10 @@ struct AllocHeader static std::mutex s_allocMutex; static AllocHeader *s_allocRoot = nullptr; static size_t s_allocTotalCount = 0, s_allocTotalSize = 0, s_allocPeakSize = 0, s_allocCount[MemTag::Count] = { 0 }, s_allocTotalTagSize[MemTag::Count] = { 0 }, s_allocPeakTagSize[MemTag::Count] = { 0 }; -static uint32_t s_allocId =0 ; +static uint32_t s_allocId = 0; static constexpr uint32_t kAllocRedzone = 0x12345678; -static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line) -{ +static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line) { std::unique_lock<std::mutex> lock(s_allocMutex); if (!size && !ptr) return nullptr; @@ -268,8 +261,7 @@ static void *Realloc(void *ptr, size_t size, int tag, const char *file, int line return newPtr + sizeof(AllocHeader); } -static void ReportLeaks() -{ +static void ReportLeaks() { printf("Checking for memory leaks...\n"); bool anyLeaks = false; AllocHeader *header = s_allocRoot; @@ -297,8 +289,7 @@ static void ReportLeaks() s_allocTotalTagSize[i] = s_allocPeakTagSize[i] = 0; } -static void PrintMemoryUsage() -{ +static void PrintMemoryUsage() { XA_PRINT("Total allocations: %zu\n", s_allocTotalCount); XA_PRINT("Memory usage: %0.2fMB current, %0.2fMB peak\n", internal::s_allocTotalSize / 1024.0f / 1024.0f, internal::s_allocPeakSize / 1024.0f / 1024.0f); static const char *labels[] = { // Sync with MemTag @@ -327,8 +318,7 @@ static void PrintMemoryUsage() #define XA_PRINT_MEM_USAGE internal::PrintMemoryUsage(); #else -static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, int /*line*/) -{ +static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, int /*line*/) { if (size == 0 && !ptr) return nullptr; if (size == 0 && s_free) { @@ -347,10 +337,11 @@ static void *Realloc(void *ptr, size_t size, int /*tag*/, const char * /*file*/, #if XA_PROFILE #define XA_PROFILE_START(var) const clock_t var##Start = clock(); #define XA_PROFILE_END(var) internal::s_profile.var += clock() - var##Start; -#define XA_PROFILE_PRINT_AND_RESET(label, var) XA_PRINT("%s%.2f seconds (%g ms)\n", label, internal::clockToSeconds(internal::s_profile.var), internal::clockToMs(internal::s_profile.var)); internal::s_profile.var = 0; +#define XA_PROFILE_PRINT_AND_RESET(label, var) \ + XA_PRINT("%s%.2f seconds (%g ms)\n", label, internal::clockToSeconds(internal::s_profile.var), internal::clockToMs(internal::s_profile.var)); \ + internal::s_profile.var = 0; -struct ProfileData -{ +struct ProfileData { clock_t addMeshReal; clock_t addMeshCopyData; std::atomic<clock_t> addMeshThread; @@ -390,13 +381,11 @@ struct ProfileData static ProfileData s_profile; -static double clockToMs(clock_t c) -{ +static double clockToMs(clock_t c) { return c * 1000.0 / CLOCKS_PER_SEC; } -static double clockToSeconds(clock_t c) -{ +static double clockToSeconds(clock_t c) { return c / (double)CLOCKS_PER_SEC; } #else @@ -412,89 +401,75 @@ static constexpr float kEpsilon = 0.0001f; static constexpr float kAreaEpsilon = FLT_EPSILON; static constexpr float kNormalEpsilon = 0.001f; -static int align(int x, int a) -{ +static int align(int x, int a) { return (x + a - 1) & ~(a - 1); } template <typename T> -static T max(const T &a, const T &b) -{ +static T max(const T &a, const T &b) { return a > b ? a : b; } template <typename T> -static T min(const T &a, const T &b) -{ +static T min(const T &a, const T &b) { return a < b ? a : b; } template <typename T> -static T max3(const T &a, const T &b, const T &c) -{ +static T max3(const T &a, const T &b, const T &c) { return max(a, max(b, c)); } /// Return the maximum of the three arguments. template <typename T> -static T min3(const T &a, const T &b, const T &c) -{ +static T min3(const T &a, const T &b, const T &c) { return min(a, min(b, c)); } /// Clamp between two values. template <typename T> -static T clamp(const T &x, const T &a, const T &b) -{ +static T clamp(const T &x, const T &a, const T &b) { return min(max(x, a), b); } template <typename T> -static void swap(T &a, T &b) -{ +static void swap(T &a, T &b) { T temp = a; a = b; b = temp; } -union FloatUint32 -{ +union FloatUint32 { float f; uint32_t u; }; -static bool isFinite(float f) -{ +static bool isFinite(float f) { FloatUint32 fu; fu.f = f; return fu.u != 0x7F800000u && fu.u != 0x7F800001u; } -static bool isNan(float f) -{ +static bool isNan(float f) { return f != f; } // Robust floating point comparisons: // http://realtimecollisiondetection.net/blog/?p=89 -static bool equal(const float f0, const float f1, const float epsilon) -{ +static bool equal(const float f0, const float f1, const float epsilon) { //return fabs(f0-f1) <= epsilon; return fabs(f0 - f1) <= epsilon * max3(1.0f, fabsf(f0), fabsf(f1)); } -static int ftoi_ceil(float val) -{ +static int ftoi_ceil(float val) { return (int)ceilf(val); } -static bool isZero(const float f, const float epsilon) -{ +static bool isZero(const float f, const float epsilon) { return fabs(f) <= epsilon; } -static float square(float f) -{ +static float square(float f) { return f * f; } @@ -504,9 +479,8 @@ static float square(float f) * @note isPowerOfTwo(x) == true -> nextPowerOfTwo(x) == x * @note nextPowerOfTwo(x) = 2 << log2(x-1) */ -static uint32_t nextPowerOfTwo(uint32_t x) -{ - XA_DEBUG_ASSERT( x != 0 ); +static uint32_t nextPowerOfTwo(uint32_t x) { + XA_DEBUG_ASSERT(x != 0); // On modern CPUs this is supposed to be as fast as using the bsr instruction. x--; x |= x >> 1; @@ -517,65 +491,59 @@ static uint32_t nextPowerOfTwo(uint32_t x) return x + 1; } -static uint32_t sdbmHash(const void *data_in, uint32_t size, uint32_t h = 5381) -{ - const uint8_t *data = (const uint8_t *) data_in; +static uint32_t sdbmHash(const void *data_in, uint32_t size, uint32_t h = 5381) { + const uint8_t *data = (const uint8_t *)data_in; uint32_t i = 0; while (i < size) { - h = (h << 16) + (h << 6) - h + (uint32_t ) data[i++]; + h = (h << 16) + (h << 6) - h + (uint32_t)data[i++]; } return h; } template <typename T> -static uint32_t hash(const T &t, uint32_t h = 5381) -{ +static uint32_t hash(const T &t, uint32_t h = 5381) { return sdbmHash(&t, sizeof(T), h); } // Functors for hash table: -template <typename Key> struct Hash -{ +template <typename Key> +struct Hash { uint32_t operator()(const Key &k) const { return hash(k); } }; -template <typename Key> struct Equal -{ +template <typename Key> +struct Equal { bool operator()(const Key &k0, const Key &k1) const { return k0 == k1; } }; -class Vector2 -{ +class Vector2 { public: Vector2() {} - explicit Vector2(float f) : x(f), y(f) {} - Vector2(float x, float y): x(x), y(y) {} + explicit Vector2(float f) : + x(f), y(f) {} + Vector2(float x, float y) : + x(x), y(y) {} - Vector2 operator-() const - { + Vector2 operator-() const { return Vector2(-x, -y); } - void operator+=(const Vector2 &v) - { + void operator+=(const Vector2 &v) { x += v.x; y += v.y; } - void operator-=(const Vector2 &v) - { + void operator-=(const Vector2 &v) { x -= v.x; y -= v.y; } - void operator*=(float s) - { + void operator*=(float s) { x *= s; y *= s; } - void operator*=(const Vector2 &v) - { + void operator*=(const Vector2 &v) { x *= v.x; y *= v.y; } @@ -583,13 +551,11 @@ public: float x, y; }; -static bool operator==(const Vector2 &a, const Vector2 &b) -{ +static bool operator==(const Vector2 &a, const Vector2 &b) { return a.x == b.x && a.y == b.y; } -static bool operator!=(const Vector2 &a, const Vector2 &b) -{ +static bool operator!=(const Vector2 &a, const Vector2 &b) { return a.x != b.x || a.y != b.y; } @@ -598,40 +564,33 @@ static bool operator!=(const Vector2 &a, const Vector2 &b) return Vector2(a.x + b.x, a.y + b.y); }*/ -static Vector2 operator-(const Vector2 &a, const Vector2 &b) -{ +static Vector2 operator-(const Vector2 &a, const Vector2 &b) { return Vector2(a.x - b.x, a.y - b.y); } -static Vector2 operator*(const Vector2 &v, float s) -{ +static Vector2 operator*(const Vector2 &v, float s) { return Vector2(v.x * s, v.y * s); } -static float dot(const Vector2 &a, const Vector2 &b) -{ +static float dot(const Vector2 &a, const Vector2 &b) { return a.x * b.x + a.y * b.y; } -static float lengthSquared(const Vector2 &v) -{ +static float lengthSquared(const Vector2 &v) { return v.x * v.x + v.y * v.y; } -static float length(const Vector2 &v) -{ +static float length(const Vector2 &v) { return sqrtf(lengthSquared(v)); } #if XA_DEBUG -static bool isNormalized(const Vector2 &v, float epsilon = kNormalEpsilon) -{ +static bool isNormalized(const Vector2 &v, float epsilon = kNormalEpsilon) { return equal(length(v), 1, epsilon); } #endif -static Vector2 normalize(const Vector2 &v, float epsilon) -{ +static Vector2 normalize(const Vector2 &v, float epsilon) { float l = length(v); XA_DEBUG_ASSERT(!isZero(l, epsilon)); XA_UNUSED(epsilon); @@ -640,36 +599,30 @@ static Vector2 normalize(const Vector2 &v, float epsilon) return n; } -static Vector2 normalizeSafe(const Vector2 &v, const Vector2 &fallback, float epsilon) -{ +static Vector2 normalizeSafe(const Vector2 &v, const Vector2 &fallback, float epsilon) { float l = length(v); if (isZero(l, epsilon)) return fallback; return v * (1.0f / l); } -static bool equal(const Vector2 &v1, const Vector2 &v2, float epsilon) -{ +static bool equal(const Vector2 &v1, const Vector2 &v2, float epsilon) { return equal(v1.x, v2.x, epsilon) && equal(v1.y, v2.y, epsilon); } -static Vector2 min(const Vector2 &a, const Vector2 &b) -{ +static Vector2 min(const Vector2 &a, const Vector2 &b) { return Vector2(min(a.x, b.x), min(a.y, b.y)); } -static Vector2 max(const Vector2 &a, const Vector2 &b) -{ +static Vector2 max(const Vector2 &a, const Vector2 &b) { return Vector2(max(a.x, b.x), max(a.y, b.y)); } -static bool isFinite(const Vector2 &v) -{ +static bool isFinite(const Vector2 &v) { return isFinite(v.x) && isFinite(v.y); } -static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) -{ +static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) { // IC: While it may be appealing to use the following expression: //return (c.x * a.y + a.x * b.y + b.x * c.y - b.x * a.y - c.x * b.y - a.x * c.y) * 0.5f; // That's actually a terrible idea. Small triangles far from the origin can end up producing fairly large floating point @@ -683,8 +636,7 @@ static float triangleArea(const Vector2 &a, const Vector2 &b, const Vector2 &c) return (v0.x * v1.y - v0.y * v1.x) * 0.5f; } -static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 &b1, const Vector2 &b2, float epsilon) -{ +static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 &b1, const Vector2 &b2, float epsilon) { const Vector2 v0 = a2 - a1; const Vector2 v1 = b2 - b1; const float denom = -v1.x * v0.y + v0.x * v1.y; @@ -692,76 +644,70 @@ static bool linesIntersect(const Vector2 &a1, const Vector2 &a2, const Vector2 & return false; const float s = (-v0.y * (a1.x - b1.x) + v0.x * (a1.y - b1.y)) / denom; if (s > epsilon && s < 1.0f - epsilon) { - const float t = ( v1.x * (a1.y - b1.y) - v1.y * (a1.x - b1.x)) / denom; + const float t = (v1.x * (a1.y - b1.y) - v1.y * (a1.x - b1.x)) / denom; return t > epsilon && t < 1.0f - epsilon; } return false; } -struct Vector2i -{ +struct Vector2i { Vector2i() {} - Vector2i(int32_t x, int32_t y) : x(x), y(y) {} + Vector2i(int32_t x, int32_t y) : + x(x), y(y) {} int32_t x, y; }; -class Vector3 -{ +class Vector3 { public: Vector3() {} - explicit Vector3(float f) : x(f), y(f), z(f) {} - Vector3(float x, float y, float z) : x(x), y(y), z(z) {} - Vector3(const Vector2 &v, float z) : x(v.x), y(v.y), z(z) {} - - Vector2 xy() const - { + explicit Vector3(float f) : + x(f), y(f), z(f) {} + Vector3(float x, float y, float z) : + x(x), y(y), z(z) {} + Vector3(const Vector2 &v, float z) : + x(v.x), y(v.y), z(z) {} + + Vector2 xy() const { return Vector2(x, y); } - Vector3 operator-() const - { + Vector3 operator-() const { return Vector3(-x, -y, -z); } - void operator+=(const Vector3 &v) - { + void operator+=(const Vector3 &v) { x += v.x; y += v.y; z += v.z; } - void operator-=(const Vector3 &v) - { + void operator-=(const Vector3 &v) { x -= v.x; y -= v.y; z -= v.z; } - void operator*=(float s) - { + void operator*=(float s) { x *= s; y *= s; z *= s; } - void operator/=(float s) - { + void operator/=(float s) { float is = 1.0f / s; x *= is; y *= is; z *= is; } - void operator*=(const Vector3 &v) - { + void operator*=(const Vector3 &v) { x *= v.x; y *= v.y; z *= v.z; } - void operator/=(const Vector3 &v) - { + void operator/=(const Vector3 &v) { x /= v.x; y /= v.y; z /= v.z; @@ -770,53 +716,43 @@ public: float x, y, z; }; -static Vector3 operator+(const Vector3 &a, const Vector3 &b) -{ +static Vector3 operator+(const Vector3 &a, const Vector3 &b) { return Vector3(a.x + b.x, a.y + b.y, a.z + b.z); } -static Vector3 operator-(const Vector3 &a, const Vector3 &b) -{ +static Vector3 operator-(const Vector3 &a, const Vector3 &b) { return Vector3(a.x - b.x, a.y - b.y, a.z - b.z); } -static Vector3 cross(const Vector3 &a, const Vector3 &b) -{ +static Vector3 cross(const Vector3 &a, const Vector3 &b) { return Vector3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); } -static Vector3 operator*(const Vector3 &v, float s) -{ +static Vector3 operator*(const Vector3 &v, float s) { return Vector3(v.x * s, v.y * s, v.z * s); } -static Vector3 operator/(const Vector3 &v, float s) -{ +static Vector3 operator/(const Vector3 &v, float s) { return v * (1.0f / s); } -static float dot(const Vector3 &a, const Vector3 &b) -{ +static float dot(const Vector3 &a, const Vector3 &b) { return a.x * b.x + a.y * b.y + a.z * b.z; } -static float lengthSquared(const Vector3 &v) -{ +static float lengthSquared(const Vector3 &v) { return v.x * v.x + v.y * v.y + v.z * v.z; } -static float length(const Vector3 &v) -{ +static float length(const Vector3 &v) { return sqrtf(lengthSquared(v)); } -static bool isNormalized(const Vector3 &v, float epsilon = kNormalEpsilon) -{ +static bool isNormalized(const Vector3 &v, float epsilon = kNormalEpsilon) { return equal(length(v), 1, epsilon); } -static Vector3 normalize(const Vector3 &v, float epsilon) -{ +static Vector3 normalize(const Vector3 &v, float epsilon) { float l = length(v); XA_DEBUG_ASSERT(!isZero(l, epsilon)); XA_UNUSED(epsilon); @@ -825,8 +761,7 @@ static Vector3 normalize(const Vector3 &v, float epsilon) return n; } -static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float epsilon) -{ +static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float epsilon) { float l = length(v); if (isZero(l, epsilon)) { return fallback; @@ -834,72 +769,59 @@ static Vector3 normalizeSafe(const Vector3 &v, const Vector3 &fallback, float ep return v * (1.0f / l); } -static bool equal(const Vector3 &v0, const Vector3 &v1, float epsilon) -{ +static bool equal(const Vector3 &v0, const Vector3 &v1, float epsilon) { return fabs(v0.x - v1.x) <= epsilon && fabs(v0.y - v1.y) <= epsilon && fabs(v0.z - v1.z) <= epsilon; } -static Vector3 min(const Vector3 &a, const Vector3 &b) -{ +static Vector3 min(const Vector3 &a, const Vector3 &b) { return Vector3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); } -static Vector3 max(const Vector3 &a, const Vector3 &b) -{ +static Vector3 max(const Vector3 &a, const Vector3 &b) { return Vector3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); } #if XA_DEBUG -bool isFinite(const Vector3 &v) -{ +bool isFinite(const Vector3 &v) { return isFinite(v.x) && isFinite(v.y) && isFinite(v.z); } #endif -struct Extents2 -{ +struct Extents2 { Vector2 min, max; - void reset() - { + void reset() { min.x = min.y = FLT_MAX; max.x = max.y = -FLT_MAX; } - void add(Vector2 p) - { + void add(Vector2 p) { min = xatlas::internal::min(min, p); max = xatlas::internal::max(max, p); } - Vector2 midpoint() const - { + Vector2 midpoint() const { return Vector2(min.x + (max.x - min.x) * 0.5f, min.y + (max.y - min.y) * 0.5f); } - static bool intersect(Extents2 e1, Extents2 e2) - { + static bool intersect(Extents2 e1, Extents2 e2) { return e1.min.x <= e2.max.x && e1.max.x >= e2.min.x && e1.min.y <= e2.max.y && e1.max.y >= e2.min.y; } }; -struct Plane -{ +struct Plane { Plane() = default; - - Plane(const Vector3 &p1, const Vector3 &p2, const Vector3 &p3) - { + + Plane(const Vector3 &p1, const Vector3 &p2, const Vector3 &p3) { normal = cross(p2 - p1, p3 - p1); dist = dot(normal, p1); } - float distance(const Vector3 &p) const - { + float distance(const Vector3 &p) const { return dot(normal, p) - dist; } - void normalize() - { + void normalize() { const float len = length(normal); if (len > 0.0f) { const float il = 1.0f / len; @@ -912,8 +834,7 @@ struct Plane float dist; }; -static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, const Vector3 &lineEnd, float *t, float epsilon) -{ +static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, const Vector3 &lineEnd, float *t, float epsilon) { float tt; if (!t) t = &tt; @@ -930,22 +851,19 @@ static bool lineIntersectsPoint(const Vector3 &point, const Vector3 &lineStart, return *t > kEpsilon && *t < 1.0f - kEpsilon; } -static bool sameSide(const Vector3 &p1, const Vector3 &p2, const Vector3 &a, const Vector3 &b) -{ +static bool sameSide(const Vector3 &p1, const Vector3 &p2, const Vector3 &a, const Vector3 &b) { const Vector3 &ab = b - a; return dot(cross(ab, p1 - a), cross(ab, p2 - a)) >= 0.0f; } // http://blackpawn.com/texts/pointinpoly/default.html -static bool pointInTriangle(const Vector3 &p, const Vector3 &a, const Vector3 &b, const Vector3 &c) -{ +static bool pointInTriangle(const Vector3 &p, const Vector3 &a, const Vector3 &b, const Vector3 &c) { return sameSide(p, a, b, c) && sameSide(p, b, a, c) && sameSide(p, c, a, b); } #if XA_CLOSE_HOLES_CHECK_EDGE_INTERSECTION // https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm -static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDir, const Vector3 *tri, float *t) -{ +static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDir, const Vector3 *tri, float *t) { *t = 0.0f; const Vector3 &edge1 = tri[1] - tri[0]; const Vector3 &edge2 = tri[2] - tri[0]; @@ -972,50 +890,47 @@ static bool rayIntersectsTriangle(const Vector3 &rayOrigin, const Vector3 &rayDi #endif // From Fast-BVH -struct AABB -{ - AABB() : min(FLT_MAX, FLT_MAX, FLT_MAX), max(-FLT_MAX, -FLT_MAX, -FLT_MAX) {} - AABB(const Vector3 &min, const Vector3 &max) : min(min), max(max) { } - AABB(const Vector3 &p, float radius = 0.0f) : min(p), max(p) { if (radius > 0.0f) expand(radius); } +struct AABB { + AABB() : + min(FLT_MAX, FLT_MAX, FLT_MAX), max(-FLT_MAX, -FLT_MAX, -FLT_MAX) {} + AABB(const Vector3 &min, const Vector3 &max) : + min(min), max(max) {} + AABB(const Vector3 &p, float radius = 0.0f) : + min(p), max(p) { + if (radius > 0.0f) expand(radius); + } - bool intersect(const AABB &other) const - { + bool intersect(const AABB &other) const { return min.x <= other.max.x && max.x >= other.min.x && min.y <= other.max.y && max.y >= other.min.y && min.z <= other.max.z && max.z >= other.min.z; } - void expandToInclude(const Vector3 &p) - { + void expandToInclude(const Vector3 &p) { min = internal::min(min, p); max = internal::max(max, p); } - void expandToInclude(const AABB &aabb) - { + void expandToInclude(const AABB &aabb) { min = internal::min(min, aabb.min); max = internal::max(max, aabb.max); } - void expand(float amount) - { + void expand(float amount) { min -= Vector3(amount); max += Vector3(amount); } - Vector3 centroid() const - { + Vector3 centroid() const { return min + (max - min) * 0.5f; } - uint32_t maxDimension() const - { + uint32_t maxDimension() const { const Vector3 extent = max - min; uint32_t result = 0; if (extent.y > extent.x) { result = 1; if (extent.z > extent.y) result = 2; - } - else if(extent.z > extent.x) + } else if (extent.z > extent.x) result = 2; return result; } @@ -1023,10 +938,9 @@ struct AABB Vector3 min, max; }; -struct ArrayBase -{ - ArrayBase(uint32_t elementSize, int memTag = MemTag::Default) : buffer(nullptr), elementSize(elementSize), size(0), capacity(0) - { +struct ArrayBase { + ArrayBase(uint32_t elementSize, int memTag = MemTag::Default) : + buffer(nullptr), elementSize(elementSize), size(0), capacity(0) { #if XA_DEBUG_HEAP this->memTag = memTag; #else @@ -1034,31 +948,26 @@ struct ArrayBase #endif } - ~ArrayBase() - { + ~ArrayBase() { XA_FREE(buffer); } - XA_INLINE void clear() - { + XA_INLINE void clear() { size = 0; } - void copyFrom(const uint8_t *data, uint32_t length) - { + void copyFrom(const uint8_t *data, uint32_t length) { resize(length, true); memcpy(buffer, data, length * elementSize); } - void copyTo(ArrayBase &other) const - { + void copyTo(ArrayBase &other) const { XA_DEBUG_ASSERT(elementSize == other.elementSize); other.resize(size, true); memcpy(other.buffer, buffer, size * elementSize); } - void destroy() - { + void destroy() { size = 0; XA_FREE(buffer); buffer = nullptr; @@ -1067,8 +976,7 @@ struct ArrayBase } // Insert the given element at the given index shifting all the elements up. - void insertAt(uint32_t index, const uint8_t *value) - { + void insertAt(uint32_t index, const uint8_t *value) { XA_DEBUG_ASSERT(index >= 0 && index <= size); resize(size + 1, false); if (index < size - 1) @@ -1076,8 +984,7 @@ struct ArrayBase memcpy(&buffer[index * elementSize], value, elementSize); } - void moveTo(ArrayBase &other) - { + void moveTo(ArrayBase &other) { XA_DEBUG_ASSERT(elementSize == other.elementSize); other.destroy(); other.buffer = buffer; @@ -1091,21 +998,18 @@ struct ArrayBase elementSize = size = capacity = 0; } - void pop_back() - { + void pop_back() { XA_DEBUG_ASSERT(size > 0); resize(size - 1, false); } - void push_back(const uint8_t *value) - { + void push_back(const uint8_t *value) { XA_DEBUG_ASSERT(value < buffer || value >= buffer + size); resize(size + 1, false); memcpy(&buffer[(size - 1) * elementSize], value, elementSize); } - void push_back(const ArrayBase &other) - { + void push_back(const ArrayBase &other) { XA_DEBUG_ASSERT(elementSize == other.elementSize); if (other.size == 0) return; @@ -1115,22 +1019,19 @@ struct ArrayBase } // Remove the element at the given index. This is an expensive operation! - void removeAt(uint32_t index) - { + void removeAt(uint32_t index) { XA_DEBUG_ASSERT(index >= 0 && index < size); if (size != 1) memmove(buffer + elementSize * index, buffer + elementSize * (index + 1), elementSize * (size - 1 - index)); size--; } - void reserve(uint32_t desiredSize) - { + void reserve(uint32_t desiredSize) { if (desiredSize > capacity) setArrayCapacity(desiredSize); } - void resize(uint32_t newSize, bool exact) - { + void resize(uint32_t newSize, bool exact) { size = newSize; if (size > capacity) { // First allocation is always exact. Otherwise, following allocations grow array to 150% of desired size. @@ -1143,8 +1044,7 @@ struct ArrayBase } } - void setArrayCapacity(uint32_t newCapacity) - { + void setArrayCapacity(uint32_t newCapacity) { XA_DEBUG_ASSERT(newCapacity >= size); if (newCapacity == 0) { // free the buffer. @@ -1164,8 +1064,7 @@ struct ArrayBase } #if XA_DEBUG_HEAP - void setMemTag(int memTag) - { + void setMemTag(int memTag) { this->memTag = memTag; } #endif @@ -1179,28 +1078,25 @@ struct ArrayBase #endif }; -template<typename T> -class Array -{ +template <typename T> +class Array { public: - Array(int memTag = MemTag::Default) : m_base(sizeof(T), memTag) {} - Array(const Array&) = delete; + Array(int memTag = MemTag::Default) : + m_base(sizeof(T), memTag) {} + Array(const Array &) = delete; Array &operator=(const Array &) = delete; - XA_INLINE const T &operator[](uint32_t index) const - { + XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < m_base.size); return ((const T *)m_base.buffer)[index]; } - XA_INLINE T &operator[](uint32_t index) - { + XA_INLINE T &operator[](uint32_t index) { XA_DEBUG_ASSERT(index < m_base.size); return ((T *)m_base.buffer)[index]; } - XA_INLINE const T &back() const - { + XA_INLINE const T &back() const { XA_DEBUG_ASSERT(!isEmpty()); return ((const T *)m_base.buffer)[m_base.size - 1]; } @@ -1208,8 +1104,7 @@ public: XA_INLINE T *begin() { return (T *)m_base.buffer; } XA_INLINE void clear() { m_base.clear(); } - bool contains(const T &value) const - { + bool contains(const T &value) const { for (uint32_t i = 0; i < m_base.size; i++) { if (((const T *)m_base.buffer)[i] == value) return true; @@ -1232,20 +1127,17 @@ public: void reserve(uint32_t desiredSize) { m_base.reserve(desiredSize); } void resize(uint32_t newSize) { m_base.resize(newSize, true); } - void runCtors() - { + void runCtors() { for (uint32_t i = 0; i < m_base.size; i++) new (&((T *)m_base.buffer)[i]) T; } - void runDtors() - { + void runDtors() { for (uint32_t i = 0; i < m_base.size; i++) ((T *)m_base.buffer)[i].~T(); } - void setAll(const T &value) - { + void setAll(const T &value) { auto buffer = (T *)m_base.buffer; for (uint32_t i = 0; i < m_base.size; i++) buffer[i] = value; @@ -1262,33 +1154,47 @@ private: ArrayBase m_base; }; -template<typename T> -struct ArrayView -{ - ArrayView(Array<T> &a) : data(a.data()), length(a.size()) {} - ArrayView(T *data, uint32_t length) : data(data), length(length) {} - ArrayView &operator=(Array<T> &a) { data = a.data(); length = a.size(); return *this; } - XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < length); return data[index]; } +template <typename T> +struct ArrayView { + ArrayView(Array<T> &a) : + data(a.data()), length(a.size()) {} + ArrayView(T *data, uint32_t length) : + data(data), length(length) {} + ArrayView &operator=(Array<T> &a) { + data = a.data(); + length = a.size(); + return *this; + } + XA_INLINE const T &operator[](uint32_t index) const { + XA_DEBUG_ASSERT(index < length); + return data[index]; + } T *data; uint32_t length; }; -template<typename T> -struct ConstArrayView -{ - ConstArrayView(const Array<T> &a) : data(a.data()), length(a.size()) {} - ConstArrayView(const T *data, uint32_t length) : data(data), length(length) {} - ConstArrayView &operator=(const Array<T> &a) { data = a.data(); length = a.size(); return *this; } - XA_INLINE const T &operator[](uint32_t index) const { XA_DEBUG_ASSERT(index < length); return data[index]; } +template <typename T> +struct ConstArrayView { + ConstArrayView(const Array<T> &a) : + data(a.data()), length(a.size()) {} + ConstArrayView(const T *data, uint32_t length) : + data(data), length(length) {} + ConstArrayView &operator=(const Array<T> &a) { + data = a.data(); + length = a.size(); + return *this; + } + XA_INLINE const T &operator[](uint32_t index) const { + XA_DEBUG_ASSERT(index < length); + return data[index]; + } const T *data; uint32_t length; }; /// Basis class to compute tangent space basis, ortogonalizations and to transform vectors from one space to another. -struct Basis -{ - XA_NODISCARD static Vector3 computeTangent(const Vector3 &normal) - { +struct Basis { + XA_NODISCARD static Vector3 computeTangent(const Vector3 &normal) { XA_ASSERT(isNormalized(normal)); // Choose minimum axis. Vector3 tangent; @@ -1304,8 +1210,7 @@ struct Basis return tangent; } - XA_NODISCARD static Vector3 computeBitangent(const Vector3 &normal, const Vector3 &tangent) - { + XA_NODISCARD static Vector3 computeBitangent(const Vector3 &normal, const Vector3 &tangent) { return cross(normal, tangent); } @@ -1315,36 +1220,31 @@ struct Basis }; // Simple bit array. -class BitArray -{ +class BitArray { public: - BitArray() : m_size(0) {} + BitArray() : + m_size(0) {} - BitArray(uint32_t sz) - { + BitArray(uint32_t sz) { resize(sz); } - void resize(uint32_t new_size) - { + void resize(uint32_t new_size) { m_size = new_size; m_wordArray.resize((m_size + 31) >> 5); } - bool get(uint32_t index) const - { + bool get(uint32_t index) const { XA_DEBUG_ASSERT(index < m_size); return (m_wordArray[index >> 5] & (1 << (index & 31))) != 0; } - void set(uint32_t index) - { + void set(uint32_t index) { XA_DEBUG_ASSERT(index < m_size); - m_wordArray[index >> 5] |= (1 << (index & 31)); + m_wordArray[index >> 5] |= (1 << (index & 31)); } - void zeroOutMemory() - { + void zeroOutMemory() { m_wordArray.zeroOutMemory(); } @@ -1353,13 +1253,13 @@ private: Array<uint32_t> m_wordArray; }; -class BitImage -{ +class BitImage { public: - BitImage() : m_width(0), m_height(0), m_rowStride(0), m_data(MemTag::BitImage) {} + BitImage() : + m_width(0), m_height(0), m_rowStride(0), m_data(MemTag::BitImage) {} - BitImage(uint32_t w, uint32_t h) : m_width(w), m_height(h), m_data(MemTag::BitImage) - { + BitImage(uint32_t w, uint32_t h) : + m_width(w), m_height(h), m_data(MemTag::BitImage) { m_rowStride = (m_width + 63) >> 6; m_data.resize(m_rowStride * m_height); m_data.zeroOutMemory(); @@ -1370,16 +1270,14 @@ public: uint32_t width() const { return m_width; } uint32_t height() const { return m_height; } - void copyTo(BitImage &other) - { + void copyTo(BitImage &other) { other.m_width = m_width; other.m_height = m_height; other.m_rowStride = m_rowStride; m_data.copyTo(other.m_data); } - void resize(uint32_t w, uint32_t h, bool discard) - { + void resize(uint32_t w, uint32_t h, bool discard) { const uint32_t rowStride = (w + 63) >> 6; if (discard) { m_data.resize(rowStride * h); @@ -1403,28 +1301,24 @@ public: m_rowStride = rowStride; } - bool get(uint32_t x, uint32_t y) const - { + bool get(uint32_t x, uint32_t y) const { XA_DEBUG_ASSERT(x < m_width && y < m_height); const uint32_t index = (x >> 6) + y * m_rowStride; return (m_data[index] & (UINT64_C(1) << (uint64_t(x) & UINT64_C(63)))) != 0; } - void set(uint32_t x, uint32_t y) - { + void set(uint32_t x, uint32_t y) { XA_DEBUG_ASSERT(x < m_width && y < m_height); const uint32_t index = (x >> 6) + y * m_rowStride; m_data[index] |= UINT64_C(1) << (uint64_t(x) & UINT64_C(63)); XA_DEBUG_ASSERT(get(x, y)); } - void zeroOutMemory() - { + void zeroOutMemory() { m_data.zeroOutMemory(); } - bool canBlit(const BitImage &image, uint32_t offsetX, uint32_t offsetY) const - { + bool canBlit(const BitImage &image, uint32_t offsetX, uint32_t offsetY) const { for (uint32_t y = 0; y < image.m_height; y++) { const uint32_t thisY = y + offsetY; if (thisY >= m_height) @@ -1448,8 +1342,7 @@ public: return true; } - void dilate(uint32_t padding) - { + void dilate(uint32_t padding) { BitImage tmp(m_width, m_height); for (uint32_t p = 0; p < padding; p++) { tmp.zeroOutMemory(); @@ -1486,11 +1379,10 @@ private: }; // From Fast-BVH -class BVH -{ +class BVH { public: - BVH(const Array<AABB> &objectAabbs, uint32_t leafSize = 4) : m_objectIds(MemTag::BVH), m_nodes(MemTag::BVH) - { + BVH(const Array<AABB> &objectAabbs, uint32_t leafSize = 4) : + m_objectIds(MemTag::BVH), m_nodes(MemTag::BVH) { m_objectAabbs = &objectAabbs; if (m_objectAabbs->isEmpty()) return; @@ -1510,7 +1402,7 @@ public: Node node; m_nodes.reserve(objectAabbs.size() * 2); uint32_t nNodes = 0; - while(stackptr > 0) { + while (stackptr > 0) { // Pop the next item off of the stack const BuildEntry &bnode = todo[--stackptr]; const uint32_t start = bnode.start; @@ -1523,7 +1415,7 @@ public: // Calculate the bounding box for this node AABB bb(objectAabbs[m_objectIds[start]]); AABB bc(objectAabbs[m_objectIds[start]].centroid()); - for(uint32_t p = start + 1; p < end; ++p) { + for (uint32_t p = start + 1; p < end; ++p) { bb.expandToInclude(objectAabbs[m_objectIds[p]]); bc.expandToInclude(objectAabbs[m_objectIds[p]].centroid()); } @@ -1539,7 +1431,7 @@ public: m_nodes[bnode.parent].rightOffset--; // When this is the second touch, this is the right child. // The right child sets up the offset for the flat tree. - if (m_nodes[bnode.parent].rightOffset == kTouchedTwice ) + if (m_nodes[bnode.parent].rightOffset == kTouchedTwice) m_nodes[bnode.parent].rightOffset = nNodes - 1 - bnode.parent; } // If this is a leaf, no need to subdivide. @@ -1574,21 +1466,20 @@ public: } } - void query(const AABB &queryAabb, Array<uint32_t> &result) const - { + void query(const AABB &queryAabb, Array<uint32_t> &result) const { result.clear(); // Working set uint32_t todo[64]; int32_t stackptr = 0; // "Push" on the root node to the working set todo[stackptr] = 0; - while(stackptr >= 0) { + while (stackptr >= 0) { // Pop off the next node to work on. const int ni = todo[stackptr--]; const Node &node = m_nodes[ni]; // Is leaf -> Intersect if (node.rightOffset == 0) { - for(uint32_t o = 0; o < node.nPrims; ++o) { + for (uint32_t o = 0; o < node.nPrims; ++o) { const uint32_t obj = node.start + o; if (queryAabb.intersect((*m_objectAabbs)[m_objectIds[obj]])) result.push_back(m_objectIds[obj]); @@ -1605,14 +1496,12 @@ public: } private: - struct BuildEntry - { + struct BuildEntry { uint32_t parent; // If non-zero then this is the index of the parent. (used in offsets) uint32_t start, end; // The range of objects in the object list covered by this node. }; - struct Node - { + struct Node { AABB aabb; uint32_t start, nPrims, rightOffset; }; @@ -1622,10 +1511,8 @@ private: Array<Node> m_nodes; }; -struct Fit -{ - static bool computeBasis(const Vector3 *points, uint32_t pointsCount, Basis *basis) - { +struct Fit { + static bool computeBasis(const Vector3 *points, uint32_t pointsCount, Basis *basis) { if (computeLeastSquaresNormal(points, pointsCount, &basis->normal)) { basis->tangent = Basis::computeTangent(basis->normal); basis->bitangent = Basis::computeBitangent(basis->normal, basis->tangent); @@ -1639,8 +1526,7 @@ private: // Fast, and accurate to within a few degrees. // Returns None if the points do not span a plane. // https://www.ilikebigbits.com/2015_03_04_plane_from_points.html - static bool computeLeastSquaresNormal(const Vector3 *points, uint32_t pointsCount, Vector3 *normal) - { + static bool computeLeastSquaresNormal(const Vector3 *points, uint32_t pointsCount, Vector3 *normal) { XA_DEBUG_ASSERT(pointsCount >= 3); if (pointsCount == 3) { *normal = normalize(cross(points[2] - points[0], points[1] - points[0]), kEpsilon); @@ -1705,7 +1591,7 @@ private: // Pick path with best conditioning: Vector3 dir(0.0f); if (det_max == det_x) - dir = Vector3(det_x,xz * yz - xy * zz,xy * yz - xz * yy); + dir = Vector3(det_x, xz * yz - xy * zz, xy * yz - xz * yy); else if (det_max == det_y) dir = Vector3(xz * yz - xy * zz, det_y, xy * xz - yz * xx); else if (det_max == det_z) @@ -1718,8 +1604,7 @@ private: return isNormalized(*normal); } - static bool computeEigen(const Vector3 *points, uint32_t pointsCount, Basis *basis) - { + static bool computeEigen(const Vector3 *points, uint32_t pointsCount, Basis *basis) { float matrix[6]; computeCovariance(pointsCount, points, matrix); if (matrix[0] == 0 && matrix[3] == 0 && matrix[5] == 0) @@ -1734,8 +1619,7 @@ private: return true; } - static Vector3 computeCentroid(int n, const Vector3 * points) - { + static Vector3 computeCentroid(int n, const Vector3 *points) { Vector3 centroid(0.0f); for (int i = 0; i < n; i++) { centroid += points[i]; @@ -1744,8 +1628,7 @@ private: return centroid; } - static Vector3 computeCovariance(int n, const Vector3 * points, float * covariance) - { + static Vector3 computeCovariance(int n, const Vector3 *points, float *covariance) { // compute the centroid Vector3 centroid = computeCentroid(n, points); // compute covariance matrix @@ -1767,8 +1650,7 @@ private: // Tridiagonal solver from Charles Bloom. // Householder transforms followed by QL decomposition. // Seems to be based on the code from Numerical Recipes in C. - static bool eigenSolveSymmetric3(const float matrix[6], float eigenValues[3], Vector3 eigenVectors[3]) - { + static bool eigenSolveSymmetric3(const float matrix[6], float eigenValues[3], Vector3 eigenVectors[3]) { XA_DEBUG_ASSERT(matrix != nullptr && eigenValues != nullptr && eigenVectors != nullptr); float subd[3]; float diag[3]; @@ -1793,7 +1675,7 @@ private: // eigenvectors are the columns; make them the rows : for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { - (&eigenVectors[j].x)[i] = (float) work[i][j]; + (&eigenVectors[j].x)[i] = (float)work[i][j]; } } // shuffle to sort by singular value : @@ -1815,8 +1697,7 @@ private: } private: - static void EigenSolver3_Tridiagonal(float mat[3][3], float *diag, float *subd) - { + static void EigenSolver3_Tridiagonal(float mat[3][3], float *diag, float *subd) { // Householder reduction T = Q^t M Q // Input: // mat, symmetric 3x3 matrix M @@ -1868,8 +1749,7 @@ private: } } - static bool EigenSolver3_QLAlgorithm(float mat[3][3], float *diag, float *subd) - { + static bool EigenSolver3_QLAlgorithm(float mat[3][3], float *diag, float *subd) { // QL iteration with implicit shifting to reduce matrix from tridiagonal // to diagonal const int maxiter = 32; @@ -1879,21 +1759,21 @@ private: int m; for (m = ell; m <= 1; m++) { float dd = fabsf(diag[m]) + fabsf(diag[m + 1]); - if ( fabsf(subd[m]) + dd == dd ) + if (fabsf(subd[m]) + dd == dd) break; } - if ( m == ell ) + if (m == ell) break; float g = (diag[ell + 1] - diag[ell]) / (2 * subd[ell]); float r = sqrtf(g * g + 1); - if ( g < 0 ) + if (g < 0) g = diag[m] - diag[ell] + subd[ell] / (g - r); else g = diag[m] - diag[ell] + subd[ell] / (g + r); float s = 1, c = 1, p = 0; for (int i = m - 1; i >= ell; i--) { float f = s * subd[i], b = c * subd[i]; - if ( fabsf(f) >= fabsf(g) ) { + if (fabsf(f) >= fabsf(g)) { c = g / f; r = sqrtf(c * c + 1); subd[i + 1] = f * r; @@ -1919,7 +1799,7 @@ private: subd[ell] = g; subd[m] = 0; } - if ( iter == maxiter ) + if (iter == maxiter) // should not get here under normal circumstances return false; } @@ -1928,18 +1808,18 @@ private: }; /// Fixed size vector class. -class FullVector -{ +class FullVector { public: - FullVector(uint32_t dim) : m_array(MemTag::FullVector) { m_array.resize(dim); } - FullVector(const FullVector &v) : m_array(MemTag::FullVector) { v.m_array.copyTo(m_array); } + FullVector(uint32_t dim) : + m_array(MemTag::FullVector) { m_array.resize(dim); } + FullVector(const FullVector &v) : + m_array(MemTag::FullVector) { v.m_array.copyTo(m_array); } FullVector &operator=(const FullVector &v) = delete; XA_INLINE uint32_t dimension() const { return m_array.size(); } XA_INLINE const float &operator[](uint32_t index) const { return m_array[index]; } XA_INLINE float &operator[](uint32_t index) { return m_array[index]; } - void fill(float f) - { + void fill(float f) { const uint32_t dim = dimension(); for (uint32_t i = 0; i < dim; i++) m_array[i] = f; @@ -1949,22 +1829,19 @@ private: Array<float> m_array; }; -template<typename Key, typename H = Hash<Key>, typename E = Equal<Key> > -class HashMap -{ +template <typename Key, typename H = Hash<Key>, typename E = Equal<Key>> +class HashMap { public: - HashMap(int memTag, uint32_t size) : m_memTag(memTag), m_size(size), m_numSlots(0), m_slots(nullptr), m_keys(memTag), m_next(memTag) - { + HashMap(int memTag, uint32_t size) : + m_memTag(memTag), m_size(size), m_numSlots(0), m_slots(nullptr), m_keys(memTag), m_next(memTag) { } - ~HashMap() - { + ~HashMap() { if (m_slots) XA_FREE(m_slots); } - void add(const Key &key) - { + void add(const Key &key) { if (!m_slots) alloc(); const uint32_t hash = computeHash(key); @@ -1973,8 +1850,7 @@ public: m_slots[hash] = m_next.size() - 1; } - uint32_t get(const Key &key) const - { + uint32_t get(const Key &key) const { if (!m_slots) return UINT32_MAX; const uint32_t hash = computeHash(key); @@ -1988,8 +1864,7 @@ public: return UINT32_MAX; } - uint32_t getNext(uint32_t current) const - { + uint32_t getNext(uint32_t current) const { uint32_t i = m_next[current]; E equal; while (i != UINT32_MAX) { @@ -2001,8 +1876,7 @@ public: } private: - void alloc() - { + void alloc() { XA_DEBUG_ASSERT(m_size > 0); m_numSlots = nextPowerOfTwo(m_size); auto minNumSlots = uint32_t(m_size * 1.3); @@ -2015,8 +1889,7 @@ private: m_next.reserve(m_size); } - uint32_t computeHash(const Key &key) const - { + uint32_t computeHash(const Key &key) const { H hash; return hash(key) & (m_numSlots - 1); } @@ -2029,9 +1902,8 @@ private: Array<uint32_t> m_next; }; -template<typename T> -static void insertionSort(T *data, uint32_t length) -{ +template <typename T> +static void insertionSort(T *data, uint32_t length) { for (int32_t i = 1; i < (int32_t)length; i++) { T x = data[i]; int32_t j = i - 1; @@ -2043,21 +1915,18 @@ static void insertionSort(T *data, uint32_t length) } } -class KISSRng -{ +class KISSRng { public: KISSRng() { reset(); } - void reset() - { + void reset() { x = 123456789; y = 362436000; z = 521288629; c = 7654321; } - uint32_t getRange(uint32_t range) - { + uint32_t getRange(uint32_t range) { if (range == 0) return 0; x = 69069 * x + 12345; @@ -2076,20 +1945,18 @@ private: // Based on Pierre Terdiman's and Michael Herf's source code. // http://www.codercorner.com/RadixSortRevisited.htm // http://www.stereopsis.com/radix.html -class RadixSort -{ +class RadixSort { public: - RadixSort() : m_size(0), m_ranks(nullptr), m_ranks2(nullptr), m_validRanks(false) {} + RadixSort() : + m_size(0), m_ranks(nullptr), m_ranks2(nullptr), m_validRanks(false) {} - ~RadixSort() - { + ~RadixSort() { // Release everything XA_FREE(m_ranks2); XA_FREE(m_ranks); } - RadixSort &sort(const float *input, uint32_t count) - { + RadixSort &sort(const float *input, uint32_t count) { if (input == nullptr || count == 0) return *this; // Resize lists if needed if (count != m_size) { @@ -2115,20 +1982,17 @@ public: return *this; } - RadixSort &sort(const Array<float> &input) - { + RadixSort &sort(const Array<float> &input) { return sort(input.data(), input.size()); } // Access to results. m_ranks is a list of indices in sorted order, i.e. in the order you may further process your data - const uint32_t *ranks() const - { + const uint32_t *ranks() const { XA_DEBUG_ASSERT(m_validRanks); return m_ranks; } - uint32_t *ranks() - { + uint32_t *ranks() { XA_DEBUG_ASSERT(m_validRanks); return m_ranks; } @@ -2139,21 +2003,18 @@ private: uint32_t *m_ranks2; bool m_validRanks; - void FloatFlip(uint32_t &f) - { + void FloatFlip(uint32_t &f) { int32_t mask = (int32_t(f) >> 31) | 0x80000000; // Warren Hunt, Manchor Ko. f ^= mask; } - void IFloatFlip(uint32_t &f) - { + void IFloatFlip(uint32_t &f) { uint32_t mask = ((f >> 31) - 1) | 0x80000000; // Michael Herf. f ^= mask; } - template<typename T> - void createHistograms(const T *buffer, uint32_t count, uint32_t *histogram) - { + template <typename T> + void createHistograms(const T *buffer, uint32_t count, uint32_t *histogram) { const uint32_t bucketCount = sizeof(T); // (8 * sizeof(T)) / log2(radix) // Init bucket pointers. uint32_t *h[bucketCount]; @@ -2161,10 +2022,10 @@ private: h[i] = histogram + 256 * i; } // Clear histograms. - memset(histogram, 0, 256 * bucketCount * sizeof(uint32_t )); + memset(histogram, 0, 256 * bucketCount * sizeof(uint32_t)); // @@ Add support for signed integers. // Build histograms. - const uint8_t *p = (const uint8_t *)buffer; // @@ Does this break aliasing rules? + const uint8_t *p = (const uint8_t *)buffer; // @@ Does this break aliasing rules? const uint8_t *pe = p + count * sizeof(T); while (p != pe) { h[0][*p++]++, h[1][*p++]++, h[2][*p++]++, h[3][*p++]++; @@ -2179,8 +2040,8 @@ private: } } - template <typename T> void insertionSort(const T *input, uint32_t count) - { + template <typename T> + void insertionSort(const T *input, uint32_t count) { if (!m_validRanks) { m_ranks[0] = 0; for (uint32_t i = 1; i != count; ++i) { @@ -2210,8 +2071,8 @@ private: } } - template <typename T> void radixSort(const T *input, uint32_t count) - { + template <typename T> + void radixSort(const T *input, uint32_t count) { const uint32_t P = sizeof(T); // pass count // Allocate histograms & offsets on the stack uint32_t histogram[256 * P]; @@ -2229,7 +2090,8 @@ private: } // Create offsets link[0] = m_ranks2; - for (uint32_t i = 1; i < 256; i++) link[i] = link[i - 1] + h[i - 1]; + for (uint32_t i = 1; i < 256; i++) + link[i] = link[i - 1] + h[i - 1]; // Perform Radix Sort if (!m_validRanks) { for (uint32_t i = 0; i < count; i++) { @@ -2256,25 +2118,21 @@ private: }; // Wrapping this in a class allows temporary arrays to be re-used. -class BoundingBox2D -{ +class BoundingBox2D { public: Vector2 majorAxis, minorAxis, minCorner, maxCorner; - void clear() - { + void clear() { m_boundaryVertices.clear(); } - void appendBoundaryVertex(Vector2 v) - { + void appendBoundaryVertex(Vector2 v) { m_boundaryVertices.push_back(v); } // This should compute convex hull and use rotating calipers to find the best box. Currently it uses a brute force method. // If vertices is null or vertexCount is 0, the boundary vertices are used. - void compute(const Vector2 *vertices = nullptr, uint32_t vertexCount = 0) - { + void compute(const Vector2 *vertices = nullptr, uint32_t vertexCount = 0) { if (!vertices || vertexCount == 0) { vertices = m_boundaryVertices.data(); vertexCount = m_boundaryVertices.size(); @@ -2322,8 +2180,7 @@ public: private: // Compute the convex hull using Graham Scan. - void convexHull(const Vector2 *input, uint32_t inputCount, Array<Vector2> &output, float epsilon) - { + void convexHull(const Vector2 *input, uint32_t inputCount, Array<Vector2> &output, float epsilon) { m_coords.resize(inputCount); for (uint32_t i = 0; i < inputCount; i++) m_coords[i] = input[i].x; @@ -2353,7 +2210,7 @@ private: XA_DEBUG_ASSERT(m_top.size() >= 2); output.push_back(m_top[0]); output.push_back(m_top[1]); - for (uint32_t i = 2; i < m_top.size(); ) { + for (uint32_t i = 2; i < m_top.size();) { Vector2 a = output[output.size() - 2]; Vector2 b = output[output.size() - 1]; Vector2 c = m_top[i]; @@ -2369,7 +2226,7 @@ private: XA_DEBUG_ASSERT(m_bottom.size() >= 2); output.push_back(m_bottom[1]); // Filter bottom list. - for (uint32_t i = 2; i < m_bottom.size(); ) { + for (uint32_t i = 2; i < m_bottom.size();) { Vector2 a = output[output.size() - 2]; Vector2 b = output[output.size() - 1]; Vector2 c = m_bottom[i]; @@ -2391,33 +2248,33 @@ private: Array<Vector2> m_top, m_bottom, m_hull; }; -static uint32_t meshEdgeFace(uint32_t edge) { return edge / 3; } -static uint32_t meshEdgeIndex0(uint32_t edge) { return edge; } +static uint32_t meshEdgeFace(uint32_t edge) { + return edge / 3; +} +static uint32_t meshEdgeIndex0(uint32_t edge) { + return edge; +} -static uint32_t meshEdgeIndex1(uint32_t edge) -{ +static uint32_t meshEdgeIndex1(uint32_t edge) { const uint32_t faceFirstEdge = edge / 3 * 3; return faceFirstEdge + (edge - faceFirstEdge + 1) % 3; } -struct MeshFlags -{ - enum - { - HasFaceGroups = 1<<0, - HasIgnoredFaces = 1<<1, - HasNormals = 1<<2 +struct MeshFlags { + enum { + HasFaceGroups = 1 << 0, + HasIgnoredFaces = 1 << 1, + HasNormals = 1 << 2 }; }; class Mesh; static void meshGetBoundaryLoops(const Mesh &mesh, Array<uint32_t> &boundaryLoops); -class Mesh -{ +class Mesh { public: - Mesh(float epsilon, uint32_t approxVertexCount, uint32_t approxFaceCount, uint32_t flags = 0, uint32_t id = UINT32_MAX) : m_epsilon(epsilon), m_flags(flags), m_id(id), m_faceIgnore(MemTag::Mesh), m_ignoredFaceCount(0), m_indices(MemTag::MeshIndices), m_positions(MemTag::MeshPositions), m_normals(MemTag::MeshNormals), m_texcoords(MemTag::MeshTexcoords), m_faceGroups(MemTag::Mesh), m_faceGroupFirstFace(MemTag::Mesh), m_faceGroupNextFace(MemTag::Mesh), m_faceGroupFaceCounts(MemTag::Mesh), m_colocalVertexCount(0), m_nextColocalVertex(MemTag::MeshColocals), m_boundaryEdges(MemTag::MeshBoundaries), m_oppositeEdges(MemTag::MeshBoundaries), m_nextBoundaryEdges(MemTag::MeshBoundaries), m_edgeMap(MemTag::MeshEdgeMap, approxFaceCount * 3) - { + Mesh(float epsilon, uint32_t approxVertexCount, uint32_t approxFaceCount, uint32_t flags = 0, uint32_t id = UINT32_MAX) : + m_epsilon(epsilon), m_flags(flags), m_id(id), m_faceIgnore(MemTag::Mesh), m_ignoredFaceCount(0), m_indices(MemTag::MeshIndices), m_positions(MemTag::MeshPositions), m_normals(MemTag::MeshNormals), m_texcoords(MemTag::MeshTexcoords), m_faceGroups(MemTag::Mesh), m_faceGroupFirstFace(MemTag::Mesh), m_faceGroupNextFace(MemTag::Mesh), m_faceGroupFaceCounts(MemTag::Mesh), m_colocalVertexCount(0), m_nextColocalVertex(MemTag::MeshColocals), m_boundaryEdges(MemTag::MeshBoundaries), m_oppositeEdges(MemTag::MeshBoundaries), m_nextBoundaryEdges(MemTag::MeshBoundaries), m_edgeMap(MemTag::MeshEdgeMap, approxFaceCount * 3) { m_indices.reserve(approxFaceCount * 3); m_positions.reserve(approxVertexCount); m_texcoords.reserve(approxVertexCount); @@ -2433,8 +2290,7 @@ public: uint32_t flags() const { return m_flags; } uint32_t id() const { return m_id; } - void addVertex(const Vector3 &pos, const Vector3 &normal = Vector3(0.0f), const Vector2 &texcoord = Vector2(0.0f)) - { + void addVertex(const Vector3 &pos, const Vector3 &normal = Vector3(0.0f), const Vector2 &texcoord = Vector2(0.0f)) { XA_DEBUG_ASSERT(isFinite(pos)); m_positions.push_back(pos); if (m_flags & MeshFlags::HasNormals) @@ -2442,17 +2298,14 @@ public: m_texcoords.push_back(texcoord); } - struct AddFaceResult - { - enum Enum - { + struct AddFaceResult { + enum Enum { OK, DuplicateEdge = 1 }; }; - AddFaceResult::Enum addFace(uint32_t v0, uint32_t v1, uint32_t v2, bool ignore = false, bool hashEdge = true) - { + AddFaceResult::Enum addFace(uint32_t v0, uint32_t v1, uint32_t v2, bool ignore = false, bool hashEdge = true) { uint32_t indexArray[3]; indexArray[0] = v0; indexArray[1] = v1; @@ -2460,8 +2313,7 @@ public: return addFace(indexArray, ignore, hashEdge); } - AddFaceResult::Enum addFace(const uint32_t *indices, bool ignore = false, bool hashEdge = true) - { + AddFaceResult::Enum addFace(const uint32_t *indices, bool ignore = false, bool hashEdge = true) { AddFaceResult::Enum result = AddFaceResult::OK; if (m_flags & MeshFlags::HasFaceGroups) m_faceGroups.push_back(kInvalidFaceGroup); @@ -2486,8 +2338,7 @@ public: return result; } - void createColocals() - { + void createColocals() { const uint32_t vertexCount = m_positions.size(); Array<AABB> aabbs(MemTag::BVH); aabbs.resize(vertexCount); @@ -2515,7 +2366,7 @@ public: if (colocals.size() == 1) { // No colocals for this vertex. m_nextColocalVertex[i] = i; - continue; + continue; } m_colocalVertexCount += colocals.size(); // Link in ascending order. @@ -2527,8 +2378,7 @@ public: } // Check if the face duplicates any edges of any face already in the group. - bool faceDuplicatesGroupEdge(uint16_t group, uint32_t face) const - { + bool faceDuplicatesGroupEdge(uint16_t group, uint32_t face) const { for (FaceEdgeIterator edgeIt(this, face); !edgeIt.isDone(); edgeIt.advance()) { for (ColocalEdgeIterator colocalEdgeIt(this, edgeIt.vertex0(), edgeIt.vertex1()); !colocalEdgeIt.isDone(); colocalEdgeIt.advance()) { if (m_faceGroups[meshEdgeFace(colocalEdgeIt.edge())] == group) @@ -2538,8 +2388,7 @@ public: return false; } - void createFaceGroups() - { + void createFaceGroups() { uint32_t firstUnassignedFace = 0; uint16_t group = 0; Array<uint32_t> growFaces; @@ -2619,8 +2468,7 @@ public: } } - void createBoundaries() - { + void createBoundaries() { const uint32_t edgeCount = m_indices.size(); const uint32_t vertexCount = m_positions.size(); m_oppositeEdges.resize(edgeCount); @@ -2650,8 +2498,7 @@ public: } } - void linkBoundaries() - { + void linkBoundaries() { const uint32_t edgeCount = m_indices.size(); HashMap<uint32_t> vertexToEdgeMap(MemTag::Mesh, edgeCount); // Edge is index / 2 for (uint32_t i = 0; i < edgeCount; i++) { @@ -2744,8 +2591,7 @@ public: } /// Find edge, test all colocals. - uint32_t findEdge(uint32_t vertex0, uint32_t vertex1) const - { + uint32_t findEdge(uint32_t vertex0, uint32_t vertex1) const { uint32_t result = UINT32_MAX; if (m_nextColocalVertex.isEmpty()) { EdgeKey key(vertex0, vertex1); @@ -2784,8 +2630,7 @@ public: } #if XA_DEBUG_EXPORT_OBJ - void writeObjVertices(FILE *file) const - { + void writeObjVertices(FILE *file) const { for (uint32_t i = 0; i < m_positions.size(); i++) fprintf(file, "v %g %g %g\n", m_positions[i].x, m_positions[i].y, m_positions[i].z); if (m_flags & MeshFlags::HasNormals) { @@ -2796,8 +2641,7 @@ public: fprintf(file, "vt %g %g\n", m_texcoords[i].x, m_texcoords[i].y); } - void writeObjFace(FILE *file, uint32_t face) const - { + void writeObjFace(FILE *file, uint32_t face) const { fprintf(file, "f "); for (uint32_t j = 0; j < 3; j++) { const uint32_t index = m_indices[face * 3 + j] + 1; // 1-indexed @@ -2805,8 +2649,7 @@ public: } } - void writeObjBoundaryEges(FILE *file) const - { + void writeObjBoundaryEges(FILE *file) const { if (m_oppositeEdges.isEmpty()) return; // Boundaries haven't been created. fprintf(file, "o boundary_edges\n"); @@ -2817,8 +2660,7 @@ public: } } - void writeObjLinkedBoundaries(FILE *file) const - { + void writeObjLinkedBoundaries(FILE *file) const { if (m_oppositeEdges.isEmpty() || m_nextBoundaryEdges.isEmpty()) return; // Boundaries haven't been created and/or linked. Array<uint32_t> boundaryLoops; @@ -2840,8 +2682,7 @@ public: } } - void writeObjFile(const char *filename) const - { + void writeObjFile(const char *filename) const { FILE *file; XA_FOPEN(file, filename, "w"); if (!file) @@ -2857,8 +2698,7 @@ public: } #endif - float computeSurfaceArea() const - { + float computeSurfaceArea() const { float area = 0; for (uint32_t f = 0; f < faceCount(); f++) area += computeFaceArea(f); @@ -2866,24 +2706,21 @@ public: return area; } - float computeParametricArea() const - { + float computeParametricArea() const { float area = 0; for (uint32_t f = 0; f < faceCount(); f++) area += computeFaceParametricArea(f); return fabsf(area); // May be negative, depends on texcoord winding. } - float computeFaceArea(uint32_t face) const - { + float computeFaceArea(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; return length(cross(p1 - p0, p2 - p0)) * 0.5f; } - Vector3 computeFaceCentroid(uint32_t face) const - { + Vector3 computeFaceCentroid(uint32_t face) const { Vector3 sum(0.0f); for (uint32_t i = 0; i < 3; i++) sum += m_positions[m_indices[face * 3 + i]]; @@ -2892,8 +2729,7 @@ public: // Average of the edge midpoints weighted by the edge length. // I want a point inside the triangle, but closer to the cirumcenter. - Vector3 computeFaceCenter(uint32_t face) const - { + Vector3 computeFaceCenter(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; @@ -2906,8 +2742,7 @@ public: return m0 + m1 + m2; } - Vector3 computeFaceNormal(uint32_t face) const - { + Vector3 computeFaceNormal(uint32_t face) const { const Vector3 &p0 = m_positions[m_indices[face * 3 + 0]]; const Vector3 &p1 = m_positions[m_indices[face * 3 + 1]]; const Vector3 &p2 = m_positions[m_indices[face * 3 + 2]]; @@ -2917,17 +2752,15 @@ public: return normalizeSafe(normalAreaScaled, Vector3(0, 0, 1), 0.0f); } - float computeFaceParametricArea(uint32_t face) const - { + float computeFaceParametricArea(uint32_t face) const { const Vector2 &t0 = m_texcoords[m_indices[face * 3 + 0]]; const Vector2 &t1 = m_texcoords[m_indices[face * 3 + 1]]; const Vector2 &t2 = m_texcoords[m_indices[face * 3 + 2]]; return triangleArea(t0, t1, t2); } - + // @@ This is not exactly accurate, we should compare the texture coordinates... - bool isSeam(uint32_t edge) const - { + bool isSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_oppositeEdges[edge]; if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -2938,8 +2771,7 @@ public: return m_indices[e0] != m_indices[oe1] || m_indices[e1] != m_indices[oe0]; } - bool isTextureSeam(uint32_t edge) const - { + bool isTextureSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_oppositeEdges[edge]; if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -2950,8 +2782,7 @@ public: return m_texcoords[m_indices[e0]] != m_texcoords[m_indices[oe1]] || m_texcoords[m_indices[e1]] != m_texcoords[m_indices[oe0]]; } - uint32_t firstColocal(uint32_t vertex) const - { + uint32_t firstColocal(uint32_t vertex) const { for (ColocalVertexIterator it(this, vertex); !it.isDone(); it.advance()) { if (it.vertex() < vertex) vertex = it.vertex(); @@ -2959,8 +2790,7 @@ public: return vertex; } - bool areColocal(uint32_t vertex0, uint32_t vertex1) const - { + bool areColocal(uint32_t vertex0, uint32_t vertex1) const { if (vertex0 == vertex1) return true; if (m_nextColocalVertex.isEmpty()) @@ -2982,17 +2812,32 @@ public: XA_INLINE uint32_t vertexCount() const { return m_positions.size(); } XA_INLINE uint32_t vertexAt(uint32_t i) const { return m_indices[i]; } XA_INLINE const Vector3 &position(uint32_t vertex) const { return m_positions[vertex]; } - XA_INLINE const Vector3 &normal(uint32_t vertex) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasNormals); return m_normals[vertex]; } + XA_INLINE const Vector3 &normal(uint32_t vertex) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasNormals); + return m_normals[vertex]; + } XA_INLINE const Vector2 &texcoord(uint32_t vertex) const { return m_texcoords[vertex]; } XA_INLINE Vector2 &texcoord(uint32_t vertex) { return m_texcoords[vertex]; } XA_INLINE const Vector2 *texcoords() const { return m_texcoords.data(); } XA_INLINE Vector2 *texcoords() { return m_texcoords.data(); } XA_INLINE uint32_t ignoredFaceCount() const { return m_ignoredFaceCount; } XA_INLINE uint32_t faceCount() const { return m_indices.size() / 3; } - XA_INLINE uint16_t faceGroupAt(uint32_t face) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroups[face]; } - XA_INLINE uint32_t faceGroupCount() const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupFaceCounts.size(); } - XA_INLINE uint32_t faceGroupNextFace(uint32_t face) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupNextFace[face]; } - XA_INLINE uint32_t faceGroupFaceCount(uint32_t group) const { XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); return m_faceGroupFaceCounts[group]; } + XA_INLINE uint16_t faceGroupAt(uint32_t face) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroups[face]; + } + XA_INLINE uint32_t faceGroupCount() const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupFaceCounts.size(); + } + XA_INLINE uint32_t faceGroupNextFace(uint32_t face) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupNextFace[face]; + } + XA_INLINE uint32_t faceGroupFaceCount(uint32_t group) const { + XA_DEBUG_ASSERT(m_flags & MeshFlags::HasFaceGroups); + return m_faceGroupFaceCounts[group]; + } XA_INLINE const uint32_t *indices() const { return m_indices.data(); } XA_INLINE uint32_t indexCount() const { return m_indices.size(); } @@ -3027,49 +2872,45 @@ private: // Populated by linkBoundaries Array<uint32_t> m_nextBoundaryEdges; // The index of the next boundary edge. UINT32_MAX if the edge is not a boundary edge. - struct EdgeKey - { + struct EdgeKey { EdgeKey() {} - EdgeKey(const EdgeKey &k) : v0(k.v0), v1(k.v1) {} - EdgeKey(uint32_t v0, uint32_t v1) : v0(v0), v1(v1) {} + EdgeKey(const EdgeKey &k) : + v0(k.v0), v1(k.v1) {} + EdgeKey(uint32_t v0, uint32_t v1) : + v0(v0), v1(v1) {} bool operator==(const EdgeKey &k) const { return v0 == k.v0 && v1 == k.v1; } uint32_t v0; uint32_t v1; }; - struct EdgeHash - { + struct EdgeHash { uint32_t operator()(const EdgeKey &k) const { return k.v0 * 32768u + k.v1; } }; HashMap<EdgeKey, EdgeHash> m_edgeMap; public: - class BoundaryLoopEdgeIterator - { + class BoundaryLoopEdgeIterator { public: - BoundaryLoopEdgeIterator(const Mesh *mesh, uint32_t edge) : m_mesh(mesh), m_first(UINT32_MAX), m_current(edge) {} + BoundaryLoopEdgeIterator(const Mesh *mesh, uint32_t edge) : + m_mesh(mesh), m_first(UINT32_MAX), m_current(edge) {} - void advance() - { + void advance() { if (m_first == UINT32_MAX) m_first = m_current; m_current = m_mesh->m_nextBoundaryEdges[m_current]; } - bool isDone() const - { + bool isDone() const { return m_first == m_current || m_current == UINT32_MAX; } - uint32_t edge() const - { + uint32_t edge() const { return m_current; } - uint32_t nextEdge() const - { + uint32_t nextEdge() const { return m_mesh->m_nextBoundaryEdges[m_current]; } @@ -3079,31 +2920,27 @@ public: uint32_t m_current; }; - class ColocalVertexIterator - { + class ColocalVertexIterator { public: - ColocalVertexIterator(const Mesh *mesh, uint32_t v) : m_mesh(mesh), m_first(UINT32_MAX), m_current(v) {} + ColocalVertexIterator(const Mesh *mesh, uint32_t v) : + m_mesh(mesh), m_first(UINT32_MAX), m_current(v) {} - void advance() - { + void advance() { if (m_first == UINT32_MAX) m_first = m_current; if (!m_mesh->m_nextColocalVertex.isEmpty()) m_current = m_mesh->m_nextColocalVertex[m_current]; } - bool isDone() const - { + bool isDone() const { return m_first == m_current; } - uint32_t vertex() const - { + uint32_t vertex() const { return m_current; } - const Vector3 *pos() const - { + const Vector3 *pos() const { return &m_mesh->m_positions[m_current]; } @@ -3113,39 +2950,33 @@ public: uint32_t m_current; }; - class ColocalEdgeIterator - { + class ColocalEdgeIterator { public: - ColocalEdgeIterator(const Mesh *mesh, uint32_t vertex0, uint32_t vertex1) : m_mesh(mesh), m_vertex0It(mesh, vertex0), m_vertex1It(mesh, vertex1), m_vertex1(vertex1) - { + ColocalEdgeIterator(const Mesh *mesh, uint32_t vertex0, uint32_t vertex1) : + m_mesh(mesh), m_vertex0It(mesh, vertex0), m_vertex1It(mesh, vertex1), m_vertex1(vertex1) { do { if (!resetElement()) { advanceVertex1(); - } - else { + } else { break; } } while (!isDone()); } - void advance() - { + void advance() { advanceElement(); } - bool isDone() const - { + bool isDone() const { return m_vertex0It.isDone() && m_vertex1It.isDone() && m_edge == UINT32_MAX; } - uint32_t edge() const - { + uint32_t edge() const { return m_edge; } private: - bool resetElement() - { + bool resetElement() { m_edge = m_mesh->m_edgeMap.get(Mesh::EdgeKey(m_vertex0It.vertex(), m_vertex1It.vertex())); while (m_edge != UINT32_MAX) { if (!isIgnoredFace()) @@ -3158,8 +2989,7 @@ public: return true; } - void advanceElement() - { + void advanceElement() { for (;;) { m_edge = m_mesh->m_edgeMap.getNext(m_edge); if (m_edge == UINT32_MAX) @@ -3171,17 +3001,15 @@ public: advanceVertex1(); } - void advanceVertex1() - { + void advanceVertex1() { auto successful = false; - while (!successful) { + while (!successful) { m_vertex1It.advance(); if (m_vertex1It.isDone()) { if (!m_vertex0It.isDone()) { m_vertex0It.advance(); m_vertex1It = ColocalVertexIterator(m_mesh, m_vertex1); - } - else { + } else { return; } } @@ -3189,8 +3017,7 @@ public: } } - bool isIgnoredFace() const - { + bool isIgnoredFace() const { return m_mesh->m_faceIgnore[meshEdgeFace(m_edge)]; } @@ -3200,24 +3027,21 @@ public: uint32_t m_edge; }; - class FaceEdgeIterator - { + class FaceEdgeIterator { public: - FaceEdgeIterator (const Mesh *mesh, uint32_t face) : m_mesh(mesh), m_face(face), m_relativeEdge(0) - { + FaceEdgeIterator(const Mesh *mesh, uint32_t face) : + m_mesh(mesh), m_face(face), m_relativeEdge(0) { m_edge = m_face * 3; } - void advance() - { + void advance() { if (m_relativeEdge < 3) { m_edge++; m_relativeEdge++; } } - bool isDone() const - { + bool isDone() const { return m_relativeEdge == 3; } @@ -3228,9 +3052,8 @@ public: uint32_t relativeEdge() const { return m_relativeEdge; } uint32_t face() const { return m_face; } uint32_t oppositeEdge() const { return m_mesh->m_oppositeEdges[m_edge]; } - - uint32_t oppositeFace() const - { + + uint32_t oppositeFace() const { const uint32_t oedge = m_mesh->m_oppositeEdges[m_edge]; if (oedge == UINT32_MAX) return UINT32_MAX; @@ -3253,27 +3076,23 @@ public: uint32_t m_relativeEdge; }; - class GroupFaceIterator - { + class GroupFaceIterator { public: - GroupFaceIterator(const Mesh *mesh, uint32_t group) : m_mesh(mesh) - { + GroupFaceIterator(const Mesh *mesh, uint32_t group) : + m_mesh(mesh) { XA_DEBUG_ASSERT(group != UINT32_MAX); m_current = mesh->m_faceGroupFirstFace[group]; } - void advance() - { + void advance() { m_current = m_mesh->m_faceGroupNextFace[m_current]; } - bool isDone() const - { + bool isDone() const { return m_current == UINT32_MAX; } - uint32_t face() const - { + uint32_t face() const { return m_current; } @@ -3285,8 +3104,7 @@ public: constexpr uint16_t Mesh::kInvalidFaceGroup; -static bool meshCloseHole(Mesh *mesh, const Array<uint32_t> &holeVertices, const Vector3 &normal) -{ +static bool meshCloseHole(Mesh *mesh, const Array<uint32_t> &holeVertices, const Vector3 &normal) { #if XA_CLOSE_HOLES_CHECK_EDGE_INTERSECTION const uint32_t faceCount = mesh->faceCount(); #endif @@ -3412,8 +3230,7 @@ static bool meshCloseHole(Mesh *mesh, const Array<uint32_t> &holeVertices, const return true; } -static bool meshCloseHoles(Mesh *mesh, const Array<uint32_t> &boundaryLoops, const Vector3 &normal, uint32_t *holeCount, Array<uint32_t> *holeFaceCounts) -{ +static bool meshCloseHoles(Mesh *mesh, const Array<uint32_t> &boundaryLoops, const Vector3 &normal, uint32_t *holeCount, Array<uint32_t> *holeFaceCounts) { if (holeFaceCounts) holeFaceCounts->clear(); // Compute lengths. @@ -3469,8 +3286,7 @@ static bool meshCloseHoles(Mesh *mesh, const Array<uint32_t> &boundaryLoops, con return result; } -static bool meshIsPlanar(const Mesh &mesh) -{ +static bool meshIsPlanar(const Mesh &mesh) { const Vector3 p1 = mesh.position(mesh.vertexAt(0)); const Vector3 p2 = mesh.position(mesh.vertexAt(1)); const Vector3 p3 = mesh.position(mesh.vertexAt(2)); @@ -3496,14 +3312,12 @@ Fixing T-junctions. - Split edge. */ -struct SplitEdge -{ +struct SplitEdge { uint32_t edge; float t; uint32_t vertex; - bool operator<(const SplitEdge &other) const - { + bool operator<(const SplitEdge &other) const { if (edge < other.edge) return true; else if (edge == other.edge) { @@ -3515,8 +3329,7 @@ struct SplitEdge }; // Returns nullptr if there were no t-junctions to fix. -static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool *failed, uint32_t *fixedTJunctionsCount) -{ +static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool *failed, uint32_t *fixedTJunctionsCount) { if (duplicatedEdge) *duplicatedEdge = false; if (failed) @@ -3591,8 +3404,7 @@ static Mesh *meshFixTJunctions(const Mesh &inputMesh, bool *duplicatedEdge, bool } // boundaryLoops are the first edges for each boundary loop. -static void meshGetBoundaryLoops(const Mesh &mesh, Array<uint32_t> &boundaryLoops) -{ +static void meshGetBoundaryLoops(const Mesh &mesh, Array<uint32_t> &boundaryLoops) { const uint32_t edgeCount = mesh.edgeCount(); BitArray bitFlags(edgeCount); bitFlags.zeroOutMemory(); @@ -3607,26 +3419,23 @@ static void meshGetBoundaryLoops(const Mesh &mesh, Array<uint32_t> &boundaryLoop } } -struct Progress -{ - Progress(ProgressCategory::Enum category, ProgressFunc func, void *userData, uint32_t maxValue) : value(0), cancel(false), m_category(category), m_func(func), m_userData(userData), m_maxValue(maxValue), m_progress(0) - { +struct Progress { + Progress(ProgressCategory::Enum category, ProgressFunc func, void *userData, uint32_t maxValue) : + value(0), cancel(false), m_category(category), m_func(func), m_userData(userData), m_maxValue(maxValue), m_progress(0) { if (m_func) { if (!m_func(category, 0, userData)) cancel = true; } } - ~Progress() - { + ~Progress() { if (m_func) { if (!m_func(m_category, 100, m_userData)) cancel = true; } } - void update() - { + void update() { if (!m_func) return; m_mutex.lock(); @@ -3639,8 +3448,7 @@ struct Progress m_mutex.unlock(); } - void setMaxValue(uint32_t maxValue) - { + void setMaxValue(uint32_t maxValue) { m_mutex.lock(); m_maxValue = maxValue; m_mutex.unlock(); @@ -3658,32 +3466,31 @@ private: std::mutex m_mutex; }; -struct Spinlock -{ - void lock() { while(m_lock.test_and_set(std::memory_order_acquire)) {} } +struct Spinlock { + void lock() { + while (m_lock.test_and_set(std::memory_order_acquire)) { + } + } void unlock() { m_lock.clear(std::memory_order_release); } private: std::atomic_flag m_lock = ATOMIC_FLAG_INIT; }; -struct TaskGroupHandle -{ +struct TaskGroupHandle { uint32_t value = UINT32_MAX; }; -struct Task -{ +struct Task { void (*func)(void *userData); void *userData; }; #if XA_MULTITHREADED -class TaskScheduler -{ +class TaskScheduler { public: - TaskScheduler() : m_shutdown(false) - { + TaskScheduler() : + m_shutdown(false) { m_threadIndex = 0; // Max with current task scheduler usage is 1 per thread + 1 deep nesting, but allow for some slop. m_maxGroups = std::thread::hardware_concurrency() * 4; @@ -3701,8 +3508,7 @@ public: } } - ~TaskScheduler() - { + ~TaskScheduler() { m_shutdown = true; for (uint32_t i = 0; i < m_workers.size(); i++) { Worker &worker = m_workers[i]; @@ -3720,13 +3526,11 @@ public: XA_FREE(m_groups); } - uint32_t threadCount() const - { + uint32_t threadCount() const { return max(1u, std::thread::hardware_concurrency()); // Including the main thread. } - TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) - { + TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) { // Claim the first free group. for (uint32_t i = 0; i < m_maxGroups; i++) { TaskGroup &group = m_groups[i]; @@ -3748,8 +3552,7 @@ public: return handle; } - void run(TaskGroupHandle handle, Task task) - { + void run(TaskGroupHandle handle, Task task) { XA_DEBUG_ASSERT(handle.value != UINT32_MAX); TaskGroup &group = m_groups[handle.value]; group.queueLock.lock(); @@ -3763,8 +3566,7 @@ public: } } - void wait(TaskGroupHandle *handle) - { + void wait(TaskGroupHandle *handle) { if (handle->value == UINT32_MAX) { XA_DEBUG_ASSERT(false); return; @@ -3792,8 +3594,7 @@ public: static uint32_t currentThreadIndex() { return m_threadIndex; } private: - struct TaskGroup - { + struct TaskGroup { std::atomic<bool> free; Array<Task> queue; // Items are never removed. queueHead is incremented to pop items. uint32_t queueHead = 0; @@ -3801,8 +3602,7 @@ private: std::atomic<uint32_t> ref; // Increment when a task is enqueued, decrement when a task finishes. }; - struct Worker - { + struct Worker { std::thread *thread = nullptr; std::mutex mutex; std::condition_variable cv; @@ -3815,12 +3615,11 @@ private: std::atomic<bool> m_shutdown; static thread_local uint32_t m_threadIndex; - static void workerThread(TaskScheduler *scheduler, Worker *worker, uint32_t threadIndex) - { + static void workerThread(TaskScheduler *scheduler, Worker *worker, uint32_t threadIndex) { m_threadIndex = threadIndex; std::unique_lock<std::mutex> lock(worker->mutex); for (;;) { - worker->cv.wait(lock, [=]{ return worker->wakeup.load(); }); + worker->cv.wait(lock, [=] { return worker->wakeup.load(); }); worker->wakeup = false; for (;;) { if (scheduler->m_shutdown) @@ -3851,22 +3650,18 @@ private: thread_local uint32_t TaskScheduler::m_threadIndex; #else -class TaskScheduler -{ +class TaskScheduler { public: - ~TaskScheduler() - { + ~TaskScheduler() { for (uint32_t i = 0; i < m_groups.size(); i++) destroyGroup({ i }); } - uint32_t threadCount() const - { + uint32_t threadCount() const { return 1; } - TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) - { + TaskGroupHandle createTaskGroup(uint32_t reserveSize = 0) { TaskGroup *group = XA_NEW(MemTag::Default, TaskGroup); group->queue.reserve(reserveSize); m_groups.push_back(group); @@ -3875,13 +3670,11 @@ public: return handle; } - void run(TaskGroupHandle handle, Task task) - { + void run(TaskGroupHandle handle, Task task) { m_groups[handle.value]->queue.push_back(task); } - void wait(TaskGroupHandle *handle) - { + void wait(TaskGroupHandle *handle) { if (handle->value == UINT32_MAX) { XA_DEBUG_ASSERT(false); return; @@ -3897,8 +3690,7 @@ public: static uint32_t currentThreadIndex() { return 0; } private: - void destroyGroup(TaskGroupHandle handle) - { + void destroyGroup(TaskGroupHandle handle) { TaskGroup *group = m_groups[handle.value]; if (group) { group->~TaskGroup(); @@ -3907,8 +3699,7 @@ private: } } - struct TaskGroup - { + struct TaskGroup { Array<Task> queue; }; @@ -3921,8 +3712,7 @@ const uint8_t TGA_TYPE_RGB = 2; const uint8_t TGA_ORIGIN_UPPER = 0x20; #pragma pack(push, 1) -struct TgaHeader -{ +struct TgaHeader { uint8_t id_length; uint8_t colormap_type; uint8_t image_type; @@ -3939,8 +3729,7 @@ struct TgaHeader }; #pragma pack(pop) -static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, uint32_t height) -{ +static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, uint32_t height) { XA_DEBUG_ASSERT(sizeof(TgaHeader) == TgaHeader::Size); FILE *f; XA_FOPEN(f, filename, "wb"); @@ -3965,12 +3754,10 @@ static void WriteTga(const char *filename, const uint8_t *data, uint32_t width, } #endif -template<typename T> -class ThreadLocal -{ +template <typename T> +class ThreadLocal { public: - ThreadLocal() - { + ThreadLocal() { #if XA_MULTITHREADED const uint32_t n = std::thread::hardware_concurrency(); #else @@ -3981,8 +3768,7 @@ public: new (&m_array[i]) T; } - ~ThreadLocal() - { + ~ThreadLocal() { #if XA_MULTITHREADED const uint32_t n = std::thread::hardware_concurrency(); #else @@ -3993,8 +3779,7 @@ public: XA_FREE(m_array); } - T &get() const - { + T &get() const { return m_array[TaskScheduler::currentThreadIndex()]; } @@ -4002,11 +3787,9 @@ private: T *m_array; }; -class UniformGrid2 -{ +class UniformGrid2 { public: - void reset(const Vector2 *positions, const uint32_t *indices = nullptr, uint32_t reserveEdgeCount = 0) - { + void reset(const Vector2 *positions, const uint32_t *indices = nullptr, uint32_t reserveEdgeCount = 0) { m_edges.clear(); if (reserveEdgeCount > 0) m_edges.reserve(reserveEdgeCount); @@ -4015,14 +3798,12 @@ public: m_cellDataOffsets.clear(); } - void append(uint32_t edge) - { + void append(uint32_t edge) { XA_DEBUG_ASSERT(m_cellDataOffsets.isEmpty()); m_edges.push_back(edge); } - bool intersect(Vector2 v1, Vector2 v2, float epsilon) - { + bool intersect(Vector2 v1, Vector2 v2, float epsilon) { const uint32_t edgeCount = m_edges.size(); bool bruteForce = edgeCount <= 64; if (!bruteForce && m_cellDataOffsets.isEmpty()) @@ -4048,8 +3829,7 @@ public: return false; } - bool intersectSelf(float epsilon) - { + bool intersectSelf(float epsilon) { const uint32_t edgeCount = m_edges.size(); bool bruteForce = edgeCount <= 64; if (!bruteForce && m_cellDataOffsets.isEmpty()) @@ -4079,8 +3859,7 @@ public: } #if XA_DEBUG_EXPORT_BOUNDARY_GRID - void debugExport(const char *filename) - { + void debugExport(const char *filename) { Array<uint8_t> image; image.resize(m_gridWidth * m_gridHeight * 3); for (uint32_t y = 0; y < m_gridHeight; y++) { @@ -4102,8 +3881,7 @@ public: #endif private: - bool createGrid() - { + bool createGrid() { // Compute edge extents. Min will be the grid origin. const uint32_t edgeCount = m_edges.size(); Extents2 edgeExtents; @@ -4155,8 +3933,7 @@ private: return true; } - void computePotentialEdges(Vector2 p1, Vector2 p2) - { + void computePotentialEdges(Vector2 p1, Vector2 p2) { m_potentialEdges.clear(); traverse(p1, p2); for (uint32_t j = 0; j < m_traversedCellOffsets.size(); j++) { @@ -4174,8 +3951,7 @@ private: } // "A Fast Voxel Traversal Algorithm for Ray Tracing" - void traverse(Vector2 p1, Vector2 p2) - { + void traverse(Vector2 p1, Vector2 p2) { const Vector2 dir = p2 - p1; const Vector2 normal = normalizeSafe(dir, Vector2(0.0f), kEpsilon); const int stepX = dir.x >= 0 ? 1 : -1; @@ -4196,14 +3972,12 @@ private: if (normal.x > kEpsilon || normal.x < -kEpsilon) { tMaxX = (distToNextCellX * stepX) / normal.x; tDeltaX = (m_cellSize * stepX) / normal.x; - } - else + } else tMaxX = tDeltaX = FLT_MAX; if (normal.y > kEpsilon || normal.y < -kEpsilon) { tMaxY = (distToNextCellY * stepY) / normal.y; tDeltaY = (m_cellSize * stepY) / normal.y; - } - else + } else tMaxY = tDeltaY = FLT_MAX; m_traversedCellOffsets.clear(); m_traversedCellOffsets.push_back(firstCell[0] + firstCell[1] * m_gridWidth); @@ -4230,8 +4004,7 @@ private: } } - bool edgesIntersect(uint32_t edge1, uint32_t edge2, float epsilon) const - { + bool edgesIntersect(uint32_t edge1, uint32_t edge2, float epsilon) const { if (edge1 == edge2) return false; const uint32_t ai[2] = { vertexAt(meshEdgeIndex0(edge1)), vertexAt(meshEdgeIndex1(edge1)) }; @@ -4242,28 +4015,23 @@ private: return linesIntersect(m_positions[ai[0]], m_positions[ai[1]], m_positions[bi[0]], m_positions[bi[1]], epsilon); } - uint32_t cellX(float x) const - { + uint32_t cellX(float x) const { return min((uint32_t)max(0.0f, (x - m_gridOrigin.x) / m_cellSize), m_gridWidth - 1u); } - uint32_t cellY(float y) const - { + uint32_t cellY(float y) const { return min((uint32_t)max(0.0f, (y - m_gridOrigin.y) / m_cellSize), m_gridHeight - 1u); } - Vector2 edgePosition0(uint32_t edge) const - { + Vector2 edgePosition0(uint32_t edge) const { return m_positions[vertexAt(meshEdgeIndex0(edge))]; } - Vector2 edgePosition1(uint32_t edge) const - { + Vector2 edgePosition1(uint32_t edge) const { return m_positions[vertexAt(meshEdgeIndex1(edge))]; } - uint32_t vertexAt(uint32_t index) const - { + uint32_t vertexAt(uint32_t index) const { return m_indices ? m_indices[index] : index; } @@ -4279,34 +4047,29 @@ private: Array<uint32_t> m_traversedCellOffsets; }; -struct UvMeshChart -{ +struct UvMeshChart { Array<uint32_t> faces; Array<uint32_t> indices; uint32_t material; }; -struct UvMesh -{ +struct UvMesh { UvMeshDecl decl; Array<uint32_t> indices; Array<UvMeshChart *> charts; Array<uint32_t> vertexToChartMap; }; -struct UvMeshInstance -{ +struct UvMeshInstance { UvMesh *mesh; Array<Vector2> texcoords; bool rotateCharts; }; namespace raster { -class ClippedTriangle -{ +class ClippedTriangle { public: - ClippedTriangle(const Vector2 &a, const Vector2 &b, const Vector2 &c) - { + ClippedTriangle(const Vector2 &a, const Vector2 &b, const Vector2 &c) { m_numVertices = 3; m_activeVertexBuffer = 0; m_verticesA[0] = a; @@ -4316,20 +4079,19 @@ public: m_vertexBuffers[1] = m_verticesB; } - void clipHorizontalPlane(float offset, float clipdirection) - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void clipHorizontalPlane(float offset, float clipdirection) { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; m_activeVertexBuffer ^= 1; Vector2 *v2 = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; - float dy2, dy1 = offset - v[0].y; - int dy2in, dy1in = clipdirection * dy1 >= 0; - uint32_t p = 0; + float dy2, dy1 = offset - v[0].y; + int dy2in, dy1in = clipdirection * dy1 >= 0; + uint32_t p = 0; for (uint32_t k = 0; k < m_numVertices; k++) { - dy2 = offset - v[k + 1].y; + dy2 = offset - v[k + 1].y; dy2in = clipdirection * dy2 >= 0; if (dy1in) v2[p++] = v[k]; - if ( dy1in + dy2in == 1 ) { // not both in/out + if (dy1in + dy2in == 1) { // not both in/out float dx = v[k + 1].x - v[k].x; float dy = v[k + 1].y - v[k].y; v2[p++] = Vector2(v[k].x + dy1 * (dx / dy), offset); @@ -4340,20 +4102,19 @@ public: m_numVertices = p; } - void clipVerticalPlane(float offset, float clipdirection) - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void clipVerticalPlane(float offset, float clipdirection) { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; m_activeVertexBuffer ^= 1; Vector2 *v2 = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; - float dx2, dx1 = offset - v[0].x; - int dx2in, dx1in = clipdirection * dx1 >= 0; - uint32_t p = 0; + float dx2, dx1 = offset - v[0].x; + int dx2in, dx1in = clipdirection * dx1 >= 0; + uint32_t p = 0; for (uint32_t k = 0; k < m_numVertices; k++) { dx2 = offset - v[k + 1].x; dx2in = clipdirection * dx2 >= 0; if (dx1in) v2[p++] = v[k]; - if ( dx1in + dx2in == 1 ) { // not both in/out + if (dx1in + dx2in == 1) { // not both in/out float dx = v[k + 1].x - v[k].x; float dy = v[k + 1].y - v[k].y; v2[p++] = Vector2(offset, v[k].y + dx1 * (dy / dx)); @@ -4364,9 +4125,8 @@ public: m_numVertices = p; } - void computeArea() - { - Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; + void computeArea() { + Vector2 *v = m_vertexBuffers[m_activeVertexBuffer]; v[m_numVertices] = v[0]; m_area = 0; float centroidx = 0, centroidy = 0; @@ -4380,8 +4140,7 @@ public: m_area = 0.5f * fabsf(m_area); } - void clipAABox(float x0, float y0, float x1, float y1) - { + void clipAABox(float x0, float y0, float x1, float y1) { clipVerticalPlane(x0, -1); clipHorizontalPlane(y0, -1); clipVerticalPlane(x1, 1); @@ -4389,8 +4148,7 @@ public: computeArea(); } - float area() const - { + float area() const { return m_area; } @@ -4407,10 +4165,8 @@ private: typedef bool (*SamplingCallback)(void *param, int x, int y); /// A triangle for rasterization. -struct Triangle -{ - Triangle(const Vector2 &v0, const Vector2 &v1, const Vector2 &v2) - { +struct Triangle { + Triangle(const Vector2 &v0, const Vector2 &v1, const Vector2 &v2) { // Init vertices. this->v1 = v0; this->v2 = v2; @@ -4422,8 +4178,7 @@ struct Triangle computeUnitInwardNormals(); } - bool isValid() - { + bool isValid() { const Vector2 e0 = v3 - v1; const Vector2 e1 = v2 - v1; const float area = e0.y * e1.x - e1.y * e0.x; @@ -4431,18 +4186,17 @@ struct Triangle } // extents has to be multiple of BK_SIZE!! - bool drawAA(const Vector2 &extents, SamplingCallback cb, void *param) - { - const float PX_INSIDE = 1.0f/sqrtf(2.0f); - const float PX_OUTSIDE = -1.0f/sqrtf(2.0f); + bool drawAA(const Vector2 &extents, SamplingCallback cb, void *param) { + const float PX_INSIDE = 1.0f / sqrtf(2.0f); + const float PX_OUTSIDE = -1.0f / sqrtf(2.0f); const float BK_SIZE = 8; - const float BK_INSIDE = sqrtf(BK_SIZE*BK_SIZE/2.0f); - const float BK_OUTSIDE = -sqrtf(BK_SIZE*BK_SIZE/2.0f); + const float BK_INSIDE = sqrtf(BK_SIZE * BK_SIZE / 2.0f); + const float BK_OUTSIDE = -sqrtf(BK_SIZE * BK_SIZE / 2.0f); // Bounding rectangle float minx = floorf(max(min3(v1.x, v2.x, v3.x), 0.0f)); float miny = floorf(max(min3(v1.y, v2.y, v3.y), 0.0f)); - float maxx = ceilf( min(max3(v1.x, v2.x, v3.x), extents.x - 1.0f)); - float maxy = ceilf( min(max3(v1.y, v2.y, v3.y), extents.y - 1.0f)); + float maxx = ceilf(min(max3(v1.x, v2.x, v3.x), extents.x - 1.0f)); + float maxy = ceilf(min(max3(v1.y, v2.y, v3.y), extents.y - 1.0f)); // There's no reason to align the blocks to the viewport, instead we align them to the origin of the triangle bounds. minx = floorf(minx); miny = floorf(miny); @@ -4467,9 +4221,9 @@ struct Triangle float bC = C2 + n2.x * xc + n2.y * yc; float cC = C3 + n3.x * xc + n3.y * yc; // Skip block when outside an edge - if ( (aC <= BK_OUTSIDE) || (bC <= BK_OUTSIDE) || (cC <= BK_OUTSIDE) ) continue; + if ((aC <= BK_OUTSIDE) || (bC <= BK_OUTSIDE) || (cC <= BK_OUTSIDE)) continue; // Accept whole block when totally covered - if ( (aC >= BK_INSIDE) && (bC >= BK_INSIDE) && (cC >= BK_INSIDE) ) { + if ((aC >= BK_INSIDE) && (bC >= BK_INSIDE) && (cC >= BK_INSIDE)) { for (float y = y0; y < y0 + BK_SIZE; y++) { for (float x = x0; x < x0 + BK_SIZE; x++) { if (!cb(param, (int)x, (int)y)) @@ -4512,10 +4266,9 @@ struct Triangle } private: - void flipBackface() - { + void flipBackface() { // check if triangle is backfacing, if so, swap two vertices - if ( ((v3.x - v1.x) * (v2.y - v1.y) - (v3.y - v1.y) * (v2.x - v1.x)) < 0 ) { + if (((v3.x - v1.x) * (v2.y - v1.y) - (v3.y - v1.y) * (v2.x - v1.x)) < 0) { Vector2 hv = v1; v1 = v2; v2 = hv; // swap pos @@ -4523,8 +4276,7 @@ private: } // compute unit inward normals for each edge. - void computeUnitInwardNormals() - { + void computeUnitInwardNormals() { n1 = v1 - v2; n1 = Vector2(-n1.y, n1.x); n1 = n1 * (1.0f / sqrtf(dot(n1, n1))); @@ -4542,8 +4294,7 @@ private: }; // Process the given triangle. Returns false if rasterization was interrupted by the callback. -static bool drawTriangle(const Vector2 &extents, const Vector2 v[3], SamplingCallback cb, void *param) -{ +static bool drawTriangle(const Vector2 &extents, const Vector2 v[3], SamplingCallback cb, void *param) { Triangle tri(v[0], v[1], v[2]); // @@ It would be nice to have a conservative drawing mode that enlarges the triangle extents by one texel and is able to handle degenerate triangles. // @@ Maybe the simplest thing to do would be raster triangle edges. @@ -4566,18 +4317,16 @@ namespace sparse { * elements for each row of the matrix. As with the FullVector the * dimension of the matrix is constant. **/ -class Matrix -{ +class Matrix { public: // An element of the sparse array. - struct Coefficient - { - uint32_t x; // column + struct Coefficient { + uint32_t x; // column float v; // value }; - Matrix(uint32_t d) : m_width(d), m_array(MemTag::Matrix) - { + Matrix(uint32_t d) : + m_width(d), m_array(MemTag::Matrix) { m_array.resize(d); m_array.runCtors(); #if XA_DEBUG_HEAP @@ -4585,9 +4334,9 @@ public: m_array[i].setMemTag(MemTag::Matrix); #endif } - - Matrix(uint32_t w, uint32_t h) : m_width(w), m_array(MemTag::Matrix) - { + + Matrix(uint32_t w, uint32_t h) : + m_width(w), m_array(MemTag::Matrix) { m_array.resize(h); m_array.runCtors(); #if XA_DEBUG_HEAP @@ -4595,9 +4344,8 @@ public: m_array[i].setMemTag(MemTag::Matrix); #endif } - - ~Matrix() - { + + ~Matrix() { m_array.runDtors(); } @@ -4608,10 +4356,9 @@ public: bool isSquare() const { return width() == height(); } // x is column, y is row - float getCoefficient(uint32_t x, uint32_t y) const - { - XA_DEBUG_ASSERT( x < width() ); - XA_DEBUG_ASSERT( y < height() ); + float getCoefficient(uint32_t x, uint32_t y) const { + XA_DEBUG_ASSERT(x < width()); + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { if (m_array[y][i].x == x) return m_array[y][i].v; @@ -4619,10 +4366,9 @@ public: return 0.0f; } - void setCoefficient(uint32_t x, uint32_t y, float f) - { - XA_DEBUG_ASSERT( x < width() ); - XA_DEBUG_ASSERT( y < height() ); + void setCoefficient(uint32_t x, uint32_t y, float f) { + XA_DEBUG_ASSERT(x < width()); + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { if (m_array[y][i].x == x) { @@ -4632,13 +4378,12 @@ public: } if (f != 0.0f) { Coefficient c = { x, f }; - m_array[y].push_back( c ); + m_array[y].push_back(c); } } - float dotRow(uint32_t y, const FullVector &v) const - { - XA_DEBUG_ASSERT( y < height() ); + float dotRow(uint32_t y, const FullVector &v) const { + XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); float sum = 0; for (uint32_t i = 0; i < count; i++) { @@ -4647,8 +4392,7 @@ public: return sum; } - void madRow(uint32_t y, float alpha, FullVector &v) const - { + void madRow(uint32_t y, float alpha, FullVector &v) const { XA_DEBUG_ASSERT(y < height()); const uint32_t count = m_array[y].size(); for (uint32_t i = 0; i < count; i++) { @@ -4656,9 +4400,8 @@ public: } } - void clearRow(uint32_t y) - { - XA_DEBUG_ASSERT( y < height() ); + void clearRow(uint32_t y) { + XA_DEBUG_ASSERT(y < height()); m_array[y].clear(); } @@ -4669,12 +4412,11 @@ private: const uint32_t m_width; /// Array of matrix elements. - Array< Array<Coefficient> > m_array; + Array<Array<Coefficient>> m_array; }; // y = a * x + y -static void saxpy(float a, const FullVector &x, FullVector &y) -{ +static void saxpy(float a, const FullVector &x, FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { @@ -4682,8 +4424,7 @@ static void saxpy(float a, const FullVector &x, FullVector &y) } } -static void copy(const FullVector &x, FullVector &y) -{ +static void copy(const FullVector &x, FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { @@ -4691,16 +4432,14 @@ static void copy(const FullVector &x, FullVector &y) } } -static void scal(float a, FullVector &x) -{ +static void scal(float a, FullVector &x) { const uint32_t dim = x.dimension(); for (uint32_t i = 0; i < dim; i++) { x[i] *= a; } } -static float dot(const FullVector &x, const FullVector &y) -{ +static float dot(const FullVector &x, const FullVector &y) { XA_DEBUG_ASSERT(x.dimension() == y.dimension()); const uint32_t dim = x.dimension(); float sum = 0; @@ -4711,24 +4450,22 @@ static float dot(const FullVector &x, const FullVector &y) } // y = M * x -static void mult(const Matrix &M, const FullVector &x, FullVector &y) -{ +static void mult(const Matrix &M, const FullVector &x, FullVector &y) { uint32_t w = M.width(); uint32_t h = M.height(); - XA_DEBUG_ASSERT( w == x.dimension() ); + XA_DEBUG_ASSERT(w == x.dimension()); XA_UNUSED(w); - XA_DEBUG_ASSERT( h == y.dimension() ); + XA_DEBUG_ASSERT(h == y.dimension()); for (uint32_t i = 0; i < h; i++) y[i] = M.dotRow(i, x); } // y = alpha*A*x + beta*y -static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, FullVector &y) -{ +static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, FullVector &y) { const uint32_t w = A.width(); const uint32_t h = A.height(); - XA_DEBUG_ASSERT( w == x.dimension() ); - XA_DEBUG_ASSERT( h == y.dimension() ); + XA_DEBUG_ASSERT(w == x.dimension()); + XA_DEBUG_ASSERT(h == y.dimension()); XA_UNUSED(w); XA_UNUSED(h); for (uint32_t i = 0; i < h; i++) @@ -4736,8 +4473,7 @@ static void sgemv(float alpha, const Matrix &A, const FullVector &x, float beta, } // dot y-row of A by x-column of B -static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) -{ +static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) { const Array<Matrix::Coefficient> &row = A.getRow(y); const uint32_t count = row.size(); float sum = 0.0f; @@ -4748,8 +4484,7 @@ static float dotRowColumn(int y, const Matrix &A, int x, const Matrix &B) return sum; } -static void transpose(const Matrix &A, Matrix &B) -{ +static void transpose(const Matrix &A, Matrix &B) { XA_DEBUG_ASSERT(A.width() == B.height()); XA_DEBUG_ASSERT(B.width() == A.height()); const uint32_t w = A.width(); @@ -4768,8 +4503,7 @@ static void transpose(const Matrix &A, Matrix &B) } } -static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Matrix &C) -{ +static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Matrix &C) { const uint32_t w = C.width(); const uint32_t h = C.height(); #if XA_DEBUG @@ -4793,8 +4527,7 @@ static void sgemm(float alpha, const Matrix &A, const Matrix &B, float beta, Mat } // C = A * B -static void mult(const Matrix &A, const Matrix &B, Matrix &C) -{ +static void mult(const Matrix &A, const Matrix &B, Matrix &C) { sgemm(1.0f, A, B, 0.0f, C); } @@ -4804,22 +4537,19 @@ namespace segment { // - Insertion is o(n) // - Smallest element goes at the end, so that popping it is o(1). -struct CostQueue -{ - CostQueue(uint32_t size = UINT32_MAX) : m_maxSize(size), m_pairs(MemTag::SegmentAtlasChartCandidates) {} +struct CostQueue { + CostQueue(uint32_t size = UINT32_MAX) : + m_maxSize(size), m_pairs(MemTag::SegmentAtlasChartCandidates) {} - float peekCost() const - { + float peekCost() const { return m_pairs.back().cost; } - uint32_t peekFace() const - { + uint32_t peekFace() const { return m_pairs.back().face; } - void push(float cost, uint32_t face) - { + void push(float cost, uint32_t face) { const Pair p = { cost, face }; if (m_pairs.isEmpty() || cost < peekCost()) m_pairs.push_back(p); @@ -4836,29 +4566,25 @@ struct CostQueue } } - uint32_t pop() - { + uint32_t pop() { XA_DEBUG_ASSERT(!m_pairs.isEmpty()); uint32_t f = m_pairs.back().face; m_pairs.pop_back(); return f; } - XA_INLINE void clear() - { + XA_INLINE void clear() { m_pairs.clear(); } - XA_INLINE uint32_t count() const - { + XA_INLINE uint32_t count() const { return m_pairs.size(); } private: const uint32_t m_maxSize; - struct Pair - { + struct Pair { float cost; uint32_t face; }; @@ -4866,9 +4592,9 @@ private: Array<Pair> m_pairs; }; -struct Chart -{ - Chart() : faces(MemTag::SegmentAtlasChartFaces) {} +struct Chart { + Chart() : + faces(MemTag::SegmentAtlasChartFaces) {} int id = -1; Basis basis; // Best fit normal. @@ -4882,12 +4608,11 @@ struct Chart CostQueue candidates; }; -struct Atlas -{ - Atlas() : m_edgeLengths(MemTag::SegmentAtlasMeshData), m_faceAreas(MemTag::SegmentAtlasMeshData), m_faceNormals(MemTag::SegmentAtlasMeshData), m_texcoords(MemTag::SegmentAtlasMeshData), m_bestTriangles(10), m_nextPlanarRegionFace(MemTag::SegmentAtlasPlanarRegions), m_facePlanarRegionId(MemTag::SegmentAtlasPlanarRegions) {} +struct Atlas { + Atlas() : + m_edgeLengths(MemTag::SegmentAtlasMeshData), m_faceAreas(MemTag::SegmentAtlasMeshData), m_faceNormals(MemTag::SegmentAtlasMeshData), m_texcoords(MemTag::SegmentAtlasMeshData), m_bestTriangles(10), m_nextPlanarRegionFace(MemTag::SegmentAtlasPlanarRegions), m_facePlanarRegionId(MemTag::SegmentAtlasPlanarRegions) {} - ~Atlas() - { + ~Atlas() { const uint32_t chartCount = m_charts.size(); for (uint32_t i = 0; i < chartCount; i++) { m_charts[i]->~Chart(); @@ -4900,8 +4625,7 @@ struct Atlas const Array<uint32_t> &chartFaces(uint32_t i) const { return m_charts[i]->faces; } const Basis &chartBasis(uint32_t chartIndex) const { return m_charts[chartIndex]->basis; } - void reset(uint32_t meshId, uint32_t chartGroupId, const Mesh *mesh, const ChartOptions &options) - { + void reset(uint32_t meshId, uint32_t chartGroupId, const Mesh *mesh, const ChartOptions &options) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_PROFILE_START(buildAtlasInit) @@ -4995,8 +4719,7 @@ struct Atlas XA_PROFILE_END(buildAtlasInit) } - void placeSeeds(float threshold) - { + void placeSeeds(float threshold) { XA_PROFILE_START(buildAtlasPlaceSeeds) // Instead of using a predefiened number of seeds: // - Add seeds one by one, growing chart until a certain treshold. @@ -5010,8 +4733,7 @@ struct Atlas } // Returns true if any of the charts can grow more. - void growCharts(float threshold) - { + void growCharts(float threshold) { XA_PROFILE_START(buildAtlasGrowCharts) for (;;) { if (m_facesLeft == 0) @@ -5057,8 +4779,7 @@ struct Atlas XA_PROFILE_END(buildAtlasGrowCharts) } - void resetCharts() - { + void resetCharts() { XA_PROFILE_START(buildAtlasResetCharts) const uint32_t faceCount = m_mesh->faceCount(); for (uint32_t i = 0; i < faceCount; i++) @@ -5083,8 +4804,7 @@ struct Atlas XA_PROFILE_END(buildAtlasResetCharts) } - bool relocateSeeds() - { + bool relocateSeeds() { XA_PROFILE_START(buildAtlasRelocateSeeds) bool anySeedChanged = false; const uint32_t chartCount = m_charts.size(); @@ -5097,8 +4817,7 @@ struct Atlas return anySeedChanged; } - void fillHoles(float threshold) - { + void fillHoles(float threshold) { XA_PROFILE_START(buildAtlasFillHoles) while (m_facesLeft > 0) createRandomChart(threshold); @@ -5106,8 +4825,7 @@ struct Atlas } #if XA_MERGE_CHARTS - void mergeCharts() - { + void mergeCharts() { XA_PROFILE_START(buildAtlasMergeCharts) const uint32_t chartCount = m_charts.size(); // Merge charts progressively until there's none left to merge. @@ -5165,7 +4883,7 @@ struct Atlas // Merge if chart2 has a single face. // chart1 must have more than 1 face. // chart2 area must be <= 10% of chart1 area. - if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && chart->faces.size() > 1 && chart2->faces.size() == 1 && chart2->area <= chart->area * 0.1f) + if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && chart->faces.size() > 1 && chart2->faces.size() == 1 && chart2->area <= chart->area * 0.1f) goto merge; // Merge if chart2 has two faces (probably a quad), and chart1 bounds at least 2 of its edges. if (chart2->faces.size() == 2 && m_sharedBoundaryEdgeCountNoSeams[cc] >= 2) @@ -5173,8 +4891,8 @@ struct Atlas // Merge if chart2 is wholely inside chart1, ignoring seams. if (m_sharedBoundaryLengthsNoSeams[cc] > 0.0f && equal(m_sharedBoundaryLengthsNoSeams[cc], chart2->boundaryLength, kEpsilon)) goto merge; - if (m_sharedBoundaryLengths[cc] > 0.2f * max(0.0f, chart->boundaryLength - externalBoundaryLength) || - m_sharedBoundaryLengths[cc] > 0.75f * chart2->boundaryLength) + if (m_sharedBoundaryLengths[cc] > 0.2f * max(0.0f, chart->boundaryLength - externalBoundaryLength) || + m_sharedBoundaryLengths[cc] > 0.75f * chart2->boundaryLength) goto merge; continue; merge: @@ -5212,8 +4930,7 @@ struct Atlas #endif private: - void createRandomChart(float threshold) - { + void createRandomChart(float threshold) { Chart *chart = XA_NEW(MemTag::Default, Chart); chart->id = (int)m_charts.size(); m_charts.push_back(chart); @@ -5239,15 +4956,13 @@ private: } } - bool isChartBoundaryEdge(const Chart *chart, uint32_t edge) const - { + bool isChartBoundaryEdge(const Chart *chart, uint32_t edge) const { const uint32_t oppositeEdge = m_mesh->oppositeEdge(edge); const uint32_t oppositeFace = meshEdgeFace(oppositeEdge); return oppositeEdge == UINT32_MAX || m_faceCharts[oppositeFace] != chart->id; } - bool computeChartBasis(Chart *chart, Basis *basis) - { + bool computeChartBasis(Chart *chart, Basis *basis) { const uint32_t faceCount = chart->faces.size(); m_tempPoints.resize(chart->faces.size() * 3); for (uint32_t i = 0; i < faceCount; i++) { @@ -5258,8 +4973,7 @@ private: return Fit::computeBasis(m_tempPoints.data(), m_tempPoints.size(), basis); } - bool isFaceFlipped(uint32_t face) const - { + bool isFaceFlipped(uint32_t face) const { const Vector2 &v1 = m_texcoords[face * 3 + 0]; const Vector2 &v2 = m_texcoords[face * 3 + 1]; const Vector2 &v3 = m_texcoords[face * 3 + 2]; @@ -5267,8 +4981,7 @@ private: return parametricArea < 0.0f; } - void parameterizeChart(const Chart *chart) - { + void parameterizeChart(const Chart *chart) { const uint32_t faceCount = chart->faces.size(); for (uint32_t i = 0; i < faceCount; i++) { const uint32_t face = chart->faces[i]; @@ -5281,8 +4994,7 @@ private: } // m_faceCharts for the chart faces must be set to the chart ID. Needed to compute boundary edges. - bool isChartParameterizationValid(const Chart *chart) - { + bool isChartParameterizationValid(const Chart *chart) { const uint32_t faceCount = chart->faces.size(); // Check for flipped faces in the parameterization. OK if all are flipped. uint32_t flippedFaceCount = 0; @@ -5307,15 +5019,14 @@ private: return true; } - bool addFaceToChart(Chart *chart, uint32_t face) - { + bool addFaceToChart(Chart *chart, uint32_t face) { XA_DEBUG_ASSERT(m_faceCharts[face] == -1); const uint32_t oldFaceCount = chart->faces.size(); const bool firstFace = oldFaceCount == 0; // Append the face and any coplanar connected faces to the chart faces array. chart->faces.push_back(face); uint32_t coplanarFace = m_nextPlanarRegionFace[face]; - while (coplanarFace != face) { + while (coplanarFace != face) { XA_DEBUG_ASSERT(m_faceCharts[coplanarFace] == -1); chart->faces.push_back(coplanarFace); coplanarFace = m_nextPlanarRegionFace[coplanarFace]; @@ -5327,7 +5038,7 @@ private: // Use the first face normal. // Use any edge as the tangent vector. basis.normal = m_faceNormals[face]; - basis.tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 0)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 1)), kEpsilon); + basis.tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 0)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 1)), 0); basis.bitangent = cross(basis.normal, basis.tangent); } else { // Use best fit normal. @@ -5385,8 +5096,7 @@ private: } // Returns true if the seed has changed. - bool relocateSeed(Chart *chart) - { + bool relocateSeed(Chart *chart) { // Find the first N triangles that fit the proxy best. const uint32_t faceCount = chart->faces.size(); m_bestTriangles.clear(); @@ -5425,8 +5135,7 @@ private: } // Evaluate combined metric. - float evaluateCost(Chart *chart, uint32_t face) const - { + float evaluateCost(Chart *chart, uint32_t face) const { if (dot(m_faceNormals[face], chart->basis.normal) <= 0.26f) // ~75 degrees return FLT_MAX; // Estimate boundary length and area: @@ -5467,16 +5176,14 @@ private: } // Returns a value in [0-1]. - float evaluateProxyFitMetric(Chart *chart, uint32_t face) const - { + float evaluateProxyFitMetric(Chart *chart, uint32_t face) const { // All faces in coplanar regions have the same normal, can use any face. const Vector3 faceNormal = m_faceNormals[face]; // Use plane fitting metric for now: return 1 - dot(faceNormal, chart->basis.normal); // @@ normal deviations should be weighted by face area } - float evaluateRoundnessMetric(Chart *chart, float newBoundaryLength, float newChartArea) const - { + float evaluateRoundnessMetric(Chart *chart, float newBoundaryLength, float newChartArea) const { const float roundness = square(chart->boundaryLength) / chart->area; const float newBoundaryLengthSq = square(newBoundaryLength); const float newRoundness = newBoundaryLengthSq / newChartArea; @@ -5486,12 +5193,11 @@ private: return 0; } - float evaluateStraightnessMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateStraightnessMetric(Chart *chart, uint32_t firstFace) const { float l_out = 0.0f, l_in = 0.0f; const uint32_t planarRegionId = m_facePlanarRegionId[firstFace]; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { const float l = m_edgeLengths[it.edge()]; if (it.isBoundary()) { @@ -5512,8 +5218,7 @@ private: return min(ratio, 0.0f); // Only use the straightness metric to close gaps. } - bool isNormalSeam(uint32_t edge) const - { + bool isNormalSeam(uint32_t edge) const { const uint32_t oppositeEdge = m_mesh->oppositeEdge(edge); if (oppositeEdge == UINT32_MAX) return false; // boundary edge @@ -5533,11 +5238,10 @@ private: return !equal(m_faceNormals[f0], m_faceNormals[f1], kNormalEpsilon); } - float evaluateNormalSeamMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateNormalSeamMetric(Chart *chart, uint32_t firstFace) const { float seamFactor = 0.0f, totalLength = 0.0f; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { if (it.isBoundary()) continue; @@ -5574,11 +5278,10 @@ private: return seamFactor / totalLength; } - float evaluateTextureSeamMetric(Chart *chart, uint32_t firstFace) const - { + float evaluateTextureSeamMetric(Chart *chart, uint32_t firstFace) const { float seamLength = 0.0f, totalLength = 0.0f; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { if (it.isBoundary()) continue; @@ -5601,11 +5304,10 @@ private: return seamLength / totalLength; } - float computeArea(Chart *chart, uint32_t firstFace) const - { + float computeArea(Chart *chart, uint32_t firstFace) const { float area = chart->area; uint32_t face = firstFace; - for (;;) { + for (;;) { area += m_faceAreas[face]; face = m_nextPlanarRegionFace[face]; if (face == firstFace) @@ -5614,13 +5316,12 @@ private: return area; } - float computeBoundaryLength(Chart *chart, uint32_t firstFace) const - { + float computeBoundaryLength(Chart *chart, uint32_t firstFace) const { float boundaryLength = chart->boundaryLength; // Add new edges, subtract edges shared with the chart. const uint32_t planarRegionId = m_facePlanarRegionId[firstFace]; uint32_t face = firstFace; - for (;;) { + for (;;) { for (Mesh::FaceEdgeIterator it(m_mesh, face); !it.isDone(); it.advance()) { const float edgeLength = m_edgeLengths[it.edge()]; if (it.isBoundary()) { @@ -5636,11 +5337,10 @@ private: if (face == firstFace) break; } - return max(0.0f, boundaryLength); // @@ Hack! + return max(0.0f, boundaryLength); // @@ Hack! } - bool mergeChart(Chart *owner, Chart *chart, float sharedBoundaryLength) - { + bool mergeChart(Chart *owner, Chart *chart, float sharedBoundaryLength) { const uint32_t oldOwnerFaceCount = owner->faces.size(); const uint32_t chartFaceCount = chart->faces.size(); owner->faces.push_back(chart->faces); @@ -5706,11 +5406,10 @@ private: namespace param { -class JacobiPreconditioner -{ +class JacobiPreconditioner { public: - JacobiPreconditioner(const sparse::Matrix &M, bool symmetric) : m_inverseDiagonal(M.width()) - { + JacobiPreconditioner(const sparse::Matrix &M, bool symmetric) : + m_inverseDiagonal(M.width()) { XA_ASSERT(M.isSquare()); for (uint32_t x = 0; x < M.width(); x++) { float elem = M.getCoefficient(x, x); @@ -5723,8 +5422,7 @@ public: } } - void apply(const FullVector &x, FullVector &y) const - { + void apply(const FullVector &x, FullVector &y) const { XA_DEBUG_ASSERT(x.dimension() == m_inverseDiagonal.dimension()); XA_DEBUG_ASSERT(y.dimension() == m_inverseDiagonal.dimension()); // @@ Wrap vector component-wise product into a separate function. @@ -5739,12 +5437,10 @@ private: }; // Linear solvers. -class Solver -{ +class Solver { public: // Solve the symmetric system: At·A·x = At·b - static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) - { + static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.width() == x.dimension()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(A.height() >= A.width()); // @@ If height == width we could solve it directly... @@ -5759,8 +5455,7 @@ public: } // See section 10.4.3 in: Mesh Parameterization: Theory and Practice, Siggraph Course Notes, August 2007 - static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, const uint32_t *lockedParameters, uint32_t lockedCount, float epsilon = 1e-5f) - { + static bool LeastSquaresSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, const uint32_t *lockedParameters, uint32_t lockedCount, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.width() == x.dimension()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(A.height() >= A.width() - lockedCount); @@ -5859,18 +5554,17 @@ private: * **/ // Conjugate gradient with preconditioner. - static bool ConjugateGradientSolver(const JacobiPreconditioner &preconditioner, const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon) - { - XA_DEBUG_ASSERT( A.isSquare() ); - XA_DEBUG_ASSERT( A.width() == b.dimension() ); - XA_DEBUG_ASSERT( A.width() == x.dimension() ); + static bool ConjugateGradientSolver(const JacobiPreconditioner &preconditioner, const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon) { + XA_DEBUG_ASSERT(A.isSquare()); + XA_DEBUG_ASSERT(A.width() == b.dimension()); + XA_DEBUG_ASSERT(A.width() == x.dimension()); int i = 0; const int D = A.width(); - const int i_max = 4 * D; // Convergence should be linear, but in some cases, it's not. - FullVector r(D); // residual - FullVector p(D); // search direction - FullVector q(D); // - FullVector s(D); // preconditioned + const int i_max = 4 * D; // Convergence should be linear, but in some cases, it's not. + FullVector r(D); // residual + FullVector p(D); // search direction + FullVector q(D); // + FullVector s(D); // preconditioned float delta_0; float delta_old; float delta_new; @@ -5896,7 +5590,7 @@ private: // x = alfa·p + x sparse::saxpy(alpha, p, x); if ((i & 31) == 0) { // recompute r after 32 steps - // r = b - A·x + // r = b - A·x sparse::copy(b, r); sparse::sgemv(-1, A, x, 1, r); } else { @@ -5906,7 +5600,7 @@ private: // s = M^-1 · r preconditioner.apply(r, s); delta_old = delta_new; - delta_new = sparse::dot( r, s ); + delta_new = sparse::dot(r, s); beta = delta_new / delta_old; // p = s + beta·p sparse::scal(beta, p); @@ -5915,8 +5609,7 @@ private: return delta_new <= epsilon * epsilon * delta_0; } - static bool SymmetricSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) - { + static bool SymmetricSolver(const sparse::Matrix &A, const FullVector &b, FullVector &x, float epsilon = 1e-5f) { XA_DEBUG_ASSERT(A.height() == A.width()); XA_DEBUG_ASSERT(A.height() == b.dimension()); XA_DEBUG_ASSERT(b.dimension() == x.dimension()); @@ -5926,8 +5619,7 @@ private: }; // Fast sweep in 3 directions -static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b) -{ +static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b) { XA_DEBUG_ASSERT(a != nullptr); XA_DEBUG_ASSERT(b != nullptr); const uint32_t vertexCount = mesh->vertexCount(); @@ -5984,28 +5676,24 @@ static bool findApproximateDiameterVertices(Mesh *mesh, uint32_t *a, uint32_t *b // Conformal relations from Brecht Van Lommel (based on ABF): -static float vec_angle_cos(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) -{ +static float vec_angle_cos(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) { Vector3 d1 = v1 - v2; Vector3 d2 = v3 - v2; return clamp(dot(d1, d2) / (length(d1) * length(d2)), -1.0f, 1.0f); } -static float vec_angle(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) -{ +static float vec_angle(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3) { float dot = vec_angle_cos(v1, v2, v3); return acosf(dot); } -static void triangle_angles(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3, float *a1, float *a2, float *a3) -{ +static void triangle_angles(const Vector3 &v1, const Vector3 &v2, const Vector3 &v3, float *a1, float *a2, float *a3) { *a1 = vec_angle(v3, v1, v2); *a2 = vec_angle(v1, v2, v3); *a3 = kPi - *a2 - *a1; } -static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, int id2, const Vector3 &p0, const Vector3 &p1, const Vector3 &p2) -{ +static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, int id2, const Vector3 &p0, const Vector3 &p1, const Vector3 &p2) { // @@ IC: Wouldn't it be more accurate to return cos and compute 1-cos^2? // It does indeed seem to be a little bit more robust. // @@ Need to revisit this more carefully! @@ -6055,8 +5743,7 @@ static void setup_abf_relations(sparse::Matrix &A, int row, int id0, int id1, in A.setCoefficient(v2_id, 2 * row + 1, 1); } -static bool computeLeastSquaresConformalMap(Mesh *mesh) -{ +static bool computeLeastSquaresConformalMap(Mesh *mesh) { // For this to work properly, mesh should not have colocals that have the same // attributes, unless you want the vertices to actually have different texcoords. const uint32_t vertexCount = mesh->vertexCount(); @@ -6114,10 +5801,8 @@ static bool computeLeastSquaresConformalMap(Mesh *mesh) } #if XA_RECOMPUTE_CHARTS -struct PiecewiseParam -{ - void reset(const Mesh *mesh, uint32_t faceCount) - { +struct PiecewiseParam { + void reset(const Mesh *mesh, uint32_t faceCount) { m_mesh = mesh; m_faceCount = faceCount; const uint32_t vertexCount = m_mesh->vertexCount(); @@ -6134,8 +5819,7 @@ struct PiecewiseParam ConstArrayView<uint32_t> chartFaces() const { return m_patch; } const Vector2 *texcoords() const { return m_texcoords.data(); } - bool computeChart() - { + bool computeChart() { m_patch.clear(); m_faceInvalid.zeroOutMemory(); m_faceInPatch.zeroOutMemory(); @@ -6242,8 +5926,7 @@ struct PiecewiseParam } private: - struct Candidate - { + struct Candidate { uint32_t face, vertex; uint32_t next; // The next candidate with the same vertex. Vector2 position; @@ -6253,10 +5936,12 @@ private: float patchVertexOrient; }; - struct CandidateIterator - { - CandidateIterator(Array<Candidate> &candidates, uint32_t first) : m_candidates(candidates), m_current(first) {} - void advance() { if (m_current != UINT32_MAX) m_current = m_candidates[m_current].next; } + struct CandidateIterator { + CandidateIterator(Array<Candidate> &candidates, uint32_t first) : + m_candidates(candidates), m_current(first) {} + void advance() { + if (m_current != UINT32_MAX) m_current = m_candidates[m_current].next; + } bool isDone() const { return m_current == UINT32_MAX; } Candidate ¤t() { return m_candidates[m_current]; } @@ -6277,8 +5962,7 @@ private: UniformGrid2 m_boundaryGrid; // Find candidate faces on the patch front. - void findCandidates() - { + void findCandidates() { m_candidates.clear(); m_faceInCandidates.zeroOutMemory(); for (uint32_t i = 0; i < m_patch.size(); i++) { @@ -6335,8 +6019,7 @@ private: } } - void addCandidateFace(uint32_t patchEdge, float patchVertexOrient, uint32_t face, uint32_t edge, uint32_t freeVertex) - { + void addCandidateFace(uint32_t patchEdge, float patchVertexOrient, uint32_t face, uint32_t edge, uint32_t freeVertex) { Vector2 texcoords[3]; orthoProjectFace(face, texcoords); // Find corresponding vertices between the patch edge and candidate edge. @@ -6412,8 +6095,7 @@ private: m_faceInCandidates.set(face); } - void orthoProjectFace(uint32_t face, Vector2 *texcoords) const - { + void orthoProjectFace(uint32_t face, Vector2 *texcoords) const { const Vector3 normal = m_mesh->computeFaceNormal(face); const Vector3 tangent = normalize(m_mesh->position(m_mesh->vertexAt(face * 3 + 1)) - m_mesh->position(m_mesh->vertexAt(face * 3 + 0)), kEpsilon); const Vector3 bitangent = cross(normal, tangent); @@ -6423,16 +6105,14 @@ private: } } - float parametricArea(const Vector2 *texcoords) const - { + float parametricArea(const Vector2 *texcoords) const { const Vector2 &v1 = texcoords[0]; const Vector2 &v2 = texcoords[1]; const Vector2 &v3 = texcoords[2]; return ((v2.x - v1.x) * (v3.y - v1.y) - (v3.x - v1.x) * (v2.y - v1.y)) * 0.5f; } - float computeStretch(Vector3 p1, Vector3 p2, Vector3 p3, Vector2 t1, Vector2 t2, Vector2 t3) const - { + float computeStretch(Vector3 p1, Vector3 p2, Vector3 p3, Vector2 t1, Vector2 t2, Vector2 t3) const { float parametricArea = ((t2.y - t1.y) * (t3.x - t1.x) - (t3.y - t1.y) * (t2.x - t1.x)) * 0.5f; if (isZero(parametricArea, kAreaEpsilon)) return FLT_MAX; @@ -6446,16 +6126,14 @@ private: } // Return value is positive if the point is one side of the edge, negative if on the other side. - float orientToEdge(Vector2 edgeVertex0, Vector2 edgeVertex1, Vector2 point) const - { + float orientToEdge(Vector2 edgeVertex0, Vector2 edgeVertex1, Vector2 point) const { return (edgeVertex0.x - point.x) * (edgeVertex1.y - point.y) - (edgeVertex0.y - point.y) * (edgeVertex1.x - point.x); } }; #endif // Estimate quality of existing parameterization. -struct Quality -{ +struct Quality { // computeBoundaryIntersection bool boundaryIntersection = false; @@ -6472,8 +6150,7 @@ struct Quality float conformalMetric = 0.0f; float authalicMetric = 0.0f; - void computeBoundaryIntersection(const Mesh *mesh, UniformGrid2 &boundaryGrid) - { + void computeBoundaryIntersection(const Mesh *mesh, UniformGrid2 &boundaryGrid) { const Array<uint32_t> &boundaryEdges = mesh->boundaryEdges(); const uint32_t boundaryEdgeCount = boundaryEdges.size(); boundaryGrid.reset(mesh->texcoords(), mesh->indices(), boundaryEdgeCount); @@ -6489,8 +6166,7 @@ struct Quality #endif } - void computeFlippedFaces(const Mesh *mesh, uint32_t faceCount, Array<uint32_t> *flippedFaces) - { + void computeFlippedFaces(const Mesh *mesh, uint32_t faceCount, Array<uint32_t> *flippedFaces) { totalTriangleCount = flippedTriangleCount = zeroAreaTriangleCount = 0; if (flippedFaces) flippedFaces->clear(); @@ -6525,8 +6201,7 @@ struct Quality flippedFaces->clear(); flippedTriangleCount = 0; } - if (flippedTriangleCount > totalTriangleCount / 2) - { + if (flippedTriangleCount > totalTriangleCount / 2) { // If more than half the triangles are flipped, reverse the flipped / not flipped classification. flippedTriangleCount = totalTriangleCount - flippedTriangleCount; if (flippedFaces) { @@ -6548,8 +6223,7 @@ struct Quality } } - void computeMetrics(const Mesh *mesh, uint32_t faceCount) - { + void computeMetrics(const Mesh *mesh, uint32_t faceCount) { totalGeometricArea = totalParametricArea = 0.0f; stretchMetric = maxStretchMetric = conformalMetric = authalicMetric = 0.0f; for (uint32_t f = 0; f < faceCount; f++) { @@ -6580,7 +6254,7 @@ struct Quality const float a = dot(Ss, Ss); // E const float b = dot(Ss, St); // F const float c = dot(St, St); // G - // Compute eigen-values of the first fundamental form: + // Compute eigen-values of the first fundamental form: const float sigma1 = sqrtf(0.5f * max(0.0f, a + c - sqrtf(square(a - c) + 4 * square(b)))); // gamma uppercase, min eigenvalue. const float sigma2 = sqrtf(0.5f * max(0.0f, a + c + sqrtf(square(a - c) + 4 * square(b)))); // gamma lowercase, max eigenvalue. XA_ASSERT(sigma2 > sigma1 || equal(sigma1, sigma2, kEpsilon)); @@ -6611,37 +6285,33 @@ struct Quality if (totalGeometricArea > 0.0f) { const float normFactor = sqrtf(totalParametricArea / totalGeometricArea); stretchMetric = sqrtf(stretchMetric / totalGeometricArea) * normFactor; - maxStretchMetric *= normFactor; + maxStretchMetric *= normFactor; conformalMetric = sqrtf(conformalMetric / totalGeometricArea); authalicMetric = sqrtf(authalicMetric / totalGeometricArea); } } }; -struct ChartWarningFlags -{ - enum Enum - { - CloseHolesFailed = 1<<1, - FixTJunctionsDuplicatedEdge = 1<<2, - FixTJunctionsFailed = 1<<3, - TriangulateDuplicatedEdge = 1<<4, +struct ChartWarningFlags { + enum Enum { + CloseHolesFailed = 1 << 1, + FixTJunctionsDuplicatedEdge = 1 << 2, + FixTJunctionsFailed = 1 << 3, + TriangulateDuplicatedEdge = 1 << 4, }; }; -struct ChartCtorBuffers -{ +struct ChartCtorBuffers { Array<uint32_t> chartMeshIndices; Array<uint32_t> unifiedMeshIndices; Array<uint32_t> boundaryLoops; }; /// A chart is a connected set of faces with a certain topology (usually a disk). -class Chart -{ +class Chart { public: - Chart(ChartCtorBuffers &buffers, const Basis &basis, ConstArrayView<uint32_t> faces, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : m_basis(basis), m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::LSCM), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) - { + Chart(ChartCtorBuffers &buffers, const Basis &basis, ConstArrayView<uint32_t> faces, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : + m_basis(basis), m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::LSCM), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_UNUSED(chartId); @@ -6780,8 +6450,8 @@ public: } #if XA_RECOMPUTE_CHARTS - Chart(ChartCtorBuffers &buffers, const Chart *parent, const Mesh *parentMesh, ConstArrayView<uint32_t> faces, const Vector2 *texcoords, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::Piecewise), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) - { + Chart(ChartCtorBuffers &buffers, const Chart *parent, const Mesh *parentMesh, ConstArrayView<uint32_t> faces, const Vector2 *texcoords, const Mesh *originalMesh, uint32_t meshId, uint32_t chartGroupId, uint32_t chartId) : + m_mesh(nullptr), m_unifiedMesh(nullptr), m_unmodifiedUnifiedMesh(nullptr), m_type(ChartType::Piecewise), m_warningFlags(0), m_closedHolesCount(0), m_fixedTJunctionsCount(0) { XA_UNUSED(meshId); XA_UNUSED(chartGroupId); XA_UNUSED(chartId); @@ -6846,8 +6516,7 @@ public: } #endif - ~Chart() - { + ~Chart() { if (m_mesh) { m_mesh->~Mesh(); XA_FREE(m_mesh); @@ -6880,8 +6549,7 @@ public: const Mesh *unmodifiedUnifiedMesh() const { return m_unmodifiedUnifiedMesh; } uint32_t mapChartVertexToOriginalVertex(uint32_t i) const { return m_chartToOriginalMap[i]; } - void evaluateOrthoQuality(UniformGrid2 &boundaryGrid) - { + void evaluateOrthoQuality(UniformGrid2 &boundaryGrid) { XA_PROFILE_START(parameterizeChartsEvaluateQuality) m_quality.computeBoundaryIntersection(m_unifiedMesh, boundaryGrid); m_quality.computeFlippedFaces(m_unifiedMesh, m_initialFaceCount, nullptr); @@ -6892,8 +6560,7 @@ public: m_type = ChartType::Ortho; } - void evaluateQuality(UniformGrid2 &boundaryGrid) - { + void evaluateQuality(UniformGrid2 &boundaryGrid) { XA_PROFILE_START(parameterizeChartsEvaluateQuality) m_quality.computeBoundaryIntersection(m_unifiedMesh, boundaryGrid); #if XA_DEBUG_EXPORT_OBJ_INVALID_PARAMETERIZATION @@ -6906,15 +6573,13 @@ public: } // Transfer parameterization from unified mesh to chart mesh. - void transferParameterization() - { + void transferParameterization() { const uint32_t vertexCount = m_mesh->vertexCount(); for (uint32_t v = 0; v < vertexCount; v++) m_mesh->texcoord(v) = m_unifiedMesh->texcoord(m_chartToUnifiedMap[v]); } - Vector2 computeParametricBounds() const - { + Vector2 computeParametricBounds() const { Vector2 minCorner(FLT_MAX, FLT_MAX); Vector2 maxCorner(-FLT_MAX, -FLT_MAX); const uint32_t vertexCount = m_mesh->vertexCount(); @@ -6949,8 +6614,7 @@ private: #endif }; -struct CreateChartTaskArgs -{ +struct CreateChartTaskArgs { const Mesh *mesh; const Basis *basis; ConstArrayView<uint32_t> faces; @@ -6961,27 +6625,23 @@ struct CreateChartTaskArgs Chart **chart; }; -static void runCreateChartTask(void *userData) -{ +static void runCreateChartTask(void *userData) { XA_PROFILE_START(createChartMeshesThread) auto args = (CreateChartTaskArgs *)userData; *(args->chart) = XA_NEW_ARGS(MemTag::Default, Chart, args->chartBuffers->get(), *(args->basis), args->faces, args->mesh, args->meshId, args->chartGroupId, args->chartId); XA_PROFILE_END(createChartMeshesThread) } -struct ParameterizeChartTaskArgs -{ +struct ParameterizeChartTaskArgs { Chart *chart; ParameterizeFunc func; ThreadLocal<UniformGrid2> *boundaryGrid; }; -static void runParameterizeChartTask(void *userData) -{ +static void runParameterizeChartTask(void *userData) { auto args = (ParameterizeChartTaskArgs *)userData; Mesh *mesh = args->chart->unifiedMesh(); - XA_PROFILE_START(parameterizeChartsOrthogonal) - { + XA_PROFILE_START(parameterizeChartsOrthogonal) { // Project vertices to plane. const uint32_t vertexCount = mesh->vertexCount(); const Basis &basis = args->chart->basis(); @@ -7006,11 +6666,10 @@ static void runParameterizeChartTask(void *userData) } // Set of charts corresponding to mesh faces in the same face group. -class ChartGroup -{ +class ChartGroup { public: - ChartGroup(uint32_t id, const Mesh *sourceMesh, uint16_t faceGroup) : m_sourceId(sourceMesh->id()), m_id(id), m_isVertexMap(faceGroup == Mesh::kInvalidFaceGroup), m_paramAddedChartsCount(0), m_paramDeletedChartsCount(0) - { + ChartGroup(uint32_t id, const Mesh *sourceMesh, uint16_t faceGroup) : + m_sourceId(sourceMesh->id()), m_id(id), m_isVertexMap(faceGroup == Mesh::kInvalidFaceGroup), m_paramAddedChartsCount(0), m_paramDeletedChartsCount(0) { // Create new mesh from the source mesh, using faces that belong to this group. const uint32_t sourceFaceCount = sourceMesh->faceCount(); if (!m_isVertexMap) { @@ -7072,8 +6731,7 @@ public: #endif } - ~ChartGroup() - { + ~ChartGroup() { m_mesh->~Mesh(); XA_FREE(m_mesh); for (uint32_t i = 0; i < m_charts.size(); i++) { @@ -7151,8 +6809,7 @@ public: - emphasize roundness metrics to prevent those cases. - If interior self-overlaps: preserve boundary parameterization and use mean-value map. */ - void computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, segment::Atlas &atlas, ThreadLocal<ChartCtorBuffers> *chartBuffers) - { + void computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, segment::Atlas &atlas, ThreadLocal<ChartCtorBuffers> *chartBuffers) { m_chartOptions = options; // This function may be called multiple times, so destroy existing charts. for (uint32_t i = 0; i < m_charts.size(); i++) { @@ -7222,7 +6879,7 @@ public: #if XA_RECOMPUTE_CHARTS void parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ThreadLocal<UniformGrid2> *boundaryGrid, ThreadLocal<ChartCtorBuffers> *chartBuffers, ThreadLocal<PiecewiseParam> *piecewiseParam) #else - void parameterizeCharts(TaskScheduler* taskScheduler, ParameterizeFunc func, ThreadLocal<UniformGrid2>* boundaryGrid, ThreadLocal<ChartCtorBuffers>* /*chartBuffers*/) + void parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ThreadLocal<UniformGrid2> *boundaryGrid, ThreadLocal<ChartCtorBuffers> * /*chartBuffers*/) #endif { m_paramAddedChartsCount = 0; @@ -7316,8 +6973,7 @@ public: } private: - void buildAtlas(segment::Atlas &atlas, const ChartOptions &options) - { + void buildAtlas(segment::Atlas &atlas, const ChartOptions &options) { if (atlas.facesLeft() == 0) return; // Create initial charts greedely. @@ -7347,8 +7003,7 @@ private: XA_DEBUG_ASSERT(atlas.facesLeft() == 0); } - void removeChart(const Chart *chart) - { + void removeChart(const Chart *chart) { for (uint32_t i = 0; i < m_charts.size(); i++) { if (m_charts[i] == chart) { m_charts.removeAt(i); @@ -7368,24 +7023,21 @@ private: uint32_t m_paramDeletedChartsCount; // Number of charts with invalid parameterizations that were deleted, after charts were recomputed. }; -struct CreateChartGroupTaskArgs -{ +struct CreateChartGroupTaskArgs { uint16_t faceGroup; uint32_t groupId; const Mesh *mesh; ChartGroup **chartGroup; }; -static void runCreateChartGroupTask(void *userData) -{ +static void runCreateChartGroupTask(void *userData) { XA_PROFILE_START(addMeshCreateChartGroupsThread) auto args = (CreateChartGroupTaskArgs *)userData; *(args->chartGroup) = XA_NEW_ARGS(MemTag::Default, ChartGroup, args->groupId, args->mesh, args->faceGroup); XA_PROFILE_END(addMeshCreateChartGroupsThread) } -struct ComputeChartsTaskArgs -{ +struct ComputeChartsTaskArgs { TaskScheduler *taskScheduler; ChartGroup *chartGroup; ThreadLocal<segment::Atlas> *atlas; @@ -7394,8 +7046,7 @@ struct ComputeChartsTaskArgs Progress *progress; }; -static void runComputeChartsJob(void *userData) -{ +static void runComputeChartsJob(void *userData) { auto args = (ComputeChartsTaskArgs *)userData; if (args->progress->cancel) return; @@ -7406,8 +7057,7 @@ static void runComputeChartsJob(void *userData) args->progress->update(); } -struct ParameterizeChartsTaskArgs -{ +struct ParameterizeChartsTaskArgs { TaskScheduler *taskScheduler; ChartGroup *chartGroup; ParameterizeFunc func; @@ -7419,8 +7069,7 @@ struct ParameterizeChartsTaskArgs Progress *progress; }; -static void runParameterizeChartsJob(void *userData) -{ +static void runParameterizeChartsJob(void *userData) { auto args = (ParameterizeChartsTaskArgs *)userData; if (args->progress->cancel) return; @@ -7436,13 +7085,12 @@ static void runParameterizeChartsJob(void *userData) } /// An atlas is a set of chart groups. -class Atlas -{ +class Atlas { public: - Atlas() : m_meshCount(0), m_chartsComputed(false), m_chartsParameterized(false) {} + Atlas() : + m_meshCount(0), m_chartsComputed(false), m_chartsParameterized(false) {} - ~Atlas() - { + ~Atlas() { for (uint32_t i = 0; i < m_chartGroups.size(); i++) { m_chartGroups[i]->~ChartGroup(); XA_FREE(m_chartGroups[i]); @@ -7454,8 +7102,7 @@ public: uint32_t chartGroupCount() const { return m_chartGroups.size(); } const ChartGroup *chartGroupAt(uint32_t index) const { return m_chartGroups[index]; } - uint32_t chartGroupCount(uint32_t mesh) const - { + uint32_t chartGroupCount(uint32_t mesh) const { uint32_t count = 0; for (uint32_t i = 0; i < m_chartGroups.size(); i++) { if (m_chartGroupSourceMeshes[i] == mesh) @@ -7464,8 +7111,7 @@ public: return count; } - const ChartGroup *chartGroupAt(uint32_t mesh, uint32_t group) const - { + const ChartGroup *chartGroupAt(uint32_t mesh, uint32_t group) const { for (uint32_t c = 0; c < m_chartGroups.size(); c++) { if (m_chartGroupSourceMeshes[c] != mesh) continue; @@ -7477,8 +7123,7 @@ public: } // This function is thread safe. - void addMesh(TaskScheduler *taskScheduler, const Mesh *mesh) - { + void addMesh(TaskScheduler *taskScheduler, const Mesh *mesh) { // Create one chart group per face group. // If there's any ignored faces in the mesh, create an extra face group for that (vertex map). // Chart group creation is slow since it copies a chunk of the source mesh, so use tasks. @@ -7513,8 +7158,7 @@ public: // Chart id/index is determined by depth-first hierarchy of mesh -> chart group -> chart. // For chart index to be consistent here, chart groups needs to sorted by mesh index. Since addMesh is called by multithreaded tasks, order is indeterminate, so chart groups need to be explicitly sorted after all meshes are added. - void sortChartGroups() - { + void sortChartGroups() { Array<ChartGroup *> oldChartGroups; oldChartGroups.resize(m_chartGroups.size()); memcpy(oldChartGroups.data(), m_chartGroups.data(), sizeof(ChartGroup *) * m_chartGroups.size()); @@ -7533,8 +7177,7 @@ public: } } - bool computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, ProgressFunc progressFunc, void *progressUserData) - { + bool computeCharts(TaskScheduler *taskScheduler, const ChartOptions &options, ProgressFunc progressFunc, void *progressUserData) { m_chartsComputed = false; m_chartsParameterized = false; // Ignore vertex maps. @@ -7582,8 +7225,7 @@ public: return true; } - bool parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ProgressFunc progressFunc, void *progressUserData) - { + bool parameterizeCharts(TaskScheduler *taskScheduler, ParameterizeFunc func, ProgressFunc progressFunc, void *progressUserData) { m_chartsParameterized = false; // Ignore vertex maps. uint32_t chartGroupCount = 0; @@ -7643,17 +7285,15 @@ private: namespace pack { -class AtlasImage -{ +class AtlasImage { public: - AtlasImage(uint32_t width, uint32_t height) : m_width(width), m_height(height) - { + AtlasImage(uint32_t width, uint32_t height) : + m_width(width), m_height(height) { m_data.resize(m_width * m_height); memset(m_data.data(), 0, sizeof(uint32_t) * m_data.size()); } - void resize(uint32_t width, uint32_t height) - { + void resize(uint32_t width, uint32_t height) { Array<uint32_t> data; data.resize(width * height); memset(data.data(), 0, sizeof(uint32_t) * data.size()); @@ -7664,8 +7304,7 @@ public: data.moveTo(m_data); } - void addChart(uint32_t chartIndex, const BitImage *image, const BitImage *imageBilinear, const BitImage *imagePadding, int atlas_w, int atlas_h, int offset_x, int offset_y) - { + void addChart(uint32_t chartIndex, const BitImage *image, const BitImage *imageBilinear, const BitImage *imagePadding, int atlas_w, int atlas_h, int offset_x, int offset_y) { const int w = image->width(); const int h = image->height(); for (int y = 0; y < h; y++) { @@ -7691,15 +7330,13 @@ public: } } - void copyTo(uint32_t *dest, uint32_t destWidth, uint32_t destHeight, int padding) const - { + void copyTo(uint32_t *dest, uint32_t destWidth, uint32_t destHeight, int padding) const { for (uint32_t y = 0; y < destHeight; y++) memcpy(&dest[y * destWidth], &m_data[padding + (y + padding) * m_width], destWidth * sizeof(uint32_t)); } #if XA_DEBUG_EXPORT_ATLAS_IMAGES - void writeTga(const char *filename, uint32_t width, uint32_t height) const - { + void writeTga(const char *filename, uint32_t width, uint32_t height) const { Array<uint8_t> image; image.resize(width * height * 3); for (uint32_t y = 0; y < height; y++) { @@ -7741,8 +7378,7 @@ private: Array<uint32_t> m_data; }; -struct Chart -{ +struct Chart { int32_t atlasIndex; uint32_t material; uint32_t indexCount; @@ -7764,15 +7400,13 @@ struct Chart uint32_t uniqueVertexCount() const { return uniqueVertices.isEmpty() ? vertexCount : uniqueVertices.size(); } }; -struct AddChartTaskArgs -{ +struct AddChartTaskArgs { ThreadLocal<BoundingBox2D> *boundingBox; param::Chart *paramChart; Chart *chart; // out }; -static void runAddChartTask(void *userData) -{ +static void runAddChartTask(void *userData) { XA_PROFILE_START(packChartsAddChartsThread) auto args = (AddChartTaskArgs *)userData; param::Chart *paramChart = args->paramChart; @@ -7811,10 +7445,8 @@ static void runAddChartTask(void *userData) XA_PROFILE_END(packChartsAddChartsThread) } -struct Atlas -{ - ~Atlas() - { +struct Atlas { + ~Atlas() { for (uint32_t i = 0; i < m_atlasImages.size(); i++) { m_atlasImages[i]->~AtlasImage(); XA_FREE(m_atlasImages[i]); @@ -7838,8 +7470,7 @@ struct Atlas const Array<AtlasImage *> &getImages() const { return m_atlasImages; } float getUtilization(uint32_t atlas) const { return m_utilization[atlas]; } - void addCharts(TaskScheduler *taskScheduler, param::Atlas *paramAtlas) - { + void addCharts(TaskScheduler *taskScheduler, param::Atlas *paramAtlas) { // Count charts. uint32_t chartCount = 0; const uint32_t chartGroupsCount = paramAtlas->chartGroupCount(); @@ -7880,8 +7511,7 @@ struct Atlas m_charts[i] = taskArgs[i].chart; } - void addUvMeshCharts(UvMeshInstance *mesh) - { + void addUvMeshCharts(UvMeshInstance *mesh) { BitArray vertexUsed(mesh->texcoords.size()); BoundingBox2D boundingBox; for (uint32_t c = 0; c < mesh->mesh->charts.size(); c++) { @@ -7942,8 +7572,7 @@ struct Atlas } // Pack charts in the smallest possible rectangle. - bool packCharts(const PackOptions &options, ProgressFunc progressFunc, void *progressUserData) - { + bool packCharts(const PackOptions &options, ProgressFunc progressFunc, void *progressUserData) { if (progressFunc) { if (!progressFunc(ProgressCategory::PackCharts, 0, progressUserData)) return false; @@ -8178,8 +7807,7 @@ struct Atlas int best_x = 0, best_y = 0; int best_cw = 0, best_ch = 0; int best_r = 0; - for (;;) - { + for (;;) { bool firstChartInBitImage = false; XA_UNUSED(firstChartInBitImage); if (currentAtlas + 1 > m_bitImages.size()) { @@ -8212,8 +7840,7 @@ struct Atlas if (best_x + best_cw > atlasSizes[currentAtlas].x || best_y + best_ch > atlasSizes[currentAtlas].y) { for (uint32_t j = 0; j < chartStartPositions.size(); j++) chartStartPositions[j] = Vector2i(0, 0); - } - else { + } else { chartStartPositions[currentAtlas] = Vector2i(best_x, best_y); } } @@ -8312,8 +7939,7 @@ struct Atlas } if (m_utilization.size() > 1) { XA_PRINT(" %u: %f%% utilization\n", i, m_utilization[i] * 100.0f); - } - else { + } else { XA_PRINT(" %f%% utilization\n", m_utilization[i] * 100.0f); } } @@ -8336,16 +7962,14 @@ private: // is occupied at this point. At the end we have many small charts and a large atlas with sparse holes. Finding those holes randomly is slow. A better approach would be to // start stacking large charts as if they were tetris pieces. Once charts get small try to place them randomly. It may be interesting to try a intermediate strategy, first try // along one axis and then try exhaustively along that axis. - bool findChartLocation(const Vector2i &startPosition, bool bruteForce, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation(const Vector2i &startPosition, bool bruteForce, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) { const int attempts = 4096; if (bruteForce || attempts >= w * h) return findChartLocation_bruteForce(startPosition, atlasBitImage, chartBitImage, chartBitImageRotated, w, h, best_x, best_y, best_w, best_h, best_r, blockAligned, maxResolution, allowRotate); return findChartLocation_random(atlasBitImage, chartBitImage, chartBitImageRotated, w, h, best_x, best_y, best_w, best_h, best_r, attempts, blockAligned, maxResolution, allowRotate); } - bool findChartLocation_bruteForce(const Vector2i &startPosition, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation_bruteForce(const Vector2i &startPosition, const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, bool blockAligned, uint32_t maxResolution, bool allowRotate) { const int stepSize = blockAligned ? 4 : 1; int best_metric = INT_MAX; // Try two different orientations. @@ -8390,8 +8014,7 @@ private: return best_metric != INT_MAX; } - bool findChartLocation_random(const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, int minTrialCount, bool blockAligned, uint32_t maxResolution, bool allowRotate) - { + bool findChartLocation_random(const BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int w, int h, int *best_x, int *best_y, int *best_w, int *best_h, int *best_r, int minTrialCount, bool blockAligned, uint32_t maxResolution, bool allowRotate) { bool result = false; const int BLOCK_SIZE = 4; int best_metric = INT_MAX; @@ -8446,8 +8069,7 @@ private: return result; } - void addChart(BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int atlas_w, int atlas_h, int offset_x, int offset_y, int r) - { + void addChart(BitImage *atlasBitImage, const BitImage *chartBitImage, const BitImage *chartBitImageRotated, int atlas_w, int atlas_h, int offset_x, int offset_y, int r) { XA_DEBUG_ASSERT(r == 0 || r == 1); const BitImage *image = r == 0 ? chartBitImage : chartBitImageRotated; const int w = image->width(); @@ -8470,8 +8092,7 @@ private: } } - void bilinearExpand(const Chart *chart, BitImage *source, BitImage *dest, BitImage *destRotated, UniformGrid2 &boundaryEdgeGrid) const - { + void bilinearExpand(const Chart *chart, BitImage *source, BitImage *dest, BitImage *destRotated, UniformGrid2 &boundaryEdgeGrid) const { boundaryEdgeGrid.reset(chart->vertices, chart->indices); if (chart->boundaryEdges) { const uint32_t edgeCount = chart->boundaryEdges->size(); @@ -8526,13 +8147,11 @@ private: } } - struct DrawTriangleCallbackArgs - { + struct DrawTriangleCallbackArgs { BitImage *chartBitImage, *chartBitImageRotated; }; - static bool drawTriangleCallback(void *param, int x, int y) - { + static bool drawTriangleCallback(void *param, int x, int y) { auto args = (DrawTriangleCallbackArgs *)param; args->chartBitImage->set(x, y); if (args->chartBitImageRotated) @@ -8554,8 +8173,7 @@ private: } // namespace pack } // namespace internal -struct Context -{ +struct Context { Atlas atlas; uint32_t meshCount = 0; internal::Progress *addMeshProgress = nullptr; @@ -8568,16 +8186,14 @@ struct Context internal::Array<internal::UvMeshInstance *> uvMeshInstances; }; -Atlas *Create() -{ +Atlas *Create() { Context *ctx = XA_NEW(internal::MemTag::Default, Context); memset(&ctx->atlas, 0, sizeof(Atlas)); ctx->taskScheduler = XA_NEW(internal::MemTag::Default, internal::TaskScheduler); return &ctx->atlas; } -static void DestroyOutputMeshes(Context *ctx) -{ +static void DestroyOutputMeshes(Context *ctx) { if (!ctx->atlas.meshes) return; for (int i = 0; i < (int)ctx->atlas.meshCount; i++) { @@ -8598,8 +8214,7 @@ static void DestroyOutputMeshes(Context *ctx) ctx->atlas.meshes = nullptr; } -void Destroy(Atlas *atlas) -{ +void Destroy(Atlas *atlas) { XA_DEBUG_ASSERT(atlas); Context *ctx = (Context *)atlas; if (atlas->utilization) @@ -8634,14 +8249,12 @@ void Destroy(Atlas *atlas) #endif } -struct AddMeshTaskArgs -{ +struct AddMeshTaskArgs { Context *ctx; internal::Mesh *mesh; }; -static void runAddMeshTask(void *userData) -{ +static void runAddMeshTask(void *userData) { XA_PROFILE_START(addMeshThread) auto args = (AddMeshTaskArgs *)userData; // Responsible for freeing this. internal::Mesh *mesh = args->mesh; @@ -8710,37 +8323,32 @@ cleanup: XA_PROFILE_END(addMeshThread) } -static internal::Vector3 DecodePosition(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector3 DecodePosition(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexPositionData); XA_DEBUG_ASSERT(meshDecl.vertexPositionStride > 0); return *((const internal::Vector3 *)&((const uint8_t *)meshDecl.vertexPositionData)[meshDecl.vertexPositionStride * index]); } -static internal::Vector3 DecodeNormal(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector3 DecodeNormal(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexNormalData); XA_DEBUG_ASSERT(meshDecl.vertexNormalStride > 0); return *((const internal::Vector3 *)&((const uint8_t *)meshDecl.vertexNormalData)[meshDecl.vertexNormalStride * index]); } -static internal::Vector2 DecodeUv(const MeshDecl &meshDecl, uint32_t index) -{ +static internal::Vector2 DecodeUv(const MeshDecl &meshDecl, uint32_t index) { XA_DEBUG_ASSERT(meshDecl.vertexUvData); XA_DEBUG_ASSERT(meshDecl.vertexUvStride > 0); return *((const internal::Vector2 *)&((const uint8_t *)meshDecl.vertexUvData)[meshDecl.vertexUvStride * index]); } -static uint32_t DecodeIndex(IndexFormat::Enum format, const void *indexData, int32_t offset, uint32_t i) -{ +static uint32_t DecodeIndex(IndexFormat::Enum format, const void *indexData, int32_t offset, uint32_t i) { XA_DEBUG_ASSERT(indexData); if (format == IndexFormat::UInt16) return uint16_t((int32_t)((const uint16_t *)indexData)[i] + offset); return uint32_t((int32_t)((const uint32_t *)indexData)[i] + offset); } -AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t meshCountHint) -{ +AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t meshCountHint) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddMesh: atlas is null.\n"); @@ -8758,8 +8366,7 @@ AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t mesh // Don't know how many times AddMesh will be called, so progress needs to adjusted each time. if (!ctx->addMeshProgress) { ctx->addMeshProgress = XA_NEW_ARGS(internal::MemTag::Default, internal::Progress, ProgressCategory::AddMesh, ctx->progressFunc, ctx->progressUserData, 1); - } - else { + } else { ctx->addMeshProgress->setMaxValue(internal::max(ctx->meshCount + 1, meshCountHint)); } XA_PROFILE_START(addMeshCopyData) @@ -8875,8 +8482,7 @@ AddMeshError::Enum AddMesh(Atlas *atlas, const MeshDecl &meshDecl, uint32_t mesh return AddMeshError::Success; } -void AddMeshJoin(Atlas *atlas) -{ +void AddMeshJoin(Atlas *atlas) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddMeshJoin: atlas is null.\n"); @@ -8904,19 +8510,19 @@ void AddMeshJoin(Atlas *atlas) XA_PRINT_MEM_USAGE } -struct EdgeKey -{ +struct EdgeKey { EdgeKey() {} - EdgeKey(const EdgeKey &k) : v0(k.v0), v1(k.v1) {} - EdgeKey(uint32_t v0, uint32_t v1) : v0(v0), v1(v1) {} + EdgeKey(const EdgeKey &k) : + v0(k.v0), v1(k.v1) {} + EdgeKey(uint32_t v0, uint32_t v1) : + v0(v0), v1(v1) {} bool operator==(const EdgeKey &k) const { return v0 == k.v0 && v1 == k.v1; } uint32_t v0; uint32_t v1; }; -AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) -{ +AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) { XA_DEBUG_ASSERT(atlas); if (!atlas) { XA_PRINT_WARNING("AddUvMesh: atlas is null.\n"); @@ -9026,8 +8632,7 @@ AddMeshError::Enum AddUvMesh(Atlas *atlas, const UvMeshDecl &decl) return AddMeshError::Success; } -void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) -{ +void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) { if (!atlas) { XA_PRINT_WARNING("ComputeCharts: atlas is null.\n"); return; @@ -9100,8 +8705,7 @@ void ComputeCharts(Atlas *atlas, ChartOptions chartOptions) XA_PRINT_MEM_USAGE } -void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) -{ +void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) { if (!atlas) { XA_PRINT_WARNING("ParameterizeCharts: atlas is null.\n"); return; @@ -9132,7 +8736,7 @@ void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) XA_PROFILE_START(parameterizeChartsReal) if (!ctx->paramAtlas.parameterizeCharts(ctx->taskScheduler, func, ctx->progressFunc, ctx->progressUserData)) { XA_PRINT(" Cancelled by user\n"); - return; + return; } XA_PROFILE_END(parameterizeChartsReal) uint32_t chartCount = 0, orthoChartsCount = 0, planarChartsCount = 0, lscmChartsCount = 0, piecewiseChartsCount = 0, chartsAddedCount = 0, chartsDeletedCount = 0; @@ -9234,8 +8838,7 @@ void ParameterizeCharts(Atlas *atlas, ParameterizeFunc func) XA_PRINT_MEM_USAGE } -void PackCharts(Atlas *atlas, PackOptions packOptions) -{ +void PackCharts(Atlas *atlas, PackOptions packOptions) { // Validate arguments and context state. if (!atlas) { XA_PRINT_WARNING("PackCharts: atlas is null.\n"); @@ -9277,8 +8880,7 @@ void PackCharts(Atlas *atlas, PackOptions packOptions) if (!ctx->uvMeshInstances.isEmpty()) { for (uint32_t i = 0; i < ctx->uvMeshInstances.size(); i++) packAtlas.addUvMeshCharts(ctx->uvMeshInstances[i]); - } - else + } else packAtlas.addCharts(ctx->taskScheduler, &ctx->paramAtlas); XA_PROFILE_END(packChartsAddCharts) XA_PROFILE_START(packCharts) @@ -9479,8 +9081,7 @@ void PackCharts(Atlas *atlas, PackOptions packOptions) XA_PRINT_MEM_USAGE } -void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFunc, PackOptions packOptions) -{ +void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFunc, PackOptions packOptions) { if (!atlas) { XA_PRINT_WARNING("Generate: atlas is null.\n"); return; @@ -9499,8 +9100,7 @@ void Generate(Atlas *atlas, ChartOptions chartOptions, ParameterizeFunc paramFun PackCharts(atlas, packOptions); } -void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progressUserData) -{ +void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progressUserData) { if (!atlas) { XA_PRINT_WARNING("SetProgressCallback: atlas is null.\n"); return; @@ -9510,20 +9110,17 @@ void SetProgressCallback(Atlas *atlas, ProgressFunc progressFunc, void *progress ctx->progressUserData = progressUserData; } -void SetAlloc(ReallocFunc reallocFunc, FreeFunc freeFunc) -{ +void SetAlloc(ReallocFunc reallocFunc, FreeFunc freeFunc) { internal::s_realloc = reallocFunc; internal::s_free = freeFunc; } -void SetPrint(PrintFunc print, bool verbose) -{ +void SetPrint(PrintFunc print, bool verbose) { internal::s_print = print; internal::s_printVerbose = verbose; } -const char *StringForEnum(AddMeshError::Enum error) -{ +const char *StringForEnum(AddMeshError::Enum error) { if (error == AddMeshError::Error) return "Unspecified error"; if (error == AddMeshError::IndexOutOfRange) @@ -9533,8 +9130,7 @@ const char *StringForEnum(AddMeshError::Enum error) return "Success"; } -const char *StringForEnum(ProgressCategory::Enum category) -{ +const char *StringForEnum(ProgressCategory::Enum category) { if (category == ProgressCategory::AddMesh) return "Adding mesh(es)"; if (category == ProgressCategory::ComputeCharts) |