--[[
an asynchronous socket library.
if i feel like making a full executor work this might get more complicated,
but for now it's a standalone function.
]]
-- c imports
## cinclude '<sys/socket.h>'
local AF_INET: cint <cimport, nodecl>
local AF_INET6: cint <cimport, nodecl>
-- local AF_UNIX: cint <cimport, nodecl>
local SOCK_STREAM: cint <cimport, nodecl>
-- local SOCK_DGRAM: cint <cimport, nodecl>
-- local SOCK_RAW: cint <cimport, nodecl>
local MSG_DONTWAIT: cint <cimport, nodecl>
local MSG_PEEK: cint <cimport, nodecl>
local function c_socket(domain: cint, type: cint, protocol: cint): cint <cimport 'socket', nodecl> end
local function c_accept(fd: cint, address: pointer, address_len: *cint): cint <cimport 'accept', nodecl> end
local function c_listen(fd: cint, backlog: cint): cint <cimport 'listen', nodecl> end
local function c_send(fd: cint, buf: *[0]byte, len: usize, flags: cint): isize <cimport 'send', nodecl> end
local function c_recv(fd: cint, buf: *[0]byte, len: usize, flags: cint): isize <cimport 'recv', nodecl> end
## cinclude '<arpa/inet.h>'
local function c_inet_pton(af: cint, src: cstring <const>, dst: pointer): cint <cimport 'inet_pton', nodecl> end
local function c_htons(hostshort: uint16): uint16 <cimport 'htons'> end
local function c_ntohs(hostshort: uint16): uint16 <cimport 'ntohs'> end
## cinclude '<sys/epoll.h>'
local EPOLL_CLOEXEC: cint <cimport, nodecl>
local EPOLL_CTL_ADD: cint <cimport, nodecl>
local EPOLL_CTL_MOD: cint <cimport, nodecl>
local EPOLL_CTL_DEL: cint <cimport, nodecl>
## for i,epoll_val in ipairs({ 'IN', 'OUT', 'RDHUP', 'PRI', 'ERR', 'HUP', 'ET'}) do
local #|'EPOLL' .. epoll_val|# : uint32 <cimport, nodecl>
## end
local epoll_data = @union{
ptr: pointer,
fd: cint,
u32: uint32,
u64: uint64
}
local epoll_event = @record{
events: uint32,
data: epoll_data
}
local function c_epoll_create(flags: cint): cint <cimport 'epoll_create1', nodecl> end
local function c_epoll_ctl(epfd: cint, op: cint, fd: cint, event: *epoll_event): cint <cimport 'epoll_ctl', nodecl> end
local function c_epoll_wait(epfd: cint, events: *[0]epoll_event, maxevents: cint, timeout: cint): cint <cimport 'epoll_wait', nodecl> end
## cinclude '<unistd.h>'
local function c_close(fd: cint): cint <cimport 'close', nodecl> end
## cinclude '<netinet/in.h>'
-- end c imports
local export = @record{}
require 'string'
require 'hashmap'
require 'stringbuilder'
require 'coroutine'
require 'math'
require 'C.stdio'
require 'allocators.default'
local function fakeuse(...: varargs) end
-- 192.168.2.1:2000 -> 192.168.2.1,2000
local inet_re = '^(%d?%d?%d%.%d?%d?%d%.%d?%d?%d%.%d?%d?%d):(%d?%d?%d?%d?%d)$'
-- {::1}:2000 -> ::1,2000
local inet6_re = '^{([^}]+)}:(%d?%d?%d?%d?%d)$'
local function die_errno(cond: boolean, msg: string): void
if cond then
C.perror(nilptr)
error(msg)
end
end
local handler_req = @enum{
close = 0,
read_line,
write
}
local handler_state = @record{
fd: cint,
co: coroutine,
last_req: handler_req,
buf: stringbuilder
}
function handler_state:step(epfd: cint, key: uint32): (boolean)
local should_call = false
switch self.last_req do
case handler_req.write then fallthrough
case handler_req.close then
should_call = true
case handler_req.read_line then
local maxread <comptime> = 4096
local buf: [maxread]byte = {}
local len = c_recv(self.fd, &buf, maxread, MSG_PEEK)
if len == -1 then
local errmsg = "Internal error, please try again later.\n"
C.perror("Error reading for handler")
c_send(self.fd, errmsg.data, errmsg.size, 0)
c_epoll_ctl(epfd, EPOLL_CTL_DEL, self.fd, nilptr)
c_close(self.fd)
return false
end
local lf_idx = -1
for i=0,<len do
if buf[i] == '\n'_b then
lf_idx = i
break
end
end
if lf_idx ~= -1 then
-- the kernel will notify us again if there's more data
len = c_recv(self.fd, &buf, lf_idx + 1, 0)
self.buf:write((@span(byte))({data = &buf, size = lf_idx}))
self.co:push(self.buf:promote())
self.buf = stringbuilder()
should_call = true
else
len = c_recv(self.fd, &buf, maxread, 0)
self.buf:write((@span(byte))({data = &buf, size = lf_idx}))
end
end
if should_call then
local req: handler_req
while true do
local ok, err = self.co:resume()
if not ok then
local errmsg = "Internal error, please try again later.\n"
print("Error in handler:", err)
c_send(self.fd, errmsg.data, errmsg.size, 0)
c_epoll_ctl(epfd, EPOLL_CTL_DEL, self.fd, nilptr)
c_close(self.fd)
return false
end
self.co:pop(&req)
if req == handler_req.write then
local data: string
self.co:pop(&data)
local len = c_send(self.fd, data.data, data.size, 0)
if len == -1 then
C.perror("Error writing for handler")
c_epoll_ctl(epfd, EPOLL_CTL_DEL, self.fd, nilptr)
c_close(self.fd)
return false
end
else break end
end
if req == handler_req.close then
c_epoll_ctl(epfd, EPOLL_CTL_DEL, self.fd, nilptr)
c_close(self.fd)
return false
end
if req == handler_req.read_line then
local event: epoll_event = {
events = EPOLLIN,
data = { u32 = key }
}
c_epoll_ctl(epfd, EPOLL_CTL_MOD, self.fd, &event)
else
local event: epoll_event = {
events = 0,
data = { u32 = key }
}
c_epoll_ctl(epfd, EPOLL_CTL_MOD, self.fd, &event)
end
self.last_req = req
end
return true
end
function export.listen_sock(sock: cint, handler: function(): void): void
die_errno(c_listen(sock, 32) ~= 0, "couldn't listen on TCP socket")
local epfd = c_epoll_create(0)
die_errno(epfd == -1, "couldn't create epoll instance")
local handlers: hashmap(uint32, handler_state)
local sock_event: epoll_event = {
events = EPOLLIN,
data = { u32 = 0 }
}
c_epoll_ctl(epfd, EPOLL_CTL_ADD, sock, &sock_event)
while true do
local maxevents <comptime> = 8
local events: [maxevents]epoll_event = {}
local event_count = c_epoll_wait(epfd, &events, maxevents, -1)
die_errno(event_count == -1, "couldn't wait on epoll instance")
for i = 0, < event_count do
if events[i].data.u32 == 0 then
local fd = c_accept(sock, nilptr, nilptr)
if #handlers >= 1024 then -- drop connections if allocation fails
local errmsg = "The server is overloaded, please try again later.\n"
c_send(fd, errmsg.data, errmsg.size, 0)
c_close(fd)
continue
end
local key: uint32 = math.random(0_u32, (@uint32)(2^32-1))
local state: handler_state = {
fd = fd,
co = coroutine.create(handler),
last_req = handler_req.close
}
local fd_event: epoll_event = { data = { u32 = key } }
c_epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &fd_event)
local ok = state:step(epfd, key)
if ok then
handlers[key] = state
else
c_epoll_ctl(epfd, EPOLL_CTL_DEL, fd, nilptr)
local errmsg = "Internal error, please try again later.\n"
c_send(fd, errmsg.data, errmsg.size, 0)
c_close(fd)
end
else
if not handlers[events[i].data.u32]:step(epfd, events[i].data.u32) then
handlers:remove(events[i].data.u32)
end
end
end
end
end
function export.listen_tcp(address: string, handler: function(): void): void
local matched, matches = string.match(address, inet_re)
if matched then
local c_err: cint = 0
local s_addr, s_port = matches:unpack(1, 2)
local addr: uint32
die_errno(c_inet_pton(AF_INET, (@cstring)(s_addr), &addr) <= 0,
"bad IPv4 address")
local i_port = tointeger(s_port)
assert((i_port >= 0) and (i_port < 65535), "port not within range [0,65536)")
local port: uint16 = c_htons(i_port)
local fd = c_socket(AF_INET, SOCK_STREAM, 0)
die_errno(fd == -1, "couldn't open TCP socket")
fakeuse(port)
##[==[ cemit [[
struct sockaddr_in sa;
memset(&sa, 0, sizeof(sa));
sa.sin_family = AF_INET;
sa.sin_port = port;
sa.sin_addr.s_addr = addr;
c_err = bind(fd, &sa, sizeof(sa));
]] ]==]
die_errno(c_err == -1, "couldn't bind TCP socket")
defer c_close(fd) end
export.listen_sock(fd, handler)
end
end
function export.end_conn(): void
coroutine.yield(handler_req.close)
end
function export.send(line: string): void
coroutine.running():push(line)
coroutine.yield(handler_req.write)
end
function export.send_line(line: string): void
coroutine.running():push(line.."\n")
coroutine.yield(handler_req.write)
end
function export.recv_line(): string
local line: string
coroutine.yield(handler_req.read_line)
coroutine.running():pop(&line)
return line
end
return export