diff options
Diffstat (limited to 'server.c')
-rw-r--r-- | server.c | 259 |
1 files changed, 177 insertions, 82 deletions
@@ -1,30 +1,38 @@ -#include "arena.h" -#include "common.h" +#include "chatty.h" #include <assert.h> #include <netinet/in.h> #include <poll.h> #include <stdarg.h> +#include <string.h> #include <sys/socket.h> +#include <unistd.h> // timeout on polling #define TIMEOUT 60 * 1000 // max pending connections -#define PENDING_MAX 16 - -// the size of pollfd element in the fdsArena -// note: clientsArena and pollfd_size must have been initialisezd -#define FDS_SIZE fdsArena->pos / pollfd_size +#define MAX_CONNECTIONS 16 +// Get number of connections from arena position +// NOTE: this is somewhat wrong, because of when disconnections happen +#define FDS_SIZE (fdsArena->pos / sizeof(*fds)) // enum for indexing the fds array enum { FDS_STDIN = 0, FDS_SERVER, FDS_CLIENTS }; -int main(void) +// Has information on clients +// For each pollfd in fds there should be a matching client in clients +// clients[i - FDS_CLIENTS] <=> fds[i] +typedef struct { + u8 author[AUTHOR_LEN]; // matches author property on other message types + Bool initialized; // boolean +} Client; + +int +main(void) { - u32 err, serverfd, clientfd; - u16 nclient = 0; + s32 err, serverfd, clientfd; u32 on = 1; // Start listening on the socket @@ -32,7 +40,7 @@ int main(void) serverfd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP); assert(serverfd > 2); - err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8 *)&on, sizeof(on)); + err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8*)&on, sizeof(on)); assert(err == 0); const struct sockaddr_in address = { @@ -41,36 +49,39 @@ int main(void) {0}, }; - err = bind(serverfd, (const struct sockaddr *)&address, sizeof(address)); + err = bind(serverfd, (const struct sockaddr*)&address, sizeof(address)); assert(err == 0); - err = listen(serverfd, PENDING_MAX); + err = listen(serverfd, MAX_CONNECTIONS); assert(err == 0); } - Arena *msgTextArena = ArenaAlloc(); // allocating text in messages that have a dynamic sized - Message mrecv = {0}; // message used for receiving messages from clients - u32 nrecv = 0; // number of bytes received - u32 recv_len; // Number of bytes of the message received over stream - u32 nsend = 0; // number of bytes sent - Arena *bufArena = ArenaAlloc(); // data in buf - u8 *buf = ArenaPush(bufArena, STREAM_LIMIT); // temporary buffer for receiving and sending data - Message *mbuf = (Message *)buf; // pointer for indexing buf as a message - - Arena *fdsArena = ArenaAlloc(); // arena for fds to accomodate multiple clients - struct pollfd *fds = fdsArena->memory; // helper for indexing memory - struct pollfd c = {0, POLLIN, 0}; // helper client structure fore reusing - struct pollfd *fdsAddr; // used for copying clients - const u64 pollfd_size = sizeof(struct pollfd); + Arena* msgsArena = ArenaAlloc(Megabytes(128)); // storing received messages + // NOTE: sent messages? + s32 nrecv = 0; // number of bytes received + s32 nsend = 0; // number of bytes sent + + Arena* clientsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(Client)); + Arena* fdsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(struct pollfd)); + struct pollfd* fds = fdsArena->addr; + Client* clients = clientsArena->addr; + + struct pollfd* fdsAddr; + struct pollfd newpollfd = {-1, POLLIN, 0}; // initialize fds structure - // add stdin (c.fd == 0) - fdsAddr = ArenaPush(fdsArena, pollfd_size); - memcpy(fdsAddr, &c, pollfd_size); + newpollfd.fd = 0; + fdsAddr = ArenaPush(fdsArena, sizeof(*fds)); + memcpy(fdsAddr, &newpollfd, sizeof(*fds)); // add serverfd - c.fd = serverfd; - fdsAddr = ArenaPush(fdsArena, pollfd_size); - memcpy(fdsAddr, &c, pollfd_size); + newpollfd.fd = serverfd; + fdsAddr = ArenaPush(fdsArena, sizeof(*fds)); + memcpy(fdsAddr, &newpollfd, sizeof(*fds)); + newpollfd.fd = -1; + + // Initialize the rest of the fds array + for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++) + fds[i] = newpollfd; while (1) { err = poll(fds, FDS_SIZE, TIMEOUT); @@ -83,81 +94,165 @@ int main(void) clientfd = accept(serverfd, NULL, NULL); assert(clientfd != -1); assert(clientfd > serverfd); + fprintf(stdout, "New connection(%d).\n", clientfd); // fill up a hole - u8 found = 0; - for (u32 i = FDS_CLIENTS; i < FDS_SIZE; i++) { - if (fds[i].fd == -1) { - fds[i].fd = clientfd; - // note we do not have to reset .revents because poll will set it to 0 next time - found = 1; + u8 found; + for (found = FDS_CLIENTS; found < FDS_SIZE; found++) + if (fds[found].fd == -1) break; - } - } - - // allocate an extra client because there was no empty spot in the fds array - if (!found) { - // add client to arena - fdsAddr = ArenaPush(fdsArena, pollfd_size); - c.fd = clientfd; - memcpy(fdsAddr, &c, pollfd_size); + if (found == FDS_SIZE) { + // no more space, allocate + struct pollfd* pollfd = ArenaPush(fdsArena, sizeof(*pollfd)); + pollfd->fd = clientfd; + pollfd->events = POLLIN; + } else if (found == MAX_CONNECTIONS) { + // TODO: reject connection + close(clientfd); + fprintf(stdout, "Max clients reached."); + } else { + // hole found + fds[found].fd = clientfd; + fds[found].events = POLLIN; + fprintf(stdout, "Added pollfd(%d).\n", clientfd); } - - nclient++; - fprintf(stdout, "connected(%d).\n", clientfd - serverfd); } + // Check for messages from clients for (u32 i = FDS_CLIENTS; i < (FDS_SIZE); i++) { if (!(fds[i].revents & POLLIN)) continue; - if (fds[i].fd == -1) + assert(fds[i].fd != -1); + + fprintf(stdout, "Message(%d).\n", fds[i].fd); + // If this is the first message from the client it must be a presence message indicated + // it connected. + Client* client = clients + i - FDS_CLIENTS; + if (!client->initialized) { + fprintf(stdout, " Adding to clients(%d).\n", fds[i].fd); + // Wait for PresenceMessage from new client to get author information + HeaderMessage header; + // TODO: handle wrong message, disconnection, etc. + nrecv = recv(clientfd, &header, sizeof(header), 0); + assert(nrecv != -1); + assert(nrecv == sizeof(header)); + if (header.type != HEADER_TYPE_PRESENCE) { + // TODO: reject connection + close(clientfd); + continue; + } + fprintf(stdout, " Got header(%d).\n", fds[i].fd); + + PresenceMessage message; + // TODO: handle wrong message + nrecv = recv(clientfd, &message, sizeof(message), 0); + assert(nrecv != -1); + assert(nrecv == sizeof(message)); + fprintf(stdout, " Got presence message(%d).\n", fds[i].fd); + + memcpy(client->author, message.author, AUTHOR_LEN); + client->initialized = True; + + fprintf(stdout, " Added to clients(%d): %s\n", fds[i].fd, client->author); + + // Notify other clients from this new one + // Reuse header and message + for (u32 j = FDS_CLIENTS; j < (FDS_SIZE); j++) { + if (fds[j].fd == fds[i].fd) + continue; + if (fds[j].fd == -1) + continue; + fprintf(stdout, " Notifying (%d)\n", fds[j].fd); + + nsend = send(fds[j].fd, &header, sizeof(header), 0); + assert(nsend != -1); + assert(nsend == sizeof(header)); + nsend = send(fds[j].fd, &message, sizeof(message), 0); + assert(nsend != -1); + assert(nsend == sizeof(message)); + } continue; + } - nrecv = recv(fds[i].fd, buf, bufArena->pos, 0); + // We received a message, try to parse the header + HeaderMessage header; + nrecv = recv(fds[i].fd, &header, sizeof(header), 0); assert(nrecv != -1); if (nrecv == 0) { - fprintf(stdout, "disconnected(%d). \n", fds[i].fd - serverfd); + fprintf(stdout, "Disconnected(%d). \n", fds[i].fd); shutdown(fds[i].fd, SHUT_RDWR); - close(fds[i].fd); // send close to client - fds[i].fd = -1; // ignore in the future + close(fds[i].fd); // send close to client + fds[i].fd = -1; // ignore in the future + clients[i - FDS_CLIENTS].initialized = False; // deinitialize client + // + // Send disconnection to other connected clients + HeaderMessage header = HEADER_PRESENCEMESSAGE; + PresenceMessage message = { + .type = PRESENCE_TYPE_DISCONNECTED + }; + memcpy(message.author, clients[i - FDS_CLIENTS].author, AUTHOR_LEN); + for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) { + if (fds[j].fd == fds[i].fd) + continue; + if (fds[j].fd == -1) + continue; + nsend = send(fds[j].fd, &header, sizeof(header), 0); + assert(nsend != -1); + assert(nsend == sizeof(header)); + nsend = send(fds[j].fd, &message, sizeof(message), 0); + assert(nsend != -1); + assert(nsend == sizeof(message)); + } + continue; } - recv_len = sizeof(*mbuf) - sizeof(mbuf->text) + mbuf->text_len * sizeof(*mbuf->text); - if (recv_len > nrecv) { - // allocate needed space for buf - if (recv_len > bufArena->pos) - ArenaPush(bufArena, recv_len - bufArena->pos); - - // receive remaining bytes - u32 nr = recv(fds[i].fd, buf + nrecv, recv_len - nrecv, 0); - assert(nr != -1); - nrecv += nr; - assert(nrecv == recv_len); - } + assert(nrecv == sizeof(header)); + fprintf(stderr, " Received(%d): %d bytes -> " PH_FMT "\n", fds[i].fd, nrecv, PH_ARG(header)); - // TODO: Do not print the message in the logs - fprintf(stdout, "message(%d): %d bytes.\n", fds[i].fd - serverfd, nrecv); + switch (header.type) { + case HEADER_TYPE_TEXT:; + TextMessage* message; + nrecv = recvTextMessage(msgsArena, fds[i].fd, &message); + fprintf(stderr, " Received(%d): %d bytes -> ", fds[i].fd, nrecv); + printTextMessage(message, 0); - for (u32 j = FDS_CLIENTS; j < (FDS_SIZE); j++) { - if (j == i) - continue; - if (fds[j].fd == -1) - continue; + HeaderMessage header = HEADER_TEXTMESSAGE; + // Send message to all other clients + for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) { + if (fds[j].fd == fds[i].fd) continue; + if (fds[j].fd == -1) continue; - nsend = send(fds[j].fd, buf, nrecv, 0); - assert(nsend != 1); - assert(nsend == nrecv); - fprintf(stdout, "retransmitted(%d->%d).\n", fds[i].fd - serverfd, fds[j].fd - serverfd); - } + // NOTE: I wonder if this is more expensive than constructing a buffer and sending + // that + u32 nsend_total = 0; + nsend = send(fds[j].fd, &header, sizeof(header), 0); + assert(nsend != 1); + assert(nsend == sizeof(header)); + nsend_total += nsend; + nsend = send(fds[j].fd, message, TEXTMESSAGE_SIZE, 0); + assert(nsend != -1); + assert(nsend == TEXTMESSAGE_SIZE); + nsend_total += nsend; + nsend = send(fds[j].fd, &message->text, message->len * sizeof(*message->text), 0); + assert(nsend != -1); + assert(nsend == (message->len * sizeof(*message->text))); + nsend_total += nsend; - ArenaPop(msgTextArena, mrecv.text_len); + fprintf(stdout, " Retransmitted(%d->%d) %d bytes.\n", fds[i].fd, fds[j].fd, nsend_total); + } + break; + default: + fprintf(stdout, " Got unhandled message type '%s' from client %d", headerTypeString(header.type), fds[i].fd); + continue; + } } } + ArenaRelease(clientsArena); ArenaRelease(fdsArena); - ArenaRelease(msgTextArena); + ArenaRelease(msgsArena); return 0; } |