diff --git a/.clangd b/.clangd new file mode 100644 index 0000000..6418b4e --- /dev/null +++ b/.clangd @@ -0,0 +1,2 @@ +CompileFlags: + Add: [-std=c++23] diff --git a/README.md b/README.md index 2186ca8..72bed3e 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ And also my first serious C++ project. - [X] Serialisation and saving - [X] FUCKING NETWORKING (probably minimal RESP) - [X] Proper startup/shutdown w/ load & save -- [ ] Multithreading +- [X] Multithreading - [ ] More types -- [ ] Async +- [X] Async - [ ] Backflip off the Rio-Antirrio bridge diff --git a/main.cpp b/main.cpp index 8ca88db..5817f23 100644 --- a/main.cpp +++ b/main.cpp @@ -1,4 +1,7 @@ -#include +#include +#include +#include +#include #include #include #include @@ -12,6 +15,7 @@ #include #include #include +#include using Db = std::vector>; Db database; @@ -63,18 +67,18 @@ std::vector parse_resp_request(char* request) { return split_request; } -void safe_exit(int signum) { - std::cout << "[W] Exiting..." << std::endl; - save_database("db.json"); - exit(signum); -} +namespace asio = boost::asio; +using asio::ip::tcp; +using asio::use_awaitable; +boost::asio::io_context io_context; -void handle_client(int client_socket) { - while (true) { - char buffer[4096]; - int bytes_received = recv(client_socket, buffer, 4096, 0); +asio::awaitable handle_client(tcp::socket client_socket) { + char buffer[4096]; + + try { + while (true) { + std::size_t bytes_received = co_await client_socket.async_read_some(asio::buffer(buffer), use_awaitable); - if (bytes_received > 0) { buffer[bytes_received] = '\0'; std::vector parsed = parse_resp_request(buffer); @@ -84,97 +88,100 @@ void handle_client(int client_socket) { for (const auto& inner_vec: database) { // Iterate through DB and find the required item if (!inner_vec.empty() && inner_vec[0] == parsed[1]) { int length = inner_vec[1].length(); - std::string return_string = std::format("${}\r\n{}\r\n", length, inner_vec[1]); // Serialise to RESP - std::cout << "[d] Sending: " << return_string << std::endl; - send(client_socket, return_string.data(), return_string.length(), 0); - found = true; + std::string return_string = std::format("${}\r\n{}\r\n", length, inner_vec[1]); // Serialise to RESP + std::cout << "[d] Sending: " << return_string << std::endl; + co_await asio::async_write(client_socket, asio::buffer(return_string.data(), return_string.length()), use_awaitable); + found = true; + } } - } - - if (!found) { - send(client_socket, "$-1\r\n", strlen("$-1\r\n"), 0); // Not found - } - } else if (parsed[0] == "SET") { - bool found = false; // Avoids SEGFAULT - - for (auto& inner_vec: database) { - if (!inner_vec.empty() && inner_vec[0] == parsed[1]) { - inner_vec[1] = parsed[2]; - send(client_socket, "+OK\r\n", strlen("+OK\r\n"), 0); - found = true; - } - + if (!found) { - std::vector new_kv = { parsed[1], parsed[2] }; - database.push_back(new_kv); - send(client_socket, "+OK\r\n", strlen("+OK\r\n"), 0); - found = true; + co_await asio::async_write(client_socket, asio::buffer("$-1\r\n", strlen("$-1\r\n")), use_awaitable); // Not found } + } else if (parsed[0] == "SET") { + bool found = false; // Avoids SEGFAULT + + for (auto& inner_vec: database) { + if (!inner_vec.empty() && inner_vec[0] == parsed[1]) { + inner_vec[1] = parsed[2]; + co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable); + found = true; + } + + if (!found) { + std::vector new_kv = { parsed[1], parsed[2] }; + database.push_back(new_kv); + co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable); + found = true; + } + } + } else { + co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable); // Temporary catch-all (for more advanced handshake) } - } else { - send(client_socket, "+OK\r\n", strlen("+OK\r\n"), 0); // Temporary catch-all (for more advanced handshake) } - } else { - close(client_socket); - std::cout << "[i] Client disconnected." << std::endl; - break; + } catch (const std::exception& e) { + std::cout << "[i] Client session ened/error" << e.what() << std::endl; } - } +} + + + +asio::awaitable tcp_server_task(tcp::acceptor& acceptor) { + std::cout << "[i] Server listening on port 6379" << std::endl; + while (true) { + tcp::socket socket = co_await acceptor.async_accept(use_awaitable); + + asio::co_spawn( + io_context, + handle_client(std::move(socket)), + asio::detached + ); + } +} + +asio::awaitable shutdown_handler(asio::io_context& io_context, tcp::acceptor& acceptor, std::shared_ptr> work_guard) { + asio::signal_set signals(io_context, SIGINT, SIGTERM); + co_await signals.async_wait(use_awaitable); + std::cout << "[i] Stopping blueis server..." << std::endl; + acceptor.cancel(); + acceptor.close(); + + work_guard->reset(); } int main() { - std::cout << "[i] Starting blueis server..." << std::endl; - - signal(SIGINT, safe_exit); - load_database("db.json"); - int listening_socket = socket(AF_INET, SOCK_STREAM, 0 ); - if (listening_socket < 0) { - std::cerr << "[E] Failed to create socket." << std::endl; - return 1; - } - - int optval = 1; - if (setsockopt(listening_socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { - std::cerr << "[W] Failed to use SO_REUSEADDR. Continuing." << std::endl; - } + const int N_THREADS = std::thread::hardware_concurrency(); + std::vector threads; + tcp::endpoint endpoint(tcp::v4(), 6379); + tcp::acceptor acceptor(io_context, endpoint); + auto work_guard = std::make_shared>(asio::make_work_guard(io_context)); - sockaddr_in server_address; - server_address.sin_family = AF_INET; - server_address.sin_port = htons(6379); - server_address.sin_addr.s_addr = INADDR_ANY; + asio::co_spawn( + io_context, + tcp_server_task(acceptor), + asio::detached + ); - if (bind(listening_socket, (sockaddr*)&server_address, sizeof(server_address)) < 0) { - std::cerr << "[E] Failed to bind to port 6379. Check for any other running redis or blueis instances." << std::endl; - close(listening_socket); - return 1; + asio::co_spawn( + io_context, + shutdown_handler(io_context, acceptor, work_guard), + asio::detached + ); + + for (int i = 0; i < N_THREADS - 1; i++) { + threads.emplace_back([&io_context] { + io_context.run(); + }); } - if (listen(listening_socket, 5) < 0) { - std::cerr << "[E] Failed to listen on socket." << std::endl; - close(listening_socket); - return 1; - } + std::cout << "[i] Starting blueis server on 6379..." << std::endl; + io_context.run(); - std::cout << "[i] Server listening on port 6379..." << std::endl; - - sockaddr_in client_address; - socklen_t client_size = sizeof(client_address); - - while (true) { - int client_socket = accept(listening_socket, (sockaddr*)&client_address, &client_size); - if (client_socket < 0) { - std::cerr << "[E] Failed to accept connection." << std::endl; - close(listening_socket); - return 1; - } - std::thread t(handle_client, client_socket); // Temporarily using the shit method of one thread per client despite the DoS potential until I can get async down. + for (auto& t: threads) { if (t.joinable()) { - t.detach(); + t.join(); } - std::cout << "[i] Client connected." << std::endl; } - - return 0; }