From 844093e3e712082e8e5b6db7cbbb4191bffd6b3d Mon Sep 17 00:00:00 2001 From: Nikolai Fedorov Date: Sun, 26 Apr 2026 17:05:35 +0300 Subject: [PATCH] fix leaked client socket after receiving file --- include/client.hpp | 1 + source/Main.cpp | 4 ++- source/TitlesLayout.cpp | 4 +-- source/client.cpp | 79 +++++++++++++++++++++++++++++------------ 4 files changed, 63 insertions(+), 25 deletions(-) diff --git a/include/client.hpp b/include/client.hpp index c46446c..582b8a6 100644 --- a/include/client.hpp +++ b/include/client.hpp @@ -6,6 +6,7 @@ bool isClientTransferDone(); bool isClientTransferCancelled(); bool isClientConnectionFailed(); bool isClientProgressKnown(); +bool isClientWorkersIdle(); void cancelClientTransfer(); double getClientProgress(); std::string getClientStatusText(); diff --git a/source/Main.cpp b/source/Main.cpp index 3e2a78e..47748d1 100644 --- a/source/Main.cpp +++ b/source/Main.cpp @@ -2,6 +2,7 @@ #include "util.hpp" #include "main.hpp" #include +#include #include static int nxlink_sock = -1; @@ -22,7 +23,8 @@ extern "C" void userAppInit() { extern "C" void userAppExit() { cancelServerTransfer(); - for (int i = 0; i < 150 && !isServerWorkersIdle(); i++) { + cancelClientTransfer(); + for (int i = 0; i < 150 && (!isServerWorkersIdle() || !isClientWorkersIdle()); i++) { usleep(10000); } if (nxlink_sock != -1) { diff --git a/source/TitlesLayout.cpp b/source/TitlesLayout.cpp index 8431cd0..3af2d07 100644 --- a/source/TitlesLayout.cpp +++ b/source/TitlesLayout.cpp @@ -54,7 +54,7 @@ namespace ui { PanelX + space::lg, btnY, BtnW, BtnH, color::BgSurface2, radius::md); this->Add(this->btnTransferBg); this->btnTransferText = pu::ui::elm::TextBlock::New( - PanelX + space::lg + space::md, btnY + 14, "Transfer to PC"); + PanelX + space::lg + space::md, btnY + 14, "Transfer to another device"); this->btnTransferText->SetFont(type::font(type::Body)); this->btnTransferText->SetColor(color::TextSecondary); this->Add(this->btnTransferText); @@ -64,7 +64,7 @@ namespace ui { PanelX + space::lg, btnY2, BtnW, BtnH, color::BgSurface2, radius::md); this->Add(this->btnReceiveBg); this->btnReceiveText = pu::ui::elm::TextBlock::New( - PanelX + space::lg + space::md, btnY2 + 14, "Receive from PC"); + PanelX + space::lg + space::md, btnY2 + 14, "Receive from another device"); this->btnReceiveText->SetFont(type::font(type::Body)); this->btnReceiveText->SetColor(color::TextSecondary); this->Add(this->btnReceiveText); diff --git a/source/client.cpp b/source/client.cpp index 710ca0e..0ea5350 100644 --- a/source/client.cpp +++ b/source/client.cpp @@ -20,22 +20,38 @@ #include #include -#include namespace fs = std::filesystem; using path = fs::path; -static TransferState g_client_state; +static TransferState g_client_state; +static std::atomic g_client_udp_sock{-1}; +static std::atomic g_client_tcp_sock{-1}; +static std::atomic g_client_thread_active{false}; bool isClientTransferDone() { return g_client_state.done.load(); } bool isClientTransferCancelled() { return g_client_state.cancelled.load(); } bool isClientConnectionFailed() { return g_client_state.connection_failed.load(); } bool isClientProgressKnown() { return g_client_state.bytes_total.load() > 0; } -void cancelClientTransfer() { g_client_state.cancelled.store(true); } +bool isClientWorkersIdle() { return !g_client_thread_active.load(); } double getClientProgress() { return g_client_state.progress(); } std::string getClientStatusText() { return g_client_state.getStatus(); } std::string getClientFailReason() { return g_client_state.fail_reason; } +void cancelClientTransfer() { + g_client_state.cancelled.store(true); + int udp = g_client_udp_sock.exchange(-1); + if (udp >= 0) { + shutdown(udp, SHUT_RDWR); + close(udp); + } + int tcp = g_client_tcp_sock.exchange(-1); + if (tcp >= 0) { + shutdown(tcp, SHUT_RDWR); + close(tcp); + } +} + static bool send_all(int sock, const void* buf, size_t len) { size_t sent = 0; while (sent < len) { @@ -92,42 +108,55 @@ static void fail_connect(const std::string& reason) { } static void* discovery_and_send_thread(void* arg) { + g_client_thread_active.store(true); ThreadArgs* targs = static_cast(arg); size_t index = targs->index; AccountUid uid = targs->uid; delete targs; + auto finish = [](void*) { + g_client_state.done.store(true); + g_client_thread_active.store(false); + return (void*)nullptr; + }; + char server_ip[INET_ADDRSTRLEN]; if (find_server(server_ip) != 0) { if (!g_client_state.cancelled.load()) fail_connect("No receiver found.\nMake sure the other Switch is in Receive mode."); - else - g_client_state.done.store(true); - return nullptr; + return finish(nullptr); } - if (g_client_state.cancelled.load()) { g_client_state.done.store(true); return nullptr; } + if (g_client_state.cancelled.load()) return finish(nullptr); g_client_state.setStatus("Creating backup..."); auto backupResult = io::backup(index, uid); if (!std::get<0>(backupResult)) { fail_connect("Failed to create backup:\n" + std::get<2>(backupResult)); - return nullptr; + return finish(nullptr); } fs::path directory = std::get<2>(backupResult); - if (g_client_state.cancelled.load()) { g_client_state.done.store(true); return nullptr; } + if (g_client_state.cancelled.load()) return finish(nullptr); g_client_state.setStatus("Connecting..."); - Socket tcp(socket(AF_INET, SOCK_STREAM, 0)); - if (!tcp.valid()) { fail_connect("Failed to open socket."); return nullptr; } + int tcp_fd = socket(AF_INET, SOCK_STREAM, 0); + if (tcp_fd < 0) { fail_connect("Failed to open socket."); return finish(nullptr); } + g_client_tcp_sock.store(tcp_fd); + + auto release_tcp = [&]() { + int owned = g_client_tcp_sock.exchange(-1); + if (owned == tcp_fd) close(tcp_fd); + }; sockaddr_in serv{}; serv.sin_family = AF_INET; serv.sin_port = htons(proto::TCP_PORT); if (inet_pton(AF_INET, server_ip, &serv.sin_addr) <= 0 || - connect(tcp, (sockaddr*)&serv, sizeof(serv)) < 0) { - fail_connect("Failed to connect to receiver."); - return nullptr; + connect(tcp_fd, (sockaddr*)&serv, sizeof(serv)) < 0) { + if (!g_client_state.cancelled.load()) + fail_connect("Failed to connect to receiver."); + release_tcp(); + return finish(nullptr); } uint64_t total = 0; @@ -141,21 +170,27 @@ static void* discovery_and_send_thread(void* arg) { const path& p = entry.path(); if (fs::is_regular_file(p)) { g_client_state.setStatus(p.filename().string()); - if (!sendFile(tcp.fd, p)) break; + if (!sendFile(tcp_fd, p)) break; } } uint32_t sentinel = proto::EOF_SENTINEL; - send_all(tcp.fd, &sentinel, sizeof(sentinel)); + send_all(tcp_fd, &sentinel, sizeof(sentinel)); + release_tcp(); g_client_state.setStatus(""); - g_client_state.done.store(true); - return nullptr; + return finish(nullptr); } static int find_server(char* server_ip) { int udp_fd = socket(AF_INET, SOCK_DGRAM, 0); if (udp_fd < 0) return -1; + g_client_udp_sock.store(udp_fd); + + auto release_udp = [&]() { + int owned = g_client_udp_sock.exchange(-1); + if (owned == udp_fd) close(udp_fd); + }; sockaddr_in addr{}; addr.sin_family = AF_INET; @@ -164,7 +199,7 @@ static int find_server(char* server_ip) { const char* msg = "DISCOVER_SERVER"; if (sendto(udp_fd, msg, strlen(msg), 0, (sockaddr*)&addr, sizeof(addr)) < 0) { - close(udp_fd); + release_udp(); return -1; } @@ -172,7 +207,7 @@ static int find_server(char* server_ip) { auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(3); while (std::chrono::steady_clock::now() < deadline) { if (g_client_state.cancelled.load()) { - close(udp_fd); + release_udp(); return -1; } struct timeval tv{0, 100000}; @@ -188,14 +223,14 @@ static int find_server(char* server_ip) { buf[n] = '\0'; if (strcmp(buf, "SERVER_HERE") == 0) { inet_ntop(AF_INET, &from.sin_addr, server_ip, INET_ADDRSTRLEN); - close(udp_fd); + release_udp(); return 0; } } } } - close(udp_fd); + release_udp(); return -1; }