diff options
Diffstat (limited to 'server.c')
-rw-r--r-- | server.c | 462 |
1 files changed, 238 insertions, 224 deletions
@@ -39,8 +39,8 @@ enum { FDS_STDIN = 0, typedef struct { u8 author[AUTHOR_LEN]; // matches author property on other message types ID id; - struct pollfd* pollunifd; // Index in fds array - struct pollfd* pollbifd; // Index in fds array + struct pollfd* unifd; // Index in fds array + struct pollfd* bifd; // Index in fds array } Client; #define CLIENT_FMT "[%s](%lu)" #define CLIENT_ARG(client) client.author, client.id @@ -50,24 +50,42 @@ typedef enum { BIFD } ClientFD; -// TODO: remove +// TODO: remove global variable // For handing out new ids to connections. -global_variable u32 nclients = 0; +// Start at 1 because this makes 0 an invalid client id. +global_variable u32 nclients = 1; -// Returns client matching id in clients. -// clientsArena is used to get an upper bound. -// Returns 0 if there was no client found. +// Returns client matching id in clients nclients number of clients. +// Returns 0 if no client was found or if id was 0. Client* -getClientByID(Arena* clientsArena, ID id) +getClientByID(Client* clients, u32 nclients, ID id) { - Client* clients = clientsArena->addr; - for (u32 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) { + if (!id) return 0; + + for (u32 i = 0; i < nclients - 1; i++) + { if (clients[i].id == id) return clients + i; } return 0; } +// Returns client matching fd in clients nclients number of clients. +// Returns 0 if no clients was found or if fd was -1. +Client* +getClientByFD(Client* clients, u32 nclients, s32 fd) +{ + if (fd == -1) return 0; + + for (u32 i = 0; i < nclients - 1; i++) + { + if ((clients[i].unifd && clients[i].unifd->fd == fd) || + (clients[i].bifd && clients[i].bifd->fd == fd)) + return clients + i; + } + return 0; +} + // Print TextMessage prettily void printTextMessage(TextMessage* message, Client* client, u8 wide) @@ -75,7 +93,8 @@ printTextMessage(TextMessage* message, Client* client, u8 wide) u8 timestamp[TIMESTAMP_LEN] = {0}; formatTimestamp(timestamp, message->timestamp); - if (wide) { + if (wide) + { setlocale(LC_ALL, ""); wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, client->author, (wchar_t*)&message->text); } else { @@ -89,60 +108,81 @@ printTextMessage(TextMessage* message, Client* client, u8 wide) // for connfd. // Type will filter out only connections matching the type. void -sendToOthers(struct pollfd* fds, u32 nfds, s32 connfd, ClientFD type, HeaderMessage* header, void* anyMessage) +sendToOthers(Client* clients, u32 nclients, Client* client, ClientFD type, HeaderMessage* header, void* anyMessage) { s32 nsend; - for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { - if (fds[i].fd == connfd) continue; - if (fds[i].fd == -1) continue; + for (u32 i = 0; i < nclients - 1; i ++) + { + if (clients + i == client) continue; - nsend = sendAnyMessage(fds[i].fd, header, anyMessage); - loggingf("sendToOthers(%d)|[%s]->%d %d bytes\n", connfd, headerTypeString(header->type), fds[i].fd, nsend); + if (type == UNIFD) + { + nsend = sendAnyMessage(client->unifd->fd, *header, anyMessage); + } + else if (type == BIFD) + { + nsend = sendAnyMessage(client->bifd->fd, *header, anyMessage); + } + assert(nsend != -1); + loggingf("sendToOthers "CLIENT_FMT"|%s %d bytes\n", CLIENT_ARG((*client)), headerTypeString(header->type), nsend); } } // Send header and anyMessage to each connection in fds that is nfds number of connections. // Type will filter out only connections matching the type. void -sendToAll(struct pollfd* fds, u32 nfds, ClientFD type, HeaderMessage* header, void* anyMessage) +sendToAll(Client* clients, u32 nclients, ClientFD type, HeaderMessage* header, void* anyMessage) { - for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { - if (fds[i].fd == -1) continue; - s32 nsend = sendAnyMessage(fds[i].fd, header, anyMessage); - loggingf("sendToAll|[%s]->%d %d bytes\n", headerTypeString(header->type), fds[i].fd, nsend); + s32 nsend; + for (u32 i = 0; i < nclients - 1; i++) + { + if (type == UNIFD) + { + if (clients[i].unifd && clients[i].unifd->fd != -1) + nsend = sendAnyMessage(clients[i].unifd->fd, *header, anyMessage); + } + else if (type == BIFD) + { + if (clients[i].bifd && clients[i].bifd->fd != -1) + nsend = sendAnyMessage(clients[i].bifd->fd, *header, anyMessage); + } + assert(nsend != -1); + loggingf("sendToAll|[%s]->"CLIENT_FMT" %d bytes\n", headerTypeString(header->type), + CLIENT_ARG(clients[i]), + nsend); } } // Disconnect a client by closing the matching file descriptors void -disconnect(struct pollfd* pollfd, Client* client) +disconnect(Client* client) { loggingf("Disconnecting "CLIENT_FMT"\n", CLIENT_ARG((*client))); - if (pollfd[UNIFD].fd != -1) { - close(pollfd[UNIFD].fd); - } - if (pollfd[BIFD].fd != -1) { - close(pollfd[BIFD].fd); + if (client->unifd && client->unifd->fd != -1) + { + close(client->unifd->fd); + client->unifd->fd = -1; + client->unifd = 0; } - pollfd[UNIFD].fd = -1; - pollfd[BIFD].fd = -1; - // TODO: mark as free - if (client) { - client->pollunifd = 0; - client->pollbifd = 0; + if (client->bifd && client->bifd->fd != -1) + { + close(client->bifd->fd); + client->bifd->fd = -1; + client->bifd = 0; } } // Disconnects fds+conn from fds with nfds connections, then send a PresenceMessage to other // clients about disconnection. void -disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn) +disconnectAndNotify(Client* clients, u32 nclients, Client* client) { - disconnect(fds + conn, client); + disconnect(client); local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); - PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_DISCONNECTED}; - sendToAll(fds, nfds, UNIFD, &header, &message); + header.id = client->id; + PresenceMessage message = {.type = PRESENCE_TYPE_DISCONNECTED}; + sendToAll(clients, nclients, UNIFD, &header, &message); } // Receive authentication from pollfd->fd and create client out of it. Look in @@ -150,83 +190,76 @@ disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn) // to clients_file. // See "Authentication" in chatty.h Client* -authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* clientfds) +authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, HeaderMessage header) { s32 nrecv = 0; - Client* clients = clientsArena->addr; - - HeaderMessage header; - nrecv = recv(clientfds[BIFD].fd, &header, sizeof(header), 0); - if (nrecv != sizeof(header)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(header)); - return 0; - } - loggingf("authenticate(%d)|" HEADER_FMT "\n", clientfds[BIFD].fd, HEADER_ARG(header)); - Client* client = 0; - // Scenario 1: Search for existing client - if (header.type == HEADER_TYPE_ID) { - IDMessage message; - nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); - if (nrecv != sizeof(message)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); - return 0; - } - client = getClientByID(clientsArena, message.id); - if (client) { - loggingf("authenticate(%d)|found [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); - header.type = HEADER_TYPE_ERROR; - // TODO: allow multiple connections - if (client->pollunifd != 0 || client->pollbifd != 0) { - loggingf("authenticate(%d)|err: already connected\n", clientfds[BIFD].fd); - ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_ALREADYCONNECTED); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); - return 0; - } - ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); - } else { - loggingf("authenticate(%d)|notfound\n", clientfds[BIFD].fd); + /* Scenario 1: Search for existing client */ + if (header.type == HEADER_TYPE_ID) + { + client = getClientByID((Client*)clientsArena->addr, nclients, header.id); + if (!client) + { + loggingf("authenticate(%d)|notfound\n", pollfd->fd); header.type = HEADER_TYPE_ERROR; ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); + sendAnyMessage(pollfd->fd, header, &error_message); return 0; } - // Scenario 2: Create a new client - } else if (header.type == HEADER_TYPE_INTRODUCTION) { + + loggingf("authenticate(%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id); + header.type = HEADER_TYPE_ERROR; + if (!client->unifd) + client->unifd = pollfd; + else if(!client->bifd) + client->bifd = pollfd; + else + assert(0); + + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); + header.type = HEADER_TYPE_ERROR; + sendAnyMessage(pollfd->fd, header, &error_message); + + return client; + } + + /* Scenario 2: Create a new client */ + if (header.type == HEADER_TYPE_INTRODUCTION) + { IntroductionMessage message; - nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); - if (nrecv != sizeof(message)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); + nrecv = recv(pollfd->fd, &message, sizeof(message), 0); + if (nrecv != sizeof(message)) + { + loggingf("authenticate(%d)|err: %d/%lu bytes\n", pollfd->fd, nrecv, sizeof(message)); return 0; } // Copy metadata from IntroductionMessage - client = ArenaPush(clientsArena, sizeof(*clients)); + client = ArenaPush(clientsArena, sizeof(*client)); memcpy(client->author, message.author, AUTHOR_LEN); client->id = nclients; nclients++; - // Save client #ifdef IMPORT_ID write(clients_file, client, sizeof(*client)); #endif - loggingf("authenticate(%d)|Added [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); + loggingf("authenticate(%d)|Added [%s](%lu)\n", pollfd->fd, client->author, client->id); + // Send ID to new client HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - IDMessage id_message = {.id = client->id}; - sendAnyMessage(clientfds[BIFD].fd, &header, &id_message); - } else { - loggingf("authenticate(%d)|Wrong header expected %s or %s\n", clientfds[BIFD].fd, headerTypeString(HEADER_TYPE_INTRODUCTION), headerTypeString(HEADER_TYPE_ID)); - return 0; - } - assert(client != 0); + header.id = client->id; + s32 nsend = send(pollfd->fd, &header, sizeof(header), 0); + assert(nsend != -1); + assert(nsend == sizeof(header)); - client->pollunifd = clientfds; - client->pollbifd = clientfds + 1; + return client; + } - return client; + loggingf("authenticate(%d)|Wrong header expected %s or %s\n", pollfd->fd, + headerTypeString(HEADER_TYPE_INTRODUCTION), + headerTypeString(HEADER_TYPE_ID)); + return 0; } int @@ -236,9 +269,11 @@ main(int argc, char** argv) logfd = 2; // optional logging - if (argc > 1) { + if (argc > 1) + { if (*argv[1] == '-') - if (argv[1][1] == 'l') { + if (argv[1][1] == 'l') + { logfd = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600); assert(logfd != -1); } @@ -299,192 +334,171 @@ main(int argc, char** argv) assert(fstat(clients_file, &statbuf) != -1); read(clients_file, clients, statbuf.st_size); - if (statbuf.st_size > 0) { + if (statbuf.st_size > 0) + { ArenaPush(&clientsArena, statbuf.st_size); loggingf("Imported %lu client(s)\n", statbuf.st_size / sizeof(*clients)); nclients += statbuf.st_size / sizeof(*clients); + + // Reset pointers on imported clients + for (u32 i = 0; i < nclients - 1; i++) + { + clients[i].unifd = 0; + clients[i].bifd = 0; + } } - for (u32 i = 0; i < nclients; i++) + for (u32 i = 0; i < nclients - 1; i++) loggingf("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i])); #endif // Initialize the rest of the fds array for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++) fds[i] = newpollfd; - // Reset file descriptors on imported clients - for (u32 i = 0; i < CLIENTS_SIZE; i++) { - clients[i].pollunifd = 0; - clients[i].pollbifd = 0; - } - while (1) { + while (1) + { s32 err = poll(fds, FDS_SIZE, TIMEOUT); assert(err != -1); - if (fds[FDS_STDIN].revents & POLLIN) { + if (fds[FDS_STDIN].revents & POLLIN) + { u8 c; // exit on ctrl-d if (!read(fds[FDS_STDIN].fd, &c, 1)) break; - } else if (fds[FDS_SERVER].revents & POLLIN) { + } + else if (fds[FDS_SERVER].revents & POLLIN) + { // TODO: what if we are not aligned by 2 anymore? - s32 unifd = accept(serverfd, 0, 0); - s32 bifd = accept(serverfd, 0, 0); + s32 clientfd = accept(serverfd, 0, 0); - if (unifd == -1 || bifd == -1) { - loggingf("Error while accepting connection (%d,%d)\n", unifd, bifd); - if (unifd != -1) close(unifd); - if (bifd != -1) close(bifd); + if (clientfd == -1) + { + loggingf("Error while accepting connection (%d)\n", clientfd); continue; - } else - loggingf("New connection(%d,%d)\n", unifd, bifd); + } + else + loggingf("New connection(%d)\n", clientfd); - // TODO: find empty space in arena - if (nclients + 1 == MAX_CONNECTIONS) { + // TODO: find empty space in arena (fragmentation) + if (nclients + 1 == MAX_CONNECTIONS) + { local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); local_persist ErrorMessage message = ERROR_INIT(ERROR_TYPE_TOOMANYCONNECTIONS); - sendAnyMessage(unifd, &header, &message); - if (unifd != -1) - close(unifd); - if (bifd != -1) - close(bifd); + sendAnyMessage(clientfd, header, &message); + if (clientfd != -1) + close(clientfd); loggingf("Max clients reached. Rejected connection\n"); - } else { + } + else + { // no more space, allocate - struct pollfd* clientfds = ArenaPush(&fdsArena, 2 * sizeof(*clientfds)); - clientfds[UNIFD].fd = unifd; - clientfds[UNIFD].events = POLLIN; - clientfds[BIFD].fd = bifd; - clientfds[BIFD].events = POLLIN; - loggingf("Added pollfd(%d,%d)\n", unifd, bifd); + struct pollfd* pollfd = ArenaPush(&fdsArena, sizeof(*pollfd)); + pollfd->fd = clientfd; + loggingf("Added pollfd(%d)\n", clientfd); } } - // Check for messages from clients in their unifd - for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn += 2) { + for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn++) + { if (!(fds[conn].revents & POLLIN)) continue; if (fds[conn].fd == -1) continue; loggingf("Message unifd (%d)\n", fds[conn].fd); - // Get client associated with connection - Client* client = 0; - for (u32 j = 0; j < CLIENTS_SIZE; j++) { - if (!clients[j].pollunifd) - continue; - if (clients[j].pollunifd == fds + conn) { - client = clients + j; - break; - } - } - if (!client) { - loggingf("No client associated(%d)\n", fds[conn].fd); - close(fds[conn].fd); - continue; - } - loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); - // We received a message, try to parse the header HeaderMessage header; s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); - if (nrecv == 0) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); - continue; - } else if (nrecv != sizeof(header)) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); + assert(nrecv != -1); + + Client* client; + if (nrecv != sizeof(header)) + { + client = getClientByFD(clients, nclients, fds[conn].fd); + loggingf(CLIENT_FMT" %d/%lu bytes\n", CLIENT_ARG((*client)), nrecv, sizeof(header)); + if (client) + { + disconnectAndNotify(clients, nclients, client); + loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); + } + else + { + loggingf("Got error from unauntheticated client\n"); + close(fds[conn].fd); + fds[conn].fd = -1; + } continue; } loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); - switch (header.type) { - case HEADER_TYPE_TEXT: { - TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); - loggingf("Received(%d)", fds[conn].fd); - printTextMessage(text_message, client, 0); + // Authentication + if (header.id) + { + client = getClientByID(clients, nclients, header.id); + if (!client) + { + loggingf("No client for id %d\n", fds[conn].fd); - sendToOthers(fds, FDS_SIZE, fds[conn].fd, UNIFD, &header, text_message); - } break; - // handle request for information about client id - default: - loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); - disconnectAndNotify(client, fds, FDS_SIZE, conn); - continue; - } - } + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - // Check for messages from clients in their bifd - for (u32 conn = FDS_CLIENTS + BIFD; conn < FDS_SIZE; conn += 2) { - if (!(fds[conn].revents & POLLIN)) continue; - if (fds[conn].fd == -1) continue; - loggingf("Message bifd (%d)\n", fds[conn].fd); - - // Get client associated with connection - Client* client = 0; - for (u32 j = 0; j < CLIENTS_SIZE; j++) { - if (!clients[j].pollbifd) - continue; - if (clients[j].pollbifd == fds + conn) { - client = clients + j; - break; + sendAnyMessage(fds[conn].fd, header, &message); + + // Reject connection + fds[conn].fd = -1; + close(fds[conn].fd); + } + else + { + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_SUCCESS); + + sendAnyMessage(fds[conn].fd, header, &message); } } - if (!client) { + else + { loggingf("No client for connection(%d)\n", fds[conn].fd); -#ifdef IMPORT_ID - client = authenticate(&clientsArena, clients_file, fds + conn - 1); -#else - client = authenticate(&clientsArena, 0, fds + conn - 1); -#endif - // If the client sent an IDMessage but no ID was found authenticate() could return null - if (!client) { - loggingf("Could not initialize client\n"); - disconnect(fds + conn, 0); - } else { // client was added/connected + + client = authenticate(&clientsArena, clients_file, fds + conn, header); + + if (!client) + loggingf("Could not initialize client (%d)\n", fds[conn].fd); + else if (!client->bifd) + { + // Send connected message to other clients if this was the first time -> bifd is + // not set yet. local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); - PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_CONNECTED}; - sendToOthers(fds, FDS_SIZE, fds[conn - BIFD].fd, UNIFD, &header, &message); + header.id = client->id; + PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED}; + sendToOthers(clients, nclients, client, UNIFD, &header, &message); } - continue; } - loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); - - // We received a message, try to parse the header - HeaderMessage header; - s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); - if (nrecv == 0) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); - continue; - } else if (nrecv != sizeof(header)) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); - continue; - } - loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); switch (header.type) { - case HEADER_TYPE_ID: { - // handle request for information about client id - IDMessage id_message; - nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0); - - Client* client = getClientByID(&clientsArena, id_message.id); - if (!client) { - local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); - local_persist ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - sendAnyMessage(fds[conn].fd, &header, &error_message); - loggingf("Could not find %lu\n", id_message.id); - break; - } + /* Send text message to all other clients */ + case HEADER_TYPE_TEXT: + { + TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); + loggingf("Received(%d)", fds[conn].fd); + printTextMessage(text_message, client, 0); + + sendToOthers(clients, nclients, client, UNIFD, &header, text_message); + } break; + /* Send back client information */ + case HEADER_TYPE_ID: + { HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); IntroductionMessage introduction_message; + header.id = client->id; memcpy(introduction_message.author, client->author, AUTHOR_LEN); - sendAnyMessage(fds[conn].fd, &header, &introduction_message); + s32 nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message); + assert(nrecv != -1); } break; default: - loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); - disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("Unhandled '%s' from "CLIENT_FMT"(%d)\n", headerTypeString(header.type), + CLIENT_ARG((*client)), + fds[conn].fd); + disconnectAndNotify(client, nclients, client); continue; } } |