runtime_shape.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
  13. #define TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_
  14. namespace tflite {
  15. template <int N>
  16. struct Dims {
  17. int sizes[N];
  18. int strides[N];
  19. };
  20. class RuntimeShape {
  21. public:
  22. RuntimeShape& operator=(RuntimeShape const&) = delete;
  23. RuntimeShape() : size_(0) {}
  24. explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {}
  25. RuntimeShape(int shape_size, int32_t value) : size_(shape_size) {
  26. for (int i = 0; i < shape_size; ++i) {
  27. SetDim(i, value);
  28. }
  29. }
  30. RuntimeShape(int dimensions_count, const int32_t* dims_data)
  31. : size_(dimensions_count) {
  32. ReplaceWith(dimensions_count, dims_data);
  33. }
  34. bool operator==(const RuntimeShape& comp) const {
  35. return this->size_ == comp.size_ &&
  36. std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) ==
  37. 0;
  38. }
  39. ~RuntimeShape() {}
  40. int32_t DimensionsCount() const { return size_; }
  41. int32_t Dims(int i) const {
  42. TFLITE_DCHECK_GE(i, 0);
  43. TFLITE_DCHECK_LT(i, size_);
  44. return dims_[i];
  45. }
  46. void SetDim(int i, int32_t val) {
  47. TFLITE_DCHECK_GE(i, 0);
  48. TFLITE_DCHECK_LT(i, size_);
  49. dims_[i] = val;
  50. }
  51. static RuntimeShape ExtendedShape(int new_shape_size,
  52. const RuntimeShape& shape) {
  53. return RuntimeShape(new_shape_size, shape, 1);
  54. }
  55. int32_t* DimsData() { return dims_; }
  56. const int32_t* DimsData() const { return dims_; }
  57. const int32_t* DimsDataUpTo5D() const { return dims_; }
  58. void ReplaceWith(int dimensions_count, const int32_t* dims_data) {
  59. size_ = dimensions_count;
  60. int32_t* dst_dims = DimsData();
  61. std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
  62. }
  63. // Returns the total count of elements, that is the size when flattened into a
  64. // vector.
  65. int FlatSize() const {
  66. int buffer_size = 1;
  67. const int* dims_data = reinterpret_cast<const int*>(DimsData());
  68. for (int i = 0; i < size_; i++) {
  69. buffer_size *= dims_data[i];
  70. }
  71. return buffer_size;
  72. }
  73. private:
  74. // For use only by ExtendedShape(), written to guarantee (return-value) copy
  75. // elision in C++17.
  76. // This creates a shape padded to the desired size with the specified value.
  77. RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
  78. : size_(new_shape_size) {
  79. // If the following check fails, it is likely because a 4D-only kernel is
  80. // being used with an array of larger dimension count.
  81. TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
  82. const int size_increase = new_shape_size - shape.DimensionsCount();
  83. for (int i = 0; i < size_increase; ++i) {
  84. SetDim(i, pad_value);
  85. }
  86. std::memcpy(DimsData() + size_increase, shape.DimsData(),
  87. sizeof(int32_t) * shape.DimensionsCount());
  88. }
  89. // A maximum of 4 dimensions are supported on TFLM.
  90. static constexpr int kMaxSize = 5;
  91. int32_t size_;
  92. union {
  93. int32_t dims_[kMaxSize];
  94. };
  95. };
  96. // Since tensors with '0' in their shape are valid in TF, these offset functions
  97. // allow that as long as the corresponding index is also 0. It is upto the
  98. // calling ops to ensure that they perform verification checks on tensor shapes
  99. // if they don't support a particular behavior.
  100. inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
  101. TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
  102. const int* dims_data = reinterpret_cast<const int*>(shape.DimsData());
  103. TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
  104. (i0 >= 0 && i0 < dims_data[0]));
  105. TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
  106. (i1 >= 0 && i1 < dims_data[1]));
  107. TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
  108. (i2 >= 0 && i2 < dims_data[2]));
  109. TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
  110. (i3 >= 0 && i3 < dims_data[3]));
  111. return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
  112. }
  113. inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3,
  114. int i4) {
  115. TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5);
  116. const int* dims_data = reinterpret_cast<const int*>(shape.DimsData());
  117. TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) ||
  118. (i0 >= 0 && i0 < dims_data[0]));
  119. TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) ||
  120. (i1 >= 0 && i1 < dims_data[1]));
  121. TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) ||
  122. (i2 >= 0 && i2 < dims_data[2]));
  123. TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) ||
  124. (i3 >= 0 && i3 < dims_data[3]));
  125. TFLITE_DCHECK((dims_data[4] == 0 && i4 == 0) ||
  126. (i4 >= 0 && i4 < dims_data[4]));
  127. return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) *
  128. dims_data[4] +
  129. i4;
  130. }
  131. } // namespace tflite
  132. #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_