Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ if(BUILD_TESTING)
message(STATUS "building tests is enabled")

enable_testing()

# Server + Client tests
add_custom_target(tests)

message(STATUS "generating protocols")
Expand All @@ -133,6 +135,29 @@ if(BUILD_TESTING)
"tests/generated/test_protocol_v1-server.cpp")
target_link_libraries(fork PRIVATE PkgConfig::deps hyprwire)
add_dependencies(tests fork)

# GTests
find_package(GTest CONFIG REQUIRED)
include(GoogleTest)
file(GLOB_RECURSE TESTFILES CONFIGURE_DEPENDS "tests/unit/*.cpp")
add_executable(hyprwire_tests ${TESTFILES})

target_compile_options(hyprwire_tests PRIVATE --coverage -fsanitize=address)
target_link_options(hyprwire_tests PRIVATE --coverage)

target_include_directories(
hyprwire_tests
PUBLIC "./include"
PRIVATE "./src" "./src/include" "./protocols" "${CMAKE_BINARY_DIR}")
target_link_libraries(hyprwire_tests PRIVATE asan hyprwire GTest::gtest_main
PkgConfig::deps)
gtest_discover_tests(hyprwire_tests
PROPERTIES ENVIRONMENT "ASAN_OPTIONS=detect_leaks=0"
)

# Add coverage to hyprwire for test builds
target_compile_options(hyprwire PRIVATE --coverage)
target_link_options(hyprwire PRIVATE --coverage)
else()
message(STATUS "building tests is disabled")
endif()
Expand Down
7 changes: 4 additions & 3 deletions include/hyprwire/core/implementation/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace Hyprwire {
struct SMethod {
uint32_t idx = 0;
std::vector<uint8_t> params;
std::string returnsType = "";
uint32_t since = 0;
std::string returnsType = "";
uint32_t since = 0;
bool isDestructor = false;
};

class IProtocolObjectSpec {
Expand All @@ -26,4 +27,4 @@ namespace Hyprwire {
IProtocolObjectSpec() = default;
};

};
};
4 changes: 3 additions & 1 deletion nix/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
stdenv,
cmake,
pkg-config,
gtest,
hyprutils,
libffi,
pugixml,
Expand All @@ -18,7 +19,8 @@ stdenv.mkDerivation {
nativeBuildInputs = [
cmake
pkg-config
];
]
++ lib.optionals doCheck [ gtest ];

buildInputs = [
hyprutils
Expand Down
6 changes: 4 additions & 2 deletions scanner/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ Hyprwire::SMethod{{
.params = {{ {} }},
.returnsType = "{}",
.since = {},
.isDestructor = {},
}},)#",
m.idx, argArrayStr, m.returns, m.since);
m.idx, argArrayStr, m.returns, m.since, m.destructor ? "true" : "false");
}

if (!object.c2s.empty())
Expand Down Expand Up @@ -391,8 +392,9 @@ Hyprwire::SMethod{{
.idx = {},
.params = {{ {} }},
.since = {},
.isDestructor = {},
}},)#",
m.idx, argArrayStr, m.since);
m.idx, argArrayStr, m.since, m.destructor ? "true" : "false");
}

if (!object.s2c.empty())
Expand Down
25 changes: 25 additions & 0 deletions src/core/client/ClientObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ CClientObject::CClientObject(SP<CClientSocket> client) : m_client(client) {
}

CClientObject::~CClientObject() {
if (!m_destroyed && m_id != 0 && m_spec && m_client && m_client->m_fd.isValid()) {
const auto methods = methodsOut();
for (const auto& method : methods) {
if (!method.isDestructor)
continue;

if (method.since > m_version)
continue;

if (!method.returnsType.empty()) {
Debug::log(WARN, "can't auto-call destructor for object {}: method {} has returns type", m_id, method.idx);
break;
}

if (!method.params.empty()) {
Debug::log(WARN, "can't auto-call destructor for object {}: method {} has params", m_id, method.idx);
break;
}

TRACE(Debug::log(TRACE, "auto-calling protocol destructor {} for object {}", method.idx, m_id));
call(method.idx);
break;
}
}

TRACE(Debug::log(TRACE, "destroying object {}", m_id));
}

