Skip to content

Commit

Permalink
Piped Mode: Flush data stream before closing StdIn
Browse files Browse the repository at this point in the history
  • Loading branch information
gerardog committed Mar 12, 2022
1 parent a3fe406 commit aad1281
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 95 deletions.
20 changes: 9 additions & 11 deletions src/gsudo/ProcessHosts/PipedProcessHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public async Task Start(Connection connection, ElevationRequest request)

var t1 = process.StandardOutput.ConsumeOutput((s) => WriteToPipe(s));
var t2 = process.StandardError.ConsumeOutput((s) => WriteToErrorPipe(s));
var t3 = new StreamReader(connection.DataStream, Settings.Encoding).ConsumeOutput((s) => WriteToProcessStdIn(s, process), CloseProcessStdIn);
var t3 = new StreamReader(connection.DataStream, Settings.Encoding).ConsumeOutput((s) => WriteToProcessStdIn(s, process));
var t4 = new StreamReader(connection.ControlStream, Settings.Encoding).ConsumeOutput((s) => HandleControl(s, process));

if (Settings.SecurityEnforceUacIsolation)
Expand Down Expand Up @@ -108,7 +108,7 @@ private async Task WriteToProcessStdIn(string s, Process process)
}

static readonly string[] TOKENS = new string[] { "\0", Constants.TOKEN_KEY_CTRLBREAK, Constants.TOKEN_KEY_CTRLC, Constants.TOKEN_EOF };
private Task HandleControl(string s, Process process)
private async Task HandleControl(string s, Process process)
{
var tokens = new Stack<string>(StringTokenizer.Split(s, TOKENS));

Expand All @@ -134,9 +134,14 @@ private Task HandleControl(string s, Process process)

if (token == Constants.TOKEN_EOF)
{
Logger.Instance.Log("Incoming StdIn EOF", LogLevel.Debug);
// Logger.Instance.Log("Incoming StdIn EOF", LogLevel.Debug);

// There is a race condition here. Lets ensure StdIn is depleted before closing sending EOF.
await _connection.FlushDataStream().ConfigureAwait(false);
await Task.Delay(1).ConfigureAwait(false);

bool done = false;
while (!done)
while (!done) // Loop until process.StandardInput is not used by other thread.
{
try
{
Expand All @@ -150,7 +155,6 @@ private Task HandleControl(string s, Process process)
continue;
}
}
return Task.CompletedTask;
}

private async Task WriteToErrorPipe(string s)
Expand Down Expand Up @@ -199,11 +203,5 @@ private static int EqualCharsCount(string s1, string s2)
{ }
return i;
}

private Task CloseProcessStdIn()
{
process.StandardInput.Close();
return Task.CompletedTask;
}
}
}
4 changes: 2 additions & 2 deletions src/gsudo/ProcessRenderers/PipedClientRenderer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.IO.Pipes;
using System.Linq;
using System.Threading.Tasks;

Expand Down Expand Up @@ -193,7 +192,8 @@ private async Task SendKeysToHost(string s)

private async Task CloseStdIn()
{
(_connection.DataStream as NamedPipeClientStream)?.WaitForPipeDrain();
// flush data before control command.
await _connection.FlushDataStream().ConfigureAwait(false);
await _connection.ControlStream.WriteAsync(Constants.TOKEN_EOF).ConfigureAwait(false);
}
}
Expand Down
85 changes: 85 additions & 0 deletions src/gsudo/Rpc/Connection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using System;
using System.IO;
using System.IO.Pipes;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using System.Threading.Tasks;

