diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | README.md | 13 | ||||
-rw-r--r-- | chatty.c | 393 | ||||
-rw-r--r-- | chatty.h | 220 | ||||
-rw-r--r-- | protocol.h | 360 | ||||
-rw-r--r-- | send.c | 78 | ||||
-rw-r--r-- | server.c | 567 |
7 files changed, 1134 insertions, 500 deletions
@@ -1,3 +1,6 @@ chatty send server +_id +_clients +*.log @@ -15,19 +15,28 @@ The idea is the following: - [x] wrapping messages - [x] bug: when sending message after diconnect (serverfd?) - [x] Handle disconnection thiin a thread, the best way would be -- [ ] ctrl+z to suspend +- [x] Add limit_y to printf_wrap +- [x] id2string on clients +- [x] ctrl+z to suspend +- [ ] bug(tb_printf_wrap): text after pfx is wrapped one too soon ## server +- [x] import clients - [ ] check if when sending and the client is offline (due to connection loss) what happens - [ ] timeout on recv? - [ ] use threads to handle clients/ timeout when receiving because a client could theoretically stall the entire server. - [ ] do not crash on errors from clients - implement error message? + - timeout on recv with setsockopt ## common - [x] handle messages that are too large -- [ ] log messages to file (save history) +- [x] refactor i&self into conn +- [x] logging +- [x] Req|Inf connection per client +- [ ] bug: blocking after `Added pollfd`, after importing a client and then connecting with the + id/or without? After reconnection fails chatty blocks (remove sleep) - [ ] connect/disconnections messages - [ ] use IP address / domain - [ ] chat history @@ -2,6 +2,7 @@ #include "termbox2.h" #include "chatty.h" +#include "protocol.h" #include <arpa/inet.h> #include <assert.h> @@ -14,19 +15,39 @@ // time to reconnect in seconds #define TIMEOUT_RECONNECT 1 #define INPUT_LIMIT 512 - -// must be of AUTHOR_LEN -1 -static u8 username[AUTHOR_LEN] = "(null)"; -// file descriptros for polling -static struct pollfd* fds = NULL; -// mutex for locking fds when in thread_reconnect() -static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER; - -enum { FDS_SERVER = 0, +// Filepath where user ID is stored +#define ID_FILE "_id" +// Import id from ID_FILE +#define IMPORT_ID +// Filepath where logged +#define LOGFILE "chatty.log" +// enable logging +#define LOGGING + +enum { FDS_UNI = 0, // for one-way communication with the server (eg. TextMessage) + FDS_BI, // For two-way communication with the server (eg. IDMessage) FDS_TTY, FDS_RESIZE, FDS_MAX }; +typedef struct { + u8 author[AUTHOR_LEN]; + ID id; +} Client; +#define CLIENT_FMT "[%s](%lu)" +#define CLIENT_ARG(client) client.author, client.id + +typedef struct { + s32 err; // Error while connecting + s32 unifd; + s32 bifd; +} ConnectionResult; + +// Client used by chatty +global_variable Client user = {0}; +// Address of chatty server +global_variable struct sockaddr_in address; + // fill str array with char void fillstr(u32* str, u32 ch, u32 len) @@ -44,39 +65,118 @@ popup(u32 fg, u32 bg, char* text) tb_print(global.width / 2 - len / 2, global.height / 2, fg, bg, text); } +// Returns client in clientsArena matching id +// Returns user if the id was the user's ID +// Returns 0 if nothing was found +Client* +getClientById(Arena* clientsArena, ID id) +{ + if (id == user.id) return &user; + + Client* clients = clientsArena->addr; + for (u64 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) { + if (clients[i].id == id) + return clients + i; + } + return 0; +} + +// Request information of client from fd byd id and add it to clientsArena +// Returns pointer to added client +Client* +addClientInfo(Arena* clientsArena, s32 fd, u64 id) +{ + // Request information about ID + HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); + IDMessage id_message = {.id = id}; + sendAnyMessage(fd, &header, &id_message); + + Client* client = ArenaPush(clientsArena, sizeof(*client)); + + // Wait for response + IntroductionMessage introduction_message; + recvAnyMessageType(fd, &header, &introduction_message, HEADER_TYPE_INTRODUCTION); + + // Add the information + memcpy(client->author, introduction_message.author, AUTHOR_LEN); + client->id = id; + + loggingf("Got " CLIENT_FMT "\n", CLIENT_ARG((*client))); + return client; +} + +// Tries to connect to address and populates resulting file descriptors in ConnectionResult. +ConnectionResult +getConnection(struct sockaddr_in* address) +{ + ConnectionResult result; + result.unifd = socket(AF_INET, SOCK_STREAM, 0); + result.bifd = socket(AF_INET, SOCK_STREAM, 0); + result.err = connect(result.unifd, (struct sockaddr*)address, sizeof(*address)); + if (result.err) return result; // We do not overwrite the error and return early so we can be + // certain of what error errno belongs to. + result.err = connect(result.bifd, (struct sockaddr*)address, sizeof(*address)); + return result; +} + // Connect to *address_ptr of type `struct sockaddr_in*`. If it failed wait for TIMEOUT_RECONNECT // seconds. // This function is meant to be run by a thread. // An offline server means fds[FDS_SERVER] is set to -1. When online // it is set to with the appropriate file descriptor. -// Returns NULL. +// Returns 0. +#define Miliseconds(s) (s*1000*1000) void* -thread_reconnect(void* address_ptr) +threadReconnect(void* fds_ptr) { - u32 serverfd, err; - struct sockaddr_in* address = address_ptr; - + struct pollfd* fds = fds_ptr; + ConnectionResult result; + struct timespec t = { 0, Miliseconds(300) }; // 300 miliseconds + loggingf("Trying to reconnect\n"); while (1) { - serverfd = socket(AF_INET, SOCK_STREAM, 0); - assert(serverfd > 2); // greater than STDERR - err = connect(serverfd, (struct sockaddr*)address, sizeof(*address)); - if (err == 0) - break; - assert(errno == ECONNREFUSED); - // TODO: faster reconnection? (too many open files) - sleep(TIMEOUT_RECONNECT); + nanosleep(&t, &t); + result = getConnection(&address); + if (result.err) { + // loggingf("err: %d\n", result.err); + loggingf("errno: %d\n", errno); + } else if (result.unifd != -1 && result.bifd != -1) { + loggingf("Reconnect succeeded (%d, %d), authenticating\n", result.unifd, result.bifd); + // We assume that we already have an ID + // TODO: could there be a problem if a message is received at the same time? + // - not on server restart, but what if we lost connection? + HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); + IDMessage id_message = {.id = user.id}; + sendAnyMessage(result.bifd, &header, &id_message); + + ErrorMessage error_message; + s32 nrecv = recvAnyMessageType(result.bifd, &header, &error_message, HEADER_TYPE_ERROR); + if (nrecv == -1 || nrecv == 0) { + loggingf("Error on receive, retrying...\n"); + continue; + } + + assert(header.type == HEADER_TYPE_ERROR); + if (error_message.type == ERROR_TYPE_SUCCESS) { + loggingf("Reconnected\n"); + break; + } else { + loggingf("err: %s\n", errorTypeString(error_message.type)); + } + } + if (result.unifd != -1) + close(result.unifd); + if (result.bifd != -1) + close(result.bifd); + loggingf("Failed, retrying..\n"); } - // if the server would send a disconnect again and the polling catches up there could be two - // threads accessing fds. - pthread_mutex_lock(&mutex); - fds[FDS_SERVER].fd = serverfd; - pthread_mutex_unlock(&mutex); + fds[FDS_BI].fd = result.bifd; + fds[FDS_UNI].fd = result.unifd; - // ask to redraw screen + // Redraw screen raise(SIGWINCH); - return NULL; + return 0; } // Print `text` wrapped to limit_x. It will print no more than limit_y lines. x, y, fg and @@ -85,9 +185,8 @@ thread_reconnect(void* address_ptr) // this is useful when for example: printing messages and wanting to have consistent // timestamp+author name. // Returns the number of lines printed. -// TODO: add y limit -// TODO:(bug) text after pfx is wrapped one too soon -// TODO: text == NULL to know how many lines *would* be printed +// TODO: (bug) text after pfx is wrapped one too soon +// TODO: text == 0 to know how many lines *would* be printed // - no this should be a separate function // TODO: check if text[i] goes out of bounds u32 @@ -107,7 +206,7 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx u32 failed = 0; // NOTE: We can assume that we need to wrap, therefore print a newline after the prefix string - if (pfx != NULL) { + if (pfx != 0) { tb_printf(x, ly, fg_pfx, bg_pfx, "%s", pfx); // If the text fits on one line print the text and return @@ -174,7 +273,7 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx // it displays a prompt with the user input of input_len wide characters // and the received messages from msgsArena void -screen_home(Arena* msgsArena, u32 nmessages, u32 input[], u32 input_len) +screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd* fds, u32 input[], u32 input_len) { // config options const s32 box_max_len = 80; @@ -205,7 +304,7 @@ screen_home(Arena* msgsArena, u32 nmessages, u32 input[], u32 input_len) goto draw_prompt; u8* addr = msgsArena->addr; - assert(addr != NULL); + assert(addr != 0); // on what line to print the current message, used for scrolling u32 msg_y = 0; @@ -238,46 +337,62 @@ screen_home(Arena* msgsArena, u32 nmessages, u32 input[], u32 input_len) HeaderMessage* header = (HeaderMessage*)addr; addr += sizeof(*header); + // Get Client for message + ID* id; + Client* client; + switch (header->type) { + case HEADER_TYPE_TEXT: + id = &((TextMessage*)addr)->id; + case HEADER_TYPE_PRESENCE: + id = &((PresenceMessage*)addr)->id; + client = getClientById(clientsArena, *id); + if (!client) { + loggingf("Client not known, requesting from server\n"); + client = addClientInfo(clientsArena, fds[FDS_BI].fd, *id); + } + assert(client); + break; + } + switch (header->type) { case HEADER_TYPE_TEXT: { TextMessage* message = (TextMessage*)addr; + // Color own messages u32 fg = 0; - if (strncmp((char*)username, (char*)message->author, AUTHOR_LEN) == 0) { + if (user.id == message->id) { fg = TB_CYAN; } else { fg = TB_MAGENTA; } - // prefix is of format "HH:MM:SS [<author>] ", so + + // prefix is of format "HH:MM:SS [<author>] ", create it u8 pfx[AUTHOR_LEN - 1 + TIMESTAMP_LEN - 1 + 4 + 1] = {0}; u8 timestamp[TIMESTAMP_LEN]; formatTimestamp(timestamp, message->timestamp); - sprintf((char*)pfx, "%s [%s] ", timestamp, message->author); - // TODO: y_limit + sprintf((char*)pfx, "%s [%s] ", timestamp, client->author); + msg_y += tb_printf_wrap(0, msg_y, TB_WHITE, 0, (u32*)&message->text, message->len, fg, 0, pfx, global.width, free_y - msg_y); u32 message_size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text); addr += message_size; - break; - } + } break; case HEADER_TYPE_PRESENCE: { PresenceMessage* message = (PresenceMessage*)addr; - tb_printf(0, msg_y, 0, 0, " [%s] *%s*", message->author, presenceTypeString(message->type)); + tb_printf(0, msg_y, 0, 0, " [%s] *%s*", client->author, presenceTypeString(message->type)); msg_y++; addr += sizeof(*message); - break; - } + } break; case HEADER_TYPE_HISTORY: { HistoryMessage* message = (HistoryMessage*)addr; addr += sizeof(*message); // TODO: implement - } - default: { + } break; + default: tb_printf(0, msg_y, 0, 0, "%s", headerTypeString(header->type)); msg_y++; break; } - } } draw_prompt: @@ -340,7 +455,7 @@ screen_home(Arena* msgsArena, u32 nmessages, u32 input[], u32 input_len) } } - if (fds[FDS_SERVER].fd == -1) { + if (fds[FDS_UNI].fd == -1 || fds[FDS_BI].fd == -1) { // show error popup popup(TB_RED, TB_BLACK, "Server disconnected."); } @@ -350,77 +465,128 @@ screen_home(Arena* msgsArena, u32 nmessages, u32 input[], u32 input_len) int main(int argc, char** argv) { - // Use first argument as username - if (argc > 1) { - u32 arg_len = strlen(argv[1]); - assert(arg_len <= AUTHOR_LEN - 1); - memcpy(username, argv[1], arg_len); - username[arg_len] = '\0'; + if (argc < 2) { + fprintf(stderr, "usage: chatty <username>\n"); + return 1; } + u32 arg_len = strlen(argv[1]); + assert(arg_len <= AUTHOR_LEN - 1); + memcpy(user.author, argv[1], arg_len); + user.author[arg_len] = '\0'; + s32 err = 0; // error code for functions - Arena* msgsArena = ArenaAlloc(Megabytes(64)); // Messages received & sent - u32 nmessages = 0; // Number of messages in msgsArena - s32 nrecv = 0; // number of bytes received - s32 nsend = 0; // number of bytes sent + u32 nmessages = 0; // Number of messages in msgsArena + s32 nrecv = 0; // number of bytes received u32 input[INPUT_LIMIT] = {0}; // input buffer u32 ninput = 0; // number of characters in input + Arena msgsArena; + Arena clientsArena; + ArenaAlloc(&msgsArena, Megabytes(64)); // Messages received & sent + ArenaAlloc(&clientsArena, Megabytes(1)); // Arena for storing clients + struct tb_event ev; // event fork keypress & resize u8 quit = 0; // boolean to indicate if we want to quit the main loop - u8* quitmsg = NULL; // this string will be printed before returning from main + u8* quitmsg = 0; // this string will be printed before returning from main pthread_t thr_rec; // thread for reconnecting to server when disconnected +#ifdef LOGGING + logfd = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600); + assert(logfd != -1); +#else + logfd = 2; // stderr +#endif + // poopoo C cannot infer type - fds = (struct pollfd[FDS_MAX]){ - {-1, POLLIN, 0}, // FDS_SERVER + struct pollfd fds[FDS_MAX] = { + {-1, POLLIN, 0}, // FDS_UNI + {-1, POLLIN, 0}, // FDS_BI {-1, POLLIN, 0}, // FDS_TTY {-1, POLLIN, 0}, // FDS_RESIZE }; - const struct sockaddr_in address = { + address = (struct sockaddr_in){ AF_INET, htons(PORT), {0}, {0}, }; - // Connecting to server - { - s32 serverfd; - serverfd = socket(AF_INET, SOCK_STREAM, 0); - assert(serverfd > 2); // greater than STDERR - - err = connect(serverfd, (struct sockaddr*)&address, sizeof(address)); - if (err != 0) { - perror("Server"); + ConnectionResult result = getConnection(&address); + if (result.err) { + perror("Server"); + return 1; + } + assert(result.unifd != -1); + assert(result.bifd != -1); + assert(!result.err); + fds[FDS_BI].fd = result.bifd; + fds[FDS_UNI].fd = result.unifd; + +#ifdef IMPORT_ID + // File for storing the user's ID. + u32 idfile = open(ID_FILE, O_RDWR | O_CREAT, 0600); + s32 nread = read(idfile, &user.id, sizeof(user.id)); + assert(nread != -1); + // see "Authentication" in chatty.h + if (nread == sizeof(user.id)) { + // Scenario 1: We know our id + + // Send IDMessage and check if it is correct + HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); + IDMessage message = {.id = user.id}; + sendAnyMessage(fds[FDS_BI].fd, &header, &message); + + ErrorMessage error_message = {0}; + recvAnyMessageType(fds[FDS_BI].fd, &header, &error_message, HEADER_TYPE_ERROR); + + switch (error_message.type) { + case ERROR_TYPE_SUCCESS: break; + case ERROR_TYPE_NOTFOUND: + printf("Server does not know our ID. Consider removing '" ID_FILE "'\n"); + return 1; + default: + printf("Server: %s\n", errorTypeString(error_message.type)); return 1; } - fds[FDS_SERVER].fd = serverfd; - - // Introduce ourselves - HeaderMessage header = HEADER_PRESENCEMESSAGE; - PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED}; - memcpy(message.author, username, AUTHOR_LEN); - nsend = send(serverfd, &header, sizeof(header), 0); - assert(nsend != -1); - assert(nsend == sizeof(header)); - nsend = send(serverfd, &message, sizeof(message), 0); - assert(nsend != -1); - assert(nsend == sizeof(message)); + } else { +#else + if (1) { +#endif + // Scenario 2: We do not have an ID + HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); + IntroductionMessage message = {0}; + // copy user data into message + memcpy(message.author, user.author, AUTHOR_LEN); + + // Send the introduction message + sendAnyMessage(fds[FDS_BI].fd, &header, &message); + + IDMessage id_message = {0}; + // Receive the response IDMessage + recvAnyMessageType(fds[FDS_BI].fd, &header, &id_message, HEADER_TYPE_ID); + assert(header.type == HEADER_TYPE_ID); + user.id = id_message.id; +#ifdef IMPORT_ID + // Save permanently + write(idfile, &user.id, sizeof(user.id)); + close(idfile); +#endif } + loggingf("Got ID: %lu\n", user.id); // for wide character printing - assert(setlocale(LC_ALL, "") != NULL); + assert(setlocale(LC_ALL, "") != 0); // init tb_init(); tb_get_fds(&fds[FDS_TTY].fd, &fds[FDS_RESIZE].fd); - screen_home(msgsArena, nmessages, input, ninput); + screen_home(&msgsArena, nmessages, &clientsArena, fds, input, ninput); tb_present(); // main loop @@ -431,43 +597,45 @@ main(int argc, char** argv) tb_clear(); - if (fds[FDS_SERVER].revents & POLLIN) { + if (fds[FDS_UNI].revents & POLLIN) { // got data from server HeaderMessage header; - nrecv = recv(fds[FDS_SERVER].fd, &header, sizeof(header), 0); + nrecv = recv(fds[FDS_UNI].fd, &header, sizeof(header), 0); assert(nrecv != -1); // Server disconnects if (nrecv == 0) { // close diconnected server's socket - err = close(fds[FDS_SERVER].fd); + err = close(fds[FDS_UNI].fd); assert(err == 0); - fds[FDS_SERVER].fd = -1; // ignore + fds[FDS_UNI].fd = -1; // ignore // start trying to reconnect in a thread - err = pthread_create(&thr_rec, NULL, &thread_reconnect, (void*)&address); + err = pthread_create(&thr_rec, 0, &threadReconnect, (void*)fds); assert(err == 0); } else { - // TODO: validate version - // if (header.version == PROTOCOL_VERSION) - // continue; + if (header.version != PROTOCOL_VERSION) { + loggingf("Header received does not match version\n"); + continue; + } - void* addr = ArenaPush(msgsArena, sizeof(header)); + void* addr = ArenaPush(&msgsArena, sizeof(header)); memcpy(addr, &header, sizeof(header)); + // Messages handled from server switch (header.type) { case HEADER_TYPE_TEXT: - recvTextMessage(msgsArena, fds[FDS_SERVER].fd, NULL); + recvTextMessage(&msgsArena, fds[FDS_UNI].fd); nmessages++; break; case HEADER_TYPE_PRESENCE:; - PresenceMessage* message = ArenaPush(msgsArena, sizeof(*message)); - nrecv = recv(fds[FDS_SERVER].fd, message, sizeof(*message), 0); + PresenceMessage* message = ArenaPush(&msgsArena, sizeof(*message)); + nrecv = recv(fds[FDS_UNI].fd, message, sizeof(*message), 0); assert(nrecv != -1); assert(nrecv == sizeof(*message)); nmessages++; break; default: - // TODO: log + loggingf("Got unhandled message: %s\n", headerTypeString(header.type)); break; } } @@ -512,7 +680,7 @@ main(int argc, char** argv) if (ninput == 0) // do not send empty message break; - if (fds[FDS_SERVER].fd == -1) + if (fds[FDS_UNI].fd == -1) // do not send message to disconnected server break; @@ -521,25 +689,21 @@ main(int argc, char** argv) ninput++; // Save header - HeaderMessage header = HEADER_TEXTMESSAGE; - void* addr = ArenaPush(msgsArena, sizeof(header)); + HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT); + void* addr = ArenaPush(&msgsArena, sizeof(header)); memcpy(addr, &header, sizeof(header)); // Save message - TextMessage* sendmsg = ArenaPush(msgsArena, TEXTMESSAGE_SIZE); - memcpy(sendmsg->author, username, AUTHOR_LEN); - sendmsg->timestamp = time(NULL); + TextMessage* sendmsg = ArenaPush(&msgsArena, TEXTMESSAGE_SIZE); + sendmsg->id = user.id; + sendmsg->timestamp = time(0); sendmsg->len = ninput; u32 text_size = ninput * sizeof(*input); - ArenaPush(msgsArena, text_size); + ArenaPush(&msgsArena, text_size); memcpy(&sendmsg->text, input, text_size); - // Send message - nsend = send(fds[FDS_SERVER].fd, &header, sizeof(header), 0); - assert(nsend != -1); - nsend = send(fds[FDS_SERVER].fd, sendmsg, TEXTMESSAGE_SIZE + TEXTMESSAGE_TEXT_SIZE((*sendmsg)), 0); - assert(nsend != -1); + sendAnyMessage(fds[FDS_UNI].fd, &header, sendmsg); nmessages++; // also clear input @@ -550,7 +714,8 @@ main(int argc, char** argv) default: if (ev.ch == 0) break; - // TODO: logging + + // TODO: show error if (ninput == INPUT_LIMIT - 1) // last byte reserved for \0 break; @@ -568,17 +733,15 @@ main(int argc, char** argv) tb_poll_event(&ev); } - screen_home(msgsArena, nmessages, input, ninput); + screen_home(&msgsArena, nmessages, &clientsArena, fds, input, ninput); tb_present(); } tb_shutdown(); - if (quitmsg != NULL) + if (quitmsg != 0) printf("%s\n", quitmsg); - ArenaRelease(msgsArena); - return 0; } @@ -1,7 +1,8 @@ -#ifndef CHATTY_IMPL +#ifndef CHATTY_H #include <assert.h> #include <locale.h> +#include <stdarg.h> #include <stdbool.h> #include <stddef.h> #include <stdint.h> @@ -11,6 +12,7 @@ #include <sys/mman.h> #include <sys/socket.h> #include <time.h> +#include <unistd.h> #include <wchar.h> typedef uint8_t u8; @@ -28,13 +30,54 @@ typedef enum { // port for chatty #define PORT 9983 +// max number of bytes that can be logged at once +#define LOGMESSAGE_MAX 2048 +#define LOG_FMT "%H:%M:%S " +#define LOG_LEN 10 #define Kilobytes(Value) ((Value) * 1024) #define Megabytes(Value) (Kilobytes(Value) * 1024) #define Gigabytes(Value) (Megabytes((u64)Value) * 1024) #define Terabytes(Value) (Gigabytes((u64)Value) * 1024) #define PAGESIZE 4096 +#define local_persist static +#define global_variable +#define internal static +global_variable s32 logfd; + +u32 +wstrlen(u32* str) +{ + u32 i = 0; + while (str[i] != 0) + i++; + return i; +} + +void +loggingf(char* format, ...) +{ + char buf[LOGMESSAGE_MAX]; + va_list args; + va_start(args, format); + + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + int n = 0; + while (*(buf + n) != 0) n++; + + u64 t = time(0); + u8 timestamp[LOG_LEN]; + struct tm* ltime = localtime((time_t*)&t); + strftime((char*)timestamp, LOG_LEN, LOG_FMT, ltime); + write(logfd, timestamp, LOG_LEN - 1); + + write(logfd, buf, n); +} + +// Arena Allocator struct Arena { void* addr; u64 size; @@ -46,25 +89,20 @@ struct Arena { #define PushStruct(arena, type) PushArray((arena), (type), 1) #define PushStructZero(arena, type) PushArrayZero((arena), (type), 1) -Arena* -ArenaAlloc(u64 size) +// Returns arena in case of success, or 0 if it failed to alllocate the memory +void +ArenaAlloc(Arena* arena, u64 size) { - Arena* arena = (Arena*)malloc(sizeof(Arena)); - - arena->addr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); - if (arena->addr == MAP_FAILED) - return NULL; + arena->addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0); + assert(arena->addr != MAP_FAILED); arena->pos = 0; arena->size = size; - - return arena; } void ArenaRelease(Arena* arena) { munmap(arena->addr, arena->size); - free(arena); } void* @@ -73,167 +111,9 @@ ArenaPush(Arena* arena, u64 size) u8* mem; mem = (u8*)arena->addr + arena->pos; arena->pos += size; + assert(arena->pos <= arena->size); return mem; } -/// Protocol -// - every message has format Header + Message -// TODO: authentication -// TODO: encryption - -/// Protocol Header -// - 2 bytes for version -// - 1 byte for message type -// - 16 bytes for checksum -// -// Text Message -// - 12 bytes for the author -// - 8 bytes for the timestamp -// - 2 bytes for the text length -// - x*4 bytes for the text -// -// History Message -// This message is for requesting messages sent after a timestamp. -// - 8 bytes for the timestamp - -/// Naming convention -// Messages end with the Message suffix (eg. TextMessag, HistoryMessage) -// A function that is coupled to a type works like -// <noun><type> eg. (printTextMessage, formatTimestamp) - -#define PROTOCOL_VERSION 0 - -typedef struct { - u16 version; - u8 type; -} HeaderMessage; - -enum { HEADER_TYPE_TEXT = 0, - HEADER_TYPE_HISTORY, - HEADER_TYPE_PRESENCE }; -#define HEADER_TEXTMESSAGE {.version = PROTOCOL_VERSION, .type = HEADER_TYPE_TEXT}; -#define HEADER_HISTORYMESSAGE {.version = PROTOCOL_VERSION, .type = HEADER_TYPE_HISTORY}; -#define HEADER_PRESENCEMESSAGE {.version = PROTOCOL_VERSION, .type = HEADER_TYPE_PRESENCE}; - -// Size of author string including null terminator -#define AUTHOR_LEN 13 -// Size of formatted timestamp string including null terminator -#define TIMESTAMP_LEN 9 - -typedef struct { - u8 checksum[16]; - u8 author[AUTHOR_LEN]; - u64 timestamp; - u16 len; // including null terminator - u32* text; // placeholder for indexing - // TODO: 0-length field? -} TextMessage; - -// Size of TextMessage without text pointer, used when receiving the message over a stream -#define TEXTMESSAGE_TEXT_SIZE(m) (m.len * sizeof(*m.text)) -#define TEXTMESSAGE_SIZE (sizeof(TextMessage) - sizeof(u32*)) - -typedef struct { - u64 timestamp; -} HistoryMessage; - -typedef struct { - u8 author[AUTHOR_LEN]; - u8 type; -} PresenceMessage; -enum { PRESENCE_TYPE_CONNECTED = 0, - PRESENCE_TYPE_DISCONNECTED }; - -// Returns string for type byte in HeaderMessage -u8* -headerTypeString(u8 type) -{ - switch (type) { - case HEADER_TYPE_TEXT: return (u8*)"TextMessage"; - case HEADER_TYPE_HISTORY: return (u8*)"HistoryMessage"; - case HEADER_TYPE_PRESENCE: return (u8*)"PresenceMessage"; - default: return (u8*)"Unknown"; - } -} - -u8* -presenceTypeString(u8 type) -{ - switch (type) { - case PRESENCE_TYPE_CONNECTED: return (u8*)"connected"; - case PRESENCE_TYPE_DISCONNECTED: return (u8*)"disconnected"; - default: return (u8*)"Unknown"; - } -} - -// from Tsoding video on minicel (https://youtu.be/HCAgvKQDJng?t=4546) -// sv(https://github.com/tsoding/sv) -#define PH_FMT "header: v%d %s(%d)" -#define PH_ARG(header) header.version, headerTypeString(header.type), header.type - -void -formatTimestamp(u8 tmsp[TIMESTAMP_LEN], u64 t) -{ - struct tm* ltime; - ltime = localtime((time_t*)&t); - strftime((char*)tmsp, TIMESTAMP_LEN, "%H:%M:%S", ltime); -} - -void -printTextMessage(TextMessage* message, u8 wide) -{ - u8 timestamp[TIMESTAMP_LEN] = {0}; - formatTimestamp(timestamp, message->timestamp); - - assert(setlocale(LC_ALL, "") != NULL); - - if (wide) - wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, message->author, (wchar_t*)&message->text); - else { - u8 str[message->len]; - wcstombs((char*)str, (wchar_t*)&message->text, message->len * sizeof(*message->text)); - printf("TextMessage: %s [%s] (%d)%s\n", timestamp, message->author, message->len, str); - } -} - -// Receive a message from fd and store it to the msgsArena, -// if dest is not NULL point it to the new message created on msgsArena -// Returns the number of bytes received -u32 -recvTextMessage(Arena* msgsArena, u32 fd, TextMessage** dest) -{ - s32 nrecv = 0; - - TextMessage* message = ArenaPush(msgsArena, TEXTMESSAGE_SIZE); - if (dest != NULL) - *dest = message; - - // Receive everything but the text so we can know the text's size and act accordingly - nrecv = recv(fd, message, TEXTMESSAGE_SIZE, 0); - assert(nrecv != -1); - assert(nrecv == TEXTMESSAGE_SIZE); - - nrecv = 0; - - // Allocate memory for text and receive in that memory - u32 text_size = message->len * sizeof(*message->text); - ArenaPush(msgsArena, text_size); - - nrecv = recv(fd, (u8*)&message->text, text_size, 0); - assert(nrecv != -1); - assert(nrecv == message->len * sizeof(*message->text)); - - return TEXTMESSAGE_SIZE + nrecv; -} - -u32 -wstrlen(u32* str) -{ - u32 i = 0; - while (str[i] != 0) - i++; - return i; -} - #endif #define CHATTY_H diff --git a/protocol.h b/protocol.h new file mode 100644 index 0000000..84d636f --- /dev/null +++ b/protocol.h @@ -0,0 +1,360 @@ +#ifndef PROTOCOL_H +#define PROTOCOL_H + +#include "chatty.h" + +/// Protocol +// - every message has format Header + Message +// TODO: security +// +/// ID +// - So clients can be identified uniquely. +// - 8 bytes +// - number that increments for each new client +// +/// Strings +// - strings are sent with their null terminator +// +/// Authentication +// This is what happens when the first time a client connects. +// Scenario 1. We alreayd have an ID +// 1. client-> Send own ID +// 2. server-> knows ID? +// y. server-> Success +// n. 1. server-> Error 'notfound' +// 2. client-> exit +// Scenario 2. We do not have an ID +// 1. client-> Introduces +// 2. server-> Sends & Saves ID +// 3. client-> Saves ID +// +/// Naming convention +// Messages end with the Message suffix (eg. TextMessag, HistoryMessage) +// +// A function that is coupled to a type works like +// <noun><type> eg. (printTextMessage, formatTimestamp) + +#define PROTOCOL_VERSION 0 +// Size of author string including null terminator +#define AUTHOR_LEN 13 +// Size of formatted timestamp string including null terminator +#define TIMESTAMP_LEN 9 +#define TIMESTAMP_FORMAT "%H:%M:%S" + +typedef u64 ID; + +// - 2 bytes for version +// - 1 byte for message type +// - 16 bytes for checksum +typedef struct { + u16 version; + u8 type; +} HeaderMessage; + +typedef enum { + HEADER_TYPE_TEXT = 0, + HEADER_TYPE_HISTORY, + HEADER_TYPE_PRESENCE, + HEADER_TYPE_ID, + HEADER_TYPE_INTRODUCTION, + HEADER_TYPE_ERROR +} HeaderType; +// shorthand for creating a header with a value from the enum +#define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t} +// from Tsoding video on minicel (https://youtu.be/HCAgvKQDJng?t=4546) +// sv(https://github.com/tsoding/sv) +#define HEADER_FMT "header: v%d %s(%d)" +#define HEADER_ARG(header) header.version, headerTypeString(header.type), header.type + +// For sending texts to other clients +// - 13 bytes for the author +// - 8 bytes for the timestamp +// - 8 bytes for id +// - 2 bytes for the text length +// - x*4 bytes for the text +typedef struct { + ID id; + u64 timestamp; // timestamp of when the message was sent + u16 len; + wchar_t* text; // placeholder for indexing + // wchar_t* is used, because this renders the text in the debugger +} TextMessage; +// Size of TextMessage without text pointer +#define TEXTMESSAGE_SIZE (sizeof(TextMessage) - sizeof(u32*)) + +// Requesting messages sent after a timestamp. +// - 8 bytes for the timestamp +typedef struct { + u64 timestamp; +} HistoryMessage; + +// Introduce the client to the server by sending the client's information. +// See "First connection". +// - 13 bytes for author +typedef struct { + u8 author[AUTHOR_LEN]; +} IntroductionMessage; +#define INTRODUCTION_FMT "introduction: %s" +#define INTRODUCTION_ARG(message) message.author + +// Request IntroductionMessage for client with that id. +// See "First connection" if this message is used when the client connects for the first time. +// be used to retrieve information about a client with an unknown ID. +// - 8 bytes for id +typedef struct { + ID id; +} IDMessage; + +// Notifying the sender's state, such as "connected", "disconnected", "AFK", ... +// - 8 bytes for id +// - 1 byte for type +typedef struct { + ID id; + u8 type; +} PresenceMessage; +typedef enum { + PRESENCE_TYPE_CONNECTED = 0, + PRESENCE_TYPE_DISCONNECTED, + PRESENCE_TYPE_AFK +} PresenceType; + +// Send an error message +// - 1 byte for type +typedef struct { + u8 type; +} ErrorMessage; +typedef enum { + ERROR_TYPE_BADMESSAGE = 0, + ERROR_TYPE_NOTFOUND, + ERROR_TYPE_SUCCESS, + ERROR_TYPE_ALREADYCONNECTED, + ERROR_TYPE_TOOMANYCONNECTIONS +} ErrorType; +#define ERROR_INIT(t) {.type = t} + +typedef struct { + s32 nrecv; + TextMessage* message; +} recvTextMessageResult; + +// Returns string for type byte in HeaderMessage +u8* +headerTypeString(HeaderType type) +{ + switch (type) { + case HEADER_TYPE_TEXT: return (u8*)"TextMessage"; + case HEADER_TYPE_HISTORY: return (u8*)"HistoryMessage"; + case HEADER_TYPE_PRESENCE: return (u8*)"PresenceMessage"; + case HEADER_TYPE_ID: return (u8*)"IDMessage"; + case HEADER_TYPE_INTRODUCTION: return (u8*)"IntroductionMessage"; + case HEADER_TYPE_ERROR: return (u8*)"ErrorMessage"; + default: return (u8*)"Unknown"; + } +} + +u8* +presenceTypeString(PresenceType type) +{ + switch (type) { + case PRESENCE_TYPE_CONNECTED: return (u8*)"connected"; + case PRESENCE_TYPE_DISCONNECTED: return (u8*)"disconnected"; + case PRESENCE_TYPE_AFK: return (u8*)"afk"; + default: return (u8*)"Unknown"; + } +} + +u8* +errorTypeString(ErrorType type) +{ + switch (type) { + case ERROR_TYPE_BADMESSAGE: return (u8*)"bad message"; + case ERROR_TYPE_NOTFOUND: return (u8*)"not found"; + case ERROR_TYPE_SUCCESS: return (u8*)"success"; + case ERROR_TYPE_ALREADYCONNECTED: return (u8*)"already connected"; + case ERROR_TYPE_TOOMANYCONNECTIONS: return (u8*)"too many connections"; + default: return (u8*)"Unknown"; + } +} + +// Formats time t into tmsp string +void +formatTimestamp(u8 timestamp_str[TIMESTAMP_LEN], u64 timestamp) +{ + struct tm* ltime; + ltime = localtime((time_t*)×tamp); + strftime((char*)timestamp_str, TIMESTAMP_LEN, TIMESTAMP_FORMAT, ltime); +} + +// Receive a message from fd and store it in the msgsArena, +// Returns pointer to the allocated memory +TextMessage* +recvTextMessage(Arena* msgsArena, u32 fd) +{ + TextMessage* message = ArenaPush(msgsArena, TEXTMESSAGE_SIZE); + + // Receive everything but the text so we can know the text's size and act accordingly + s32 nrecv = recv(fd, message, TEXTMESSAGE_SIZE, 0); + assert(nrecv != -1); + assert(nrecv == TEXTMESSAGE_SIZE); + + // Allocate memory for text and receive in that memory + u32 text_size = message->len * sizeof(*message->text); + ArenaPush(msgsArena, text_size); + + nrecv = recv(fd, (u8*)&message->text, text_size, 0); + assert(nrecv != -1); + assert(nrecv == message->len * sizeof(*message->text)); + + return message; +} + +typedef struct { + HeaderMessage* header; + void* message; +} Message; + +u32 +getMessageSize(HeaderType type) +{ + u32 size = 0; + switch (type) { + case HEADER_TYPE_ERROR: size = sizeof(ErrorMessage); break; + case HEADER_TYPE_HISTORY: size = sizeof(HistoryMessage); break; + case HEADER_TYPE_ID: size = sizeof(IDMessage); break; + case HEADER_TYPE_INTRODUCTION: size = sizeof(IntroductionMessage); break; + case HEADER_TYPE_PRESENCE: size = sizeof(PresenceMessage); break; + default: assert(0); + } + return size; +} + +s32 +recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType type) +{ + s32 nrecv = recv(fd, header, sizeof(*header), 0); + if (nrecv == -1 || nrecv == 0) + return nrecv; + assert(nrecv == sizeof(*header)); + + s32 size = 0; + switch (type) { + case HEADER_TYPE_ERROR: + case HEADER_TYPE_HISTORY: + case HEADER_TYPE_ID: + case HEADER_TYPE_INTRODUCTION: + case HEADER_TYPE_PRESENCE: + size = getMessageSize(header->type); + break; + case HEADER_TYPE_TEXT: { + TextMessage* message = anyMessage; + size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text); + } break; + default: assert(0); break; + } + assert(header->type == type); + + nrecv = recv(fd, anyMessage, size, 0); + assert(nrecv != -1); + assert(nrecv == size); + + return size; +} + +// Get any message into arena +Message +recvAnyMessage(Arena* arena, s32 fd) +{ + HeaderMessage* header = ArenaPush(arena, sizeof(*header)); + s32 nrecv = recv(fd, header, sizeof(*header), 0); + assert(nrecv != -1); + assert(nrecv == sizeof(*header)); + + s32 size = 0; + switch (header->type) { + case HEADER_TYPE_ERROR: + case HEADER_TYPE_HISTORY: + case HEADER_TYPE_ID: + case HEADER_TYPE_INTRODUCTION: + case HEADER_TYPE_PRESENCE: + size = getMessageSize(header->type); + break; + case HEADER_TYPE_TEXT: { + Message result; + result.header = header; + result.message = recvTextMessage(arena, fd); + return result; + } break; + default: assert(0); break; + } + + void* message = ArenaPush(arena, size); + nrecv = recv(fd, message, size, 0); + assert(nrecv != -1); + assert(nrecv == size); + + Message result; + result.header = header; + result.message = message; + + return result; +} + +Message +waitForMessageType(Arena* arena, Arena* queueArena, u32 fd, HeaderType type) +{ + Message message; + while (1) { + message = recvAnyMessage(arena, fd); + if (message.header->type == type) + break; + ArenaPush(queueArena, getMessageSize(message.header->type)); + } + return message; +} + +// Generic sending function for sending any type of message to fd +// Returns number of bytes sent in message or -1 if there was an error. +s32 +sendAnyMessage(u32 fd, HeaderMessage* header, void* anyMessage) +{ + s32 nsend_total; + s32 nsend = send(fd, header, sizeof(*header), 0); + if (nsend == -1) return nsend; + assert(nsend == sizeof(*header)); + nsend_total = nsend; + + s32 size = 0; + switch (header->type) { + case HEADER_TYPE_ERROR: + case HEADER_TYPE_HISTORY: + case HEADER_TYPE_ID: + case HEADER_TYPE_INTRODUCTION: + case HEADER_TYPE_PRESENCE: + size = getMessageSize(header->type); + break; + case HEADER_TYPE_TEXT: { + nsend = send(fd, anyMessage, TEXTMESSAGE_SIZE, 0); + assert(nsend != -1); + assert(nsend == TEXTMESSAGE_SIZE); + nsend_total += nsend; + // set size to remaning text size that should be sent + TextMessage* message = (TextMessage*)anyMessage; + size = message->len * sizeof(*message->text); + nsend = 0; + + anyMessage = &message->text; + } break; + default: + fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header->type)); + return 0; + } + + nsend = send(fd, anyMessage, size, 0); + if (nsend == -1) return nsend; + assert(nsend == size); + nsend_total += nsend; + + return nsend_total; +} + +#endif @@ -8,6 +8,7 @@ #include <unistd.h> #include "chatty.h" +#include "protocol.h" int main(int argc, char** argv) @@ -17,61 +18,70 @@ main(int argc, char** argv) return 1; } - s32 err, serverfd, nsend; + s32 err, serverfd, nsend, nrecv; serverfd = socket(AF_INET, SOCK_STREAM, 0); assert(serverfd != -1); - const struct sockaddr_in address = { - AF_INET, - htons(PORT), - {0}, - }; + const struct sockaddr_in address = {AF_INET, htons(PORT), {0}, {0}}; err = connect(serverfd, (struct sockaddr*)&address, sizeof(address)); assert(err == 0); - // convert text to wide string - u32 text_len = strlen(argv[2]) + 1; - u32 text_wide[text_len]; - u32 size = mbstowcs((wchar_t*)text_wide, argv[2], text_len - 1); - assert(size == text_len - 1); - text_wide[text_len - 1] = 0; - u32 author_len = strlen(argv[1]); - assert(author_len + 1 <= AUTHOR_LEN); // add 1 for null terminator - - // Introduce ourselves + // Get our ID + ID id = 0; { - HeaderMessage header = HEADER_PRESENCEMESSAGE; - PresenceMessage message; + // get author len + u32 author_len = strlen(argv[1]); + assert(author_len + 1 <= AUTHOR_LEN); // add 1 for null terminator + + // Introduce ourselves + HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); + IntroductionMessage message; memcpy(message.author, argv[1], author_len); nsend = send(serverfd, &header, sizeof(header), 0); assert(nsend != -1); nsend = send(serverfd, &message, sizeof(message), 0); assert(nsend != -1); - } - HeaderMessage header = HEADER_TEXTMESSAGE; - TextMessage* message; + // Get id + nrecv = recv(serverfd, &header, sizeof(header), 0); + assert(nrecv != -1); + if (header.type == HEADER_TYPE_ERROR) { + ErrorMessage message; + nrecv = recv(serverfd, &message, sizeof(message), 0); + fprintf(stderr, "Got '%s' error.\n'", errorTypeString(message.type)); + close(serverfd); + return 1; + } + assert(header.type == HEADER_TYPE_ID); + IDMessage idmessage; + nrecv = recv(serverfd, &idmessage, sizeof(idmessage), 0); + assert(nrecv != -1); + fprintf(stderr, "Got id: %lu\n", idmessage.id); + } - u8 buf[text_len * sizeof(*text_wide) + TEXTMESSAGE_SIZE]; - bzero(buf, sizeof(buf)); - message = (TextMessage*)buf; + // convert text to wide string + u32 text_len = strlen(argv[2]) + 1; + u32 text_wide[text_len]; + u32 size = mbstowcs((wchar_t*)text_wide, argv[2], text_len - 1); + assert(size == text_len - 1); + text_wide[text_len - 1] = 0; - memcpy(message->author, argv[1], author_len); - message->timestamp = time(NULL); - message->len = text_len; - memcpy(&message->text, text_wide, text_len * sizeof(*message->text)); + HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT); + TextMessage message; + bzero(&message, TEXTMESSAGE_SIZE); + message = (TextMessage){.id = id, .timestamp = time(NULL), .len = text_len}; nsend = send(serverfd, &header, sizeof(header), 0); assert(nsend != -1); - printf("header bytes sent: %d\n", nsend); - nsend = send(serverfd, buf, sizeof(buf), 0); + fprintf(stderr, "header bytes sent: %d\n", nsend); + nsend = send(serverfd, &message, TEXTMESSAGE_SIZE, 0); assert(nsend != -1); + fprintf(stderr, "message bytes sent: %d\n", nsend); - printf("text length: %d\n", text_len); - printf("buf size: %lu\n", sizeof(buf)); - printf("text size: %lu\n", sizeof(*text_wide) * text_len); - printf("message bytes sent: %d\n", nsend); + u32 text_size = message.len * sizeof(*message.text); + nsend = send(serverfd, text_wide, text_size, 0); + fprintf(stderr, "text bytes sent: %d\n", nsend); return 0; } @@ -1,11 +1,15 @@ #include "chatty.h" +#include "protocol.h" #include <assert.h> +#include <fcntl.h> #include <netinet/in.h> #include <poll.h> +#include <signal.h> #include <stdarg.h> #include <string.h> #include <sys/socket.h> +#include <sys/stat.h> #include <unistd.h> // timeout on polling @@ -14,276 +18,481 @@ #define MAX_CONNECTIONS 16 // Get number of connections from arena position // NOTE: this is somewhat wrong, because of when disconnections happen -#define FDS_SIZE (fdsArena->pos / sizeof(*fds)) +#define FDS_SIZE (fdsArena.pos / sizeof(struct pollfd)) +#define CLIENTS_SIZE (clientsArena.pos / sizeof(Client)) + +// Enable/Disable saving clients permanently to file +#define IMPORT_ID +// Where to save clients +#define CLIENTS_FILE "_clients" +// Where to write logs +#define LOGFILE "server.log" +// Log to LOGFILE instead of stderr +// #define LOGGING // enum for indexing the fds array enum { FDS_STDIN = 0, FDS_SERVER, FDS_CLIENTS }; -// Has information on clients -// For each pollfd in fds there should be a matching client in clients -// clients[i] <=> fds[i] +// Client information typedef struct { u8 author[AUTHOR_LEN]; // matches author property on other message types - Bool initialized; // boolean + ID id; + struct pollfd* pollunifd; // Index in fds array + struct pollfd* pollbifd; // Index in fds array } Client; +#define CLIENT_FMT "[%s](%lu)" +#define CLIENT_ARG(client) client.author, client.id + +typedef enum { + UNIFD = 0, + BIFD +} ClientFD; + +// TODO: remove +// For handing out new ids to connections. +global_variable u32 nclients = 0; + +// Returns client matching id in clients. +// clientsArena is used to get an upper bound. +// Returns 0 if there was no client found. +Client* +getClientByID(Arena* clientsArena, ID id) +{ + Client* clients = clientsArena->addr; + for (u32 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) { + if (clients[i].id == id) + return clients + i; + } + return 0; +} + +// Print TextMessage prettily +void +printTextMessage(TextMessage* message, Client* client, u8 wide) +{ + u8 timestamp[TIMESTAMP_LEN] = {0}; + formatTimestamp(timestamp, message->timestamp); + + if (wide) { + setlocale(LC_ALL, ""); + wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, client->author, (wchar_t*)&message->text); + } else { + u8 str[message->len]; + wcstombs((char*)str, (wchar_t*)&message->text, message->len * sizeof(*message->text)); + loggingf("TextMessage: %s [%s] (%d)%s\n", timestamp, client->author, message->len, str); + } +} -// Send anyMessage to all clients in fds from fdsArena except for fds[i]. +// Send header and anyMessage to each connection in fds that is nfds number of connections except +// for connfd. +// Type will filter out only connections matching the type. void -sendToOthers(Arena* fdsArena, struct pollfd* fds, u32 i, HeaderMessage* header, void* anyMessage) +sendToOthers(struct pollfd* fds, u32 nfds, s32 connfd, ClientFD type, HeaderMessage* header, void* anyMessage) { s32 nsend; - for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) { - if (fds[j].fd == fds[i].fd) continue; - if (fds[j].fd == -1) continue; - - // send header - u32 nsend_total = 0; - nsend = send(fds[j].fd, header, sizeof(*header), 0); - assert(nsend != -1); - assert(nsend == sizeof(*header)); - nsend_total += nsend; - - // send message - switch (header->type) { - case HEADER_TYPE_PRESENCE: { - PresenceMessage* message = (PresenceMessage*)anyMessage; - nsend = send(fds[j].fd, message, sizeof(*message), 0); - assert(nsend != -1); - assert(nsend == sizeof(*message)); - fprintf(stdout, " Notifying(%d->%d).\n", fds[i].fd, fds[j].fd); - break; - } - case HEADER_TYPE_TEXT: { - TextMessage* message = (TextMessage*)anyMessage; - nsend = send(fds[j].fd, message, TEXTMESSAGE_SIZE, 0); - assert(nsend != -1); - assert(nsend == TEXTMESSAGE_SIZE); - nsend_total += nsend; - nsend = send(fds[j].fd, &message->text, message->len * sizeof(*message->text), 0); - assert(nsend != -1); - assert(nsend == (message->len * sizeof(*message->text))); - nsend_total += nsend; - break; - } - default: - fprintf(stdout, " Cannot retransmit %s\n", headerTypeString(header->type)); - } + for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { + if (fds[i].fd == connfd) continue; + if (fds[i].fd == -1) continue; - fprintf(stdout, " Retransmitted(%d->%d) %d bytes.\n", fds[i].fd, fds[j].fd, nsend_total); + nsend = sendAnyMessage(fds[i].fd, header, anyMessage); + loggingf("sendToOthers(%d)|[%s]->%d %d bytes\n", connfd, headerTypeString(header->type), fds[i].fd, nsend); } } +// Send header and anyMessage to each connection in fds that is nfds number of connections. +// Type will filter out only connections matching the type. void -disconnect(Arena* fdsArena, struct pollfd* fds, u32 i, Client* client) +sendToAll(struct pollfd* fds, u32 nfds, ClientFD type, HeaderMessage* header, void* anyMessage) { - fprintf(stdout, "Disconnected(%d). \n", fds[i].fd); - shutdown(fds[i].fd, SHUT_RDWR); - close(fds[i].fd); // send close to client - - // Send disconnection to other connected clients - HeaderMessage header = HEADER_PRESENCEMESSAGE; - PresenceMessage message = { - .type = PRESENCE_TYPE_DISCONNECTED - }; - memcpy(message.author, client->author, AUTHOR_LEN); - sendToOthers(fdsArena, fds, i, &header, &message); - - fds[i].fd = -1; // ignore in the future - client->initialized = False; // deinitialize client + for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) { + if (fds[i].fd == -1) continue; + s32 nsend = sendAnyMessage(fds[i].fd, header, anyMessage); + loggingf("sendToAll|[%s]->%d %d bytes\n", headerTypeString(header->type), fds[i].fd, nsend); + } } -// Initialize a client that connects for the first time or reconnects. -// Receive HeaderMessage and PresenceMessage from fd and set client with the data from -// PresenceMessage. -// Notify fds in fdsArena. -// TODO: handle wrong messages +// Disconnect a client by closing the matching file descriptors void -initClient(Arena* fdsArena, struct pollfd* fds, s32 fd, Client* client) +disconnect(struct pollfd* pollfd, Client* client) { - s32 nrecv = 0; - s32 nsend = 0; + loggingf("Disconnecting "CLIENT_FMT"\n", CLIENT_ARG((*client))); + if (pollfd[UNIFD].fd != -1) { + close(pollfd[UNIFD].fd); + } + if (pollfd[BIFD].fd != -1) { + close(pollfd[BIFD].fd); + } + pollfd[UNIFD].fd = -1; + pollfd[BIFD].fd = -1; + // TODO: mark as free + if (client) { + client->pollunifd = 0; + client->pollbifd = 0; + } +} - fprintf(stdout, " Adding to clients(%d).\n", fd); +// Disconnects fds+conn from fds with nfds connections, then send a PresenceMessage to other +// clients about disconnection. +void +disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn) +{ + disconnect(fds + conn, client); + + local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); + PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_DISCONNECTED}; + sendToAll(fds, nfds, UNIFD, &header, &message); +} + +// Receive authentication from pollfd->fd and create client out of it. Look in +// clientsArena if it already exists. Otherwise push a new onto the arena and write its information +// to clients_file. +// See "Authentication" in chatty.h +Client* +authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* clientfds) +{ + s32 nrecv = 0; + Client* clients = clientsArena->addr; HeaderMessage header; - nrecv = recv(fd, &header, sizeof(header), 0); - assert(nrecv != -1); - assert(nrecv == sizeof(header)); - if (header.type != HEADER_TYPE_PRESENCE) { - // reject connection - close(fd); - fprintf(stdout, " Got wrong header(%d).\n", fd); - return; + nrecv = recv(clientfds[BIFD].fd, &header, sizeof(header), 0); + if (nrecv != sizeof(header)) { + loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(header)); + return 0; } - fprintf(stdout, " Got header(%d).\n", fd); - - PresenceMessage message; - nrecv = recv(fd, &message, sizeof(message), 0); - assert(nrecv != -1); - assert(nrecv == sizeof(message)); - fprintf(stdout, " Got presence message(%d).\n", fd); - - // Copy author from PresenceMessage. - memcpy(client->author, message.author, AUTHOR_LEN); - - // Notify other clients from this new one - // Reuse header and message - for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) { - if (fds[j].fd == fd) - continue; - if (fds[j].fd == -1) - continue; - fprintf(stdout, " Notifying(%d->%d).\n", fd, fds[j].fd); - nsend = send(fds[j].fd, &header, sizeof(header), 0); - assert(nsend != -1); - assert(nsend == sizeof(header)); - nsend = send(fds[j].fd, &message, sizeof(message), 0); - assert(nsend != -1); - assert(nsend == sizeof(message)); + loggingf("authenticate(%d)|" HEADER_FMT "\n", clientfds[BIFD].fd, HEADER_ARG(header)); + + Client* client = 0; + // Scenario 1: Search for existing client + if (header.type == HEADER_TYPE_ID) { + IDMessage message; + nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); + if (nrecv != sizeof(message)) { + loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); + return 0; + } + + client = getClientByID(clientsArena, message.id); + if (client) { + loggingf("authenticate(%d)|found [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); + header.type = HEADER_TYPE_ERROR; + // TODO: allow multiple connections + if (client->pollunifd != 0 || client->pollbifd != 0) { + loggingf("authenticate(%d)|err: already connected\n", clientfds[BIFD].fd); + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_ALREADYCONNECTED); + sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); + return 0; + } + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS); + sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); + } else { + loggingf("authenticate(%d)|notfound\n", clientfds[BIFD].fd); + header.type = HEADER_TYPE_ERROR; + ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); + sendAnyMessage(clientfds[BIFD].fd, &header, &error_message); + return 0; + } + // Scenario 2: Create a new client + } else if (header.type == HEADER_TYPE_INTRODUCTION) { + IntroductionMessage message; + nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0); + if (nrecv != sizeof(message)) { + loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message)); + return 0; + } + + // Copy metadata from IntroductionMessage + client = ArenaPush(clientsArena, sizeof(*clients)); + memcpy(client->author, message.author, AUTHOR_LEN); + client->id = nclients; + nclients++; + + // Save client +#ifdef IMPORT_ID + write(clients_file, client, sizeof(*client)); +#endif + loggingf("authenticate(%d)|Added [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id); + + HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID); + IDMessage id_message = {.id = client->id}; + sendAnyMessage(clientfds[BIFD].fd, &header, &id_message); + } else { + loggingf("authenticate(%d)|Wrong header expected %s or %s\n", clientfds[BIFD].fd, headerTypeString(HEADER_TYPE_INTRODUCTION), headerTypeString(HEADER_TYPE_ID)); + return 0; } + assert(client != 0); + + client->pollunifd = clientfds; + client->pollbifd = clientfds + 1; + + return client; } int -main(void) +main(int argc, char** argv) { - s32 err, serverfd, clientfd; - u32 on = 1; + signal(SIGPIPE, SIG_IGN); + + logfd = 2; + // optional logging + if (argc > 1) { + if (*argv[1] == '-') + if (argv[1][1] == 'l') { + logfd = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600); + assert(logfd != -1); + } + } + s32 serverfd; // Start listening on the socket { - serverfd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP); + s32 err; + u32 on = 1; + serverfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); assert(serverfd > 2); err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8*)&on, sizeof(on)); - assert(err == 0); + assert(!err); const struct sockaddr_in address = { AF_INET, htons(PORT), {0}, + {0}, }; err = bind(serverfd, (const struct sockaddr*)&address, sizeof(address)); - assert(err == 0); + assert(!err); err = listen(serverfd, MAX_CONNECTIONS); - assert(err == 0); + assert(!err); + loggingf("Listening on :%d\n", PORT); } - Arena* msgsArena = ArenaAlloc(Megabytes(128)); // storing received messages - // NOTE: sent messages? - s32 nrecv = 0; // number of bytes received - - Arena* clientsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(Client)); - Arena* fdsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(struct pollfd)); - struct pollfd* fds = fdsArena->addr; - Client* clients = clientsArena->addr; + Arena clientsArena; + Arena fdsArena; + Arena msgsArena; + ArenaAlloc(&clientsArena, MAX_CONNECTIONS * sizeof(Client)); + ArenaAlloc(&fdsArena, MAX_CONNECTIONS * 2 * sizeof(struct pollfd)); + ArenaAlloc(&msgsArena, Megabytes(128)); // storing received messages + struct pollfd* fds = fdsArena.addr; + Client* clients = clientsArena.addr; + // Initializing fds struct pollfd* fdsAddr; - struct pollfd newpollfd = {-1, POLLIN, 0}; - + struct pollfd newpollfd = {-1, POLLIN, 0}; // for copying with events already set // initialize fds structure newpollfd.fd = 0; - fdsAddr = ArenaPush(fdsArena, sizeof(*fds)); + fdsAddr = ArenaPush(&fdsArena, sizeof(*fds)); memcpy(fdsAddr, &newpollfd, sizeof(*fds)); // add serverfd newpollfd.fd = serverfd; - fdsAddr = ArenaPush(fdsArena, sizeof(*fds)); + fdsAddr = ArenaPush(&fdsArena, sizeof(*fds)); memcpy(fdsAddr, &newpollfd, sizeof(*fds)); newpollfd.fd = -1; +#ifdef IMPORT_ID + s32 clients_file = open(CLIENTS_FILE, O_RDWR | O_CREAT | O_APPEND, 0600); + assert(clients_file != -1); + struct stat statbuf; + assert(fstat(clients_file, &statbuf) != -1); + + read(clients_file, clients, statbuf.st_size); + if (statbuf.st_size > 0) { + ArenaPush(&clientsArena, statbuf.st_size); + loggingf("Imported %lu client(s)\n", statbuf.st_size / sizeof(*clients)); + nclients += statbuf.st_size / sizeof(*clients); + } + for (u32 i = 0; i < nclients; i++) + loggingf("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i])); +#endif + // Initialize the rest of the fds array for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++) fds[i] = newpollfd; + // Reset file descriptors on imported clients + for (u32 i = 0; i < CLIENTS_SIZE; i++) { + clients[i].pollunifd = 0; + clients[i].pollbifd = 0; + } while (1) { - err = poll(fds, FDS_SIZE, TIMEOUT); + s32 err = poll(fds, FDS_SIZE, TIMEOUT); assert(err != -1); if (fds[FDS_STDIN].revents & POLLIN) { - // helps for testing and exiting gracefully - break; + u8 c; // exit on ctrl-d + if (!read(fds[FDS_STDIN].fd, &c, 1)) + break; } else if (fds[FDS_SERVER].revents & POLLIN) { - clientfd = accept(serverfd, NULL, NULL); - assert(clientfd != -1); - assert(clientfd > serverfd); - fprintf(stdout, "New connection(%d).\n", clientfd); - - // If there is a slot in fds with fds[found].fd == -1 use it instead, otherwise allocate - // some space on the arena. - u8 found; - for (found = FDS_CLIENTS; found < FDS_SIZE; found++) - if (fds[found].fd == -1) - break; - if (found == MAX_CONNECTIONS) { - // TODO: reject connection - close(clientfd); - fprintf(stdout, "Max clients reached."); - } else if (found == FDS_SIZE) { - // no more space, allocate - struct pollfd* pollfd = ArenaPush(fdsArena, sizeof(*pollfd)); - pollfd->fd = clientfd; - pollfd->events = POLLIN; + // TODO: what if we are not aligned by 2 anymore? + s32 unifd = accept(serverfd, 0, 0); + s32 bifd = accept(serverfd, 0, 0); + + if (unifd == -1 || bifd == -1) { + loggingf("Error while accepting connection (%d,%d)\n", unifd, bifd); + if (unifd != -1) close(unifd); + if (bifd != -1) close(bifd); + continue; + } else + loggingf("New connection(%d,%d)\n", unifd, bifd); + + // TODO: find empty space in arena + if (nclients + 1 == MAX_CONNECTIONS) { + local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); + local_persist ErrorMessage message = ERROR_INIT(ERROR_TYPE_TOOMANYCONNECTIONS); + sendAnyMessage(unifd, &header, &message); + if (unifd != -1) + close(unifd); + if (bifd != -1) + close(bifd); + loggingf("Max clients reached. Rejected connection\n"); } else { - // hole found - fds[found].fd = clientfd; - fds[found].events = POLLIN; - fprintf(stdout, "Added pollfd(%d).\n", clientfd); + // no more space, allocate + struct pollfd* clientfds = ArenaPush(&fdsArena, 2 * sizeof(*clientfds)); + clientfds[UNIFD].fd = unifd; + clientfds[UNIFD].events = POLLIN; + clientfds[BIFD].fd = bifd; + clientfds[BIFD].events = POLLIN; + loggingf("Added pollfd(%d,%d)\n", unifd, bifd); } } - // Check for messages from clients - for (u32 i = FDS_CLIENTS; i < FDS_SIZE; i++) { - if (!(fds[i].revents & POLLIN)) - continue; - assert(fds[i].fd != -1); - fprintf(stdout, "Message(%d).\n", fds[i].fd); - Client* client = clients + i; - - // Initialize the client if this is the first time - if (!client->initialized) { - initClient(fdsArena, fds, fds[i].fd, client); - client->initialized = True; - fprintf(stdout, " Added to clients(%d): %s\n", fds[i].fd, client->author); + // Check for messages from clients in their unifd + for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn += 2) { + if (!(fds[conn].revents & POLLIN)) continue; + if (fds[conn].fd == -1) continue; + loggingf("Message unifd (%d)\n", fds[conn].fd); + + // Get client associated with connection + Client* client = 0; + for (u32 j = 0; j < CLIENTS_SIZE; j++) { + if (!clients[j].pollunifd) + continue; + if (clients[j].pollunifd == fds + conn) { + client = clients + j; + break; + } + } + if (!client) { + loggingf("No client associated(%d)\n", fds[conn].fd); + close(fds[conn].fd); continue; } + loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); // We received a message, try to parse the header HeaderMessage header; - nrecv = recv(fds[i].fd, &header, sizeof(header), 0); - assert(nrecv != -1); - + s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); if (nrecv == 0) { - disconnect(fdsArena, fds, i, (clients + i)); + disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); + continue; + } else if (nrecv != sizeof(header)) { + disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); continue; } - - assert(nrecv == sizeof(header)); - fprintf(stderr, " Received(%d): %d bytes -> " PH_FMT "\n", fds[i].fd, nrecv, PH_ARG(header)); + loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); switch (header.type) { - case HEADER_TYPE_TEXT:; - TextMessage* message; - nrecv = recvTextMessage(msgsArena, fds[i].fd, &message); - fprintf(stderr, " Received(%d): %d bytes -> ", fds[i].fd, nrecv); - printTextMessage(message, 0); + case HEADER_TYPE_TEXT: { + TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd); + loggingf("Received(%d)", fds[conn].fd); + printTextMessage(text_message, client, 0); + + sendToOthers(fds, FDS_SIZE, fds[conn].fd, UNIFD, &header, text_message); + } break; + // handle request for information about client id + default: + loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); + disconnectAndNotify(client, fds, FDS_SIZE, conn); + continue; + } + } - // Send message to all other clients - sendToOthers(fdsArena, fds, i, &header, message); + // Check for messages from clients in their bifd + for (u32 conn = FDS_CLIENTS + BIFD; conn < FDS_SIZE; conn += 2) { + if (!(fds[conn].revents & POLLIN)) continue; + if (fds[conn].fd == -1) continue; + loggingf("Message bifd (%d)\n", fds[conn].fd); + + // Get client associated with connection + Client* client = 0; + for (u32 j = 0; j < CLIENTS_SIZE; j++) { + if (!clients[j].pollbifd) + continue; + if (clients[j].pollbifd == fds + conn) { + client = clients + j; + break; + } + } + if (!client) { + loggingf("No client for connection(%d)\n", fds[conn].fd); +#ifdef IMPORT_ID + client = authenticate(&clientsArena, clients_file, fds + conn - 1); +#else + client = authenticate(&clientsArena, 0, fds + conn - 1); +#endif + // If the client sent an IDMessage but no ID was found authenticate() could return null + if (!client) { + loggingf("Could not initialize client\n"); + disconnect(fds + conn, 0); + } else { // client was added/connected + local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE); + PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_CONNECTED}; + sendToOthers(fds, FDS_SIZE, fds[conn - BIFD].fd, UNIFD, &header, &message); + } + continue; + } + loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd); - break; + // We received a message, try to parse the header + HeaderMessage header; + s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0); + if (nrecv == 0) { + disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("Disconnected(%lu) [%s]\n", client->id, client->author); + continue; + } else if (nrecv != sizeof(header)) { + disconnectAndNotify(client, fds, FDS_SIZE, conn); + loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header)); + continue; + } + loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header)); + + switch (header.type) { + case HEADER_TYPE_ID: { + // handle request for information about client id + IDMessage id_message; + nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0); + + Client* client = getClientByID(&clientsArena, id_message.id); + if (!client) { + local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR); + local_persist ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND); + sendAnyMessage(fds[conn].fd, &header, &error_message); + loggingf("Could not find %lu\n", id_message.id); + break; + } + HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION); + IntroductionMessage introduction_message; + memcpy(introduction_message.author, client->author, AUTHOR_LEN); + + sendAnyMessage(fds[conn].fd, &header, &introduction_message); + } break; default: - fprintf(stdout, " Got unhandled message type '%s' from client %d", headerTypeString(header.type), fds[i].fd); + loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd); + disconnectAndNotify(client, fds, FDS_SIZE, conn); continue; } } } - ArenaRelease(clientsArena); - ArenaRelease(fdsArena); - ArenaRelease(msgsArena); +#ifdef IMPORT_ID + close(clients_file); +#endif return 0; } |