本文聚焦于在QT6框架下,面向Android平台,利用C++与TensorFlow Lite(TFLite)实现的人脸识别程序开发。研究目的在于探索一种跨平台、高性能的移动端人脸识别解决方案,突破传统Java或Python方案在实时性与资源管控上的局限。方法上,程序以QT6的Android JNI接口为桥梁,实现C++业务逻辑与Android系统服务的交互;核心推理引擎采用TFLite的C++ API,加载预训练的轻量化人脸检测与特征提取模型(如MobileNet-SSD或FaceNet变体)。在架构设计中,通过QT的信号槽机制处理摄像头帧数据,并利用OpenCV进行图像预处理与后处理,显著提升了数据流转效率。项目重点解决了Android平台下内存管理、模型量化(INT8/FP16)对精度的影响以及多线程推理等关键问题。实验结果表明,该程序在主流Android设备上能够实现实时(>30 FPS)且准确的人脸定位与识别,验证了QT6与TFLite在移动端C++生态下联合开发的可行性。本文为后续在嵌入式或移动设备上开发高性能、低延迟的视觉AI应用提供了可复用的技术框架与实践参考,具有重要的工程应用价值。

#include "mainwindow.h"
#include "ui_mainwindow.h"
#include "tfliteengine.h"

#include <QFileDialog>
#include <QMessageBox>
#include <QStandardPaths>
#include <QDir>
#include <QElapsedTimer>
#include <QDebug>

MainWindow::MainWindow(QWidget *parent)
    : QMainWindow(parent)
    , ui(new Ui::MainWindow)
    , m_engine(new TfLiteEngine(this))
{
    ui->setupUi(this);

    QString modelPath = getModelPath();
    if (!m_engine->loadModel(modelPath)) {
        QMessageBox::warning(this, tr("Error"),
            tr("Failed to load model from:\n%1").arg(modelPath));
    }

    connect(ui->selectBtn, &QPushButton::clicked, this, &MainWindow::onSelectImage);
    connect(ui->inferBtn, &QPushButton::clicked, this, &MainWindow::onRunInference);
    connect(m_engine, &TfLiteEngine::inferenceFinished,
            this, &MainWindow::onInferenceFinished);
    connect(m_engine, &TfLiteEngine::errorOccurred,
            this, &MainWindow::onError);
}

MainWindow::~MainWindow()
{
    delete ui;
}

QString MainWindow::getModelPath() const
{
#ifdef Q_OS_ANDROID
    return QStringLiteral("assets:/model.tflite");
#else
    QDir dir(QCoreApplication::applicationDirPath());
    QString modelPath = dir.filePath(QStringLiteral("../tflite_cpp/model/model.tflite"));
    if (QFile::exists(modelPath)) {
        return modelPath;
    }
    modelPath = dir.filePath(QStringLiteral("model.tflite"));
    return modelPath;
#endif
}

void MainWindow::onSelectImage()
{
    QStringList picturesLocations = QStandardPaths::standardLocations(QStandardPaths::PicturesLocation);
    QString startDir = picturesLocations.isEmpty() ? QString() : picturesLocations.first();

    QString fileName = QFileDialog::getOpenFileName(this,
        tr("Select Image"), startDir,
        tr("Images (*.png *.jpg *.jpeg *.bmp *.gif)"));

    if (fileName.isEmpty())
        return;

    QImage image(fileName);
    if (image.isNull()) {
        QMessageBox::warning(this, tr("Error"), tr("Failed to load image"));
        return;
    }

    m_currentImage = image;

    QPixmap pixmap = QPixmap::fromImage(image);
    QSize labelSize = ui->imageLabel->size();
    ui->imageLabel->setPixmap(pixmap.scaled(labelSize, Qt::KeepAspectRatio, Qt::SmoothTransformation));

    ui->inferBtn->setEnabled(m_engine->isLoaded());
    ui->resultText->clear();
}