namespace gsudo.Rpc
{
class Connection : IDisposable
{
private PipeStream _dataStream;
private PipeStream _controlStream;

public Connection(PipeStream ControlStream, PipeStream DataStream)
{
_dataStream = DataStream;
_controlStream = ControlStream;
}

public Stream DataStream => _dataStream;
public Stream ControlStream => _controlStream;

private ManualResetEvent DisconnectedResetEvent { get; } = new ManualResetEvent(false);
public WaitHandle DisconnectedWaitHandle => DisconnectedResetEvent;

public bool IsAlive { get; private set; } = true;
public void SignalDisconnected()
{
IsAlive = false;
DisconnectedResetEvent.Set();
}

public async Task FlushAndCloseAll()
{
IsAlive = false;
await FlushDataStream().ConfigureAwait(false);
await FlushControlStream().ConfigureAwait(false);
DataStream.Close();
ControlStream.Close();
}

public Task FlushDataStream() => Flush(_dataStream);
public Task FlushControlStream() => Flush(_controlStream);

private async Task Flush(PipeStream npStream)
{
try
{
await Task.Delay(1).ConfigureAwait(false);
await npStream.FlushAsync().ConfigureAwait(false);
npStream.WaitForPipeDrain();
await Task.Delay(1).ConfigureAwait(false);
}
catch (ObjectDisposedException) { }
catch (Exception) { }
}

public async Task WriteElevationRequest(ElevationRequest elevationRequest)
{
// Using Binary instead of Newtonsoft.JSON to reduce load times.
var ms = new System.IO.MemoryStream();
new BinaryFormatter()
{ TypeFormat = System.Runtime.Serialization.Formatters.FormatterTypeStyle.TypesAlways, Binder = new MySerializationBinder() }
.Serialize(ms, elevationRequest);
ms.Seek(0, System.IO.SeekOrigin.Begin);

byte[] lengthArray = BitConverter.GetBytes(ms.Length);
Logger.Instance.Log($"ElevationRequest length {ms.Length}", LogLevel.Debug);

await ControlStream.WriteAsync(lengthArray, 0, sizeof(int)).ConfigureAwait(false);
await ControlStream.WriteAsync(ms.ToArray(), 0, (int)ms.Length).ConfigureAwait(false);
await ControlStream.FlushAsync().ConfigureAwait(false);
}

public void Dispose()
{
DataStream?.Close();
DataStream?.Dispose();
ControlStream?.Close();
ControlStream?.Dispose();
IsAlive = false;
}
}
}
76 changes: 1 addition & 75 deletions src/gsudo/Rpc/IRpcClient.cs
Original file line number Diff line number Diff line change
@@ -1,83 +1,9 @@
using System;
using System.IO;
using System.IO.Pipes;
using System.Runtime.Serialization.Formatters.Binary;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks;

namespace gsudo.Rpc
{
internal interface IRpcClient
{
Task<Connection> Connect(int? clientPid, bool failFast);
}

class Connection : IDisposable
{
public Stream DataStream { get; set; }
public Stream ControlStream { get; set; }

private ManualResetEvent DisconnectedResetEvent { get; } = new ManualResetEvent(false);
public WaitHandle DisconnectedWaitHandle => DisconnectedResetEvent;

public bool IsAlive { get; private set; } = true;
public void SignalDisconnected()
{
IsAlive = false;
DisconnectedResetEvent.Set();
}

public async Task FlushAndCloseAll()
{
IsAlive = false;
await Flush(DataStream).ConfigureAwait(false);
await Flush(ControlStream).ConfigureAwait(false);
DataStream.Close();
ControlStream.Close();
}

private static async Task Flush(Stream DataStream)
{
if (DataStream is NamedPipeServerStream)
{
var npStream = DataStream as NamedPipeServerStream;
try
{
await Task.Delay(1).ConfigureAwait(false);
await npStream.FlushAsync().ConfigureAwait(false);
npStream.WaitForPipeDrain();
await Task.Delay(1).ConfigureAwait(false);
}
catch (Exception) { }
}
else
DataStream.Close();
}

public async Task WriteElevationRequest(ElevationRequest elevationRequest)
{
// Using Binary instead of Newtonsoft.JSON to reduce load times.
var ms = new System.IO.MemoryStream();
new BinaryFormatter()
{ TypeFormat = System.Runtime.Serialization.Formatters.FormatterTypeStyle.TypesAlways, Binder = new MySerializationBinder() }
.Serialize(ms, elevationRequest);
ms.Seek(0, System.IO.SeekOrigin.Begin);

byte[] lengthArray = BitConverter.GetBytes(ms.Length);
Logger.Instance.Log($"ElevationRequest length {ms.Length}", LogLevel.Debug);

await ControlStream.WriteAsync(lengthArray, 0, sizeof(int)).ConfigureAwait(false);
await ControlStream.WriteAsync(ms.ToArray(), 0, (int)ms.Length).ConfigureAwait(false);
await ControlStream.FlushAsync().ConfigureAwait(false);
}

public void Dispose()
{
DataStream?.Close();
DataStream?.Dispose();
ControlStream?.Close();
ControlStream?.Dispose();
IsAlive = false;
}
}
}
7 changes: 1 addition & 6 deletions src/gsudo/Rpc/NamedPipeClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,7 @@ public async Task<Connection> Connect(int? clientPid, bool failFast)

Logger.Instance.Log($"Connected via Named Pipe {pipeName}.", LogLevel.Debug);

var conn = new Connection()
{
ControlStream = controlPipe,
DataStream = dataPipe,
};

var conn = new Connection(controlPipe, dataPipe);
return conn;
}
catch (System.TimeoutException)
Expand Down
2 changes: 1 addition & 1 deletion src/gsudo/Rpc/NamedPipeServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public async Task Listen()

if (dataPipe.IsConnected && controlPipe.IsConnected && !_cancellationTokenSource.IsCancellationRequested)
{
var connection = new Connection() { ControlStream = controlPipe, DataStream = dataPipe };
var connection = new Connection(controlPipe, dataPipe);

ConnectionKeepAliveThread.Start(connection);

Expand Down

0 comments on commit aad1281

Please sign in to comment.