1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| public class TensorFlowImageClassifier implements Classifier {
private TensorFlowImageClassifier() { }
public static Classifier create( String modelPath, String modelFilename, String modelScoreFileName, String[] labels, int inputSize, String inputName, String outputName) { TensorFlowImageClassifier c = new TensorFlowImageClassifier(); c.inputName = inputName; c.outputName = outputName; c.tags = labels; String path = Environment.getExternalStorageDirectory()+modelPath + modelFilename; String pathScore = Environment.getExternalStorageDirectory()+modelPath + modelScoreFileName; try { c.inferenceInterface = new TensorFlowInferenceInterface(new FileInputStream(new File(path))); c.inferenceInterfaceScore = new TensorFlowInferenceInterface(new FileInputStream(new File(pathScore))); } catch (FileNotFoundException e) { e.printStackTrace(); } c.inputSize = inputSize; c.outputNames = new String[]{outputName}; c.intValues = new int[inputSize * inputSize]; c.floatValues = new float[inputSize * inputSize * 3]; c.outputs = new float[labels.length]; c.outputScore= new float[2]; return c; }
@Override public List<Recognition> recognizeImage(final Bitmap bitmap) { final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3 + 0] = (((val >> 16) & 0xFF)) / 128f - 1.0f; floatValues[i * 3 + 1] = (((val >> 8) & 0xFF)) / 128f - 1.0f; floatValues[i * 3 + 2] = ((val & 0xFF)) / 128f - 1.0f; } inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); inferenceInterface.run(outputNames, logStats); inferenceInterface.fetch(outputName, outputs); int index = 0; float maxPercent = 0f; for (int i = 0; i < outputs.length; i++) { if (maxPercent < outputs[i]) { maxPercent = outputs[i]; index = i; } } recognitions.add(new Recognition( "LABEL", tags[index], maxPercent, null));
inferenceInterfaceScore.feed(inputName,floatValues,1,inputSize,inputSize,3); inferenceInterfaceScore.run(outputNames,logStats); inferenceInterfaceScore.fetch(outputName,outputScore); recognitions.add(new Recognition("SCORE","GOOD="+outputScore[0]+" POOR="+outputScore[1],null,null));
return recognitions; }
@Override public void enableStatLogging(boolean logStats) { this.logStats = logStats; }
@Override public String getStatString() { return inferenceInterface.getStatString(); }
@Override public void close() { inferenceInterface.close(); } }
|