Core/Network: Fix invalid NetworkThread array access for derived classes that have additional data members (only HttpService threads were affected)

This commit is contained in:
Shauren
2025-12-22 13:06:28 +01:00
parent a4bbb61970
commit b69a1a71c6
7 changed files with 62 additions and 49 deletions

View File

@@ -113,8 +113,27 @@ template<typename Callable, typename SessionImpl>
concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
template<typename SessionImpl>
class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService
class HttpNetworkThread final : public NetworkThread<SessionImpl>
{
public:
explicit HttpNetworkThread(SessionService* service) : _service(service) { }
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);
}
private:
SessionService* _service = nullptr;
};
template<typename SessionImpl>
class HttpService : public SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>, public DispatcherService, public SessionService
{
using BaseSocketMgr = SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>;
public:
HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
{
@@ -123,7 +142,7 @@ public:
bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override
{
if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount))
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
SessionService::Start(ioContext);
@@ -133,7 +152,7 @@ public:
void StopNetwork() override
{
SessionService::Stop();
SocketMgr<SessionImpl>::StopNetwork();
BaseSocketMgr::StopNetwork();
}
// http handling
@@ -157,26 +176,11 @@ public:
}
protected:
class Thread : public NetworkThread<SessionImpl>
HttpNetworkThread<SessionImpl>* CreateThreads() const final
{
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);
}
private:
friend HttpService;
SessionService* _service;
};
NetworkThread<SessionImpl>* CreateThreads() const override
{
Thread* threads = new Thread[this->GetNetworkThreadCount()];
HttpNetworkThread<SessionImpl>* threads = static_cast<HttpNetworkThread<SessionImpl>*>(::operator new(sizeof(HttpNetworkThread<SessionImpl>) * this->GetNetworkThreadCount()));
for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
threads[i]._service = const_cast<HttpService*>(this);
new (&threads[i]) HttpNetworkThread<SessionImpl>(const_cast<HttpService*>(this));
return threads;
}

View File

@@ -27,9 +27,12 @@
namespace Trinity::Net
{
template<class SocketType>
template <typename SocketType, typename ThreadType>
class SocketMgr
{
static_assert(std::is_base_of_v<NetworkThread<SocketType>, ThreadType>);
static_assert(std::is_final_v<ThreadType>);
public:
SocketMgr(SocketMgr const&) = delete;
SocketMgr(SocketMgr&&) = delete;
@@ -135,10 +138,10 @@ protected:
{
}
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
virtual ThreadType* CreateThreads() const = 0;
std::unique_ptr<AsyncAcceptor> _acceptor;
std::unique_ptr<NetworkThread<SocketType>[]> _threads;
std::unique_ptr<ThreadType[]> _threads;
int32 _threadCount;
};
}

View File

@@ -42,7 +42,7 @@ enum class BanMode
BAN_ACCOUNT = 1
};
class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession>
class LoginRESTService final : public Trinity::Net::Http::HttpService<LoginHttpSession>
{
public:
using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult;

View File

@@ -16,7 +16,6 @@
*/
#include "SessionManager.h"
#include "Util.h"
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
@@ -30,9 +29,9 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext
return true;
}
Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const
{
return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()];
return new SessionNetworkThread[GetNetworkThreadCount()];
}
Battlenet::SessionManager& Battlenet::SessionManager::Instance()

View File

@@ -23,9 +23,13 @@
namespace Battlenet
{
class SessionManager : public Trinity::Net::SocketMgr<Session>
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session>
{
typedef SocketMgr<Session> BaseSocketMgr;
};
class SessionManager final : public Trinity::Net::SocketMgr<Session, SessionNetworkThread>
{
using BaseSocketMgr = SocketMgr;
public:
static SessionManager& Instance();
@@ -33,7 +37,7 @@ namespace Battlenet
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
protected:
Trinity::Net::NetworkThread<Session>* CreateThreads() const override;
SessionNetworkThread* CreateThreads() const override;
};
}

View File

@@ -17,24 +17,19 @@
#include "WorldSocketMgr.h"
#include "Config.h"
#include "NetworkThread.h"
#include "ScriptMgr.h"
#include <boost/system/error_code.hpp>
class WorldSocketThread : public Trinity::Net::NetworkThread<WorldSocket>
void WorldSocketThread::SocketAdded(std::shared_ptr<WorldSocket> const& sock)
{
public:
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override
{
sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize());
sScriptMgr->OnSocketOpen(sock);
}
sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize());
sScriptMgr->OnSocketOpen(sock);
}
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override
{
sScriptMgr->OnSocketClose(sock);
}
};
void WorldSocketThread::SocketRemoved(std::shared_ptr<WorldSocket>const& sock)
{
sScriptMgr->OnSocketClose(sock);
}
WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
{
@@ -114,7 +109,7 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
}
Trinity::Net::NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
WorldSocketThread* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];
}

View File

@@ -21,10 +21,18 @@
#include "SocketMgr.h"
#include "WorldSocket.h"
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr<WorldSocket>
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket>
{
typedef SocketMgr<WorldSocket> BaseSocketMgr;
public:
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override;
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override;
};
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocket, WorldSocketThread>
{
using BaseSocketMgr = SocketMgr;
public:
~WorldSocketMgr();
@@ -44,7 +52,7 @@ public:
protected:
WorldSocketMgr();
Trinity::Net::NetworkThread<WorldSocket>* CreateThreads() const override;
WorldSocketThread* CreateThreads() const override;
private:
int32 _socketSystemSendBufferSize;