-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathCredentialManager.cs
440 lines (375 loc) · 15.7 KB
/
CredentialManager.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using Coder.Desktop.App.Models;
using Coder.Desktop.CoderSdk;
using Coder.Desktop.CoderSdk.Coder;
using Coder.Desktop.Vpn.Utilities;
namespace Coder.Desktop.App.Services;
public class RawCredentials
{
public required string CoderUrl { get; set; }
public required string ApiToken { get; set; }
}
[JsonSerializable(typeof(RawCredentials))]
public partial class RawCredentialsJsonContext : JsonSerializerContext;
public interface ICredentialManager
{
public event EventHandler<CredentialModel> CredentialsChanged;
/// <summary>
/// Returns cached credentials or an invalid credential model if none are cached. It's preferable to use
/// LoadCredentials if you are operating in an async context.
/// </summary>
public CredentialModel GetCachedCredentials();
/// <summary>
/// Get any sign-in URL. The returned value is not parsed to check if it's a valid URI.
/// </summary>
public Task<string?> GetSignInUri();
/// <summary>
/// Returns cached credentials or loads/verifies them from storage if not cached.
/// </summary>
public Task<CredentialModel> LoadCredentials(CancellationToken ct = default);
public Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct = default);
public Task ClearCredentials(CancellationToken ct = default);
}
public interface ICredentialBackend
{
public Task<RawCredentials?> ReadCredentials(CancellationToken ct = default);
public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default);
public Task DeleteCredentials(CancellationToken ct = default);
}
/// <summary>
/// Implements ICredentialManager using an ICredentialBackend to store
/// credentials.
/// </summary>
public class CredentialManager : ICredentialManager
{
private const string CredentialsTargetName = "Coder.Desktop.App.Credentials";
// _opLock is held for the full duration of SetCredentials, and partially
// during LoadCredentials. _opLock protects _inFlightLoad, _loadCts, and
// writes to _latestCredentials.
private readonly RaiiSemaphoreSlim _opLock = new(1, 1);
// _inFlightLoad and _loadCts are set at the beginning of a LoadCredentials
// call.
private Task<CredentialModel>? _inFlightLoad;
private CancellationTokenSource? _loadCts;
// Reading and writing a reference in C# is always atomic, so this doesn't
// need to be protected on reads with a lock in GetCachedCredentials.
//
// The volatile keyword disables optimizations on reads/writes which helps
// other threads see the new value quickly (no guarantee that it's
// immediate).
private volatile CredentialModel? _latestCredentials;
private ICredentialBackend Backend { get; } = new WindowsCredentialBackend(CredentialsTargetName);
private ICoderApiClientFactory CoderApiClientFactory { get; } = new CoderApiClientFactory();
public CredentialManager()
{
}
public CredentialManager(ICredentialBackend backend, ICoderApiClientFactory coderApiClientFactory)
{
Backend = backend;
CoderApiClientFactory = coderApiClientFactory;
}
public event EventHandler<CredentialModel>? CredentialsChanged;
public CredentialModel GetCachedCredentials()
{
// No lock required to read the reference.
var latestCreds = _latestCredentials;
// No clone needed as the model is immutable.
if (latestCreds != null) return latestCreds;
return new CredentialModel
{
State = CredentialState.Unknown,
};
}
public async Task<string?> GetSignInUri()
{
try
{
var raw = await Backend.ReadCredentials();
if (raw is not null && !string.IsNullOrWhiteSpace(raw.CoderUrl)) return raw.CoderUrl;
}
catch
{
// ignored
}
return null;
}
// LoadCredentials may be preempted by SetCredentials.
public Task<CredentialModel> LoadCredentials(CancellationToken ct = default)
{
// This function is not `async` because we may return an existing task.
// However, we still want to acquire the lock with the
// CancellationToken so it can be canceled if needed.
using var _ = _opLock.LockAsync(ct).Result;
// If we already have a cached value, return it.
var latestCreds = _latestCredentials;
if (latestCreds != null) return Task.FromResult(latestCreds);
// If we are already loading, return the existing task.
if (_inFlightLoad != null) return _inFlightLoad;
// Otherwise, kick off a new load.
// Note: subsequent loads returned from above will ignore the passed in
// CancellationToken. We set a maximum timeout of 15 seconds anyway.
_loadCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
_loadCts.CancelAfter(TimeSpan.FromSeconds(15));
_inFlightLoad = LoadCredentialsInner(_loadCts.Token);
return _inFlightLoad;
}
public async Task SetCredentials(string coderUrl, string apiToken, CancellationToken ct)
{
using var _ = await _opLock.LockAsync(ct);
// If there's an ongoing load, cancel it.
if (_loadCts != null)
{
await _loadCts.CancelAsync();
_loadCts.Dispose();
_loadCts = null;
_inFlightLoad = null;
}
if (string.IsNullOrWhiteSpace(coderUrl)) throw new ArgumentException("Coder URL is required", nameof(coderUrl));
coderUrl = coderUrl.Trim();
if (coderUrl.Length > 128) throw new ArgumentException("Coder URL is too long", nameof(coderUrl));
if (!Uri.TryCreate(coderUrl, UriKind.Absolute, out var uri))
throw new ArgumentException($"Coder URL '{coderUrl}' is not a valid URL", nameof(coderUrl));
if (uri.Scheme != "http" && uri.Scheme != "https")
throw new ArgumentException("Coder URL must be HTTP or HTTPS", nameof(coderUrl));
if (uri.PathAndQuery != "/") throw new ArgumentException("Coder URL must be the root URL", nameof(coderUrl));
if (string.IsNullOrWhiteSpace(apiToken)) throw new ArgumentException("API token is required", nameof(apiToken));
apiToken = apiToken.Trim();
var raw = new RawCredentials
{
CoderUrl = coderUrl,
ApiToken = apiToken,
};
var populateCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
populateCts.CancelAfter(TimeSpan.FromSeconds(15));
var model = await PopulateModel(raw, populateCts.Token);
await Backend.WriteCredentials(raw, ct);
UpdateState(model);
}
public async Task ClearCredentials(CancellationToken ct = default)
{
using var _ = await _opLock.LockAsync(ct);
await Backend.DeleteCredentials(ct);
UpdateState(new CredentialModel
{
State = CredentialState.Invalid,
});
}
private async Task<CredentialModel> LoadCredentialsInner(CancellationToken ct)
{
CredentialModel model;
try
{
var raw = await Backend.ReadCredentials(ct);
model = await PopulateModel(raw, ct);
}
catch
{
// This catch will be hit if a SetCredentials operation started, or
// if the read/populate failed for some other reason (e.g. HTTP
// timeout).
//
// We don't need to clear the credentials here, the app will think
// they're unset and any subsequent SetCredentials call after the
// user signs in again will overwrite the old invalid ones.
model = new CredentialModel
{
State = CredentialState.Invalid,
};
}
// Grab the lock again so we can update the state.
using (await _opLock.LockAsync(ct))
{
// Prevent new LoadCredentials calls from returning this task.
if (_loadCts != null)
{
_loadCts.Dispose();
_loadCts = null;
_inFlightLoad = null;
}
// If we were canceled but made it this far, try to return the
// latest credentials instead.
if (ct.IsCancellationRequested)
{
var latestCreds = _latestCredentials;
if (latestCreds is not null) return latestCreds;
}
// If there aren't any latest credentials after a cancellation, we
// most likely timed out and should throw.
ct.ThrowIfCancellationRequested();
UpdateState(model);
return model;
}
}
private async Task<CredentialModel> PopulateModel(RawCredentials? credentials, CancellationToken ct)
{
if (credentials is null || string.IsNullOrWhiteSpace(credentials.CoderUrl) ||
string.IsNullOrWhiteSpace(credentials.ApiToken))
return new CredentialModel
{
State = CredentialState.Invalid,
};
BuildInfo buildInfo;
User me;
try
{
var sdkClient = CoderApiClientFactory.Create(credentials.CoderUrl);
// BuildInfo does not require authentication.
buildInfo = await sdkClient.GetBuildInfo(ct);
sdkClient.SetSessionToken(credentials.ApiToken);
me = await sdkClient.GetUser(User.Me, ct);
}
catch (CoderApiHttpException)
{
throw;
}
catch (Exception e)
{
throw new InvalidOperationException("Could not connect to or verify Coder server", e);
}
ServerVersionUtilities.ParseAndValidateServerVersion(buildInfo.Version);
if (string.IsNullOrWhiteSpace(me.Username))
throw new InvalidOperationException("Could not retrieve user information, username is empty");
return new CredentialModel
{
State = CredentialState.Valid,
CoderUrl = credentials.CoderUrl,
ApiToken = credentials.ApiToken,
Username = me.Username,
};
}
// Lock must be held when calling this function.
private void UpdateState(CredentialModel newModel)
{
_latestCredentials = newModel;
// Since the event handlers could block (or call back the
// CredentialManager and deadlock), we run these in a new task.
if (CredentialsChanged == null) return;
Task.Run(() => { CredentialsChanged?.Invoke(this, newModel); });
}
}
public class WindowsCredentialBackend : ICredentialBackend
{
private readonly string _credentialsTargetName;
public WindowsCredentialBackend(string credentialsTargetName)
{
_credentialsTargetName = credentialsTargetName;
}
public Task<RawCredentials?> ReadCredentials(CancellationToken ct = default)
{
var raw = NativeApi.ReadCredentials(_credentialsTargetName);
if (raw == null) return Task.FromResult<RawCredentials?>(null);
RawCredentials? credentials;
try
{
credentials = JsonSerializer.Deserialize(raw, RawCredentialsJsonContext.Default.RawCredentials);
}
catch (JsonException)
{
credentials = null;
}
return Task.FromResult(credentials);
}
public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default)
{
var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials);
NativeApi.WriteCredentials(_credentialsTargetName, raw);
return Task.CompletedTask;
}
public Task DeleteCredentials(CancellationToken ct = default)
{
NativeApi.DeleteCredentials(_credentialsTargetName);
return Task.CompletedTask;
}
private static class NativeApi
{
private const int CredentialTypeGeneric = 1;
private const int PersistenceTypeLocalComputer = 2;
private const int ErrorNotFound = 1168;
private const int CredMaxCredentialBlobSize = 5 * 512;
public static string? ReadCredentials(string targetName)
{
if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return null;
throw new InvalidOperationException($"Failed to read credentials (Error {error})");
}
try
{
var cred = Marshal.PtrToStructure<CREDENTIAL>(credentialPtr);
return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char));
}
finally
{
CredFree(credentialPtr);
}
}
public static void WriteCredentials(string targetName, string secret)
{
var byteCount = Encoding.Unicode.GetByteCount(secret);
if (byteCount > CredMaxCredentialBlobSize)
throw new ArgumentOutOfRangeException(nameof(secret),
$"The secret is greater than {CredMaxCredentialBlobSize} bytes");
var credentialBlob = Marshal.StringToHGlobalUni(secret);
var cred = new CREDENTIAL
{
Type = CredentialTypeGeneric,
TargetName = targetName,
CredentialBlobSize = byteCount,
CredentialBlob = credentialBlob,
Persist = PersistenceTypeLocalComputer,
};
try
{
if (!CredWriteW(ref cred, 0))
{
var error = Marshal.GetLastWin32Error();
throw new InvalidOperationException($"Failed to write credentials (Error {error})");
}
}
finally
{
Marshal.FreeHGlobal(credentialBlob);
}
}
public static void DeleteCredentials(string targetName)
{
if (!CredDeleteW(targetName, CredentialTypeGeneric, 0))
{
var error = Marshal.GetLastWin32Error();
if (error == ErrorNotFound) return;
throw new InvalidOperationException($"Failed to delete credentials (Error {error})");
}
}
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags);
[DllImport("Advapi32.dll", SetLastError = true)]
private static extern void CredFree([In] IntPtr cred);
[DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern bool CredDeleteW(string target, int type, int flags);
[StructLayout(LayoutKind.Sequential)]
private struct CREDENTIAL
{
public int Flags;
public int Type;
[MarshalAs(UnmanagedType.LPWStr)] public string TargetName;
[MarshalAs(UnmanagedType.LPWStr)] public string Comment;
public long LastWritten;
public int CredentialBlobSize;
public IntPtr CredentialBlob;
public int Persist;
public int AttributeCount;
public IntPtr Attributes;
[MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias;
[MarshalAs(UnmanagedType.LPWStr)] public string UserName;
}
}
}