@@ -37,7 +37,7 @@ const TLSRecord = struct {
version: Protocol.Version = Protocol.TLSv1_2,
length: u16 = 0,
kind: union(ContentType) {
change_cipher_spec: []const u8,
change_cipher_spec: void, // This is const packet
alert: Alert,
handshake: Handshake.Handshake,
application_data: []const u8,
@@ -47,6 +47,7 @@ const TLSRecord = struct {
var fba = fixedBufferStream(buffer);
const len = switch (record.kind) {
.handshake => |ch| try ch.pack(buffer[5..]),
.change_cipher_spec => ChangeCipherSpec.pack(buffer[5..]),
else => unreachable,
};
var w = fba.writer().any();
@@ -79,7 +80,7 @@ const TLSRecord = struct {
.version = version,
.length = length,
.kind = switch (fragtype) {
.change_cipher_spec => .{ .change_cipher_spec = unreachable },
.change_cipher_spec => .{ .change_cipher_spec = if (fragbuff[0] != 1) return error.InvalidCCSPacket else {} },
.alert => .{ .alert = try Alert.unpack(fragbuff) },
.handshake => .{ .handshake = try Handshake.Handshake.unpack(fragbuff, sess) },
.application_data => .{ .application_data = unreachable },
@@ -157,11 +158,10 @@ pub const EllipticCurveCipher = struct {
const key_material = try std.crypto.dh.X25519.KeyPair.create(empty);
//try w.writeByte(@intFromEnum(cke.pve));
try w.writeByte(0);
//try w.writeInt(u16, @truncate(key_material.public_key.len), std.builtin.Endian.big);
try w.writeInt(u8, @truncate(key_material.public_key.len), std.builtin.Endian.big);
try w.writeAll(&key_material.public_key);
return 2 + key_material.public_key.len;
return 1 + key_material.public_key.len;
}
pub fn packKeyExchange(ecc: EllipticCurveCipher, buffer: []u8) !usize {
@@ -176,10 +176,10 @@ pub const EllipticCurveCipher = struct {
const r = fba.reader().any();
const curve_type = try CurveType.fromByte(try r.readByte());
print("named curve {} {}\n", .{ try r.readByte(), try r.readByte() });
print("full buffer {any}\n", .{buffer[0..4]});
print("full buffer {any}\n", .{buffer[4..][0..32]});
print("full buffer {any}\n", .{buffer[36..]});
//print("named curve {} {}\n", .{ try r.readByte(), try r.readByte() });
//print("full buffer {any}\n", .{buffer[0..4]});
//print("full buffer {any}\n", .{buffer[4..][0..32]});
//print("full buffer {any}\n", .{buffer[36..]});
return .{
.curve = switch (curve_type) {
.named_curve => .{ .named_curve = .{} },
@@ -239,7 +239,6 @@ const ClientECDH = struct {
// 1..255
point: []u8 = &[0]u8{},
};
const Finished = struct {};
const HashAlgorithm = enum(u8) {
none = 0,
@@ -264,6 +263,17 @@ const SignatureAndHashAlgorithm = struct {
signature: SignatureAlgorithm,
};
pub const ChangeCipherSpec = struct {
pub fn unpack(_: []const u8) !ChangeCipherSpec {
unreachable;
}
pub fn pack(buffer: []u8) usize {
buffer[0] = 0x01;
return 1;
}
};
test "Handshake ClientHello" {
var buffer = [_]u8{0} ** 0x400;
const client_hello = Handshake.ClientHello.init();
@@ -293,14 +303,14 @@ fn startHandshake(conn: std.net.Stream) !void {
}
/// Forgive me, I'm tired
fn getServer(conn: std.net.Stream) !void {
var server_hello: [0x1000]u8 = undefined;
const s_hello_read = try conn.read(&server_hello);
if (s_hello_read == 0) return error.InvalidSHello;
fn readServer(conn: std.net.Stream, server: []u8) !usize {
const s_read = try conn.read(server);
if (s_read == 0) return error.InvalidSHello;
const server_msg = server_hello[0..s_hello_read];
const server_msg = server[0..s_read];
if (false) print("server data: {any}\n", .{server_msg});
try std.testing.expect(s_hello_read > 7);
try std.testing.expect(s_read > 7);
return s_read;
}
fn buildServer(data: []const u8) !void {
@@ -308,6 +318,7 @@ fn buildServer(data: []const u8) !void {
var next_block: []const u8 = data;
while (next_block.len > 0) {
if (false) print("server block\n{any}\n", .{next_block});
const tlsr = try TLSRecord.unpack(next_block, &session);
if (false) print("mock {}\n", .{tlsr.length});
next_block = next_block[tlsr.length + 5 ..];
@@ -318,10 +329,12 @@ fn buildServer(data: []const u8) !void {
switch (hs.body) {
.server_hello => |hello| {
print("server hello {}\n", .{@TypeOf(hello)});
print("srv selected suite {any}\n", .{hello.cipher});
print("test selected suite {any}\n", .{CipherSuites.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256});
if (std.mem.eql(
u8,
&hello.cipher,
&CipherSuites.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
&CipherSuites.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
)) {
session.cipher.suite = .{ .ecc = .{} };
} else {
@@ -357,15 +370,50 @@ fn completeClient(conn: std.net.Stream) !void {
};
const cke_len = try cke_record.pack(&buffer);
try std.testing.expectEqual(43, cke_len);
try std.testing.expectEqual(42, cke_len);
print("CKE: {any}\n", .{buffer[0..cke_len]});
const ckeout = try conn.write(buffer[0..cke_len]);
if (true) print("cke delivered, {}\n", .{ckeout});
var r_buf: [0x1000]u8 = undefined;
var session = State{};
if (false) { // check for alerts
const num = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num]});
const sin = try TLSRecord.unpack(r_buf[0..num], &session);
print("server thing {}\n", .{sin});
}
const ccs_record = TLSRecord{
.kind = .{ .change_cipher_spec = {} },
};
const ccs_len = try ccs_record.pack(&buffer);
const ccsout = try conn.write(buffer[0..ccs_len]);
try std.testing.expectEqual(6, ccsout);
const fin = Handshake.Finished{};
//const fin_len = try fin.pack(&buffer);
const fin_record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(fin),
},
};
const fin_len = try fin_record.pack(&buffer);
print("fin: {any}\n", .{buffer[0..fin_len]});
const finout = try conn.write(buffer[0..fin_len]);
if (true) print("fin delivered, {}\n", .{finout});
const num2 = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num2]});
const sin2 = try TLSRecord.unpack(r_buf[0..num2], &session);
print("server thing {}\n", .{sin2});
}
fn fullHandshake(conn: std.net.Stream) !void {
try startHandshake(conn);
try getServer(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
}
@@ -377,16 +425,25 @@ test "tls" {
const conn = try net.tcpConnectToAddress(addr);
try startHandshake(conn);
try getServer(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
}
test "mock server response" {
if (true) return error.SkipZigTest;
// zig fmt: off
const server_data = [_]u8{
22, 3, 3, 0, 74,
2, 0, 0, 70,
3, 3, 75, 127, 236, 41, 6, 185, 127, 156, 38, 101, 41, 80, 93, 16, 140, 154, 60, 40, 250, 248, 115, 110, 115, 15, 68, 79, 87, 78, 71, 82, 68, 1, 32, 24, 187, 143, 225, 245, 127, 101, 130, 182, 200, 134, 201, 74, 38, 128, 15, 14, 35, 146, 216, 106, 109, 225, 72, 177, 41, 225, 227, 146, 101, 101, 10, 204, 169, 0,
3, 3, 75, 127, 236, 41, 6, 185, 127, 156, 38, 101,
41, 80, 93, 16, 140, 154, 60, 40, 250, 248, 115,
110, 115, 15, 68, 79, 87, 78, 71, 82, 68, 1, 32,
24, 187, 143, 225, 245, 127, 101, 130, 182, 200,
134, 201, 74, 38, 128, 15, 14, 35, 146, 216, 106,
109, 225, 72, 177, 41, 225, 227, 146, 101, 101,
10, 204, 169, 0,
22, 3, 3, 2, 64,
11, 0, 2, 60,
0, 2, 57, 0, 2, 54, 48, 130, 2, 50, 48,