From 5e64be597ba0b30fcb99de33da605c93ddd30fdc Mon Sep 17 00:00:00 2001 From: Raymaekers Luca Date: Mon, 4 Nov 2024 21:23:03 +0100 Subject: Connect BIFD first to avoid errors --- README.md | 3 ++ chatty.c | 91 +++++++++++++++++++++----------------- protocol.h | 50 ++++++++++++++------- server.c | 146 ++++++++++++++++++++++++++++++++++++++----------------------- 4 files changed, 178 insertions(+), 112 deletions(-) diff --git a/README.md b/README.md index 50e9b7a..3889dad 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,9 @@ The idea is the following: - [x] id2string on clients - [x] ctrl+z to suspend - [ ] bug(tb_printf_wrap): text after pfx is wrapped one too soon +- [ ] bug: when reconnecting nrecv != -1 +- [ ] bug: when disconnecting +- [ ] use error type success to say that authentication succeeded ## server - [x] import clients diff --git a/chatty.c b/chatty.c index 6bb133b..a1747a0 100644 --- a/chatty.c +++ b/chatty.c @@ -22,8 +22,8 @@ // enable logging #define LOGGING -enum { FDS_UNI = 0, // for one-way communication with the server (eg. TextMessage) - FDS_BI, // For two-way communication with the server (eg. IDMessage) +enum { FDS_BI = 0, // for one-way communication with the server (eg. TextMessage) + FDS_UNI, // For two-way communication with the server (eg. IDMessage) FDS_TTY, FDS_RESIZE, FDS_MAX }; @@ -63,6 +63,7 @@ popup(u32 fg, u32 bg, char* text) User* getUserByID(Arena* clientsArena, ID id) { + // User is not in the clientsArena if (id == user.id) return &user; User* clients = clientsArena->addr; @@ -81,10 +82,10 @@ addUserInfo(Arena* clientsArena, s32 fd, u64 id) { // Request information about ID HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - header.id = id; - s32 nsend = send(fd, &header, sizeof(header), 0); + header.id = user.id; + IDMessage message = {id}; + s32 nsend = sendAnyMessage(fd, header, &message); assert(nsend != -1); - assert(nsend == sizeof(header)); // Wait for response IntroductionMessage introduction_message; @@ -112,44 +113,46 @@ getConnection(struct sockaddr_in* address) return fd; } -ID +// Authenticates a file descriptor with either the user's id if non-zero or +// it's information if id is zero. +// Returns 0 if an error occurred. Non-zero on success. +u32 authenticate(User* user, s32 fd) { + /* Scenario 1: Already have an ID */ if (user->id) { HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - header.id = user->id; - s32 nsend = send(fd, &header, sizeof(header), 0); - assert(nsend == -1); - assert(nsend == sizeof(header)); + IDMessage message = {user->id}; + s32 nsend = sendAnyMessage(fd, header, &message); + assert(nsend != -1); - s32 nrecv = recv(fd, &header, sizeof(header), 0); + ErrorMessage error_message; + s32 nrecv = recvAnyMessageType(fd, &header, &error_message, HEADER_TYPE_ERROR); + assert(nrecv != -1); + // TODO: handle not found if (nrecv == 0) return 0; - assert(nrecv != -1); - assert(nrecv == sizeof(header)); - assert(header.type == HEADER_TYPE_ID); - if (header.id == user->id) - return header.id; + + if (error_message.type == ERROR_TYPE_SUCCESS) + return 1; else return 0; } + /* Scenario 2: No ID, request one from server */ else { HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); IntroductionMessage message; memcpy(message.author, user->author, AUTHOR_LEN); - sendAnyMessage(fd, header, &message); + s32 nsend = sendAnyMessage(fd, header, &message); + assert(nsend != -1); - s32 nrecv = recv(fd, &header, sizeof(header), 0); - if (nrecv == 0) - return 0; + IDMessage id_message; + s32 nrecv = recvAnyMessageType(fd, &header, &id_message, HEADER_TYPE_ID); assert(nrecv != -1); - assert(nrecv == sizeof(header)); - assert(header.type == HEADER_TYPE_ID); - assert(header.id); - user->id = header.id; - return header.id; + user->id = id_message.id; + return 1; } } @@ -172,30 +175,31 @@ threadReconnect(void* fds_ptr) // timeout nanosleep(&t, &t); - unifd = getConnection(&address); - if (unifd == -1) + bifd = getConnection(&address); + if (bifd == -1) { loggingf("errno: %d\n", errno); continue; } - bifd = getConnection(&address); - if (bifd == -1) + unifd = getConnection(&address); + if (unifd == -1) { loggingf("errno: %d\n", errno); - close(unifd); + close(bifd); continue; } loggingf("Reconnect succeeded (%d, %d), authenticating\n", unifd, bifd); - if (authenticate(&user, unifd) && - authenticate(&user, bifd)) + if (authenticate(&user, bifd) && + authenticate(&user, unifd)) { break; } - close(unifd); close(bifd); + close(unifd); + loggingf("Failed, retrying...\n"); } @@ -559,8 +563,8 @@ main(int argc, char** argv) // poopoo C cannot infer type struct pollfd fds[FDS_MAX] = { - {-1, POLLIN, 0}, // FDS_UNI {-1, POLLIN, 0}, // FDS_BI + {-1, POLLIN, 0}, // FDS_UNI {-1, POLLIN, 0}, // FDS_TTY {-1, POLLIN, 0}, // FDS_RESIZE }; @@ -581,26 +585,31 @@ main(int argc, char** argv) /* Authentication */ { s32 unifd, bifd; - unifd = getConnection(&address); - if (unifd == -1) + bifd = getConnection(&address); + if (bifd == -1) { loggingf("errno: %d\n", errno); return 1; } - bifd = getConnection(&address); - if (bifd == -1) + unifd = getConnection(&address); + if (unifd == -1) { loggingf("errno: %d\n", errno); return 1; } - loggingf("(%d,%d)\n", unifd, bifd); - if (!authenticate(&user, unifd)) + loggingf("(%d,%d)\n", bifd, unifd); + if (!authenticate(&user, bifd) || + !authenticate(&user, unifd)) { loggingf("errno: %d\n", errno); return 1; } - fds[FDS_UNI].fd = unifd; + else + { + loggingf("Authenticated (%d,%d)\n", bifd, unifd); + } fds[FDS_BI].fd = bifd; + fds[FDS_UNI].fd = unifd; } #ifdef IMPORT_ID diff --git a/protocol.h b/protocol.h index 40bc9f3..d35126c 100644 --- a/protocol.h +++ b/protocol.h @@ -16,19 +16,32 @@ // - strings are sent with their null terminator // /// Authentication -// This is what happens when the first time a client connects. -// Scenario 1. We alreayd have an ID +// Each header contains the id of the sender, because ids start at 1 +// message with id 0 is considered unauthenticated. +// +// When the server receives a header with id 0, this can happen +// Scenario 1: IDMessage, client already has an ID // 1. client-> Send own ID // 2. server-> knows ID? // y. server-> Success // n. 1. server-> Error 'notfound' // 2. client-> exit -// Scenario 2. We do not have an ID +// Scenario 2: IntroductionMessage, client requests a new ID // 1. client-> Introduces // 2. server-> Sends & Saves ID -// 3. client-> Saves ID +// 3. Save ID +// +// BIFD & UNIFD +// Each client has 2 connections that must be authenticated one for 2-way +// communication and one for 1-way communication. Respectively BIFD and +// UNIFD. BIFD must be authenticated first and is meant for requests such +// as getting an IntroductionMessage for a sent IDMessage. +// UNIFD is for messages that are like notifications. For example +// PresenceMessage that tells us when another user connected. +// These two connections separate these message types so we do not have to +// worry about receiving a PresenceMessage when waiting for an a response. // -/// Naming convention +/// Naming conventions // Messages end with the Message suffix (eg. TextMessag, HistoryMessage) // // A function that is coupled to a type works like @@ -64,8 +77,8 @@ typedef enum { #define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t, .id = 0} // from Tsoding video on minicel (https://youtu.be/HCAgvKQDJng?t=4546) // sv(https://github.com/tsoding/sv) -#define HEADER_FMT "header: v%d %s(%d)" -#define HEADER_ARG(header) header.version, headerTypeString(header.type), header.type +#define HEADER_FMT "header: v%d %s(%d) [%d]" +#define HEADER_ARG(header) header.version, headerTypeString(header.type), header.type, header.id // For sending texts to other clients // - 13 bytes for the author @@ -121,6 +134,10 @@ typedef enum { } ErrorType; #define ERROR_INIT(t) {.type = t} +typedef struct { + ID id; +} IDMessage; + typedef struct { s32 nrecv; TextMessage* message; @@ -215,6 +232,7 @@ getMessageSize(HeaderType type) case HEADER_TYPE_HISTORY: size = sizeof(HistoryMessage); break; case HEADER_TYPE_INTRODUCTION: size = sizeof(IntroductionMessage); break; case HEADER_TYPE_PRESENCE: size = sizeof(PresenceMessage); break; + case HEADER_TYPE_ID: size = sizeof(IDMessage); break; default: assert(0); } return size; @@ -235,6 +253,7 @@ recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType t case HEADER_TYPE_HISTORY: case HEADER_TYPE_INTRODUCTION: case HEADER_TYPE_PRESENCE: + case HEADER_TYPE_ID: size = getMessageSize(header->type); break; case HEADER_TYPE_TEXT: @@ -265,12 +284,6 @@ recvAnyMessage(Arena* arena, s32 fd) s32 size = 0; switch (header->type) { - case HEADER_TYPE_ERROR: - case HEADER_TYPE_HISTORY: - case HEADER_TYPE_INTRODUCTION: - case HEADER_TYPE_PRESENCE: - size = getMessageSize(header->type); - break; case HEADER_TYPE_TEXT: { Message result; @@ -278,7 +291,10 @@ recvAnyMessage(Arena* arena, s32 fd) result.message = recvTextMessage(arena, fd); return result; } break; - default: assert(0); break; + default: + { + size = getMessageSize(header->type); + } break; } void* message = ArenaPush(arena, size); @@ -315,6 +331,7 @@ sendAnyMessage(u32 fd, HeaderMessage header, void* anyMessage) s32 nsend_total; s32 nsend = send(fd, &header, sizeof(header), 0); if (nsend == -1) return nsend; + loggingf("sendAnyMessage (%d)|sending "HEADER_FMT"\n", fd, HEADER_ARG(header)); assert(nsend == sizeof(header)); nsend_total = nsend; @@ -325,6 +342,7 @@ sendAnyMessage(u32 fd, HeaderMessage header, void* anyMessage) case HEADER_TYPE_HISTORY: case HEADER_TYPE_INTRODUCTION: case HEADER_TYPE_PRESENCE: + case HEADER_TYPE_ID: size = getMessageSize(header.type); break; case HEADER_TYPE_TEXT: @@ -341,8 +359,8 @@ sendAnyMessage(u32 fd, HeaderMessage header, void* anyMessage) anyMessage = &message->text; } break; default: - fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header.type)); - return 0; + loggingf("sendAnyMessage (%d)|Cannot send %s\n", fd, headerTypeString(header.type)); + return -1; } nsend = send(fd, anyMessage, size, 0); diff --git a/server.c b/server.c index 47b2df3..cd0cc70 100644 --- a/server.c +++ b/server.c @@ -37,15 +37,15 @@ enum { FDS_STDIN = 0, typedef struct { u8 author[AUTHOR_LEN]; // matches author property on other message types ID id; - struct pollfd* unifd; // Index in fds array struct pollfd* bifd; // Index in fds array + struct pollfd* unifd; // Index in fds array } Client; #define CLIENT_FMT "[%s](%lu)" #define CLIENT_ARG(client) client.author, client.id typedef enum { - UNIFD = 0, - BIFD + BIFD = 0, + UNIFD, } ClientFD; // TODO: remove global variable @@ -104,46 +104,62 @@ printTextMessage(TextMessage* message, Client* client, u8 wide) // Send header and anyMessage to each connection in fds that is nfds number of connections except // for connfd. +// Does not send if pollfd is not set or pollfd->fd is -1. // Type will filter out only connections matching the type. void sendToOthers(Client* clients, u32 nclients, Client* client, ClientFD type, HeaderMessage* header, void* anyMessage) { - s32 nsend; - for (u32 i = 0; i < nclients; i ++) + s32 nsend, fd; + for (u32 i = 0; i < nclients - 1; i ++) { if (clients + i == client) continue; if (type == UNIFD) { - nsend = sendAnyMessage(client->unifd->fd, *header, anyMessage); + if (clients[i].unifd && clients[i].unifd->fd != -1) + fd = clients[i].unifd->fd; + else + continue; } else if (type == BIFD) { - nsend = sendAnyMessage(client->bifd->fd, *header, anyMessage); + if (clients[i].bifd && clients[i].bifd->fd != -1) + fd = clients[i].bifd->fd; + else + continue; } + nsend = sendAnyMessage(fd, *header, anyMessage); + assert(nsend != -1); - loggingf("sendToOthers "CLIENT_FMT"|%s %d bytes\n", CLIENT_ARG((*client)), headerTypeString(header->type), nsend); + loggingf("sendToOthers "CLIENT_FMT"|%d<-%s %d bytes\n", CLIENT_ARG((clients[i])), fd, headerTypeString(header->type), nsend); } } // Send header and anyMessage to each connection in fds that is nfds number of connections. +// Does not send if pollfd is not set or pollfd->fd is -1. // Type will filter out only connections matching the type. void sendToAll(Client* clients, u32 nclients, ClientFD type, HeaderMessage* header, void* anyMessage) { s32 nsend; - for (u32 i = 0; i < nclients; i++) + for (u32 i = 0; i < nclients - 1; i++) { if (type == UNIFD) { if (clients[i].unifd && clients[i].unifd->fd != -1) nsend = sendAnyMessage(clients[i].unifd->fd, *header, anyMessage); + else + continue; } else if (type == BIFD) { if (clients[i].bifd && clients[i].bifd->fd != -1) nsend = sendAnyMessage(clients[i].bifd->fd, *header, anyMessage); + else + continue; } + else + assert(0); assert(nsend != -1); loggingf("sendToAll|[%s]->"CLIENT_FMT" %d bytes\n", headerTypeString(header->type), CLIENT_ARG(clients[i]), @@ -187,6 +203,8 @@ disconnectAndNotify(Client* clients, u32 nclients, Client* client) // clientsArena if it already exists. Otherwise push a new onto the arena and write its information // to clients_file. // See "Authentication" in chatty.h +// Assumes that the client will send a IDMessage or IntroductionMessage +// Returns authenticated client Client* authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, HeaderMessage header) { @@ -198,7 +216,11 @@ authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, Heade /* Scenario 1: Search for existing client */ if (header.type == HEADER_TYPE_ID) { - client = getClientByID((Client*)clientsArena->addr, nclients, header.id); + IDMessage message; + s32 nrecv = recv(pollfd->fd, &message, sizeof(message), 0); + assert(nrecv == sizeof(message)); + + client = getClientByID((Client*)clientsArena->addr, nclients, message.id); if (!client) { loggingf("authenticate (%d)|notfound\n", pollfd->fd); @@ -207,24 +229,26 @@ authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, Heade sendAnyMessage(pollfd->fd, header, &error_message); return 0; } + else + { + loggingf("authenticate (%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id); + header.type = HEADER_TYPE_ERROR; + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); + sendAnyMessage(pollfd->fd, header, &error_message); + } - loggingf("authenticate (%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id); - if (!client->unifd) - client->unifd = pollfd; - else if(!client->bifd) + if (!client->bifd) client->bifd = pollfd; + else if (!client->unifd) + client->unifd = pollfd; else assert(0); - header.type = HEADER_TYPE_ERROR; - ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); - sendAnyMessage(pollfd->fd, header, &error_message); return client; } - /* Scenario 2: Create a new client */ - if (header.type == HEADER_TYPE_INTRODUCTION) + else if (header.type == HEADER_TYPE_INTRODUCTION) { IntroductionMessage message; nrecv = recv(pollfd->fd, &message, sizeof(message), 0); @@ -239,10 +263,10 @@ authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, Heade memcpy(client->author, message.author, AUTHOR_LEN); client->id = nclients; - if (!client->unifd) - client->unifd = pollfd; - else if(!client->bifd) - client->bifd = pollfd; + if (!client->bifd) + client->bifd = pollfd; + else if (!client->unifd) + client->unifd = pollfd; else assert(0); @@ -255,10 +279,11 @@ authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, Heade // Send ID to new client HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); - header.id = client->id; - s32 nsend = send(pollfd->fd, &header, sizeof(header), 0); + IDMessage id_message; + id_message.id = client->id; + + s32 nsend = sendAnyMessage(pollfd->fd, header, &id_message); assert(nsend != -1); - assert(nsend == sizeof(header)); return client; } @@ -349,13 +374,13 @@ main(int argc, char** argv) nclients += statbuf.st_size / sizeof(*clients); // Reset pointers on imported clients - for (u32 i = 0; i < nclients; i++) + for (u32 i = 0; i < nclients - 1; i++) { clients[i].unifd = 0; clients[i].bifd = 0; } } - for (u32 i = 0; i < nclients; i++) + for (u32 i = 0; i < nclients - 1; i++) loggingf("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i])); #else clients_file = 0; @@ -425,9 +450,8 @@ main(int argc, char** argv) client = getClientByFD(clients, nclients, fds[conn].fd); if (client) { - loggingf(CLIENT_FMT" %d/%lu bytes\n", CLIENT_ARG((*client)), nrecv, sizeof(header)); + loggingf("Received %d/%lu bytes "CLIENT_FMT"\n", nrecv, sizeof(header), CLIENT_ARG((*client))); disconnectAndNotify(clients, nclients, client); - loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); } else { @@ -437,28 +461,10 @@ main(int argc, char** argv) } continue; } - loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); + loggingf("Received(%d): " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); // Authentication - if (header.id) - { - client = getClientByID(clients, nclients, header.id); - if (!client) - { - loggingf("No client for id %d\n", fds[conn].fd); - - header.type = HEADER_TYPE_ERROR; - ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND); - - sendAnyMessage(fds[conn].fd, header, &message); - - // Reject connection - fds[conn].fd = -1; - close(fds[conn].fd); - continue; - } - } - else + if (!header.id) { loggingf("No client for connection(%d)\n", fds[conn].fd); @@ -466,10 +472,10 @@ main(int argc, char** argv) if (!client) loggingf("Could not initialize client (%d)\n", fds[conn].fd); - else if (!client->bifd) + /* This is the first time a message is sent, because unifd is not yet set. */ + else if (!client->unifd) { - // Send connected message to other clients if this was the first time -> bifd is - // not set yet. + loggingf("Send connected message\n"); local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); header.id = client->id; PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED}; @@ -478,12 +484,28 @@ main(int argc, char** argv) continue; } + client = getClientByID(clients, nclients, header.id); + if (!client) + { + loggingf("No client for id %d\n", fds[conn].fd); + + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND); + + sendAnyMessage(fds[conn].fd, header, &message); + + // Reject connection + fds[conn].fd = -1; + close(fds[conn].fd); + continue; + } + switch (header.type) { /* Send text message to all other clients */ case HEADER_TYPE_TEXT: { TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); - loggingf("Received(%d)", fds[conn].fd); + loggingf("Received(%d): ", fds[conn].fd); printTextMessage(text_message, client, 0); sendToOthers(clients, nclients, client, UNIFD, &header, text_message); @@ -491,12 +513,26 @@ main(int argc, char** argv) /* Send back client information */ case HEADER_TYPE_ID: { + IDMessage id_message; + s32 nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0); + assert(nrecv == sizeof(id_message)); + + client = getClientByID(clients, nclients, id_message.id); + if (!client) + { + header.type = HEADER_TYPE_ERROR; + ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND); + s32 nsend = sendAnyMessage(fds[conn].fd, header, &message); + assert(nsend != -1); + break; + } + HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); IntroductionMessage introduction_message; header.id = client->id; memcpy(introduction_message.author, client->author, AUTHOR_LEN); - s32 nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message); + nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message); assert(nrecv != -1); } break; default: -- cgit v1.2.3