void MainWindow::onRunInference()
{
    if (m_currentImage.isNull() || !m_engine->isLoaded())
        return;

    ui->inferBtn->setEnabled(false);
    ui->resultText->setText(tr("Running face detection..."));

    QElapsedTimer timer;
    timer.start();

    // Run face detection with bounding box visualization
    QImage annotated = m_engine->runInferenceWithBoxes(m_currentImage);

    qint64 elapsed = timer.elapsed();
    qDebug() << "Face detection took" << elapsed << "ms";

    // Display the annotated image
    if (!annotated.isNull() && annotated != m_currentImage) {
        QPixmap pixmap = QPixmap::fromImage(annotated);
        QSize labelSize = ui->imageLabel->size();
        ui->imageLabel->setPixmap(pixmap.scaled(labelSize, Qt::KeepAspectRatio, Qt::SmoothTransformation));
    }

    // Show results text
    QString output = tr("Face Detection Results\n");
    output += QStringLiteral("========================================\n");
    output += tr("Inference time: %1 ms\n\n").arg(elapsed);

    if (!annotated.isNull() && annotated != m_currentImage) {
        output += tr("Faces detected and highlighted on image.\n");
        output += tr("Green corners mark detected face regions.\n");
    } else {
        output += tr("No faces detected in the image.\n");
        output += tr("Try a different image with visible faces.\n");
    }
    ui->resultText->setText(output);

    ui->inferBtn->setEnabled(true);
}

void MainWindow::onInferenceFinished(const QVector<float> &results)
{
    Q_UNUSED(results);
}

void MainWindow::onError(const QString &error)
{
    QMessageBox::warning(this, tr("Error"), error);
}

QString MainWindow::formatResults(const QVector<float> &results) const
{
    QString output;
    output += tr("Face Detection Results\n");
    output += QStringLiteral("========================================\n");

    if (results.isEmpty()) {
        output += tr("No results.\n");
        return output;
    }

    int numDetections = static_cast<int>(results[0]);
    output += tr("Faces detected: %1\n\n").arg(numDetections);

    for (int i = 0; i < numDetections; ++i) {
        int offset = 1 + i * 5;
        if (offset + 4 >= results.size()) break;

        float ymin = results[offset + 0];
        float xmin = results[offset + 1];
        float ymax = results[offset + 2];
        float xmax = results[offset + 3];
        float score = results[offset + 4];

        output += tr("Face #%1: confidence=%2%\n")
            .arg(i + 1)
            .arg(score * 100, 0, 'f', 1);
        output += tr("  Box: [ymin=%1, xmin=%2, ymax=%3, xmax=%4]\n\n")
            .arg(ymin, 0, 'f', 3)
            .arg(xmin, 0, 'f', 3)
            .arg(ymax, 0, 'f', 3)
            .arg(xmax, 0, 'f', 3);
    }

    return output;
}
package com.tf.d;

import android.app.Activity;
import android.content.res.AssetManager;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.Tensor;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.MappedByteBuffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import android.util.Log;

public class TFHelper {
    private static final String TAG = "TFHelper";
    private Interpreter interpreter;
    private int inputHeight = 128;
    private int inputWidth = 128;
    private int inputChannels = 3;
    private boolean isFloatInput = true;
    private boolean isFloatOutput = true;
    private Activity activity;

    // BlazeFace Short Range anchor config
    private int numAnchors = 896;
    private float[] anchorCx;
    private float[] anchorCy;
    private float[] anchorW;
    private float[] anchorH;

    // Face detection result
    private float[] detectionBoxes;   // [numDetections][4] flattened: ymin,xmin,ymax,xmax
    private float[] detectionScores;  // [numDetections]
    private int numDetections = 0;

    // Output tensor index mapping (auto-detected from model shapes)
    private int regressorOutputIdx = 0;   // output index for regressors (16 values per anchor)
    private int classifierOutputIdx = 1;  // output index for classifiers (1 value per anchor)

    public void setActivity(Activity activity) {
        this.activity = activity;
    }

