Skip to content

Commit e1cccaa

Browse files
authored
Split Nemo specific speech features into a separate header (#1061)
* Split speech features normalize * clang * Rename header * Copilot fixes
1 parent b62dd46 commit e1cccaa

3 files changed

Lines changed: 158 additions & 140 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "ocos.h"
7+
#include "nemo_mel_spectrogram.h"
8+
9+
#include <cmath>
10+
#include <cstring>
11+
#include <cstdint>
12+
#include <variant>
13+
14+
namespace ort_extensions {
15+
16+
// Per-feature (per-mel-bin) normalization: for each feature row,
17+
// compute mean and std across time, then normalize.
18+
// Input: [1, num_features, num_frames] (feature_first) or [1, num_frames, num_features]
19+
// Output: same shape, normalized.
20+
class PerFeatureNormalize {
21+
public:
22+
template <typename DictT>
23+
OrtxStatus Init(const DictT& attrs) {
24+
for (const auto& [key, value] : attrs) {
25+
if (key == "eps") {
26+
eps_ = static_cast<float>(std::get<double>(value));
27+
} else if (key == "feature_first") {
28+
feature_first_ = std::get<int64_t>(value);
29+
} else if (key != "_comment") {
30+
return {kOrtxErrorInvalidArgument, "[PerFeatureNormalize]: Invalid key in the JSON configuration."};
31+
}
32+
}
33+
return {};
34+
}
35+
36+
OrtxStatus Compute(const ortc::Tensor<float>& input, ortc::Tensor<float>& output) {
37+
const auto& shape = input.Shape();
38+
int64_t num_features, num_frames;
39+
40+
if (shape.size() == 2) {
41+
// 2D: [features, frames] or [frames, features]
42+
num_features = feature_first_ ? shape[0] : shape[1];
43+
num_frames = feature_first_ ? shape[1] : shape[0];
44+
} else if (shape.size() == 3 && shape[0] == 1) {
45+
// 3D: [1, features, frames] or [1, frames, features]
46+
num_features = feature_first_ ? shape[1] : shape[2];
47+
num_frames = feature_first_ ? shape[2] : shape[1];
48+
} else {
49+
return {kOrtxErrorInvalidArgument,
50+
"[PerFeatureNormalize]: Expected input shape [features, frames] or [1, features, frames]."};
51+
}
52+
53+
const float* in_data = input.Data();
54+
float* out_data = output.Allocate(shape);
55+
56+
// Copy input to output first
57+
std::memcpy(out_data, in_data, num_features * num_frames * sizeof(float));
58+
59+
// Need at least 2 frames for sample std (N-1 denominator)
60+
if (num_frames <= 1) {
61+
// Single frame or empty: output zeros (value - mean = 0 for constant)
62+
std::memset(out_data, 0, num_features * num_frames * sizeof(float));
63+
return {};
64+
}
65+
66+
for (int64_t f = 0; f < num_features; ++f) {
67+
// Compute mean
68+
float sum = 0.0f;
69+
for (int64_t t = 0; t < num_frames; ++t) {
70+
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
71+
sum += out_data[idx];
72+
}
73+
float mean = sum / static_cast<float>(num_frames);
74+
75+
// Compute std (sample std, divide by N-1)
76+
float var_sum = 0.0f;
77+
for (int64_t t = 0; t < num_frames; ++t) {
78+
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
79+
float d = out_data[idx] - mean;
80+
var_sum += d * d;
81+
}
82+
float std_val = std::sqrt(var_sum / static_cast<float>(num_frames - 1)) + eps_;
83+
84+
// Normalize
85+
for (int64_t t = 0; t < num_frames; ++t) {
86+
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
87+
out_data[idx] = (out_data[idx] - mean) / std_val;
88+
}
89+
}
90+
91+
return {};
92+
}
93+
94+
private:
95+
float eps_{1e-5f};
96+
int64_t feature_first_{1}; // 1 = [1, features, frames], 0 = [1, frames, features]
97+
};
98+
99+
// NeMo-compatible log-mel spectrogram kernel.
100+
// Wraps nemo_mel::NemoComputeLogMelBatch for use in the SpeechFeatureExtractor pipeline.
101+
// Input: [num_samples] or [1, num_samples] float32 PCM audio
102+
// Output: [num_mels, num_frames] float32 log-mel spectrogram per example;
103+
// StackTensors adds the batch dimension later in the pipeline.
104+
class NemoLogMel {
105+
public:
106+
template <typename DictT>
107+
OrtxStatus Init(const DictT& attrs) {
108+
for (const auto& [key, value] : attrs) {
109+
if (key == "num_mels") {
110+
cfg_.num_mels = static_cast<int>(std::get<int64_t>(value));
111+
} else if (key == "fft_size") {
112+
cfg_.fft_size = static_cast<int>(std::get<int64_t>(value));
113+
} else if (key == "hop_length") {
114+
cfg_.hop_length = static_cast<int>(std::get<int64_t>(value));
115+
} else if (key == "win_length") {
116+
cfg_.win_length = static_cast<int>(std::get<int64_t>(value));
117+
} else if (key == "sample_rate") {
118+
cfg_.sample_rate = static_cast<int>(std::get<int64_t>(value));
119+
} else if (key == "preemph") {
120+
cfg_.preemph = static_cast<float>(std::get<double>(value));
121+
} else if (key == "log_eps") {
122+
cfg_.log_eps = static_cast<float>(std::get<double>(value));
123+
} else if (key != "_comment") {
124+
return {kOrtxErrorInvalidArgument, "[NemoLogMel]: Invalid key in the JSON configuration."};
125+
}
126+
}
127+
return {};
128+
}
129+
130+
OrtxStatus Compute(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& logmel) {
131+
const auto& shape = pcm.Shape();
132+
size_t num_samples;
133+
if (shape.size() == 1) {
134+
num_samples = static_cast<size_t>(shape[0]);
135+
} else if (shape.size() == 2 && shape[0] == 1) {
136+
num_samples = static_cast<size_t>(shape[1]);
137+
} else {
138+
return {kOrtxErrorInvalidArgument, "[NemoLogMel]: Expected input shape [num_samples] or [1, num_samples]."};
139+
}
140+
141+
int num_frames = 0;
142+
auto mel_data = nemo_mel::NemoComputeLogMelBatch(pcm.Data(), num_samples, cfg_, num_frames);
143+
144+
// Output [num_mels, num_frames] (no batch dim) — StackTensors adds the batch dim
145+
auto* out = logmel.Allocate({cfg_.num_mels, num_frames});
146+
std::memcpy(out, mel_data.data(), mel_data.size() * sizeof(float));
147+
return {};
148+
}
149+
150+
private:
151+
nemo_mel::NemoMelConfig cfg_{128, 512, 160, 400, 16000, 0.97f, 5.96046448e-08f};
152+
};
153+
154+
} // namespace ort_extensions

shared/api/speech_features.hpp

Lines changed: 3 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
#include <dlib/matrix.h>
77
#include <math/dlib/stft_norm.hpp>
8-
#include "nemo_mel_spectrogram.h"
8+
9+
#include "nemo_speech_features.hpp"
910

1011
#ifndef M_PI
1112
#define M_PI 3.14159265358979323846
@@ -660,141 +661,4 @@ class Phi4AudioEmbed {
660661
int64_t qformer_compression_rate_{1};
661662
};
662663

663-
// Per-feature (per-mel-bin) normalization: for each feature row,
664-
// compute mean and std across time, then normalize.
665-
// Input: [1, num_features, num_frames] (feature_first) or [1, num_frames, num_features]
666-
// Output: same shape, normalized.
667-
class PerFeatureNormalize {
668-
public:
669-
template <typename DictT>
670-
OrtxStatus Init(const DictT& attrs) {
671-
for (const auto& [key, value] : attrs) {
672-
if (key == "eps") {
673-
eps_ = static_cast<float>(std::get<double>(value));
674-
} else if (key == "feature_first") {
675-
feature_first_ = std::get<int64_t>(value);
676-
} else if (key != "_comment") {
677-
return {kOrtxErrorInvalidArgument, "[PerFeatureNormalize]: Invalid key in the JSON configuration."};
678-
}
679-
}
680-
return {};
681-
}
682-
683-
OrtxStatus Compute(const ortc::Tensor<float>& input, ortc::Tensor<float>& output) {
684-
const auto& shape = input.Shape();
685-
int64_t num_features, num_frames;
686-
687-
if (shape.size() == 2) {
688-
// 2D: [features, frames] or [frames, features]
689-
num_features = feature_first_ ? shape[0] : shape[1];
690-
num_frames = feature_first_ ? shape[1] : shape[0];
691-
} else if (shape.size() == 3 && shape[0] == 1) {
692-
// 3D: [1, features, frames] or [1, frames, features]
693-
num_features = feature_first_ ? shape[1] : shape[2];
694-
num_frames = feature_first_ ? shape[2] : shape[1];
695-
} else {
696-
return {kOrtxErrorInvalidArgument, "[PerFeatureNormalize]: Expected input shape [features, frames] or [1, features, frames]."};
697-
}
698-
699-
const float* in_data = input.Data();
700-
float* out_data = output.Allocate(shape);
701-
702-
// Copy input to output first
703-
std::memcpy(out_data, in_data, num_features * num_frames * sizeof(float));
704-
705-
// Need at least 2 frames for sample std (N-1 denominator)
706-
if (num_frames <= 1) {
707-
// Single frame or empty: output zeros (value - mean = 0 for constant)
708-
std::memset(out_data, 0, num_features * num_frames * sizeof(float));
709-
return {};
710-
}
711-
712-
for (int64_t f = 0; f < num_features; ++f) {
713-
// Compute mean
714-
float sum = 0.0f;
715-
for (int64_t t = 0; t < num_frames; ++t) {
716-
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
717-
sum += out_data[idx];
718-
}
719-
float mean = sum / static_cast<float>(num_frames);
720-
721-
// Compute std (sample std, divide by N-1)
722-
float var_sum = 0.0f;
723-
for (int64_t t = 0; t < num_frames; ++t) {
724-
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
725-
float d = out_data[idx] - mean;
726-
var_sum += d * d;
727-
}
728-
float std_val = std::sqrt(var_sum / static_cast<float>(num_frames - 1)) + eps_;
729-
730-
// Normalize
731-
for (int64_t t = 0; t < num_frames; ++t) {
732-
int64_t idx = feature_first_ ? (f * num_frames + t) : (t * num_features + f);
733-
out_data[idx] = (out_data[idx] - mean) / std_val;
734-
}
735-
}
736-
737-
return {};
738-
}
739-
740-
private:
741-
float eps_{1e-5f};
742-
int64_t feature_first_{1}; // 1 = [1, features, frames], 0 = [1, frames, features]
743-
};
744-
745-
// NeMo-compatible log-mel spectrogram kernel.
746-
// Wraps nemo_mel::NemoComputeLogMelBatch for use in the SpeechFeatureExtractor pipeline.
747-
// Input: [num_samples] or [1, num_samples] float32 PCM audio
748-
// Output: [num_mels, num_frames] float32 log-mel spectrogram per example;
749-
// StackTensors adds the batch dimension later in the pipeline.
750-
class NemoLogMel {
751-
public:
752-
template <typename DictT>
753-
OrtxStatus Init(const DictT& attrs) {
754-
for (const auto& [key, value] : attrs) {
755-
if (key == "num_mels") {
756-
cfg_.num_mels = static_cast<int>(std::get<int64_t>(value));
757-
} else if (key == "fft_size") {
758-
cfg_.fft_size = static_cast<int>(std::get<int64_t>(value));
759-
} else if (key == "hop_length") {
760-
cfg_.hop_length = static_cast<int>(std::get<int64_t>(value));
761-
} else if (key == "win_length") {
762-
cfg_.win_length = static_cast<int>(std::get<int64_t>(value));
763-
} else if (key == "sample_rate") {
764-
cfg_.sample_rate = static_cast<int>(std::get<int64_t>(value));
765-
} else if (key == "preemph") {
766-
cfg_.preemph = static_cast<float>(std::get<double>(value));
767-
} else if (key == "log_eps") {
768-
cfg_.log_eps = static_cast<float>(std::get<double>(value));
769-
} else if (key != "_comment") {
770-
return {kOrtxErrorInvalidArgument, "[NemoLogMel]: Invalid key in the JSON configuration."};
771-
}
772-
}
773-
return {};
774-
}
775-
776-
OrtxStatus Compute(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& logmel) {
777-
const auto& shape = pcm.Shape();
778-
size_t num_samples;
779-
if (shape.size() == 1) {
780-
num_samples = static_cast<size_t>(shape[0]);
781-
} else if (shape.size() == 2 && shape[0] == 1) {
782-
num_samples = static_cast<size_t>(shape[1]);
783-
} else {
784-
return {kOrtxErrorInvalidArgument, "[NemoLogMel]: Expected input shape [num_samples] or [1, num_samples]."};
785-
}
786-
787-
int num_frames = 0;
788-
auto mel_data = nemo_mel::NemoComputeLogMelBatch(pcm.Data(), num_samples, cfg_, num_frames);
789-
790-
// Output [num_mels, num_frames] (no batch dim) — StackTensors adds the batch dim
791-
auto* out = logmel.Allocate({cfg_.num_mels, num_frames});
792-
std::memcpy(out, mel_data.data(), mel_data.size() * sizeof(float));
793-
return {};
794-
}
795-
796-
private:
797-
nemo_mel::NemoMelConfig cfg_{128, 512, 160, 400, 16000, 0.97f, 5.96046448e-08f};
798-
};
799-
800-
} // namespace ort_extensions
664+
} // namespace ort_extensions

test/static_test/test_nemo_mel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "nemo_mel_spectrogram.h"
1313
#include "c_api_utils.hpp"
1414
#include "runner.hpp"
15-
#include "speech_features.hpp"
15+
#include "nemo_speech_features.hpp"
1616

1717
#ifndef M_PI
1818
#define M_PI 3.14159265358979323846

0 commit comments

Comments
 (0)