187 lines
5.7 KiB
C++
187 lines
5.7 KiB
C++
#include <boost/asio/detached.hpp>
|
|
#include <boost/asio/executor_work_guard.hpp>
|
|
#include <boost/asio/io_context.hpp>
|
|
#include <boost/asio/use_awaitable.hpp>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include "json.hpp"
|
|
#include <netinet/in.h>
|
|
#include <unistd.h>
|
|
#include <sys/socket.h>
|
|
#include <format>
|
|
#include <signal.h>
|
|
#include <unistd.h>
|
|
#include <cstdlib>
|
|
#include <thread>
|
|
#include <boost/asio.hpp>
|
|
|
|
using Db = std::vector<std::vector<std::string>>;
|
|
Db database;
|
|
|
|
void save_database(std::string file_name) {
|
|
nlohmann::json json_data = database;
|
|
std::string serialised_json_data = json_data.dump();
|
|
std::ofstream db_file(file_name);
|
|
db_file << serialised_json_data;
|
|
db_file.close();
|
|
}
|
|
|
|
void load_database(std::string file_name) {
|
|
std::ifstream db_file(file_name);
|
|
std::stringstream raw_file_text;
|
|
raw_file_text << db_file.rdbuf();
|
|
db_file.close();
|
|
database = nlohmann::json::parse(raw_file_text.str());
|
|
}
|
|
|
|
|
|
std::vector<std::string> parse_resp_request(char* request) {
|
|
std::string request_str(request);
|
|
std::string delimiter = "\r\n";
|
|
std::vector<std::string> split_request;
|
|
size_t pos = 0;
|
|
|
|
while ((pos = request_str.find(delimiter)) != std::string::npos) {
|
|
split_request.push_back(request_str.substr(0, pos));
|
|
request_str.erase(0, pos + delimiter.length());
|
|
}
|
|
|
|
if (request_str.length() > 0) {
|
|
split_request.push_back(request_str);
|
|
}
|
|
|
|
|
|
for (std::string entry: split_request) {
|
|
if (entry.contains("*")) {
|
|
split_request.erase(find(split_request.begin(), split_request.end(), entry));
|
|
}
|
|
}
|
|
for (std::string entry: split_request) {
|
|
if (entry.contains("$")) {
|
|
split_request.erase(find(split_request.begin(), split_request.end(), entry));
|
|
}
|
|
}
|
|
|
|
return split_request;
|
|
}
|
|
|
|
namespace asio = boost::asio;
|
|
using asio::ip::tcp;
|
|
using asio::use_awaitable;
|
|
boost::asio::io_context io_context;
|
|
|
|
asio::awaitable<void> handle_client(tcp::socket client_socket) {
|
|
char buffer[4096];
|
|
|
|
try {
|
|
while (true) {
|
|
std::size_t bytes_received = co_await client_socket.async_read_some(asio::buffer(buffer), use_awaitable);
|
|
|
|
buffer[bytes_received] = '\0';
|
|
std::vector<std::string> parsed = parse_resp_request(buffer);
|
|
|
|
if (parsed[0] == "GET") {
|
|
bool found = false;
|
|
|
|
for (const auto& inner_vec: database) { // Iterate through DB and find the required item
|
|
if (!inner_vec.empty() && inner_vec[0] == parsed[1]) {
|
|
int length = inner_vec[1].length();
|
|
std::string return_string = std::format("${}\r\n{}\r\n", length, inner_vec[1]); // Serialise to RESP
|
|
std::cout << "[d] Sending: " << return_string << std::endl;
|
|
co_await asio::async_write(client_socket, asio::buffer(return_string.data(), return_string.length()), use_awaitable);
|
|
found = true;
|
|
}
|
|
}
|
|
|
|
if (!found) {
|
|
co_await asio::async_write(client_socket, asio::buffer("$-1\r\n", strlen("$-1\r\n")), use_awaitable); // Not found
|
|
}
|
|
} else if (parsed[0] == "SET") {
|
|
bool found = false; // Avoids SEGFAULT
|
|
|
|
for (auto& inner_vec: database) {
|
|
if (!inner_vec.empty() && inner_vec[0] == parsed[1]) {
|
|
inner_vec[1] = parsed[2];
|
|
co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable);
|
|
found = true;
|
|
}
|
|
|
|
if (!found) {
|
|
std::vector<std::string> new_kv = { parsed[1], parsed[2] };
|
|
database.push_back(new_kv);
|
|
co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable);
|
|
found = true;
|
|
}
|
|
}
|
|
} else {
|
|
co_await asio::async_write(client_socket, asio::buffer("+OK\r\n", strlen("+OK\r\n")), use_awaitable); // Temporary catch-all (for more advanced handshake)
|
|
}
|
|
}
|
|
} catch (const std::exception& e) {
|
|
std::cout << "[i] Client session ened/error" << e.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
asio::awaitable<void> tcp_server_task(tcp::acceptor& acceptor) {
|
|
std::cout << "[i] Server listening on port 6379" << std::endl;
|
|
while (true) {
|
|
tcp::socket socket = co_await acceptor.async_accept(use_awaitable);
|
|
|
|
asio::co_spawn(
|
|
io_context,
|
|
handle_client(std::move(socket)),
|
|
asio::detached
|
|
);
|
|
}
|
|
}
|
|
|
|
asio::awaitable<void> shutdown_handler(asio::io_context& io_context, tcp::acceptor& acceptor, std::shared_ptr<asio::executor_work_guard<asio::io_context::executor_type>> work_guard) {
|
|
asio::signal_set signals(io_context, SIGINT, SIGTERM);
|
|
co_await signals.async_wait(use_awaitable);
|
|
std::cout << "[i] Stopping blueis server..." << std::endl;
|
|
acceptor.cancel();
|
|
acceptor.close();
|
|
|
|
work_guard->reset();
|
|
}
|
|
|
|
int main() {
|
|
load_database("db.json");
|
|
const int N_THREADS = std::thread::hardware_concurrency();
|
|
std::vector<std::thread> threads;
|
|
tcp::endpoint endpoint(tcp::v4(), 6379);
|
|
tcp::acceptor acceptor(io_context, endpoint);
|
|
auto work_guard = std::make_shared<asio::executor_work_guard<asio::io_context::executor_type>>(asio::make_work_guard(io_context));
|
|
|
|
asio::co_spawn(
|
|
io_context,
|
|
tcp_server_task(acceptor),
|
|
asio::detached
|
|
);
|
|
|
|
asio::co_spawn(
|
|
io_context,
|
|
shutdown_handler(io_context, acceptor, work_guard),
|
|
asio::detached
|
|
);
|
|
|
|
for (int i = 0; i < N_THREADS - 1; i++) {
|
|
threads.emplace_back([&io_context] {
|
|
io_context.run();
|
|
});
|
|
}
|
|
|
|
std::cout << "[i] Starting blueis server on 6379..." << std::endl;
|
|
io_context.run();
|
|
|
|
for (auto& t: threads) {
|
|
if (t.joinable()) {
|
|
t.join();
|
|
}
|
|
}
|
|
|
|
}
|