    public boolean loadModel(String modelPath) {
        try {
            MappedByteBuffer modelBuffer = null;
            File modelFile = new File(modelPath);
            if (modelFile.exists()) {
                modelBuffer = loadModelFile(modelFile.getAbsolutePath());
            }
            if (modelBuffer == null && activity != null) {
                String assetName = "model.tflite";
                File extractedFile = extractAssetToCache(activity, assetName);
                if (extractedFile != null) {
                    modelBuffer = loadModelFile(extractedFile.getAbsolutePath());
                }
            }
            if (modelBuffer == null) {
                Log.e(TAG, "Failed to load model buffer");
                return false;
            }
            Interpreter.Options options = new Interpreter.Options();
            options.setNumThreads(4);
            interpreter = new Interpreter(modelBuffer, options);

            // Parse input tensor
            Tensor inputTensor = interpreter.getInputTensor(0);
            int[] inputShape = inputTensor.shape();
            Log.d(TAG, "Input shape: " + Arrays.toString(inputShape));
            Log.d(TAG, "Input type: " + inputTensor.dataType());

            if (inputShape.length == 4) {
                inputHeight = inputShape[1];
                inputWidth = inputShape[2];
                inputChannels = inputShape[3];
            } else if (inputShape.length == 3) {
                inputHeight = inputShape[0];
                inputWidth = inputShape[1];
                inputChannels = inputShape[2];
            } else if (inputShape.length == 2) {
                int flatSize = inputShape[1];
                int side = (int) Math.round(Math.sqrt(flatSize));
                if (side * side == flatSize) {
                    inputHeight = side;
                    inputWidth = side;
                } else {
                    inputHeight = 1;
                    inputWidth = flatSize;
                }
                inputChannels = 1;
            }

            String dtype = inputTensor.dataType().toString();
            isFloatInput = dtype.contains("FLOAT") || dtype.contains("float");

            // Parse output tensors and detect which is regressor vs classifier
            int numOutputs = interpreter.getOutputTensorCount();
            Log.d(TAG, "Number of output tensors: " + numOutputs);
            boolean foundRegressor = false;
            boolean foundClassifier = false;
            for (int i = 0; i < numOutputs; i++) {
                Tensor outTensor = interpreter.getOutputTensor(i);
                int[] outShape = outTensor.shape();
                Log.d(TAG, "Output[" + i + "]: name='" + outTensor.name()
                        + "' shape=" + Arrays.toString(outShape)
                        + " type=" + outTensor.dataType());

                // Detect by last dimension: 16 = regressors, 1 = classifiers
                if (outShape.length >= 2) {
                    int lastDim = outShape[outShape.length - 1];
                    if (lastDim == 16 && !foundRegressor) {
                        regressorOutputIdx = i;
                        foundRegressor = true;
                        Log.d(TAG, "  -> Output[" + i + "] is REGRESSORS (16 channels)");
                    } else if (lastDim == 1 && !foundClassifier) {
                        classifierOutputIdx = i;
                        foundClassifier = true;
                        Log.d(TAG, "  -> Output[" + i + "] is CLASSIFIERS (1 channel)");
                    }
                }
            }
            if (!foundRegressor || !foundClassifier) {
                Log.w(TAG, "Could not auto-detect output mapping, using default: regressor="
                        + regressorOutputIdx + ", classifier=" + classifierOutputIdx);
            }

            // Generate anchors for BlazeFace
            generateAnchors();

            String outDtype = interpreter.getOutputTensor(0).dataType().toString();
            isFloatOutput = outDtype.contains("FLOAT") || outDtype.contains("float");

            Log.d(TAG, "Model loaded: input=" + inputWidth + "x" + inputHeight + "x" + inputChannels
                    + ", anchors=" + numAnchors);
            return true;
        } catch (Exception e) {
            Log.e(TAG, "loadModel failed", e);
            return false;
        }
    }

