Skip to content

Commit

Permalink
Improve naming of buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
BenBE committed Mar 13, 2022
1 parent 9f7cac1 commit 297bca4
Showing 1 changed file with 45 additions and 27 deletions.
72 changes: 45 additions & 27 deletions lib/libbackscrub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,38 @@ struct normalization_t {
struct backscrub_ctx_t {
// Loaded inference model
std::unique_ptr<tflite::FlatBufferModel> model;

// Model interpreter instance
std::unique_ptr<tflite::Interpreter> interpreter;

// Specific model type & input normalization
modeltype_t modeltype;
normalization_t norm;

// Optional callbacks with caller-provided context
void (*ondebug)(void *ctx, const char *msg);
void (*onprep)(void *ctx);
void (*oninfer)(void *ctx);
void (*onmask)(void *ctx);
void *caller_ctx;
// Processing state
cv::Mat input;
cv::Mat output;
cv::Rect roidim;
cv::Mat mask;
cv::Mat mroi;
cv::Mat ofinal;
cv::Size blur;

// Single step variables
cv::Mat input; // NN input tensors
cv::Mat output; // NN output tensors
cv::Mat ofinal; // NN output (post-processed mask)

float src_ratio; // Source image aspect ratio
cv::Rect src_roidim; // Source image rect of interest
cv::Mat mask_region; // Region of the final mask to operate on

float net_ratio; // NN input image aspect ratio
cv::Rect net_roidim; // NN input image rect of interest

// Result stitching variables
cv::Mat in_u8_bgr;
cv::Rect in_roidim;
float ratio;
float frameratio;

cv::Size blur; // Size of blur on final mask
cv::Mat mask; // Fully processed mask (full image)
};

// Debug helper
Expand Down Expand Up @@ -190,14 +199,17 @@ void *bs_maskgen_new(
) {
// Allocate context
backscrub_ctx_t *pctx = new backscrub_ctx_t;

// Take a reference so we can write tidy code with ctx.<x>
backscrub_ctx_t &ctx = *pctx;

// Save callbacks
ctx.ondebug = ondebug;
ctx.onprep = onprep;
ctx.oninfer = oninfer;
ctx.onmask = onmask;
ctx.caller_ctx = caller_ctx;

// Load model
ctx.model = tflite::FlatBufferModel::BuildFromFile(modelname.c_str());

Expand All @@ -209,18 +221,23 @@ void *bs_maskgen_new(

// Determine model type and normalization values
ctx.modeltype = get_modeltype(modelname);
ctx.norm = get_normalization(ctx.modeltype);

if (modeltype_t::Unknown == ctx.modeltype) {
_dbg(ctx, "error: unknown model type '%s'.\n", modelname.c_str());
bs_maskgen_delete(pctx);
return nullptr;
}

ctx.norm = get_normalization(ctx.modeltype);

// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;

// custom op for Google Meet network
resolver.AddCustom("Convolution2DTransposeBias", mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
resolver.AddCustom(
"Convolution2DTransposeBias",
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()
);
tflite::InterpreterBuilder builder(*ctx.model, resolver);
builder(&ctx.interpreter);

Expand Down Expand Up @@ -250,22 +267,22 @@ void *bs_maskgen_new(
return nullptr;
}

ctx.ratio = (float)ctx.input.rows / (float)ctx.input.cols;
ctx.frameratio = (float)height / (float)width;
ctx.net_ratio = (float)ctx.input.rows / (float)ctx.input.cols;
ctx.src_ratio = (float)height / (float)width;

// initialize mask and model-aspect ROI in center
if (ctx.frameratio < ctx.ratio) {
if (ctx.src_ratio < ctx.net_ratio) {
// if frame is wider than model, then use only the frame center
ctx.roidim = cv::Rect((width - height / ctx.ratio) / 2, 0, height / ctx.ratio, height);
ctx.in_roidim = cv::Rect(0, 0, ctx.input.cols, ctx.input.rows);
ctx.src_roidim = cv::Rect((width - height / ctx.net_ratio) / 2, 0, height / ctx.net_ratio, height);
ctx.net_roidim = cv::Rect(0, 0, ctx.input.cols, ctx.input.rows);
} else {
// if model is wider than the frame, center the frame in the model
ctx.roidim = cv::Rect(0, 0, width, height);
ctx.in_roidim = cv::Rect((ctx.input.cols - ctx.input.rows / ctx.frameratio) / 2, 0, ctx.input.rows / ctx.frameratio, ctx.input.rows);
ctx.src_roidim = cv::Rect(0, 0, width, height);
ctx.net_roidim = cv::Rect((ctx.input.cols - ctx.input.rows / ctx.src_ratio) / 2, 0, ctx.input.rows / ctx.src_ratio, ctx.input.rows);
}

ctx.mask = cv::Mat::ones(height, width, CV_8UC1) * 255;
ctx.mroi = ctx.mask(ctx.roidim);
ctx.mask_region = ctx.mask(ctx.src_roidim);

ctx.in_u8_bgr = cv::Mat(ctx.input.rows, ctx.input.cols, CV_8UC3, cv::Scalar(0, 0, 0));

Expand Down Expand Up @@ -301,11 +318,12 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
backscrub_ctx_t &ctx = *((backscrub_ctx_t *)context);

// map ROI
cv::Mat roi = frame(ctx.roidim);
cv::Mat roi = frame(ctx.src_roidim);

cv::Mat in_roi = ctx.in_u8_bgr(ctx.net_roidim);
cv::resize(roi, in_roi, ctx.net_roidim.size());

cv::Mat in_u8_rgb;
cv::Mat in_roi = ctx.in_u8_bgr(ctx.in_roidim);
cv::resize(roi, in_roi, ctx.in_roidim.size());
cv::cvtColor(ctx.in_u8_bgr, in_u8_rgb, cv::COLOR_BGR2RGB);

// TODO: can convert directly to float?
Expand Down Expand Up @@ -378,7 +396,7 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {
* probability in [0.0, 1.0].
*/
for (unsigned int n = 0; n < ctx.output.total(); n++) {
float exp0 = expf(tmp[2 * n ]);
float exp0 = expf(tmp[2 * n ]);
float exp1 = expf(tmp[2 * n + 1]);
float p0 = exp0 / (exp0 + exp1);
float p1 = exp1 / (exp0 + exp1);
Expand All @@ -398,10 +416,10 @@ bool bs_maskgen_process(void *context, cv::Mat &frame, cv::Mat &mask) {

// scale up into full-sized mask
cv::Mat tmpbuf;
cv::resize(ctx.ofinal(ctx.in_roidim), tmpbuf, ctx.mroi.size());
cv::resize(ctx.ofinal(ctx.net_roidim), tmpbuf, ctx.mask_region.size());

// blur at full size for maximum smoothness
cv::blur(tmpbuf, ctx.mroi, ctx.blur);
cv::blur(tmpbuf, ctx.mask_region, ctx.blur);

// copy out
mask = ctx.mask;
Expand Down

0 comments on commit 297bca4

Please sign in to comment.