srctree

Gregory Mullen parent 4a0099a4 c92ca7df
refactor unpacking to read into ctx directly

src/cipher.zig added: 118, removed: 97, total 21
@@ -1,4 +1,7 @@
const std = @import("std");
const ConnCtx = @import("context.zig");
 
pub const X25519 = std.crypto.dh.X25519;
 
const print = std.debug.print;
const fixedBufferStream = std.io.fixedBufferStream;
@@ -10,7 +13,7 @@ suite: union(enum) {
ecc: EllipticCurve,
} = .{ .invalid = {} },
 
pub const Suites = enum(u16) {
pub const UnsupportedSuites = enum(u16) {
TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA = 0x0013,
TLS_DHE_DSS_WITH_AES_128_CBC_SHA = 0x0032,
TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 = 0x0040,
@@ -39,19 +42,23 @@ pub const Suites = enum(u16) {
TLS_DH_anon_WITH_AES_256_CBC_SHA = 0x003A,
TLS_DH_anon_WITH_AES_256_CBC_SHA256 = 0x006D,
TLS_DH_anon_WITH_RC4_128_MD5 = 0x0018,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9,
// Planned to implement
TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 = 0xCCAC,
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8,
TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 = 0xCCAB,
TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 = 0xCCAE,
};
 
pub const Suites = enum(u16) {
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9,
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8,
 
pub fn fromInt(s: u16) Suites {
return switch (s) {
0xCCA9 => .TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
0xCCAC => .TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256,
0xCCA8 => .TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
0xCCAB => .TLS_PSK_WITH_CHACHA20_POLY1305_SHA256,
0xCCAE => .TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256,
//0xCCAC => .TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256,
//0xCCAB => .TLS_PSK_WITH_CHACHA20_POLY1305_SHA256,
//0xCCAE => .TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256,
else => unreachable,
};
}
@@ -60,8 +67,9 @@ pub const Suites = enum(u16) {
pub const EllipticCurve = struct {
curve: Curves = .{ .invalid = {} },
 
clt_key_mat: ?std.crypto.dh.X25519.KeyPair = null,
srv_key_mat: ?std.crypto.dh.X25519.KeyPair = null,
srv_material: ?X25519.KeyPair = undefined,
cli_material: ?X25519.KeyPair = null,
premaster: [X25519.shared_length]u8 = [_]u8{0} ** X25519.shared_length,
 
pub const Curves = union(CurveType) {
invalid: void,
@@ -88,11 +96,14 @@ pub const EllipticCurve = struct {
pub const ExplicitChar2 = struct {};
pub const NamedCurve = struct {};
 
pub fn init() EllipticCurve {
return .{};
/// srv is copied, and will not zero any arguments
pub fn init(srv: [X25519.public_length]u8) !EllipticCurve {
return .{
.srv_material = try X25519.KeyPair.create(srv),
};
}
 
fn packNamedCurve(_: EllipticCurve, buffer: []u8) !usize {
fn packNamed(_: EllipticCurve, buffer: []u8) !usize {
var fba = fixedBufferStream(buffer);
const w = fba.writer().any();
 
@@ -120,26 +131,39 @@ pub const EllipticCurve = struct {
 
pub fn packKeyExchange(ecc: EllipticCurve, buffer: []u8) !usize {
return switch (ecc.curve) {
.named_curve => try ecc.packNamedCurve(buffer),
else => unreachable,
.named_curve => try ecc.packNamed(buffer),
// LOL sorry future me!
else => return try ecc.packNamed(buffer),
};
}
 
pub fn unpackKeyExchange(buffer: []const u8) !EllipticCurve {
fn unpackNamed(buffer: []const u8, ctx: *ConnCtx) !void {
var fba = fixedBufferStream(buffer);
const r = fba.reader().any();
const name = try r.readInt(u16, .big);
if (name != 0x001d) return error.UnknownCurveName;
if (ctx.cipher.suite.ecc.cli_material == null) unreachable;
const our_seckey = ctx.cipher.suite.ecc.cli_material.?.secret_key;
ctx.cipher.suite.ecc.srv_material = undefined;
const peer_key = &ctx.cipher.suite.ecc.srv_material.?.public_key;
try r.readNoEof(peer_key);
 
ctx.cipher.suite.ecc.premaster = try X25519.scalarmult(our_seckey, peer_key.*);
 
// TODO verify signature
 
}
 
pub fn unpackKeyExchange(buffer: []const u8, ctx: *ConnCtx) !void {
if (ctx.cipher.suite != .ecc) unreachable;
var fba = fixedBufferStream(buffer);
const r = fba.reader().any();
 
const curve_type = try CurveType.fromByte(try r.readByte());
//print("named curve {} {}\n", .{ try r.readByte(), try r.readByte() });
//print("full buffer {any}\n", .{buffer[0..4]});
//print("full buffer {any}\n", .{buffer[4..][0..32]});
//print("full buffer {any}\n", .{buffer[36..]});
return .{
.curve = switch (curve_type) {
.named_curve => .{ .named_curve = .{} },
else => return error.UnsupportedCurve,
},
};
switch (curve_type) {
.named_curve => try unpackNamed(buffer[1..], ctx),
else => return error.UnsupportedCurve,
}
}
};
 
 
src/state.zig added: 118, removed: 97, total 21
@@ -11,9 +11,11 @@ const MACAlgorithm = root.MACAlgorithm;
const CompressionMethod = root.CompressionMethod;
const SessionID = root.SessionID;
 
pub const State = @This();
pub const ConnCtx = @This();
 
cipher: Cipher = .{},
our_random: [32]u8 = [_]u8{0} ** 32,
peer_random: [32]u8 = [_]u8{0} ** 32,
session_id: ?SessionID = null,
entity: ConnectionEnd = .{},
prf_algorithm: PRFAlgorithm = .{},
 
src/handshake.zig added: 118, removed: 97, total 21
@@ -17,13 +17,13 @@
///
const std = @import("std");
 
const State = @import("state.zig");
const ConnCtx = @import("context.zig");
const Protocol = @import("protocol.zig");
const root = @import("root.zig");
const Extensions = @import("extensions.zig");
const Cipher = @import("cipher.zig");
 
const Random = root.Random;
const Random = [32]u8;
const SessionID = root.SessionID;
const Extension = Extensions.Extension;
 
@@ -40,11 +40,11 @@ const HelloRequest = struct {};
 
/// Client Section
pub const ClientHello = struct {
version: Protocol.Version,
version: Protocol.Version = Protocol.TLSv1_2,
random: Random,
session_id: SessionID,
ciphers: []const Cipher.Suites = &[0]Cipher.Suites{},
compression: Compression,
ciphers: []const Cipher.Suites = &SupportedSuiteList,
compression: Compression = .null,
extensions: []const Extension = &[0]Extension{},
 
pub const SupportedExtensions = [_]type{
@@ -59,18 +59,13 @@ pub const ClientHello = struct {
 
pub const length = @sizeOf(ClientHello);
 
pub fn init() ClientHello {
pub fn init(rand: Random) ClientHello {
var hello = ClientHello{
.version = .{ .major = 3, .minor = 3 },
.random = .{
.random_bytes = undefined,
},
.random = rand,
.session_id = [_]u8{0} ** 32,
.ciphers = &SupportedSuiteList,
.compression = .null,
};
 
csprng.fill(&hello.random.random_bytes);
//csprng.fill(&hello.random.random_bytes);
csprng.fill(&hello.session_id);
return hello;
}
@@ -80,7 +75,7 @@ pub const ClientHello = struct {
var w = fba.writer().any();
try w.writeByte(ch.version.major);
try w.writeByte(ch.version.minor);
try w.writeStruct(ch.random);
try w.writeAll(&ch.random);
try w.writeByte(0);
//try w.writeByte(ch.session_id.len);
//try w.writeAll(&ch.session_id);
@@ -105,7 +100,7 @@ pub const ClientHello = struct {
return fba.pos;
}
 
pub fn unpack(buffer: []const u8, _: *State) !ClientHello {
pub fn unpack(buffer: []const u8, _: *ConnCtx) !ClientHello {
_ = buffer;
unreachable;
}
@@ -121,14 +116,8 @@ pub const ClientKeyExchange = struct {
} = .explicit,
cipher: *const Cipher,
 
pub fn init() !ClientKeyExchange {
const cke = ClientKeyExchange{
.cipher = &.{
.suite = .{ .ecc = Cipher.EllipticCurve{
.curve = .{ .named_curve = .{} },
} },
},
};
pub fn init(ctx: *ConnCtx) !ClientKeyExchange {
const cke = ClientKeyExchange{ .cipher = &ctx.cipher };
return cke;
}
 
@@ -165,7 +154,7 @@ pub const ServerHello = struct {
compression: Compression,
extensions: []const Extension,
 
pub fn unpack(buffer: []const u8, sess: *State) !ServerHello {
pub fn unpack(buffer: []const u8, ctx: *ConnCtx) !ServerHello {
print("buffer:: {any}\n", .{buffer});
var fba = fixedBufferStream(buffer);
const r = fba.reader().any();
@@ -174,10 +163,8 @@ pub const ServerHello = struct {
.major = try r.readByte(),
.minor = try r.readByte(),
};
var random = Random{
.random_bytes = undefined,
};
try r.readNoEof(&random.random_bytes);
var random: [32]u8 = undefined;
try r.readNoEof(&random);
 
const session_size = try r.readByte();
var session_id: [32]u8 = [_]u8{0} ** 32;
@@ -192,16 +179,16 @@ pub const ServerHello = struct {
.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
=> {
sess.cipher.suite = .{ .ecc = .{} };
ctx.cipher.suite = .{ .ecc = undefined };
},
else => unreachable,
//else => unreachable,
}
 
// compression
if (try r.readByte() != 0) return error.InvalidCompression;
 
// extensions
if (r.readInt(u16, std.builtin.Endian.big)) |extbytes| {
if (r.readInt(u16, .big)) |extbytes| {
var extbuffer: [0x1000]u8 = undefined;
try r.readNoEof(extbuffer[0..extbytes]);
} else |err| switch (err) {
@@ -227,25 +214,23 @@ pub const ServerKeyExchange = struct {
}
 
/// Will modify sess with supplied
pub fn unpack(buffer: []const u8, sess: *State) !ServerKeyExchange {
switch (sess.cipher.suite) {
.ecc => {
sess.cipher.suite.ecc = try Cipher.EllipticCurve.unpackKeyExchange(buffer);
},
pub fn unpack(buffer: []const u8, ctx: *ConnCtx) !ServerKeyExchange {
switch (ctx.cipher.suite) {
.ecc => try Cipher.EllipticCurve.unpackKeyExchange(buffer, ctx),
else => unreachable,
}
return .{
.buffer = buffer,
.cipher = &sess.cipher,
.cipher = &ctx.cipher,
};
}
};
 
pub const ServerHelloDone = struct {
buffer: []const u8,
session: *State,
session: *ConnCtx,
 
pub fn unpack(buffer: []const u8, sess: *State) !ServerHelloDone {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !ServerHelloDone {
return .{
.buffer = buffer,
.session = sess,
@@ -256,9 +241,9 @@ pub const ServerHelloDone = struct {
/// Certs
pub const Certificate = struct {
buffer: []const u8,
session: *State,
session: *ConnCtx,
 
pub fn unpack(buffer: []const u8, sess: *State) !Certificate {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !Certificate {
return .{
.buffer = buffer,
.session = sess,
@@ -268,9 +253,9 @@ pub const Certificate = struct {
 
pub const CertificateRequest = struct {
buffer: []const u8,
session: *State,
session: *ConnCtx,
 
pub fn unpack(buffer: []const u8, sess: *State) !CertificateRequest {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !CertificateRequest {
return .{
.buffer = buffer,
.session = sess,
@@ -385,7 +370,7 @@ pub const Handshake = struct {
return len + 4;
}
 
pub fn unpack(buffer: []const u8, sess: *State) !Handshake {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !Handshake {
const hs_type = try Type.fromByte(buffer[0]);
const hsbuf = buffer[4..];
return .{
 
src/root.zig added: 118, removed: 97, total 21
@@ -11,12 +11,10 @@ const Alert = @import("alert.zig");
const Extensions = @import("extensions.zig");
const Extension = Extensions.Extension;
const Protocol = @import("protocol.zig");
const State = @import("state.zig");
const ConnCtx = @import("context.zig");
const Handshake = @import("handshake.zig");
const Cipher = @import("cipher.zig");
 
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
 
const ContentType = enum(u8) {
change_cipher_spec = 20,
alert = 21,
@@ -63,7 +61,7 @@ const TLSRecord = struct {
return record.packFragment(buffer);
}
 
pub fn unpackFragment(buffer: []const u8, sess: *State) !TLSRecord {
pub fn unpackFragment(buffer: []const u8, sess: *ConnCtx) !TLSRecord {
var fba = fixedBufferStream(buffer);
var r = fba.reader().any();
 
@@ -88,7 +86,7 @@ const TLSRecord = struct {
},
};
}
pub fn unpack(buffer: []const u8, sess: *State) !TLSRecord {
pub fn unpack(buffer: []const u8, sess: *ConnCtx) !TLSRecord {
return try unpackFragment(buffer, sess);
}
};
@@ -123,7 +121,10 @@ pub const ChangeCipherSpec = struct {
 
test "Handshake ClientHello" {
var buffer = [_]u8{0} ** 0x400;
const client_hello = Handshake.ClientHello.init();
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
var rand = [_]u8{0} ** 32;
csprng.fill(&rand);
const client_hello = Handshake.ClientHello.init(rand);
const record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(client_hello),
@@ -134,9 +135,12 @@ test "Handshake ClientHello" {
_ = len;
}
 
fn startHandshake(conn: std.net.Stream) !void {
fn startHandshake(conn: std.net.Stream) !ConnCtx {
var buffer = [_]u8{0} ** 0x1000;
const client_hello = Handshake.ClientHello.init();
var csprng = std.Random.ChaCha.init([_]u8{0} ** 32);
var rand = [_]u8{0} ** 32;
csprng.fill(&rand);
const client_hello = Handshake.ClientHello.init(rand);
const record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(client_hello),
@@ -147,6 +151,9 @@ fn startHandshake(conn: std.net.Stream) !void {
const dout = try conn.write(buffer[0..len]);
if (false) print("data count {}\n", .{dout});
if (false) print("data out {any}\n", .{buffer[0..len]});
return .{
.our_random = client_hello.random,
};
}
 
/// Forgive me, I'm tired
@@ -160,13 +167,12 @@ fn readServer(conn: std.net.Stream, server: []u8) !usize {
return s_read;
}
 
fn buildServer(data: []const u8) !void {
var session = State{};
fn buildServer(data: []const u8, ctx: *ConnCtx) !void {
var next_block: []const u8 = data;
 
while (next_block.len > 0) {
if (false) print("server block\n{any}\n", .{next_block});
const tlsr = try TLSRecord.unpack(next_block, &session);
const tlsr = try TLSRecord.unpack(next_block, ctx);
if (false) print("mock {}\n", .{tlsr.length});
next_block = next_block[tlsr.length + 5 ..];
 
@@ -178,9 +184,10 @@ fn buildServer(data: []const u8) !void {
print("server hello {}\n", .{@TypeOf(hello)});
print("srv selected suite {any}\n", .{hello.cipher});
print("test selected suite {any}\n", .{.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256});
if (session.cipher.suite != .ecc) {
if (ctx.cipher.suite != .ecc) {
return error.UnexpectedCipherSuite;
}
ctx.cipher.suite.ecc.cli_material = try Cipher.X25519.KeyPair.create(null);
},
.certificate => |cert| {
print("server cert {}\n", .{@TypeOf(cert)});
@@ -201,9 +208,10 @@ fn buildServer(data: []const u8) !void {
}
}
 
fn completeClient(conn: std.net.Stream) !void {
fn completeClient(conn: std.net.Stream, ctx: *ConnCtx) !void {
var buffer = [_]u8{0} ** 0x1000;
const cke = try Handshake.ClientKeyExchange.init();
 
const cke = try Handshake.ClientKeyExchange.init(ctx);
const cke_record = TLSRecord{
.kind = .{
.handshake = try Handshake.Handshake.wrap(cke),
@@ -217,11 +225,10 @@ fn completeClient(conn: std.net.Stream) !void {
if (true) print("cke delivered, {}\n", .{ckeout});
 
var r_buf: [0x1000]u8 = undefined;
var session = State{};
if (false) { // check for alerts
const num = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num]});
const sin = try TLSRecord.unpack(r_buf[0..num], &session);
const sin = try TLSRecord.unpack(r_buf[0..num], ctx);
print("server thing {}\n", .{sin});
}
 
@@ -246,16 +253,16 @@ fn completeClient(conn: std.net.Stream) !void {
 
const num2 = try conn.read(&r_buf);
print("sin: {any}\n", .{r_buf[0..num2]});
const sin2 = try TLSRecord.unpack(r_buf[0..num2], &session);
const sin2 = try TLSRecord.unpack(r_buf[0..num2], ctx);
print("server thing {}\n", .{sin2});
}
 
fn fullHandshake(conn: std.net.Stream) !void {
try startHandshake(conn);
var ctx = try startHandshake(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
try buildServer(server[0..l], &ctx);
try completeClient(conn, &ctx);
}
 
test "tls" {
@@ -265,15 +272,18 @@ test "tls" {
};
const conn = try net.tcpConnectToAddress(addr);
 
try startHandshake(conn);
var ctx = try startHandshake(conn);
var server: [0x1000]u8 = undefined;
const l = try readServer(conn, &server);
try buildServer(server[0..l]);
try completeClient(conn);
try buildServer(server[0..l], &ctx);
try completeClient(conn, &ctx);
}
 
test "mock server response" {
if (true) return error.SkipZigTest;
 
var ctx = ConnCtx{};
 
// zig fmt: off
const server_data = [_]u8{
22, 3, 3, 0, 74,
@@ -377,7 +387,7 @@ test "mock server response" {
};
// zig fmt: on
 
try buildServer(&server_data);
try buildServer(&server_data, &ctx);
 
//const cke = try ClientKeyExchange.init();
//const record = TLSRecord{