77#include < math/dlib/stft_norm.hpp>
88#include " nemo_mel_spectrogram.h"
99
10+ #include " speech_features_normalize.hpp"
11+
1012#ifndef M_PI
1113#define M_PI 3.14159265358979323846
1214#endif
@@ -660,141 +662,9 @@ class Phi4AudioEmbed {
660662 int64_t qformer_compression_rate_{1 };
661663};
662664
663- // Per-feature (per-mel-bin) normalization: for each feature row,
664- // compute mean and std across time, then normalize.
665- // Input: [1, num_features, num_frames] (feature_first) or [1, num_frames, num_features]
666- // Output: same shape, normalized.
667- class PerFeatureNormalize {
668- public:
669- template <typename DictT>
670- OrtxStatus Init (const DictT& attrs) {
671- for (const auto & [key, value] : attrs) {
672- if (key == " eps" ) {
673- eps_ = static_cast <float >(std::get<double >(value));
674- } else if (key == " feature_first" ) {
675- feature_first_ = std::get<int64_t >(value);
676- } else if (key != " _comment" ) {
677- return {kOrtxErrorInvalidArgument , " [PerFeatureNormalize]: Invalid key in the JSON configuration." };
678- }
679- }
680- return {};
681- }
682-
683- OrtxStatus Compute (const ortc::Tensor<float >& input, ortc::Tensor<float >& output) {
684- const auto & shape = input.Shape ();
685- int64_t num_features, num_frames;
686-
687- if (shape.size () == 2 ) {
688- // 2D: [features, frames] or [frames, features]
689- num_features = feature_first_ ? shape[0 ] : shape[1 ];
690- num_frames = feature_first_ ? shape[1 ] : shape[0 ];
691- } else if (shape.size () == 3 && shape[0 ] == 1 ) {
692- // 3D: [1, features, frames] or [1, frames, features]
693- num_features = feature_first_ ? shape[1 ] : shape[2 ];
694- num_frames = feature_first_ ? shape[2 ] : shape[1 ];
695- } else {
696- return {kOrtxErrorInvalidArgument , " [PerFeatureNormalize]: Expected input shape [features, frames] or [1, features, frames]." };
697- }
698-
699- const float * in_data = input.Data ();
700- float * out_data = output.Allocate (shape);
701-
702- // Copy input to output first
703- std::memcpy (out_data, in_data, num_features * num_frames * sizeof (float ));
665+ // PerFeatureNormalize and NemoLogMel have been moved to
666+ // "speech_features_normalize.hpp" (included above). They remain available
667+ // to consumers of this header unchanged.
704668
705- // Need at least 2 frames for sample std (N-1 denominator)
706- if (num_frames <= 1 ) {
707- // Single frame or empty: output zeros (value - mean = 0 for constant)
708- std::memset (out_data, 0 , num_features * num_frames * sizeof (float ));
709- return {};
710- }
711-
712- for (int64_t f = 0 ; f < num_features; ++f) {
713- // Compute mean
714- float sum = 0 .0f ;
715- for (int64_t t = 0 ; t < num_frames; ++t) {
716- int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
717- sum += out_data[idx];
718- }
719- float mean = sum / static_cast <float >(num_frames);
720-
721- // Compute std (sample std, divide by N-1)
722- float var_sum = 0 .0f ;
723- for (int64_t t = 0 ; t < num_frames; ++t) {
724- int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
725- float d = out_data[idx] - mean;
726- var_sum += d * d;
727- }
728- float std_val = std::sqrt (var_sum / static_cast <float >(num_frames - 1 )) + eps_;
729-
730- // Normalize
731- for (int64_t t = 0 ; t < num_frames; ++t) {
732- int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
733- out_data[idx] = (out_data[idx] - mean) / std_val;
734- }
735- }
736-
737- return {};
738- }
739-
740- private:
741- float eps_{1e-5f };
742- int64_t feature_first_{1 }; // 1 = [1, features, frames], 0 = [1, frames, features]
743- };
744-
745- // NeMo-compatible log-mel spectrogram kernel.
746- // Wraps nemo_mel::NemoComputeLogMelBatch for use in the SpeechFeatureExtractor pipeline.
747- // Input: [num_samples] or [1, num_samples] float32 PCM audio
748- // Output: [num_mels, num_frames] float32 log-mel spectrogram per example;
749- // StackTensors adds the batch dimension later in the pipeline.
750- class NemoLogMel {
751- public:
752- template <typename DictT>
753- OrtxStatus Init (const DictT& attrs) {
754- for (const auto & [key, value] : attrs) {
755- if (key == " num_mels" ) {
756- cfg_.num_mels = static_cast <int >(std::get<int64_t >(value));
757- } else if (key == " fft_size" ) {
758- cfg_.fft_size = static_cast <int >(std::get<int64_t >(value));
759- } else if (key == " hop_length" ) {
760- cfg_.hop_length = static_cast <int >(std::get<int64_t >(value));
761- } else if (key == " win_length" ) {
762- cfg_.win_length = static_cast <int >(std::get<int64_t >(value));
763- } else if (key == " sample_rate" ) {
764- cfg_.sample_rate = static_cast <int >(std::get<int64_t >(value));
765- } else if (key == " preemph" ) {
766- cfg_.preemph = static_cast <float >(std::get<double >(value));
767- } else if (key == " log_eps" ) {
768- cfg_.log_eps = static_cast <float >(std::get<double >(value));
769- } else if (key != " _comment" ) {
770- return {kOrtxErrorInvalidArgument , " [NemoLogMel]: Invalid key in the JSON configuration." };
771- }
772- }
773- return {};
774- }
775-
776- OrtxStatus Compute (const ortc::Tensor<float >& pcm, ortc::Tensor<float >& logmel) {
777- const auto & shape = pcm.Shape ();
778- size_t num_samples;
779- if (shape.size () == 1 ) {
780- num_samples = static_cast <size_t >(shape[0 ]);
781- } else if (shape.size () == 2 && shape[0 ] == 1 ) {
782- num_samples = static_cast <size_t >(shape[1 ]);
783- } else {
784- return {kOrtxErrorInvalidArgument , " [NemoLogMel]: Expected input shape [num_samples] or [1, num_samples]." };
785- }
786-
787- int num_frames = 0 ;
788- auto mel_data = nemo_mel::NemoComputeLogMelBatch (pcm.Data (), num_samples, cfg_, num_frames);
789-
790- // Output [num_mels, num_frames] (no batch dim) — StackTensors adds the batch dim
791- auto * out = logmel.Allocate ({cfg_.num_mels , num_frames});
792- std::memcpy (out, mel_data.data (), mel_data.size () * sizeof (float ));
793- return {};
794- }
795-
796- private:
797- nemo_mel::NemoMelConfig cfg_{128 , 512 , 160 , 400 , 16000 , 0 .97f , 5 .96046448e-08f };
798- };
799669
800- } // namespace ort_extensions
670+ } // namespace ort_extensions
0 commit comments