diff --git a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml
index 64af657..ef20cb0 100644
--- a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml
+++ b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml
@@ -2,6 +2,7 @@
+
\ No newline at end of file
diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs
index 034f405..08d2303 100644
--- a/App/Models/RpcModel.cs
+++ b/App/Models/RpcModel.cs
@@ -1,4 +1,7 @@
+using System;
using System.Collections.Generic;
+using System.Diagnostics;
+using Coder.Desktop.App.Converters;
using Coder.Desktop.Vpn.Proto;
namespace Coder.Desktop.App.Models;
@@ -19,11 +22,168 @@ public enum VpnLifecycle
Stopping,
}
+public enum VpnStartupStage
+{
+ Unknown,
+ Initializing,
+ Downloading,
+ Finalizing,
+}
+
+public class VpnDownloadProgress
+{
+ public ulong BytesWritten { get; set; } = 0;
+ public ulong? BytesTotal { get; set; } = null; // null means unknown total size
+
+ public double Progress
+ {
+ get
+ {
+ if (BytesTotal is > 0)
+ {
+ return (double)BytesWritten / BytesTotal.Value;
+ }
+ return 0.0;
+ }
+ }
+
+ public override string ToString()
+ {
+ // TODO: it would be nice if the two suffixes could match
+ var s = FriendlyByteConverter.FriendlyBytes(BytesWritten);
+ if (BytesTotal != null)
+ s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}";
+ else
+ s += " of unknown";
+ if (BytesTotal != null)
+ s += $" ({Progress:0%})";
+ return s;
+ }
+
+ public VpnDownloadProgress Clone()
+ {
+ return new VpnDownloadProgress
+ {
+ BytesWritten = BytesWritten,
+ BytesTotal = BytesTotal,
+ };
+ }
+
+ public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto)
+ {
+ return new VpnDownloadProgress
+ {
+ BytesWritten = proto.BytesWritten,
+ BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null,
+ };
+ }
+}
+
+public class VpnStartupProgress
+{
+ public const string DefaultStartProgressMessage = "Starting Coder Connect...";
+
+ // Scale the download progress to an overall progress value between these
+ // numbers.
+ private const double DownloadProgressMin = 0.05;
+ private const double DownloadProgressMax = 0.80;
+
+ public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown;
+ public VpnDownloadProgress? DownloadProgress { get; init; } = null;
+
+ // 0.0 to 1.0
+ public double Progress
+ {
+ get
+ {
+ switch (Stage)
+ {
+ case VpnStartupStage.Unknown:
+ case VpnStartupStage.Initializing:
+ return 0.0;
+ case VpnStartupStage.Downloading:
+ var progress = DownloadProgress?.Progress ?? 0.0;
+ return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress;
+ case VpnStartupStage.Finalizing:
+ return DownloadProgressMax;
+ default:
+ throw new ArgumentOutOfRangeException();
+ }
+ }
+ }
+
+ public override string ToString()
+ {
+ switch (Stage)
+ {
+ case VpnStartupStage.Unknown:
+ case VpnStartupStage.Initializing:
+ return DefaultStartProgressMessage;
+ case VpnStartupStage.Downloading:
+ var s = "Downloading Coder Connect binary...";
+ if (DownloadProgress is not null)
+ {
+ s += "\n" + DownloadProgress;
+ }
+
+ return s;
+ case VpnStartupStage.Finalizing:
+ return "Finalizing Coder Connect startup...";
+ default:
+ throw new ArgumentOutOfRangeException();
+ }
+ }
+
+ public VpnStartupProgress Clone()
+ {
+ return new VpnStartupProgress
+ {
+ Stage = Stage,
+ DownloadProgress = DownloadProgress?.Clone(),
+ };
+ }
+
+ public static VpnStartupProgress FromProto(StartProgress proto)
+ {
+ return new VpnStartupProgress
+ {
+ Stage = proto.Stage switch
+ {
+ StartProgressStage.Initializing => VpnStartupStage.Initializing,
+ StartProgressStage.Downloading => VpnStartupStage.Downloading,
+ StartProgressStage.Finalizing => VpnStartupStage.Finalizing,
+ _ => VpnStartupStage.Unknown,
+ },
+ DownloadProgress = proto.Stage is StartProgressStage.Downloading ?
+ VpnDownloadProgress.FromProto(proto.DownloadProgress) :
+ null,
+ };
+ }
+}
+
public class RpcModel
{
public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected;
- public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown;
+ public VpnLifecycle VpnLifecycle
+ {
+ get;
+ set
+ {
+ if (VpnLifecycle != value && value == VpnLifecycle.Starting)
+ // Reset the startup progress when the VPN lifecycle changes to
+ // Starting.
+ VpnStartupProgress = null;
+ field = value;
+ }
+ }
+
+ // Nullable because it is only set when the VpnLifecycle is Starting
+ public VpnStartupProgress? VpnStartupProgress
+ {
+ get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null;
+ set;
+ }
public IReadOnlyList Workspaces { get; set; } = [];
@@ -35,6 +195,7 @@ public RpcModel Clone()
{
RpcLifecycle = RpcLifecycle,
VpnLifecycle = VpnLifecycle,
+ VpnStartupProgress = VpnStartupProgress?.Clone(),
Workspaces = Workspaces,
Agents = Agents,
};
diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs
index 7beff66..168a1be 100644
--- a/App/Services/RpcController.cs
+++ b/App/Services/RpcController.cs
@@ -161,7 +161,10 @@ public async Task StartVpn(CancellationToken ct = default)
throw new RpcOperationException(
$"Cannot start VPN without valid credentials, current state: {credentials.State}");
- MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; });
+ MutateState(state =>
+ {
+ state.VpnLifecycle = VpnLifecycle.Starting;
+ });
ServiceMessage reply;
try
@@ -283,15 +286,28 @@ private void ApplyStatusUpdate(Status status)
});
}
+ private void ApplyStartProgressUpdate(StartProgress message)
+ {
+ MutateState(state =>
+ {
+ // The model itself will ignore this value if we're not in the
+ // starting state.
+ state.VpnStartupProgress = VpnStartupProgress.FromProto(message);
+ });
+ }
+
private void SpeakerOnReceive(ReplyableRpcMessage message)
{
switch (message.Message.MsgCase)
{
+ case ServiceMessage.MsgOneofCase.Start:
+ case ServiceMessage.MsgOneofCase.Stop:
case ServiceMessage.MsgOneofCase.Status:
ApplyStatusUpdate(message.Message.Status);
break;
- case ServiceMessage.MsgOneofCase.Start:
- case ServiceMessage.MsgOneofCase.Stop:
+ case ServiceMessage.MsgOneofCase.StartProgress:
+ ApplyStartProgressUpdate(message.Message.StartProgress);
+ break;
case ServiceMessage.MsgOneofCase.None:
default:
// TODO: log unexpected message
diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs
index d8b3182..820ff12 100644
--- a/App/ViewModels/TrayWindowViewModel.cs
+++ b/App/ViewModels/TrayWindowViewModel.cs
@@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
{
private const int MaxAgents = 5;
private const string DefaultDashboardUrl = "https://coder.com";
- private const string DefaultHostnameSuffix = ".coder";
private readonly IServiceProvider _services;
private readonly IRpcController _rpcController;
@@ -53,6 +52,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
+ [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
@@ -63,6 +63,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
+ [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
@@ -70,7 +71,25 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
[NotifyPropertyChangedFor(nameof(ShowFailedSection))]
public partial string? VpnFailedMessage { get; set; } = null;
- public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started;
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))]
+ [NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))]
+ public partial int? VpnStartProgressValue { get; set; } = null;
+
+ public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0;
+
+ [ObservableProperty]
+ [NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))]
+ public partial string? VpnStartProgressMessage { get; set; } = null;
+
+ public string VpnStartProgressMessageOrDefault =>
+ string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage;
+
+ public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0;
+
+ public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started;
+
+ public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting;
public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started;
@@ -170,6 +189,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel)
VpnLifecycle = rpcModel.VpnLifecycle;
VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started;
+ // VpnStartupProgress is only set when the VPN is starting.
+ if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null)
+ {
+ // Convert 0.00-1.00 to 0-100.
+ var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100);
+ VpnStartProgressValue = Math.Clamp(progress, 0, 100);
+ VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString();
+ }
+ else
+ {
+ VpnStartProgressValue = null;
+ VpnStartProgressMessage = null;
+ }
+
// Add every known agent.
HashSet workspacesWithAgents = [];
List agents = [];
diff --git a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml
index c1d69aa..171e292 100644
--- a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml
+++ b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml
@@ -36,7 +36,7 @@
diff --git a/App/Views/Pages/TrayWindowMainPage.xaml b/App/Views/Pages/TrayWindowMainPage.xaml
index 283867d..f488454 100644
--- a/App/Views/Pages/TrayWindowMainPage.xaml
+++ b/App/Views/Pages/TrayWindowMainPage.xaml
@@ -43,6 +43,8 @@
+
+
+ HorizontalContentAlignment="Left">
diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs
index 986ce46..bb9b39c 100644
--- a/Tests.Vpn.Service/DownloaderTest.cs
+++ b/Tests.Vpn.Service/DownloaderTest.cs
@@ -277,8 +277,8 @@ public async Task Download(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
- Assert.That(dlTask.TotalBytes, Is.EqualTo(4));
- Assert.That(dlTask.BytesRead, Is.EqualTo(4));
+ Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
+ Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
Assert.That(dlTask.Progress, Is.EqualTo(1));
Assert.That(dlTask.IsCompleted, Is.True);
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
@@ -300,18 +300,62 @@ public async Task DownloadSameDest(CancellationToken ct)
NullDownloadValidator.Instance, ct);
var dlTask0 = await startTask0;
await dlTask0.Task;
- Assert.That(dlTask0.TotalBytes, Is.EqualTo(5));
- Assert.That(dlTask0.BytesRead, Is.EqualTo(5));
+ Assert.That(dlTask0.BytesTotal, Is.EqualTo(5));
+ Assert.That(dlTask0.BytesWritten, Is.EqualTo(5));
Assert.That(dlTask0.Progress, Is.EqualTo(1));
Assert.That(dlTask0.IsCompleted, Is.True);
var dlTask1 = await startTask1;
await dlTask1.Task;
- Assert.That(dlTask1.TotalBytes, Is.EqualTo(5));
- Assert.That(dlTask1.BytesRead, Is.EqualTo(5));
+ Assert.That(dlTask1.BytesTotal, Is.EqualTo(5));
+ Assert.That(dlTask1.BytesWritten, Is.EqualTo(5));
Assert.That(dlTask1.Progress, Is.EqualTo(1));
Assert.That(dlTask1.IsCompleted, Is.True);
}
+ [Test(Description = "Download with X-Original-Content-Length")]
+ [CancelAfter(30_000)]
+ public async Task DownloadWithXOriginalContentLength(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(async ctx =>
+ {
+ ctx.Response.StatusCode = 200;
+ ctx.Response.Headers.Add("X-Original-Content-Length", "4");
+ ctx.Response.ContentType = "text/plain";
+ // Don't set Content-Length.
+ await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
+ });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+ var manager = new Downloader(NullLogger.Instance);
+ var req = new HttpRequestMessage(HttpMethod.Get, url);
+ var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
+
+ await dlTask.Task;
+ Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
+ Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
+ }
+
+ [Test(Description = "Download with mismatched Content-Length")]
+ [CancelAfter(30_000)]
+ public async Task DownloadWithMismatchedContentLength(CancellationToken ct)
+ {
+ using var httpServer = new TestHttpServer(async ctx =>
+ {
+ ctx.Response.StatusCode = 200;
+ ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect
+ ctx.Response.ContentType = "text/plain";
+ await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
+ });
+ var url = new Uri(httpServer.BaseUrl + "/test");
+ var destPath = Path.Combine(_tempDir, "test");
+ var manager = new Downloader(NullLogger.Instance);
+ var req = new HttpRequestMessage(HttpMethod.Get, url);
+ var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
+
+ var ex = Assert.ThrowsAsync(() => dlTask.Task);
+ Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4"));
+ }
+
[Test(Description = "Download with custom headers")]
[CancelAfter(30_000)]
public async Task WithHeaders(CancellationToken ct)
@@ -347,7 +391,7 @@ public async Task DownloadExisting(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
- Assert.That(dlTask.BytesRead, Is.Zero);
+ Assert.That(dlTask.BytesWritten, Is.Zero);
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1)));
}
@@ -368,7 +412,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
- Assert.That(dlTask.BytesRead, Is.EqualTo(4));
+ Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1)));
}
diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto
index 2561a4b..bace7e0 100644
--- a/Vpn.Proto/vpn.proto
+++ b/Vpn.Proto/vpn.proto
@@ -60,7 +60,8 @@ message ServiceMessage {
oneof msg {
StartResponse start = 2;
StopResponse stop = 3;
- Status status = 4; // either in reply to a StatusRequest or broadcasted
+ Status status = 4; // either in reply to a StatusRequest or broadcasted
+ StartProgress start_progress = 5; // broadcasted during startup
}
}
@@ -218,6 +219,28 @@ message StartResponse {
string error_message = 2;
}
+// StartProgress is sent from the manager to the client to indicate the
+// download/startup progress of the tunnel. This will be sent during the
+// processing of a StartRequest before the StartResponse is sent.
+//
+// Note: this is currently a broadcasted message to all clients due to the
+// inability to easily send messages to a specific client in the Speaker
+// implementation. If clients are not expecting these messages, they
+// should ignore them.
+enum StartProgressStage {
+ Initializing = 0;
+ Downloading = 1;
+ Finalizing = 2;
+}
+message StartProgressDownloadProgress {
+ uint64 bytes_written = 1;
+ optional uint64 bytes_total = 2; // unknown in some situations
+}
+message StartProgress {
+ StartProgressStage stage = 1;
+ optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading
+}
+
// StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a
// StopResponse.
message StopRequest {}
diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs
index 6a3108b..c4a916f 100644
--- a/Vpn.Service/Downloader.cs
+++ b/Vpn.Service/Downloader.cs
@@ -339,31 +339,35 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance
}
///
-/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final
+/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final
/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if
/// it hasn't changed.
///
public class DownloadTask
{
- private const int BufferSize = 4096;
+ private const int BufferSize = 64 * 1024;
+ private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available
- private static readonly HttpClient HttpClient = new();
+ private static readonly HttpClient HttpClient = new(new HttpClientHandler
+ {
+ AutomaticDecompression = DecompressionMethods.All,
+ });
private readonly string _destinationDirectory;
private readonly ILogger _logger;
private readonly RaiiSemaphoreSlim _semaphore = new(1, 1);
private readonly IDownloadValidator _validator;
- public readonly string DestinationPath;
+ private readonly string _destinationPath;
+ private readonly string _tempDestinationPath;
public readonly HttpRequestMessage Request;
- public readonly string TempDestinationPath;
- public ulong? TotalBytes { get; private set; }
- public ulong BytesRead { get; private set; }
public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync
-
- public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value;
+ public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download
+ public ulong BytesWritten { get; private set; }
+ public ulong? BytesTotal { get; private set; }
+ public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value;
public bool IsCompleted => Task.IsCompleted;
internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator)
@@ -374,17 +378,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination
if (string.IsNullOrWhiteSpace(destinationPath))
throw new ArgumentException("Destination path must not be empty", nameof(destinationPath));
- DestinationPath = Path.GetFullPath(destinationPath);
- if (Path.EndsInDirectorySeparator(DestinationPath))
- throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator",
+ _destinationPath = Path.GetFullPath(destinationPath);
+ if (Path.EndsInDirectorySeparator(_destinationPath))
+ throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator",
nameof(destinationPath));
- _destinationDirectory = Path.GetDirectoryName(DestinationPath)
+ _destinationDirectory = Path.GetDirectoryName(_destinationPath)
?? throw new ArgumentException(
- $"Destination path '{DestinationPath}' must have a parent directory",
+ $"Destination path '{_destinationPath}' must have a parent directory",
nameof(destinationPath));
- TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) +
+ _tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) +
".download-" + Path.GetRandomFileName());
}
@@ -406,9 +410,9 @@ private async Task Start(CancellationToken ct = default)
// If the destination path exists, generate a Coder SHA1 ETag and send
// it in the If-None-Match header to the server.
- if (File.Exists(DestinationPath))
+ if (File.Exists(_destinationPath))
{
- await using var stream = File.OpenRead(DestinationPath);
+ await using var stream = File.OpenRead(_destinationPath);
var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower();
Request.Headers.Add("If-None-Match", "\"" + etag + "\"");
}
@@ -419,11 +423,11 @@ private async Task Start(CancellationToken ct = default)
_logger.LogInformation("File has not been modified, skipping download");
try
{
- await _validator.ValidateAsync(DestinationPath, ct);
+ await _validator.ValidateAsync(_destinationPath, ct);
}
catch (Exception e)
{
- _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath);
+ _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath);
throw new Exception("Existing file failed validation after 304 Not Modified", e);
}
@@ -446,24 +450,38 @@ private async Task Start(CancellationToken ct = default)
}
if (res.Content.Headers.ContentLength >= 0)
- TotalBytes = (ulong)res.Content.Headers.ContentLength;
+ BytesTotal = (ulong)res.Content.Headers.ContentLength;
+
+ // X-Original-Content-Length overrules Content-Length if set.
+ if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues))
+ {
+ // If there are multiple we only look at the first one.
+ var headerValue = headerValues.ToList().FirstOrDefault();
+ if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength))
+ BytesTotal = originalContentLength;
+ else
+ _logger.LogWarning(
+ "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'",
+ XOriginalContentLengthHeader, headerValue);
+ }
await Download(res, ct);
}
private async Task Download(HttpResponseMessage res, CancellationToken ct)
{
+ DownloadStarted = true;
try
{
var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null;
FileStream tempFile;
try
{
- tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan);
+ tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan);
}
catch (Exception e)
{
- _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath);
+ _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath);
throw;
}
@@ -476,13 +494,14 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
{
await tempFile.WriteAsync(buffer.AsMemory(0, n), ct);
sha1?.TransformBlock(buffer, 0, n, null, 0);
- BytesRead += (ulong)n;
+ BytesWritten += (ulong)n;
}
}
- if (TotalBytes != null && BytesRead != TotalBytes)
+ BytesTotal ??= BytesWritten;
+ if (BytesWritten != BytesTotal)
throw new IOException(
- $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}");
+ $"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}");
// Verify the ETag if it was sent by the server.
if (res.Headers.Contains("ETag") && sha1 != null)
@@ -497,26 +516,34 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
try
{
- await _validator.ValidateAsync(TempDestinationPath, ct);
+ await _validator.ValidateAsync(_tempDestinationPath, ct);
}
catch (Exception e)
{
_logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation",
- TempDestinationPath);
+ _tempDestinationPath);
throw new HttpRequestException("Downloaded file failed validation", e);
}
- File.Move(TempDestinationPath, DestinationPath, true);
+ File.Move(_tempDestinationPath, _destinationPath, true);
}
- finally
+ catch
{
#if DEBUG
_logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode",
- TempDestinationPath);
+ _tempDestinationPath);
#else
- if (File.Exists(TempDestinationPath))
- File.Delete(TempDestinationPath);
+ try
+ {
+ if (File.Exists(_tempDestinationPath))
+ File.Delete(_tempDestinationPath);
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath);
+ }
#endif
+ throw;
}
}
}
diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs
index fc014c0..fdb62af 100644
--- a/Vpn.Service/Manager.cs
+++ b/Vpn.Service/Manager.cs
@@ -131,6 +131,8 @@ private async ValueTask HandleClientMessageStart(ClientMessage me
{
try
{
+ await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct);
+
var serverVersion =
await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct);
if (_status == TunnelStatus.Started && _lastStartRequest != null &&
@@ -151,10 +153,14 @@ private async ValueTask HandleClientMessageStart(ClientMessage me
_lastServerVersion = serverVersion;
// TODO: each section of this operation needs a timeout
+
// Stop the tunnel if it's running so we don't have to worry about
// permissions issues when replacing the binary.
await _tunnelSupervisor.StopAsync(ct);
+
await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct);
+
+ await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct);
await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage,
HandleTunnelRpcError,
ct);
@@ -237,6 +243,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage CurrentStatus(CancellationToken ct = default)
private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default)
{
if (newStatus != null) _status = newStatus.Value;
- await _managerRpc.BroadcastAsync(new ServiceMessage
+ await FallibleBroadcast(new ServiceMessage
{
Status = await CurrentStatus(ct),
}, ct);
}
+ private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default)
+ {
+ // Broadcast the messages out with a low timeout. If clients don't
+ // receive broadcasts in time, it's not a big deal.
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
+ cts.CancelAfter(TimeSpan.FromMilliseconds(30));
+ try
+ {
+ await _managerRpc.BroadcastAsync(message, cts.Token);
+ }
+ catch (Exception ex)
+ {
+ _logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message);
+ }
+ }
+
private void HandleTunnelRpcError(Exception e)
{
_logger.LogError(e, "Manager<->Tunnel RPC error");
@@ -425,12 +450,61 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected
_logger.LogDebug("Skipping tunnel binary version validation");
}
+ // Note: all ETag, signature and version validation is performed by the
+ // DownloadTask.
var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct);
- // TODO: monitor and report progress when we have a mechanism to do so
+ // Wait for the download to complete, sending progress updates every
+ // 50ms.
+ while (true)
+ {
+ // Wait for the download to complete, or for a short delay before
+ // we send a progress update.
+ var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct);
+ var winner = await Task.WhenAny([
+ downloadTask.Task,
+ delayTask,
+ ]);
+ if (winner == downloadTask.Task)
+ break;
+
+ // Task.WhenAny will not throw if the winner was cancelled, so
+ // check CT afterward and not beforehand.
+ ct.ThrowIfCancellationRequested();
+
+ if (!downloadTask.DownloadStarted)
+ // Don't send progress updates if we don't know what the
+ // progress is yet.
+ continue;
+
+ var progress = new StartProgressDownloadProgress
+ {
+ BytesWritten = downloadTask.BytesWritten,
+ };
+ if (downloadTask.BytesTotal != null)
+ progress.BytesTotal = downloadTask.BytesTotal.Value;
- // Awaiting this will check the checksum (via the ETag) if the file
- // exists, and will also validate the signature and version.
+ await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct);
+ }
+
+ // Await again to re-throw any exceptions that occurred during the
+ // download.
await downloadTask.Task;
+
+ // We don't send a broadcast here as we immediately send one in the
+ // parent routine.
+ _logger.LogInformation("Completed downloading VPN binary");
+ }
+
+ private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default)
+ {
+ await FallibleBroadcast(new ServiceMessage
+ {
+ StartProgress = new StartProgress
+ {
+ Stage = stage,
+ DownloadProgress = downloadProgress,
+ },
+ }, cancellationToken);
}
}
diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs
index c23752f..4920570 100644
--- a/Vpn.Service/ManagerRpc.cs
+++ b/Vpn.Service/ManagerRpc.cs
@@ -127,14 +127,20 @@ public async Task ExecuteAsync(CancellationToken stoppingToken)
public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct)
{
+ // Sends messages to all clients simultaneously and waits for them all
+ // to send or fail/timeout.
+ //
// Looping over a ConcurrentDictionary is exception-safe, but any items
// added or removed during the loop may or may not be included.
- foreach (var (clientId, client) in _activeClients)
+ await Task.WhenAll(_activeClients.Select(async item =>
+ {
try
{
- var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
- cts.CancelAfter(5 * 1000);
- await client.Speaker.SendMessage(message, cts.Token);
+ // Enforce upper bound in case a CT with a timeout wasn't
+ // supplied.
+ using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
+ cts.CancelAfter(TimeSpan.FromSeconds(2));
+ await item.Value.Speaker.SendMessage(message, cts.Token);
}
catch (ObjectDisposedException)
{
@@ -142,11 +148,12 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct)
}
catch (Exception e)
{
- _logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId);
+ _logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key);
// TODO: this should probably kill the client, but due to the
// async nature of the client handling, calling Dispose
// will not remove the client from the active clients list
}
+ }));
}
private async Task HandleRpcClientAsync(ulong clientId, Speaker speaker,
diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs
index fc61247..094875d 100644
--- a/Vpn.Service/Program.cs
+++ b/Vpn.Service/Program.cs
@@ -16,10 +16,12 @@ public static class Program
#if !DEBUG
private const string ServiceName = "Coder Desktop";
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService";
+ private const string DefaultLogLevel = "Information";
#else
// This value matches Create-Service.ps1.
private const string ServiceName = "Coder Desktop (Debug)";
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService";
+ private const string DefaultLogLevel = "Debug";
#endif
private const string ManagerConfigSection = "Manager";
@@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args)
builder.Services.AddSingleton();
// Services
+ builder.Services.AddHostedService();
+ builder.Services.AddHostedService();
+
+ // Either run as a Windows service or a console application
if (!Environment.UserInteractive)
{
MainLogger.Information("Running as a windows service");
@@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args)
MainLogger.Information("Running as a console application");
}
- builder.Services.AddHostedService();
- builder.Services.AddHostedService();
-
var host = builder.Build();
Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!;
MainLogger.Information("Application is starting");
@@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder)
["Serilog:Using:0"] = "Serilog.Sinks.File",
["Serilog:Using:1"] = "Serilog.Sinks.Console",
- ["Serilog:MinimumLevel"] = "Information",
+ ["Serilog:MinimumLevel"] = DefaultLogLevel,
["Serilog:Enrich:0"] = "FromLogContext",
["Serilog:WriteTo:0:Name"] = "File",
diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs
index a323cac..7dd6738 100644
--- a/Vpn.Service/TunnelSupervisor.cs
+++ b/Vpn.Service/TunnelSupervisor.cs
@@ -99,18 +99,16 @@ public async Task StartAsync(string binPath,
},
};
// TODO: maybe we should change the log format in the inner binary
- // to something without a timestamp
- var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]");
- var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]");
+ // to something without a timestamp
_subprocess.OutputDataReceived += (_, args) =>
{
if (!string.IsNullOrWhiteSpace(args.Data))
- outLogger.Debug("{Data}", args.Data);
+ _logger.LogInformation("stdout: {Data}", args.Data);
};
_subprocess.ErrorDataReceived += (_, args) =>
{
if (!string.IsNullOrWhiteSpace(args.Data))
- errLogger.Debug("{Data}", args.Data);
+ _logger.LogInformation("stderr: {Data}", args.Data);
};
// Pass the other end of the pipes to the subprocess and dispose
diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs
index d113a50..37ec554 100644
--- a/Vpn/Speaker.cs
+++ b/Vpn/Speaker.cs
@@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default)
// Handshakes should always finish quickly, so enforce a 5s timeout.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
cts.CancelAfter(TimeSpan.FromSeconds(5));
- await PerformHandshake(ct);
+ await PerformHandshake(cts.Token);
// Start ReceiveLoop in the background.
_receiveTask = ReceiveLoop(_cts.Token);