@@ -11,12 +11,10 @@ const Alert = @import("alert.zig");
const Extensions = @import("extensions.zig");
const Extension = Extensions.Extension;
const Protocol = @import("protocol.zig");
const State = @import("state.zig");
const ConnCtx = @import("context.zig");
const Handshake = @import("handshake.zig");
const Cipher = @import("cipher.zig");
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
const ContentType = enum(u8) {
change_cipher_spec = 20,
alert = 21,
@@ -63,7 +61,7 @@ const TLSRecord = struct {
return record.packFragment(buffer);
}
pub fn unpackFragment(buffer: []const u8, sess: *State) !TLSRecord {
pub fn unpackFragment(buffer: []const u8, sess: *ConnCtx) !TLSRecord {
var fba = fixedBufferStream(buffer);
var r = fba.reader().any();
@@ -88,7 +86,7 @@ const TLSRecord = struct {
},
};
}
pub fn unpack(buffer: []const u8, sess: *State) !TLSRecord {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !TLSRecord {
return try unpackFragment(buffer, sess);
}
};
@@ -123,7 +121,10 @@ pub const ChangeCipherSpec = struct {
test "Handshake ClientHello" {
var buffer = [_]u8{0} ** 0x400;
const client_hello = Handshake.ClientHello.init();
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
var rand = [_]u8{0} ** 32;
csprng.fill(&rand);
const client_hello = Handshake.ClientHello.init(rand);
const record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(client_hello),
@@ -134,9 +135,12 @@ test "Handshake ClientHello" {
_ = len;
}
fn startHandshake(conn: std.net.Stream) !void {
fn startHandshake(conn: std.net.Stream) !ConnCtx {
var buffer = [_]u8{0} ** 0x1000;
const client_hello = Handshake.ClientHello.init();
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
var rand = [_]u8{0} ** 32;
csprng.fill(&rand);
const client_hello = Handshake.ClientHello.init(rand);
const record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(client_hello),
@@ -147,6 +151,9 @@ fn startHandshake(conn: std.net.Stream) !void {
const dout = try conn.write(buffer[0..len]);
if (false) print("data count {}\n", .{dout});
if (false) print("data out {any}\n", .{buffer[0..len]});
return .{
.our_random = client_hello.random,
};
}
/// Forgive me, I'm tired
@@ -160,13 +167,12 @@ fn readServer(conn: std.net.Stream, server: []u8) !usize {
return s_read;
}
fn buildServer(data: []const u8) !void {
var session = State{};
fn buildServer(data: []const u8, ctx: *ConnCtx) !void {
var next_block: []const u8 = data;
while (next_block.len > 0) {
if (false) print("server block\n{any}\n", .{next_block});
const tlsr = try TLSRecord.unpack(next_block, &session);
const tlsr = try TLSRecord.unpack(next_block, ctx);
if (false) print("mock {}\n", .{tlsr.length});
next_block = next_block[tlsr.length + 5 ..];
@@ -178,9 +184,10 @@ fn buildServer(data: []const u8) !void {
print("server hello {}\n", .{@TypeOf(hello)});
print("srv selected suite {any}\n", .{hello.cipher});
print("test selected suite {any}\n", .{.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256});
if (session.cipher.suite != .ecc) {
if (ctx.cipher.suite != .ecc) {
return error.UnexpectedCipherSuite;
}
ctx.cipher.suite.ecc.cli_material = try Cipher.X25519.KeyPair.create(null);
},
.certificate => |cert| {
print("server cert {}\n", .{@TypeOf(cert)});
@@ -201,9 +208,10 @@ fn buildServer(data: []const u8) !void {
}
}
fn completeClient(conn: std.net.Stream) !void {
fn completeClient(conn: std.net.Stream, ctx: *ConnCtx) !void {
var buffer = [_]u8{0} ** 0x1000;
const cke = try Handshake.ClientKeyExchange.init();
const cke = try Handshake.ClientKeyExchange.init(ctx);
const cke_record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(cke),
@@ -217,11 +225,10 @@ fn completeClient(conn: std.net.Stream) !void {
if (true) print("cke delivered, {}\n", .{ckeout});
var r_buf: [0x1000]u8 = undefined;
var session = State{};
if (false) { // check for alerts
const num = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num]});
const sin = try TLSRecord.unpack(r_buf[0..num], &session);
const sin = try TLSRecord.unpack(r_buf[0..num], ctx);
print("server thing {}\n", .{sin});
}
@@ -246,16 +253,16 @@ fn completeClient(conn: std.net.Stream) !void {
const num2 = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num2]});
const sin2 = try TLSRecord.unpack(r_buf[0..num2], &session);
const sin2 = try TLSRecord.unpack(r_buf[0..num2], ctx);
print("server thing {}\n", .{sin2});
}
fn fullHandshake(conn: std.net.Stream) !void {
try startHandshake(conn);
var ctx = try startHandshake(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
try buildServer(server[0..l], &ctx);
try completeClient(conn, &ctx);
}
test "tls" {
@@ -265,15 +272,18 @@ test "tls" {
};
const conn = try net.tcpConnectToAddress(addr);
try startHandshake(conn);
var ctx = try startHandshake(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
try buildServer(server[0..l], &ctx);
try completeClient(conn, &ctx);
}
test "mock server response" {
if (true) return error.SkipZigTest;
var ctx = ConnCtx{};
// zig fmt: off
const server_data = [_]u8{
22, 3, 3, 0, 74,
@@ -377,7 +387,7 @@ test "mock server response" {
};
// zig fmt: on
try buildServer(&server_data);
try buildServer(&server_data, &ctx);
//const cke = try ClientKeyExchange.init();
//const record = TLSRecord{