    /**
     * Generate BlazeFace Short Range anchors.
     * Config: strides=[8,16,16,16], 2 anchors per location, total=896
     * Anchors are square (aspect ratio = 1.0) with interpolated scales.
     */
    private void generateAnchors() {
        int[] strides = {8, 16, 16, 16};
        float minScale = 0.1484375f;
        float maxScale = 0.75f;
        int anchorsPerLocation = 2;
        int numLayers = strides.length;

        int totalAnchors = 0;
        for (int stride : strides) {
            int featureH = inputHeight / stride;
            int featureW = inputWidth / stride;
            totalAnchors += featureH * featureW * anchorsPerLocation;
        }
        numAnchors = totalAnchors;
        Log.d(TAG, "Generating " + numAnchors + " anchors");

        anchorCx = new float[numAnchors];
        anchorCy = new float[numAnchors];
        anchorW = new float[numAnchors];
        anchorH = new float[numAnchors];

        int idx = 0;
        for (int layerId = 0; layerId < numLayers; layerId++) {
            int stride = strides[layerId];
            int featureH = inputHeight / stride;
            int featureW = inputWidth / stride;

            float scale = minScale + (maxScale - minScale) * (2.0f * layerId + 1) / (2.0f * numLayers);
            float scaleNext = minScale + (maxScale - minScale) * (2.0f * layerId + 2) / (2.0f * numLayers);

            for (int y = 0; y < featureH; y++) {
                for (int x = 0; x < featureW; x++) {
                    float cx = (x + 0.5f) / featureW;
                    float cy = (y + 0.5f) / featureH;

                    anchorCx[idx] = cx;
                    anchorCy[idx] = cy;
                    anchorW[idx] = scale;
                    anchorH[idx] = scale;
                    idx++;

                    float interpolatedScale = (float) Math.sqrt(scale * scaleNext);
                    anchorCx[idx] = cx;
                    anchorCy[idx] = cy;
                    anchorW[idx] = interpolatedScale;
                    anchorH[idx] = interpolatedScale;
                    idx++;
                }
            }
        }
    }

