@@ -11,6 +11,8 @@ request: *std.http.Server.Request,
recv_fifo: std.fifo.LinearFifo(u8, .Slice),
reader: std.io.AnyReader,
response: std.http.Server.Response,
/// Number of bytes that have been peeked but not discarded yet.
outstanding_len: usize,
pub const InitError = error{WebSocketUpgradeMissingKey} ||
std.http.Server.Request.ReaderError;
@@ -65,6 +67,7 @@ pub fn init(
},
}),
.request = request,
.outstanding_len = 0,
};
return true;
}
@@ -92,6 +95,8 @@ pub const Opcode = enum(u4) {
binary = 2,
connection_close = 8,
ping = 9,
/// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
/// heartbeat. A response to an unsolicited Pong frame is not expected."
pong = 10,
_,
};
@@ -106,58 +111,63 @@ pub const ReadSmallTextMessageError = error{
/// Reads the next message from the WebSocket stream, failing if the message does not fit
/// into `recv_buffer`.
pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError![]u8 {
const header_bytes = (try recv(ws, 2))[0..2];
const h0: Header0 = @bitCast(header_bytes[0]);
const h1: Header1 = @bitCast(header_bytes[1]);
while (true) {
const header_bytes = (try recv(ws, 2))[0..2];
const h0: Header0 = @bitCast(header_bytes[0]);
const h1: Header1 = @bitCast(header_bytes[1]);
switch (h0.opcode) {
.text, .binary => {},
.connection_close => return error.ConnectionClose,
else => return error.UnexpectedOpCode,
switch (h0.opcode) {
.text, .binary, .pong => {},
.connection_close => return error.ConnectionClose,
else => return error.UnexpectedOpCode,
}
if (!h0.fin) return error.MessageTooBig;
if (!h1.mask) return error.MissingMaskBit;
const len: usize = switch (h1.payload_len) {
.len16 => try recvReadInt(ws, u16),
.len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
else => @intFromEnum(h1.payload_len),
};
if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
const payload = try recv(ws, len);
// Skip pongs.
if (h0.opcode == .pong) continue;
// The last item may contain a partial word of unused data.
const floored_len = (payload.len / 4) * 4;
const u32_payload: []align(2) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len]));
for (u32_payload) |*elem| elem.* ^= mask;
const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len];
for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m;
return payload;
}
if (!h0.fin) return error.MessageTooBig;
if (!h1.mask) return error.MissingMaskBit;
const len: usize = switch (h1.payload_len) {
.len16 => try recvReadInt(ws, u16),
.len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
else => @intFromEnum(h1.payload_len),
};
if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
const payload = try recv(ws, len);
// The last item may contain a partial word of unused data.
const u32_payload: []align(2) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload));
for (u32_payload) |*elem| elem.* ^= mask;
return payload;
}
fn recv(ws: *WebSocket, n: usize) ![]u8 {
const result = try recvPeek(ws, n);
ws.recv_fifo.discard(n);
return result;
}
const RecvError = std.http.Server.Request.ReadError || error{EndOfStream};
fn recvPeek(ws: *WebSocket, n: usize) RecvError![]u8 {
assert(n <= ws.recv_fifo.buf.len);
if (n > ws.recv_fifo.count) {
fn recv(ws: *WebSocket, len: usize) RecvError![]u8 {
ws.recv_fifo.discard(ws.outstanding_len);
assert(len <= ws.recv_fifo.buf.len);
if (len > ws.recv_fifo.count) {
const small_buf = ws.recv_fifo.writableSlice(0);
const needed = n - ws.recv_fifo.count;
const needed = len - ws.recv_fifo.count;
const buf = if (small_buf.len >= needed) small_buf else b: {
ws.recv_fifo.realign();
break :b ws.recv_fifo.writableSlice(0);
};
try @as(RecvError!void, @errorCast(ws.reader.readNoEof(buf)));
ws.recv_fifo.update(buf.len);
const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed)));
if (n < needed) return error.EndOfStream;
ws.recv_fifo.update(n);
}
ws.outstanding_len = len;
// TODO: improve the std lib API so this cast isn't necessary.
return @constCast(ws.recv_fifo.readableSliceOfLen(n));
return @constCast(ws.recv_fifo.readableSliceOfLen(len));
}
fn recvReadInt(ws: *WebSocket, comptime I: type) !I {