@@ -143,19 +143,71 @@ pub fn cookieAuth(HMAC: type) type {
};
}
pub fn validateToken(hm: *HMAC, token: []const u8, user_buffer: []u8) Error![]u8 {
var buffer: [ibuf_size]u8 = undefined;
const len = b64_dec.calcSizeForSlice(token) catch return error.InvalidAuth;
const decoded = buffer[0..len];
b64_dec.decode(decoded, token) catch return error.InvalidAuth;
const time = decoded[0..8];
if (std.mem.indexOfScalar(u8, decoded[8..], ':')) |i| {
const username = decoded[8..][0..i];
var extra_data: ?[]const u8 = null;
var given_hash: []const u8 = decoded[8..][i + 1 ..][0..HMAC.mac_length];
if (std.mem.indexOfScalar(u8, decoded[8 + i + 1 ..], ':')) |ed| {
extra_data = decoded[8..][i + 1 ..][0..ed];
given_hash = decoded[8..][i + 1 ..][ed + 1 .. HMAC.mac_length];
}
@memcpy(user_buffer[0..username.len], username);
hm.update(time);
hm.update(username);
if (extra_data) |ed| hm.update(ed);
var our_hash: [HMAC.mac_length]u8 = undefined;
hm.final(our_hash[0..]);
if (std.crypto.utils.timingSafeEql([HMAC.mac_length]u8, given_hash[0..HMAC.mac_length].*, our_hash)) {
return username;
}
return error.InvalidAuth;
} else return error.InvalidAuth;
return error.InvalidAuth;
}
pub fn authenticate(ptr: *anyopaque, headers: *const Headers) Error!User {
const ca: *Self = @ptrCast(@alignCast(ptr));
if (ca.base) |base| {
// If base provider offers authenticate, we should defer to it
if (base.vtable.authenticate) |_| {
return base.authenticate(headers);
}
if (headers.get("Cookie")) |cookies| {
// This actually isn't technically invalid, it's only
// currently not implemented.
if (cookies.value_list.next != null) return error.InvalidAuth;
const cookie = cookies.value_list.value;
std.debug.print("cookie: {s} \n", .{cookie});
if (std.mem.indexOf(u8, cookie, ca.cookie_name)) |i| {
return base.lookupUser(cookie[i..]);
var itr = std.mem.tokenizeSequence(u8, cookie, "; ");
while (itr.next()) |tkn| {
if (startsWith(u8, tkn, ca.cookie_name)) {
var un_buf: [64]u8 = undefined;
var hmac = HMAC.init(ca.server_secret_key);
const username = try validateToken(
&hmac,
tkn[ca.cookie_name.len + 1 ..],
un_buf[0..],
);
return base.lookupUser(username);
}
}
} else {
std.debug.print("no cookie\n{any}", .{headers});
}
}
return .{ .user_ptr = null };
return error.UnknownUser;
}
pub fn valid(ptr: *anyopaque, user: *const User) bool {
@@ -201,7 +253,10 @@ pub fn cookieAuth(HMAC: type) type {
pub fn createSession(ptr: *anyopaque, user: *User) Error!void {
const ca: *Self = @ptrCast(@alignCast(ptr));
if (ca.base) |base| try base.createSession(user);
if (ca.base) |base| base.createSession(user) catch |e| switch (e) {
error.NotProvided => {},
else => return e,
};
const prefix_len: usize = (if (user.username) |u| u.len + 1 else 0) +
if (user.session_extra_data) |ed| ed.len + 1 else 0;
@@ -217,7 +272,9 @@ pub fn cookieAuth(HMAC: type) type {
pub fn getCookie(ptr: *anyopaque, user: User) Error!?Cookie {
const ca: *Self = @ptrCast(@alignCast(ptr));
if (ca.base) |base| return base.getCookie(user);
if (ca.base) |base| if (base.vtable.get_cookie) |_| {
return base.getCookie(user);
};
if (user.session_next) |next| {
return .{
@@ -306,6 +363,30 @@ test "CookieAuth ExtraData" {
try std.testing.expectStringStartsWith(decoded[21..], "extra data:");
}
test "CookieAuth token" {
const a = std.testing.allocator;
var auth = CookieAuth.init(.{
.alloc = a,
.server_secret_key = "This may surprise you; but this secret_key is more secure than most of the secret keys in prod use",
});
const provider = auth.provider();
var user = User{ .username = "testing user" };
try provider.createSession(&user);
try std.testing.expect(user.session_next != null);
const cookie = try provider.getCookie(user);
try std.testing.expect(cookie != null);
try std.testing.expectStringStartsWith(cookie.?.value[8..], "AAB0ZXN0aW5nIHVzZXI6");
var username_buf: [64]u8 = undefined;
var hm = Hmac.sha2.HmacSha256.init(auth.server_secret_key);
const valid = try CookieAuth.validateToken(&hm, cookie.?.value, username_buf[0..]);
try std.testing.expectEqualStrings(user.username.?, valid);
}
pub const InvalidAuth = struct {
pub fn provider() Provider {
return Provider{
@@ -384,6 +465,7 @@ test Provider {
const std = @import("std");
const Allocator = std.mem.Allocator;
const toBytes = std.mem.toBytes;
const startsWith = std.mem.startsWith;
const nativeToLittle = std.mem.nativeToLittle;
const Hmac = std.crypto.auth.hmac;
const b64_enc = std.base64.url_safe.Encoder;