    private MappedByteBuffer loadModelFile(String modelPath) throws IOException {
        File file = new File(modelPath);
        FileInputStream inputStream = new FileInputStream(file);
        java.nio.channels.FileChannel fileChannel = inputStream.getChannel();
        return fileChannel.map(java.nio.channels.FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
    }

    private File extractAssetToCache(Activity activity, String assetName) {
        AssetManager assetManager = activity.getAssets();
        File cacheDir = activity.getCacheDir();
        File outFile = new File(cacheDir, assetName);
        // Always overwrite to ensure latest model is used
        // (remove the .exists() check so new model replaces old cached one)
        try {
            InputStream in = assetManager.open(assetName);
            OutputStream out = new FileOutputStream(outFile);
            byte[] buffer = new byte[4096];
            int read;
            while ((read = in.read(buffer)) != -1) out.write(buffer, 0, read);
            in.close();
            out.close();
            Log.d(TAG, "Model extracted to cache: " + outFile.getAbsolutePath()
                    + " (" + outFile.length() + " bytes)");
            return outFile;
        } catch (IOException e) {
            Log.e(TAG, "extractAssetToCache failed", e);
            return null;
        }
    }

    public int[] getInputDims() {
        return new int[]{inputHeight, inputWidth, inputChannels};
    }

    public int getOutputSize() { return numAnchors * 5 + 1; } // max possible detections * 5 + count header

    /**
     * Run face detection on the input image.
     * Returns a flat float array:
     *   [0] = number of detections
     *   For each detection i (starting at index 1):
     *     [1 + i*5 + 0] = ymin (normalized 0-1)
     *     [1 + i*5 + 1] = xmin (normalized 0-1)
     *     [1 + i*5 + 2] = ymax (normalized 0-1)
     *     [1 + i*5 + 3] = xmax (normalized 0-1)
     *     [1 + i*5 + 4] = score (0-1)
     */
    public float[] runFaceDetection(byte[] imageData) {
        if (interpreter == null) {
            Log.e(TAG, "interpreter is null");
            return null;
        }

        try {
            // Prepare input buffer
            int inputPixels = inputHeight * inputWidth * inputChannels;
            int inputBytes = isFloatInput ? inputPixels * 4 : inputPixels;
            Log.d(TAG, "Input bytes needed: " + inputBytes + ", received: " + imageData.length);

            if (isFloatInput) {
                // Model expects float input, convert uint8 RGB888 to float [0, 1]
                ByteBuffer floatBuffer = ByteBuffer.allocateDirect(inputBytes);
                floatBuffer.order(ByteOrder.nativeOrder());

                for (int i = 0; i < imageData.length; i++) {
                    floatBuffer.putFloat((imageData[i] & 0xFF) / 255.0f);
                }
                // Pad remaining with zeros
                while (floatBuffer.position() < inputPixels) {
                    floatBuffer.putFloat(0.0f);
                }
                floatBuffer.rewind();

                // Run inference
                float[][][] regressorsOut = new float[1][numAnchors][16];
                float[][][] classificatorsOut = new float[1][numAnchors][1];
                java.util.Map<Integer, Object> outputs = new java.util.HashMap<>();
                outputs.put(regressorOutputIdx, regressorsOut);
                outputs.put(classifierOutputIdx, classificatorsOut);
                interpreter.runForMultipleInputsOutputs(new Object[]{floatBuffer}, outputs);

                // Process outputs
                processOutputs(regressorsOut, classificatorsOut);
            } else {
                // Uint8 input: copy bytes directly
                ByteBuffer inputBuffer = ByteBuffer.allocateDirect(inputBytes);
                inputBuffer.order(ByteOrder.nativeOrder());
                int copyLen = Math.min(imageData.length, inputBytes);
                inputBuffer.put(imageData, 0, copyLen);
                while (inputBuffer.position() < inputBytes) {
                    inputBuffer.put((byte) 0);
                }
                inputBuffer.rewind();

                // Run inference
                float[][][] regressorsOut = new float[1][numAnchors][16];
                float[][][] classificatorsOut = new float[1][numAnchors][1];
                java.util.Map<Integer, Object> outputs = new java.util.HashMap<>();
                outputs.put(regressorOutputIdx, regressorsOut);
                outputs.put(classifierOutputIdx, classificatorsOut);
                interpreter.runForMultipleInputsOutputs(new Object[]{inputBuffer}, outputs);

                // Process outputs
                processOutputs(regressorsOut, classificatorsOut);
            }

            return getLastResult();

        } catch (Exception e) {
            Log.e(TAG, "runFaceDetection failed", e);
            return null;
        }
    }

    // Store last detection result
    private float[] lastResult = null;

    private void processOutputs(float[][][] regressorsOut, float[][][] classificatorsOut) {
        // Flatten 3D output to 1D arrays
        float[] regressors = new float[numAnchors * 16];
        float[] classificators = new float[numAnchors];
        for (int a = 0; a < numAnchors; a++) {
            System.arraycopy(regressorsOut[0][a], 0, regressors, a * 16, 16);
            classificators[a] = classificatorsOut[0][a][0];
        }

        Log.d(TAG, "Regressors: " + regressors.length + ", Classificators: " + classificators.length);

        // Debug: log classification score distribution
        float minLogit = Float.MAX_VALUE, maxLogit = -Float.MAX_VALUE;
        for (float c : classificators) {
            if (c < minLogit) minLogit = c;
            if (c > maxLogit) maxLogit = c;
        }
        float minScore = 1.0f / (1.0f + (float) Math.exp(-minLogit));
        float maxScore = 1.0f / (1.0f + (float) Math.exp(-maxLogit));
        Log.d(TAG, "Classification logit range: [" + minLogit + ", " + maxLogit + "]");
        Log.d(TAG, "Score range (sigmoid): [" + minScore + ", " + maxScore + "]");

        // Debug: log raw regressor values for top-scoring anchors
        int debugCount = 0;
        for (int i = 0; i < numAnchors && debugCount < 3; i++) {
            float s = 1.0f / (1.0f + (float) Math.exp(-classificators[i]));
            if (s > 0.3f) {
                int base = i * 16;
                Log.d(TAG, "Top anchor #" + i + " score=" + s
                        + " reg=[y0=" + String.format("%.4f", regressors[base])
                        + ", x0=" + String.format("%.4f", regressors[base+1])
                        + ", y1=" + String.format("%.4f", regressors[base+2])
                        + ", x1=" + String.format("%.4f", regressors[base+3])
                        + "] anchor=(" + anchorCx[i] + "," + anchorCy[i] + ")");
                debugCount++;
            }
        }

        // Decode detections
        lastResult = decodeDetections(regressors, classificators, 0.3f, 0.3f);
    }

    private float[] getLastResult() {
        return lastResult != null ? lastResult : new float[]{0};
    }

    /**
     * Decode BlazeFace detections.
     * The 16 regression values per anchor are:
     *   [0..3] = bounding box: cy, cx, h, w as pixel offsets from anchor center
     *          normalized by dividing by input dimensions
     *   [4..15] = 6 face keypoints as (y, x) pairs, same encoding
     * All output in normalized [0, 1] coordinate space.
     */
    private float[] decodeDetections(float[] regressors, float[] classificators,
                                      float scoreThreshold, float iouThreshold) {
        ArrayList<float[]> detections = new ArrayList<>(); // [ymin, xmin, ymax, xmax, score]
        float invInputH = 1.0f / inputHeight;
        float invInputW = 1.0f / inputWidth;

        // Step 1: Decode all anchors
        for (int i = 0; i < numAnchors; i++) {
            float score = 1.0f / (1.0f + (float) Math.exp(-classificators[i]));
            if (score < scoreThreshold) continue;

            int base = i * 16;

            // Get anchor position
            float acx = this.anchorCx[i];
            float acy = this.anchorCy[i];

            // Decode bounding box from first 4 regression values
            // All values are pixel offsets from anchor center, normalized by input size
            float boxCy = acy + regressors[base] * invInputH;
            float boxCx = acx + regressors[base + 1] * invInputW;
            float boxH = regressors[base + 2] * invInputH;
            float boxW = regressors[base + 3] * invInputW;

            // Convert center+size to corners (normalized [0,1])
            float ymin = boxCy - boxH / 2.0f;
            float xmin = boxCx - boxW / 2.0f;
            float ymax = boxCy + boxH / 2.0f;
            float xmax = boxCx + boxW / 2.0f;

            // Validate box center is in reasonable range
            if (boxCy < -0.2f || boxCy > 1.2f || boxCx < -0.2f || boxCx > 1.2f) continue;
            if (boxW <= 0 || boxH <= 0) continue;

            // Optionally refine box using keypoints (values 4..15)
            // Keypoints use same pixel-offset / inputSize encoding
            float kpMinY = ymin, kpMaxY = ymax;
            float kpMinX = xmin, kpMaxX = xmax;
            int validKps = 0;

            for (int k = 0; k < 6; k++) {
                float kpY = acy + regressors[base + 4 + k * 2] * invInputH;
                float kpX = acx + regressors[base + 4 + k * 2 + 1] * invInputW;

                if (kpX >= -0.2f && kpX <= 1.2f && kpY >= -0.2f && kpY <= 1.2f) {
                    kpMinY = Math.min(kpMinY, kpY);
                    kpMaxY = Math.max(kpMaxY, kpY);
                    kpMinX = Math.min(kpMinX, kpX);
                    kpMaxX = Math.max(kpMaxX, kpX);
                    validKps++;
                }
            }

            // Use union of box and keypoint bounds if we have enough valid keypoints
            float finalYmin, finalXmin, finalYmax, finalXmax;
            if (validKps >= 3) {
                finalYmin = Math.min(ymin, kpMinY);
                finalXmin = Math.min(xmin, kpMinX);
                finalYmax = Math.max(ymax, kpMaxY);
                finalXmax = Math.max(xmax, kpMaxX);
            } else {
                finalYmin = ymin;
                finalXmin = xmin;
                finalYmax = ymax;
                finalXmax = xmax;
            }

            // Add margin (15%)
            float bW = finalXmax - finalXmin;
            float bH = finalYmax - finalYmin;
            float margin = Math.max(bW, bH) * 0.15f;

            float outYmin = clamp01(finalYmin - margin);
            float outXmin = clamp01(finalXmin - margin);
            float outYmax = clamp01(finalYmax + margin);
            float outXmax = clamp01(finalXmax + margin);

            // Validate box size
            float finalW = outXmax - outXmin;
            float finalH = outYmax - outYmin;
            if (finalW < 0.02f || finalH < 0.02f) continue;
            if (finalW > 0.98f || finalH > 0.98f) continue;

            detections.add(new float[]{outYmin, outXmin, outYmax, outXmax, score});

            // Debug
            if (score > 0.5f) {
                Log.d(TAG, "Face #" + i + " score=" + score
                        + " anchor=(" + String.format("%.3f", acx) + "," + String.format("%.3f", acy) + ")"
                        + " kps=[" + validKps + "]"
                        + " box=[" + String.format("%.3f", outYmin) + "," + String.format("%.3f", outXmin)
                        + "," + String.format("%.3f", outYmax) + "," + String.format("%.3f", outXmax) + "]"
                        + " raw=(" + String.format("%.2f", regressors[base])
                        + "," + String.format("%.2f", regressors[base+1])
                        + "," + String.format("%.2f", regressors[base+2])
                        + "," + String.format("%.2f", regressors[base+3]) + ")"
                        + " decoded_c=(" + String.format("%.3f", boxCy) + "," + String.format("%.3f", boxCx) + ")");
            }
        }

        Log.d(TAG, "Anchors above threshold (" + scoreThreshold + "): " + detections.size());

        if (detections.isEmpty()) {
            return new float[]{0};
        }

        // Step 2: Sort by score descending
        int n = detections.size();
        Integer[] order = new Integer[n];
        for (int i = 0; i < n; i++) order[i] = i;
        Arrays.sort(order, (a, b) -> Float.compare(detections.get(b)[4], detections.get(a)[4]));

        // Step 3: NMS deduplication
        boolean[] suppressed = new boolean[n];
        ArrayList<Integer> kept = new ArrayList<>();

        for (int i = 0; i < n; i++) {
            int ci = order[i];
            if (suppressed[ci]) continue;
            kept.add(ci);
            for (int j = i + 1; j < n; j++) {
                int cj = order[j];
                if (suppressed[cj]) continue;
                float iou = computeIoU(detections.get(ci), detections.get(cj));
                if (iou > iouThreshold) {
                    suppressed[cj] = true;
                }
            }
        }

        Log.d(TAG, "After NMS: " + kept.size() + " detections");

        if (kept.isEmpty()) {
            return new float[]{0};
        }

        // Step 4: Keep only the detection with the largest area
        int bestIdx = 0;
        float bestArea = 0;
        for (int k = 0; k < kept.size(); k++) {
            float[] d = detections.get(kept.get(k));
            float area = (d[2] - d[0]) * (d[3] - d[1]);
            if (area > bestArea) {
                bestArea = area;
                bestIdx = k;
            }
        }

        numDetections = 1;
        float[] result = new float[6];
        result[0] = 1;
        float[] det = detections.get(kept.get(bestIdx));

        // Shift box: down by 1/6 height, left by 1/6 width
        float boxH = det[2] - det[0];
        float boxW = det[3] - det[1];
        float shiftY = boxH / 6.0f;
        float shiftX = boxW / 6.0f;
        result[1] = clamp01(det[0] + shiftY); // ymin down
        result[2] = clamp01(det[1] - shiftX); // xmin left
        result[3] = clamp01(det[2] + shiftY); // ymax down
        result[4] = clamp01(det[3] - shiftX); // xmax left
        result[5] = det[4];

        Log.d(TAG, "Best face (largest area): score=" + det[4]
                + " area=" + String.format("%.4f", bestArea)
                + " box=[" + det[0] + "," + det[1] + "," + det[2] + "," + det[3] + "]");

        return result;
    }

    private float clamp01(float v) {
        return Math.max(0, Math.min(1, v));
    }

    private float computeIoU(float[] boxA, float[] boxB) {
        float ymin = Math.max(boxA[0], boxB[0]);
        float xmin = Math.max(boxA[1], boxB[1]);
        float ymax = Math.min(boxA[2], boxB[2]);
        float xmax = Math.min(boxA[3], boxB[3]);

        float interW = Math.max(0, xmax - xmin);
        float interH = Math.max(0, ymax - ymin);
        float interArea = interW * interH;

        float areaA = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]);
        float areaB = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]);
        float unionArea = areaA + areaB - interArea;

        return unionArea > 0 ? interArea / unionArea : 0;
    }

    /**
     * Legacy method for backward compatibility.
     * Returns raw detection results as flat array.
     */
    public float[] runInference(byte[] imageData) {
        return runFaceDetection(imageData);
    }
}

需要完整程序请私信或者留邮箱。

更多推荐