aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaymaekers Luca <raymaekers.luca@gmail.com>2024-11-03 16:25:10 +0100
committerRaymaekers Luca <raymaekers.luca@gmail.com>2024-11-03 16:32:11 +0100
commit4b3dfeddc15908fdf53df42818ed167addc71659 (patch)
tree8c307416a19df07e8977896101553458144391e6
parent0d175d022fff8b87628dbc3fd0ead14f6660eb20 (diff)
Add id field to HeaderMessage for simplifying
- Changed bracket style as well
-rw-r--r--.gitignore3
-rw-r--r--README.md10
-rw-r--r--chatty.c345
-rw-r--r--protocol.h61
-rw-r--r--send.c17
-rw-r--r--server.c462
6 files changed, 478 insertions, 420 deletions
diff --git a/.gitignore b/.gitignore
index 91c2cfe..688b48f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,9 @@
chatty
send
server
+
_id
_clients
*.log
+
+tags
diff --git a/README.md b/README.md
index c3a7d19..50e9b7a 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,7 @@ The idea is the following:
## server
- [x] import clients
+- [ ] check that fds arena does not overflow
- [ ] check if when sending and the client is offline (due to connection loss) what happens
- [ ] timeout on recv?
- [ ] use threads to handle clients/ timeout when receiving because a client could theoretically
@@ -29,31 +30,32 @@ The idea is the following:
- [ ] do not crash on errors from clients
- implement error message?
- timeout on recv with setsockopt
+- [ ] theoretically two clients can connect at the same time. The uni/bi connections should be
+ negotiated.
## common
- [x] handle messages that are too large
- [x] refactor i&self into conn
- [x] logging
- [x] Req|Inf connection per client
+- [x] connect/disconnect messages
- [ ] bug: blocking after `Added pollfd`, after importing a client and then connecting with the
id/or without? After reconnection fails chatty blocks (remove sleep)
- [ ] connect/disconnections messages
- [ ] use IP address / domain
- [ ] chat history
- [ ] asserting, logging if fail / halt execution
+- [ ] compression
## Protocol
- see `protocol.h` for more info
- [ ] make sections per message
- request chat logs from a certain point up to now (history)
- connect to a specific room
-- connect/disconnect messages
- The null terminator must be sent with the string.
- The text can be arbitrary length
-- [ ] compression
-
## Arena's
1. There is an arena for the messages' texts (`msgTextArena`) and an arena for the messages
(`msgsArena`).
@@ -82,4 +84,4 @@ Notice, that this depends on knowing the text's length before allocating the mem
- *pthreads*: [C for dummies](https://c-for-dummies.com/blog/?p=5365)
- *unicode and wide characters*: [C for dummies](https://c-for-dummies.com/blog/?p=2578)
- *sockets*: [Nir Lichtman - Making Minimalist Chat Server in C on Linux](https://www.youtube.com/watch?v=gGfTjKwLQxY)
-- syscall manpages
+- syscall manpages `man`
diff --git a/chatty.c b/chatty.c
index 4746aa7..ccc722a 100644
--- a/chatty.c
+++ b/chatty.c
@@ -37,12 +37,6 @@ typedef struct {
#define CLIENT_FMT "[%s](%lu)"
#define CLIENT_ARG(client) client.author, client.id
-typedef struct {
- s32 err; // Error while connecting
- s32 unifd;
- s32 bifd;
-} ConnectionResult;
-
// Client used by chatty
global_variable Client user = {0};
// Address of chatty server
@@ -74,7 +68,8 @@ getClientById(Arena* clientsArena, ID id)
if (id == user.id) return &user;
Client* clients = clientsArena->addr;
- for (u64 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) {
+ for (u64 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++)
+ {
if (clients[i].id == id)
return clients + i;
}
@@ -88,16 +83,17 @@ addClientInfo(Arena* clientsArena, s32 fd, u64 id)
{
// Request information about ID
HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
- IDMessage id_message = {.id = id};
- sendAnyMessage(fd, &header, &id_message);
-
- Client* client = ArenaPush(clientsArena, sizeof(*client));
+ header.id = id;
+ s32 nsend = send(fd, &header, sizeof(header), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(header));
// Wait for response
IntroductionMessage introduction_message;
recvAnyMessageType(fd, &header, &introduction_message, HEADER_TYPE_INTRODUCTION);
// Add the information
+ Client* client = ArenaPush(clientsArena, sizeof(*client));
memcpy(client->author, introduction_message.author, AUTHOR_LEN);
client->id = id;
@@ -106,17 +102,57 @@ addClientInfo(Arena* clientsArena, s32 fd, u64 id)
}
// Tries to connect to address and populates resulting file descriptors in ConnectionResult.
-ConnectionResult
+s32
getConnection(struct sockaddr_in* address)
{
- ConnectionResult result;
- result.unifd = socket(AF_INET, SOCK_STREAM, 0);
- result.bifd = socket(AF_INET, SOCK_STREAM, 0);
- result.err = connect(result.unifd, (struct sockaddr*)address, sizeof(*address));
- if (result.err) return result; // We do not overwrite the error and return early so we can be
- // certain of what error errno belongs to.
- result.err = connect(result.bifd, (struct sockaddr*)address, sizeof(*address));
- return result;
+ s32 fd = socket(AF_INET, SOCK_STREAM, 0);
+ if (fd == -1) return -1;
+
+ s32 err = connect(fd, (struct sockaddr*)address, sizeof(*address));
+ if (err) return -1;
+
+ return fd;
+}
+
+ID
+authenticate(Client* user, s32 fd)
+{
+ if (user->id)
+ {
+ HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
+ header.id = user->id;
+ s32 nsend = send(fd, &header, sizeof(header), 0);
+ assert(nsend == -1);
+ assert(nsend == sizeof(header));
+
+ s32 nrecv = recv(fd, &header, sizeof(header), 0);
+ if (nrecv == 0)
+ return 0;
+ assert(nrecv != -1);
+ assert(nrecv == sizeof(header));
+ assert(header.type == HEADER_TYPE_ID);
+ if (header.id == user->id)
+ return header.id;
+ else
+ return 0;
+ }
+ else
+ {
+ HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION);
+ IntroductionMessage message;
+ memcpy(message.author, user->author, AUTHOR_LEN);
+ sendAnyMessage(fd, header, &message);
+
+ s32 nrecv = recv(fd, &header, sizeof(header), 0);
+ if (nrecv == 0)
+ return 0;
+ assert(nrecv != -1);
+ assert(nrecv == sizeof(header));
+ assert(header.type == HEADER_TYPE_ID);
+ assert(!header.id);
+ user->id = header.id;
+ return header.id;
+ }
}
// Connect to *address_ptr of type `struct sockaddr_in*`. If it failed wait for TIMEOUT_RECONNECT
@@ -129,49 +165,44 @@ getConnection(struct sockaddr_in* address)
void*
threadReconnect(void* fds_ptr)
{
+ s32 unifd, bifd;
struct pollfd* fds = fds_ptr;
- ConnectionResult result;
struct timespec t = { 0, Miliseconds(300) }; // 300 miliseconds
loggingf("Trying to reconnect\n");
- while (1) {
+ while (1)
+ {
+ // timeout
nanosleep(&t, &t);
- result = getConnection(&address);
- if (result.err) {
- // loggingf("err: %d\n", result.err);
+
+ unifd = getConnection(&address);
+ if (unifd == -1)
+ {
loggingf("errno: %d\n", errno);
- } else if (result.unifd != -1 && result.bifd != -1) {
- loggingf("Reconnect succeeded (%d, %d), authenticating\n", result.unifd, result.bifd);
- // We assume that we already have an ID
- // TODO: could there be a problem if a message is received at the same time?
- // - not on server restart, but what if we lost connection?
- HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
- IDMessage id_message = {.id = user.id};
- sendAnyMessage(result.bifd, &header, &id_message);
-
- ErrorMessage error_message;
- s32 nrecv = recvAnyMessageType(result.bifd, &header, &error_message, HEADER_TYPE_ERROR);
- if (nrecv == -1 || nrecv == 0) {
- loggingf("Error on receive, retrying...\n");
- continue;
- }
+ continue;
+ }
+ bifd = getConnection(&address);
+ if (bifd == -1)
+ {
+ loggingf("errno: %d\n", errno);
+ close(unifd);
+ continue;
+ }
- assert(header.type == HEADER_TYPE_ERROR);
- if (error_message.type == ERROR_TYPE_SUCCESS) {
- loggingf("Reconnected\n");
- break;
- } else {
- loggingf("err: %s\n", errorTypeString(error_message.type));
- }
+ loggingf("Reconnect succeeded (%d, %d), authenticating\n", unifd, bifd);
+
+ if (authenticate(&user, unifd) &&
+ authenticate(&user, bifd))
+ {
+ break;
}
- if (result.unifd != -1)
- close(result.unifd);
- if (result.bifd != -1)
- close(result.bifd);
- loggingf("Failed, retrying..\n");
+
+ close(unifd);
+ close(bifd);
+ loggingf("Failed, retrying...\n");
}
- fds[FDS_BI].fd = result.bifd;
- fds[FDS_UNI].fd = result.unifd;
+ fds[FDS_BI].fd = bifd;
+ fds[FDS_UNI].fd = unifd;
// Redraw screen
raise(SIGWINCH);
@@ -206,16 +237,20 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx
u32 failed = 0;
// NOTE: We can assume that we need to wrap, therefore print a newline after the prefix string
- if (pfx != 0) {
+ if (pfx != 0)
+ {
tb_printf(x, ly, fg_pfx, bg_pfx, "%s", pfx);
// If the text fits on one line print the text and return
// Otherwise print the text on the next line
s32 pfx_len = strlen((char*)pfx);
- if (limit_x > pfx_len + text_len) {
+ if (limit_x > pfx_len + text_len)
+ {
tb_printf(x + pfx_len, y, fg, bg, "%ls", text);
return 1;
- } else {
+ }
+ else
+ {
ly++;
}
}
@@ -235,18 +270,22 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx
// 8. step 2. until i >= text_len
// 9. print remaining part of the string
- while (i < text_len && ly - y < limit_y) {
+ while (i < text_len && ly - y < limit_y)
+ {
// search backwards for whitespace
while (i > offset && text[i] != L' ')
i--;
// retry with bigger limit
- if (i == offset) {
+ if (i == offset)
+ {
offset = i;
failed++;
i += limit_x + failed * limit_x;
continue;
- } else {
+ }
+ else
+ {
failed = 0;
}
@@ -261,7 +300,8 @@ tb_printf_wrap(u32 x, u32 y, u32 fg, u32 bg, u32* text, s32 text_len, u32 fg_pfx
offset = i;
i += limit_x;
}
- if ((u32)ly <= limit_y) {
+ if ((u32)ly <= limit_y)
+ {
tb_printf(x, ly, fg, bg, "%ls", text + offset);
ly++;
}
@@ -283,11 +323,14 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
// the minimum height required is the hight for the box prompt
// the minimum width required is that one character should fit in the box prompt
if (global.height < box_height ||
- global.width < (box_x + box_mar_x * 2 + box_pad_x * 2 + box_bwith * 2 + 1)) {
+ global.width < (box_x + box_mar_x * 2 + box_pad_x * 2 + box_bwith * 2 + 1))
+ {
// + 1 for cursor
tb_hide_cursor();
return;
- } else {
+ }
+ else
+ {
// show cursor
// TODO: show cursor as block character instead of using the real cursor
bytebuf_puts(&global.out, global.caps[TB_CAP_SHOW_CURSOR]);
@@ -310,11 +353,14 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
u32 offs = (nmessages > free_y) ? nmessages - free_y : 0;
// skip offs ccount messages
- for (u32 i = 0; i < offs; i++) {
+ for (u32 i = 0; i < offs; i++)
+ {
HeaderMessage* header = (HeaderMessage*)addr;
addr += sizeof(*header);
- switch (header->type) {
- case HEADER_TYPE_TEXT: {
+ switch (header->type)
+ {
+ case HEADER_TYPE_TEXT:
+ {
TextMessage* message = (TextMessage*)addr;
addr += TEXTMESSAGE_SIZE;
addr += message->len * sizeof(*message->text);
@@ -333,36 +379,41 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
}
// In each case statement advance the addr pointer by the size of the message
- for (u32 i = offs; i < nmessages && msg_y < free_y; i++) {
+ for (u32 i = offs; i < nmessages && msg_y < free_y; i++)
+ {
HeaderMessage* header = (HeaderMessage*)addr;
addr += sizeof(*header);
// Get Client for message
- ID* id;
Client* client;
- switch (header->type) {
+ switch (header->type)
+ {
case HEADER_TYPE_TEXT:
- id = &((TextMessage*)addr)->id;
case HEADER_TYPE_PRESENCE:
- id = &((PresenceMessage*)addr)->id;
- client = getClientById(clientsArena, *id);
- if (!client) {
+ client = getClientById(clientsArena, header->id);
+ if (!client)
+ {
loggingf("Client not known, requesting from server\n");
- client = addClientInfo(clientsArena, fds[FDS_BI].fd, *id);
+ client = addClientInfo(clientsArena, fds[FDS_BI].fd, header->id);
}
assert(client);
break;
}
- switch (header->type) {
- case HEADER_TYPE_TEXT: {
+ switch (header->type)
+ {
+ case HEADER_TYPE_TEXT:
+ {
TextMessage* message = (TextMessage*)addr;
// Color own messages
u32 fg = 0;
- if (user.id == message->id) {
+ if (user.id == header->id)
+ {
fg = TB_CYAN;
- } else {
+ }
+ else
+ {
fg = TB_MAGENTA;
}
@@ -377,13 +428,15 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
u32 message_size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text);
addr += message_size;
} break;
- case HEADER_TYPE_PRESENCE: {
+ case HEADER_TYPE_PRESENCE:
+ {
PresenceMessage* message = (PresenceMessage*)addr;
tb_printf(0, msg_y, 0, 0, " [%s] *%s*", client->author, presenceTypeString(message->type));
msg_y++;
addr += sizeof(*message);
} break;
- case HEADER_TYPE_HISTORY: {
+ case HEADER_TYPE_HISTORY:
+ {
HistoryMessage* message = (HistoryMessage*)addr;
addr += sizeof(*message);
// TODO: implement
@@ -445,17 +498,21 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
if (freesp <= 0)
return;
- if (input_len > freesp) {
+ if (input_len > freesp)
+ {
u32* text_offs = input + (input_len - freesp);
tb_printf(box_x + box_mar_x + box_pad_x + box_bwith, box_y + 1, 0, 0, "%ls", text_offs);
global.cursor_x = box_x + box_pad_x + box_mar_x + box_bwith + freesp;
- } else {
+ }
+ else
+ {
global.cursor_x = prompt_x;
tb_printf(box_x + box_mar_x + box_pad_x + box_bwith, box_y + 1, 0, 0, "%ls", input);
}
}
- if (fds[FDS_UNI].fd == -1 || fds[FDS_BI].fd == -1) {
+ if (fds[FDS_UNI].fd == -1 || fds[FDS_BI].fd == -1)
+ {
// show error popup
popup(TB_RED, TB_BLACK, "Server disconnected.");
}
@@ -465,7 +522,8 @@ screen_home(Arena* msgsArena, u32 nmessages, Arena* clientsArena, struct pollfd*
int
main(int argc, char** argv)
{
- if (argc < 2) {
+ if (argc < 2)
+ {
fprintf(stderr, "usage: chatty <username>\n");
return 1;
}
@@ -516,67 +574,42 @@ main(int argc, char** argv)
{0},
};
- ConnectionResult result = getConnection(&address);
- if (result.err) {
- perror("Server");
- return 1;
- }
- assert(result.unifd != -1);
- assert(result.bifd != -1);
- assert(!result.err);
- fds[FDS_BI].fd = result.bifd;
- fds[FDS_UNI].fd = result.unifd;
-
#ifdef IMPORT_ID
// File for storing the user's ID.
u32 idfile = open(ID_FILE, O_RDWR | O_CREAT, 0600);
s32 nread = read(idfile, &user.id, sizeof(user.id));
assert(nread != -1);
- // see "Authentication" in chatty.h
- if (nread == sizeof(user.id)) {
- // Scenario 1: We know our id
-
- // Send IDMessage and check if it is correct
- HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
- IDMessage message = {.id = user.id};
- sendAnyMessage(fds[FDS_BI].fd, &header, &message);
-
- ErrorMessage error_message = {0};
- recvAnyMessageType(fds[FDS_BI].fd, &header, &error_message, HEADER_TYPE_ERROR);
-
- switch (error_message.type) {
- case ERROR_TYPE_SUCCESS: break;
- case ERROR_TYPE_NOTFOUND:
- printf("Server does not know our ID. Consider removing '" ID_FILE "'\n");
+#endif
+ /* Authentication */
+ {
+ s32 unifd, bifd;
+ unifd = getConnection(&address);
+ if (unifd == -1)
+ {
+ loggingf("errno: %d\n", errno);
return 1;
- default:
- printf("Server: %s\n", errorTypeString(error_message.type));
+ }
+ bifd = getConnection(&address);
+ if (bifd == -1)
+ {
+ loggingf("errno: %d\n", errno);
return 1;
}
- } else {
-#else
- if (1) {
-#endif
- // Scenario 2: We do not have an ID
- HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION);
- IntroductionMessage message = {0};
- // copy user data into message
- memcpy(message.author, user.author, AUTHOR_LEN);
-
- // Send the introduction message
- sendAnyMessage(fds[FDS_BI].fd, &header, &message);
+ if (!authenticate(&user, unifd) ||
+ !authenticate(&user, bifd))
+ {
+ loggingf("errno: %d\n", errno);
+ return 1;
+ }
+ fds[FDS_UNI].fd = unifd;
+ fds[FDS_BI].fd = bifd;
+ }
- IDMessage id_message = {0};
- // Receive the response IDMessage
- recvAnyMessageType(fds[FDS_BI].fd, &header, &id_message, HEADER_TYPE_ID);
- assert(header.type == HEADER_TYPE_ID);
- user.id = id_message.id;
#ifdef IMPORT_ID
- // Save permanently
- write(idfile, &user.id, sizeof(user.id));
- close(idfile);
+ // Save id
+ write(idfile, &user.id, sizeof(user.id));
#endif
- }
+
loggingf("Got ID: %lu\n", user.id);
// for wide character printing
@@ -590,21 +623,24 @@ main(int argc, char** argv)
tb_present();
// main loop
- while (!quit) {
+ while (!quit)
+ {
err = poll(fds, FDS_MAX, TIMEOUT_POLL);
// ignore resize events and use them to redraw the screen
assert(err != -1 || errno == EINTR);
tb_clear();
- if (fds[FDS_UNI].revents & POLLIN) {
+ if (fds[FDS_UNI].revents & POLLIN)
+ {
// got data from server
HeaderMessage header;
nrecv = recv(fds[FDS_UNI].fd, &header, sizeof(header), 0);
assert(nrecv != -1);
// Server disconnects
- if (nrecv == 0) {
+ if (nrecv == 0)
+ {
// close diconnected server's socket
err = close(fds[FDS_UNI].fd);
assert(err == 0);
@@ -612,8 +648,11 @@ main(int argc, char** argv)
// start trying to reconnect in a thread
err = pthread_create(&thr_rec, 0, &threadReconnect, (void*)fds);
assert(err == 0);
- } else {
- if (header.version != PROTOCOL_VERSION) {
+ }
+ else
+ {
+ if (header.version != PROTOCOL_VERSION)
+ {
loggingf("Header received does not match version\n");
continue;
}
@@ -622,7 +661,8 @@ main(int argc, char** argv)
memcpy(addr, &header, sizeof(header));
// Messages handled from server
- switch (header.type) {
+ switch (header.type)
+ {
case HEADER_TYPE_TEXT:
recvTextMessage(&msgsArena, fds[FDS_UNI].fd);
nmessages++;
@@ -641,15 +681,19 @@ main(int argc, char** argv)
}
}
- if (fds[FDS_TTY].revents & POLLIN) {
+ if (fds[FDS_TTY].revents & POLLIN)
+ {
// got a key event
tb_poll_event(&ev);
- switch (ev.key) {
+ switch (ev.key)
+ {
case TB_KEY_CTRL_W:
// delete consecutive whitespace
- while (ninput) {
- if (input[ninput - 1] == L' ') {
+ while (ninput)
+ {
+ if (input[ninput - 1] == L' ')
+ {
input[ninput - 1] = 0;
ninput--;
continue;
@@ -657,7 +701,8 @@ main(int argc, char** argv)
break;
}
// delete until whitespace
- while (ninput) {
+ while (ninput)
+ {
if (input[ninput - 1] == L' ')
break;
// erase
@@ -665,7 +710,8 @@ main(int argc, char** argv)
ninput--;
}
break;
- case TB_KEY_CTRL_Z: {
+ case TB_KEY_CTRL_Z:
+ {
pid_t pid = getpid();
tb_shutdown();
kill(pid, SIGSTOP);
@@ -692,10 +738,10 @@ main(int argc, char** argv)
HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT);
void* addr = ArenaPush(&msgsArena, sizeof(header));
memcpy(addr, &header, sizeof(header));
+ header.id = user.id;
// Save message
TextMessage* sendmsg = ArenaPush(&msgsArena, TEXTMESSAGE_SIZE);
- sendmsg->id = user.id;
sendmsg->timestamp = time(0);
sendmsg->len = ninput;
@@ -703,7 +749,7 @@ main(int argc, char** argv)
ArenaPush(&msgsArena, text_size);
memcpy(&sendmsg->text, input, text_size);
- sendAnyMessage(fds[FDS_UNI].fd, &header, sendmsg);
+ sendAnyMessage(fds[FDS_UNI].fd, header, sendmsg);
nmessages++;
// also clear input
@@ -728,7 +774,8 @@ main(int argc, char** argv)
}
// These are used to redraw the screen from threads
- if (fds[FDS_RESIZE].revents & POLLIN) {
+ if (fds[FDS_RESIZE].revents & POLLIN)
+ {
// ignore
tb_poll_event(&ev);
}
diff --git a/protocol.h b/protocol.h
index 84d636f..5d9344b 100644
--- a/protocol.h
+++ b/protocol.h
@@ -49,6 +49,7 @@ typedef u64 ID;
typedef struct {
u16 version;
u8 type;
+ ID id;
} HeaderMessage;
typedef enum {
@@ -60,7 +61,7 @@ typedef enum {
HEADER_TYPE_ERROR
} HeaderType;
// shorthand for creating a header with a value from the enum
-#define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t}
+#define HEADER_INIT(t) {.version = PROTOCOL_VERSION, .type = t, .id = 0}
// from Tsoding video on minicel (https://youtu.be/HCAgvKQDJng?t=4546)
// sv(https://github.com/tsoding/sv)
#define HEADER_FMT "header: v%d %s(%d)"
@@ -69,11 +70,9 @@ typedef enum {
// For sending texts to other clients
// - 13 bytes for the author
// - 8 bytes for the timestamp
-// - 8 bytes for id
// - 2 bytes for the text length
// - x*4 bytes for the text
typedef struct {
- ID id;
u64 timestamp; // timestamp of when the message was sent
u16 len;
wchar_t* text; // placeholder for indexing
@@ -97,19 +96,9 @@ typedef struct {
#define INTRODUCTION_FMT "introduction: %s"
#define INTRODUCTION_ARG(message) message.author
-// Request IntroductionMessage for client with that id.
-// See "First connection" if this message is used when the client connects for the first time.
-// be used to retrieve information about a client with an unknown ID.
-// - 8 bytes for id
-typedef struct {
- ID id;
-} IDMessage;
-
// Notifying the sender's state, such as "connected", "disconnected", "AFK", ...
-// - 8 bytes for id
// - 1 byte for type
typedef struct {
- ID id;
u8 type;
} PresenceMessage;
typedef enum {
@@ -141,11 +130,11 @@ typedef struct {
u8*
headerTypeString(HeaderType type)
{
- switch (type) {
+ switch (type)
+ {
case HEADER_TYPE_TEXT: return (u8*)"TextMessage";
case HEADER_TYPE_HISTORY: return (u8*)"HistoryMessage";
case HEADER_TYPE_PRESENCE: return (u8*)"PresenceMessage";
- case HEADER_TYPE_ID: return (u8*)"IDMessage";
case HEADER_TYPE_INTRODUCTION: return (u8*)"IntroductionMessage";
case HEADER_TYPE_ERROR: return (u8*)"ErrorMessage";
default: return (u8*)"Unknown";
@@ -155,7 +144,8 @@ headerTypeString(HeaderType type)
u8*
presenceTypeString(PresenceType type)
{
- switch (type) {
+ switch (type)
+ {
case PRESENCE_TYPE_CONNECTED: return (u8*)"connected";
case PRESENCE_TYPE_DISCONNECTED: return (u8*)"disconnected";
case PRESENCE_TYPE_AFK: return (u8*)"afk";
@@ -166,7 +156,8 @@ presenceTypeString(PresenceType type)
u8*
errorTypeString(ErrorType type)
{
- switch (type) {
+ switch (type)
+ {
case ERROR_TYPE_BADMESSAGE: return (u8*)"bad message";
case ERROR_TYPE_NOTFOUND: return (u8*)"not found";
case ERROR_TYPE_SUCCESS: return (u8*)"success";
@@ -217,10 +208,11 @@ u32
getMessageSize(HeaderType type)
{
u32 size = 0;
- switch (type) {
+ switch (type)
+ {
case HEADER_TYPE_ERROR: size = sizeof(ErrorMessage); break;
case HEADER_TYPE_HISTORY: size = sizeof(HistoryMessage); break;
- case HEADER_TYPE_ID: size = sizeof(IDMessage); break;
+ case HEADER_TYPE_ID: size = sizeof(HeaderMessage); break;
case HEADER_TYPE_INTRODUCTION: size = sizeof(IntroductionMessage); break;
case HEADER_TYPE_PRESENCE: size = sizeof(PresenceMessage); break;
default: assert(0);
@@ -237,7 +229,8 @@ recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType t
assert(nrecv == sizeof(*header));
s32 size = 0;
- switch (type) {
+ switch (type)
+ {
case HEADER_TYPE_ERROR:
case HEADER_TYPE_HISTORY:
case HEADER_TYPE_ID:
@@ -245,7 +238,8 @@ recvAnyMessageType(s32 fd, HeaderMessage* header, void *anyMessage, HeaderType t
case HEADER_TYPE_PRESENCE:
size = getMessageSize(header->type);
break;
- case HEADER_TYPE_TEXT: {
+ case HEADER_TYPE_TEXT:
+ {
TextMessage* message = anyMessage;
size = TEXTMESSAGE_SIZE + message->len * sizeof(*message->text);
} break;
@@ -270,7 +264,8 @@ recvAnyMessage(Arena* arena, s32 fd)
assert(nrecv == sizeof(*header));
s32 size = 0;
- switch (header->type) {
+ switch (header->type)
+ {
case HEADER_TYPE_ERROR:
case HEADER_TYPE_HISTORY:
case HEADER_TYPE_ID:
@@ -278,7 +273,8 @@ recvAnyMessage(Arena* arena, s32 fd)
case HEADER_TYPE_PRESENCE:
size = getMessageSize(header->type);
break;
- case HEADER_TYPE_TEXT: {
+ case HEADER_TYPE_TEXT:
+ {
Message result;
result.header = header;
result.message = recvTextMessage(arena, fd);
@@ -303,7 +299,8 @@ Message
waitForMessageType(Arena* arena, Arena* queueArena, u32 fd, HeaderType type)
{
Message message;
- while (1) {
+ while (1)
+ {
message = recvAnyMessage(arena, fd);
if (message.header->type == type)
break;
@@ -315,24 +312,26 @@ waitForMessageType(Arena* arena, Arena* queueArena, u32 fd, HeaderType type)
// Generic sending function for sending any type of message to fd
// Returns number of bytes sent in message or -1 if there was an error.
s32
-sendAnyMessage(u32 fd, HeaderMessage* header, void* anyMessage)
+sendAnyMessage(u32 fd, HeaderMessage header, void* anyMessage)
{
s32 nsend_total;
- s32 nsend = send(fd, header, sizeof(*header), 0);
+ s32 nsend = send(fd, &header, sizeof(header), 0);
if (nsend == -1) return nsend;
- assert(nsend == sizeof(*header));
+ assert(nsend == sizeof(header));
nsend_total = nsend;
s32 size = 0;
- switch (header->type) {
+ switch (header.type)
+ {
case HEADER_TYPE_ERROR:
case HEADER_TYPE_HISTORY:
case HEADER_TYPE_ID:
case HEADER_TYPE_INTRODUCTION:
case HEADER_TYPE_PRESENCE:
- size = getMessageSize(header->type);
+ size = getMessageSize(header.type);
break;
- case HEADER_TYPE_TEXT: {
+ case HEADER_TYPE_TEXT:
+ {
nsend = send(fd, anyMessage, TEXTMESSAGE_SIZE, 0);
assert(nsend != -1);
assert(nsend == TEXTMESSAGE_SIZE);
@@ -345,7 +344,7 @@ sendAnyMessage(u32 fd, HeaderMessage* header, void* anyMessage)
anyMessage = &message->text;
} break;
default:
- fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header->type));
+ fprintf(stdout, "sendAnyMessage(%d)|Cannot send %s\n", fd, headerTypeString(header.type));
return 0;
}
diff --git a/send.c b/send.c
index 689c6dc..067cb46 100644
--- a/send.c
+++ b/send.c
@@ -46,18 +46,9 @@ main(int argc, char** argv)
// Get id
nrecv = recv(serverfd, &header, sizeof(header), 0);
assert(nrecv != -1);
- if (header.type == HEADER_TYPE_ERROR) {
- ErrorMessage message;
- nrecv = recv(serverfd, &message, sizeof(message), 0);
- fprintf(stderr, "Got '%s' error.\n'", errorTypeString(message.type));
- close(serverfd);
- return 1;
- }
assert(header.type == HEADER_TYPE_ID);
- IDMessage idmessage;
- nrecv = recv(serverfd, &idmessage, sizeof(idmessage), 0);
- assert(nrecv != -1);
- fprintf(stderr, "Got id: %lu\n", idmessage.id);
+ id = header.id;
+ fprintf(stderr, "Got id: %lu\n", header.id);
}
// convert text to wide string
@@ -68,13 +59,15 @@ main(int argc, char** argv)
text_wide[text_len - 1] = 0;
HeaderMessage header = HEADER_INIT(HEADER_TYPE_TEXT);
+ header.id = id;
TextMessage message;
bzero(&message, TEXTMESSAGE_SIZE);
- message = (TextMessage){.id = id, .timestamp = time(NULL), .len = text_len};
+ message = (TextMessage){.timestamp = time(NULL), .len = text_len};
nsend = send(serverfd, &header, sizeof(header), 0);
assert(nsend != -1);
fprintf(stderr, "header bytes sent: %d\n", nsend);
+
nsend = send(serverfd, &message, TEXTMESSAGE_SIZE, 0);
assert(nsend != -1);
fprintf(stderr, "message bytes sent: %d\n", nsend);
diff --git a/server.c b/server.c
index 03dbf4d..881abf5 100644
--- a/server.c
+++ b/server.c
@@ -39,8 +39,8 @@ enum { FDS_STDIN = 0,
typedef struct {
u8 author[AUTHOR_LEN]; // matches author property on other message types
ID id;
- struct pollfd* pollunifd; // Index in fds array
- struct pollfd* pollbifd; // Index in fds array
+ struct pollfd* unifd; // Index in fds array
+ struct pollfd* bifd; // Index in fds array
} Client;
#define CLIENT_FMT "[%s](%lu)"
#define CLIENT_ARG(client) client.author, client.id
@@ -50,24 +50,42 @@ typedef enum {
BIFD
} ClientFD;
-// TODO: remove
+// TODO: remove global variable
// For handing out new ids to connections.
-global_variable u32 nclients = 0;
+// Start at 1 because this makes 0 an invalid client id.
+global_variable u32 nclients = 1;
-// Returns client matching id in clients.
-// clientsArena is used to get an upper bound.
-// Returns 0 if there was no client found.
+// Returns client matching id in clients nclients number of clients.
+// Returns 0 if no client was found or if id was 0.
Client*
-getClientByID(Arena* clientsArena, ID id)
+getClientByID(Client* clients, u32 nclients, ID id)
{
- Client* clients = clientsArena->addr;
- for (u32 i = 0; i < (clientsArena->pos / sizeof(*clients)); i++) {
+ if (!id) return 0;
+
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
if (clients[i].id == id)
return clients + i;
}
return 0;
}
+// Returns client matching fd in clients nclients number of clients.
+// Returns 0 if no clients was found or if fd was -1.
+Client*
+getClientByFD(Client* clients, u32 nclients, s32 fd)
+{
+ if (fd == -1) return 0;
+
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
+ if ((clients[i].unifd && clients[i].unifd->fd == fd) ||
+ (clients[i].bifd && clients[i].bifd->fd == fd))
+ return clients + i;
+ }
+ return 0;
+}
+
// Print TextMessage prettily
void
printTextMessage(TextMessage* message, Client* client, u8 wide)
@@ -75,7 +93,8 @@ printTextMessage(TextMessage* message, Client* client, u8 wide)
u8 timestamp[TIMESTAMP_LEN] = {0};
formatTimestamp(timestamp, message->timestamp);
- if (wide) {
+ if (wide)
+ {
setlocale(LC_ALL, "");
wprintf(L"TextMessage: %s [%s] %ls\n", timestamp, client->author, (wchar_t*)&message->text);
} else {
@@ -89,60 +108,81 @@ printTextMessage(TextMessage* message, Client* client, u8 wide)
// for connfd.
// Type will filter out only connections matching the type.
void
-sendToOthers(struct pollfd* fds, u32 nfds, s32 connfd, ClientFD type, HeaderMessage* header, void* anyMessage)
+sendToOthers(Client* clients, u32 nclients, Client* client, ClientFD type, HeaderMessage* header, void* anyMessage)
{
s32 nsend;
- for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) {
- if (fds[i].fd == connfd) continue;
- if (fds[i].fd == -1) continue;
+ for (u32 i = 0; i < nclients - 1; i ++)
+ {
+ if (clients + i == client) continue;
- nsend = sendAnyMessage(fds[i].fd, header, anyMessage);
- loggingf("sendToOthers(%d)|[%s]->%d %d bytes\n", connfd, headerTypeString(header->type), fds[i].fd, nsend);
+ if (type == UNIFD)
+ {
+ nsend = sendAnyMessage(client->unifd->fd, *header, anyMessage);
+ }
+ else if (type == BIFD)
+ {
+ nsend = sendAnyMessage(client->bifd->fd, *header, anyMessage);
+ }
+ assert(nsend != -1);
+ loggingf("sendToOthers "CLIENT_FMT"|%s %d bytes\n", CLIENT_ARG((*client)), headerTypeString(header->type), nsend);
}
}
// Send header and anyMessage to each connection in fds that is nfds number of connections.
// Type will filter out only connections matching the type.
void
-sendToAll(struct pollfd* fds, u32 nfds, ClientFD type, HeaderMessage* header, void* anyMessage)
+sendToAll(Client* clients, u32 nclients, ClientFD type, HeaderMessage* header, void* anyMessage)
{
- for (u32 i = FDS_CLIENTS + type; i < nfds; i += 2) {
- if (fds[i].fd == -1) continue;
- s32 nsend = sendAnyMessage(fds[i].fd, header, anyMessage);
- loggingf("sendToAll|[%s]->%d %d bytes\n", headerTypeString(header->type), fds[i].fd, nsend);
+ s32 nsend;
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
+ if (type == UNIFD)
+ {
+ if (clients[i].unifd && clients[i].unifd->fd != -1)
+ nsend = sendAnyMessage(clients[i].unifd->fd, *header, anyMessage);
+ }
+ else if (type == BIFD)
+ {
+ if (clients[i].bifd && clients[i].bifd->fd != -1)
+ nsend = sendAnyMessage(clients[i].bifd->fd, *header, anyMessage);
+ }
+ assert(nsend != -1);
+ loggingf("sendToAll|[%s]->"CLIENT_FMT" %d bytes\n", headerTypeString(header->type),
+ CLIENT_ARG(clients[i]),
+ nsend);
}
}
// Disconnect a client by closing the matching file descriptors
void
-disconnect(struct pollfd* pollfd, Client* client)
+disconnect(Client* client)
{
loggingf("Disconnecting "CLIENT_FMT"\n", CLIENT_ARG((*client)));
- if (pollfd[UNIFD].fd != -1) {
- close(pollfd[UNIFD].fd);
- }
- if (pollfd[BIFD].fd != -1) {
- close(pollfd[BIFD].fd);
+ if (client->unifd && client->unifd->fd != -1)
+ {
+ close(client->unifd->fd);
+ client->unifd->fd = -1;
+ client->unifd = 0;
}
- pollfd[UNIFD].fd = -1;
- pollfd[BIFD].fd = -1;
- // TODO: mark as free
- if (client) {
- client->pollunifd = 0;
- client->pollbifd = 0;
+ if (client->bifd && client->bifd->fd != -1)
+ {
+ close(client->bifd->fd);
+ client->bifd->fd = -1;
+ client->bifd = 0;
}
}
// Disconnects fds+conn from fds with nfds connections, then send a PresenceMessage to other
// clients about disconnection.
void
-disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn)
+disconnectAndNotify(Client* clients, u32 nclients, Client* client)
{
- disconnect(fds + conn, client);
+ disconnect(client);
local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE);
- PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_DISCONNECTED};
- sendToAll(fds, nfds, UNIFD, &header, &message);
+ header.id = client->id;
+ PresenceMessage message = {.type = PRESENCE_TYPE_DISCONNECTED};
+ sendToAll(clients, nclients, UNIFD, &header, &message);
}
// Receive authentication from pollfd->fd and create client out of it. Look in
@@ -150,83 +190,76 @@ disconnectAndNotify(Client* client, struct pollfd* fds, u32 nfds, u32 conn)
// to clients_file.
// See "Authentication" in chatty.h
Client*
-authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* clientfds)
+authenticate(Arena* clientsArena, s32 clients_file, struct pollfd* pollfd, HeaderMessage header)
{
s32 nrecv = 0;
- Client* clients = clientsArena->addr;
-
- HeaderMessage header;
- nrecv = recv(clientfds[BIFD].fd, &header, sizeof(header), 0);
- if (nrecv != sizeof(header)) {
- loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(header));
- return 0;
- }
- loggingf("authenticate(%d)|" HEADER_FMT "\n", clientfds[BIFD].fd, HEADER_ARG(header));
-
Client* client = 0;
- // Scenario 1: Search for existing client
- if (header.type == HEADER_TYPE_ID) {
- IDMessage message;
- nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0);
- if (nrecv != sizeof(message)) {
- loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message));
- return 0;
- }
- client = getClientByID(clientsArena, message.id);
- if (client) {
- loggingf("authenticate(%d)|found [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id);
- header.type = HEADER_TYPE_ERROR;
- // TODO: allow multiple connections
- if (client->pollunifd != 0 || client->pollbifd != 0) {
- loggingf("authenticate(%d)|err: already connected\n", clientfds[BIFD].fd);
- ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_ALREADYCONNECTED);
- sendAnyMessage(clientfds[BIFD].fd, &header, &error_message);
- return 0;
- }
- ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS);
- sendAnyMessage(clientfds[BIFD].fd, &header, &error_message);
- } else {
- loggingf("authenticate(%d)|notfound\n", clientfds[BIFD].fd);
+ /* Scenario 1: Search for existing client */
+ if (header.type == HEADER_TYPE_ID)
+ {
+ client = getClientByID((Client*)clientsArena->addr, nclients, header.id);
+ if (!client)
+ {
+ loggingf("authenticate(%d)|notfound\n", pollfd->fd);
header.type = HEADER_TYPE_ERROR;
ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
- sendAnyMessage(clientfds[BIFD].fd, &header, &error_message);
+ sendAnyMessage(pollfd->fd, header, &error_message);
return 0;
}
- // Scenario 2: Create a new client
- } else if (header.type == HEADER_TYPE_INTRODUCTION) {
+
+ loggingf("authenticate(%d)|found [%s](%lu)\n", pollfd->fd, client->author, client->id);
+ header.type = HEADER_TYPE_ERROR;
+ if (!client->unifd)
+ client->unifd = pollfd;
+ else if(!client->bifd)
+ client->bifd = pollfd;
+ else
+ assert(0);
+
+ ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_SUCCESS);
+ header.type = HEADER_TYPE_ERROR;
+ sendAnyMessage(pollfd->fd, header, &error_message);
+
+ return client;
+ }
+
+ /* Scenario 2: Create a new client */
+ if (header.type == HEADER_TYPE_INTRODUCTION)
+ {
IntroductionMessage message;
- nrecv = recv(clientfds[BIFD].fd, &message, sizeof(message), 0);
- if (nrecv != sizeof(message)) {
- loggingf("authenticate(%d)|err: %d/%lu bytes\n", clientfds[BIFD].fd, nrecv, sizeof(message));
+ nrecv = recv(pollfd->fd, &message, sizeof(message), 0);
+ if (nrecv != sizeof(message))
+ {
+ loggingf("authenticate(%d)|err: %d/%lu bytes\n", pollfd->fd, nrecv, sizeof(message));
return 0;
}
// Copy metadata from IntroductionMessage
- client = ArenaPush(clientsArena, sizeof(*clients));
+ client = ArenaPush(clientsArena, sizeof(*client));
memcpy(client->author, message.author, AUTHOR_LEN);
client->id = nclients;
nclients++;
- // Save client
#ifdef IMPORT_ID
write(clients_file, client, sizeof(*client));
#endif
- loggingf("authenticate(%d)|Added [%s](%lu)\n", clientfds[BIFD].fd, client->author, client->id);
+ loggingf("authenticate(%d)|Added [%s](%lu)\n", pollfd->fd, client->author, client->id);
+ // Send ID to new client
HeaderMessage header = HEADER_INIT(HEADER_TYPE_ID);
- IDMessage id_message = {.id = client->id};
- sendAnyMessage(clientfds[BIFD].fd, &header, &id_message);
- } else {
- loggingf("authenticate(%d)|Wrong header expected %s or %s\n", clientfds[BIFD].fd, headerTypeString(HEADER_TYPE_INTRODUCTION), headerTypeString(HEADER_TYPE_ID));
- return 0;
- }
- assert(client != 0);
+ header.id = client->id;
+ s32 nsend = send(pollfd->fd, &header, sizeof(header), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(header));
- client->pollunifd = clientfds;
- client->pollbifd = clientfds + 1;
+ return client;
+ }
- return client;
+ loggingf("authenticate(%d)|Wrong header expected %s or %s\n", pollfd->fd,
+ headerTypeString(HEADER_TYPE_INTRODUCTION),
+ headerTypeString(HEADER_TYPE_ID));
+ return 0;
}
int
@@ -236,9 +269,11 @@ main(int argc, char** argv)
logfd = 2;
// optional logging
- if (argc > 1) {
+ if (argc > 1)
+ {
if (*argv[1] == '-')
- if (argv[1][1] == 'l') {
+ if (argv[1][1] == 'l')
+ {
logfd = open(LOGFILE, O_RDWR | O_CREAT | O_TRUNC, 0600);
assert(logfd != -1);
}
@@ -299,192 +334,171 @@ main(int argc, char** argv)
assert(fstat(clients_file, &statbuf) != -1);
read(clients_file, clients, statbuf.st_size);
- if (statbuf.st_size > 0) {
+ if (statbuf.st_size > 0)
+ {
ArenaPush(&clientsArena, statbuf.st_size);
loggingf("Imported %lu client(s)\n", statbuf.st_size / sizeof(*clients));
nclients += statbuf.st_size / sizeof(*clients);
+
+ // Reset pointers on imported clients
+ for (u32 i = 0; i < nclients - 1; i++)
+ {
+ clients[i].unifd = 0;
+ clients[i].bifd = 0;
+ }
}
- for (u32 i = 0; i < nclients; i++)
+ for (u32 i = 0; i < nclients - 1; i++)
loggingf("Imported: " CLIENT_FMT "\n", CLIENT_ARG(clients[i]));
#endif
// Initialize the rest of the fds array
for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++)
fds[i] = newpollfd;
- // Reset file descriptors on imported clients
- for (u32 i = 0; i < CLIENTS_SIZE; i++) {
- clients[i].pollunifd = 0;
- clients[i].pollbifd = 0;
- }
- while (1) {
+ while (1)
+ {
s32 err = poll(fds, FDS_SIZE, TIMEOUT);
assert(err != -1);
- if (fds[FDS_STDIN].revents & POLLIN) {
+ if (fds[FDS_STDIN].revents & POLLIN)
+ {
u8 c; // exit on ctrl-d
if (!read(fds[FDS_STDIN].fd, &c, 1))
break;
- } else if (fds[FDS_SERVER].revents & POLLIN) {
+ }
+ else if (fds[FDS_SERVER].revents & POLLIN)
+ {
// TODO: what if we are not aligned by 2 anymore?
- s32 unifd = accept(serverfd, 0, 0);
- s32 bifd = accept(serverfd, 0, 0);
+ s32 clientfd = accept(serverfd, 0, 0);
- if (unifd == -1 || bifd == -1) {
- loggingf("Error while accepting connection (%d,%d)\n", unifd, bifd);
- if (unifd != -1) close(unifd);
- if (bifd != -1) close(bifd);
+ if (clientfd == -1)
+ {
+ loggingf("Error while accepting connection (%d)\n", clientfd);
continue;
- } else
- loggingf("New connection(%d,%d)\n", unifd, bifd);
+ }
+ else
+ loggingf("New connection(%d)\n", clientfd);
- // TODO: find empty space in arena
- if (nclients + 1 == MAX_CONNECTIONS) {
+ // TODO: find empty space in arena (fragmentation)
+ if (nclients + 1 == MAX_CONNECTIONS)
+ {
local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR);
local_persist ErrorMessage message = ERROR_INIT(ERROR_TYPE_TOOMANYCONNECTIONS);
- sendAnyMessage(unifd, &header, &message);
- if (unifd != -1)
- close(unifd);
- if (bifd != -1)
- close(bifd);
+ sendAnyMessage(clientfd, header, &message);
+ if (clientfd != -1)
+ close(clientfd);
loggingf("Max clients reached. Rejected connection\n");
- } else {
+ }
+ else
+ {
// no more space, allocate
- struct pollfd* clientfds = ArenaPush(&fdsArena, 2 * sizeof(*clientfds));
- clientfds[UNIFD].fd = unifd;
- clientfds[UNIFD].events = POLLIN;
- clientfds[BIFD].fd = bifd;
- clientfds[BIFD].events = POLLIN;
- loggingf("Added pollfd(%d,%d)\n", unifd, bifd);
+ struct pollfd* pollfd = ArenaPush(&fdsArena, sizeof(*pollfd));
+ pollfd->fd = clientfd;
+ loggingf("Added pollfd(%d)\n", clientfd);
}
}
- // Check for messages from clients in their unifd
- for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn += 2) {
+ for (u32 conn = FDS_CLIENTS; conn < FDS_SIZE; conn++)
+ {
if (!(fds[conn].revents & POLLIN)) continue;
if (fds[conn].fd == -1) continue;
loggingf("Message unifd (%d)\n", fds[conn].fd);
- // Get client associated with connection
- Client* client = 0;
- for (u32 j = 0; j < CLIENTS_SIZE; j++) {
- if (!clients[j].pollunifd)
- continue;
- if (clients[j].pollunifd == fds + conn) {
- client = clients + j;
- break;
- }
- }
- if (!client) {
- loggingf("No client associated(%d)\n", fds[conn].fd);
- close(fds[conn].fd);
- continue;
- }
- loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd);
-
// We received a message, try to parse the header
HeaderMessage header;
s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0);
- if (nrecv == 0) {
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
- loggingf("Disconnected(%lu) [%s]\n", client->id, client->author);
- continue;
- } else if (nrecv != sizeof(header)) {
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
- loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header));
+ assert(nrecv != -1);
+
+ Client* client;
+ if (nrecv != sizeof(header))
+ {
+ client = getClientByFD(clients, nclients, fds[conn].fd);
+ loggingf(CLIENT_FMT" %d/%lu bytes\n", CLIENT_ARG((*client)), nrecv, sizeof(header));
+ if (client)
+ {
+ disconnectAndNotify(clients, nclients, client);
+ loggingf("Disconnected(%lu) [%s]\n", client->id, client->author);
+ }
+ else
+ {
+ loggingf("Got error from unauntheticated client\n");
+ close(fds[conn].fd);
+ fds[conn].fd = -1;
+ }
continue;
}
loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header));
- switch (header.type) {
- case HEADER_TYPE_TEXT: {
- TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd);
- loggingf("Received(%d)", fds[conn].fd);
- printTextMessage(text_message, client, 0);
+ // Authentication
+ if (header.id)
+ {
+ client = getClientByID(clients, nclients, header.id);
+ if (!client)
+ {
+ loggingf("No client for id %d\n", fds[conn].fd);
- sendToOthers(fds, FDS_SIZE, fds[conn].fd, UNIFD, &header, text_message);
- } break;
- // handle request for information about client id
- default:
- loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd);
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
- continue;
- }
- }
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
- // Check for messages from clients in their bifd
- for (u32 conn = FDS_CLIENTS + BIFD; conn < FDS_SIZE; conn += 2) {
- if (!(fds[conn].revents & POLLIN)) continue;
- if (fds[conn].fd == -1) continue;
- loggingf("Message bifd (%d)\n", fds[conn].fd);
-
- // Get client associated with connection
- Client* client = 0;
- for (u32 j = 0; j < CLIENTS_SIZE; j++) {
- if (!clients[j].pollbifd)
- continue;
- if (clients[j].pollbifd == fds + conn) {
- client = clients + j;
- break;
+ sendAnyMessage(fds[conn].fd, header, &message);
+
+ // Reject connection
+ fds[conn].fd = -1;
+ close(fds[conn].fd);
+ }
+ else
+ {
+ header.type = HEADER_TYPE_ERROR;
+ ErrorMessage message = ERROR_INIT(ERROR_TYPE_SUCCESS);
+
+ sendAnyMessage(fds[conn].fd, header, &message);
}
}
- if (!client) {
+ else
+ {
loggingf("No client for connection(%d)\n", fds[conn].fd);
-#ifdef IMPORT_ID
- client = authenticate(&clientsArena, clients_file, fds + conn - 1);
-#else
- client = authenticate(&clientsArena, 0, fds + conn - 1);
-#endif
- // If the client sent an IDMessage but no ID was found authenticate() could return null
- if (!client) {
- loggingf("Could not initialize client\n");
- disconnect(fds + conn, 0);
- } else { // client was added/connected
+
+ client = authenticate(&clientsArena, clients_file, fds + conn, header);
+
+ if (!client)
+ loggingf("Could not initialize client (%d)\n", fds[conn].fd);
+ else if (!client->bifd)
+ {
+ // Send connected message to other clients if this was the first time -> bifd is
+ // not set yet.
local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_PRESENCE);
- PresenceMessage message = {.id = client->id, .type = PRESENCE_TYPE_CONNECTED};
- sendToOthers(fds, FDS_SIZE, fds[conn - BIFD].fd, UNIFD, &header, &message);
+ header.id = client->id;
+ PresenceMessage message = {.type = PRESENCE_TYPE_CONNECTED};
+ sendToOthers(clients, nclients, client, UNIFD, &header, &message);
}
- continue;
}
- loggingf("Found client(%lu) [%s] (%d)\n", client->id, client->author, fds[conn].fd);
-
- // We received a message, try to parse the header
- HeaderMessage header;
- s32 nrecv = recv(fds[conn].fd, &header, sizeof(header), 0);
- if (nrecv == 0) {
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
- loggingf("Disconnected(%lu) [%s]\n", client->id, client->author);
- continue;
- } else if (nrecv != sizeof(header)) {
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
- loggingf("error(%lu) [%s] %d/%lu bytes\n", client->id, client->author, nrecv, sizeof(header));
- continue;
- }
- loggingf("Received(%d) -> " HEADER_FMT "\n", fds[conn].fd, HEADER_ARG(header));
switch (header.type) {
- case HEADER_TYPE_ID: {
- // handle request for information about client id
- IDMessage id_message;
- nrecv = recv(fds[conn].fd, &id_message, sizeof(id_message), 0);
-
- Client* client = getClientByID(&clientsArena, id_message.id);
- if (!client) {
- local_persist HeaderMessage header = HEADER_INIT(HEADER_TYPE_ERROR);
- local_persist ErrorMessage error_message = ERROR_INIT(ERROR_TYPE_NOTFOUND);
- sendAnyMessage(fds[conn].fd, &header, &error_message);
- loggingf("Could not find %lu\n", id_message.id);
- break;
- }
+ /* Send text message to all other clients */
+ case HEADER_TYPE_TEXT:
+ {
+ TextMessage* text_message = recvTextMessage(&msgsArena, fds[conn].fd);
+ loggingf("Received(%d)", fds[conn].fd);
+ printTextMessage(text_message, client, 0);
+
+ sendToOthers(clients, nclients, client, UNIFD, &header, text_message);
+ } break;
+ /* Send back client information */
+ case HEADER_TYPE_ID:
+ {
HeaderMessage header = HEADER_INIT(HEADER_TYPE_INTRODUCTION);
IntroductionMessage introduction_message;
+ header.id = client->id;
memcpy(introduction_message.author, client->author, AUTHOR_LEN);
- sendAnyMessage(fds[conn].fd, &header, &introduction_message);
+ s32 nrecv = sendAnyMessage(fds[conn].fd, header, &introduction_message);
+ assert(nrecv != -1);
} break;
default:
- loggingf("Unhandled '%s' from client(%d)\n", headerTypeString(header.type), fds[conn].fd);
- disconnectAndNotify(client, fds, FDS_SIZE, conn);
+ loggingf("Unhandled '%s' from "CLIENT_FMT"(%d)\n", headerTypeString(header.type),
+ CLIENT_ARG((*client)),
+ fds[conn].fd);
+ disconnectAndNotify(client, nclients, client);
continue;
}
}