1
0

BlockingSslClientSocket: Migrated to cNetwork API.

This commit is contained in:
Mattes D
2015-01-24 23:17:13 +01:00
parent 7dfeb67f01
commit 86f2f82d2a
3 changed files with 198 additions and 15 deletions

View File

@@ -10,6 +10,80 @@
////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocketConnectCallbacks:
class cBlockingSslClientSocketConnectCallbacks:
public cNetwork::cConnectCallbacks
{
/** The socket object that is using this instance of the callbacks. */
cBlockingSslClientSocket & m_Socket;
virtual void OnConnected(cTCPLink & a_Link) override
{
m_Socket.OnConnected();
}
virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
{
m_Socket.OnConnectError(a_ErrorMsg);
}
public:
cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket & a_Socket):
m_Socket(a_Socket)
{
}
};
////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocketLinkCallbacks:
class cBlockingSslClientSocketLinkCallbacks:
public cTCPLink::cCallbacks
{
cBlockingSslClientSocket & m_Socket;
virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
{
m_Socket.SetLink(a_Link);
}
virtual void OnReceivedData(const char * a_Data, size_t a_Length)
{
m_Socket.OnReceivedData(a_Data, a_Length);
}
virtual void OnRemoteClosed(void)
{
m_Socket.OnDisconnected();
}
virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg)
{
m_Socket.OnDisconnected();
}
public:
cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket & a_Socket):
m_Socket(a_Socket)
{
}
};
////////////////////////////////////////////////////////////////////////////////
// cBlockingSslClientSocket:
cBlockingSslClientSocket::cBlockingSslClientSocket(void) :
m_Ssl(*this),
m_IsConnected(false)
@@ -32,10 +106,19 @@ bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Po
}
// Connect the underlying socket:
m_Socket.CreateSocket(cSocket::IPv4);
if (!m_Socket.ConnectIPv4(a_ServerName.c_str(), a_Port))
m_ServerName = a_ServerName;
if (!cNetwork::Connect(a_ServerName, a_Port,
std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
)
{
return false;
}
// Wait for the connection to succeed or fail:
m_Event.Wait();
if (!m_IsConnected)
{
Printf(m_LastErrorText, "Socket connect failed: %s", m_Socket.GetLastErrorString().c_str());
return false;
}
@@ -102,7 +185,7 @@ bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
ASSERT(m_IsConnected);
// Keep sending the data until all of it is sent:
const char * Data = (const char *)a_Data;
const char * Data = reinterpret_cast<const char *>(a_Data);
size_t NumBytes = a_NumBytes;
for (;;)
{
@@ -156,7 +239,8 @@ void cBlockingSslClientSocket::Disconnect(void)
}
m_Ssl.NotifyClose();
m_Socket.CloseSocket();
m_Socket->Close();
m_Socket.reset();
m_IsConnected = false;
}
@@ -166,13 +250,25 @@ void cBlockingSslClientSocket::Disconnect(void)
int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
{
int res = m_Socket.Receive((char *)a_Buffer, a_NumBytes, 0);
if (res < 0)
// Wait for any incoming data, if there is none:
cCSLock Lock(m_CSIncomingData);
while (m_IsConnected && m_IncomingData.empty())
{
cCSUnlock Unlock(Lock);
m_Event.Wait();
}
// If we got disconnected, report an error after processing all data:
if (!m_IsConnected && m_IncomingData.empty())
{
// PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
return POLARSSL_ERR_NET_RECV_FAILED;
}
return res;
// Copy the data from the incoming buffer into the specified space:
size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
m_IncomingData.erase(0, NumToCopy);
return static_cast<int>(NumToCopy);
}
@@ -181,13 +277,69 @@ int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t
int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
{
int res = m_Socket.Send((const char *)a_Buffer, a_NumBytes);
if (res < 0)
cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket.
if (Socket == nullptr)
{
return POLARSSL_ERR_NET_SEND_FAILED;
}
if (!Socket->Send(a_Buffer, a_NumBytes))
{
// PolarSSL's net routines distinguish between connection reset and general failure, we don't need to
return POLARSSL_ERR_NET_SEND_FAILED;
}
return res;
return static_cast<int>(a_NumBytes);
}
void cBlockingSslClientSocket::OnConnected(void)
{
m_IsConnected = true;
m_Event.Set();
}
void cBlockingSslClientSocket::OnConnectError(const AString & a_ErrorMsg)
{
LOG("Cannot connect to %s: %s", m_ServerName.c_str(), a_ErrorMsg.c_str());
m_Event.Set();
}
void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
{
{
cCSLock Lock(m_CSIncomingData);
m_IncomingData.append(a_Data, a_Size);
}
m_Event.Set();
}
void cBlockingSslClientSocket::SetLink(cTCPLinkPtr a_Link)
{
m_Socket = a_Link;
}
void cBlockingSslClientSocket::OnDisconnected(void)
{
m_Socket.reset();
m_IsConnected = false;
m_Event.Set();
}