Skip to content

Commit a7a6d88

Browse files
Adds in a way to add settings for the MLContext. (#7273)
* api, no tests * updates from pr * fixed rebase errors and pr comments * updates based on ONNX team
1 parent 869dc9f commit a7a6d88

File tree

7 files changed

+347
-19
lines changed

7 files changed

+347
-19
lines changed

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using Microsoft.ML.Data;
78

89
namespace Microsoft.ML.Runtime;
@@ -104,6 +105,16 @@ internal interface IHostEnvironmentInternal : IHostEnvironment
104105
/// GPU device ID to run execution on, <see langword="null" /> to run on CPU.
105106
/// </summary>
106107
int? GpuDeviceId { get; set; }
108+
109+
bool TryAddOption<T>(string name, T value);
110+
111+
void SetOption<T>(string name, T value);
112+
113+
bool TryGetOption<T>(string name, out T value);
114+
115+
T GetOptionOrDefault<T>(string name);
116+
117+
bool RemoveOption(string name);
107118
}
108119

109120
/// <summary>

src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
334334

335335
public bool FallbackToCpu { get; set; }
336336

337+
#pragma warning disable MSML_NoInstanceInitializers // Need this to have a default value.
338+
protected Dictionary<string, object> Options { get; } = [];
339+
#pragma warning restore MSML_NoInstanceInitializers
340+
337341
protected readonly TEnv Root;
338342
// This is non-null iff this environment was a fork of another. Disposing a fork
339343
// doesn't free temp files. That is handled when the master is disposed.
@@ -567,4 +571,91 @@ public virtual void PrintMessageNormalized(TextWriter writer, string message, bo
567571
else if (!removeLastNewLine)
568572
writer.WriteLine();
569573
}
574+
575+
/// <summary>
576+
/// Trys to add a new runtime option.
577+
/// </summary>
578+
/// <typeparam name="T"></typeparam>
579+
/// <param name="name">Name of the option to add.</param>
580+
/// <param name="value">Value to set.</param>
581+
/// <returns><see langword="true"/> if successful. <see langword="false"/> otherwise.</returns>
582+
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
583+
public bool TryAddOption<T>(string name, T value)
584+
{
585+
if (string.IsNullOrWhiteSpace(name))
586+
throw new ArgumentNullException(nameof(name));
587+
588+
if (Options.ContainsKey(name))
589+
return false;
590+
SetOption(name, value);
591+
return true;
592+
}
593+
594+
/// <summary>
595+
/// Adds or Sets the <paramref name="value"/> with the given <paramref name="name"/>. Is cast to <typeparamref name="T"/>.
596+
/// </summary>
597+
/// <typeparam name="T"></typeparam>
598+
/// <param name="name">Name of the option to set.</param>
599+
/// <param name="value">Value to set.</param>
600+
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
601+
public void SetOption<T>(string name, T value)
602+
{
603+
if (string.IsNullOrWhiteSpace(name))
604+
throw new ArgumentNullException(nameof(name));
605+
Options[name] = value;
606+
}
607+
608+
/// <summary>
609+
/// Gets an option by <paramref name="name"/> and returns <see langword="true"/> if that has been added and <see langword="false"/> otherwise.
610+
/// </summary>
611+
/// <typeparam name="T"></typeparam>
612+
/// <param name="name">Name of the option to get.</param>
613+
/// <param name="value">Options value of type <typeparamref name="T"/>.</param>
614+
/// <returns><see langword="true"/> if the option was able to be retrieved, else <see langword="false"/></returns>
615+
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
616+
public bool TryGetOption<T>(string name, out T value)
617+
{
618+
if (string.IsNullOrWhiteSpace(name))
619+
throw new ArgumentNullException(nameof(name));
620+
621+
if (!Options.TryGetValue(name, out var val) || val is not T)
622+
{
623+
value = default;
624+
return false;
625+
}
626+
value = (T)val;
627+
return true;
628+
}
629+
630+
/// <summary>
631+
/// Gets either the option stored by that <paramref name="name"/>, or adds the default value of <typeparamref name="T"/> with that <paramref name="name"/> and returns it.
632+
/// </summary>
633+
/// <typeparam name="T"></typeparam>
634+
/// <param name="name">Name of the option to get.</param>
635+
/// <returns>Options value of type <typeparamref name="T"/>.</returns>
636+
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
637+
public T GetOptionOrDefault<T>(string name)
638+
{
639+
if (string.IsNullOrWhiteSpace(name))
640+
throw new ArgumentNullException(nameof(name));
641+
642+
if (!Options.TryGetValue(name, out object value))
643+
SetOption<T>(name, default);
644+
else
645+
return (T)value;
646+
return (T)Options[name];
647+
}
648+
649+
/// <summary>
650+
/// Removes an option.
651+
/// </summary>
652+
/// <param name="name">Name of the option to remove.</param>
653+
/// <returns><see langword="true"/> if successfully removed, else <see langword="false"/>.</returns>
654+
/// <exception cref="ArgumentNullException">When <paramref name="name"/> is null or empty.</exception>
655+
public bool RemoveOption(string name)
656+
{
657+
if (string.IsNullOrWhiteSpace(name))
658+
throw new ArgumentNullException(nameof(name));
659+
return Options.Remove(name);
660+
}
570661
}

