Files
NXST/source/server.cpp
T

346 lines
11 KiB
C++

#include <arpa/inet.h>
#include <atomic>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <netinet/in.h>
#include <pthread.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <vector>
#ifdef __SWITCH__
#include <server.hpp>
#include <switch.h>
#include <main.hpp>
#endif
#include <protocol.hpp>
#include <TransferState.hpp>
#include <net/Socket.hpp>
static TransferState g_server_state;
static std::atomic<int> g_server_client_sock{-1};
static std::atomic<int> g_server_listen_sock{-1};
static std::atomic<int> g_broadcast_sock{-1};
static std::atomic<bool> g_accept_thread_active{false};
static std::atomic<bool> g_broadcast_thread_active{false};
static pthread_t g_broadcast_thread{};
bool isServerTransferDone() { return g_server_state.done.load(); }
bool isServerTransferCancelled() { return g_server_state.cancelled.load(); }
bool isServerWorkersIdle() { return !g_accept_thread_active.load() && !g_broadcast_thread_active.load(); }
double getServerProgress() { return g_server_state.progress(); }
std::string getServerStatusText() { return g_server_state.getStatus(); }
void cancelServerTransfer() {
g_server_state.cancelled.store(true);
int sock = g_server_client_sock.exchange(-1);
if (sock >= 0) {
shutdown(sock, SHUT_RDWR);
close(sock);
}
int lsock = g_server_listen_sock.exchange(-1);
if (lsock >= 0) {
shutdown(lsock, SHUT_RDWR);
close(lsock);
}
int bsock = g_broadcast_sock.exchange(-1);
if (bsock >= 0) {
shutdown(bsock, SHUT_RDWR);
close(bsock);
}
if (g_broadcast_thread_active.load()) {
pthread_cancel(g_broadcast_thread);
}
}
#ifdef __SWITCH__
static std::string replaceUsername(const std::string& path) {
std::string username = StringUtils::removeNotAscii(
StringUtils::removeAccents(Account::username(g_currentUId)));
size_t lastSlash = path.rfind('/');
if (lastSlash == std::string::npos) return path;
size_t prevSlash = path.rfind('/', lastSlash - 1);
if (prevSlash == std::string::npos)
return username + path.substr(lastSlash);
return path.substr(0, prevSlash + 1) + username + path.substr(lastSlash);
}
#endif
static bool recv_all(int sock, void* buf, size_t len) {
size_t received = 0;
while (received < len) {
ssize_t n = read(sock, static_cast<char*>(buf) + received, len - received);
if (n <= 0) return false;
received += n;
}
return true;
}
static void mkdirs(const std::string& path) {
for (size_t i = 1; i < path.size(); i++) {
if (path[i] == '/') {
std::string component = path.substr(0, i);
mkdir(component.c_str(), 0777);
}
}
mkdir(path.c_str(), 0777);
}
static void receive_file(int sock, const std::string& relative_path, uint64_t file_size) {
std::cout << "Receiving: " << relative_path << " (" << file_size << " bytes)" << std::endl;
size_t last_slash = relative_path.rfind('/');
if (last_slash != std::string::npos) {
std::string dir = relative_path.substr(0, last_slash);
if (!dir.empty()) mkdirs(dir);
}
FILE* outfile = fopen(relative_path.c_str(), "wb");
if (!outfile) {
std::cerr << "Failed to open for writing: " << relative_path
<< " errno=" << errno << std::endl;
// Drain so sender doesn't hang
std::vector<char> drain(proto::BUF_SIZE);
uint64_t remaining = file_size;
while (remaining > 0) {
size_t to_read = (size_t)std::min(remaining, (uint64_t)proto::BUF_SIZE);
ssize_t n = read(sock, drain.data(), to_read);
if (n <= 0) break;
remaining -= (uint64_t)n;
}
return;
}
g_server_state.bytes_total.store(file_size);
g_server_state.bytes_done.store(0);
std::vector<char> buffer(proto::BUF_SIZE);
uint64_t total_received = 0;
while (total_received < file_size) {
size_t to_read = (size_t)std::min(file_size - total_received, (uint64_t)proto::BUF_SIZE);
ssize_t n = read(sock, buffer.data(), to_read);
if (n <= 0) {
std::cerr << "Read error receiving: " << relative_path << std::endl;
break;
}
fwrite(buffer.data(), 1, (size_t)n, outfile);
total_received += (uint64_t)n;
g_server_state.bytes_done.store(total_received);
}
fclose(outfile);
std::cout << "Received: " << relative_path << std::endl;
}
static void* handle_client(void* socket_desc) {
int client_socket = *(int*)socket_desc;
delete static_cast<int*>(socket_desc);
while (true) {
uint32_t filename_len = 0;
if (!recv_all(client_socket, &filename_len, sizeof(filename_len)))
break;
if (filename_len == proto::EOF_SENTINEL) {
std::cout << "End of transfer." << std::endl;
break;
}
if (filename_len > proto::MAX_FILENAME) {
std::cerr << "filename_len=" << filename_len << " exceeds MAX_FILENAME, aborting." << std::endl;
break;
}
std::vector<char> filename(filename_len + 1, '\0');
if (!recv_all(client_socket, filename.data(), filename_len)) {
std::cerr << "Short read on filename, aborting." << std::endl;
break;
}
std::string filename_str(filename.data(), filename_len);
#ifdef __SWITCH__
filename_str = replaceUsername(filename_str);
#endif
{
size_t sl = filename_str.rfind('/');
g_server_state.setStatus(
sl != std::string::npos ? filename_str.substr(sl + 1) : filename_str);
}
uint64_t file_size = 0;
if (!recv_all(client_socket, &file_size, sizeof(file_size))) {
std::cerr << "Short read on file_size, aborting." << std::endl;
break;
}
receive_file(client_socket, filename_str, file_size);
}
int owned_client = g_server_client_sock.exchange(-1);
if (owned_client == client_socket) {
close(client_socket);
}
return nullptr;
}
struct AcceptArgs { int server_fd; };
static void* accept_and_handle(void* arg) {
g_accept_thread_active.store(true);
int server_fd = static_cast<AcceptArgs*>(arg)->server_fd;
delete static_cast<AcceptArgs*>(arg);
g_server_listen_sock.store(server_fd);
sockaddr_in client_addr{};
socklen_t client_len = sizeof(client_addr);
int client_socket = accept(server_fd, (sockaddr*)&client_addr, &client_len);
int owned_listen = g_server_listen_sock.exchange(-1);
if (owned_listen == server_fd) {
close(server_fd);
}
if (client_socket >= 0) {
g_server_client_sock.store(client_socket);
int* pclient = new (std::nothrow) int(client_socket);
if (pclient) {
handle_client(pclient);
} else {
close(client_socket);
}
}
g_server_state.done.store(true);
g_accept_thread_active.store(false);
return nullptr;
}
static void* broadcast_listener(void* arg) {
g_broadcast_thread_active.store(true);
pthread_setcancelstate(PTHREAD_CANCEL_ENABLE, nullptr);
pthread_setcanceltype(PTHREAD_CANCEL_DEFERRED, nullptr);
int udp = socket(AF_INET, SOCK_DGRAM, 0);
if (udp < 0) {
perror("broadcast_listener: socket");
g_broadcast_thread_active.store(false);
return nullptr;
}
g_broadcast_sock.store(udp);
struct timeval tv{0, 20000}; // 20ms poll so cancel/exit wins race with socketExit
setsockopt(udp, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(proto::MULTICAST_PORT);
if (bind(udp, (sockaddr*)&addr, sizeof(addr)) < 0) {
perror("broadcast_listener: bind");
int owned = g_broadcast_sock.exchange(-1);
if (owned == udp) close(udp);
g_broadcast_thread_active.store(false);
return nullptr;
}
ip_mreq group{};
group.imr_multiaddr.s_addr = inet_addr(proto::MULTICAST_GROUP);
group.imr_interface.s_addr = htonl(INADDR_ANY);
if (setsockopt(udp, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)) < 0) {
perror("broadcast_listener: setsockopt");
int owned = g_broadcast_sock.exchange(-1);
if (owned == udp) close(udp);
g_broadcast_thread_active.store(false);
return nullptr;
}
std::cout << "Broadcast listener started" << std::endl;
char buf[256];
sockaddr_in from{};
socklen_t fromlen = sizeof(from);
while (true) {
ssize_t n = recvfrom(udp, buf, sizeof(buf) - 1, 0, (sockaddr*)&from, &fromlen);
if (n < 0) {
if (g_server_state.cancelled.load()) break;
continue;
}
buf[n] = '\0';
if (strcmp(buf, "DISCOVER_SERVER") == 0) {
const char* reply = "SERVER_HERE";
sendto(udp, reply, strlen(reply), 0, (sockaddr*)&from, fromlen);
std::cout << "Discovery replied." << std::endl;
break;
}
}
int owned = g_broadcast_sock.exchange(-1);
if (owned == udp) close(udp);
g_broadcast_thread_active.store(false);
return nullptr;
}
int startSendingThread() {
g_server_state.reset();
g_server_state.setStatus("Waiting for connection...");
pthread_t broadcast_thread;
if (pthread_create(&broadcast_thread, nullptr, broadcast_listener, nullptr) != 0) {
perror("startSendingThread: broadcast thread");
return 1;
}
g_broadcast_thread = broadcast_thread;
pthread_detach(broadcast_thread);
Socket server(socket(AF_INET, SOCK_STREAM, 0));
if (!server.valid()) {
perror("startSendingThread: socket");
cancelServerTransfer();
return 1;
}
int yes = 1;
setsockopt(server, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
sockaddr_in addr{};
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = htons(proto::TCP_PORT);
if (bind(server, (sockaddr*)&addr, sizeof(addr)) < 0) {
perror("startSendingThread: bind");
cancelServerTransfer();
return 1;
}
if (listen(server, 3) < 0) {
perror("startSendingThread: listen");
cancelServerTransfer();
return 1;
}
AcceptArgs* acc_args = new AcceptArgs{server.fd};
pthread_t accept_thread;
if (pthread_create(&accept_thread, nullptr, accept_and_handle, acc_args) != 0) {
delete acc_args;
cancelServerTransfer();
return 1;
}
pthread_detach(accept_thread);
server.release(); // accepted by accept_and_handle
return 0;
}
#ifndef __SWITCH__
int main() {
if (startSendingThread() != 0) return 1;
while (!isServerTransferDone()) usleep(16000);
return 0;
}
#endif