fix leaked client socket after receiving file

This commit is contained in:
2026-04-26 17:05:35 +03:00
parent 64b30e9835
commit 844093e3e7
4 changed files with 63 additions and 25 deletions
+1
View File
@@ -6,6 +6,7 @@ bool isClientTransferDone();
bool isClientTransferCancelled(); bool isClientTransferCancelled();
bool isClientConnectionFailed(); bool isClientConnectionFailed();
bool isClientProgressKnown(); bool isClientProgressKnown();
bool isClientWorkersIdle();
void cancelClientTransfer(); void cancelClientTransfer();
double getClientProgress(); double getClientProgress();
std::string getClientStatusText(); std::string getClientStatusText();
+3 -1
View File
@@ -2,6 +2,7 @@
#include "util.hpp" #include "util.hpp"
#include "main.hpp" #include "main.hpp"
#include <server.hpp> #include <server.hpp>
#include <client.hpp>
#include <unistd.h> #include <unistd.h>
static int nxlink_sock = -1; static int nxlink_sock = -1;
@@ -22,7 +23,8 @@ extern "C" void userAppInit() {
extern "C" void userAppExit() { extern "C" void userAppExit() {
cancelServerTransfer(); cancelServerTransfer();
for (int i = 0; i < 150 && !isServerWorkersIdle(); i++) { cancelClientTransfer();
for (int i = 0; i < 150 && (!isServerWorkersIdle() || !isClientWorkersIdle()); i++) {
usleep(10000); usleep(10000);
} }
if (nxlink_sock != -1) { if (nxlink_sock != -1) {
+2 -2
View File
@@ -54,7 +54,7 @@ namespace ui {
PanelX + space::lg, btnY, BtnW, BtnH, color::BgSurface2, radius::md); PanelX + space::lg, btnY, BtnW, BtnH, color::BgSurface2, radius::md);
this->Add(this->btnTransferBg); this->Add(this->btnTransferBg);
this->btnTransferText = pu::ui::elm::TextBlock::New( this->btnTransferText = pu::ui::elm::TextBlock::New(
PanelX + space::lg + space::md, btnY + 14, "Transfer to PC"); PanelX + space::lg + space::md, btnY + 14, "Transfer to another device");
this->btnTransferText->SetFont(type::font(type::Body)); this->btnTransferText->SetFont(type::font(type::Body));
this->btnTransferText->SetColor(color::TextSecondary); this->btnTransferText->SetColor(color::TextSecondary);
this->Add(this->btnTransferText); this->Add(this->btnTransferText);
@@ -64,7 +64,7 @@ namespace ui {
PanelX + space::lg, btnY2, BtnW, BtnH, color::BgSurface2, radius::md); PanelX + space::lg, btnY2, BtnW, BtnH, color::BgSurface2, radius::md);
this->Add(this->btnReceiveBg); this->Add(this->btnReceiveBg);
this->btnReceiveText = pu::ui::elm::TextBlock::New( this->btnReceiveText = pu::ui::elm::TextBlock::New(
PanelX + space::lg + space::md, btnY2 + 14, "Receive from PC"); PanelX + space::lg + space::md, btnY2 + 14, "Receive from another device");
this->btnReceiveText->SetFont(type::font(type::Body)); this->btnReceiveText->SetFont(type::font(type::Body));
this->btnReceiveText->SetColor(color::TextSecondary); this->btnReceiveText->SetColor(color::TextSecondary);
this->Add(this->btnReceiveText); this->Add(this->btnReceiveText);
+57 -22
View File
@@ -20,22 +20,38 @@
#include <protocol.hpp> #include <protocol.hpp>
#include <TransferState.hpp> #include <TransferState.hpp>
#include <net/Socket.hpp>
namespace fs = std::filesystem; namespace fs = std::filesystem;
using path = fs::path; using path = fs::path;
static TransferState g_client_state; static TransferState g_client_state;
static std::atomic<int> g_client_udp_sock{-1};
static std::atomic<int> g_client_tcp_sock{-1};
static std::atomic<bool> g_client_thread_active{false};
bool isClientTransferDone() { return g_client_state.done.load(); } bool isClientTransferDone() { return g_client_state.done.load(); }
bool isClientTransferCancelled() { return g_client_state.cancelled.load(); } bool isClientTransferCancelled() { return g_client_state.cancelled.load(); }
bool isClientConnectionFailed() { return g_client_state.connection_failed.load(); } bool isClientConnectionFailed() { return g_client_state.connection_failed.load(); }
bool isClientProgressKnown() { return g_client_state.bytes_total.load() > 0; } bool isClientProgressKnown() { return g_client_state.bytes_total.load() > 0; }
void cancelClientTransfer() { g_client_state.cancelled.store(true); } bool isClientWorkersIdle() { return !g_client_thread_active.load(); }
double getClientProgress() { return g_client_state.progress(); } double getClientProgress() { return g_client_state.progress(); }
std::string getClientStatusText() { return g_client_state.getStatus(); } std::string getClientStatusText() { return g_client_state.getStatus(); }
std::string getClientFailReason() { return g_client_state.fail_reason; } std::string getClientFailReason() { return g_client_state.fail_reason; }
void cancelClientTransfer() {
g_client_state.cancelled.store(true);
int udp = g_client_udp_sock.exchange(-1);
if (udp >= 0) {
shutdown(udp, SHUT_RDWR);
close(udp);
}
int tcp = g_client_tcp_sock.exchange(-1);
if (tcp >= 0) {
shutdown(tcp, SHUT_RDWR);
close(tcp);
}
}
static bool send_all(int sock, const void* buf, size_t len) { static bool send_all(int sock, const void* buf, size_t len) {
size_t sent = 0; size_t sent = 0;
while (sent < len) { while (sent < len) {
@@ -92,42 +108,55 @@ static void fail_connect(const std::string& reason) {
} }
static void* discovery_and_send_thread(void* arg) { static void* discovery_and_send_thread(void* arg) {
g_client_thread_active.store(true);
ThreadArgs* targs = static_cast<ThreadArgs*>(arg); ThreadArgs* targs = static_cast<ThreadArgs*>(arg);
size_t index = targs->index; size_t index = targs->index;
AccountUid uid = targs->uid; AccountUid uid = targs->uid;
delete targs; delete targs;
auto finish = [](void*) {
g_client_state.done.store(true);
g_client_thread_active.store(false);
return (void*)nullptr;
};
char server_ip[INET_ADDRSTRLEN]; char server_ip[INET_ADDRSTRLEN];
if (find_server(server_ip) != 0) { if (find_server(server_ip) != 0) {
if (!g_client_state.cancelled.load()) if (!g_client_state.cancelled.load())
fail_connect("No receiver found.\nMake sure the other Switch is in Receive mode."); fail_connect("No receiver found.\nMake sure the other Switch is in Receive mode.");
else return finish(nullptr);
g_client_state.done.store(true);
return nullptr;
} }
if (g_client_state.cancelled.load()) { g_client_state.done.store(true); return nullptr; } if (g_client_state.cancelled.load()) return finish(nullptr);
g_client_state.setStatus("Creating backup..."); g_client_state.setStatus("Creating backup...");
auto backupResult = io::backup(index, uid); auto backupResult = io::backup(index, uid);
if (!std::get<0>(backupResult)) { if (!std::get<0>(backupResult)) {
fail_connect("Failed to create backup:\n" + std::get<2>(backupResult)); fail_connect("Failed to create backup:\n" + std::get<2>(backupResult));
return nullptr; return finish(nullptr);
} }
fs::path directory = std::get<2>(backupResult); fs::path directory = std::get<2>(backupResult);
if (g_client_state.cancelled.load()) { g_client_state.done.store(true); return nullptr; } if (g_client_state.cancelled.load()) return finish(nullptr);
g_client_state.setStatus("Connecting..."); g_client_state.setStatus("Connecting...");
Socket tcp(socket(AF_INET, SOCK_STREAM, 0)); int tcp_fd = socket(AF_INET, SOCK_STREAM, 0);
if (!tcp.valid()) { fail_connect("Failed to open socket."); return nullptr; } if (tcp_fd < 0) { fail_connect("Failed to open socket."); return finish(nullptr); }
g_client_tcp_sock.store(tcp_fd);
auto release_tcp = [&]() {
int owned = g_client_tcp_sock.exchange(-1);
if (owned == tcp_fd) close(tcp_fd);
};
sockaddr_in serv{}; sockaddr_in serv{};
serv.sin_family = AF_INET; serv.sin_family = AF_INET;
serv.sin_port = htons(proto::TCP_PORT); serv.sin_port = htons(proto::TCP_PORT);
if (inet_pton(AF_INET, server_ip, &serv.sin_addr) <= 0 || if (inet_pton(AF_INET, server_ip, &serv.sin_addr) <= 0 ||
connect(tcp, (sockaddr*)&serv, sizeof(serv)) < 0) { connect(tcp_fd, (sockaddr*)&serv, sizeof(serv)) < 0) {
fail_connect("Failed to connect to receiver."); if (!g_client_state.cancelled.load())
return nullptr; fail_connect("Failed to connect to receiver.");
release_tcp();
return finish(nullptr);
} }
uint64_t total = 0; uint64_t total = 0;
@@ -141,21 +170,27 @@ static void* discovery_and_send_thread(void* arg) {
const path& p = entry.path(); const path& p = entry.path();
if (fs::is_regular_file(p)) { if (fs::is_regular_file(p)) {
g_client_state.setStatus(p.filename().string()); g_client_state.setStatus(p.filename().string());
if (!sendFile(tcp.fd, p)) break; if (!sendFile(tcp_fd, p)) break;
} }
} }
uint32_t sentinel = proto::EOF_SENTINEL; uint32_t sentinel = proto::EOF_SENTINEL;
send_all(tcp.fd, &sentinel, sizeof(sentinel)); send_all(tcp_fd, &sentinel, sizeof(sentinel));
release_tcp();
g_client_state.setStatus(""); g_client_state.setStatus("");
g_client_state.done.store(true); return finish(nullptr);
return nullptr;
} }
static int find_server(char* server_ip) { static int find_server(char* server_ip) {
int udp_fd = socket(AF_INET, SOCK_DGRAM, 0); int udp_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (udp_fd < 0) return -1; if (udp_fd < 0) return -1;
g_client_udp_sock.store(udp_fd);
auto release_udp = [&]() {
int owned = g_client_udp_sock.exchange(-1);
if (owned == udp_fd) close(udp_fd);
};
sockaddr_in addr{}; sockaddr_in addr{};
addr.sin_family = AF_INET; addr.sin_family = AF_INET;
@@ -164,7 +199,7 @@ static int find_server(char* server_ip) {
const char* msg = "DISCOVER_SERVER"; const char* msg = "DISCOVER_SERVER";
if (sendto(udp_fd, msg, strlen(msg), 0, (sockaddr*)&addr, sizeof(addr)) < 0) { if (sendto(udp_fd, msg, strlen(msg), 0, (sockaddr*)&addr, sizeof(addr)) < 0) {
close(udp_fd); release_udp();
return -1; return -1;
} }
@@ -172,7 +207,7 @@ static int find_server(char* server_ip) {
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(3); auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(3);
while (std::chrono::steady_clock::now() < deadline) { while (std::chrono::steady_clock::now() < deadline) {
if (g_client_state.cancelled.load()) { if (g_client_state.cancelled.load()) {
close(udp_fd); release_udp();
return -1; return -1;
} }
struct timeval tv{0, 100000}; struct timeval tv{0, 100000};
@@ -188,14 +223,14 @@ static int find_server(char* server_ip) {
buf[n] = '\0'; buf[n] = '\0';
if (strcmp(buf, "SERVER_HERE") == 0) { if (strcmp(buf, "SERVER_HERE") == 0) {
inet_ntop(AF_INET, &from.sin_addr, server_ip, INET_ADDRSTRLEN); inet_ntop(AF_INET, &from.sin_addr, server_ip, INET_ADDRSTRLEN);
close(udp_fd); release_udp();
return 0; return 0;
} }
} }
} }
} }
close(udp_fd); release_udp();
return -1; return -1;
} }