fixedpoint.h 34 KB


  1. // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // fixedpoint.h: fixed-point arithmetic, with basic operations and
  15. // a few math functions such as tanh.
  16. #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
  17. #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
  18. #include <algorithm>
  19. #include <cassert>
  20. #include <cmath>
  21. #include <cstdint>
  22. #include <limits>
  23. #include "../internal/detect_platform.h"
  24. namespace gemmlowp {
  25. // Part 1: Low-level integer-arithmetic primitives.
  26. // The implementations here are generic implementations valid for
  27. // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
  28. // (e.g. NEON int32x4_t) may be supported by providing
  29. // specializations for them in separate files.
  30. //
  31. // The purpose of these primitives is two-fold:
  32. // - They will be used to implement higher-level fixed-point
  33. // abstractions, namely the FixedPoint class and its arithmetic
  34. // operators.
  35. // - They will be directly used to implement some more involved
  36. // fixed-point computations, e.g. the fixed-point implementation
  37. // of math functions such as tanh.
  38. // Some compile-time traits around raw types to handle SIMD aspects:
  39. // number of lanes, underlying scalar type.
  40. template <typename tIntegerType>
  41. struct FixedPointRawTypeTraits {};
  42. template <>
  43. struct FixedPointRawTypeTraits<std::int32_t> {
  44. typedef std::int32_t ScalarRawType;
  45. static constexpr int kLanes = 1;
  46. };
  47. template <>
  48. struct FixedPointRawTypeTraits<std::int16_t> {
  49. typedef std::int16_t ScalarRawType;
  50. static constexpr int kLanes = 1;
  51. };
  52. // Returns a SIMD value duplicating a scalar value across all lanes.
  53. template <typename tRawType>
  54. tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
  55. return x;
  56. }
  57. // Plain bit-wise AND
  58. template <typename tIntegerType>
  59. tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
  60. return a & b;
  61. }
  62. // Plain bit-wise OR
  63. template <typename tIntegerType>
  64. tIntegerType BitOr(tIntegerType a, tIntegerType b) {
  65. return a | b;
  66. }
  67. // Plain bit-wise XOR
  68. template <typename tIntegerType>
  69. tIntegerType BitXor(tIntegerType a, tIntegerType b) {
  70. return a ^ b;
  71. }
  72. // Plain bit-wise NOT
  73. template <typename tIntegerType>
  74. tIntegerType BitNot(tIntegerType a) {
  75. return ~a;
  76. }
  77. // Integer addition. Not saturating. Overflow is undefined behavior.
  78. template <typename tIntegerType>
  79. tIntegerType Add(tIntegerType a, tIntegerType b) {
  80. return a + b;
  81. }
  82. // Integer subtraction. Not saturating. Overflow is undefined behavior.
  83. template <typename tIntegerType>
  84. tIntegerType Mul(tIntegerType a, tIntegerType b) {
  85. return a * b;
  86. }
  87. template <typename tIntegerType>
  88. tIntegerType Sub(tIntegerType a, tIntegerType b) {
  89. return a - b;
  90. }
  91. // Integer unary negative. Not saturating. Overflow is undefined behavior.
  92. template <typename tIntegerType>
  93. tIntegerType Neg(tIntegerType a) {
  94. return -a;
  95. }
  96. // Integer arithmetic left-shift, equivalent to multiplying with a power of two.
  97. // Negative values are OK. In case of overflow, no Undefined
  98. // Behavior, but the results are implementation-defined (in practice,
  99. // they currently are saturated, but we make no commitment to that). The idea
  100. // is that the caller will want to implement the overflowing cases with
  101. // saturation with compare-and-mask, so we don't care about the results
  102. // in the overflow case, we just want to avoid undefined behavior.
  103. //
  104. // tIntegerType may be int32 or any narrower signed type.
  105. template <typename tIntegerType>
  106. tIntegerType ShiftLeft(tIntegerType a, int offset) {
  107. const std::int64_t wide_a = static_cast<std::int64_t>(a);
  108. const std::int64_t wide_shifted = wide_a * (1 << offset);
  109. const auto min = std::numeric_limits<tIntegerType>::min();
  110. const auto max = std::numeric_limits<tIntegerType>::max();
  111. return wide_shifted < min
  112. ? min
  113. : wide_shifted > max ? max
  114. : static_cast<tIntegerType>(wide_shifted);
  115. }
  116. // Integer arithmetic right-shift. Not rounding.
  117. // Relying on implementation-defined, but in-practice-consistent,
  118. // C++ compiler behavior.
  119. template <typename tIntegerType>
  120. tIntegerType ShiftRight(tIntegerType a, int offset) {
  121. return a >> offset;
  122. }
  123. // Each bit of the result is set to the corresponding bit of either then_val or
  124. // else_val depending on whether the corresponding bit of if_mask is set.
  125. // Equivalent to the VBSL instruction in ARM NEON.
  126. template <typename tIntegerType>
  127. tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
  128. tIntegerType else_val) {
  129. return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
  130. }
  131. // For each input scalar, the corresponding bits of the result are set if the
  132. // input scalar is non-zero.
  133. template <typename tIntegerType>
  134. tIntegerType MaskIfNonZero(tIntegerType a) {
  135. static constexpr tIntegerType zero = 0;
  136. return a ? BitNot(zero) : zero;
  137. }
  138. // For each input scalar, the corresponding bits of the result are set if the
  139. // input scalar is zero.
  140. template <typename tIntegerType>
  141. tIntegerType MaskIfZero(tIntegerType a) {
  142. return MaskIfNonZero<tIntegerType>(!a);
  143. }
  144. // For each pair of input scalars, the corresponding bits of the result are
  145. // set if the input scalars are equal.
  146. template <typename tIntegerType>
  147. tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
  148. return MaskIfNonZero<tIntegerType>(a == b);
  149. }
  150. // For each pair of input scalars, the corresponding bits of the result are
  151. // set if the input scalars are not equal.
  152. template <typename tIntegerType>
  153. tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
  154. return MaskIfNonZero<tIntegerType>(a != b);
  155. }
  156. // For each pair of input scalars, the corresponding bits of the result are
  157. // set if the input scalars a, b satisfy a > b.
  158. template <typename tIntegerType>
  159. tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
  160. return MaskIfNonZero<tIntegerType>(a > b);
  161. }
  162. // For each pair of input scalars, the corresponding bits of the result are
  163. // set if the input scalars a, b satisfy a >= b.
  164. template <typename tIntegerType>
  165. tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
  166. return MaskIfNonZero<tIntegerType>(a >= b);
  167. }
  168. // For each pair of input scalars, the corresponding bits of the result are
  169. // set if the input scalars a, b satisfy a < b.
  170. template <typename tIntegerType>
  171. tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
  172. return MaskIfNonZero<tIntegerType>(a < b);
  173. }
  174. // For each pair of input scalars, the corresponding bits of the result are
  175. // set if the input scalars a, b satisfy a <= b.
  176. template <typename tIntegerType>
  177. tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
  178. return MaskIfNonZero<tIntegerType>(a <= b);
  179. }
  180. // Returns true if all of the input scalars are nonzero.
  181. // This function may currently assume that each of the input scalars has either
  182. // all or none of its bits set. Otherwise, its behavior is currently undefined.
  183. template <typename tIntegerType>
  184. bool All(tIntegerType a) {
  185. return a;
  186. }
  187. // Returns true if any of the input scalars are nonzero.
  188. // This function may currently assume that each of the input scalars has either
  189. // all or none of its bits set. Otherwise, its behavior is currently undefined.
  190. template <typename tIntegerType>
  191. bool Any(tIntegerType a) {
  192. return a;
  193. }
  194. // Returns (a+b)/2, rounded to the nearest integer.
  195. // Equivalent to VRHADD in the ARM NEON instruction set.
  196. template <typename IntegerType>
  197. IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
  198. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  199. (void)b;
  200. return a;
  201. }
  202. template <>
  203. inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
  204. std::int64_t a64 = a;
  205. std::int64_t b64 = b;
  206. std::int64_t sum = a64 + b64;
  207. std::int64_t sign = sum >= 0 ? 1 : -1;
  208. return static_cast<std::int32_t>((sum + sign) / 2);
  209. }
  210. template <>
  211. inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
  212. std::int32_t a32 = a;
  213. std::int32_t b32 = b;
  214. std::int32_t sum = a32 + b32;
  215. std::int32_t sign = sum >= 0 ? 1 : -1;
  216. return static_cast<std::int16_t>((sum + sign) / 2);
  217. }
  218. template <typename IntegerType>
  219. IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
  220. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  221. (void)b;
  222. return a;
  223. }
  224. // So far this is only needed for int16.
  225. template <>
  226. inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
  227. std::int32_t a32 = a;
  228. std::int32_t b32 = b;
  229. std::int32_t sum = a32 + b32;
  230. return static_cast<std::int16_t>(
  231. std::min(static_cast<std::int32_t>(32767),
  232. std::max(static_cast<std::int32_t>(-32768), sum)));
  233. }
  234. // Returns a+b, saturating if the integers are 16bit or narrower,
  235. // otherwise just a plain addition.
  236. template <typename IntegerType, bool Is16Bit>
  237. struct AddSaturatingIf16BitImpl {
  238. static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
  239. };
  240. template <typename IntegerType>
  241. struct AddSaturatingIf16BitImpl<IntegerType, true> {
  242. static IntegerType Run(IntegerType a, IntegerType b) {
  243. return SaturatingAdd(a, b);
  244. }
  245. };
  246. template <typename IntegerType>
  247. IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
  248. using ScalarType =
  249. typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
  250. return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
  251. b);
  252. }
  253. // Returns the integer that represents the product of two fixed-point
  254. // numbers, interpreting all integers as fixed-point values in the
  255. // interval [-1, 1), rounding to the nearest value, and saturating
  256. // -1 * -1 to the maximum value (since 1 is not in the half-open
  257. // interval [-1, 1)).
  258. //
  259. // [The explanation below specializes to std::int32_t for example purpose.]
  260. //
  261. // The mapping between IntegerType and the interval [-1, 1) is unique and
  262. // implied by IntegerType, which is assumed to be signed. For example,
  263. // for IntegerType==std::int32_t, the mapping is
  264. // real_value = integer_value / 2^31.
  265. // So in this case, and leaving aside rounding and saturating, this
  266. // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
  267. // (a * b) / 2^31.
  268. //
  269. // The 'doubling' part in the name of this function comes from the fact that
  270. // this operation is very close to a "multiply-high" operation, keeping only
  271. // the top half bits, except that that would be effectively computing
  272. // (a * b) / 2^32,
  273. // so here we are computing 2x that, since
  274. // 1/2^31 = 2 * 1/2^32.
  275. // The idea is to use all of the available 32 bits in the destination int32
  276. // value.
  277. //
  278. // [End of the explanation specializing to int32.]
  279. //
  280. // This is equivalent to the VQRDMULH instruction in ARM NEON.
  281. template <typename IntegerType>
  282. IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
  283. static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
  284. (void)b;
  285. return a;
  286. }
  287. // This function implements the same computation as the ARMv7 NEON VQRDMULH
  288. // instruction.
  289. template <>
  290. inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
  291. std::int32_t b) {
  292. bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
  293. std::int64_t a_64(a);
  294. std::int64_t b_64(b);
  295. std::int64_t ab_64 = a_64 * b_64;
  296. std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
  297. std::int32_t ab_x2_high32 =
  298. static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
  299. return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
  300. }
  301. template <>
  302. inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
  303. std::int16_t b) {
  304. bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
  305. std::int32_t a_32(a);
  306. std::int32_t b_32(b);
  307. std::int32_t ab_32 = a_32 * b_32;
  308. std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
  309. std::int16_t ab_x2_high16 =
  310. static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
  311. return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
  312. }
  313. // Correctly-rounded-to-nearest division by a power-of-two.
  314. // Also known as a rounding arithmetic right shift.
  315. template <typename IntegerType>
  316. inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
  317. assert(exponent >= 0);
  318. assert(exponent <= 31);
  319. const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
  320. const IntegerType zero = Dup<IntegerType>(0);
  321. const IntegerType one = Dup<IntegerType>(1);
  322. const IntegerType remainder = BitAnd(x, mask);
  323. const IntegerType threshold =
  324. Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
  325. return Add(ShiftRight(x, exponent),
  326. BitAnd(MaskIfGreaterThan(remainder, threshold), one));
  327. }
  328. // Returns the product of a run-time integer value by a compile-time power
  329. // of two, with either a positive exponent (equivalent to an arithmetic
  330. // left shift, saturating) or a negative exponent (equivalent to an arithmetic
  331. // right shift, rounding to nearest).
  332. template <int Exponent, typename IntegerType,
  333. int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
  334. struct ImplSaturatingRoundingMultiplyByPOT {};
  335. template <int Exponent, typename IntegerType>
  336. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
  337. static IntegerType eval(IntegerType x) { return x; }
  338. };
  339. template <int Exponent, typename IntegerType>
  340. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
  341. static IntegerType eval(IntegerType x) {
  342. using ScalarIntegerType =
  343. typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
  344. const IntegerType min =
  345. Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
  346. const IntegerType max =
  347. Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
  348. const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
  349. const std::int32_t threshold =
  350. ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
  351. const IntegerType positive_mask =
  352. MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
  353. const IntegerType negative_mask =
  354. MaskIfLessThan(x, Dup<IntegerType>(-threshold));
  355. IntegerType result = ShiftLeft(x, Exponent);
  356. result = SelectUsingMask(positive_mask, max, result);
  357. result = SelectUsingMask(negative_mask, min, result);
  358. return result;
  359. }
  360. };
  361. template <int Exponent, typename IntegerType>
  362. struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
  363. static IntegerType eval(IntegerType x) {
  364. return RoundingDivideByPOT<IntegerType>(x, -Exponent);
  365. }
  366. };
  367. template <int Exponent, typename IntegerType>
  368. IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
  369. return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
  370. }
  371. // Part 2: the FixedPoint class.
  372. // A FixedPoint object represents a fixed-point value stored in the underlying
  373. // integer type tRawType, if tRawType is a plain scalar integer type.
  374. // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
  375. // case a FixedPoint object represents a corresponding SIMD vector of fixed
  376. // point values.
  377. //
  378. // tIntegerBits describes the range of the fixed-point format: if
  379. // tIntegerBits == m then the range of representable values is the half-open
  380. // interval [-2^m; 2^m) where the open boundary on the right side means that
  381. // 2^m is not representable (how close the maximum representable value is to
  382. // it, depends on bit-depth of tRawType).
  383. //
  384. // In "Q format notation",
  385. // https://en.wikipedia.org/wiki/Q_(number_format)
  386. // we are describing the format
  387. // Qm.n
  388. // where
  389. // m = tIntegerBits
  390. // and
  391. // n = NumberOfBits(tRawType) - (m + 1)
  392. // Note that the (m + 1) in the above line is because we adopt the convention
  393. // that we count the integer bits exclusively of the sign bit; so (m + 1) is
  394. // the total number of integer bits inclusive of the sign bit.
  395. //
  396. // Accordingly, the number of integral representable values in our range
  397. // [-2^m ; 2^m)
  398. // is equal to 2^(m+1).
  399. template <typename tRawType, int tIntegerBits>
  400. class FixedPoint {
  401. public:
  402. typedef tRawType RawType;
  403. typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
  404. typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
  405. static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
  406. static constexpr int kIntegerBits = tIntegerBits;
  407. static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
  408. static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
  409. "bad IntegerBits");
  410. typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
  411. static const ScalarRawType ScalarRawMin() {
  412. return std::numeric_limits<ScalarRawType>::min();
  413. }
  414. static const ScalarRawType ScalarRawMax() {
  415. return std::numeric_limits<ScalarRawType>::max();
  416. }
  417. static const ScalarRawType RawMin() {
  418. return VectorFromScalar(ScalarRawMin());
  419. }
  420. static const ScalarRawType RawMax() {
  421. return VectorFromScalar(ScalarRawMax());
  422. }
  423. static FixedPoint FromRaw(RawType x) {
  424. FixedPoint retval;
  425. retval.raw() = x;
  426. return retval;
  427. }
  428. static FixedPoint FromScalarRaw(ScalarRawType x) {
  429. FixedPoint retval;
  430. retval.raw() = Dup<RawType>(x);
  431. return retval;
  432. }
  433. static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
  434. return FromScalarRaw(x.raw());
  435. }
  436. template <int Exponent>
  437. static FixedPoint ConstantPOT() {
  438. static constexpr int kOffset = kFractionalBits + Exponent;
  439. static_assert(
  440. kOffset < 31,
  441. "Constant not exactly representable in this fixed-point format");
  442. return FromScalarRaw(ScalarRawType(1) << kOffset);
  443. }
  444. static FixedPoint Zero() { return FromScalarRaw(0); }
  445. static FixedPoint One() {
  446. return FromScalarRaw(
  447. kIntegerBits == 0
  448. ? ScalarRawMax()
  449. : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
  450. }
  451. static FixedPoint FromDouble(double x) {
  452. const double min_bound = static_cast<double>(ScalarRawMin());
  453. const double max_bound = static_cast<double>(ScalarRawMax());
  454. return FromScalarRaw(static_cast<ScalarRawType>(std::min(
  455. std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
  456. min_bound),
  457. max_bound)));
  458. }
  459. RawType raw() const { return i_; }
  460. RawType& raw() { return i_; }
  461. private:
  462. RawType i_;
  463. };
  464. // Part 3: implementation of arithmetic operators for the
  465. // FixedPoint class, and a few related functions.
  466. // A FixedPoint multiplication is just a
  467. // SaturatingRoundingDoublingHighMul operation on the underlying
  468. // raw integer values. The IntegerBits simply add up, as is obvious
  469. // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
  470. template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
  471. FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
  472. FixedPoint<tRawType, tIntegerBits_a> a,
  473. FixedPoint<tRawType, tIntegerBits_b> b) {
  474. FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
  475. c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
  476. return c;
  477. }
  478. // Tweaking IntegerBits gives exact multiplication by a power of two.
  479. template <int tExponent, typename tRawType, int tIntegerBits>
  480. FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
  481. FixedPoint<tRawType, tIntegerBits> a) {
  482. FixedPoint<tRawType, tExponent + tIntegerBits> c;
  483. c.raw() = a.raw();
  484. return c;
  485. }
  486. // If we want to leave IntegerBits fixed, then multiplication
  487. // by a power of two has to be saturating/rounding, not exact anymore.
  488. template <int tExponent, typename tRawType, int tIntegerBits>
  489. FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
  490. FixedPoint<tRawType, tIntegerBits> a) {
  491. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  492. SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
  493. }
  494. // Generic arithmetic operators.
  495. #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \
  496. template <typename tRawType, int tIntegerBits> \
  497. FixedPoint<tRawType, tIntegerBits> FuncName( \
  498. FixedPoint<tRawType, tIntegerBits> a) { \
  499. return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
  500. }
  501. #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
  502. template <typename tRawType, int tIntegerBits> \
  503. FixedPoint<tRawType, tIntegerBits> FuncName( \
  504. FixedPoint<tRawType, tIntegerBits> a, \
  505. FixedPoint<tRawType, tIntegerBits> b) { \
  506. return FixedPoint<tRawType, tIntegerBits>::FromRaw( \
  507. ImplFuncName(a.raw(), b.raw())); \
  508. }
  509. MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
  510. MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
  511. MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
  512. MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
  513. MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
  514. MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
  515. MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
  516. MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
  517. #undef MAKE_FIXEDPOINT_UNARY_FUNC
  518. #undef MAKE_FIXEDPOINT_BINARY_FUNC
  519. #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \
  520. template <typename tRawType, int tIntegerBits> \
  521. tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
  522. return FuncName(a.raw()); \
  523. }
  524. #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
  525. template <typename tRawType, int tIntegerBits> \
  526. tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \
  527. FixedPoint<tRawType, tIntegerBits> b) { \
  528. return FuncName(a.raw(), b.raw()); \
  529. }
  530. MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
  531. MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
  532. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
  533. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
  534. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
  535. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
  536. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
  537. MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
  538. #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
  539. #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
  540. template <typename tRawType, int tIntegerBits>
  541. FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
  542. tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
  543. FixedPoint<tRawType, tIntegerBits> else_val) {
  544. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  545. SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
  546. }
  547. template <typename tRawType, int tIntegerBits>
  548. bool operator==(FixedPoint<tRawType, tIntegerBits> a,
  549. FixedPoint<tRawType, tIntegerBits> b) {
  550. return All(MaskIfEqual(a.raw(), b.raw()));
  551. }
  552. template <typename tRawType, int tIntegerBits>
  553. bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
  554. FixedPoint<tRawType, tIntegerBits> b) {
  555. return !(a == b);
  556. }
  557. template <typename tRawType, int tIntegerBits>
  558. FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
  559. FixedPoint<tRawType, tIntegerBits> a,
  560. FixedPoint<tRawType, tIntegerBits> b) {
  561. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  562. SaturatingAdd(a.raw(), b.raw()));
  563. }
  564. template <typename tRawType, int tIntegerBits>
  565. FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
  566. FixedPoint<tRawType, tIntegerBits> a,
  567. FixedPoint<tRawType, tIntegerBits> b) {
  568. return FixedPoint<tRawType, tIntegerBits>::FromRaw(
  569. AddSaturatingIf16Bit(a.raw(), b.raw()));
  570. }
  571. // Conversion to floating-point.
  572. template <typename tRawType, int tIntegerBits>
  573. double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
  574. static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
  575. "not applicable to SIMD types");
  576. typedef FixedPoint<tRawType, tIntegerBits> F;
  577. return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
  578. }
  579. // Rescale changes the number of IntegerBits and updates the underlying
  580. // raw integer value accordingly.
  581. template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
  582. FixedPoint<tRawType, tIntegerBitsDst> Rescale(
  583. FixedPoint<tRawType, tIntegerBitsSrc> x) {
  584. static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
  585. FixedPoint<tRawType, tIntegerBitsDst> result;
  586. result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
  587. return result;
  588. }
  589. // CheckedFixedPointConstant allows to specify fixed-point constants
  590. // initialized as real numbers, in a way that does not compile floating-point
  591. // arithmetic in production code, yet still checks agreement with the
  592. // floating-point expressions when asserts are enabled.
  593. //
  594. // The raw integer value provided is always a int32, encoding a 32-bit
  595. // fixed-point value, regardless of the actual Scalar type. This allows
  596. // writing generic code that applies just as well to the 32-bit and 16-bit
  597. // cases. In the 16-bit case, the raw integer value is internally
  598. // rounding-shifted by 16 bits to the right.
  599. template <typename FixedPointType>
  600. inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
  601. std::int32_t int32_value) {
  602. typedef typename FixedPointType::ScalarRawType ScalarRawType;
  603. static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
  604. return static_cast<ScalarRawType>(
  605. RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
  606. }
  607. #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
  608. template <typename FixedPointType>
  609. FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
  610. double double_value) {
  611. const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
  612. assert(result == FixedPointType::FromDouble(double_value));
  613. return result;
  614. }
  615. #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
  616. ScalarRawInt32Value, DoubleValue) \
  617. (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
  618. gemmlowp::RescaleConstantInitializer<FixedPointType>( \
  619. ScalarRawInt32Value), \
  620. DoubleValue))
  621. #else
  622. #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
  623. ScalarRawInt32Value, DoubleValue) \
  624. (FixedPointType::FromScalarRaw( \
  625. gemmlowp::RescaleConstantInitializer<FixedPointType>( \
  626. ScalarRawInt32Value)))
  627. #endif
  628. // Implementation of exponential function.
  629. // Returns exp(x) for x in [-1/4, 0).
  630. template <typename tRawType>
  631. FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
  632. FixedPoint<tRawType, 0> a) {
  633. typedef FixedPoint<tRawType, 0> F;
  634. const F constant_term =
  635. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
  636. const F constant_1_over_3 =
  637. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
  638. // We're evaluating a Taylor expansion around -1/8, so we do the change of
  639. // variable: x = a + 1/8.
  640. // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
  641. F x = a + F::template ConstantPOT<-3>();
  642. F x2 = x * x;
  643. F x3 = x2 * x;
  644. F x4 = x2 * x2;
  645. F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
  646. F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
  647. SaturatingRoundingMultiplyByPOT<-1>(
  648. ((x4_over_4 + x3) * constant_1_over_3) + x2);
  649. return AddSaturatingIf16Bit(
  650. constant_term,
  651. constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
  652. }
  653. // Returns exp(x) for x < 0.
  654. template <typename tRawType, int tIntegerBits>
  655. FixedPoint<tRawType, 0> exp_on_negative_values(
  656. FixedPoint<tRawType, tIntegerBits> a) {
  657. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  658. typedef FixedPoint<tRawType, 0> ResultF;
  659. static constexpr int kFractionalBits = InputF::kFractionalBits;
  660. static constexpr int kIntegerBits = InputF::kIntegerBits;
  661. const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
  662. InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
  663. InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
  664. ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
  665. Rescale<0>(a_mod_quarter_minus_one_quarter));
  666. tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
  667. #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \
  668. if (kIntegerBits > Exponent) { \
  669. const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \
  670. ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
  671. static constexpr int kShiftAmount = \
  672. kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
  673. result = SelectUsingMask( \
  674. MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
  675. result * kMultiplier, result); \
  676. }
  677. GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
  678. GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
  679. GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
  680. GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
  681. GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
  682. GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
  683. GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
  684. #undef GEMMLOWP_EXP_BARREL_SHIFTER
  685. static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
  686. if (kIntegerBits > 5) {
  687. const InputF clamp =
  688. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
  689. result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
  690. }
  691. result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
  692. return result;
  693. }
  694. // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
  695. // Returns (1 - x) / (1 + x) for x in (0, 1).
  696. template <typename tRawType>
  697. FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
  698. FixedPoint<tRawType, 0> a) {
  699. typedef FixedPoint<tRawType, 0> F0;
  700. typedef FixedPoint<tRawType, 2> F2;
  701. F0 half_denominator = RoundingHalfSum(a, F0::One());
  702. // Newton-Raphson division
  703. // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  704. // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  705. const F2 constant_48_over_17 =
  706. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  707. const F2 constant_neg_32_over_17 =
  708. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  709. F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  710. for (int i = 0; i < 3; i++) {
  711. F2 half_denominator_times_x = half_denominator * x;
  712. F2 one_minus_half_denominator_times_x =
  713. F2::One() - half_denominator_times_x;
  714. x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  715. }
  716. return Rescale<0>(x - F2::One());
  717. }
  718. // Returns -tanh(x) for x < 0.
  719. template <typename tRawType, int tIntegerBits>
  720. FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
  721. FixedPoint<tRawType, tIntegerBits> a) {
  722. return one_minus_x_over_one_plus_x_for_x_in_0_1(
  723. exp_on_negative_values(ExactMulByPot<1>(a)));
  724. }
  725. // Returns tanh(x) for any x.
  726. template <typename tRawType, int tIntegerBits>
  727. FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
  728. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  729. typedef FixedPoint<tRawType, 0> ResultF;
  730. tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
  731. tRawType mask_if_zero = MaskIfZero(a);
  732. InputF n = SelectUsingMask(mask_if_negative, a, -a);
  733. ResultF t = neg_tanh_on_negative_values(n);
  734. return SelectUsingMask(mask_if_zero, ResultF::Zero(),
  735. SelectUsingMask(mask_if_negative, -t, t));
  736. }
  737. // Implementation of logistic function.
  738. // Returns 1 / (1 + x) for x in (0, 1).
  739. template <typename tRawType>
  740. FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
  741. FixedPoint<tRawType, 0> a) {
  742. typedef FixedPoint<tRawType, 0> F0;
  743. typedef FixedPoint<tRawType, 2> F2;
  744. F0 half_denominator = RoundingHalfSum(a, F0::One());
  745. // Newton-Raphson division
  746. // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
  747. // Refer to that page for the logic behind the 48/17 and 32/17 constants.
  748. const F2 constant_48_over_17 =
  749. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
  750. const F2 constant_neg_32_over_17 =
  751. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
  752. F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
  753. for (int i = 0; i < 3; i++) {
  754. F2 half_denominator_times_x = half_denominator * x;
  755. F2 one_minus_half_denominator_times_x =
  756. F2::One() - half_denominator_times_x;
  757. x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
  758. }
  759. return Rescale<0>(ExactMulByPot<-1>(x));
  760. }
  761. // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
  762. template <typename tRawType, int tIntegerBits>
  763. FixedPoint<tRawType, 0> logistic_on_positive_values(
  764. FixedPoint<tRawType, tIntegerBits> a) {
  765. return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
  766. }
  767. // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
  768. template <typename tRawType, int tIntegerBits>
  769. FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
  770. typedef FixedPoint<tRawType, tIntegerBits> InputF;
  771. typedef FixedPoint<tRawType, 0> ResultF;
  772. tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
  773. tRawType mask_if_zero = MaskIfZero(a);
  774. InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
  775. ResultF result_if_positive = logistic_on_positive_values(abs_input);
  776. ResultF result_if_negative = ResultF::One() - result_if_positive;
  777. const ResultF one_half =
  778. GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
  779. return SelectUsingMask(mask_if_zero, one_half,
  780. SelectUsingMask(mask_if_positive, result_if_positive,
  781. result_if_negative));
  782. }
  783. } // end namespace gemmlowp
  784. #ifdef GEMMLOWP_NEON
  785. #include "./fixedpoint_neon.h"
  786. #elif defined(GEMMLOWP_AVX2)
  787. #include "./fixedpoint_avx.h"
  788. #elif defined(GEMMLOWP_SSE4)
  789. #include "./fixedpoint_sse.h"
  790. #elif defined(GEMMLOWP_MSA)
  791. #include "./fixedpoint_msa.h"
  792. #endif
  793. #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_