summaryrefslogtreecommitdiff
path: root/core/templates
diff options
context:
space:
mode:
authorRémi Verschelde <rverschelde@gmail.com>2022-01-10 13:56:55 +0100
committerRémi Verschelde <rverschelde@gmail.com>2022-01-10 22:42:03 +0100
commitc6cefb1b79d207af1bc78ce20c01b5788e806252 (patch)
treeca713ba3bf904b57a7408ec50029c9fbcc275d40 /core/templates
parent4acc819f9bafacf5f912caf5ba2ebc15f70e3dbb (diff)
`Array`: Relax `slice` bound checks to properly handle negative indices
The same is done for `Vector` (and thus `Packed*Array`). `begin` and `end` can now take any value and will be clamped to `[-size(), size()]`. Negative values are a shorthand for indexing the array from the last element upward. `end` is given a default `INT_MAX` value (which will be clamped to `size()`) so that the `end` parameter can be omitted to go from `begin` to the max size of the array. This makes `slice` works similarly to numpy's and JavaScript's.
Diffstat (limited to 'core/templates')
-rw-r--r--core/templates/vector.h23
1 files changed, 14 insertions, 9 deletions
diff --git a/core/templates/vector.h b/core/templates/vector.h
index 4ada3b597a..d1408125c8 100644
--- a/core/templates/vector.h
+++ b/core/templates/vector.h
@@ -43,6 +43,7 @@
#include "core/templates/search_array.h"
#include "core/templates/sort_array.h"
+#include <climits>
#include <initializer_list>
template <class T>
@@ -145,25 +146,29 @@ public:
return ret;
}
- Vector<T> slice(int p_begin, int p_end) const {
+ Vector<T> slice(int p_begin, int p_end = INT_MAX) const {
Vector<T> result;
- if (p_end < 0) {
- p_end += size() + 1;
- }
+ const int s = size();
- ERR_FAIL_INDEX_V(p_begin, size(), result);
- ERR_FAIL_INDEX_V(p_end, size() + 1, result);
+ int begin = CLAMP(p_begin, -s, s);
+ if (begin < 0) {
+ begin += s;
+ }
+ int end = CLAMP(p_end, -s, s);
+ if (end < 0) {
+ end += s;
+ }
- ERR_FAIL_COND_V(p_begin > p_end, result);
+ ERR_FAIL_COND_V(begin > end, result);
- int result_size = p_end - p_begin;
+ int result_size = end - begin;
result.resize(result_size);
const T *const r = ptr();
T *const w = result.ptrw();
for (int i = 0; i < result_size; ++i) {
- w[i] = r[p_begin + i];
+ w[i] = r[begin + i];
}
return result;