Skip to content

Commit

Permalink
Refactored and added pixelConfidence parameter for Segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
NickSwardh committed Oct 13, 2024
1 parent f599d49 commit 9e2e1e1
Show file tree
Hide file tree
Showing 19 changed files with 87 additions and 133 deletions.
13 changes: 8 additions & 5 deletions YoloDotNet/Data/YoloCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ public IDisposableReadOnlyCollection<OrtValue> Run(SKImage image)
public Dictionary<int, List<T>> RunVideo<T>(
VideoOptions options,
double confidence,
double pixelConfidence,
double iouThreshold,
Func<SKImage, double, double, List<T>> func) where T : class, new()
Func<SKImage, double, double, double, List<T>> func) where T : class, new()
{
var output = new Dictionary<int, List<T>>();

Expand All @@ -116,7 +117,7 @@ public Dictionary<int, List<T>> RunVideo<T>(
_videoHandler.StatusChangeEvent += (sender, e) => VideoStatusEvent?.Invoke(sender, e);
_videoHandler.FramesExtractedEvent += (sender, e) =>
{
output = RunBatchInferenceOnVideoFrames<T>(_videoHandler, confidence, iouThreshold, func);
output = RunBatchInferenceOnVideoFrames<T>(_videoHandler, confidence, pixelConfidence, iouThreshold, func);

if (options.GenerateVideo)
_videoHandler.ProcessVideoPipeline(VideoAction.CompileFrames);
Expand All @@ -134,8 +135,10 @@ public Dictionary<int, List<T>> RunVideo<T>(
/// </summary>
private Dictionary<int, List<T>> RunBatchInferenceOnVideoFrames<T>(
VideoHandler.VideoHandler _videoHandler,
double confidence, double iouThreshold,
Func<SKImage, double, double, List<T>> func) where T : class, new()
double confidence,
double pixelConfidence,
double iouThreshold,
Func<SKImage, double, double, double, List<T>> func) where T : class, new()
{
var frames = _videoHandler.GetExtractedFrames();
int progressCounter = 0;
Expand All @@ -151,7 +154,7 @@ private Dictionary<int, List<T>> RunBatchInferenceOnVideoFrames<T>(
using var img = SKImage.FromEncodedData(frame);


var results = func.Invoke(img, confidence, iouThreshold);
var results = func.Invoke(img, confidence, pixelConfidence, iouThreshold);
batch[i] = results;

if (shouldDrawLabelsOnKeptFrames || shouldDrawLabelsOnVideoFrames)
Expand Down
4 changes: 2 additions & 2 deletions YoloDotNet/Modules/Interfaces/IClassificationModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
public interface IClassificationModule : IModule
{
List<Classification> ProcessImage(SKImage image, double classes, double iou);
Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double iou);
List<Classification> ProcessImage(SKImage image, double classes, double pixelConfidence, double iou);
Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou);
}
}
4 changes: 2 additions & 2 deletions YoloDotNet/Modules/Interfaces/IOBBDetectionModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
internal interface IOBBDetectionModule : IModule
{
List<OBBDetection> ProcessImage(SKImage image, double confidence, double iou);
Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double iou);
List<OBBDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence,double iou);
Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence,double iou);
}
}
4 changes: 2 additions & 2 deletions YoloDotNet/Modules/Interfaces/IObjectDetectionModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
public interface IObjectDetectionModule : IModule
{
List<ObjectDetection> ProcessImage(SKImage image, double confidence, double iou);
Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double iou);
List<ObjectDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence,double iou);
Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou);
}
}
4 changes: 2 additions & 2 deletions YoloDotNet/Modules/Interfaces/IPoseEstimationModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
internal interface IPoseEstimationModule : IModule
{
List<PoseEstimation> ProcessImage(SKImage image, double confidence, double iou);
Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double iou);
List<PoseEstimation> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou);
Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou);
}
}
4 changes: 2 additions & 2 deletions YoloDotNet/Modules/Interfaces/ISegmentationModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
{
public interface ISegmentationModule : IModule
{
List<Segmentation> ProcessImage(SKImage image, double confidence, double iou);
Dictionary<int, List<Segmentation>> ProcessVideo(VideoOptions options, double confidence, double iou);
List<Segmentation> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou);
Dictionary<int, List<Segmentation>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou);
}
}
6 changes: 3 additions & 3 deletions YoloDotNet/Modules/V10/ObjectDetectionModuleV10.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public ObjectDetectionModuleV10(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double iou)
public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
{
using var ortValues = _yoloCore!.Run(image);
using var ort = ortValues[0];
Expand All @@ -26,8 +26,8 @@ public List<ObjectDetection> ProcessImage(SKImage image, double confidence, doub
return [.. results];
}

public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
14 changes: 4 additions & 10 deletions YoloDotNet/Modules/V11/ClassificationModuleV11.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ public ClassificationModuleV11(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<Classification> ProcessImage(SKImage image, double classes, double iou)
{
using var ortValues = _yoloCore.Run(image);
using var ort = ortValues[0];

return _classificationModuleV8.ProcessImage(image, classes, iou);
}
public List<Classification> ProcessImage(SKImage image, double classes, double pixelConfidence,double iou)
=> _classificationModuleV8.ProcessImage(image, classes, pixelConfidence, iou);

public Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

private void SubscribeToVideoEvents()
{
_yoloCore.VideoProgressEvent += (sender, e) => VideoProgressEvent?.Invoke(sender, e);
Expand Down
21 changes: 7 additions & 14 deletions YoloDotNet/Modules/V11/OBBDetectionModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ internal class OBBDetectionModuleV11 : IOBBDetectionModule
public event EventHandler VideoCompleteEvent = delegate { };

private readonly YoloCore _yoloCore;
private readonly ObjectDetectionModuleV8 _objectDetectionModuleV8 = default!;
private readonly OBBDetectionModuleV8 _obbDetectionModuleV8 = default!;

public OnnxModel OnnxModel => _yoloCore.OnnxModel;

Expand All @@ -17,23 +17,16 @@ public OBBDetectionModuleV11(YoloCore yoloCore)

// Yolov11 has the same model input/output shapes as Yolov8
// Use Yolov8 module
_objectDetectionModuleV8 = new ObjectDetectionModuleV8(_yoloCore);
_obbDetectionModuleV8 = new OBBDetectionModuleV8(_yoloCore);

SubscribeToVideoEvents();
}

public List<OBBDetection> ProcessImage(SKImage image, double confidence, double iou)
{
using IDisposableReadOnlyCollection<OrtValue>? ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>();

var objectDetectionResults = _objectDetectionModuleV8.ObjectDetection(image, ortSpan, confidence, iou);

return [.. objectDetectionResults.Select(x => (OBBDetection)x)];
}
public List<OBBDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
=> _obbDetectionModuleV8.ProcessImage(image, confidence, pixelConfidence, iou);

public Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand All @@ -50,7 +43,7 @@ public void Dispose()
_yoloCore.VideoCompleteEvent -= (sender, e) => VideoCompleteEvent?.Invoke(sender, e);
_yoloCore.VideoStatusEvent -= (sender, e) => VideoStatusEvent?.Invoke(sender, e);

_objectDetectionModuleV8?.Dispose();
_obbDetectionModuleV8?.Dispose();
_yoloCore?.Dispose();

GC.SuppressFinalize(this);
Expand Down
16 changes: 4 additions & 12 deletions YoloDotNet/Modules/V11/ObjectDetectionModuleV11.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,11 @@ public ObjectDetectionModuleV11(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double iou)
{
using var ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>();

var results = _objectDetectionModuleV8.ObjectDetection(image, ortSpan, confidence, iou)
.Select(x => (ObjectDetection)x);

return [..results];
}
public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
=> _objectDetectionModuleV8.ProcessImage(image, confidence, pixelConfidence, iou);

public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

private void SubscribeToVideoEvents()
{
Expand Down
13 changes: 4 additions & 9 deletions YoloDotNet/Modules/V11/PoseEstimationModuleV11.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,11 @@ public PoseEstimationModuleV11(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<PoseEstimation> ProcessImage(SKImage image, double confidence, double iou)
{
using IDisposableReadOnlyCollection<OrtValue>? ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>();

return _poseEstimationModuleV8.ProcessImage(image, confidence, iou);
}
public List<PoseEstimation> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
=> _poseEstimationModuleV8.ProcessImage(image, confidence, pixelConfidence, iou);

public Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
12 changes: 4 additions & 8 deletions YoloDotNet/Modules/V11/SegmentationModuleV8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ public SegmentationModuleV11(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<Segmentation> ProcessImage(SKImage image, double confidence, double iou)
{
using IDisposableReadOnlyCollection<OrtValue>? ortValues = _yoloCore.Run(image);

return _segmentationModuleV8.ProcessImage(image, confidence, iou);
}
public List<Segmentation> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
=> _segmentationModuleV8.ProcessImage(image, confidence, pixelConfidence, iou);

public Dictionary<int, List<Segmentation>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<Segmentation>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
35 changes: 12 additions & 23 deletions YoloDotNet/Modules/V8/ClassificationModuleV8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ public ClassificationModuleV8(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<Classification> ProcessImage(SKImage image, double classes, double iou)
public List<Classification> ProcessImage(SKImage image, double classes, double pixelConfidence, double iou)
{
using var ortValues = _yoloCore.Run(image);
using var ort = ortValues[0];
return ClassifyTensor(ort, (int)classes);
}

public Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Classicifation

Expand All @@ -38,30 +38,19 @@ public Dictionary<int, List<Classification>> ProcessVideo(VideoOptions options,
/// <param name="numberOfClasses"></param>
private List<Classification> ClassifyTensor(OrtValue ortTensor, int numberOfClasses)
{
var span = ortTensor.GetTensorMutableDataAsSpan<float>();
var len = span.Length;
var span = ortTensor.GetTensorDataAsSpan<float>();
var tmp = new Classification[span.Length];

var tmp = _classificationPool.Rent(len);

try
for (int i = 0; i < tmp.Length; i++)
{
for (int i = 0; i < len; i++)
tmp[i] = new Classification
{
tmp[i] = new Classification
{
Confidence = span[i],
Label = _yoloCore.OnnxModel.Labels[i].Name
};
}

// Use Array.Sort() instead of LINQ for performance
Array.Sort(tmp[.. len], (a, b) => b.Confidence.CompareTo(a.Confidence));
return [.. tmp[..numberOfClasses]];
}
finally
{
_classificationPool.Return(tmp, true);
Confidence = span[i],
Label = _yoloCore.OnnxModel.Labels[i].Name
};
}

return [.. tmp.OrderByDescending(x => x.Confidence).Take(numberOfClasses)];
}

#endregion
Expand Down
6 changes: 3 additions & 3 deletions YoloDotNet/Modules/V8/OBBDetectionModuleV8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public OBBDetectionModuleV8(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<OBBDetection> ProcessImage(SKImage image, double confidence, double iou)
public List<OBBDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
{
using IDisposableReadOnlyCollection<OrtValue>? ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>();
Expand All @@ -28,8 +28,8 @@ public List<OBBDetection> ProcessImage(SKImage image, double confidence, double
return [.. objectDetectionResults.Select(x => (OBBDetection)x)];
}

public Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<OBBDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
6 changes: 3 additions & 3 deletions YoloDotNet/Modules/V8/ObjectDetectionModuleV8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public ObjectDetectionModuleV8(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double iou)
public List<ObjectDetection> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
{
using var ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>();
Expand All @@ -42,8 +42,8 @@ public List<ObjectDetection> ProcessImage(SKImage image, double confidence, doub
return [..results];
}

public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<ObjectDetection>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
6 changes: 3 additions & 3 deletions YoloDotNet/Modules/V8/PoseEstimationModuleV8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ public PoseEstimationModuleV8(YoloCore yoloCore)
SubscribeToVideoEvents();
}

public List<PoseEstimation> ProcessImage(SKImage image, double confidence, double iou)
public List<PoseEstimation> ProcessImage(SKImage image, double confidence, double pixelConfidence, double iou)
{
using IDisposableReadOnlyCollection<OrtValue>? ortValues = _yoloCore.Run(image);
var ortSpan = ortValues[0].GetTensorDataAsSpan<float>(); ;

return PoseEstimateImage(image, ortSpan, confidence, iou);
}

public Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double iou)
=> _yoloCore.RunVideo(options, confidence, iou, ProcessImage);
public Dictionary<int, List<PoseEstimation>> ProcessVideo(VideoOptions options, double confidence, double pixelConfidence, double iou)
=> _yoloCore.RunVideo(options, confidence, pixelConfidence, iou, ProcessImage);

#region Helper methods

Expand Down
Loading

0 comments on commit 9e2e1e1

Please sign in to comment.