diff --git a/CMakeLists.txt b/CMakeLists.txt index 1796f83..40916a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,8 @@ if(BUILD_TESTING) message(STATUS "building tests is enabled") enable_testing() + + # Server + Client tests add_custom_target(tests) message(STATUS "generating protocols") @@ -133,6 +135,29 @@ if(BUILD_TESTING) "tests/generated/test_protocol_v1-server.cpp") target_link_libraries(fork PRIVATE PkgConfig::deps hyprwire) add_dependencies(tests fork) + + # GTests + find_package(GTest CONFIG REQUIRED) + include(GoogleTest) + file(GLOB_RECURSE TESTFILES CONFIGURE_DEPENDS "tests/unit/*.cpp") + add_executable(hyprwire_tests ${TESTFILES}) + + target_compile_options(hyprwire_tests PRIVATE --coverage -fsanitize=address) + target_link_options(hyprwire_tests PRIVATE --coverage) + + target_include_directories( + hyprwire_tests + PUBLIC "./include" + PRIVATE "./src" "./src/include" "./protocols" "${CMAKE_BINARY_DIR}") + target_link_libraries(hyprwire_tests PRIVATE asan hyprwire GTest::gtest_main + PkgConfig::deps) + gtest_discover_tests(hyprwire_tests + PROPERTIES ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0" + ) + + # Add coverage to hyprwire for test builds + target_compile_options(hyprwire PRIVATE --coverage) + target_link_options(hyprwire PRIVATE --coverage) else() message(STATUS "building tests is disabled") endif() diff --git a/include/hyprwire/core/implementation/Types.hpp b/include/hyprwire/core/implementation/Types.hpp index 3489c11..9785799 100644 --- a/include/hyprwire/core/implementation/Types.hpp +++ b/include/hyprwire/core/implementation/Types.hpp @@ -9,8 +9,9 @@ namespace Hyprwire { struct SMethod { uint32_t idx = 0; std::vector params; - std::string returnsType = ""; - uint32_t since = 0; + std::string returnsType = ""; + uint32_t since = 0; + bool isDestructor = false; }; class IProtocolObjectSpec { @@ -26,4 +27,4 @@ namespace Hyprwire { IProtocolObjectSpec() = default; }; -}; \ No newline at end of file +}; diff --git a/nix/default.nix b/nix/default.nix index 377a1fc..9983046 100644 --- a/nix/default.nix +++ b/nix/default.nix @@ -3,6 +3,7 @@ stdenv, cmake, pkg-config, + gtest, hyprutils, libffi, pugixml, @@ -18,7 +19,8 @@ stdenv.mkDerivation { nativeBuildInputs = [ cmake pkg-config - ]; + ] + ++ lib.optionals doCheck [ gtest ]; buildInputs = [ hyprutils diff --git a/scanner/main.cpp b/scanner/main.cpp index c5d0310..4b6166f 100644 --- a/scanner/main.cpp +++ b/scanner/main.cpp @@ -359,8 +359,9 @@ Hyprwire::SMethod{{ .params = {{ {} }}, .returnsType = "{}", .since = {}, +.isDestructor = {}, }},)#", - m.idx, argArrayStr, m.returns, m.since); + m.idx, argArrayStr, m.returns, m.since, m.destructor ? "true" : "false"); } if (!object.c2s.empty()) @@ -391,8 +392,9 @@ Hyprwire::SMethod{{ .idx = {}, .params = {{ {} }}, .since = {}, +.isDestructor = {}, }},)#", - m.idx, argArrayStr, m.since); + m.idx, argArrayStr, m.since, m.destructor ? "true" : "false"); } if (!object.s2c.empty()) diff --git a/src/core/client/ClientObject.cpp b/src/core/client/ClientObject.cpp index 0f8c526..c335580 100644 --- a/src/core/client/ClientObject.cpp +++ b/src/core/client/ClientObject.cpp @@ -18,6 +18,31 @@ CClientObject::CClientObject(SP client) : m_client(client) { } CClientObject::~CClientObject() { + if (!m_destroyed && m_id != 0 && m_spec && m_client && m_client->m_fd.isValid()) { + const auto methods = methodsOut(); + for (const auto& method : methods) { + if (!method.isDestructor) + continue; + + if (method.since > m_version) + continue; + + if (!method.returnsType.empty()) { + Debug::log(WARN, "can't auto-call destructor for object {}: method {} has returns type", m_id, method.idx); + break; + } + + if (!method.params.empty()) { + Debug::log(WARN, "can't auto-call destructor for object {}: method {} has params", m_id, method.idx); + break; + } + + TRACE(Debug::log(TRACE, "auto-calling protocol destructor {} for object {}", method.idx, m_id)); + call(method.idx); + break; + } + } + TRACE(Debug::log(TRACE, "destroying object {}", m_id)); } diff --git a/src/core/client/ClientSocket.cpp b/src/core/client/ClientSocket.cpp index 3c20152..008b625 100644 --- a/src/core/client/ClientSocket.cpp +++ b/src/core/client/ClientSocket.cpp @@ -1,6 +1,7 @@ #include "ClientSocket.hpp" #include "../../helpers/Memory.hpp" #include "../../helpers/Log.hpp" +#include "../../helpers/Syscalls.hpp" #include "../../Macros.hpp" #include "../message/MessageParser.hpp" #include "../message/messages/IMessage.hpp" @@ -30,6 +31,10 @@ using namespace Hyprwire; using namespace Hyprutils::OS; using namespace Hyprutils::Utils; +namespace { + std::chrono::milliseconds g_handshakeMax = std::chrono::milliseconds(5000); +} + SP IClientSocket::open(const std::string& path) { SP sock = makeShared(); sock->m_self = sock; @@ -101,7 +106,13 @@ void CClientSocket::addImplementation(SP&& x) { m_impls.emplace_back(std::move(x)); } -constexpr const size_t HANDSHAKE_MAX_MS = 5000; +void CClientSocket::setHandshakeTimeoutForTests(std::chrono::milliseconds timeout) { + g_handshakeMax = timeout; +} + +void CClientSocket::resetHandshakeTimeoutForTests() { + g_handshakeMax = std::chrono::milliseconds(5000); +} // bool CClientSocket::dispatchEvents(bool block) { @@ -109,10 +120,20 @@ bool CClientSocket::dispatchEvents(bool block) { if (m_error) return false; + collectOrphanedObjects(); + if (!m_handshakeDone) { - const auto MAX_MS = - std::chrono::duration_cast(std::chrono::milliseconds(HANDSHAKE_MAX_MS) - (std::chrono::steady_clock::now() - m_handshakeBegin)).count(); - int ret = poll(m_pollfds.data(), m_pollfds.size(), block ? MAX_MS : 0); + const auto elapsed = std::chrono::steady_clock::now() - m_handshakeBegin; + const auto maxMs = g_handshakeMax; + + if (block && elapsed >= maxMs) { + Debug::log(ERR, "handshake error: timed out"); + disconnectOnError(); + return false; + } + + const auto timeout = block ? std::chrono::duration_cast(maxMs - elapsed).count() : 0; + int ret = Syscalls::poll(m_pollfds.data(), m_pollfds.size(), static_cast(timeout)); if (block && !ret) { Debug::log(ERR, "handshake error: timed out"); disconnectOnError(); @@ -121,13 +142,15 @@ bool CClientSocket::dispatchEvents(bool block) { } if (m_handshakeDone) - poll(m_pollfds.data(), m_pollfds.size(), block ? -1 : 0); + Syscalls::poll(m_pollfds.data(), m_pollfds.size(), block ? -1 : 0); if (m_pollfds[0].revents & POLLHUP) return false; - if (!(m_pollfds[0].revents & POLLIN)) + if (!(m_pollfds[0].revents & POLLIN)) { + collectOrphanedObjects(); return true; + } // dispatch @@ -165,6 +188,8 @@ bool CClientSocket::dispatchEvents(bool block) { return true; }); + collectOrphanedObjects(); + return !m_error; } @@ -203,13 +228,13 @@ void CClientSocket::sendMessage(const IMessage& message) { } while (m_fd.isValid()) { - int ret = sendmsg(m_fd.get(), &msg, 0); + int ret = Syscalls::sendmsg(m_fd.get(), &msg, 0); if (ret < 0 && (errno == EWOULDBLOCK || errno == EAGAIN)) { pollfd pfd = { .fd = m_fd.get(), .events = POLLOUT | POLLWRBAND, }; - poll(&pfd, 1, -1); + Syscalls::poll(&pfd, 1, -1); } else break; } @@ -327,14 +352,38 @@ void CClientSocket::waitForObject(SP x) { } void CClientSocket::onGeneric(const CGenericProtocolMessage& msg) { + SP object; + for (const auto& o : m_objects) { - if (o->m_id == msg.m_object) { - o->called(msg.m_method, msg.m_dataSpan, msg.m_fds); - return; + if (o && o->m_id == msg.m_object) { + object = o; + break; } } - Debug::log(WARN, "[{} @ {:.3f}] -> Generic message not handled. No object with id {}!", m_fd.get(), steadyMillis(), msg.m_object); + if (!object) { + Debug::log(ERR, "[{} @ {:.3f}] -> Generic message references unknown object {}", m_fd.get(), steadyMillis(), msg.m_object); + disconnectOnError(); + return; + } + + object->called(msg.m_method, msg.m_dataSpan, msg.m_fds); +} + +void CClientSocket::destroyObject(uint32_t id) { + std::erase_if(m_objects, [id](const auto& obj) { return obj && obj->m_id == id; }); +} + +void CClientSocket::collectOrphanedObjects() { + std::erase_if(m_objects, [](const auto& obj) { + if (!obj) + return true; + + if (obj->m_id == 0) + return false; + + return obj.strongRef() == 1; + }); } SP CClientSocket::objectForId(uint32_t id) { diff --git a/src/core/client/ClientSocket.hpp b/src/core/client/ClientSocket.hpp index cdfbec0..5a70179 100644 --- a/src/core/client/ClientSocket.hpp +++ b/src/core/client/ClientSocket.hpp @@ -8,6 +8,7 @@ #include #include +#include namespace Hyprwire { class IMessage; @@ -33,11 +34,16 @@ namespace Hyprwire { virtual void roundtrip(); virtual bool isHandshakeDone(); + static void setHandshakeTimeoutForTests(std::chrono::milliseconds timeout); + static void resetHandshakeTimeoutForTests(); + void sendMessage(const IMessage& message); void serverSpecs(const std::vector& s); void recheckPollFds(); void onSeq(uint32_t seq, uint32_t id); void onGeneric(const CGenericProtocolMessage& msg); + void destroyObject(uint32_t id); + void collectOrphanedObjects(); SP makeObject(const std::string& protocolName, const std::string& objectName, uint32_t seq); void waitForObject(SP); @@ -65,4 +71,4 @@ namespace Hyprwire { uint32_t m_lastAckdRoundtripSeq = 0; uint32_t m_lastSentRoundtripSeq = 0; }; -}; \ No newline at end of file +}; diff --git a/src/core/message/messages/FatalProtocolError.cpp b/src/core/message/messages/FatalProtocolError.cpp index e8976f5..82b3f21 100644 --- a/src/core/message/messages/FatalProtocolError.cpp +++ b/src/core/message/messages/FatalProtocolError.cpp @@ -53,12 +53,19 @@ CFatalErrorMessage::CFatalErrorMessage(const std::vector& data, size_t } CFatalErrorMessage::CFatalErrorMessage(SP obj, uint32_t errorId, const std::string_view& msg) { + uint32_t objectId = 0; + if (obj) + objectId = obj->m_id; + + *this = CFatalErrorMessage(objectId, errorId, msg); +} + +CFatalErrorMessage::CFatalErrorMessage(uint32_t objectId, uint32_t errorId, const std::string_view& msg) { m_type = HW_MESSAGE_TYPE_FATAL_PROTOCOL_ERROR; m_data = {HW_MESSAGE_TYPE_FATAL_PROTOCOL_ERROR, HW_MESSAGE_MAGIC_TYPE_UINT, 0, 0, 0, 0, HW_MESSAGE_MAGIC_TYPE_UINT, 0, 0, 0, 0, HW_MESSAGE_MAGIC_TYPE_VARCHAR}; - if (obj) - std::memcpy(&m_data[2], &obj->m_id, sizeof(obj->m_id)); + std::memcpy(&m_data[2], &objectId, sizeof(objectId)); std::memcpy(&m_data[7], &errorId, sizeof(errorId)); m_data.append_range(g_messageParser->encodeVarInt(msg.size())); diff --git a/src/core/message/messages/FatalProtocolError.hpp b/src/core/message/messages/FatalProtocolError.hpp index 3b42002..b996d38 100644 --- a/src/core/message/messages/FatalProtocolError.hpp +++ b/src/core/message/messages/FatalProtocolError.hpp @@ -13,6 +13,7 @@ namespace Hyprwire { public: CFatalErrorMessage(const std::vector& data, size_t offset); CFatalErrorMessage(SP obj, uint32_t errorId, const std::string_view& msg); + CFatalErrorMessage(uint32_t objectId, uint32_t errorId, const std::string_view& msg); virtual ~CFatalErrorMessage() = default; @@ -20,4 +21,4 @@ namespace Hyprwire { uint32_t m_errorId = 0; std::string m_errorMsg; }; -}; \ No newline at end of file +}; diff --git a/src/core/server/ServerClient.cpp b/src/core/server/ServerClient.cpp index 495edcc..27f4ec2 100644 --- a/src/core/server/ServerClient.cpp +++ b/src/core/server/ServerClient.cpp @@ -4,7 +4,9 @@ #include "../message/messages/IMessage.hpp" #include "../message/messages/NewObject.hpp" #include "../message/messages/GenericProtocolMessage.hpp" +#include "../message/messages/FatalProtocolError.hpp" #include "../../helpers/Log.hpp" +#include "../../helpers/Syscalls.hpp" #include "../../Macros.hpp" #include @@ -80,13 +82,13 @@ void CServerClient::sendMessage(const IMessage& message) { } while (m_fd.isValid()) { - int ret = sendmsg(m_fd.get(), &msg, 0); + int ret = Syscalls::sendmsg(m_fd.get(), &msg, 0); if (ret < 0 && (errno == EWOULDBLOCK || errno == EAGAIN)) { pollfd pfd = { .fd = m_fd.get(), .events = POLLOUT | POLLWRBAND, }; - poll(&pfd, 1, -1); + Syscalls::poll(&pfd, 1, -1); } else break; } @@ -143,6 +145,10 @@ SP CServerClient::createObject(const std::string& protocol, const return obj; } +void CServerClient::destroyObject(uint32_t id) { + std::erase_if(m_objects, [id](const auto& obj) { return obj && obj->m_id == id; }); +} + void CServerClient::onBind(SP obj) { for (const auto& p : m_server->m_impls) { if (p->protocol()->specName() != obj->m_protocolName) @@ -162,14 +168,24 @@ void CServerClient::onBind(SP obj) { } void CServerClient::onGeneric(const CGenericProtocolMessage& msg) { + SP object; + for (const auto& o : m_objects) { - if (o->m_id == msg.m_object) { - o->called(msg.m_method, msg.m_dataSpan, msg.m_fds); - return; + if (o && o->m_id == msg.m_object) { + object = o; + break; } } - Debug::log(WARN, "[{} @ {:.3f}] -> Generic message not handled. No object with id {}!", m_fd.get(), steadyMillis(), msg.m_object); + if (!object) { + const auto error = std::format("generic message references unknown object {}", msg.m_object); + Debug::log(ERR, "[{} @ {:.3f}] -> {}", m_fd.get(), steadyMillis(), error); + sendMessage(CFatalErrorMessage(msg.m_object, static_cast(-1), error)); + m_error = true; + return; + } + + object->called(msg.m_method, msg.m_dataSpan, msg.m_fds); } int CServerClient::getPID() { diff --git a/src/core/server/ServerClient.hpp b/src/core/server/ServerClient.hpp index 5ce4cf2..2c84da5 100644 --- a/src/core/server/ServerClient.hpp +++ b/src/core/server/ServerClient.hpp @@ -21,6 +21,7 @@ namespace Hyprwire { void sendMessage(const IMessage& message); SP createObject(const std::string& protocol, const std::string& object, uint32_t version, uint32_t seq); + void destroyObject(uint32_t id); void onBind(SP obj); void onGeneric(const CGenericProtocolMessage& msg); void dispatchFirstPoll(); @@ -40,4 +41,4 @@ namespace Hyprwire { WP m_server; WP m_self; }; -}; \ No newline at end of file +}; diff --git a/src/core/server/ServerSocket.cpp b/src/core/server/ServerSocket.cpp index 4509100..1122aaf 100644 --- a/src/core/server/ServerSocket.cpp +++ b/src/core/server/ServerSocket.cpp @@ -3,6 +3,7 @@ #include "ServerObject.hpp" #include "../../helpers/Memory.hpp" #include "../../helpers/Log.hpp" +#include "../../helpers/Syscalls.hpp" #include "../../Macros.hpp" #include "../message/MessageParser.hpp" #include "../message/messages/FatalProtocolError.hpp" @@ -142,7 +143,7 @@ void CServerSocket::addImplementation(SP&& x) { } bool CServerSocket::dispatchPending() { - poll(m_pollfds.data(), m_pollfds.size(), 0); + Syscalls::poll(m_pollfds.data(), m_pollfds.size(), 0); if (dispatchNewConnections()) return dispatchPending(); @@ -162,7 +163,7 @@ bool CServerSocket::dispatchEvents(bool block) { clearWakeupFd(); if (block) { - poll(m_pollfds.data(), m_pollfds.size(), -1); + Syscalls::poll(m_pollfds.data(), m_pollfds.size(), -1); while (dispatchPending()) { ; } @@ -186,7 +187,7 @@ void CServerSocket::clearFd(const Hyprutils::OS::CFileDescriptor& fd) { }; while (fd.isValid()) { - poll(&pfd, 1, 0); + Syscalls::poll(&pfd, 1, 0); if (pfd.revents & POLLIN) { sc(read(fd.get(), buf, 127)); @@ -302,8 +303,10 @@ bool CServerSocket::dispatchExistingConnections() { continue; } - if (m_clients.at(i - internalFds())->m_error) + if (m_clients.at(i - internalFds())->m_error) { + needsPollRecheck = true; TRACE(Debug::log(TRACE, "[{} @ {:.3f}] Dropping client (protocol error)", m_clients.at(i - internalFds())->m_fd.get(), steadyMillis())); + } } if (needsPollRecheck) { @@ -400,7 +403,7 @@ int CServerSocket::extractLoopFD() { m_pollmtx.unlock(); - poll(pollfds.data(), pollfds.size(), -1); + Syscalls::poll(pollfds.data(), pollfds.size(), -1); if (!m_threadCanPoll) return; diff --git a/src/core/socket/SocketHelpers.cpp b/src/core/socket/SocketHelpers.cpp index bf75385..2d98611 100644 --- a/src/core/socket/SocketHelpers.cpp +++ b/src/core/socket/SocketHelpers.cpp @@ -2,6 +2,7 @@ #include "../../helpers/Log.hpp" #include "../../helpers/Memory.hpp" +#include "../../helpers/Syscalls.hpp" #include "../../Macros.hpp" #include @@ -38,7 +39,7 @@ SSocketRawParsedMessage Hyprwire::parseFromFd(const Hyprutils::OS::CFileDescript msg.msg_control = controlBuf.data(); msg.msg_controllen = controlBuf.size(); - sizeWritten = recvmsg(fd.get(), &msg, 0); + sizeWritten = Syscalls::recvmsg(fd.get(), &msg, 0); if (sizeWritten < 0) return {.bad = true}; diff --git a/src/core/wireObject/IWireObject.cpp b/src/core/wireObject/IWireObject.cpp index 81c7eca..264b831 100644 --- a/src/core/wireObject/IWireObject.cpp +++ b/src/core/wireObject/IWireObject.cpp @@ -4,6 +4,7 @@ #include "../../helpers/Log.hpp" #include "../../helpers/FFI.hpp" #include "../client/ClientObject.hpp" +#include "../server/ServerObject.hpp" #include "../message/MessageType.hpp" #include "../message/MessageParser.hpp" #include "../message/MessageMagic.hpp" @@ -19,6 +20,30 @@ using namespace Hyprwire; using namespace Hyprutils::Utils; +namespace { + void destroyIfNeeded(Hyprwire::IWireObject* obj, const Hyprwire::SMethod& method) { + if (!method.isDestructor) + return; + + obj->m_destroyed = true; + + const auto id = obj->m_id; + if (id == 0) + return; + + if (obj->server()) { + auto serverObj = reinterpret_cast(obj); + if (serverObj->m_client) + serverObj->m_client->destroyObject(id); + return; + } + + auto clientObj = reinterpret_cast(obj); + if (clientObj->m_client) + clientObj->m_client->destroyObject(id); + } +} + IWireObject::~IWireObject() = default; uint32_t IWireObject::call(uint32_t id, ...) { @@ -50,6 +75,9 @@ uint32_t IWireObject::call(uint32_t id, ...) { return 0; } + if (method.isDestructor) + m_destroyed = true; + // encode the message std::vector data; std::vector fds; @@ -224,10 +252,13 @@ void IWireObject::called(uint32_t id, const std::span& data, cons return; } - if (m_listeners.size() <= id || m_listeners.at(id) == nullptr) + const auto& method = METHODS.at(id); + + if (m_listeners.size() <= id || m_listeners.at(id) == nullptr) { + destroyIfNeeded(this, method); return; + } - const auto& method = METHODS.at(id); std::vector params; if (!method.returnsType.empty()) @@ -511,4 +542,6 @@ void IWireObject::called(uint32_t id, const std::span& data, cons auto fptr = reinterpret_cast(m_listeners.at(id)); ffi_call(&cif, fptr, nullptr, avalues.data()); + + destroyIfNeeded(this, method); } diff --git a/src/core/wireObject/IWireObject.hpp b/src/core/wireObject/IWireObject.hpp index 2dde5d4..9d32966 100644 --- a/src/core/wireObject/IWireObject.hpp +++ b/src/core/wireObject/IWireObject.hpp @@ -26,6 +26,7 @@ namespace Hyprwire { std::vector m_listeners; uint32_t m_id = 0, m_version = 0, m_seq = 1; + bool m_destroyed = false; std::string m_protocolName; SP m_spec; @@ -34,4 +35,4 @@ namespace Hyprwire { protected: IWireObject() = default; }; -}; \ No newline at end of file +}; diff --git a/src/helpers/Env.cpp b/src/helpers/Env.cpp index 9455384..88bc184 100644 --- a/src/helpers/Env.cpp +++ b/src/helpers/Env.cpp @@ -6,6 +6,11 @@ using namespace Hyprwire; using namespace Hyprwire::Env; +namespace { + bool g_traceCached = false; + bool g_trace = false; +} + bool Hyprwire::Env::envEnabled(const std::string& env) { auto ret = getenv(env.c_str()); if (!ret) @@ -17,6 +22,15 @@ bool Hyprwire::Env::envEnabled(const std::string& env) { } bool Hyprwire::Env::isTrace() { - static bool TRACE = envEnabled("HW_TRACE"); - return TRACE; -} \ No newline at end of file + if (!g_traceCached) { + g_trace = envEnabled("HW_TRACE"); + g_traceCached = true; + } + + return g_trace; +} + +void Hyprwire::Env::resetTraceCache() { + g_traceCached = false; + g_trace = false; +} diff --git a/src/helpers/Env.hpp b/src/helpers/Env.hpp index 4870ec9..d8a60c3 100644 --- a/src/helpers/Env.hpp +++ b/src/helpers/Env.hpp @@ -5,4 +5,5 @@ namespace Hyprwire::Env { bool envEnabled(const std::string& env); bool isTrace(); + void resetTraceCache(); } diff --git a/src/helpers/Syscalls.cpp b/src/helpers/Syscalls.cpp new file mode 100644 index 0000000..bc09659 --- /dev/null +++ b/src/helpers/Syscalls.cpp @@ -0,0 +1,38 @@ +#include "Syscalls.hpp" + +#include + +using namespace Hyprwire; + +namespace { + Hyprwire::Syscalls::SHooks g_hooks; +} + +int Hyprwire::Syscalls::poll(pollfd* fds, nfds_t nfds, int timeout) { + if (g_hooks.poll) + return g_hooks.poll(fds, nfds, timeout); + + return ::poll(fds, nfds, timeout); +} + +ssize_t Hyprwire::Syscalls::sendmsg(int sockfd, const msghdr* msg, int flags) { + if (g_hooks.sendmsg) + return g_hooks.sendmsg(sockfd, msg, flags); + + return ::sendmsg(sockfd, msg, flags); +} + +ssize_t Hyprwire::Syscalls::recvmsg(int sockfd, msghdr* msg, int flags) { + if (g_hooks.recvmsg) + return g_hooks.recvmsg(sockfd, msg, flags); + + return ::recvmsg(sockfd, msg, flags); +} + +void Hyprwire::Syscalls::setHooks(const SHooks& hooks) { + g_hooks = hooks; +} + +void Hyprwire::Syscalls::resetHooks() { + g_hooks = {}; +} diff --git a/src/helpers/Syscalls.hpp b/src/helpers/Syscalls.hpp new file mode 100644 index 0000000..59b5c18 --- /dev/null +++ b/src/helpers/Syscalls.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace Hyprwire::Syscalls { + using PFN_poll = int (*)(pollfd* fds, nfds_t nfds, int timeout); + using PFN_sendmsg = ssize_t (*)(int sockfd, const msghdr* msg, int flags); + using PFN_recvmsg = ssize_t (*)(int sockfd, msghdr* msg, int flags); + + struct SHooks { + PFN_poll poll = nullptr; + PFN_sendmsg sendmsg = nullptr; + PFN_recvmsg recvmsg = nullptr; + }; + + int poll(pollfd* fds, nfds_t nfds, int timeout); + ssize_t sendmsg(int sockfd, const msghdr* msg, int flags); + ssize_t recvmsg(int sockfd, msghdr* msg, int flags); + + void setHooks(const SHooks& hooks); + void resetHooks(); +}; diff --git a/tests/Client.cpp b/tests/Client.cpp index a29563d..4092365 100644 --- a/tests/Client.cpp +++ b/tests/Client.cpp @@ -51,10 +51,10 @@ int main(int argc, char** argv, char** envp) { int pips2[2]; int pips3[2]; - + sc(pipe(pips2)); sc(pipe(pips3)); - + sc(write(pips2[1], "o kurwa", 7)); sc(write(pips3[1], "bober!!", 7)); diff --git a/tests/Server.cpp b/tests/Server.cpp index 36cdaa7..171835d 100644 --- a/tests/Server.cpp +++ b/tests/Server.cpp @@ -26,7 +26,7 @@ static SP spec = makeShared(1, [] }); manager->setSendMessageArrayFd([](const std::vector& fds) { std::println("Received {} fds", fds.size()); - + for (int fd : fds) { char msgbuf[8] = {0}; sc(read(fd, msgbuf, 7)); diff --git a/tests/unit/FFIEnv.cpp b/tests/unit/FFIEnv.cpp new file mode 100644 index 0000000..06f2c58 --- /dev/null +++ b/tests/unit/FFIEnv.cpp @@ -0,0 +1,70 @@ +#include + +#include "helpers/Env.hpp" +#include "helpers/FFI.hpp" + +#include +#include + +#include + +using namespace Hyprwire; + +TEST(FFI, MapsKnownMagicTypesToExpectedFfiTypes) { + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_UINT), &ffi_type_uint32); + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_OBJECT), &ffi_type_uint32); + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_SEQ), &ffi_type_uint32); + + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_INT), &ffi_type_sint32); + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_FD), &ffi_type_sint32); + + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_F32), &ffi_type_float); + + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_VARCHAR), &ffi_type_pointer); + EXPECT_EQ(FFI::ffiTypeFrom(HW_MESSAGE_MAGIC_TYPE_ARRAY), &ffi_type_pointer); +} + +TEST(FFI, UnknownMagicReturnsNull) { + EXPECT_EQ(FFI::ffiTypeFrom(static_cast(0xFF)), nullptr); +} + +TEST(Env, EnvEnabledFollowsVariableContents) { + constexpr const char* name = "HW_TEST_ENV_ENABLED"; + + unsetenv(name); + EXPECT_FALSE(Env::envEnabled(name)); + + setenv(name, "0", 1); + EXPECT_FALSE(Env::envEnabled(name)); + + setenv(name, "1", 1); + EXPECT_TRUE(Env::envEnabled(name)); + + setenv(name, "hello", 1); + EXPECT_TRUE(Env::envEnabled(name)); + + unsetenv(name); +} + +TEST(Env, TraceCacheCanBeResetForDeterministicTests) { + constexpr const char* traceName = "HW_TRACE"; + + unsetenv(traceName); + Env::resetTraceCache(); + EXPECT_FALSE(Env::isTrace()); + + setenv(traceName, "1", 1); + EXPECT_FALSE(Env::isTrace()) << "isTrace should stay cached until reset"; + + Env::resetTraceCache(); + EXPECT_TRUE(Env::isTrace()); + + setenv(traceName, "0", 1); + EXPECT_TRUE(Env::isTrace()) << "isTrace should stay cached until reset"; + + Env::resetTraceCache(); + EXPECT_FALSE(Env::isTrace()); + + unsetenv(traceName); + Env::resetTraceCache(); +} diff --git a/tests/unit/IWireObjectMatrix.cpp b/tests/unit/IWireObjectMatrix.cpp new file mode 100644 index 0000000..e14ae21 --- /dev/null +++ b/tests/unit/IWireObjectMatrix.cpp @@ -0,0 +1,290 @@ +#include + +#include "core/message/MessageParser.hpp" +#include "core/message/MessageType.hpp" +#include "core/message/messages/IMessage.hpp" +#include "core/wireObject/IWireObject.hpp" + +#include + +#include +#include +#include +#include + +using namespace Hyprwire; + +namespace { + + template + void* fnToVoid(Fn fn) { + static_assert(std::is_pointer_v); + static_assert(sizeof(Fn) == sizeof(void*)); + + union { + Fn f; + void* p; + } caster = { + .f = fn, + }; + + return caster.p; + } + + class CTestWireObject final : public IWireObject { + public: + explicit CTestWireObject(bool isServer) : m_server(isServer) { + ; + } + + const std::vector& methodsOut() override { + return m_methodsOut; + } + + const std::vector& methodsIn() override { + return m_methodsIn; + } + + void errd() override { + m_errd = true; + } + + void sendMessage(const IMessage& msg) override { + m_lastSentData = msg.m_data; + m_lastSentFds = msg.fds(); + } + + bool server() override { + return m_server; + } + + SP self() override { + return m_self.lock(); + } + + SP client() override { + return nullptr; + } + + void error(uint32_t id, const std::string_view& message) override { + m_lastErrorId = id; + m_lastErrorMsg = std::string{message}; + } + + public: + bool m_server = false; + bool m_errd = false; + + uint32_t m_lastErrorId = 0; + std::string m_lastErrorMsg; + + std::vector m_lastSentData; + std::vector m_lastSentFds; + + std::vector m_methodsOut; + std::vector m_methodsIn; + }; + + SP makeObject(bool isServer = false) { + auto obj = makeShared(isServer); + obj->m_self = reinterpretPointerCast(obj); + return obj; + } + + int g_destructorListenerCalls = 0; + + void onNoop(IObject*) { + ; + } + + void onDestructor(IObject*) { + ++g_destructorListenerCalls; + } + +} // namespace + +TEST(IWireObjectMatrix, CallRejectsInvalidMethodIndex) { + auto obj = makeObject(); + obj->m_id = 77; + + EXPECT_EQ(obj->call(0), 0u); + EXPECT_NE(obj->m_lastErrorMsg.find("invalid method"), std::string::npos); +} + +TEST(IWireObjectMatrix, CallRejectsMethodSinceNewerThanObjectVersion) { + auto obj = makeObject(); + obj->m_id = 55; + obj->m_version = 1; + obj->m_methodsOut = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 3}, + }; + + EXPECT_EQ(obj->call(0), 0u); + EXPECT_NE(obj->m_lastErrorMsg.find("since"), std::string::npos); +} + +TEST(IWireObjectMatrix, ServerSideCallRejectsReturnsTypeMethods) { + auto obj = makeObject(true); + obj->m_id = 5; + obj->m_methodsOut = { + SMethod{.idx = 0, .params = {}, .returnsType = "child", .since = 0}, + }; + + EXPECT_EQ(obj->call(0), 0u); + EXPECT_NE(obj->m_lastErrorMsg.find("server cannot call returnsType methods"), std::string::npos); +} + +TEST(IWireObjectMatrix, CallFailsForUnsupportedArrayElementType) { + auto obj = makeObject(); + obj->m_id = 12; + obj->m_methodsOut = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_ARRAY, HW_MESSAGE_MAGIC_TYPE_OBJECT_ID}, .returnsType = "", .since = 0}, + }; + + uint32_t dummy = 1; + EXPECT_EQ(obj->call(0, &dummy, static_cast(1)), 0u); + EXPECT_TRUE(obj->m_errd); +} + +TEST(IWireObjectMatrix, CallMarksObjectDestroyedWhenDestructorMethodIsCalled) { + auto obj = makeObject(); + obj->m_id = 44; + obj->m_methodsOut = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 0, .isDestructor = true}, + }; + + EXPECT_FALSE(obj->m_destroyed); + EXPECT_EQ(obj->call(0), 0u); + EXPECT_TRUE(obj->m_destroyed); + ASSERT_FALSE(obj->m_lastSentData.empty()); + EXPECT_EQ(obj->m_lastSentData[0], HW_MESSAGE_TYPE_GENERIC_PROTOCOL_MESSAGE); +} + +TEST(IWireObjectMatrix, CalledRejectsInvalidMethodIndex) { + auto obj = makeObject(); + obj->m_id = 91; + + const std::array data = {HW_MESSAGE_MAGIC_END}; + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("invalid method"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledRejectsMethodSinceNewerThanObjectVersion) { + auto obj = makeObject(); + obj->m_id = 77; + obj->m_version = 1; + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 2}, + }; + obj->listen(0, fnToVoid(&onNoop)); + + const std::array data = {HW_MESSAGE_MAGIC_END}; + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("since"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledRejectsTypeMismatchBetweenSpecAndWire) { + auto obj = makeObject(); + obj->m_id = 73; + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + obj->listen(0, fnToVoid(&onNoop)); + + const std::array badData = { + HW_MESSAGE_MAGIC_TYPE_INT, 0, 0, 0, 0, HW_MESSAGE_MAGIC_END, + }; + + obj->called(0, std::span{badData.data(), badData.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("should be"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledRejectsArrayWireTypeMismatch) { + auto obj = makeObject(); + obj->m_id = 12; + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_ARRAY, HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + obj->listen(0, fnToVoid(&onNoop)); + + const std::array badData = { + HW_MESSAGE_MAGIC_TYPE_ARRAY, + HW_MESSAGE_MAGIC_TYPE_INT, + 0x00, + HW_MESSAGE_MAGIC_END, + }; + + obj->called(0, std::span{badData.data(), badData.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("should be"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledRejectsOversizedArrayPayload) { + auto obj = makeObject(); + obj->m_id = 88; + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_ARRAY, HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + obj->listen(0, fnToVoid(&onNoop)); + + CMessageParser parser; + auto lenVarInt = parser.encodeVarInt(10001); + + std::vector data = { + HW_MESSAGE_MAGIC_TYPE_ARRAY, + HW_MESSAGE_MAGIC_TYPE_UINT, + }; + data.insert(data.end(), lenVarInt.begin(), lenVarInt.end()); + data.push_back(HW_MESSAGE_MAGIC_END); + + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("max array size"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledRejectsObjectIdMagicType) { + auto obj = makeObject(); + obj->m_id = 19; + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_OBJECT_ID}, .returnsType = "", .since = 0}, + }; + obj->listen(0, fnToVoid(&onNoop)); + + const std::array data = {HW_MESSAGE_MAGIC_TYPE_OBJECT_ID}; + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_NE(obj->m_lastErrorMsg.find("object type is not impld"), std::string::npos); +} + +TEST(IWireObjectMatrix, CalledMarksDestroyedForDestructorWithoutListener) { + auto obj = makeObject(); + obj->m_id = 0; // avoid concrete cast path + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 0, .isDestructor = true}, + }; + + const std::array data = {HW_MESSAGE_MAGIC_END}; + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_TRUE(obj->m_destroyed); +} + +TEST(IWireObjectMatrix, CalledMarksDestroyedForDestructorWithListener) { + auto obj = makeObject(); + obj->m_id = 0; // avoid concrete cast path + obj->m_methodsIn = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 0, .isDestructor = true}, + }; + obj->listen(0, fnToVoid(&onDestructor)); + + g_destructorListenerCalls = 0; + + const std::array data = {HW_MESSAGE_MAGIC_END}; + obj->called(0, std::span{data.data(), data.size()}, {}); + + EXPECT_EQ(g_destructorListenerCalls, 1); + EXPECT_TRUE(obj->m_destroyed); +} diff --git a/tests/unit/IntegrationCore.cpp b/tests/unit/IntegrationCore.cpp new file mode 100644 index 0000000..f8d2157 --- /dev/null +++ b/tests/unit/IntegrationCore.cpp @@ -0,0 +1,503 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace Hyprwire; +using namespace Hyprutils::Memory; + +namespace { + + template + using SP = CSharedPointer; + + template + void* fnToVoid(Fn fn) { + static_assert(std::is_pointer_v); + static_assert(sizeof(Fn) == sizeof(void*)); + + union { + Fn f; + void* p; + } caster = { + .f = fn, + }; + + return caster.p; + } + + class CManagerSpec final : public IProtocolObjectSpec { + public: + std::string objectName() override { + return "manager"; + } + + const std::vector& c2s() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_VARCHAR}, .returnsType = "", .since = 0}, + SMethod{.idx = 1, .params = {}, .returnsType = "child", .since = 0}, + SMethod{.idx = 2, .params = {HW_MESSAGE_MAGIC_TYPE_FD}, .returnsType = "", .since = 0}, + SMethod{.idx = 3, .params = {HW_MESSAGE_MAGIC_TYPE_ARRAY, HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + + return methods; + } + + const std::vector& s2c() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_VARCHAR}, .returnsType = "", .since = 0}, + }; + + return methods; + } + }; + + class CChildSpec final : public IProtocolObjectSpec { + public: + std::string objectName() override { + return "child"; + } + + const std::vector& c2s() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + + return methods; + } + + const std::vector& s2c() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + + return methods; + } + }; + + class CIntegrationProtocolSpec final : public IProtocolSpec { + public: + std::string specName() override { + return "integration_protocol"; + } + + uint32_t specVer() override { + return 1; + } + + std::vector> objects() override { + return { + m_managerSpec, + m_childSpec, + }; + } + + private: + SP m_managerSpec = makeShared(); + SP m_childSpec = makeShared(); + }; + + class CIntegrationClientImpl final : public IProtocolClientImplementation { + public: + explicit CIntegrationClientImpl(SP protocolSpec) : m_protocolSpec(std::move(protocolSpec)) { + ; + } + + SP protocol() override { + return m_protocolSpec; + } + + std::vector> implementation() override { + return { + makeShared(SClientObjectImplementation{.objectName = "manager", .version = 1}), + makeShared(SClientObjectImplementation{.objectName = "child", .version = 1}), + }; + } + + private: + SP m_protocolSpec; + }; + + class CIntegrationServerImpl final : public IProtocolServerImplementation { + public: + CIntegrationServerImpl(SP protocolSpec, std::function)>&& bindFn) : m_protocolSpec(std::move(protocolSpec)), m_bindFn(std::move(bindFn)) { + ; + } + + SP protocol() override { + return m_protocolSpec; + } + + std::vector> implementation() override { + return { + makeShared(SServerObjectImplementation{.objectName = "manager", .version = 1, .onBind = m_bindFn}), + makeShared(SServerObjectImplementation{.objectName = "child", .version = 1}), + }; + } + + private: + SP m_protocolSpec; + std::function)> m_bindFn; + }; + + class CIntegrationHarness { + public: + CIntegrationHarness() { + int fds[2] = {-1, -1}; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) + throw std::runtime_error("socketpair failed"); + + m_server = IServerSocket::open(); + if (!m_server) + throw std::runtime_error("server open failed"); + + m_protocolSpec = makeShared(); + m_serverImpl = makeShared(m_protocolSpec, [this](SP obj) { onManagerBind(std::move(obj)); }); + m_clientImpl = makeShared(m_protocolSpec); + + m_server->addImplementation(std::move(m_serverImpl)); + m_serverClient = m_server->addClient(fds[0]); + if (!m_serverClient) + throw std::runtime_error("server addClient failed"); + + m_client = IClientSocket::open(fds[1]); + if (!m_client) + throw std::runtime_error("client open failed"); + + m_client->addImplementation(std::move(m_clientImpl)); + + m_pumpThread = std::thread([this] { + while (!m_stopPump) { + if (m_server) + m_server->dispatchEvents(false); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + + ~CIntegrationHarness() { + m_stopPump = true; + if (m_pumpThread.joinable()) + m_pumpThread.join(); + } + + bool waitForHandshake() { + return m_client->waitForHandshake(); + } + + SP bindManager() { + auto manager = m_client->bindProtocol(m_protocolSpec, 1); + if (!manager) + return nullptr; + + manager->setData(this); + manager->listen(0, fnToVoid(&CIntegrationHarness::onManagerNotify)); + + { + std::scoped_lock lock(m_stateMutex); + m_managerClientObject = manager; + } + + return manager; + } + + bool pumpClientUntil(const std::function& pred, std::chrono::milliseconds timeout = std::chrono::milliseconds(1500)) { + const auto start = std::chrono::steady_clock::now(); + while (std::chrono::steady_clock::now() - start < timeout) { + if (!m_client->dispatchEvents(false) && pred()) + return true; + + if (pred()) + return true; + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + return pred(); + } + + SP childClientObjectForSeq(uint32_t seq) { + return m_client->objectForSeq(seq); + } + + void attachChildClientListener(const SP& child) { + child->setData(this); + child->listen(0, fnToVoid(&CIntegrationHarness::onChildNotify)); + } + + void roundtrip() { + m_client->roundtrip(); + } + + bool dispatchClient() { + return m_client->dispatchEvents(false); + } + + int clientFD() { + return m_client->extractLoopFD(); + } + + SP client() { + return m_client; + } + + SP protocolSpec() { + return m_protocolSpec; + } + + std::vector serverPings() { + std::scoped_lock lock(m_stateMutex); + return m_serverPings; + } + + std::vector clientNotifications() { + std::scoped_lock lock(m_stateMutex); + return m_clientNotifications; + } + + std::vector serverArrayPayload() { + std::scoped_lock lock(m_stateMutex); + return m_serverArrayPayload; + } + + std::string serverFdPayload() { + std::scoped_lock lock(m_stateMutex); + return m_serverFdPayload; + } + + uint32_t childPingValue() { + std::scoped_lock lock(m_stateMutex); + return m_childPingValue; + } + + uint32_t childNotifyValue() { + std::scoped_lock lock(m_stateMutex); + return m_childNotifyValue; + } + + private: + static void onManagerPing(IObject* object, const char* msg) { + auto* self = static_cast(object->getData()); + { + std::scoped_lock lock(self->m_stateMutex); + self->m_serverPings.emplace_back(msg ? msg : ""); + } + + object->call(0, "pong"); + } + + static void onManagerMakeChild(IObject* object, uint32_t seq) { + auto* self = static_cast(object->getData()); + + auto child = object->serverSock()->createObject(object->client(), object->self(), "child", seq); + if (!child) + return; + + child->setData(self); + child->listen(0, fnToVoid(&CIntegrationHarness::onChildPing)); + } + + static void onManagerSendFd(IObject* object, int32_t fd) { + auto* self = static_cast(object->getData()); + + char buf[32] = {0}; + int n = static_cast(read(fd, buf, sizeof(buf) - 1)); + + std::scoped_lock lock(self->m_stateMutex); + self->m_serverFdPayload = n > 0 ? std::string{buf, static_cast(n)} : ""; + } + + static void onManagerSendArray(IObject* object, uint32_t* data, uint32_t len) { + auto* self = static_cast(object->getData()); + + std::vector payload; + payload.reserve(len); + for (uint32_t i = 0; i < len; ++i) { + payload.emplace_back(data[i]); + } + + std::scoped_lock lock(self->m_stateMutex); + self->m_serverArrayPayload = std::move(payload); + } + + static void onChildPing(IObject* object, uint32_t value) { + auto* self = static_cast(object->getData()); + + { + std::scoped_lock lock(self->m_stateMutex); + self->m_childPingValue = value; + } + + object->call(0, value + 1); + } + + static void onManagerNotify(IObject* object, const char* msg) { + auto* self = static_cast(object->getData()); + std::scoped_lock lock(self->m_stateMutex); + self->m_clientNotifications.emplace_back(msg ? msg : ""); + } + + static void onChildNotify(IObject* object, uint32_t value) { + auto* self = static_cast(object->getData()); + std::scoped_lock lock(self->m_stateMutex); + self->m_childNotifyValue = value; + } + + void onManagerBind(SP obj) { + obj->setData(this); + obj->listen(0, fnToVoid(&CIntegrationHarness::onManagerPing)); + obj->listen(1, fnToVoid(&CIntegrationHarness::onManagerMakeChild)); + obj->listen(2, fnToVoid(&CIntegrationHarness::onManagerSendFd)); + obj->listen(3, fnToVoid(&CIntegrationHarness::onManagerSendArray)); + } + + private: + std::atomic m_stopPump = false; + std::thread m_pumpThread; + + SP m_protocolSpec; + SP m_serverImpl; + SP m_clientImpl; + + SP m_server; + SP m_serverClient; + SP m_client; + + SP m_managerClientObject; + + std::mutex m_stateMutex; + std::vector m_serverPings; + std::vector m_clientNotifications; + std::vector m_serverArrayPayload; + std::string m_serverFdPayload; + uint32_t m_childPingValue = 0; + uint32_t m_childNotifyValue = 0; + }; + +} // namespace + +TEST(IntegrationCore, AnonymousServerCanAddAndRemoveClients) { + int fds[2] = {-1, -1}; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0); + + auto server = IServerSocket::open(); + ASSERT_NE(server, nullptr); + + auto addedClient = server->addClient(fds[0]); + ASSERT_NE(addedClient, nullptr); + + EXPECT_TRUE(server->removeClient(fds[0])); + EXPECT_FALSE(server->removeClient(fds[0])); + + close(fds[1]); +} + +TEST(IntegrationCore, HandshakeAndSpecDiscoveryWorks) { + CIntegrationHarness harness; + + ASSERT_TRUE(harness.waitForHandshake()); + ASSERT_TRUE(harness.client()->isHandshakeDone()); + + auto spec = harness.client()->getSpec(harness.protocolSpec()->specName()); + ASSERT_NE(spec, nullptr); + EXPECT_EQ(spec->specName(), "integration_protocol"); + EXPECT_EQ(spec->specVer(), 1u); + + EXPECT_EQ(harness.client()->getSpec("does_not_exist"), nullptr); +} + +TEST(IntegrationCore, BindPingAndRoundtripFlowWorks) { + CIntegrationHarness harness; + + ASSERT_TRUE(harness.waitForHandshake()); + + auto manager = harness.bindManager(); + ASSERT_NE(manager, nullptr); + + manager->call(0, "hello"); + harness.roundtrip(); + + ASSERT_TRUE(harness.pumpClientUntil([&harness] { return !harness.clientNotifications().empty(); })); + + const auto pings = harness.serverPings(); + ASSERT_EQ(pings.size(), 1u); + EXPECT_EQ(pings[0], "hello"); + + const auto notifications = harness.clientNotifications(); + ASSERT_FALSE(notifications.empty()); + EXPECT_EQ(notifications.back(), "pong"); +} + +TEST(IntegrationCore, FdArrayAndReturnedObjectFlowWorks) { + CIntegrationHarness harness; + + ASSERT_TRUE(harness.waitForHandshake()); + + auto manager = harness.bindManager(); + ASSERT_NE(manager, nullptr); + + int pipefd[2] = {-1, -1}; + ASSERT_EQ(pipe(pipefd), 0); + ASSERT_EQ(write(pipefd[1], "pipe!", 5), 5); + + manager->call(2, pipefd[0]); + + uint32_t numbers[3] = {69, 420, 2137}; + manager->call(3, numbers, static_cast(3)); + + const uint32_t childSeq = manager->call(1); + ASSERT_NE(childSeq, 0u); + + auto child = harness.childClientObjectForSeq(childSeq); + ASSERT_NE(child, nullptr); + harness.attachChildClientListener(child); + + child->call(0, 41u); + + harness.roundtrip(); + + ASSERT_TRUE(harness.pumpClientUntil([&harness] { return harness.childNotifyValue() != 0; })); + + EXPECT_EQ(harness.serverFdPayload(), "pipe!"); + EXPECT_EQ(harness.serverArrayPayload(), (std::vector{69, 420, 2137})); + EXPECT_EQ(harness.childPingValue(), 41u); + EXPECT_EQ(harness.childNotifyValue(), 42u); + + close(pipefd[0]); + close(pipefd[1]); +} + +TEST(IntegrationCore, MalformedMessageDisconnectsClient) { + CIntegrationHarness harness; + + ASSERT_TRUE(harness.waitForHandshake()); + + const std::array badMessage = {0xFF, HW_MESSAGE_MAGIC_END}; + ASSERT_EQ(write(harness.clientFD(), badMessage.data(), badMessage.size()), static_cast(badMessage.size())); + + bool disconnected = false; + const auto start = std::chrono::steady_clock::now(); + while (std::chrono::steady_clock::now() - start < std::chrono::milliseconds(1500)) { + if (!harness.dispatchClient()) { + disconnected = true; + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + EXPECT_TRUE(disconnected); +} diff --git a/tests/unit/IntegrationSemantics.cpp b/tests/unit/IntegrationSemantics.cpp new file mode 100644 index 0000000..8c9063f --- /dev/null +++ b/tests/unit/IntegrationSemantics.cpp @@ -0,0 +1,353 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace Hyprwire; +using namespace Hyprutils::Memory; + +namespace { + + template + using SP = CSharedPointer; + + template + void* fnToVoid(Fn fn) { + static_assert(std::is_pointer_v); + static_assert(sizeof(Fn) == sizeof(void*)); + + union { + Fn f; + void* p; + } caster = { + .f = fn, + }; + + return caster.p; + } + + class CManagerSpec final : public IProtocolObjectSpec { + public: + std::string objectName() override { + return "manager"; + } + + const std::vector& c2s() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {}, .returnsType = "child", .since = 0}, + }; + return methods; + } + + const std::vector& s2c() override { + static const std::vector methods = {}; + return methods; + } + }; + + class CChildSpec final : public IProtocolObjectSpec { + public: + std::string objectName() override { + return "child"; + } + + const std::vector& c2s() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {}, .returnsType = "", .since = 0, .isDestructor = true}, + SMethod{.idx = 1, .params = {HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + return methods; + } + + const std::vector& s2c() override { + static const std::vector methods = { + SMethod{.idx = 0, .params = {HW_MESSAGE_MAGIC_TYPE_UINT}, .returnsType = "", .since = 0}, + }; + return methods; + } + }; + + class CSemanticsProtocolSpec final : public IProtocolSpec { + public: + std::string specName() override { + return "semantics_protocol"; + } + + uint32_t specVer() override { + return 1; + } + + std::vector> objects() override { + return {m_manager, m_child}; + } + + private: + SP m_manager = makeShared(); + SP m_child = makeShared(); + }; + + class CClientImpl final : public IProtocolClientImplementation { + public: + explicit CClientImpl(SP spec) : m_spec(std::move(spec)) { + ; + } + + SP protocol() override { + return m_spec; + } + + std::vector> implementation() override { + return { + makeShared(SClientObjectImplementation{.objectName = "manager", .version = 1}), + makeShared(SClientObjectImplementation{.objectName = "child", .version = 1}), + }; + } + + private: + SP m_spec; + }; + + class CServerImpl final : public IProtocolServerImplementation { + public: + CServerImpl(SP spec, std::function)>&& bindFn) : m_spec(std::move(spec)), m_bindFn(std::move(bindFn)) { + ; + } + + SP protocol() override { + return m_spec; + } + + std::vector> implementation() override { + return { + makeShared(SServerObjectImplementation{.objectName = "manager", .version = 1, .onBind = m_bindFn}), + makeShared(SServerObjectImplementation{.objectName = "child", .version = 1}), + }; + } + + private: + SP m_spec; + std::function)> m_bindFn; + }; + + class CSemanticsHarness { + public: + CSemanticsHarness() { + int fds[2] = {-1, -1}; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) + throw std::runtime_error("socketpair failed"); + + m_server = IServerSocket::open(); + if (!m_server) + throw std::runtime_error("server open failed"); + + m_spec = makeShared(); + m_serverImpl = makeShared(m_spec, [this](SP obj) { onManagerBind(std::move(obj)); }); + m_clientImpl = makeShared(m_spec); + + m_server->addImplementation(std::move(m_serverImpl)); + auto serverClient = m_server->addClient(fds[0]); + if (!serverClient) + throw std::runtime_error("server addClient failed"); + + m_client = IClientSocket::open(fds[1]); + if (!m_client) + throw std::runtime_error("client open failed"); + + m_client->addImplementation(std::move(m_clientImpl)); + + m_pumpThread = std::thread([this] { + while (!m_stop) { + m_server->dispatchEvents(false); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + + ~CSemanticsHarness() { + m_stop = true; + if (m_pumpThread.joinable()) + m_pumpThread.join(); + } + + bool handshake() { + return m_client->waitForHandshake(); + } + + SP bindManager() { + return m_client->bindProtocol(m_spec, 1); + } + + SP objectForSeq(uint32_t seq) { + return m_client->objectForSeq(seq); + } + + void attachChildListener(const SP& child) { + child->setData(this); + child->listen(0, fnToVoid(&CSemanticsHarness::onChildPong)); + } + + void roundtrip() { + m_client->roundtrip(); + } + + bool dispatchClient() { + return m_client->dispatchEvents(false); + } + + uint32_t childPong() const { + return m_childPong.load(); + } + + uint32_t childPingCount() const { + return m_childPingCount.load(); + } + + uint32_t childDestroyRequests() const { + return m_childDestroyRequests.load(); + } + + uint32_t childDestroyCallbacks() const { + return m_childDestroyCallbacks.load(); + } + + private: + static void onManagerCreateChild(IObject* object, uint32_t seq) { + auto* self = static_cast(object->getData()); + + auto child = object->serverSock()->createObject(object->client(), object->self(), "child", seq); + if (!child) + return; + + child->setData(self); + child->listen(0, fnToVoid(&CSemanticsHarness::onChildDestroy)); + child->listen(1, fnToVoid(&CSemanticsHarness::onChildPing)); + child->setOnDestroy([self] { self->m_childDestroyCallbacks.fetch_add(1); }); + } + + static void onChildDestroy(IObject* object) { + auto* self = static_cast(object->getData()); + self->m_childDestroyRequests.fetch_add(1); + } + + static void onChildPing(IObject* object, uint32_t value) { + auto* self = static_cast(object->getData()); + self->m_childPingCount.fetch_add(1); + object->call(0, value + 1); + } + + static void onChildPong(IObject* object, uint32_t value) { + auto* self = static_cast(object->getData()); + self->m_childPong = value; + } + + void onManagerBind(SP object) { + object->setData(this); + object->listen(0, fnToVoid(&CSemanticsHarness::onManagerCreateChild)); + } + + private: + std::atomic m_stop = false; + std::thread m_pumpThread; + + SP m_spec; + SP m_serverImpl; + SP m_clientImpl; + SP m_server; + SP m_client; + + std::atomic m_childPong = 0; + std::atomic m_childPingCount = 0; + std::atomic m_childDestroyRequests = 0; + std::atomic m_childDestroyCallbacks = 0; + }; + +} // namespace + +TEST(IntegrationSemantics, DestructorMethodsDestroyObjectsAndRejectFurtherCalls) { + CSemanticsHarness harness; + ASSERT_TRUE(harness.handshake()); + + auto manager = harness.bindManager(); + ASSERT_NE(manager, nullptr); + + const uint32_t childSeq = manager->call(0); + ASSERT_NE(childSeq, 0u); + + auto child = harness.objectForSeq(childSeq); + ASSERT_NE(child, nullptr); + harness.attachChildListener(child); + + child->call(1, 41u); + harness.roundtrip(); + + for (int i = 0; i < 500 && harness.childPong() != 42u; ++i) { + harness.dispatchClient(); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + EXPECT_EQ(harness.childPong(), 42u); + EXPECT_EQ(harness.childPingCount(), 1u); + + child->call(0); // destructor method + harness.roundtrip(); + + for (int i = 0; i < 500 && harness.childDestroyCallbacks() < 1u; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + EXPECT_EQ(harness.childDestroyRequests(), 1u); + EXPECT_EQ(harness.childDestroyCallbacks(), 1u); + + child->call(1, 55u); // stale object ID, should trigger fatal + disconnect + + bool disconnected = false; + for (int i = 0; i < 500; ++i) { + if (!harness.dispatchClient()) { + disconnected = true; + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + EXPECT_TRUE(disconnected); + EXPECT_EQ(harness.childPingCount(), 1u); +} + +TEST(IntegrationSemantics, LostClientObjectAutoCallsProtocolDestructor) { + CSemanticsHarness harness; + ASSERT_TRUE(harness.handshake()); + + auto manager = harness.bindManager(); + ASSERT_NE(manager, nullptr); + + const uint32_t childSeq = manager->call(0); + ASSERT_NE(childSeq, 0u); + + // We intentionally never grab the returned object by seq. + // This emulates the user "losing" the object immediately. + + for (int i = 0; i < 500 && harness.childDestroyRequests() == 0; ++i) { + harness.dispatchClient(); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + for (int i = 0; i < 500 && harness.childDestroyCallbacks() == 0; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + + EXPECT_EQ(harness.childDestroyRequests(), 1u); + EXPECT_EQ(harness.childDestroyCallbacks(), 1u); +} diff --git a/tests/unit/IntegrationSocketPath.cpp b/tests/unit/IntegrationSocketPath.cpp new file mode 100644 index 0000000..d54426c --- /dev/null +++ b/tests/unit/IntegrationSocketPath.cpp @@ -0,0 +1,141 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace Hyprwire; + +namespace { + + std::string makeSocketPath(const std::string& stem) { + auto dir = std::filesystem::temp_directory_path() / std::format("hyprwire-tests-{}-{}", stem, getpid()); + std::filesystem::create_directories(dir); + return (dir / "wire.sock").string(); + } + + class CPathHarness { + public: + explicit CPathHarness(std::string path) : m_path(std::move(path)) { + m_server = IServerSocket::open(m_path); + if (!m_server) + throw std::runtime_error("server open(path) failed"); + + m_pumpThread = std::thread([this] { + while (!m_stop) { + m_server->dispatchEvents(false); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + }); + } + + ~CPathHarness() { + m_stop = true; + if (m_pumpThread.joinable()) + m_pumpThread.join(); + + m_server.reset(); + std::error_code ec; + std::filesystem::remove(m_path, ec); + std::filesystem::remove(std::filesystem::path(m_path).parent_path(), ec); + } + + int loopFD() { + return m_server->extractLoopFD(); + } + + Hyprutils::Memory::CSharedPointer openClient() { + return IClientSocket::open(m_path); + } + + private: + std::string m_path; + std::atomic m_stop = false; + std::thread m_pumpThread; + Hyprutils::Memory::CSharedPointer m_server; + }; + +} // namespace + +TEST(IntegrationSocketPath, PathOpenHandshakeWorks) { + CPathHarness harness{makeSocketPath("open")}; + + auto client = harness.openClient(); + ASSERT_NE(client, nullptr); + + EXPECT_TRUE(client->waitForHandshake()); + EXPECT_TRUE(client->isHandshakeDone()); +} + +TEST(IntegrationSocketPath, ExtractLoopFdSignalsPendingConnectionWork) { + const auto socketPath = makeSocketPath("loopfd"); + + auto server = IServerSocket::open(socketPath); + ASSERT_NE(server, nullptr); + + const int loopFD = server->extractLoopFD(); + ASSERT_GE(loopFD, 0); + + auto client = IClientSocket::open(socketPath); + ASSERT_NE(client, nullptr); + + pollfd pfd = { + .fd = loopFD, + .events = POLLIN, + }; + + const int pollRet = poll(&pfd, 1, 1000); + ASSERT_GT(pollRet, 0); + EXPECT_TRUE(pfd.revents & POLLIN); + + server->dispatchEvents(false); + + std::error_code ec; + std::filesystem::remove(socketPath, ec); + std::filesystem::remove(std::filesystem::path(socketPath).parent_path(), ec); +} + +TEST(IntegrationSocketPath, RecoversFromStaleSocketFile) { + const auto socketPath = makeSocketPath("stale"); + + int staleFd = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_GE(staleFd, 0); + + sockaddr_un addr = { + .sun_family = AF_UNIX, + }; + strncpy(addr.sun_path, socketPath.c_str(), sizeof(addr.sun_path) - 1); + + ASSERT_EQ(bind(staleFd, reinterpret_cast(&addr), SUN_LEN(&addr)), 0); + close(staleFd); + + ASSERT_TRUE(std::filesystem::exists(socketPath)); + + { + auto server = IServerSocket::open(socketPath); + ASSERT_NE(server, nullptr); + } + + std::error_code ec; + std::filesystem::remove(socketPath, ec); + std::filesystem::remove(std::filesystem::path(socketPath).parent_path(), ec); +} + +TEST(IntegrationSocketPath, RejectsPathLongerThanUnixSocketLimit) { + const std::string longPath = "/tmp/" + std::string(200, 'x'); + + EXPECT_EQ(IServerSocket::open(longPath), nullptr); + EXPECT_EQ(IClientSocket::open(longPath), nullptr); +} diff --git a/tests/unit/MessageControl.cpp b/tests/unit/MessageControl.cpp new file mode 100644 index 0000000..bdfded2 --- /dev/null +++ b/tests/unit/MessageControl.cpp @@ -0,0 +1,77 @@ +#include + +#include "core/message/messages/BindProtocol.hpp" +#include "core/message/messages/FatalProtocolError.hpp" +#include "core/message/messages/NewObject.hpp" +#include "core/message/messages/RoundtripDone.hpp" +#include "core/message/messages/RoundtripRequest.hpp" + +#include + +using namespace Hyprwire; + +TEST(MessagesControl, BindProtocolRoundtripParsesFields) { + CBindProtocolMessage out("my_proto", 42, 7); + CBindProtocolMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_protocol, "my_proto"); + EXPECT_EQ(in.m_seq, 42); + EXPECT_EQ(in.m_version, 7); +} + +TEST(MessagesControl, BindProtocolRejectsZeroVersion) { + CBindProtocolMessage out("my_proto", 42, 0); + CBindProtocolMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, 0); +} + +TEST(MessagesControl, NewObjectRoundtripParsesFields) { + CNewObjectMessage out(123, 0xBEEF); + CNewObjectMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_seq, 123); + EXPECT_EQ(in.m_id, 0xBEEF); +} + +TEST(MessagesControl, FatalErrorRoundtripParsesFields) { + CFatalErrorMessage out(nullptr, 99, "boom"); + CFatalErrorMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_objectId, 0u); + EXPECT_EQ(in.m_errorId, 99u); + EXPECT_EQ(in.m_errorMsg, "boom"); +} + +TEST(MessagesControl, RoundtripMessagesRoundtripParsesSeq) { + CRoundtripRequestMessage reqOut(777); + CRoundtripRequestMessage reqIn(reqOut.m_data, 0); + + EXPECT_EQ(reqIn.m_len, reqOut.m_data.size()); + EXPECT_EQ(reqIn.m_seq, 777u); + + CRoundtripDoneMessage doneOut(888); + CRoundtripDoneMessage doneIn(doneOut.m_data, 0); + + EXPECT_EQ(doneIn.m_len, doneOut.m_data.size()); + EXPECT_EQ(doneIn.m_seq, 888u); +} + +TEST(MessagesControl, RoundtripRejectsMalformedPayload) { + const std::vector badReq = { + HW_MESSAGE_TYPE_ROUNDTRIP_REQUEST, HW_MESSAGE_MAGIC_TYPE_VARCHAR, 0x01, 'x', HW_MESSAGE_MAGIC_END, + }; + + const std::vector badDone = { + HW_MESSAGE_TYPE_ROUNDTRIP_DONE, HW_MESSAGE_MAGIC_TYPE_UINT, 0x01, 0x02, 0x03, 0x04, HW_MESSAGE_MAGIC_TYPE_UINT, + }; + + CRoundtripRequestMessage req(badReq, 0); + CRoundtripDoneMessage done(badDone, 0); + + EXPECT_EQ(req.m_len, 0); + EXPECT_EQ(done.m_len, 0); +} diff --git a/tests/unit/MessageGenericProtocol.cpp b/tests/unit/MessageGenericProtocol.cpp new file mode 100644 index 0000000..634c638 --- /dev/null +++ b/tests/unit/MessageGenericProtocol.cpp @@ -0,0 +1,133 @@ +#include + +#include "core/message/MessageParser.hpp" +#include "core/message/messages/GenericProtocolMessage.hpp" + +#include + +#include + +using namespace Hyprwire; + +static void appendU32(std::vector& data, uint32_t value) { + data.resize(data.size() + 4); + std::memcpy(data.data() + data.size() - 4, &value, sizeof(value)); +} + +static std::vector makeGenericHeader(uint32_t object, uint32_t method) { + std::vector data = { + HW_MESSAGE_TYPE_GENERIC_PROTOCOL_MESSAGE, + HW_MESSAGE_MAGIC_TYPE_OBJECT, + }; + appendU32(data, object); + data.push_back(HW_MESSAGE_MAGIC_TYPE_UINT); + appendU32(data, method); + return data; +} + +TEST(MessagesGenericProtocol, ParsesObjectMethodAndPayloadSpan) { + auto raw = makeGenericHeader(0xABCD, 3); + + raw.push_back(HW_MESSAGE_MAGIC_TYPE_UINT); + appendU32(raw, 55); + raw.push_back(HW_MESSAGE_MAGIC_END); + + std::vector fds; + CGenericProtocolMessage msg(raw, fds, 0); + + ASSERT_EQ(msg.m_len, raw.size()); + EXPECT_EQ(msg.m_object, 0xABCDu); + EXPECT_EQ(msg.m_method, 3u); + ASSERT_FALSE(msg.m_dataSpan.empty()); + EXPECT_EQ(msg.m_dataSpan.front(), HW_MESSAGE_MAGIC_TYPE_UINT); + EXPECT_EQ(msg.m_dataSpan.back(), HW_MESSAGE_MAGIC_END); + EXPECT_TRUE(msg.m_fds.empty()); + EXPECT_TRUE(fds.empty()); +} + +TEST(MessagesGenericProtocol, ConsumesSingleFdToken) { + auto raw = makeGenericHeader(1, 2); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_FD); + raw.push_back(HW_MESSAGE_MAGIC_END); + + std::vector fds = {11}; + CGenericProtocolMessage msg(raw, fds, 0); + + ASSERT_EQ(msg.m_len, raw.size()); + ASSERT_EQ(msg.m_fds.size(), 1u); + EXPECT_EQ(msg.m_fds[0], 11); + EXPECT_TRUE(fds.empty()); +} + +TEST(MessagesGenericProtocol, ConsumesArrayFdTokens) { + CMessageParser parser; + auto raw = makeGenericHeader(1, 2); + + raw.push_back(HW_MESSAGE_MAGIC_TYPE_ARRAY); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_FD); + const auto count = parser.encodeVarInt(2); + raw.insert(raw.end(), count.begin(), count.end()); + raw.push_back(HW_MESSAGE_MAGIC_END); + + std::vector fds = {4, 5}; + CGenericProtocolMessage msg(raw, fds, 0); + + ASSERT_EQ(msg.m_len, raw.size()); + ASSERT_EQ(msg.m_fds.size(), 2u); + EXPECT_EQ(msg.m_fds[0], 4); + EXPECT_EQ(msg.m_fds[1], 5); + EXPECT_TRUE(fds.empty()); +} + +TEST(MessagesGenericProtocol, RejectsFdTokenWithEmptyFdQueue) { + auto raw = makeGenericHeader(1, 2); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_FD); + raw.push_back(HW_MESSAGE_MAGIC_END); + + std::vector fds; + CGenericProtocolMessage msg(raw, fds, 0); + + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesGenericProtocol, RejectsInvalidArrayType) { + auto raw = makeGenericHeader(1, 2); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_ARRAY); + raw.push_back(HW_MESSAGE_MAGIC_END); // invalid element type + raw.push_back(0x00); // arrLen varint = 0 + raw.push_back(HW_MESSAGE_MAGIC_END); + + std::vector fds; + CGenericProtocolMessage msg(raw, fds, 0); + + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesGenericProtocol, RejectsTooLargeArray) { + CMessageParser parser; + auto raw = makeGenericHeader(1, 2); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_ARRAY); + raw.push_back(HW_MESSAGE_MAGIC_TYPE_UINT); + + const auto oversized = parser.encodeVarInt(10000); + raw.insert(raw.end(), oversized.begin(), oversized.end()); + + std::vector fds; + CGenericProtocolMessage msg(raw, fds, 0); + + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesGenericProtocol, ResolveSeqUpdatesObjectAndSerializedPayload) { + auto raw = makeGenericHeader(1, 9); + raw.push_back(HW_MESSAGE_MAGIC_END); + + CGenericProtocolMessage msg(std::move(raw), std::vector{}); + msg.resolveSeq(0xAABBCCDD); + + EXPECT_EQ(msg.m_object, 0xAABBCCDDu); + + uint32_t encodedId = 0; + std::memcpy(&encodedId, msg.m_data.data() + 2, sizeof(encodedId)); + EXPECT_EQ(encodedId, 0xAABBCCDDu); +} diff --git a/tests/unit/MessageHelloHandshake.cpp b/tests/unit/MessageHelloHandshake.cpp new file mode 100644 index 0000000..b77817d --- /dev/null +++ b/tests/unit/MessageHelloHandshake.cpp @@ -0,0 +1,107 @@ +#include + +#include "core/message/MessageParser.hpp" +#include "core/message/messages/HandshakeAck.hpp" +#include "core/message/messages/HandshakeBegin.hpp" +#include "core/message/messages/HandshakeProtocols.hpp" +#include "core/message/messages/Hello.hpp" + +#include + +using namespace Hyprwire; + +TEST(MessagesHelloHandshake, HelloCtorBuildsExpectedWireBytes) { + CHelloMessage msg; + + const std::vector expected = { + HW_MESSAGE_TYPE_SUP, HW_MESSAGE_MAGIC_TYPE_VARCHAR, 0x03, 'V', 'A', 'X', HW_MESSAGE_MAGIC_END, + }; + + EXPECT_EQ(msg.m_data, expected); +} + +TEST(MessagesHelloHandshake, HelloParserAcceptsValidAndRejectsInvalid) { + const std::vector valid = { + HW_MESSAGE_TYPE_SUP, HW_MESSAGE_MAGIC_TYPE_VARCHAR, 0x03, 'V', 'A', 'X', HW_MESSAGE_MAGIC_END, + }; + + CHelloMessage parsedValid(valid, 0); + EXPECT_EQ(parsedValid.m_len, valid.size()); + + auto invalid = valid; + invalid[5] = '!'; + CHelloMessage parsedInvalid(invalid, 0); + EXPECT_EQ(parsedInvalid.m_len, 0); +} + +TEST(MessagesHelloHandshake, HandshakeBeginRoundtripParsesVersions) { + const std::vector versions = {1, 2, 255}; + + CHandshakeBeginMessage out(versions); + CHandshakeBeginMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_versionsSupported, versions); +} + +TEST(MessagesHelloHandshake, HandshakeBeginRejectsWrongArrayType) { + const std::vector raw = { + HW_MESSAGE_TYPE_HANDSHAKE_BEGIN, HW_MESSAGE_MAGIC_TYPE_ARRAY, HW_MESSAGE_MAGIC_TYPE_VARCHAR, 0x00, HW_MESSAGE_MAGIC_END, + }; + + CHandshakeBeginMessage msg(raw, 0); + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesHelloHandshake, HandshakeBeginRejectsTooManyVersions) { + CMessageParser parser; + std::vector raw = { + HW_MESSAGE_TYPE_HANDSHAKE_BEGIN, + HW_MESSAGE_MAGIC_TYPE_ARRAY, + HW_MESSAGE_MAGIC_TYPE_UINT, + }; + + const auto encodedCount = parser.encodeVarInt(256); + raw.insert(raw.end(), encodedCount.begin(), encodedCount.end()); + + CHandshakeBeginMessage msg(raw, 0); + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesHelloHandshake, HandshakeAckRoundtripParsesVersion) { + CHandshakeAckMessage out(7); + CHandshakeAckMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_version, 7); +} + +TEST(MessagesHelloHandshake, HandshakeAckRejectsMalformedPayload) { + const std::vector raw = { + HW_MESSAGE_TYPE_HANDSHAKE_ACK, HW_MESSAGE_MAGIC_TYPE_UINT, 0x01, 0x02, 0x03, 0x04, HW_MESSAGE_MAGIC_TYPE_UINT, + }; + + CHandshakeAckMessage msg(raw, 0); + EXPECT_EQ(msg.m_len, 0); +} + +TEST(MessagesHelloHandshake, HandshakeProtocolsRoundtripParsesProtocols) { + const std::vector protocols = { + "test_protocol@1", + "another@12", + }; + + CHandshakeProtocolsMessage out(protocols); + CHandshakeProtocolsMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_EQ(in.m_protocols, protocols); +} + +TEST(MessagesHelloHandshake, HandshakeProtocolsSupportsEmptyArray) { + CHandshakeProtocolsMessage out(std::vector{}); + CHandshakeProtocolsMessage in(out.m_data, 0); + + EXPECT_EQ(in.m_len, out.m_data.size()); + EXPECT_TRUE(in.m_protocols.empty()); +} diff --git a/tests/unit/MessageMagic.cpp b/tests/unit/MessageMagic.cpp new file mode 100644 index 0000000..a367c3e --- /dev/null +++ b/tests/unit/MessageMagic.cpp @@ -0,0 +1,27 @@ +#include + +#include "core/message/MessageMagic.hpp" + +#include + +TEST(MessageMagic, KnownValuesMapToReadableStrings) { + using namespace Hyprwire; + + const std::array known = { + HW_MESSAGE_MAGIC_END, HW_MESSAGE_MAGIC_TYPE_UINT, HW_MESSAGE_MAGIC_TYPE_INT, HW_MESSAGE_MAGIC_TYPE_F32, + HW_MESSAGE_MAGIC_TYPE_SEQ, HW_MESSAGE_MAGIC_TYPE_OBJECT_ID, HW_MESSAGE_MAGIC_TYPE_VARCHAR, HW_MESSAGE_MAGIC_TYPE_ARRAY, + HW_MESSAGE_MAGIC_TYPE_OBJECT, HW_MESSAGE_MAGIC_TYPE_FD, + }; + + for (const auto magic : known) { + const char* str = magicToString(magic); + ASSERT_NE(str, nullptr); + EXPECT_STRNE(str, "ERROR"); + } +} + +TEST(MessageMagic, UnknownValueReturnsError) { + using namespace Hyprwire; + + EXPECT_STREQ(magicToString(static_cast(0xFF)), "ERROR"); +} diff --git a/tests/unit/MessageParseData.cpp b/tests/unit/MessageParseData.cpp new file mode 100644 index 0000000..42eb815 --- /dev/null +++ b/tests/unit/MessageParseData.cpp @@ -0,0 +1,79 @@ +#include + +#include "core/message/messages/BindProtocol.hpp" +#include "core/message/messages/FatalProtocolError.hpp" +#include "core/message/messages/HandshakeAck.hpp" +#include "core/message/messages/HandshakeProtocols.hpp" +#include "core/message/messages/Hello.hpp" +#include "core/message/messages/NewObject.hpp" +#include "core/message/messages/RoundtripDone.hpp" +#include "core/message/messages/RoundtripRequest.hpp" + +using namespace Hyprwire; + +TEST(IMessageParseData, HelloContainsTypeAndPayloadString) { + CHelloMessage msg; + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("SUP"), std::string::npos); + EXPECT_NE(parsed.find("\"VAX\""), std::string::npos); +} + +TEST(IMessageParseData, HandshakeAckContainsTypeAndVersion) { + CHandshakeAckMessage msg(7); + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("HANDSHAKE_ACK"), std::string::npos); + EXPECT_NE(parsed.find('7'), std::string::npos); +} + +TEST(IMessageParseData, HandshakeProtocolsContainsProtocolNames) { + CHandshakeProtocolsMessage msg(std::vector{"proto@1", "second@2"}); + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("HANDSHAKE_PROTOCOLS"), std::string::npos); + EXPECT_NE(parsed.find("\"proto@1\""), std::string::npos); + EXPECT_NE(parsed.find("\"second@2\""), std::string::npos); +} + +TEST(IMessageParseData, BindProtocolContainsCoreFields) { + CBindProtocolMessage msg("my_proto", 12, 3); + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("BIND_PROTOCOL"), std::string::npos); + EXPECT_NE(parsed.find("12"), std::string::npos); + EXPECT_NE(parsed.find("\"my_proto\""), std::string::npos); + EXPECT_NE(parsed.find('3'), std::string::npos); +} + +TEST(IMessageParseData, NewObjectContainsObjectAndSeq) { + CNewObjectMessage msg(9, 77); + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("NEW_OBJECT"), std::string::npos); + EXPECT_NE(parsed.find("77"), std::string::npos); + EXPECT_NE(parsed.find('9'), std::string::npos); +} + +TEST(IMessageParseData, FatalErrorContainsIdentifiersAndMessage) { + CFatalErrorMessage msg(nullptr, 123, "oops"); + const auto parsed = msg.parseData(); + + EXPECT_NE(parsed.find("PROTOCOL_ERROR"), std::string::npos); + EXPECT_NE(parsed.find("123"), std::string::npos); + EXPECT_NE(parsed.find("\"oops\""), std::string::npos); +} + +TEST(IMessageParseData, RoundtripMessagesContainTypeAndSequence) { + CRoundtripRequestMessage req(777); + CRoundtripDoneMessage done(888); + + const auto reqParsed = req.parseData(); + const auto doneParsed = done.parseData(); + + EXPECT_NE(reqParsed.find("ROUNDTRIP_REQUEST"), std::string::npos); + EXPECT_NE(reqParsed.find("777"), std::string::npos); + + EXPECT_NE(doneParsed.find("ROUNDTRIP_DONE"), std::string::npos); + EXPECT_NE(doneParsed.find("888"), std::string::npos); +} diff --git a/tests/unit/MessageParserMatrix.cpp b/tests/unit/MessageParserMatrix.cpp new file mode 100644 index 0000000..0816ee1 --- /dev/null +++ b/tests/unit/MessageParserMatrix.cpp @@ -0,0 +1,119 @@ +#include + +#include "core/client/ClientSocket.hpp" +#include "core/message/MessageParser.hpp" +#include "core/message/MessageType.hpp" +#include "core/message/messages/GenericProtocolMessage.hpp" +#include "core/message/messages/BindProtocol.hpp" +#include "core/message/messages/FatalProtocolError.hpp" +#include "core/message/messages/HandshakeBegin.hpp" +#include "core/message/messages/HandshakeProtocols.hpp" +#include "core/message/messages/Hello.hpp" +#include "core/server/ServerClient.hpp" + +#include + +using namespace Hyprwire; + +namespace { + SP makeClientSocket() { + auto client = makeShared(); + client->m_self = client; + return client; + } + + SP makeServerClient() { + auto client = makeShared(-1); + client->m_self = client; + return client; + } +} + +TEST(MessageParserMatrix, ServerParserRejectsInvalidTypeCode) { + CMessageParser parser; + auto serverClient = makeServerClient(); + + SSocketRawParsedMessage raw = { + .data = {0xFF}, + }; + + EXPECT_EQ(parser.handleMessage(raw, serverClient), MESSAGE_PARSED_ERROR); + EXPECT_TRUE(serverClient->m_error); +} + +TEST(MessageParserMatrix, ServerParserReportsStrayFds) { + CMessageParser parser; + auto serverClient = makeServerClient(); + + SSocketRawParsedMessage raw = { + .data = CHelloMessage().m_data, + .fds = {123}, + }; + + EXPECT_EQ(parser.handleMessage(raw, serverClient), MESSAGE_PARSED_STRAY_FDS); +} + +TEST(MessageParserMatrix, ServerParserRejectsMalformedBindProtocol) { + CMessageParser parser; + auto serverClient = makeServerClient(); + + SSocketRawParsedMessage raw = { + .data = {HW_MESSAGE_TYPE_BIND_PROTOCOL, HW_MESSAGE_MAGIC_END}, + }; + + EXPECT_EQ(parser.handleMessage(raw, serverClient), MESSAGE_PARSED_ERROR); +} + +TEST(MessageParserMatrix, ClientParserRejectsUnsupportedVersionNegotiation) { + CMessageParser parser; + auto client = makeClientSocket(); + + CHandshakeBeginMessage begin({9999}); + + SSocketRawParsedMessage raw = { + .data = begin.m_data, + }; + + EXPECT_EQ(parser.handleMessage(raw, client), MESSAGE_PARSED_ERROR); + EXPECT_FALSE(client->m_handshakeDone); +} + +TEST(MessageParserMatrix, ClientParserReportsStrayFds) { + CMessageParser parser; + auto client = makeClientSocket(); + + CHandshakeProtocolsMessage protocols(std::vector{}); + + SSocketRawParsedMessage raw = { + .data = protocols.m_data, + .fds = {11}, + }; + + EXPECT_EQ(parser.handleMessage(raw, client), MESSAGE_PARSED_STRAY_FDS); + EXPECT_TRUE(client->m_handshakeDone); +} + +TEST(MessageParserMatrix, ClientParserRejectsInvalidTypeCode) { + CMessageParser parser; + auto client = makeClientSocket(); + + SSocketRawParsedMessage raw = { + .data = {0xFF}, + }; + + EXPECT_EQ(parser.handleMessage(raw, client), MESSAGE_PARSED_ERROR); +} + +TEST(MessageParserMatrix, ClientParserHandlesFatalProtocolErrorAndFlagsClient) { + CMessageParser parser; + auto client = makeClientSocket(); + + CFatalErrorMessage msg(7, 123, "boom"); + + SSocketRawParsedMessage raw = { + .data = msg.m_data, + }; + + EXPECT_EQ(parser.handleMessage(raw, client), MESSAGE_PARSED_OK); + EXPECT_TRUE(client->m_error); +} diff --git a/tests/unit/MessageType.cpp b/tests/unit/MessageType.cpp new file mode 100644 index 0000000..86d2f9e --- /dev/null +++ b/tests/unit/MessageType.cpp @@ -0,0 +1,28 @@ +#include + +#include "core/message/MessageType.hpp" + +#include + +TEST(MessageType, KnownValuesMapToStrings) { + using namespace Hyprwire; + + const std::array known = { + HW_MESSAGE_TYPE_INVALID, HW_MESSAGE_TYPE_SUP, HW_MESSAGE_TYPE_HANDSHAKE_BEGIN, HW_MESSAGE_TYPE_HANDSHAKE_ACK, HW_MESSAGE_TYPE_HANDSHAKE_PROTOCOLS, + HW_MESSAGE_TYPE_BIND_PROTOCOL, HW_MESSAGE_TYPE_NEW_OBJECT, HW_MESSAGE_TYPE_FATAL_PROTOCOL_ERROR, HW_MESSAGE_TYPE_ROUNDTRIP_REQUEST, HW_MESSAGE_TYPE_ROUNDTRIP_DONE, + }; + + for (const auto type : known) { + const char* str = messageTypeToStr(type); + ASSERT_NE(str, nullptr); + EXPECT_STRNE(str, "ERROR"); + } + + EXPECT_STREQ(messageTypeToStr(HW_MESSAGE_TYPE_GENERIC_PROTOCOL_MESSAGE), "GENERIC_PROTOCOL_MESSAGE"); +} + +TEST(MessageType, UnknownValueReturnsError) { + using namespace Hyprwire; + + EXPECT_STREQ(messageTypeToStr(static_cast(0xFF)), "ERROR"); +} diff --git a/tests/unit/SyscallSeams.cpp b/tests/unit/SyscallSeams.cpp new file mode 100644 index 0000000..80d426a --- /dev/null +++ b/tests/unit/SyscallSeams.cpp @@ -0,0 +1,166 @@ +#include + +#include "core/client/ClientSocket.hpp" +#include "core/message/MessageType.hpp" +#include "core/socket/SocketHelpers.hpp" +#include "helpers/Syscalls.hpp" + +#include + +#include +#include +#include +#include + +#include +#include + +using namespace Hyprwire; + +namespace { + class CScopedSyscallHooks { + public: + explicit CScopedSyscallHooks(const Syscalls::SHooks& hooks) { + Syscalls::setHooks(hooks); + } + + ~CScopedSyscallHooks() { + Syscalls::resetHooks(); + } + }; + + class CScopedHandshakeTimeout { + public: + explicit CScopedHandshakeTimeout(std::chrono::milliseconds timeout) { + CClientSocket::setHandshakeTimeoutForTests(timeout); + } + + ~CScopedHandshakeTimeout() { + CClientSocket::resetHandshakeTimeoutForTests(); + } + }; + + int g_sendmsgCalls = 0; + int g_pollCalls = 0; + ssize_t hookSendmsgRetryEagain(int, const msghdr* msg, int) { + ++g_sendmsgCalls; + if (g_sendmsgCalls == 1) { + errno = EAGAIN; + return -1; + } + + return static_cast(msg->msg_iov[0].iov_len); + } + + int hookPollAwake(pollfd*, nfds_t, int) { + ++g_pollCalls; + return 1; + } + + int hookPollTimeout(pollfd*, nfds_t, int) { + return 0; + } + + ssize_t hookRecvmsgError(int, msghdr*, int) { + errno = EIO; + return -1; + } + + ssize_t hookRecvmsgInvalidControl(int, msghdr* msg, int) { + auto* io = msg->msg_iov; + if (io && io->iov_len > 0) { + auto* bytes = static_cast(io->iov_base); + bytes[0] = HW_MESSAGE_TYPE_SUP; + } + + auto* cmsg = CMSG_FIRSTHDR(msg); + if (!cmsg) + return -1; + + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = 0x7FFF; + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + *reinterpret_cast(CMSG_DATA(cmsg)) = 123; + + msg->msg_controllen = cmsg->cmsg_len; + + return 1; + } +} + +TEST(SyscallSeams, ClientSendRetriesOnEagain) { + g_sendmsgCalls = 0; + g_pollCalls = 0; + + CScopedSyscallHooks hooks(Syscalls::SHooks{ + .poll = hookPollAwake, + .sendmsg = hookSendmsgRetryEagain, + .recvmsg = nullptr, + }); + + int fds[2] = {-1, -1}; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0); + + auto client = IClientSocket::open(fds[0]); + ASSERT_NE(client, nullptr); + + close(fds[1]); + + EXPECT_GE(g_sendmsgCalls, 2); + EXPECT_GE(g_pollCalls, 1); +} + +TEST(SyscallSeams, HandshakeTimeoutCanBeOverridden) { + CScopedSyscallHooks hooks(Syscalls::SHooks{ + .poll = hookPollTimeout, + .sendmsg = nullptr, + .recvmsg = nullptr, + }); + + int fds[2] = {-1, -1}; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0); + + CScopedHandshakeTimeout timeout{std::chrono::milliseconds(1)}; + + auto client = IClientSocket::open(fds[0]); + ASSERT_NE(client, nullptr); + + EXPECT_FALSE(client->waitForHandshake()); + EXPECT_FALSE(client->isHandshakeDone()); + + close(fds[1]); +} + +TEST(SyscallSeams, ParseFromFdReturnsBadWhenRecvmsgFails) { + CScopedSyscallHooks hooks(Syscalls::SHooks{ + .poll = nullptr, + .sendmsg = nullptr, + .recvmsg = hookRecvmsgError, + }); + + int fds[2] = {-1, -1}; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0); + + Hyprutils::OS::CFileDescriptor fd{fds[0]}; + close(fds[1]); + + const auto parsed = parseFromFd(fd); + EXPECT_TRUE(parsed.bad); +} + +TEST(SyscallSeams, ParseFromFdReturnsBadForInvalidControlMessage) { + CScopedSyscallHooks hooks(Syscalls::SHooks{ + .poll = nullptr, + .sendmsg = nullptr, + .recvmsg = hookRecvmsgInvalidControl, + }); + + int fds[2] = {-1, -1}; + ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0); + + Hyprutils::OS::CFileDescriptor fd{fds[0]}; + close(fds[1]); + + const auto parsed = parseFromFd(fd); + EXPECT_TRUE(parsed.bad); +} diff --git a/tests/unit/VarInt.cpp b/tests/unit/VarInt.cpp new file mode 100644 index 0000000..65d3c37 --- /dev/null +++ b/tests/unit/VarInt.cpp @@ -0,0 +1,47 @@ +#include + +#include "core/message/MessageParser.hpp" + +#include + +TEST(VarInt, EncodeThenParseRoundtripAtBoundaries) { + Hyprwire::CMessageParser parser; + + const std::array values = { + 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455, + }; + + for (const auto value : values) { + const auto encoded = parser.encodeVarInt(value); + const auto [decoded, n] = parser.parseVarInt(std::span{encoded.data(), encoded.size()}); + + ASSERT_FALSE(encoded.empty()); + EXPECT_EQ(decoded, value); + EXPECT_EQ(n, encoded.size()); + + EXPECT_EQ(encoded.back() & 0x80, 0); + } +} + +TEST(VarInt, ParseFromVectorWithOffset) { + Hyprwire::CMessageParser parser; + + const auto encoded = parser.encodeVarInt(420); + + std::vector data = {0xAA, 0xBB}; + data.insert(data.end(), encoded.begin(), encoded.end()); + data.push_back(0xCC); + + const auto [decoded, n] = parser.parseVarInt(data, 2); + EXPECT_EQ(decoded, 420); + EXPECT_EQ(n, encoded.size()); +} + +TEST(VarInt, ParseOutOfBoundsOffsetReturnsZeroPair) { + Hyprwire::CMessageParser parser; + + const std::vector data = {1, 2, 3}; + const auto expected = std::pair{0, 0}; + EXPECT_EQ(parser.parseVarInt(data, data.size()), expected); + EXPECT_EQ(parser.parseVarInt(data, data.size() + 42), expected); +}