srctree

Gregory Mullen parent 90071865 4b824149
enable domain blocking

src/dns.zig added: 202, removed: 37, total 165
@@ -13,7 +13,7 @@ pub const Cache = struct {
};
 
pub const Zone = struct {
domain: std.StringHashMapUnmanaged(CacheRes),
domains: std.StringHashMapUnmanaged(CacheRes),
};
 
pub const CacheRes = union(enum) {
@@ -268,8 +268,9 @@ pub const Message = struct {
}
 
pub const Iterator = struct {
index: ?usize = 0,
msg: *const Message,
index: usize = 0,
name_buffer: [255]u8 = undefined,
 
pub fn init(msg: *const Message) Iterator {
return .{
@@ -280,10 +281,11 @@ pub const Message = struct {
 
pub fn next(iter: *Iterator) !?Payload {
const h = iter.msg.header;
if (iter.index > h.qdcount + h.ancount + h.nscount + h.arcount) return null;
 
if (iter.index >= h.qdcount + h.ancount + h.nscount + h.arcount) return null;
 
defer iter.index += 1;
return iter.msg.payload(iter.index);
return try iter.msg.payload(iter.index, &iter.name_buffer);
}
};
 
@@ -291,17 +293,15 @@ pub const Message = struct {
return Iterator.init(msg);
}
 
pub fn payload(msg: Message, index: usize) !?Payload {
pub fn payload(msg: Message, index: usize, name_buf: []u8) !Payload {
const payload_end = msg.header.qdcount + msg.header.ancount +
msg.header.nscount + msg.header.arcount;
if (index >= payload_end) return error.InvalidIndex;
 
var name_buf: [128]u8 = undefined;
 
var idx: usize = 12;
for (0..payload_end) |payload_idx| {
if (payload_idx < msg.header.qdcount) {
const name = try Label.getName(&name_buf, msg.bytes, &idx);
const name = try Label.getName(name_buf, msg.bytes, &idx);
log.warn("label name {s}", .{name});
if (payload_idx == index) {
return .{ .question = .{
@@ -310,35 +310,38 @@ pub const Message = struct {
.class = @enumFromInt(@byteSwap(@as(u16, @bitCast(msg.bytes[idx..][2..4].*)))),
} };
} else {
_ = try Label.getName(&name_buf, msg.bytes, &idx);
idx += 4;
}
//log.warn("{any}", .{q.*});
 
} else {
} else if (payload_idx < msg.header.qdcount + msg.header.ancount) {
if (payload_idx == index) {
log.warn("{} {}", .{ idx, msg.bytes[idx] });
const name = try Label.getName(&name_buf, msg.bytes, &idx);
log.warn("{s}", .{name});
log.warn("answer {} {} {}", .{ idx, msg.bytes[idx], msg.bytes.len });
const name = try Label.getName(name_buf, msg.bytes, &idx);
log.warn("answer label {s}", .{name});
const rdlen: u16 = @byteSwap(@as(u16, @bitCast(msg.bytes[idx..][8..10].*)));
if (idx == index) {
if (payload_idx == index) {
const r: Resource = .{
.name = name,
.rtype = @enumFromInt(@byteSwap(@as(u16, @bitCast(msg.bytes[idx..][0..2].*)))),
.class = @enumFromInt(@byteSwap(@as(u16, @bitCast(msg.bytes[idx..][2..4].*)))),
.ttl = @byteSwap(@as(u32, @bitCast(msg.bytes[idx..][4..8].*))),
.rdlength = rdlen,
.rdata = msg.bytes[idx..][10..][0..rdlen],
.addr = switch (rdlen) {
4 => .{ .a = msg.bytes[idx..][10..][0..4].* },
16 => .{ .aaaa = msg.bytes[idx..][10..][0..16].* },
else => return error.InvalidResourceData,
},
};
if (r.rtype != .a) @panic("not implemented");
if (r.rtype != .a and r.rtype != .aaaa) return error.ResponseTypeNotImplemented;
 
return .{ .resource = r };
return .{ .answer = r };
} else {
_ = try Label.getName(&name_buf, msg.bytes, &idx);
idx += @byteSwap(@as(u16, @bitCast(msg.bytes[idx..][8..10].*)));
_ = try Label.getName(name_buf, msg.bytes, &idx);
idx += rdlen;
continue;
}
}
}
}
} else return error.InvalidIndex;
}
 
pub fn query(fqdns: []const []const u8, buffer: []u8) !Message {
@@ -409,7 +412,7 @@ pub const Message = struct {
 
return .{
.header = h,
.bytes = bytes,
.bytes = bytes[0..idx],
};
}
 
@@ -468,7 +471,7 @@ pub const Label = struct {
sw: switch (bytes[idx]) {
0 => {
if (!pointered) index.* = idx + 1;
return try name.items;
return name.items;
},
1...63 => |b| {
idx += b + 1;
@@ -616,6 +619,45 @@ test "build answer" {
}, &buffer);
try std.testing.expectEqual(msg1.header.qdcount, 1);
try std.testing.expectEqual(msg1.header.ancount, 1);
 
var big_buffer: [100]u8 = undefined;
 
const msg2 = try Message.answer(
31337,
&[1][]const u8{"gr.ht."},
&[1]Address{.{ .a = .{ 127, 4, 20, 69 } }},
&big_buffer,
);
try std.testing.expectEqualSlices(u8, &[_]u8{
122, 105, 133, 128, 0, 1, 0, 1, 0, 0, 0, 0,
2, 103, 114, 2, 104, 116, 0, 0, 1, 0, 1, 192,
12, 0, 1, 0, 1, 0, 0, 1, 44, 0, 4, 127,
4, 20, 69,
}, msg2.bytes);
}
 
test "response iter" {
const base = [_]u8{
197, 22, 129, 128, 0, 1, 0, 1, 0, 0, 0, 0,
7, 122, 105, 103, 108, 97, 110, 103, 3, 111, 114, 103,
0, 0, 28, 0, 1, 192, 12, 0, 28, 0, 1, 0,
0, 1, 44, 0, 16, 42, 1, 4, 249, 48, 81, 75,
210, 0, 0, 0, 0, 0, 0, 0, 2,
};
const msg: Message = try .fromBytes(&base);
var it = msg.iterator();
try std.testing.expectEqual(msg.header.qdcount, 1);
try std.testing.expectEqual(msg.header.ancount, 1);
var count: usize = 2;
while (try it.next()) |r| {
switch (r) {
.question => try std.testing.expectEqual(count, 2),
.answer => try std.testing.expectEqual(count, 1),
}
count -%= 1;
}
 
try std.testing.expectEqual(count, 0);
}
 
test "build answerDrop" {
 
src/server.zig added: 202, removed: 37, total 165
@@ -2,16 +2,17 @@ fn usage() !void {}
 
pub fn main() !void {
const a = std.heap.page_allocator;
// TODO smp alloc
 
var blocks: ?[]const u8 = null;
var blocked_ips: std.ArrayListUnmanaged([4]u8) = .{};
var blocked_domains: std.ArrayListUnmanaged([]const u8) = .{};
 
var argv = std.process.args();
while (argv.next()) |arg| {
if (std.mem.eql(u8, arg, "--block")) {
blocks = argv.next() orelse @panic("invalid argv --block");
}
if (std.mem.eql(u8, arg, "--drop-ip")) {
} else if (std.mem.eql(u8, arg, "--drop-ip")) {
const ip_str = argv.next() orelse @panic("invalid argv --drop-ip");
var ip: [4]u8 = undefined;
var itr = std.mem.splitScalar(u8, ip_str, '.');
@@ -19,6 +20,9 @@ pub fn main() !void {
oct.* = std.fmt.parseInt(u8, itr.next() orelse "0", 10) catch 0;
}
try blocked_ips.append(a, ip);
} else if (std.mem.eql(u8, arg, "--drop-domain")) {
const domain_str = argv.next() orelse @panic("invalid argv --drop-domain");
try blocked_domains.append(a, try a.dupe(u8, domain_str));
}
}
 
@@ -37,13 +41,26 @@ pub fn main() !void {
.tld = .{},
};
 
try cache.tld.put(a, "com", .{ .domain = .{} });
try cache.tld.put(a, "ht", .{ .domain = .{} });
try cache.tld.put(a, "com", .{ .domains = .{} });
try cache.tld.put(a, "net", .{ .domains = .{} });
try cache.tld.put(a, "org", .{ .domains = .{} });
try cache.tld.put(a, "ht", .{ .domains = .{} });
 
if (blocks) |b| {
a.free(try parse(a, b));
}
 
for (blocked_domains.items) |dd| {
const domain: Domain = .init(dd);
log.err("tld {s}", .{domain.tld});
var tld = cache.tld.getPtr(domain.tld).?;
log.err("zone {s}", .{domain.zone});
try tld.domains.put(a, domain.zone, .{ .static = .{
.ttl = 0xffffffff,
.addr = .{ .a = .{ 0, 0, 0, 0 } },
} });
}
 
var upconns: [4]DNS.Peer = undefined;
for (&upconns, upstreams) |*dst, ip| {
dst.* = try .connect(ip, 53);
@@ -55,7 +72,7 @@ pub fn main() !void {
//const msgsize = try msg.write(&request);
 
var timer: std.time.Timer = try .start();
while (true) {
root: while (true) {
var addr: std.net.Address = .{ .in = .{ .sa = .{ .port = 0, .addr = 0 } } };
var buffer: [1024]u8 = undefined;
const icnt = try downstream.recvFrom(&buffer, &addr);
@@ -65,30 +82,136 @@ pub fn main() !void {
log.err("received from {any}", .{addr.in});
 
const msg = try DNS.Message.fromBytes(buffer[0..icnt]);
//log.err("data {any}", .{msg});
_ = msg;
var address_bufs: [16]DNS.Address = undefined;
var addresses: std.ArrayListUnmanaged(DNS.Address) = .{
.items = &address_bufs,
.capacity = address_bufs.len,
};
addresses.items.len = 0;
 
var qdomains: [16][255]u8 = undefined;
var dbuf: [16][]const u8 = undefined;
var domains: std.ArrayListUnmanaged([]const u8) = .{
.items = &dbuf,
.capacity = dbuf.len,
};
domains.items.len = 0;
 
if (msg.header.qdcount >= 16) {
log.err("dropping invalid msg", .{});
log.debug("that message {any}", .{buffer[0..icnt]});
continue;
}
 
var iter = msg.iterator();
while (iter.next() catch |err| e: {
log.err("question iter error {}", .{err});
log.debug("qdata {any}", .{buffer[0..icnt]});
break :e null;
}) |pay| switch (pay) {
.question => |q| {
log.err("name {s}", .{q.name});
@memcpy(qdomains[iter.index][0..q.name.len], q.name);
domains.appendAssumeCapacity(qdomains[iter.index][0..q.name.len]);
 
const domain: Domain = .init(q.name);
const tld: *DNS.Zone = f: {
if (cache.tld.getOrPut(a, domain.tld)) |goptr| {
if (!goptr.found_existing) {
goptr.value_ptr.* = .{ .domains = .{} };
}
break :f goptr.value_ptr;
} else |err| {
log.err("hash error {}", .{err});
break;
}
};
 
if (tld.domains.getPtr(domain.zone)) |zone| switch (zone.*) {
.static => {
var ans_bytes: [512]u8 = undefined;
if (msg.header.qdcount == 1) {
const ans: DNS.Message = try .answerDrop(
msg.header.id,
q.name,
&ans_bytes,
);
try downstream.sendTo(addr, ans.bytes);
log.err("responded {d}", .{@as(f64, @floatFromInt(timer.lap())) / 1000});
continue :root;
}
},
else => log.err("zone {}", .{zone}),
};
 
addresses.appendAssumeCapacity(.{ .a = .{ 127, 0, 0, 1 } });
},
.answer => break,
};
 
var answer_bytes: [512]u8 = undefined;
const answer = try DNS.Message.answer(
msg.header.id,
domains.items,
addresses.items,
&answer_bytes,
);
 
log.err("answer = {}", .{answer});
 
log.info("bounce", .{});
up_idx +%= 1;
try upconns[up_idx].send(buffer[0..icnt]);
var relay_buf: [1024]u8 = undefined;
const b_cnt = try upconns[up_idx].recv(&relay_buf);
const relayed = relay_buf[0..b_cnt];
log.info("bounce received {}", .{b_cnt});
log.debug("bounce data {any}", .{relay_buf[0..b_cnt]});
log.debug("bounce data {any}", .{relayed});
 
for (blocked_ips.items) |banned| {
if (std.mem.eql(u8, relay_buf[b_cnt - 4 .. b_cnt], &banned)) {
@memset(relay_buf[b_cnt - 4 .. b_cnt], 0);
if (std.mem.eql(u8, relayed[relayed.len - 4 .. relayed.len], &banned)) {
@memset(relayed[relayed.len - 4 .. relayed.len], 0);
}
}
 
try downstream.sendTo(addr, relay_buf[0..b_cnt]);
log.err("responded {}", .{@as(f64, @floatFromInt(timer.lap())) / 1000});
log.err("responded {d}", .{@as(f64, @floatFromInt(timer.lap())) / 1000});
 
const rmsg: DNS.Message = try .fromBytes(relayed);
var rit = rmsg.iterator();
 
while (rit.next() catch |err| e: {
log.err("relayed iter error {}", .{err});
log.debug("rdata {any}", .{relayed});
break :e null;
}) |pay| switch (pay) {
.question => |q| log.err("r question = {}", .{q}),
.answer => |r| log.err("r answer = {}", .{r}),
};
}
 
log.err("done", .{});
}
 
pub const Domain = struct {
tld: []const u8,
zone: []const u8,
 
pub fn init(dn: []const u8) Domain {
if (dn.len < 3) return .{ .tld = dn, .zone = "" };
var bit = std.mem.splitBackwardsScalar(
u8,
if (dn[dn.len - 1] == '.') dn[0 .. dn.len - 1] else dn,
'.',
);
 
return .{
.tld = bit.first(),
.zone = bit.rest(),
};
}
};
 
const upstreams: [4][4]u8 = .{
.{ 1, 1, 1, 1 },
.{ 1, 0, 0, 1 },