@@ -0,0 +1,212 @@
//! See https://tools.ietf.org/html/rfc6455
const builtin = @import("builtin");
const std = @import("std");
const WebSocket = @This();
const assert = std.debug.assert;
const native_endian = builtin.cpu.arch.endian();
key: []const u8,
request: *std.http.Server.Request,
recv_fifo: std.fifo.LinearFifo(u8, .Slice),
reader: std.io.AnyReader,
response: std.http.Server.Response,
pub const InitError = error{WebSocketUpgradeMissingKey} ||
std.http.Server.Request.ReaderError;
pub fn init(
ws: *WebSocket,
request: *std.http.Server.Request,
send_buffer: []u8,
recv_buffer: []align(4) u8,
) InitError!bool {
var sec_websocket_key: ?[]const u8 = null;
var upgrade_websocket: bool = false;
var it = request.iterateHeaders();
while (it.next()) |header| {
if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
sec_websocket_key = header.value;
} else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
if (!std.mem.eql(u8, header.value, "websocket"))
return false;
upgrade_websocket = true;
}
}
if (!upgrade_websocket)
return false;
const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey;
var sha1 = std.crypto.hash.Sha1.init(.{});
sha1.update(key);
sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
sha1.final(&digest);
var base64_digest: [28]u8 = undefined;
assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
request.head.content_length = std.math.maxInt(u64);
ws.* = .{
.key = key,
.recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer),
.reader = try request.reader(),
.response = request.respondStreaming(.{
.send_buffer = send_buffer,
.respond_options = .{
.status = .switching_protocols,
.extra_headers = &.{
.{ .name = "upgrade", .value = "websocket" },
.{ .name = "connection", .value = "upgrade" },
.{ .name = "sec-websocket-accept", .value = &base64_digest },
},
.transfer_encoding = .none,
},
}),
.request = request,
};
return true;
}
pub const Header0 = packed struct(u8) {
opcode: Opcode,
rsv3: u1 = 0,
rsv2: u1 = 0,
rsv1: u1 = 0,
fin: bool,
};
pub const Header1 = packed struct(u8) {
payload_len: enum(u7) {
len16 = 126,
len64 = 127,
_,
},
mask: bool,
};
pub const Opcode = enum(u4) {
continuation = 0,
text = 1,
binary = 2,
connection_close = 8,
ping = 9,
pong = 10,
_,
};
pub const ReadSmallTextMessageError = error{
ConnectionClose,
UnexpectedOpCode,
MessageTooBig,
};
/// Reads the next message from the WebSocket stream, failing if the message does not fit
/// into `recv_buffer`.
pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError![]u8 {
const header_bytes = (try recv(ws, 2))[0..2];
const h0: Header0 = @bitCast(header_bytes[0]);
const h1: Header1 = @bitCast(header_bytes[1]);
switch (h0.opcode) {
.text, .binary => {},
.connection_close => return error.ConnectionClose,
else => return error.UnexpectedOpCode,
}
if (!h0.fin) return error.MessageTooBig;
if (!h1.mask) return error.MissingMaskBit;
const len: usize = switch (h1.payload_len) {
.len16 => try recvReadInt(ws, u16),
.len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
else => @intFromEnum(h1.payload_len),
};
if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
const payload = try recv(ws, len);
// The last item may contain a partial word of unused data.
const u32_payload: []u32 = std.mem.bytesAsSlice(u32, payload);
for (u32_payload) |*elem| elem.* ^= mask;
return payload;
}
fn recv(ws: *WebSocket, n: usize) ![]u8 {
const result = try recvPeek(ws, n);
ws.recv_fifo.discard(n);
return result;
}
fn recvPeek(ws: *WebSocket, n: usize) ![]u8 {
assert(n <= ws.recv_fifo.buf.len);
if (n > ws.recv_fifo.count) {
const small_buf = ws.recv_fifo.writableSlice();
const needed = n - ws.recv_fifo.count;
const buf = if (small_buf.len >= needed) small_buf else b: {
ws.recv_fifo.realign();
break :b ws.recv_fifo.writableSlice();
};
try ws.reader.readNoEof(buf);
ws.recv_fifo.update(buf.len);
}
return ws.recv_fifo.readableSliceOfLen(n);
}
fn recvReadInt(ws: *WebSocket, comptime I: type) I {
const unswapped: I = @bitCast((try recv(ws, @sizeOf(I))[0..@sizeOf(I)]).*);
return switch (native_endian) {
.little => @byteSwap(unswapped),
.big => unswapped,
};
}
pub const WriteError = std.http.Server.Response.WriteError;
pub fn writeMessage(ws: *WebSocket, message: []const std.posix.iovec_const) WriteError!void {
const total_len = l: {
var total_len: u64 = 0;
for (message) |iovec| total_len += iovec.iov_len;
break :l total_len;
};
var header_buf: [2 + 8]u8 = undefined;
header_buf[0] = @bitCast(@as(Header0, .{
.opcode = .binary,
.fin = true,
}));
const header = switch (total_len) {
0...125 => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = @enumFromInt(total_len),
.mask = false,
}));
break :blk header_buf[0..2];
},
126...0xffff => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = .len16,
.mask = false,
}));
std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big);
break :blk header_buf[0..4];
},
else => blk: {
header_buf[1] = @bitCast(@as(Header1, .{
.payload_len = .len64,
.mask = false,
}));
std.mem.writeInt(u64, header_buf[2..10], total_len, .big);
break :blk header_buf[0..10];
},
};
const response = &ws.response;
try response.writeAll(header);
for (message) |iovec|
try response.writeAll(iovec.iov_base[0..iovec.iov_len]);
try response.flush();
}