From b3fe9af23f33f487ea377969b6e90d7177c1161e Mon Sep 17 00:00:00 2001 From: Koki Ibukuro Date: Thu, 17 Oct 2024 09:23:39 +0200 Subject: [PATCH 1/2] Fix `TfLiteSignatureRunnerGetInputCount` return type to match with C API --- .../Runtime/InterpreterExtension.cs | 4 ++-- .../com.github.asus4.tflite/Runtime/SignatureRunner.cs | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Packages/com.github.asus4.tflite.common/Runtime/InterpreterExtension.cs b/Packages/com.github.asus4.tflite.common/Runtime/InterpreterExtension.cs index b22575f07..18f8be182 100644 --- a/Packages/com.github.asus4.tflite.common/Runtime/InterpreterExtension.cs +++ b/Packages/com.github.asus4.tflite.common/Runtime/InterpreterExtension.cs @@ -54,7 +54,7 @@ public static void LogIOInfo(this SignatureRunner runner) } sb.AppendLine(); - int signatureInputCount = runner.GetSignatureInputCount(); + int signatureInputCount = (int)runner.GetSignatureInputCount(); for (int i = 0; i < signatureInputCount; i++) { string name = runner.GetSignatureInputName(i); @@ -62,7 +62,7 @@ public static void LogIOInfo(this SignatureRunner runner) } sb.AppendLine(); - int signatureOutputCount = runner.GetSignatureOutputCount(); + int signatureOutputCount = (int)runner.GetSignatureOutputCount(); for (int i = 0; i < signatureOutputCount; i++) { string name = runner.GetSignatureOutputName(i); diff --git a/Packages/com.github.asus4.tflite/Runtime/SignatureRunner.cs b/Packages/com.github.asus4.tflite/Runtime/SignatureRunner.cs index e0f07cfb2..903cd5c54 100644 --- a/Packages/com.github.asus4.tflite/Runtime/SignatureRunner.cs +++ b/Packages/com.github.asus4.tflite/Runtime/SignatureRunner.cs @@ -73,7 +73,7 @@ public string GetSignatureKey(int index) return ToString(TfLiteInterpreterGetSignatureKey(InterpreterPointer, index)); } - public int GetSignatureInputCount() + public ulong GetSignatureInputCount() { return TfLiteSignatureRunnerGetInputCount(runner); } @@ -138,7 +138,7 @@ public override void Invoke() ThrowIfError(TfLiteSignatureRunnerInvoke(runner)); } - public int GetSignatureOutputCount() + public ulong GetSignatureOutputCount() { return TfLiteSignatureRunnerGetOutputCount(runner); } @@ -199,7 +199,7 @@ private void Initialize(string signatureName) private Dictionary CreateMap(bool isInput) { - int signatureCount = isInput ? GetSignatureInputCount() : GetSignatureOutputCount(); + int signatureCount = (int)(isInput ? GetSignatureInputCount() : GetSignatureOutputCount()); int tensorCount = isInput ? GetInputTensorCount() : GetOutputTensorCount(); Assert.AreEqual(signatureCount, tensorCount); @@ -270,7 +270,7 @@ private Dictionary CreateMap(bool isInput) private static extern TfLiteSignatureRunner TfLiteInterpreterGetSignatureRunner(TfLiteInterpreter interpreter, string signature_name); [DllImport(TensorFlowLibrary)] - private static extern int TfLiteSignatureRunnerGetInputCount(TfLiteSignatureRunner signature_runner); + private static extern ulong TfLiteSignatureRunnerGetInputCount(TfLiteSignatureRunner signature_runner); [DllImport(TensorFlowLibrary)] private static extern IntPtr TfLiteSignatureRunnerGetInputName(TfLiteSignatureRunner signature_runner, int input_index); @@ -290,7 +290,7 @@ private static extern Status TfLiteSignatureRunnerResizeInputTensor( private static extern Status TfLiteSignatureRunnerInvoke(TfLiteSignatureRunner signature_runner); [DllImport(TensorFlowLibrary)] - private static extern int TfLiteSignatureRunnerGetOutputCount(TfLiteSignatureRunner signature_runner); + private static extern ulong TfLiteSignatureRunnerGetOutputCount(TfLiteSignatureRunner signature_runner); [DllImport(TensorFlowLibrary)] private static extern IntPtr TfLiteSignatureRunnerGetOutputName(TfLiteSignatureRunner signature_runner, int output_index); From bd29e2e994c14ea604fb4abf19550688e8ebb69a Mon Sep 17 00:00:00 2001 From: Koki Ibukuro Date: Thu, 17 Oct 2024 10:07:35 +0200 Subject: [PATCH 2/2] Fix VideoClassification crashing --- Assets/Samples/VideoClassification/VideoClassification.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Assets/Samples/VideoClassification/VideoClassification.cs b/Assets/Samples/VideoClassification/VideoClassification.cs index ad6969f24..1ce4f3c8f 100644 --- a/Assets/Samples/VideoClassification/VideoClassification.cs +++ b/Assets/Samples/VideoClassification/VideoClassification.cs @@ -57,6 +57,7 @@ public VideoClassification(Options options) try { runner = new SignatureRunner(SIGNATURE_KEY, FileUtil.LoadFile(options.modelPath), interpreterOptions); + runner.AllocateSignatureTensors(); } catch (Exception e) {