-
Notifications
You must be signed in to change notification settings - Fork 1
/
OnnxModelScorer.cs
93 lines (79 loc) · 2.86 KB
/
OnnxModelScorer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
using Microsoft.ML;
using Microsoft.ML.Data;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
using System.Threading.Tasks;
namespace godot_net_server
{
class OnnxModelScorer
{
private readonly string modelLocation;
private readonly MLContext mlContext;
private ITransformer model;
public OnnxModelScorer(string modelLocation, MLContext mlContext)
{
this.modelLocation = modelLocation;
this.mlContext = mlContext;
}
public class ModelInput
{
[VectorType(10)]
[ColumnName("input.1")]
public float[] Features { get; set; }
public ModelInput(IEnumerable<float> input)
{
Features = input.ToArray();
}
public static List<ModelInput> MakeInput(IEnumerable<float> input)
{
List<ModelInput> input_array = new List<ModelInput>();
input_array.Add(new ModelInput(input));
return input_array;
}
}
private ITransformer LoadModel(string modelLocation)
{
// Create IDataView from empty list to obtain input data schema
var data = mlContext.Data.LoadFromEnumerable(new List<ModelInput>());
// Define scoring pipeline
var pipeline = mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { "31", "34" }, inputColumnNames: new[] { "input.1" });
// Fit scoring pipeline
var model = pipeline.Fit(data);
return model;
}
public class Prediction
{
[VectorType(6)]
[ColumnName("31")]
public float[] action { get; set; }
[VectorType(1)]
[ColumnName("34")]
public float[] state { get; set; }
}
private IEnumerable<float> PredictDataUsingModel(IDataView testData, ITransformer model)
{
IDataView scoredData = model.Transform(testData);
IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>("31");
var a = probabilities.ToList();
a.Count.ToString();
return a[0];
}
public IEnumerable<float> Score(ModelInput input)
{
var data = mlContext.Data.LoadFromEnumerable(new[] { input });
if (model == null)
model = LoadModel(modelLocation);
return PredictDataUsingModel(data, model);
}
public IEnumerable<float> Score(IEnumerable<float> input)
{
var data = mlContext.Data.LoadFromEnumerable(new[] { new ModelInput(input) });
if (model == null)
model = LoadModel(modelLocation);
return PredictDataUsingModel(data, model);
}
}
}