From 4b3dfeddc15908fdf53df42818ed167addc71659 Mon Sep 17 00:00:00 2001 From: Raymaekers Luca Date: Sun, 3 Nov 2024 16:25:10 +0100 Subject: Add id field to HeaderMessage for simplifying - Changed bracket style as well --- .gitignore | 3 + README.md | 10 +- chatty.c | 345 +++++++++++++++++++++++++-------------------- protocol.h | 61 ++++---- send.c | 17 +-- server.c | 462 +++++++++++++++++++++++++++++++------------------------------ 6 files changed, 478 insertions(+), 420 deletions(-) diff --git a/.gitignore b/.gitignore index 91c2cfe..688b48f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ chatty send server + _id _clients *.log + +tags diff --git a/README.md b/README.md index c3a7d19..50e9b7a 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ The idea is the following: ## server - [x] import clients +- [ ] check that fds arena does not overflow - [ ] check if when sending and the client is offline (due to connection loss) what happens - [ ] timeout on recv? - [ ] use threads to handle clients/ timeout when receiving because a client could theoretically @@ -29,31 +30,32 @@ The idea is the following: - [ ] do not crash on errors from clients - implement error message? - timeout on recv with setsockopt +- [ ] theoretically two clients can connect at the same time. The uni/bi connections should be + negotiated. ## common - [x] handle messages that are too large - [x] refactor i&self into conn - [x] logging - [x] Req|Inf connection per client +- [x] connect/disconnect messages - [ ] bug: blocking after `Added pollfd`, after importing a client and then connecting with the id/or without? After reconnection fails chatty blocks (remove sleep) - [ ] connect/disconnections messages - [ ] use IP address / domain - [ ] chat history - [ ] asserting, logging if fail / halt execution +- [ ] compression ## Protocol - see `protocol.h` for more info - [ ] make sections per message - request chat logs from a certain point up to now (history) - connect to a specific room -- connect/disconnect messages - The null terminator must be sent with the string. - The text can be arbitrary length -- [ ] compression - ## Arena's 1. There is an arena for the messages' texts (`msgTextArena`) and an arena for the messages (`msgsArena`). @@ -82,4 +84,4 @@ Notice, that this depends on knowing the text's length before allocating the mem - *pthreads*: [C for dummies](https://c-for-dummies.com/blog/?p=5365) - *unicode and wide characters*: [C for dummies](https://c-for-dummies.com/blog/?p=2578) - *sockets*: [Nir Lichtman - Making Minimalist Chat Server in C on Linux](https://www.youtube.com/watch?v=gGfTjKwLQxY) -- syscall manpages +- syscall manpages `man` diff --git a/chatty.c b/chatty.c index 4746aa7..ccc722a 100644 --- a/chatty.c +++ b/chatty.c @@ -37,12 +37,6 @@ typedef struct { #define CLIENT_FMT "[%s](%lu)" #define CLIENT_ARG(client) client.author, client.id -typedef struct { - s32 err; // Error while connecting - s32 unifd; - s32 bifd; -} ConnectionResult; - // Client used by chatty global_variable Client user = {0}; // Address of chatty server @@ -74,7 +68,8 @@ getClientById(Arena* clientsArena, ID id) if (id == user.id) return &user; Client* clients = clientsArena->addr; - for (u64 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) { + for (u64 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) + { if (clients[i].id == id) return clients + i; } @@ -88,16 +83,17 @@ addClientInfo(Arena* clientsArena, s32 fd, u64 id) { // Request information about ID HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - IDMessage id_message = {.id = id}; - sendAnyMessage(fd, &header, &id_message); - - Client* client = ArenaPush(clientsArena, sizeof(*client)); + header.id = id; + s32 nsend = send(fd, &header, sizeof(header), 0); + assert(nsend != -1); + assert(nsend == sizeof(header)); // Wait for response IntroductionMessage introduction_message; recvAnyMessageType(fd, &header, &introduction_message, HEADER_TYPE_INTRODUCTION); // Add the information + Client* client = ArenaPush(clientsArena, sizeof(*client)); memcpy(client->author, introduction_message.author, AUTHOR_LEN); client->id = id; @@ -106,17 +102,57 @@ addClientInfo(Arena* clientsArena, s32 fd, u64 id) } // Tries to connect to address and populates resulting file descriptors in ConnectionResult. -ConnectionResult +s32 getConnection(struct sockaddr_in* address) { - ConnectionResult result; - result.unifd = socket(AF_INET, SOCK_STREAM, 0); - result.bifd = socket(AF_INET, SOCK_STREAM, 0); - result.err = connect(result.unifd, (struct sockaddr*)address, sizeof(*address)); - if (result.err) return result; // We do not overwrite the error and return early so we can be - // certain of what error errno belongs to. - result.err = connect(result.bifd, (struct sockaddr*)address, sizeof(*address)); - return result; + s32 fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) return -1; + + s32 err = connect(fd, (struct sockaddr*)address, sizeof(*address)); + if (err) return -1; + + return fd; +} + +ID +authenticate(Client* user, s32 fd) +{ + if (user->id) + { + HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); + header.id = user->id; + s32 nsend = send(fd, &header, sizeof(header), 0); + assert(nsend == -1); + assert(nsend == sizeof(header)); + + s32 nrecv = recv(fd, &header, sizeof(header), 0); + if (nrecv == 0) + return 0; + assert(nrecv != -1); + assert(nrecv == sizeof(header)); + assert(header.type == HEADER_TYPE_ID); + if (header.id == user->id) + return header.id; + else + return 0; + } + else + { + HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); + IntroductionMessage message; + memcpy(message.author, user->author, AUTHOR_LEN); + sendAnyMessage(fd, header, &message); + + s32 nrecv = recv(fd, &header, sizeof(header), 0); + if (nrecv == 0) + return 0; + assert(nrecv != -1); + assert(nrecv == sizeof(header)); + assert(header.type == HEADER_TYPE_ID); + assert(!header.id); + user->id = header.id; + return header.id; + } } // Connect to *address_ptr of type `struct sockaddr_in*`. If it failed wait for TIMEOUT_RECONNECT @@ -129,49 +165,44 @@ getConnection(struct sockaddr_in* address) void* threadReconnect(void* fds_ptr) { + s32 unifd, bifd; struct pollfd* fds = fds_ptr; - ConnectionResult result; struct timespec t = { 0, Miliseconds(300) }; // 300 miliseconds loggingf("Trying to reconnect\n"); - while (1) { + while (1) + { + // timeout nanosleep(&t, &t); - result = getConnection(&address); - if (result.err) { - // loggingf("err: %d\n", result.err); + + unifd = getConnection(&address); + if (unifd == -1) + { loggingf("errno: %d\n", errno); - } else if (result.unifd != -1 && result.bifd != -1) { - loggingf("Reconnect succeeded (%d, %d), authenticating\n", result.unifd, result.bifd); - // We assume that we already have an ID - // TODO: could there be a problem if a message is received at the same time? - // - not on server restart, but what if we lost connection? - HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - IDMessage id_message = {.id = user.id}; - sendAnyMessage(result.bifd, &header, &id_message); - - ErrorMessage error_message; - s32 nrecv = recvAnyMessageType(result.bifd, &header, &error_message, HEADER_TYPE_ERROR); - if (nrecv == -1 || nrecv == 0) { - loggingf("Error on receive, retrying...\n"); - continue; - } + continue; + } + bifd = getConnection(&address); + if (bifd == -1) + { + loggingf("errno: %d\n", errno); + close(unifd); + continue; + } - assert(header.type == HEADER_TYPE_ERROR); - if (error_message.type == ERROR_TYPE_SUCCESS) { - loggingf("Reconnected\n"); - break; - } else { - loggingf("err: %s\n", errorTypeString(error_message.type)); - } + loggingf("Reconnect succeeded (%d, %d), authenticating\n", unifd, bifd); + + if (authenticate(&user, unifd) && + authenticate(&user, bifd)) + { + break; } - if (result.unifd != -1) - close(result.unifd); - if (result.bifd != -1) - close(result.bifd); - loggingf("Failed, retrying..\n"); + + close(unifd); + close(bifd); + loggingf("Failed, retrying...\n"); } - fds[FDS_BI].fd = result.bifd; - fds[FDS_UNI].fd = result.unifd; + fds[FDS_BI].fd = bifd; + fds[FDS_UNI].fd = unifd; // Redraw screen raise(SIGWINCH); @@ -206,16 +237,20 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx u32 failed = 0; // NOTE: We can assume that we need to wrap, therefore print a newline after the prefix string - if (pfx != 0) { + if (pfx != 0) + { tb_printf(x, ly, fg_pfx, bg_pfx, "%s", pfx); // If the text fits on one line print the text and return // Otherwise print the text on the next line s32 pfx_len = strlen((char*)pfx); - if (limit_x > pfx_len + text_len) { + if (limit_x > pfx_len + text_len) + { tb_printf(x + pfx_len, y, fg, bg, "%ls", text); return 1; - } else { + } + else + { ly++; } } @@ -235,18 +270,22 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx // 8. step 2. until i >= text_len // 9. print remaining part of the string - while (i < text_len && ly - y < limit_y) { + while (i < text_len && ly - y < limit_y) + { // search backwards for whitespace while (i > offset && text[i] != L' ') i--; // retry with bigger limit - if (i == offset) { + if (i == offset) + { offset = i; failed++; i += limit_x + failed * limit_x; continue; - } else { + } + else + { failed = 0; } @@ -261,7 +300,8 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx offset = i; i += limit_x; } - if ((u32)ly <= limit_y) { + if ((u32)ly <= limit_y) + { tb_printf(x, ly, fg, bg, "%ls", text + offset); ly++; } @@ -283,11 +323,14 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* // the minimum height required is the hight for the box prompt // the minimum width required is that one character should fit in the box prompt if (global.height < box_height || - global.width < (box_x + box_mar_x * 2 + box_pad_x * 2 + box_bwith * 2 + 1)) { + global.width < (box_x + box_mar_x * 2 + box_pad_x * 2 + box_bwith * 2 + 1)) + { // + 1 for cursor tb_hide_cursor(); return; - } else { + } + else + { // show cursor // TODO: show cursor as block character instead of using the real cursor bytebuf_puts(&global.out, global.caps[TB_CAP_SHOW_CURSOR]); @@ -310,11 +353,14 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* u32 offs = (nmessages > free_y) ? nmessages - free_y : 0; // skip offs ccount messages - for (u32 i = 0; i < offs; i++) { + for (u32 i = 0; i < offs; i++) + { HeaderMessage* header = (HeaderMessage*)addr; addr += sizeof(*header); - switch (header->type) { - case HEADER_TYPE_TEXT: { + switch (header->type) + { + case HEADER_TYPE_TEXT: + { TextMessage* message = (TextMessage*)addr; addr += TEXTMESSAGE_SIZE; addr += message->len * sizeof(*message->text); @@ -333,36 +379,41 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* } // In each case statement advance the addr pointer by the size of the message - for (u32 i = offs; i < nmessages && msg_y < free_y; i++) { + for (u32 i = offs; i < nmessages && msg_y < free_y; i++) + { HeaderMessage* header = (HeaderMessage*)addr; addr += sizeof(*header); // Get Client for message - ID* id; Client* client; - switch (header->type) { + switch (header->type) + { case HEADER_TYPE_TEXT: - id = &((TextMessage*)addr)->id; case HEADER_TYPE_PRESENCE: - id = &((PresenceMessage*)addr)->id; - client = getClientById(clientsArena, *id); - if (!client) { + client = getClientById(clientsArena, header->id); + if (!client) + { loggingf("Client not known, requesting from server\n"); - client = addClientInfo(clientsArena, fds[FDS_BI].fd, *id); + client = addClientInfo(clientsArena, fds[FDS_BI].fd, header->id); } assert(client); break; } - switch (header->type) { - case HEADER_TYPE_TEXT: { + switch (header->type) + { + case HEADER_TYPE_TEXT: + { TextMessage* message = (TextMessage*)addr; // Color own messages u32 fg = 0; - if (user.id == message->id) { + if (user.id == header->id) + { fg = TB_CYAN; - } else { + } + else + { fg = TB_MAGENTA; } @@ -377,13 +428,15 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* u32 message_size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text); addr += message_size; } break; - case HEADER_TYPE_PRESENCE: { + case HEADER_TYPE_PRESENCE: + { PresenceMessage* message = (PresenceMessage*)addr; tb_printf(0, msg_y, 0, 0, " [%s] *%s*", client->author, presenceTypeString(message->type)); msg_y++; addr += sizeof(*message); } break; - case HEADER_TYPE_HISTORY: { + case HEADER_TYPE_HISTORY: + { HistoryMessage* message = (HistoryMessage*)addr; addr += sizeof(*message); // TODO: implement @@ -445,17 +498,21 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* if (freesp <= 0) return; - if (input_len > freesp) { + if (input_len > freesp) + { u32* text_offs = input + (input_len - freesp); tb_printf(box_x + box_mar_x + box_pad_x + box_bwith, box_y + 1, 0, 0, "%ls", text_offs); global.cursor_x = box_x + box_pad_x + box_mar_x + box_bwith + freesp; - } else { + } + else + { global.cursor_x = prompt_x; tb_printf(box_x + box_mar_x + box_pad_x + box_bwith, box_y + 1, 0, 0, "%ls", input); } } - if (fds[FDS_UNI].fd == -1 || fds[FDS_BI].fd == -1) { + if (fds[FDS_UNI].fd == -1 || fds[FDS_BI].fd == -1) + { // show error popup popup(TB_RED, TB_BLACK, "Server disconnected."); } @@ -465,7 +522,8 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* int main(int argc, char** argv) { - if (argc < 2) { + if (argc < 2) + { fprintf(stderr, "usage: chatty \n"); return 1; } @@ -516,67 +574,42 @@ main(int argc, char** argv) {0}, }; - ConnectionResult result = getConnection(&address); - if (result.err) { - perror("Server"); - return 1; - } - assert(result.unifd != -1); - assert(result.bifd != -1); - assert(!result.err); - fds[FDS_BI].fd = result.bifd; - fds[FDS_UNI].fd = result.unifd; - #ifdef IMPORT_ID // File for storing the user's ID. u32 idfile = open(ID_FILE, O_RDWR | O_CREAT, 0600); s32 nread = read(idfile, &user.id, sizeof(user.id)); assert(nread != -1); - // see "Authentication" in chatty.h - if (nread == sizeof(user.id)) { - // Scenario 1: We know our id - - // Send IDMessage and check if it is correct - HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - IDMessage message = {.id = user.id}; - sendAnyMessage(fds[FDS_BI].fd, &header, &message); - - ErrorMessage error_message = {0}; - recvAnyMessageType(fds[FDS_BI].fd, &header, &error_message, HEADER_TYPE_ERROR); - - switch (error_message.type) { - case ERROR_TYPE_SUCCESS: break; - case ERROR_TYPE_NOTFOUND: - printf("Server does not know our ID. Consider removing '" ID_FILE "'\n"); +#endif + /* Authentication */ + { + s32 unifd, bifd; + unifd = getConnection(&address); + if (unifd == -1) + { + loggingf("errno: %d\n", errno); return 1; - default: - printf("Server: %s\n", errorTypeString(error_message.type)); + } + bifd = getConnection(&address); + if (bifd == -1) + { + loggingf("errno: %d\n", errno); return 1; } - } else { -#else - if (1) { -#endif - // Scenario 2: We do not have an ID - HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); - IntroductionMessage message = {0}; - // copy user data into message - memcpy(message.author, user.author, AUTHOR_LEN); - - // Send the introduction message - sendAnyMessage(fds[FDS_BI].fd, &header, &message); + if (!authenticate(&user, unifd) || + !authenticate(&user, bifd)) + { + loggingf("errno: %d\n", errno); + return 1; + } + fds[FDS_UNI].fd = unifd; + fds[FDS_BI].fd = bifd; + } - IDMessage id_message = {0}; - // Receive the response IDMessage - recvAnyMessageType(fds[FDS_BI].fd, &header, &id_message, HEADER_TYPE_ID); - assert(header.type == HEADER_TYPE_ID); - user.id = id_message.id; #ifdef IMPORT_ID - // Save permanently - write(idfile, &user.id, sizeof(user.id)); - close(idfile); + // Save id + write(idfile, &user.id, sizeof(user.id)); #endif - } + loggingf("Got ID: %lu\n", user.id); // for wide character printing @@ -590,21 +623,24 @@ main(int argc, char** argv) tb_present(); // main loop - while (!quit) { + while (!quit) + { err = poll(fds, FDS_MAX, TIMEOUT_POLL); // ignore resize events and use them to redraw the screen assert(err != -1 || errno == EINTR); tb_clear(); - if (fds[FDS_UNI].revents & POLLIN) { + if (fds[FDS_UNI].revents & POLLIN) + { // got data from server HeaderMessage header; nrecv = recv(fds[FDS_UNI].fd, &header, sizeof(header), 0); assert(nrecv != -1); // Server disconnects - if (nrecv == 0) { + if (nrecv == 0) + { // close diconnected server's socket err = close(fds[FDS_UNI].fd); assert(err == 0); @@ -612,8 +648,11 @@ main(int argc, char** argv) // start trying to reconnect in a thread err = pthread_create(&thr_rec, 0, &threadReconnect, (void*)fds); assert(err == 0); - } else { - if (header.version != PROTOCOL_VERSION) { + } + else + { + if (header.version != PROTOCOL_VERSION) + { loggingf("Header received does not match version\n"); continue; } @@ -622,7 +661,8 @@ main(int argc, char** argv) memcpy(addr, &header, sizeof(header)); // Messages handled from server - switch (header.type) { + switch (header.type) + { case HEADER_TYPE_TEXT: recvTextMessage(&msgsArena, fds[FDS_UNI].fd); nmessages++; @@ -641,15 +681,19 @@ main(int argc, char** argv) } } - if (fds[FDS_TTY].revents & POLLIN) { + if (fds[FDS_TTY].revents & POLLIN) + { // got a key event tb_poll_event(&ev); - switch (ev.key) { + switch (ev.key) + { case TB_KEY_CTRL_W: // delete consecutive whitespace - while (ninput) { - if (input[ninput - 1] == L' ') { + while (ninput) + { + if (input[ninput - 1] == L' ') + { input[ninput - 1] = 0; ninput--; continue; @@ -657,7 +701,8 @@ main(int argc, char** argv) break; } // delete until whitespace - while (ninput) { + while (ninput) + { if (input[ninput - 1] == L' ') break; // erase @@ -665,7 +710,8 @@ main(int argc, char** argv) ninput--; } break; - case TB_KEY_CTRL_Z: { + case TB_KEY_CTRL_Z: + { pid_t pid = getpid(); tb_shutdown(); kill(pid, SIGSTOP); @@ -692,10 +738,10 @@ main(int argc, char** argv) HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT); void* addr = ArenaPush(&msgsArena, sizeof(header)); memcpy(addr, &header, sizeof(header)); + header.id = user.id; // Save message TextMessage* sendmsg = ArenaPush(&msgsArena, TEXTMESSAGE_SIZE); - sendmsg->id = user.id; sendmsg->timestamp = time(0); sendmsg->len = ninput; @@ -703,7 +749,7 @@ main(int argc, char** argv) ArenaPush(&msgsArena, text_size); memcpy(&sendmsg->text, input, text_size); - sendAnyMessage(fds[FDS_UNI].fd, &header, sendmsg); + sendAnyMessage(fds[FDS_UNI].fd, header, sendmsg); nmessages++; // also clear input @@ -728,7 +774,8 @@ main(int argc, char** argv) } // These are used to redraw the screen from threads - if (fds[FDS_RESIZE].revents & POLLIN) { + if (fds[FDS_RESIZE].revents & POLLIN) + { // ignore tb_poll_event(&ev); } diff --git a/protocol.h b/protocol.h index 84d636f..5d9344b 100644 --- a/protocol.h +++ b/protocol.h @@ -49,6 +49,7 @@ typedef u64 ID; typedef struct { u16 version; u8 type; + ID id; } HeaderMessage; typedef enum { @@ -60,7 +61,7 @@ typedef enum { HEADER_TYPE_ERROR } HeaderType; // shorthand for creating a header with a value from the enum -#define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t} +#define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t, .id = 0} // from Tsoding video on minicel (https://youtu.be/HCAgvKQDJng?t=4546) // sv(https://github.com/tsoding/sv) #define HEADER_FMT "header: v%d %s(%d)" @@ -69,11 +70,9 @@ typedef enum { // For sending texts to other clients // - 13 bytes for the author // - 8 bytes for the timestamp -// - 8 bytes for id // - 2 bytes for the text length // - x*4 bytes for the text typedef struct { - ID id; u64 timestamp; // timestamp of when the message was sent u16 len; wchar_t* text; // placeholder for indexing @@ -97,19 +96,9 @@ typedef struct { #define INTRODUCTION_FMT "introduction: %s" #define INTRODUCTION_ARG(message) message.author -// Request IntroductionMessage for client with that id. -// See "First connection" if this message is used when the client connects for the first time. -// be used to retrieve information about a client with an unknown ID. -// - 8 bytes for id -typedef struct { - ID id; -} IDMessage; - // Notifying the sender's state, such as "connected", "disconnected", "AFK", ... -// - 8 bytes for id // - 1 byte for type typedef struct { - ID id; u8 type; } PresenceMessage; typedef enum { @@ -141,11 +130,11 @@ typedef struct { u8* headerTypeString(HeaderType type) { - switch (type) { + switch (type) + { case HEADER_TYPE_TEXT: return (u8*)"TextMessage"; case HEADER_TYPE_HISTORY: return (u8*)"HistoryMessage"; case HEADER_TYPE_PRESENCE: return (u8*)"PresenceMessage"; - case HEADER_TYPE_ID: return (u8*)"IDMessage"; case HEADER_TYPE_INTRODUCTION: return (u8*)"IntroductionMessage"; case HEADER_TYPE_ERROR: return (u8*)"ErrorMessage"; default: return (u8*)"Unknown"; @@ -155,7 +144,8 @@ headerTypeString(HeaderType type) u8* presenceTypeString(PresenceType type) { - switch (type) { + switch (type) + { case PRESENCE_TYPE_CONNECTED: return (u8*)"connected"; case PRESENCE_TYPE_DISCONNECTED: return (u8*)"disconnected"; case PRESENCE_TYPE_AFK: return (u8*)"afk"; @@ -166,7 +156,8 @@ presenceTypeString(PresenceType type) u8* errorTypeString(ErrorType type) { - switch (type) { + switch (type) + { case ERROR_TYPE_BADMESSAGE: return (u8*)"bad message"; case ERROR_TYPE_NOTFOUND: return (u8*)"not found"; case ERROR_TYPE_SUCCESS: return (u8*)"success"; @@ -217,10 +208,11 @@ u32 getMessageSize(HeaderType type) { u32 size = 0; - switch (type) { + switch (type) + { case HEADER_TYPE_ERROR: size = sizeof(ErrorMessage); break; case HEADER_TYPE_HISTORY: size = sizeof(HistoryMessage); break; - case HEADER_TYPE_ID: size = sizeof(IDMessage); break; + case HEADER_TYPE_ID: size = sizeof(HeaderMessage); break; case HEADER_TYPE_INTRODUCTION: size = sizeof(IntroductionMessage); break; case HEADER_TYPE_PRESENCE: size = sizeof(PresenceMessage); break; default: assert(0); @@ -237,7 +229,8 @@ recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType t assert(nrecv == sizeof(*header)); s32 size = 0; - switch (type) { + switch (type) + { case HEADER_TYPE_ERROR: case HEADER_TYPE_HISTORY: case HEADER_TYPE_ID: @@ -245,7 +238,8 @@ recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType t case HEADER_TYPE_PRESENCE: size = getMessageSize(header->type); break; - case HEADER_TYPE_TEXT: { + case HEADER_TYPE_TEXT: + { TextMessage* message = anyMessage; size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text); } break; @@ -270,7 +264,8 @@ recvAnyMessage(Arena* arena, s32 fd) assert(nrecv == sizeof(*header)); s32 size = 0; - switch (header->type) { + switch (header->type) + { case HEADER_TYPE_ERROR: case HEADER_TYPE_HISTORY: case HEADER_TYPE_ID: @@ -278,7 +273,8 @@ recvAnyMessage(Arena* arena, s32 fd) case HEADER_TYPE_PRESENCE: size = getMessageSize(header->type); break; - case HEADER_TYPE_TEXT: { + case HEADER_TYPE_TEXT: + { Message result; result.header = header; result.message = recvTextMessage(arena, fd); @@ -303,7 +299,8 @@ Message waitForMessageType(Arena* arena, Arena* queueArena, u32 fd, HeaderType type) { Message message; - while (1) { + while (1) + { message = recvAnyMessage(arena, fd); if (message.header->type == type) break; @@ -315,24 +312,26 @@ waitForMessageType(Arena* arena, Arena* queueArena, u32 fd, HeaderType type) // Generic sending function for sending any type of message to fd // Returns number of bytes sent in message or -1 if there was an error. s32 -sendAnyMessage(u32 fd, HeaderMessage* header, void* anyMessage) +sendAnyMessage(u32 fd, HeaderMessage header, void* anyMessage) { s32 nsend_total; - s32 nsend = send(fd, header, sizeof(*header), 0); + s32 nsend = send(fd, &header, sizeof(header), 0); if (nsend == -1) return nsend; - assert(nsend == sizeof(*header)); + assert(nsend == sizeof(header)); nsend_total = nsend; s32 size = 0; - switch (header->type) { + switch (header.type) + { case HEADER_TYPE_ERROR: case HEADER_TYPE_HISTORY: case HEADER_TYPE_ID: case HEADER_TYPE_INTRODUCTION: case HEADER_TYPE_PRESENCE: - size = getMessageSize(header->type); + size = getMessageSize(header.type); break; - case HEADER_TYPE_TEXT: { + case HEADER_TYPE_TEXT: + { nsend = send(fd, anyMessage, TEXTMESSAGE_SIZE, 0); assert(nsend != -1); assert(nsend == TEXTMESSAGE_SIZE); @@ -345,7 +344,7 @@ sendAnyMessage(u32 fd, HeaderMessage* header, void* anyMessage) anyMessage = &message->text; } break; default: - fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header->type)); + fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header.type)); return 0; } diff --git a/send.c b/send.c index 689c6dc..067cb46 100644 --- a/send.c +++ b/send.c @@ -46,18 +46,9 @@ main(int argc, char** argv) // Get id nrecv = recv(serverfd, &header, sizeof(header), 0); assert(nrecv != -1); - if (header.type == HEADER_TYPE_ERROR) { - ErrorMessage message; - nrecv = recv(serverfd, &message, sizeof(message), 0); - fprintf(stderr, "Got '%s' error.\n'", errorTypeString(message.type)); - close(serverfd); - return 1; - } assert(header.type == HEADER_TYPE_ID); - IDMessage idmessage; - nrecv = recv(serverfd, &idmessage, sizeof(idmessage), 0); - assert(nrecv != -1); - fprintf(stderr, "Got id: %lu\n", idmessage.id); + id = header.id; + fprintf(stderr, "Got id: %lu\n", header.id); } // convert text to wide string @@ -68,13 +59,15 @@ main(int argc, char** argv) text_wide[text_len - 1] = 0; HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT); + header.id = id; TextMessage message; bzero(&message, TEXTMESSAGE_SIZE); - message = (TextMessage){.id = id, .timestamp = time(NULL), .len = text_len}; + message = (TextMessage){.timestamp = time(NULL), .len = text_len}; nsend = send(serverfd, &header, sizeof(header), 0); assert(nsend != -1); fprintf(stderr, "header bytes sent: %d\n", nsend); + nsend = send(serverfd, &message, TEXTMESSAGE_SIZE, 0); assert(nsend != -1); fprintf(stderr, "message bytes sent: %d\n", nsend); diff --git a/server.c b/server.c index 03dbf4d..881abf5 100644 --- a/server.c +++ b/server.c @@ -39,8 +39,8 @@ enum { FDS_STDIN = 0, typedef struct { u8 author[AUTHOR_LEN]; // matches author property on other message types ID id; - struct pollfd* pollunifd; // Index in fds array - struct pollfd* pollbifd; // Index in fds array + struct pollfd* unifd; // Index in fds array + struct pollfd* bifd; // Index in fds array } Client; #define CLIENT_FMT "[%s](%lu)" #define CLIENT_ARG(client) client.author, client.id @@ -50,24 +50,42 @@ typedef enum { BIFD } ClientFD; -// TODO: remove +// TODO: remove global variable // For handing out new ids to connections. -global_variable u32 nclients = 0; +// Start at 1 because this makes 0 an invalid client id. +global_variable u32 nclients = 1; -// Returns client matching id in clients. -// clientsArena is used to get an upper bound. -// Returns 0 if there was no client found. +// Returns client matching id in clients nclients number of clients. +// Returns 0 if no client was found or if id was 0. Client* -getClientByID(Arena* clientsArena, ID id) +getClientByID(Client* clients, u32 nclients, ID id) { - Client* clients = clientsArena->addr; - for (u32 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) { + if (!id) return 0; + + for (u32 i = 0; i < nclients - 1; i++) + { if (clients[i].id == id) return clients + i; } return 0; } +// Returns client matching fd in clients nclients number of clients. +// Returns 0 if no clients was found or if fd was -1. +Client* +getClientByFD(Client* clients, u32 nclients, s32 fd) +{ + if (fd == -1) return 0; + + for (u32 i = 0; i < nclients - 1; i++) + { + if ((clients[i].unifd && clients[i].unifd->fd == fd) || + (clients[i].bifd && clients[i].bifd->fd == fd)) + return clients + i; + } + return 0; +} + // Print TextMessage prettily void printTextMessage(TextMessage* message, Client* client, u8 wide) @@ -75,7 +93,8 @@ printTextMessage(TextMessage* message, Client* client, u8 wide) u8 timestamp[TIMESTAMP_LEN] = {0}; formatTimestamp(timestamp, message->timestamp); - if (wide) { + if (wide) + { setlocale(LC_ALL, ""); wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, client->author, (wchar_t*)&message->text); } else { @@ -89,60 +108,81 @@ printTextMessage(TextMessage* message, Client* client, u8 wide) // for connfd. // Type will filter out only connections matching the type. void -sendToOthers(struct pollfd* fds, u32 nfds, s32 connfd, ClientFD type, HeaderMessage* header, void* anyMessage) +sendToOthers(Client* clients, u32 nclients, Client* client, ClientFD type, HeaderMessage* header, void* anyMessage) { s32 nsend; - for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { - if (fds[i].fd == connfd) continue; - if (fds[i].fd == -1) continue; + for (u32 i = 0; i < nclients - 1; i ++) + { + if (clients + i == client) continue; - nsend = sendAnyMessage(fds[i].fd, header, anyMessage); - loggingf("sendToOthers(%d)|[%s]->%d %d bytes\n", connfd, headerTypeString(header->type), fds[i].fd, nsend); + if (type == UNIFD) + { + nsend = sendAnyMessage(client->unifd->fd, *header, anyMessage); + } + else if (type == BIFD) + { + nsend = sendAnyMessage(client->bifd->fd, *header, anyMessage); + } + assert(nsend != -1); + loggingf("sendToOthers "CLIENT_FMT"|%s %d bytes\n", CLIENT_ARG((*client)), headerTypeString(header->type), nsend); } } // Send header and anyMessage to each connection in fds that is nfds number of connections. // Type will filter out only connections matching the type. void -sendToAll(struct pollfd* fds, u32 nfds, ClientFD type, HeaderMessage* header, void* anyMessage) +sendToAll(Client* clients, u32 nclients, ClientFD type, HeaderMessage* header, void* anyMessage) { - for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { - if (fds[i].fd == -1) continue; - s32 nsend = sendAnyMessage(fds[i].fd, header, anyMessage); - loggingf("sendToAll|[%s]->%d %d bytes\n", headerTypeString(header->type), fds[i].fd, nsend); + s32 nsend; + for (u32 i = 0; i < nclients - 1; i++) + { + if (type == UNIFD) + { + if (clients[i].unifd && clients[i].unifd->fd != -1) + nsend = sendAnyMessage(clients[i].unifd->fd, *header, anyMessage); + } + else if (type == BIFD) + { + if (clients[i].bifd && clients[i].bifd->fd != -1) + nsend = sendAnyMessage(clients[i].bifd->fd, *header, anyMessage); + } + assert(nsend != -1); + loggingf("sendToAll|[%s]->"CLIENT_FMT" %d bytes\n", headerTypeString(header->type), + CLIENT_ARG(clients[i]), + nsend); } } // Disconnect a client by closing the matching file descriptors void -disconnect(struct pollfd* pollfd, Client* client) +disconnect(Client* client) { loggingf("Disconnecting "CLIENT_FMT"\n", CLIENT_ARG((*client))); - if (pollfd[UNIFD].fd != -1) { - close(pollfd[UNIFD].fd); - } - if (pollfd[BIFD].fd != -1) { - close(pollfd[BIFD].fd); + if (client->unifd && client->unifd->fd != -1) + { + close(client->unifd->fd); + client->unifd->fd = -1; + client->unifd = 0; } - pollfd[UNIFD].fd = -1; - pollfd[BIFD].fd = -1; - // TODO: mark as free - if (client) { - client->pollunifd = 0; - client->pollbifd = 0; + if (client->bifd && client->bifd->fd != -1) + { + close(client->bifd->fd); + client->bifd->fd = -1; + client->bifd = 0; } } // Disconnects fds+conn from fds with nfds connections, then send a PresenceMessage to other // clients about disconnection. void -disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn) +disconnectAndNotify(Client* clients, u32 nclients, Client* client) { - disconnect(fds + conn, client); + disconnect(client); local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); - PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_DISCONNECTED}; - sendToAll(fds, nfds, UNIFD, &header, &message); + header.id = client->id; + PresenceMessage message = {.type = PRESENCE_TYPE_DISCONNECTED}; + sendToAll(clients, nclients, UNIFD, &header, &message); } // Receive authentication from pollfd->fd and create client out of it. Look in @@ -150,83 +190,76 @@ disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn) // to clients_file. // See "Authentication" in chatty.h Client* -authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* clientfds) +authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, HeaderMessage header) { s32 nrecv = 0; - Client* clients = clientsArena->addr; - - HeaderMessage header; - nrecv = recv(clientfds[BIFD].fd, &header, sizeof(header), 0); - if (nrecv != sizeof(header)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(header)); - return 0; - } - loggingf("authenticate(%d)|" HEADER_FMT "\n", clientfds[BIFD].fd, HEADER_ARG(header)); - Client* client = 0; - // Scenario 1: Search for existing client - if (header.type == HEADER_TYPE_ID) { - IDMessage message; - nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); - if (nrecv != sizeof(message)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); - return 0; - } - client = getClientByID(clientsArena, message.id); - if (client) { - loggingf("authenticate(%d)|found [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); - header.type = HEADER_TYPE_ERROR; - // TODO: allow multiple connections - if (client->pollunifd != 0 || client->pollbifd != 0) { - loggingf("authenticate(%d)|err: already connected\n", clientfds[BIFD].fd); - ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_ALREADYCONNECTED); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); - return 0; - } - ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); - } else { - loggingf("authenticate(%d)|notfound\n", clientfds[BIFD].fd); + /* Scenario 1: Search for existing client */ + if (header.type == HEADER_TYPE_ID) + { + client = getClientByID((Client*)clientsArena->addr, nclients, header.id); + if (!client) + { + loggingf("authenticate(%d)|notfound\n", pollfd->fd); header.type = HEADER_TYPE_ERROR; ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); + sendAnyMessage(pollfd->fd, header, &error_message); return 0; } - // Scenario 2: Create a new client - } else if (header.type == HEADER_TYPE_INTRODUCTION) { + + loggingf("authenticate(%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id); + header.type = HEADER_TYPE_ERROR; + if (!client->unifd) + client->unifd = pollfd; + else if(!client->bifd) + client->bifd = pollfd; + else + assert(0); + + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); + header.type = HEADER_TYPE_ERROR; + sendAnyMessage(pollfd->fd, header, &error_message); + + return client; + } + + /* Scenario 2: Create a new client */ + if (header.type == HEADER_TYPE_INTRODUCTION) + { IntroductionMessage message; - nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); - if (nrecv != sizeof(message)) { - loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); + nrecv = recv(pollfd->fd, &message, sizeof(message), 0); + if (nrecv != sizeof(message)) + { + loggingf("authenticate(%d)|err: %d/%lu bytes\n", pollfd->fd, nrecv, sizeof(message)); return 0; } // Copy metadata from IntroductionMessage - client = ArenaPush(clientsArena, sizeof(*clients)); + client = ArenaPush(clientsArena, sizeof(*client)); memcpy(client->author, message.author, AUTHOR_LEN); client->id = nclients; nclients++; - // Save client #ifdef IMPORT_ID write(clients_file, client, sizeof(*client)); #endif - loggingf("authenticate(%d)|Added [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); + loggingf("authenticate(%d)|Added [%s](%lu)\n", pollfd->fd, client->author, client->id); + // Send ID to new client HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - IDMessage id_message = {.id = client->id}; - sendAnyMessage(clientfds[BIFD].fd, &header, &id_message); - } else { - loggingf("authenticate(%d)|Wrong header expected %s or %s\n", clientfds[BIFD].fd, headerTypeString(HEADER_TYPE_INTRODUCTION), headerTypeString(HEADER_TYPE_ID)); - return 0; - } - assert(client != 0); + header.id = client->id; + s32 nsend = send(pollfd->fd, &header, sizeof(header), 0); + assert(nsend != -1); + assert(nsend == sizeof(header)); - client->pollunifd = clientfds; - client->pollbifd = clientfds + 1; + return client; + } - return client; + loggingf("authenticate(%d)|Wrong header expected %s or %s\n", pollfd->fd, + headerTypeString(HEADER_TYPE_INTRODUCTION), + headerTypeString(HEADER_TYPE_ID)); + return 0; } int @@ -236,9 +269,11 @@ main(int argc, char** argv) logfd = 2; // optional logging - if (argc > 1) { + if (argc > 1) + { if (*argv[1] == '-') - if (argv[1][1] == 'l') { + if (argv[1][1] == 'l') + { logfd = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600); assert(logfd != -1); } @@ -299,192 +334,171 @@ main(int argc, char** argv) assert(fstat(clients_file, &statbuf) != -1); read(clients_file, clients, statbuf.st_size); - if (statbuf.st_size > 0) { + if (statbuf.st_size > 0) + { ArenaPush(&clientsArena, statbuf.st_size); loggingf("Imported %lu client(s)\n", statbuf.st_size / sizeof(*clients)); nclients += statbuf.st_size / sizeof(*clients); + + // Reset pointers on imported clients + for (u32 i = 0; i < nclients - 1; i++) + { + clients[i].unifd = 0; + clients[i].bifd = 0; + } } - for (u32 i = 0; i < nclients; i++) + for (u32 i = 0; i < nclients - 1; i++) loggingf("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i])); #endif // Initialize the rest of the fds array for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++) fds[i] = newpollfd; - // Reset file descriptors on imported clients - for (u32 i = 0; i < CLIENTS_SIZE; i++) { - clients[i].pollunifd = 0; - clients[i].pollbifd = 0; - } - while (1) { + while (1) + { s32 err = poll(fds, FDS_SIZE, TIMEOUT); assert(err != -1); - if (fds[FDS_STDIN].revents & POLLIN) { + if (fds[FDS_STDIN].revents & POLLIN) + { u8 c; // exit on ctrl-d if (!read(fds[FDS_STDIN].fd, &c, 1)) break; - } else if (fds[FDS_SERVER].revents & POLLIN) { + } + else if (fds[FDS_SERVER].revents & POLLIN) + { // TODO: what if we are not aligned by 2 anymore? - s32 unifd = accept(serverfd, 0, 0); - s32 bifd = accept(serverfd, 0, 0); + s32 clientfd = accept(serverfd, 0, 0); - if (unifd == -1 || bifd == -1) { - loggingf("Error while accepting connection (%d,%d)\n", unifd, bifd); - if (unifd != -1) close(unifd); - if (bifd != -1) close(bifd); + if (clientfd == -1) + { + loggingf("Error while accepting connection (%d)\n", clientfd); continue; - } else - loggingf("New connection(%d,%d)\n", unifd, bifd); + } + else + loggingf("New connection(%d)\n", clientfd); - // TODO: find empty space in arena - if (nclients + 1 == MAX_CONNECTIONS) { + // TODO: find empty space in arena (fragmentation) + if (nclients + 1 == MAX_CONNECTIONS) + { local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); local_persist ErrorMessage message = ERROR_INIT(ERROR_TYPE_TOOMANYCONNECTIONS); - sendAnyMessage(unifd, &header, &message); - if (unifd != -1) - close(unifd); - if (bifd != -1) - close(bifd); + sendAnyMessage(clientfd, header, &message); + if (clientfd != -1) + close(clientfd); loggingf("Max clients reached. Rejected connection\n"); - } else { + } + else + { // no more space, allocate - struct pollfd* clientfds = ArenaPush(&fdsArena, 2 * sizeof(*clientfds)); - clientfds[UNIFD].fd = unifd; - clientfds[UNIFD].events = POLLIN; - clientfds[BIFD].fd = bifd; - clientfds[BIFD].events = POLLIN; - loggingf("Added pollfd(%d,%d)\n", unifd, bifd); + struct pollfd* pollfd = ArenaPush(&fdsArena, sizeof(*pollfd)); + pollfd->fd = clientfd; + loggingf("Added pollfd(%d)\n", clientfd); } } - // Check for messages from clients in their unifd - for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn += 2) { + for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn++) + { if (!(fds[conn].revents & POLLIN)) continue; if (fds[conn].fd == -1) continue; loggingf("Message unifd (%d)\n", fds[conn].fd); - // Get client associated with connection - Client* client = 0; - for (u32 j = 0; j < CLIENTS_SIZE; j++) { - if (!clients[j].pollunifd) - continue; - if (clients[j].pollunifd == fds + conn) { - client = clients + j; - break; - } - } - if (!client) { - loggingf("No client associated(%d)\n", fds[conn].fd); - close(fds[conn].fd); - continue; - } - loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); - // We received a message, try to parse the header HeaderMessage header; s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); - if (nrecv == 0) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); - continue; - } else if (nrecv != sizeof(header)) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); + assert(nrecv != -1); + + Client* client; + if (nrecv != sizeof(header)) + { + client = getClientByFD(clients, nclients, fds[conn].fd); + loggingf(CLIENT_FMT" %d/%lu bytes\n", CLIENT_ARG((*client)), nrecv, sizeof(header)); + if (client) + { + disconnectAndNotify(clients, nclients, client); + loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); + } + else + { + loggingf("Got error from unauntheticated client\n"); + close(fds[conn].fd); + fds[conn].fd = -1; + } continue; } loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); - switch (header.type) { - case HEADER_TYPE_TEXT: { - TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); - loggingf("Received(%d)", fds[conn].fd); - printTextMessage(text_message, client, 0); + // Authentication + if (header.id) + { + client = getClientByID(clients, nclients, header.id); + if (!client) + { + loggingf("No client for id %d\n", fds[conn].fd); - sendToOthers(fds, FDS_SIZE, fds[conn].fd, UNIFD, &header, text_message); - } break; - // handle request for information about client id - default: - loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); - disconnectAndNotify(client, fds, FDS_SIZE, conn); - continue; - } - } + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - // Check for messages from clients in their bifd - for (u32 conn = FDS_CLIENTS + BIFD; conn < FDS_SIZE; conn += 2) { - if (!(fds[conn].revents & POLLIN)) continue; - if (fds[conn].fd == -1) continue; - loggingf("Message bifd (%d)\n", fds[conn].fd); - - // Get client associated with connection - Client* client = 0; - for (u32 j = 0; j < CLIENTS_SIZE; j++) { - if (!clients[j].pollbifd) - continue; - if (clients[j].pollbifd == fds + conn) { - client = clients + j; - break; + sendAnyMessage(fds[conn].fd, header, &message); + + // Reject connection + fds[conn].fd = -1; + close(fds[conn].fd); + } + else + { + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_SUCCESS); + + sendAnyMessage(fds[conn].fd, header, &message); } } - if (!client) { + else + { loggingf("No client for connection(%d)\n", fds[conn].fd); -#ifdef IMPORT_ID - client = authenticate(&clientsArena, clients_file, fds + conn - 1); -#else - client = authenticate(&clientsArena, 0, fds + conn - 1); -#endif - // If the client sent an IDMessage but no ID was found authenticate() could return null - if (!client) { - loggingf("Could not initialize client\n"); - disconnect(fds + conn, 0); - } else { // client was added/connected + + client = authenticate(&clientsArena, clients_file, fds + conn, header); + + if (!client) + loggingf("Could not initialize client (%d)\n", fds[conn].fd); + else if (!client->bifd) + { + // Send connected message to other clients if this was the first time -> bifd is + // not set yet. local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); - PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_CONNECTED}; - sendToOthers(fds, FDS_SIZE, fds[conn - BIFD].fd, UNIFD, &header, &message); + header.id = client->id; + PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED}; + sendToOthers(clients, nclients, client, UNIFD, &header, &message); } - continue; } - loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); - - // We received a message, try to parse the header - HeaderMessage header; - s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); - if (nrecv == 0) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); - continue; - } else if (nrecv != sizeof(header)) { - disconnectAndNotify(client, fds, FDS_SIZE, conn); - loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); - continue; - } - loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); switch (header.type) { - case HEADER_TYPE_ID: { - // handle request for information about client id - IDMessage id_message; - nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0); - - Client* client = getClientByID(&clientsArena, id_message.id); - if (!client) { - local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); - local_persist ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - sendAnyMessage(fds[conn].fd, &header, &error_message); - loggingf("Could not find %lu\n", id_message.id); - break; - } + /* Send text message to all other clients */ + case HEADER_TYPE_TEXT: + { + TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); + loggingf("Received(%d)", fds[conn].fd); + printTextMessage(text_message, client, 0); + + sendToOthers(clients, nclients, client, UNIFD, &header, text_message); + } break; + /* Send back client information */ + case HEADER_TYPE_ID: + { HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); IntroductionMessage introduction_message; + header.id = client->id; memcpy(introduction_message.author, client->author, AUTHOR_LEN); - sendAnyMessage(fds[conn].fd, &header, &introduction_message); + s32 nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message); + assert(nrecv != -1); } break; default: - loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); - disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("Unhandled '%s' from "CLIENT_FMT"(%d)\n", headerTypeString(header.type), + CLIENT_ARG((*client)), + fds[conn].fd); + disconnectAndNotify(client, nclients, client); continue; } } -- cgit v1.2.3