aboutsummaryrefslogtreecommitdiff
path: root/source/server.c
diff options
context:
space:
mode:
Diffstat (limited to 'source/server.c')
-rw-r--r--source/server.c581
1 files changed, 581 insertions, 0 deletions
diff --git a/source/server.c b/source/server.c
new file mode 100644
index 0000000..a6613d6
--- /dev/null
+++ b/source/server.c
@@ -0,0 +1,581 @@
+#include <errno.h>
+#include <fcntl.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+/* Assertion macro */
+#ifndef Assert
+#ifdef DEBUG
+#define Assert(expr) if (!(expr)) { \
+ raise(SIGTRAP); \
+}
+#else
+#define Assert(expr) if (!(expr)) { \
+ raise(SIGTRAP); \
+}
+#endif // DEBUG
+#endif // Assert
+
+/* Dependencies */
+#define CHATTY_IMPL
+#include "chatty.h"
+#undef CHATTY_IMPL
+
+#define ARENA_IMPL
+#include "arena.h"
+#undef ARENA_IMPL
+#include "protocol.h"
+
+/* Configuration options */
+// timeout on polling
+#define TIMEOUT 60 * 1000
+// max pending connections
+#define MAX_CONNECTIONS 1600
+// Get number of connections from arena position
+// NOTE: this is somewhat wrong, because of when disconnections happen
+#define FDS_SIZE (fdsArena.pos / sizeof(struct pollfd))
+#define CLIENTS_SIZE (clientsArena.pos / sizeof(Client))
+
+// Where to save clients
+#define CLIENTS_FILE "_clients"
+// Where to write logs
+#define LOGFILE "server.log"
+// Log to LOGFILE instead of stderr
+// #define LOGGING
+
+// enum for indexing the fds array
+enum { FDS_STDIN = 0,
+ FDS_SERVER,
+ FDS_CLIENTS };
+
+// Client information
+typedef struct {
+ u8 author[AUTHOR_LEN]; // matches author property on other message types
+ ID id;
+ struct pollfd* bifd; // Index in fds array
+ struct pollfd* unifd; // Index in fds array
+} Client;
+#define CLIENT_FMT "[%s](%lu)"
+#define CLIENT_ARG(client) client.author, client.id
+
+typedef enum {
+ BIFD = 0,
+ UNIFD,
+} ClientFD;
+
+// TODO: remove global variable
+// For handing out new ids to connections.
+// Start at 1 because this makes 0 an invalid client id.
+global_variable u32 nclients = 1;
+
+// Returns client matching id in clients nclients number of clients.
+// Returns 0 if no client was found or if id was 0.
+Client*
+getClientByID(Client* clients, u32 nclients, ID id)
+{
+ if (!id) return 0;
+
+ for (u32 i = 0; i < nclients; i++)
+ {
+ if (clients[i].id == id)
+ return clients + i;
+ }
+ return 0;
+}
+
+// Returns client matching fd in clients nclients number of clients.
+// Returns 0 if no clients was found or if fd was -1.
+Client*
+getClientByFD(Client* clients, u32 nclients, s32 fd)
+{
+ if (fd == -1) return 0;
+
+ for (u32 i = 0; i < nclients; i++)
+ {
+ if ((clients[i].unifd && clients[i].unifd->fd == fd) ||
+ (clients[i].bifd && clients[i].bifd->fd == fd))
+ return clients + i;
+ }
+ return 0;
+}
+
+// Print TextMessage prettily
+void
+printTextMessage(TextMessage* message, Client* client, u8 wide)
+{
+ u8 timestamp[TIMESTAMP_LEN] = {0};
+ formatTimestamp(timestamp, message->timestamp);
+
+ if (wide)
+ {
+ setlocale(LC_ALL, "");
+ wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, client->author, (wchar_t*)&message->text);
+ } else {
+ u8 str[message->len];
+ wcstombs((char*)str, (wchar_t*)&message->text, message->len * sizeof(*message->text));
+ LoggingF("TextMessage: %s [%s] (%d)%s\n", timestamp, client->author, message->len, str);
+ }
+}
+
+// Send header and anyMessage to each connection in fds that is nfds number of connections except
+// for connfd.
+// Does not send if pollfd is not set or pollfd->fd is -1.
+// Type will filter out only connections matching the type.
+void
+sendToOthers(Client* clients, u32 nclients, Client* client, ClientFD type, HeaderMessage* header, void* anyMessage)
+{
+ s32 nsend, fd;
+ for (u32 i = 0; i < nclients - 1; i ++)
+ {
+ if (clients + i == client) continue;
+
+ if (type == UNIFD)
+ {
+ if (clients[i].unifd && clients[i].unifd->fd != -1)
+ fd = clients[i].unifd->fd;
+ else
+ continue;
+ }
+ else if (type == BIFD)
+ {
+ if (clients[i].bifd && clients[i].bifd->fd != -1)
+ fd = clients[i].bifd->fd;
+ else
+ continue;
+ }
+ nsend = sendAnyMessage(fd, *header, anyMessage);
+
+ assert(nsend != -1);
+ LoggingF("sendToOthers "CLIENT_FMT"|%d<-%s %d bytes\n", CLIENT_ARG((clients[i])), fd, headerTypeString(header->type), nsend);
+ }
+}
+
+// Send header and anyMessage to each connection in fds that is nfds number of connections.
+// Does not send if pollfd is not set or pollfd->fd is -1.
+// Type will filter out only connections matching the type.
+void
+sendToAll(Client* clients, u32 nclients, ClientFD type, HeaderMessage* header, void* anyMessage)
+{
+ s32 nsend;
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
+ if (type == UNIFD)
+ {
+ if (clients[i].unifd && clients[i].unifd->fd != -1)
+ nsend = sendAnyMessage(clients[i].unifd->fd, *header, anyMessage);
+ else
+ continue;
+ }
+ else if (type == BIFD)
+ {
+ if (clients[i].bifd && clients[i].bifd->fd != -1)
+ nsend = sendAnyMessage(clients[i].bifd->fd, *header, anyMessage);
+ else
+ continue;
+ }
+ else
+ assert(0);
+ assert(nsend != -1);
+ LoggingF("sendToAll|[%s]->"CLIENT_FMT" %d bytes\n", headerTypeString(header->type),
+ CLIENT_ARG(clients[i]),
+ nsend);
+ }
+}
+
+// Disconnect a client by closing the matching file descriptors
+void
+disconnect(Client* client)
+{
+ LoggingF("Disconnecting "CLIENT_FMT"\n", CLIENT_ARG((*client)));
+ if (client->unifd && client->unifd->fd != -1)
+ {
+ close(client->unifd->fd);
+ client->unifd->fd = -1;
+ client->unifd = 0;
+ }
+ if (client->bifd && client->bifd->fd != -1)
+ {
+ close(client->bifd->fd);
+ client->bifd->fd = -1;
+ client->bifd = 0;
+ }
+}
+
+// Disconnects fds+conn from fds with nfds connections, then send a PresenceMessage to other
+// clients about disconnection.
+void
+disconnectAndNotify(Client* clients, u32 nclients, Client* client)
+{
+ disconnect(client);
+
+ local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE);
+ header.id = client->id;
+ PresenceMessage message = {.type = PRESENCE_TYPE_DISCONNECTED};
+ sendToAll(clients, nclients, UNIFD, &header, &message);
+}
+
+// Receive authentication from pollfd->fd and create client out of it. Look in
+// clientsArena if it already exists. Otherwise push a new onto the arena and write its information
+// to clients_file.
+// See "Authentication" in chatty.h
+// Assumes that the client will send a IDMessage or IntroductionMessage
+// Returns authenticated client
+Client*
+authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, HeaderMessage header)
+{
+ s32 nrecv = 0;
+ Client* client = 0;
+
+ LoggingF("authenticate (%d)|" HEADER_FMT "\n", pollfd->fd, HEADER_ARG(header));
+
+ /* Scenario 1: Search for existing client */
+ if (header.type == HEADER_TYPE_ID)
+ {
+ IDMessage message;
+ s32 nrecv = recv(pollfd->fd, &message, sizeof(message), 0);
+ assert(nrecv == sizeof(message));
+
+ client = getClientByID((Client*)clientsArena->addr, nclients, message.id);
+ if (!client)
+ {
+ LoggingF("authenticate (%d)|notfound\n", pollfd->fd);
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
+ sendAnyMessage(pollfd->fd, header, &error_message);
+ return 0;
+ }
+ else
+ {
+ LoggingF("authenticate (%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id);
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS);
+ sendAnyMessage(pollfd->fd, header, &error_message);
+ }
+
+ if (!client->bifd)
+ client->bifd = pollfd;
+ else if (!client->unifd)
+ client->unifd = pollfd;
+ else
+ assert(0);
+
+
+ return client;
+ }
+ /* Scenario 2: Create a new client */
+ else if (header.type == HEADER_TYPE_INTRODUCTION)
+ {
+ IntroductionMessage message;
+ nrecv = recv(pollfd->fd, &message, sizeof(message), 0);
+ if (nrecv != sizeof(message))
+ {
+ LoggingF("authenticate (%d)|err: %d/%lu bytes\n", pollfd->fd, nrecv, sizeof(message));
+ return 0;
+ }
+
+ // Copy metadata from IntroductionMessage
+ client = ArenaPush(clientsArena, sizeof(*client));
+ memcpy(client->author, message.author, AUTHOR_LEN);
+ client->id = nclients;
+
+ if (!client->bifd)
+ client->bifd = pollfd;
+ else if (!client->unifd)
+ client->unifd = pollfd;
+ else
+ assert(0);
+
+ nclients++;
+
+#ifdef IMPORT_ID
+ write(clients_file, client, sizeof(*client));
+#endif
+ LoggingF("authenticate (%d)|Added [%s](%lu)\n", pollfd->fd, client->author, client->id);
+
+ // Send ID to new client
+ HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
+ IDMessage id_message;
+ id_message.id = client->id;
+
+ s32 nsend = sendAnyMessage(pollfd->fd, header, &id_message);
+ assert(nsend != -1);
+
+ return client;
+ }
+
+ LoggingF("authenticate (%d)|Wrong header expected %s or %s\n", pollfd->fd,
+ headerTypeString(HEADER_TYPE_INTRODUCTION),
+ headerTypeString(HEADER_TYPE_ID));
+ return 0;
+}
+
+int
+main(int argc, char** argv)
+{
+ signal(SIGPIPE, SIG_IGN);
+
+ LogFD = 2;
+ // optional logging
+ if (argc > 1)
+ {
+ if (*argv[1] == '-')
+ if (argv[1][1] == 'l')
+ {
+ LogFD = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600);
+ assert(LogFD != -1);
+ }
+ }
+
+ s32 serverfd;
+ // Start listening on the socket
+ {
+ s32 err;
+ u32 on = 1;
+ serverfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ assert(serverfd > 2);
+
+ err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8*)&on, sizeof(on));
+ assert(!err);
+
+ const struct sockaddr_in address = {
+ AF_INET,
+ htons(PORT),
+ {0},
+ {0},
+ };
+
+ err = bind(serverfd, (const struct sockaddr*)&address, sizeof(address));
+ assert(!err);
+
+ err = listen(serverfd, MAX_CONNECTIONS);
+ assert(!err);
+ LoggingF("Listening on :%d\n", PORT);
+ }
+
+ Arena clientsArena;
+ Arena fdsArena;
+ Arena msgsArena;
+ ArenaAlloc(&clientsArena, MAX_CONNECTIONS * sizeof(Client));
+ ArenaAlloc(&fdsArena, MAX_CONNECTIONS * 2 * sizeof(struct pollfd));
+ ArenaAlloc(&msgsArena, Megabytes(128)); // storing received messages
+ struct pollfd* fds = fdsArena.addr;
+ Client* clients = clientsArena.addr;
+
+ // Initializing fds
+ struct pollfd* fdsAddr;
+ struct pollfd newpollfd = {-1, POLLIN, 0}; // for copying with events already set
+ // initialize fds structure
+ newpollfd.fd = 0;
+ fdsAddr = ArenaPush(&fdsArena, sizeof(*fds));
+ memcpy(fdsAddr, &newpollfd, sizeof(*fds));
+ // add serverfd
+ newpollfd.fd = serverfd;
+ fdsAddr = ArenaPush(&fdsArena, sizeof(*fds));
+ memcpy(fdsAddr, &newpollfd, sizeof(*fds));
+ newpollfd.fd = -1;
+
+ s32 clients_file;
+#ifdef IMPORT_ID
+ clients_file = open(CLIENTS_FILE, O_RDWR | O_CREAT | O_APPEND, 0600);
+ assert(clients_file != -1);
+ struct stat statbuf;
+ assert(fstat(clients_file, &statbuf) != -1);
+
+ read(clients_file, clients, statbuf.st_size);
+ if (statbuf.st_size > 0)
+ {
+ ArenaPush(&clientsArena, statbuf.st_size);
+ LoggingF("Imported %lu client(s)\n", statbuf.st_size / sizeof(*clients));
+ nclients += statbuf.st_size / sizeof(*clients);
+
+ // Reset pointers on imported clients
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
+ clients[i].unifd = 0;
+ clients[i].bifd = 0;
+ }
+ }
+ for (u32 i = 0; i < nclients - 1; i++)
+ LoggingF("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i]));
+#else
+ clients_file = 0;
+#endif
+
+ // Initialize the rest of the fds array
+ for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++)
+ fds[i] = newpollfd;
+
+ while (1)
+ {
+ s32 err = poll(fds, FDS_SIZE, TIMEOUT);
+ assert(err != -1);
+
+ if (fds[FDS_STDIN].revents & POLLIN)
+ {
+ u8 c; // exit on ctrl-d
+ if (!read(fds[FDS_STDIN].fd, &c, 1))
+ break;
+ }
+ else if (fds[FDS_SERVER].revents & POLLIN)
+ {
+ // TODO: what if we are not aligned by 2 anymore?
+ s32 clientfd = accept(serverfd, 0, 0);
+
+ if (clientfd == -1)
+ {
+ LoggingF("Error while accepting connection (%d)\n", clientfd);
+ continue;
+ }
+ else
+ LoggingF("New connection(%d)\n", clientfd);
+
+ // TODO: find empty space in arena (fragmentation)
+ if (nclients + 1 == MAX_CONNECTIONS)
+ {
+ local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR);
+ local_persist ErrorMessage message = ERROR_INIT(ERROR_TYPE_TOOMANYCONNECTIONS);
+ sendAnyMessage(clientfd, header, &message);
+ if (clientfd != -1)
+ close(clientfd);
+ LoggingF("Max clients reached. Rejected connection\n");
+ }
+ else
+ {
+ // no more space, allocate
+ struct pollfd* pollfd = ArenaPush(&fdsArena, sizeof(*pollfd));
+ pollfd->fd = clientfd;
+ LoggingF("Added pollfd(%d)\n", clientfd);
+ }
+ }
+
+ for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn++)
+ {
+ if (!(fds[conn].revents & POLLIN)) continue;
+ if (fds[conn].fd == -1) continue;
+ LoggingF("Message(%d)\n", fds[conn].fd);
+
+ // We received a message, try to parse the header
+ HeaderMessage header;
+ s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0);
+ if(nrecv == -1)
+ {
+ LoggingF("Received error from fd: %d, errno: %d\n", fds[conn].fd, errno);
+ };
+
+ Client* client;
+ if (nrecv != sizeof(header))
+ {
+ client = getClientByFD(clients, nclients, fds[conn].fd);
+ if (client)
+ {
+ LoggingF("Received %d/%lu bytes "CLIENT_FMT"\n", nrecv, sizeof(header), CLIENT_ARG((*client)));
+ disconnectAndNotify(clients, nclients, client);
+ }
+ else
+ {
+ LoggingF("Got error/disconnect from unauthenticated client\n");
+ close(fds[conn].fd);
+ fds[conn].fd = -1;
+ }
+ continue;
+ }
+ LoggingF("Received(%d): " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header));
+
+ // Authentication
+ if (!header.id)
+ {
+ LoggingF("No client for connection(%d)\n", fds[conn].fd);
+
+ client = authenticate(&clientsArena, clients_file, fds + conn, header);
+
+ if (!client)
+ {
+ LoggingF("Could not initialize client (%d)\n", fds[conn].fd);
+ close(fds[conn].fd);
+ fds[conn].fd = -1;
+ }
+ /* This is the first time a message is sent, because unifd is not yet set. */
+ else if (!client->unifd)
+ {
+ LoggingF("Send connected message\n");
+ local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE);
+ header.id = client->id;
+ PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED};
+ sendToOthers(clients, nclients, client, UNIFD, &header, &message);
+ }
+ continue;
+ }
+
+ client = getClientByID(clients, nclients, header.id);
+ if (!client)
+ {
+ LoggingF("No client for id %d\n", fds[conn].fd);
+
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
+
+ sendAnyMessage(fds[conn].fd, header, &message);
+
+ // Reject connection
+ fds[conn].fd = -1;
+ close(fds[conn].fd);
+ continue;
+ }
+
+ switch (header.type) {
+ /* Send text message to all other clients */
+ case HEADER_TYPE_TEXT:
+ {
+ TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd);
+ LoggingF("Received(%d): ", fds[conn].fd);
+ printTextMessage(text_message, client, 0);
+
+ sendToOthers(clients, nclients, client, UNIFD, &header, text_message);
+ } break;
+ /* Send back client information */
+ case HEADER_TYPE_ID:
+ {
+ IDMessage id_message;
+ s32 nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0);
+ assert(nrecv == sizeof(id_message));
+
+ client = getClientByID(clients, nclients, id_message.id);
+ if (!client)
+ {
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
+ s32 nsend = sendAnyMessage(fds[conn].fd, header, &message);
+ assert(nsend != -1);
+ break;
+ }
+
+ HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION);
+ IntroductionMessage introduction_message;
+ header.id = client->id;
+ memcpy(introduction_message.author, client->author, AUTHOR_LEN);
+
+ nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message);
+ assert(nrecv != -1);
+ } break;
+ default:
+ LoggingF("Unhandled '%s' from "CLIENT_FMT"(%d)\n", headerTypeString(header.type),
+ CLIENT_ARG((*client)),
+ fds[conn].fd);
+ disconnectAndNotify(client, nclients, client);
+ continue;
+ }
+ }
+ }
+
+#ifdef IMPORT_ID
+ close(clients_file);
+#endif
+
+ return 0;
+}