Expand Down
73 changes: 61 additions & 12 deletions src/core/client/ClientSocket.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "ClientSocket.hpp"
#include "../../helpers/Memory.hpp"
#include "../../helpers/Log.hpp"
#include "../../helpers/Syscalls.hpp"
#include "../../Macros.hpp"
#include "../message/MessageParser.hpp"
#include "../message/messages/IMessage.hpp"
Expand Down Expand Up @@ -30,6 +31,10 @@ using namespace Hyprwire;
using namespace Hyprutils::OS;
using namespace Hyprutils::Utils;

namespace {
std::chrono::milliseconds g_handshakeMax = std::chrono::milliseconds(5000);
}

SP<IClientSocket> IClientSocket::open(const std::string& path) {
SP<CClientSocket> sock = makeShared<CClientSocket>();
sock->m_self = sock;
Expand Down Expand Up @@ -101,18 +106,34 @@ void CClientSocket::addImplementation(SP<IProtocolClientImplementation>&& x) {
m_impls.emplace_back(std::move(x));
}

constexpr const size_t HANDSHAKE_MAX_MS = 5000;
void CClientSocket::setHandshakeTimeoutForTests(std::chrono::milliseconds timeout) {
g_handshakeMax = timeout;
}

void CClientSocket::resetHandshakeTimeoutForTests() {
g_handshakeMax = std::chrono::milliseconds(5000);
}

//
bool CClientSocket::dispatchEvents(bool block) {

if (m_error)
return false;

collectOrphanedObjects();

if (!m_handshakeDone) {
const auto MAX_MS =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::milliseconds(HANDSHAKE_MAX_MS) - (std::chrono::steady_clock::now() - m_handshakeBegin)).count();
int ret = poll(m_pollfds.data(), m_pollfds.size(), block ? MAX_MS : 0);
const auto elapsed = std::chrono::steady_clock::now() - m_handshakeBegin;
const auto maxMs = g_handshakeMax;

if (block && elapsed >= maxMs) {
Debug::log(ERR, "handshake error: timed out");
disconnectOnError();
return false;
}

const auto timeout = block ? std::chrono::duration_cast<std::chrono::milliseconds>(maxMs - elapsed).count() : 0;
int ret = Syscalls::poll(m_pollfds.data(), m_pollfds.size(), static_cast<int>(timeout));
if (block && !ret) {
Debug::log(ERR, "handshake error: timed out");
disconnectOnError();
Expand All @@ -121,13 +142,15 @@ bool CClientSocket::dispatchEvents(bool block) {
}

if (m_handshakeDone)
poll(m_pollfds.data(), m_pollfds.size(), block ? -1 : 0);
Syscalls::poll(m_pollfds.data(), m_pollfds.size(), block ? -1 : 0);

if (m_pollfds[0].revents & POLLHUP)
return false;

if (!(m_pollfds[0].revents & POLLIN))
if (!(m_pollfds[0].revents & POLLIN)) {
collectOrphanedObjects();
return true;
}

// dispatch

Expand Down Expand Up @@ -165,6 +188,8 @@ bool CClientSocket::dispatchEvents(bool block) {
return true;
});

collectOrphanedObjects();

return !m_error;
}

Expand Down Expand Up @@ -203,13 +228,13 @@ void CClientSocket::sendMessage(const IMessage& message) {
}

while (m_fd.isValid()) {
int ret = sendmsg(m_fd.get(), &msg, 0);
int ret = Syscalls::sendmsg(m_fd.get(), &msg, 0);
if (ret < 0 && (errno == EWOULDBLOCK || errno == EAGAIN)) {
pollfd pfd = {
.fd = m_fd.get(),
.events = POLLOUT | POLLWRBAND,
};
poll(&pfd, 1, -1);
Syscalls::poll(&pfd, 1, -1);
} else
break;
}
Expand Down Expand Up @@ -327,14 +352,38 @@ void CClientSocket::waitForObject(SP<IWireObject> x) {
}

void CClientSocket::onGeneric(const CGenericProtocolMessage& msg) {
SP<CClientObject> object;

for (const auto& o : m_objects) {
if (o->m_id == msg.m_object) {
o->called(msg.m_method, msg.m_dataSpan, msg.m_fds);
return;
if (o && o->m_id == msg.m_object) {
object = o;
break;
}
}

Debug::log(WARN, "[{} @ {:.3f}] -> Generic message not handled. No object with id {}!", m_fd.get(), steadyMillis(), msg.m_object);
if (!object) {
Debug::log(ERR, "[{} @ {:.3f}] -> Generic message references unknown object {}", m_fd.get(), steadyMillis(), msg.m_object);
disconnectOnError();
return;
}

object->called(msg.m_method, msg.m_dataSpan, msg.m_fds);
}

void CClientSocket::destroyObject(uint32_t id) {
std::erase_if(m_objects, [id](const auto& obj) { return obj && obj->m_id == id; });
}

void CClientSocket::collectOrphanedObjects() {
std::erase_if(m_objects, [](const auto& obj) {
if (!obj)
return true;

if (obj->m_id == 0)
return false;

return obj.strongRef() == 1;
});
}

SP<IObject> CClientSocket::objectForId(uint32_t id) {
Expand Down
8 changes: 7 additions & 1 deletion src/core/client/ClientSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <vector>
#include <sys/poll.h>
#include <chrono>

namespace Hyprwire {
class IMessage;
Expand All @@ -33,11 +34,16 @@ namespace Hyprwire {
virtual void roundtrip();
virtual bool isHandshakeDone();

static void setHandshakeTimeoutForTests(std::chrono::milliseconds timeout);
static void resetHandshakeTimeoutForTests();

void sendMessage(const IMessage& message);
void serverSpecs(const std::vector<std::string>& s);
void recheckPollFds();
void onSeq(uint32_t seq, uint32_t id);
void onGeneric(const CGenericProtocolMessage& msg);
void destroyObject(uint32_t id);
void collectOrphanedObjects();
SP<CClientObject> makeObject(const std::string& protocolName, const std::string& objectName, uint32_t seq);
void waitForObject(SP<IWireObject>);

Expand Down Expand Up @@ -65,4 +71,4 @@ namespace Hyprwire {
uint32_t m_lastAckdRoundtripSeq = 0;
uint32_t m_lastSentRoundtripSeq = 0;
};
};
};
11 changes: 9 additions & 2 deletions src/core/message/messages/FatalProtocolError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,19 @@ CFatalErrorMessage::CFatalErrorMessage(const std::vector<uint8_t>& data, size_t
}

CFatalErrorMessage::CFatalErrorMessage(SP<IWireObject> obj, uint32_t errorId, const std::string_view& msg) {
uint32_t objectId = 0;
if (obj)
objectId = obj->m_id;

*this = CFatalErrorMessage(objectId, errorId, msg);
}

CFatalErrorMessage::CFatalErrorMessage(uint32_t objectId, uint32_t errorId, const std::string_view& msg) {
m_type = HW_MESSAGE_TYPE_FATAL_PROTOCOL_ERROR;

m_data = {HW_MESSAGE_TYPE_FATAL_PROTOCOL_ERROR, HW_MESSAGE_MAGIC_TYPE_UINT, 0, 0, 0, 0, HW_MESSAGE_MAGIC_TYPE_UINT, 0, 0, 0, 0, HW_MESSAGE_MAGIC_TYPE_VARCHAR};

if (obj)
std::memcpy(&m_data[2], &obj->m_id, sizeof(obj->m_id));
std::memcpy(&m_data[2], &objectId, sizeof(objectId));
std::memcpy(&m_data[7], &errorId, sizeof(errorId));

m_data.append_range(g_messageParser->encodeVarInt(msg.size()));
Expand Down
3 changes: 2 additions & 1 deletion src/core/message/messages/FatalProtocolError.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ namespace Hyprwire {
public:
CFatalErrorMessage(const std::vector<uint8_t>& data, size_t offset);
CFatalErrorMessage(SP<IWireObject> obj, uint32_t errorId, const std::string_view& msg);
CFatalErrorMessage(uint32_t objectId, uint32_t errorId, const std::string_view& msg);

virtual ~CFatalErrorMessage() = default;

uint32_t m_objectId = 0;
uint32_t m_errorId = 0;
std::string m_errorMsg;
};
};
};
Loading
Loading