src/Microsoft.ML.Data/MLContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,5 +191,11 @@ private static bool InitializeOneDalDispatchingEnabled()
191191
return false;
192192
}
193193
}
194+
195+
public bool TryAddOption<T>(string name, T value) => _env.TryAddOption(name, value);
196+
public void SetOption<T>(string name, T value) => _env.SetOption(name, value);
197+
public bool TryGetOption<T>(string name, out T value) => _env.TryGetOption<T>(name, out value);
198+
public T GetOptionOrDefault<T>(string name) => _env.GetOptionOrDefault<T>(name);
199+
public bool RemoveOption(string name) => _env.RemoveOption(name);
194200
}
195201
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using Microsoft.ML.Data;
9+
using Microsoft.ML.OnnxRuntime;
10+
using Microsoft.ML.Runtime;
11+
using Microsoft.ML.Transforms.Onnx;
12+
13+
namespace Microsoft.ML.Transforms.Onnx
14+
{
15+
public static class OnnxSessionOptionsExtensions
16+
{
17+
private const string OnnxSessionOptionsName = "OnnxSessionOptions";
18+
19+
public static OnnxSessionOptions GetOnnxSessionOption(this IHostEnvironment env)
20+
{
21+
if (env is IHostEnvironmentInternal localEnvironment)
22+
{
23+
return localEnvironment.GetOptionOrDefault<OnnxSessionOptions>(OnnxSessionOptionsName);
24+
}
25+
26+
throw new ArgumentException("No Onnx Session Options");
27+
}
28+
29+
public static void SetOnnxSessionOption(this IHostEnvironment env, OnnxSessionOptions onnxSessionOptions)
30+
{
31+
if (env is IHostEnvironmentInternal localEnvironment)
32+
{
33+
localEnvironment.SetOption(OnnxSessionOptionsName, onnxSessionOptions);
34+
}
35+
else
36+
throw new ArgumentException("No Onnx Session Options");
37+
}
38+
}
39+
40+
public sealed class OnnxSessionOptions
41+
{
42+
internal void CopyTo(SessionOptions sessionOptions)
43+
{
44+
sessionOptions.EnableMemoryPattern = EnableMemoryPattern;
45+
sessionOptions.ProfileOutputPathPrefix = ProfileOutputPathPrefix;
46+
sessionOptions.EnableProfiling = EnableProfiling;
47+
sessionOptions.OptimizedModelFilePath = OptimizedModelFilePath;
48+
sessionOptions.EnableCpuMemArena = EnableCpuMemArena;
49+
if (!PerSessionThreads)
50+
sessionOptions.DisablePerSessionThreads();
51+
sessionOptions.LogId = LogId;
52+
sessionOptions.LogSeverityLevel = LogSeverityLevel;
53+
sessionOptions.LogVerbosityLevel = LogVerbosityLevel;
54+
sessionOptions.InterOpNumThreads = InterOpNumThreads;
55+
sessionOptions.IntraOpNumThreads = IntraOpNumThreads;
56+
sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel;
57+
sessionOptions.ExecutionMode = ExecutionMode;
58+
}
59+
60+
/// <summary>
61+
/// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true.
62+
/// </summary>
63+
#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties
64+
public bool EnableMemoryPattern { get; set; } = true;
65+
66+
/// <summary>
67+
/// Path prefix to use for output of profiling data
68+
/// </summary>
69+
public string ProfileOutputPathPrefix { get; set; } = "onnxruntime_profile_"; // this is the same default in C++ implementation
70+
71+
/// <summary>
72+
/// Enables profiling of InferenceSession.Run() calls. Default is false
73+
/// </summary>
74+
public bool EnableProfiling { get; set; } = false;
75+
76+
/// <summary>
77+
/// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled.
78+
/// </summary>
79+
public string OptimizedModelFilePath { get; set; } = string.Empty;
80+
81+
/// <summary>
82+
/// Enables Arena allocator for the CPU memory allocations. Default is true.
83+
/// </summary>
84+
public bool EnableCpuMemArena { get; set; } = true;
85+
86+
/// <summary>
87+
/// Per session threads. Default is true.
88+
/// If false this makes all sessions in the process use a global TP.
89+
/// </summary>
90+
public bool PerSessionThreads { get; set; } = true;
91+
92+
/// <summary>
93+
/// Sets the number of threads used to parallelize the execution within nodes
94+
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false.
95+
/// </summary>
96+
public int GlobalIntraOpNumThreads { get; set; } = 0;
97+
98+
/// <summary>
99+
/// Sets the number of threads used to parallelize the execution of the graph (across nodes)
100+
/// If sequential execution is enabled this value is ignored
101+
/// A value of 0 means ORT will pick a default. Only used when <see cref="PerSessionThreads"/> is false.
102+
/// </summary>
103+
public int GlobalInterOpNumThreads { get; set; } = 0;
104+
105+
/// <summary>
106+
/// Log Id to be used for the session. Default is empty string.
107+
/// </summary>
108+
public string LogId { get; set; } = string.Empty;
109+
110+
/// <summary>
111+
/// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING
112+
/// </summary>
113+
public OrtLoggingLevel LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
114+
115+
/// <summary>
116+
/// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0.
117+
/// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE.
118+
/// </summary>
119+
public int LogVerbosityLevel { get; set; } = 0;
120+
121+
/// <summary>
122+
/// Sets the number of threads used to parallelize the execution within nodes
123+
/// A value of 0 means ORT will pick a default
124+
/// </summary>
125+
public int IntraOpNumThreads { get; set; } = 0;
126+
127+
/// <summary>
128+
/// Sets the number of threads used to parallelize the execution of the graph (across nodes)
129+
/// If sequential execution is enabled this value is ignored
130+
/// A value of 0 means ORT will pick a default
131+
/// </summary>
132+
public int InterOpNumThreads { get; set; } = 0;
133+
134+
/// <summary>
135+
/// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL.
136+
/// </summary>
137+
public GraphOptimizationLevel GraphOptimizationLevel { get; set; } = GraphOptimizationLevel.ORT_ENABLE_ALL;
138+
139+
/// <summary>
140+
/// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL.
141+
/// See [ONNX_Runtime_Perf_Tuning.md] for more details.
142+
/// </summary>
143+
public ExecutionMode ExecutionMode { get; set; } = ExecutionMode.ORT_SEQUENTIAL;
144+
#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties
145+
146+
public delegate SessionOptions CreateOnnxSessionOptions();
147+
148+
public CreateOnnxSessionOptions CreateSessionOptions { get; set; }
149+
}
150+
}

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
262262
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
263263
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
264264
// Because we cannot delete the user file, ownModelFile should be false.
265-
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
265+
Model = new OnnxModel(env, options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false, shapeDictionary: shapeDictionary, options.RecursionLimit,
266266
options.InterOpNumThreads, options.IntraOpNumThreads);
267267
}
268268
else

0 commit comments

Comments
 (0)