aboutsummaryrefslogtreecommitdiff
path: root/server.c
diff options
context:
space:
mode:
Diffstat (limited to 'server.c')
-rw-r--r--server.c259
1 files changed, 177 insertions, 82 deletions
diff --git a/server.c b/server.c
index 428ddde..18ce0f9 100644
--- a/server.c
+++ b/server.c
@@ -1,30 +1,38 @@
-#include "arena.h"
-#include "common.h"
+#include "chatty.h"
#include <assert.h>
#include <netinet/in.h>
#include <poll.h>
#include <stdarg.h>
+#include <string.h>
#include <sys/socket.h>
+#include <unistd.h>
// timeout on polling
#define TIMEOUT 60 * 1000
// max pending connections
-#define PENDING_MAX 16
-
-// the size of pollfd element in the fdsArena
-// note: clientsArena and pollfd_size must have been initialisezd
-#define FDS_SIZE fdsArena->pos / pollfd_size
+#define MAX_CONNECTIONS 16
+// Get number of connections from arena position
+// NOTE: this is somewhat wrong, because of when disconnections happen
+#define FDS_SIZE (fdsArena->pos / sizeof(*fds))
// enum for indexing the fds array
enum { FDS_STDIN = 0,
FDS_SERVER,
FDS_CLIENTS };
-int main(void)
+// Has information on clients
+// For each pollfd in fds there should be a matching client in clients
+// clients[i - FDS_CLIENTS] <=> fds[i]
+typedef struct {
+ u8 author[AUTHOR_LEN]; // matches author property on other message types
+ Bool initialized; // boolean
+} Client;
+
+int
+main(void)
{
- u32 err, serverfd, clientfd;
- u16 nclient = 0;
+ s32 err, serverfd, clientfd;
u32 on = 1;
// Start listening on the socket
@@ -32,7 +40,7 @@ int main(void)
serverfd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP);
assert(serverfd > 2);
- err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8 *)&on, sizeof(on));
+ err = setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, (u8*)&on, sizeof(on));
assert(err == 0);
const struct sockaddr_in address = {
@@ -41,36 +49,39 @@ int main(void)
{0},
};
- err = bind(serverfd, (const struct sockaddr *)&address, sizeof(address));
+ err = bind(serverfd, (const struct sockaddr*)&address, sizeof(address));
assert(err == 0);
- err = listen(serverfd, PENDING_MAX);
+ err = listen(serverfd, MAX_CONNECTIONS);
assert(err == 0);
}
- Arena *msgTextArena = ArenaAlloc(); // allocating text in messages that have a dynamic sized
- Message mrecv = {0}; // message used for receiving messages from clients
- u32 nrecv = 0; // number of bytes received
- u32 recv_len; // Number of bytes of the message received over stream
- u32 nsend = 0; // number of bytes sent
- Arena *bufArena = ArenaAlloc(); // data in buf
- u8 *buf = ArenaPush(bufArena, STREAM_LIMIT); // temporary buffer for receiving and sending data
- Message *mbuf = (Message *)buf; // pointer for indexing buf as a message
-
- Arena *fdsArena = ArenaAlloc(); // arena for fds to accomodate multiple clients
- struct pollfd *fds = fdsArena->memory; // helper for indexing memory
- struct pollfd c = {0, POLLIN, 0}; // helper client structure fore reusing
- struct pollfd *fdsAddr; // used for copying clients
- const u64 pollfd_size = sizeof(struct pollfd);
+ Arena* msgsArena = ArenaAlloc(Megabytes(128)); // storing received messages
+ // NOTE: sent messages?
+ s32 nrecv = 0; // number of bytes received
+ s32 nsend = 0; // number of bytes sent
+
+ Arena* clientsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(Client));
+ Arena* fdsArena = ArenaAlloc(MAX_CONNECTIONS * sizeof(struct pollfd));
+ struct pollfd* fds = fdsArena->addr;
+ Client* clients = clientsArena->addr;
+
+ struct pollfd* fdsAddr;
+ struct pollfd newpollfd = {-1, POLLIN, 0};
// initialize fds structure
- // add stdin (c.fd == 0)
- fdsAddr = ArenaPush(fdsArena, pollfd_size);
- memcpy(fdsAddr, &c, pollfd_size);
+ newpollfd.fd = 0;
+ fdsAddr = ArenaPush(fdsArena, sizeof(*fds));
+ memcpy(fdsAddr, &newpollfd, sizeof(*fds));
// add serverfd
- c.fd = serverfd;
- fdsAddr = ArenaPush(fdsArena, pollfd_size);
- memcpy(fdsAddr, &c, pollfd_size);
+ newpollfd.fd = serverfd;
+ fdsAddr = ArenaPush(fdsArena, sizeof(*fds));
+ memcpy(fdsAddr, &newpollfd, sizeof(*fds));
+ newpollfd.fd = -1;
+
+ // Initialize the rest of the fds array
+ for (u32 i = FDS_CLIENTS; i < MAX_CONNECTIONS; i++)
+ fds[i] = newpollfd;
while (1) {
err = poll(fds, FDS_SIZE, TIMEOUT);
@@ -83,81 +94,165 @@ int main(void)
clientfd = accept(serverfd, NULL, NULL);
assert(clientfd != -1);
assert(clientfd > serverfd);
+ fprintf(stdout, "New connection(%d).\n", clientfd);
// fill up a hole
- u8 found = 0;
- for (u32 i = FDS_CLIENTS; i < FDS_SIZE; i++) {
- if (fds[i].fd == -1) {
- fds[i].fd = clientfd;
- // note we do not have to reset .revents because poll will set it to 0 next time
- found = 1;
+ u8 found;
+ for (found = FDS_CLIENTS; found < FDS_SIZE; found++)
+ if (fds[found].fd == -1)
break;
- }
- }
-
- // allocate an extra client because there was no empty spot in the fds array
- if (!found) {
- // add client to arena
- fdsAddr = ArenaPush(fdsArena, pollfd_size);
- c.fd = clientfd;
- memcpy(fdsAddr, &c, pollfd_size);
+ if (found == FDS_SIZE) {
+ // no more space, allocate
+ struct pollfd* pollfd = ArenaPush(fdsArena, sizeof(*pollfd));
+ pollfd->fd = clientfd;
+ pollfd->events = POLLIN;
+ } else if (found == MAX_CONNECTIONS) {
+ // TODO: reject connection
+ close(clientfd);
+ fprintf(stdout, "Max clients reached.");
+ } else {
+ // hole found
+ fds[found].fd = clientfd;
+ fds[found].events = POLLIN;
+ fprintf(stdout, "Added pollfd(%d).\n", clientfd);
}
-
- nclient++;
- fprintf(stdout, "connected(%d).\n", clientfd - serverfd);
}
+ // Check for messages from clients
for (u32 i = FDS_CLIENTS; i < (FDS_SIZE); i++) {
if (!(fds[i].revents & POLLIN))
continue;
- if (fds[i].fd == -1)
+ assert(fds[i].fd != -1);
+
+ fprintf(stdout, "Message(%d).\n", fds[i].fd);
+ // If this is the first message from the client it must be a presence message indicated
+ // it connected.
+ Client* client = clients + i - FDS_CLIENTS;
+ if (!client->initialized) {
+ fprintf(stdout, " Adding to clients(%d).\n", fds[i].fd);
+ // Wait for PresenceMessage from new client to get author information
+ HeaderMessage header;
+ // TODO: handle wrong message, disconnection, etc.
+ nrecv = recv(clientfd, &header, sizeof(header), 0);
+ assert(nrecv != -1);
+ assert(nrecv == sizeof(header));
+ if (header.type != HEADER_TYPE_PRESENCE) {
+ // TODO: reject connection
+ close(clientfd);
+ continue;
+ }
+ fprintf(stdout, " Got header(%d).\n", fds[i].fd);
+
+ PresenceMessage message;
+ // TODO: handle wrong message
+ nrecv = recv(clientfd, &message, sizeof(message), 0);
+ assert(nrecv != -1);
+ assert(nrecv == sizeof(message));
+ fprintf(stdout, " Got presence message(%d).\n", fds[i].fd);
+
+ memcpy(client->author, message.author, AUTHOR_LEN);
+ client->initialized = True;
+
+ fprintf(stdout, " Added to clients(%d): %s\n", fds[i].fd, client->author);
+
+ // Notify other clients from this new one
+ // Reuse header and message
+ for (u32 j = FDS_CLIENTS; j < (FDS_SIZE); j++) {
+ if (fds[j].fd == fds[i].fd)
+ continue;
+ if (fds[j].fd == -1)
+ continue;
+ fprintf(stdout, " Notifying (%d)\n", fds[j].fd);
+
+ nsend = send(fds[j].fd, &header, sizeof(header), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(header));
+ nsend = send(fds[j].fd, &message, sizeof(message), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(message));
+ }
continue;
+ }
- nrecv = recv(fds[i].fd, buf, bufArena->pos, 0);
+ // We received a message, try to parse the header
+ HeaderMessage header;
+ nrecv = recv(fds[i].fd, &header, sizeof(header), 0);
assert(nrecv != -1);
if (nrecv == 0) {
- fprintf(stdout, "disconnected(%d). \n", fds[i].fd - serverfd);
+ fprintf(stdout, "Disconnected(%d). \n", fds[i].fd);
shutdown(fds[i].fd, SHUT_RDWR);
- close(fds[i].fd); // send close to client
- fds[i].fd = -1; // ignore in the future
+ close(fds[i].fd); // send close to client
+ fds[i].fd = -1; // ignore in the future
+ clients[i - FDS_CLIENTS].initialized = False; // deinitialize client
+ //
+ // Send disconnection to other connected clients
+ HeaderMessage header = HEADER_PRESENCEMESSAGE;
+ PresenceMessage message = {
+ .type = PRESENCE_TYPE_DISCONNECTED
+ };
+ memcpy(message.author, clients[i - FDS_CLIENTS].author, AUTHOR_LEN);
+ for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) {
+ if (fds[j].fd == fds[i].fd)
+ continue;
+ if (fds[j].fd == -1)
+ continue;
+ nsend = send(fds[j].fd, &header, sizeof(header), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(header));
+ nsend = send(fds[j].fd, &message, sizeof(message), 0);
+ assert(nsend != -1);
+ assert(nsend == sizeof(message));
+ }
+
continue;
}
- recv_len = sizeof(*mbuf) - sizeof(mbuf->text) + mbuf->text_len * sizeof(*mbuf->text);
- if (recv_len > nrecv) {
- // allocate needed space for buf
- if (recv_len > bufArena->pos)
- ArenaPush(bufArena, recv_len - bufArena->pos);
-
- // receive remaining bytes
- u32 nr = recv(fds[i].fd, buf + nrecv, recv_len - nrecv, 0);
- assert(nr != -1);
- nrecv += nr;
- assert(nrecv == recv_len);
- }
+ assert(nrecv == sizeof(header));
+ fprintf(stderr, " Received(%d): %d bytes -> " PH_FMT "\n", fds[i].fd, nrecv, PH_ARG(header));
- // TODO: Do not print the message in the logs
- fprintf(stdout, "message(%d): %d bytes.\n", fds[i].fd - serverfd, nrecv);
+ switch (header.type) {
+ case HEADER_TYPE_TEXT:;
+ TextMessage* message;
+ nrecv = recvTextMessage(msgsArena, fds[i].fd, &message);
+ fprintf(stderr, " Received(%d): %d bytes -> ", fds[i].fd, nrecv);
+ printTextMessage(message, 0);
- for (u32 j = FDS_CLIENTS; j < (FDS_SIZE); j++) {
- if (j == i)
- continue;
- if (fds[j].fd == -1)
- continue;
+ HeaderMessage header = HEADER_TEXTMESSAGE;
+ // Send message to all other clients
+ for (u32 j = FDS_CLIENTS; j < FDS_SIZE; j++) {
+ if (fds[j].fd == fds[i].fd) continue;
+ if (fds[j].fd == -1) continue;
- nsend = send(fds[j].fd, buf, nrecv, 0);
- assert(nsend != 1);
- assert(nsend == nrecv);
- fprintf(stdout, "retransmitted(%d->%d).\n", fds[i].fd - serverfd, fds[j].fd - serverfd);
- }
+ // NOTE: I wonder if this is more expensive than constructing a buffer and sending
+ // that
+ u32 nsend_total = 0;
+ nsend = send(fds[j].fd, &header, sizeof(header), 0);
+ assert(nsend != 1);
+ assert(nsend == sizeof(header));
+ nsend_total += nsend;
+ nsend = send(fds[j].fd, message, TEXTMESSAGE_SIZE, 0);
+ assert(nsend != -1);
+ assert(nsend == TEXTMESSAGE_SIZE);
+ nsend_total += nsend;
+ nsend = send(fds[j].fd, &message->text, message->len * sizeof(*message->text), 0);
+ assert(nsend != -1);
+ assert(nsend == (message->len * sizeof(*message->text)));
+ nsend_total += nsend;
- ArenaPop(msgTextArena, mrecv.text_len);
+ fprintf(stdout, " Retransmitted(%d->%d) %d bytes.\n", fds[i].fd, fds[j].fd, nsend_total);
+ }
+ break;
+ default:
+ fprintf(stdout, " Got unhandled message type '%s' from client %d", headerTypeString(header.type), fds[i].fd);
+ continue;
+ }
}
}
+ ArenaRelease(clientsArena);
ArenaRelease(fdsArena);
- ArenaRelease(msgTextArena);
+ ArenaRelease(msgsArena);
return 0;
}