@@ -8,9 +8,10 @@ pub const AuthN = @import("authentication.zig");
pub const User = @import("auth/user.zig");
pub const Error = error{
UnknownUser,
Unauthenticated,
InvalidAuth,
NotProvided,
Unauthenticated,
UnknownUser,
};
/// Fails closed: the provider used may return an error which will be caught and
@@ -26,14 +27,31 @@ pub const Error = error{
//}
pub const MTLS = struct {
pub fn provider(mtls: *MTLS) Provider {
return Provider{
.ctx = mtls,
.vtable = .{
.valid = validPtr,
.lookup_user = lookupUserPtr,
},
};
base: ?Provider = null,
pub fn authenticatePtr(ptr: *anyopaque, headers: *const Headers) Error!User {
const self: *MTLS = @ptrCast(@alignCast(ptr));
return self.authenticate(headers);
}
pub fn authenticate(mtls: *MTLS, headers: *const Headers) Error!User {
var success: bool = false;
if (headers.get("MTLS_ENABLED")) |enabled| {
if (enabled.value_list.next) |_| return error.InvalidAuth;
if (std.mem.eql(u8, enabled.value_list.value, "SUCCESS")) {
success = true;
}
}
if (!success) return error.UnknownUser;
if (mtls.base) |base| {
if (headers.get("MTLS_FINGERPRINT")) |enabled| {
if (enabled.value_list.next != null) return error.InvalidAuth;
return base.lookupUser(enabled.value_list.value);
}
}
return .{ .user_ptr = null };
}
pub fn valid(_: *MTLS, _: *const User) bool {
@@ -41,7 +59,7 @@ pub const MTLS = struct {
}
fn validPtr(ptr: *anyopaque, user: *const User) bool {
const self: *MTLS = @ptrCast(ptr);
const self: *MTLS = @ptrCast(@alignCast(ptr));
return self.valid(user);
}
@@ -50,14 +68,47 @@ pub const MTLS = struct {
}
pub fn lookupUserPtr(ptr: *anyopaque, user_id: []const u8) Error!User {
const self: *MTLS = @ptrCast(ptr);
const self: *MTLS = @ptrCast(@alignCast(ptr));
return self.lookupUser(user_id);
}
pub fn provider(mtls: *MTLS) Provider {
return Provider{
.ctx = mtls,
.vtable = .{
.authenticate = authenticatePtr,
.valid = validPtr,
.lookup_user = lookupUserPtr,
},
};
}
};
test MTLS {
//const a = std.testing.allocator;
const a = std.testing.allocator;
var mtls = MTLS{};
var provider = mtls.provider();
var headers = Headers.init(a);
defer headers.raze();
try headers.add("MTLS_ENABLED", "SUCCESS");
try headers.add("MTLS_FINGERPRINT", "LOLTOTALLYVALID");
const user = try provider.authenticate(&headers);
try std.testing.expectEqual(null, user.user_ptr);
try headers.add("MTLS_ENABLED", "SUCCESS");
const err = provider.authenticate(&headers);
try std.testing.expectError(error.InvalidAuth, err);
headers.raze();
headers = Headers.init(a);
try headers.add("MTLS_ENABLED", "FAILURE!");
const err2 = provider.authenticate(&headers);
try std.testing.expectError(error.UnknownUser, err2);
// TODO there's likely a few more error states we should validate;
}
pub const InvalidAuth = struct {
@@ -65,6 +116,7 @@ pub const InvalidAuth = struct {
return Provider{
.ctx = undefined,
.vtable = .{
.authenticate = null, // TODO write invalid
.valid = valid,
.lookup_user = lookupUser,
},
@@ -105,6 +157,7 @@ const TestingAuth = struct {
return .{
.ctx = self,
.vtable = .{
.authenticate = null,
.valid = null,
.lookup_user = lookupUserUntyped,
},
@@ -127,3 +180,5 @@ test Provider {
const std = @import("std");
const Allocator = std.mem.Allocator;
const Verse = @import("verse.zig");
const Headers = @import("headers.zig");