@@ -377,40 +377,49 @@ inline Integer FloorLog2(Integer n) {
377
377
}
378
378
}
379
379
380
- // The size of the LUT depends on the type of input. For uint8 and int8 inputs
381
- // we use a 256 entries LUT to map all the values in the (u)int8 range. For
382
- // int16 inputs the high 9 bits are used for indexing and the 7 remaining bits
383
- // are used for interpolation. We thus use a 513-entries LUT for int16 cases,
384
- // 512 for the 9-bit indexing and 1 extra entry to interpolate the last value.
385
- template <typename T>
386
- constexpr int LUTSize () {
387
- static_assert (std::is_same<T, uint8_t >::value ||
388
- std::is_same<T, int8_t >::value ||
389
- std::is_same<T, int16_t >::value,
390
- " Only LUTs with uint8, int8 or int16 inputs are supported." );
391
- // As per c++11: constexpr methods cannot have more than one return statement.
392
- return (std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value)
393
- ? 256
394
- : 513 ;
380
+ namespace detail {
381
+
382
+ // LUTPopulate takes an optional type-erased transform_params to allow passing
383
+ // extra parameters to the transform function pointer. const void* is used
384
+ // instead of std::function to be compatible with TFLite Micro
385
+ template <typename FloatT, typename Func>
386
+ inline typename std::enable_if<std::is_same<Func, FloatT (*)(FloatT)>::value,
387
+ FloatT>::type
388
+ LUTTransform (Func transform, const void * /* transform_params*/ , FloatT value) {
389
+ static_assert (std::is_floating_point<FloatT>::value,
390
+ " FloatT must be a floating-point type." );
391
+ return transform (value);
392
+ }
393
+
394
+ template <typename FloatT, typename Func>
395
+ inline typename std::enable_if<
396
+ std::is_same<Func, FloatT (*)(FloatT, const void *)>::value, FloatT>::type
397
+ LUTTransform (Func transform, const void * transform_params, FloatT value) {
398
+ static_assert (std::is_floating_point<FloatT>::value,
399
+ " FloatT must be a floating-point type." );
400
+ return transform (value, transform_params);
395
401
}
396
402
397
403
// Use the same LUT generation code for both uint8_t and int8_t. Int8_t indexes
398
404
// will be directly casted to uint8_t, the int8 LUT will thus be ordered as [0,
399
405
// 1, ..., 127, -128, ..., -2, -1] instead of [-128, -127, ..., -1, 0, 1, ...,
400
406
// 126, 127].
401
- template <typename T>
402
- inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
403
- std::is_same<T, int8_t >::value,
404
- void >::type
405
- LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
406
- int32_t output_zero_point, float (*transform)(float ), T* lut) {
407
+ template <typename T, typename Func>
408
+ inline void LUTPopulateInt8 (float input_scale, int32_t input_zero_point,
409
+ float output_scale, int32_t output_zero_point,
410
+ Func transform, const void * transform_params,
411
+ T* lut) {
412
+ static_assert (
413
+ std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value,
414
+ " T must be an uint8 or int8 type." );
407
415
uint8_t * lut_uint8 = reinterpret_cast <uint8_t *>(lut);
408
416
const float inverse_scale = 1 / output_scale;
409
417
int32_t maxval = std::numeric_limits<T>::max ();
410
418
int32_t minval = std::numeric_limits<T>::min ();
411
419
for (int32_t val = minval; val <= maxval; ++val) {
412
420
const float dequantized = input_scale * (val - input_zero_point);
413
- const float transformed = transform (dequantized);
421
+ const float transformed =
422
+ LUTTransform (transform, transform_params, dequantized);
414
423
const float rescaled = TfLiteRound (transformed * inverse_scale);
415
424
const int32_t quantized =
416
425
static_cast <int32_t >(rescaled + output_zero_point);
@@ -421,10 +430,11 @@ LUTPopulate(float input_scale, int32_t input_zero_point, float output_scale,
421
430
422
431
// Keep floating-point type configurable for backward compatibility. float
423
432
// should be used for FloatT by default.
424
- template <typename T, typename FloatT>
425
- inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
426
- LUTPopulate (FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
427
- int32_t output_zero_point, FloatT (*transform)(FloatT), T* lut) {
433
+ template <typename FloatT, typename Func>
434
+ inline void LUTPopulateInt16 (FloatT input_scale, int32_t input_zero_point,
435
+ FloatT output_scale, int32_t output_zero_point,
436
+ Func transform, const void * transform_params,
437
+ int16_t * lut) {
428
438
static_assert (std::is_floating_point<FloatT>::value,
429
439
" FloatT must be a floating-point type." );
430
440
const FloatT input_min =
@@ -440,16 +450,21 @@ LUTPopulate(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
440
450
const FloatT step = (input_max - input_min) / nb_steps;
441
451
const FloatT half_step = step / 2 ;
442
452
const FloatT output_scaling_inv =
443
- static_cast <FloatT>(std::numeric_limits<T >::max () -
444
- std::numeric_limits<T >::min () + 1 ) /
453
+ static_cast <FloatT>(std::numeric_limits<int16_t >::max () -
454
+ std::numeric_limits<int16_t >::min () + 1 ) /
445
455
(output_max - output_min);
446
- const FloatT table_min = static_cast <FloatT>(std::numeric_limits<T>::min ());
447
- const FloatT table_max = static_cast <FloatT>(std::numeric_limits<T>::max ());
456
+ const FloatT table_min =
457
+ static_cast <FloatT>(std::numeric_limits<int16_t >::min ());
458
+ const FloatT table_max =
459
+ static_cast <FloatT>(std::numeric_limits<int16_t >::max ());
448
460
449
461
for (int i = 0 ; i < nb_steps; i++) {
450
- const FloatT val = transform (input_min + i * step);
451
- const FloatT val_midpoint = transform (input_min + i * step + half_step);
452
- const FloatT val_next = transform (input_min + (i + 1 ) * step);
462
+ const FloatT val =
463
+ LUTTransform<FloatT>(transform, transform_params, input_min + i * step);
464
+ const FloatT val_midpoint = LUTTransform<FloatT>(
465
+ transform, transform_params, input_min + i * step + half_step);
466
+ const FloatT val_next = LUTTransform<FloatT>(transform, transform_params,
467
+ input_min + (i + 1 ) * step);
453
468
454
469
const FloatT sample_val = TfLiteRound (val * output_scaling_inv);
455
470
const FloatT midpoint_interp_val =
@@ -460,54 +475,84 @@ LUTPopulate(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
460
475
const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
461
476
const FloatT bias = TfLiteRound (midpoint_err / 2 );
462
477
463
- lut[i] = static_cast <T >(std::min<FloatT>(
478
+ lut[i] = static_cast <int16_t >(std::min<FloatT>(
464
479
std::max<FloatT>(sample_val - bias, table_min), table_max));
465
480
}
466
481
467
- lut[nb_steps] = static_cast <T>(std::min<FloatT>(
468
- std::max<FloatT>(TfLiteRound (transform (input_max) * output_scaling_inv),
482
+ lut[nb_steps] = static_cast <int16_t >(std::min<FloatT>(
483
+ std::max<FloatT>(TfLiteRound (LUTTransform<FloatT>(
484
+ transform, transform_params, input_max) *
485
+ output_scaling_inv),
469
486
table_min),
470
487
table_max));
471
488
}
472
489
490
+ } // namespace detail
491
+
492
+ template <typename T>
493
+ inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
494
+ std::is_same<T, int8_t >::value,
495
+ void >::type
496
+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
497
+ int32_t output_zero_point, float (*transform)(float ), T* lut) {
498
+ detail::LUTPopulateInt8 (input_scale, input_zero_point, output_scale,
499
+ output_zero_point, transform, nullptr , lut);
500
+ }
501
+
502
+ template <typename T>
503
+ inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
504
+ std::is_same<T, int8_t >::value,
505
+ void >::type
506
+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
507
+ int32_t output_zero_point, float (*transform)(float , const void *),
508
+ const void * transform_params, T* lut) {
509
+ detail::LUTPopulateInt8 (input_scale, input_zero_point, output_scale,
510
+ output_zero_point, transform, transform_params, lut);
511
+ }
512
+
473
513
template <typename T>
474
514
inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
475
515
LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
476
516
int32_t output_zero_point, float (*transform)(float ), T* lut) {
477
- LUTPopulate<T, float >(input_scale, input_zero_point, output_scale,
478
- output_zero_point, transform, lut);
517
+ detail::LUTPopulateInt16<float >(input_scale, input_zero_point, output_scale,
518
+ output_zero_point, transform, nullptr , lut);
519
+ }
520
+
521
+ template <typename T>
522
+ inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
523
+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
524
+ int32_t output_zero_point, float (*transform)(float , const void *),
525
+ const void * transform_params, T* lut) {
526
+ detail::LUTPopulateInt16<float >(input_scale, input_zero_point, output_scale,
527
+ output_zero_point, transform,
528
+ transform_params, lut);
479
529
}
480
530
481
- // Deprecated and will be removed in future, please use LUTPopulate instead
482
- template <typename FloatT, typename LutInT, typename LutOutT>
483
- inline void gen_lut (FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
484
- FloatT output_min, FloatT output_max, LutOutT* lut) {
485
- static_assert (std::is_same<LutInT, LutOutT>::value,
486
- " Input and output type of the LUT must be the same." );
487
- static_assert (std::is_same<LutInT, int16_t >::value,
488
- " Only int16_t type LUT are supported." );
489
- static_assert (std::is_same<FloatT, float >::value,
490
- " Only float type is supported for FloatT." );
491
- using T = LutInT;
492
-
493
- const auto zero_point = [](float min, float max, float scale) {
494
- // Symmetric int16 LUT, we know the zero-point will not overflow an int32_t
495
- // and zero-point from min will be the same as from max.
496
- return static_cast <int32_t >(
497
- static_cast <float >(std::numeric_limits<T>::min ()) - min / scale);
498
- };
499
-
500
- const float scale = static_cast <float >(std::numeric_limits<T>::max () -
501
- std::numeric_limits<T>::min ());
502
- const float input_scale = (input_max - input_min) / scale;
503
- const FloatT output_scale = (output_max - output_min) / scale;
504
- const int32_t input_zero_point =
505
- zero_point (input_min, input_max, input_scale);
506
- const int32_t output_zero_point =
507
- zero_point (output_min, output_max, output_scale);
508
-
509
- return LUTPopulate<T, float >(input_scale, input_zero_point, output_scale,
510
- output_zero_point, func, lut);
531
+ // Deprecated, avoid usage and prefer the float version. Kept for
532
+ // backward-compatiblity.
533
+ template <typename T>
534
+ inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
535
+ LUTPopulate (double input_scale, int32_t input_zero_point, double output_scale,
536
+ int32_t output_zero_point, double (*transform)(double ), T* lut) {
537
+ detail::LUTPopulateInt16<double >(input_scale, input_zero_point, output_scale,
538
+ output_zero_point, transform, nullptr , lut);
539
+ }
540
+
541
+ // The size of the LUT depends on the type of input. For uint8 and int8 inputs a
542
+ // simple 256 entries LUT is used. For int16 inputs the high 9 bits are used for
543
+ // indexing and the 7 remaining bits are used for interpolation. We thus use a
544
+ // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
545
+ // to interpolate the last value.
546
+ template <typename T>
547
+ constexpr int LUTSize () {
548
+ static_assert (std::is_same<T, uint8_t >::value ||
549
+ std::is_same<T, int8_t >::value ||
550
+ std::is_same<T, int16_t >::value,
551
+ " Only LUTs with uint8, int8 or int16 inputs are supported." );
552
+ // As per c++11: constexpr methods cannot have more than one return statement.
553
+ return (std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value)
554
+ ? 256
555
+ : 513 ;
511
556
}
512
557
513
558
// int16_t -> int16_t table lookup with interpolation
0 commit comments