diff --git a/libfreerdp/core/gateway/ncacn_http.c b/libfreerdp/core/gateway/ncacn_http.c
index cffb378..f288a0f 100644
--- a/libfreerdp/core/gateway/ncacn_http.c
+++ b/libfreerdp/core/gateway/ncacn_http.c
@@ -121,6 +121,7 @@ BOOL rpc_ncacn_http_recv_in_channel_response(RpcChannel* inChannel, HttpResponse
if (ntlmTokenData && ntlmTokenLength)
return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength);
+ free(ntlmTokenData);
return TRUE;
}
@@ -274,5 +275,6 @@ BOOL rpc_ncacn_http_recv_out_channel_response(RpcChannel* outChannel, HttpRespon
if (ntlmTokenData && ntlmTokenLength)
return ntlm_client_set_input_buffer(ntlm, FALSE, ntlmTokenData, ntlmTokenLength);
+ free(ntlmTokenData);
return TRUE;
}
diff --git a/libfreerdp/core/gateway/rdg.c b/libfreerdp/core/gateway/rdg.c
index 6ea9e4f..35575db 100644
--- a/libfreerdp/core/gateway/rdg.c
+++ b/libfreerdp/core/gateway/rdg.c
@@ -1159,6 +1159,9 @@ static BOOL rdg_tunnel_connect(rdpRdg* rdg)
if (!status)
{
+ assert(rdg);
+ assert(rdg->context);
+ assert(rdg->context->rdp);
rdg->context->rdp->transport->layer = TRANSPORT_LAYER_CLOSED;
return FALSE;
}
@@ -1190,6 +1193,9 @@ BOOL rdg_connect(rdpRdg* rdg, int timeout, BOOL* rpcFallback)
if (!status)
{
+ assert(rdg);
+ assert(rdg->context);
+ assert(rdg->context->rdp);
rdg->context->rdp->transport->layer = TRANSPORT_LAYER_CLOSED;
return FALSE;
}
@@ -1535,10 +1541,10 @@ static int rdg_bio_gets(BIO* bio, char* str, int size)
return -2;
}
-static long rdg_bio_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
+static long rdg_bio_ctrl(BIO* in_bio, int cmd, long arg1, void* arg2)
{
long status = -1;
- rdpRdg* rdg = (rdpRdg*)BIO_get_data(bio);
+ rdpRdg* rdg = (rdpRdg*)BIO_get_data(in_bio);
rdpTls* tlsOut = rdg->tlsOut;
rdpTls* tlsIn = rdg->tlsIn;
diff --git a/libfreerdp/core/gateway/rpc.c b/libfreerdp/core/gateway/rpc.c
index 0b47024..4ba52f3 100644
--- a/libfreerdp/core/gateway/rpc.c
+++ b/libfreerdp/core/gateway/rpc.c
@@ -24,6 +24,7 @@
#endif
#include <winpr/crt.h>
+#include <assert.h>
#include <winpr/tchar.h>
#include <winpr/synch.h>
#include <winpr/dsparse.h>
@@ -46,6 +47,7 @@
#include "rpc_client.h"
#include "rpc.h"
+#include "rts.h"
#define TAG FREERDP_TAG("core.gateway.rpc")
@@ -88,8 +90,10 @@ static const char* PTYPE_STRINGS[] = { "PTYPE_REQUEST", "PTYPE_PING",
*
*/
-void rpc_pdu_header_print(rpcconn_hdr_t* header)
+void rpc_pdu_header_print(const rpcconn_hdr_t* header)
{
+ assert(header);
+
WLog_INFO(TAG, "rpc_vers: %" PRIu8 "", header->common.rpc_vers);
WLog_INFO(TAG, "rpc_vers_minor: %" PRIu8 "", header->common.rpc_vers_minor);
@@ -139,26 +143,30 @@ void rpc_pdu_header_print(rpcconn_hdr_t* header)
}
}
-void rpc_pdu_header_init(rdpRpc* rpc, rpcconn_hdr_t* header)
+rpcconn_common_hdr_t rpc_pdu_header_init(const rdpRpc* rpc)
{
- header->common.rpc_vers = rpc->rpc_vers;
- header->common.rpc_vers_minor = rpc->rpc_vers_minor;
- header->common.packed_drep[0] = rpc->packed_drep[0];
- header->common.packed_drep[1] = rpc->packed_drep[1];
- header->common.packed_drep[2] = rpc->packed_drep[2];
- header->common.packed_drep[3] = rpc->packed_drep[3];
+ rpcconn_common_hdr_t header = { 0 };
+ assert(rpc);
+
+ header.rpc_vers = rpc->rpc_vers;
+ header.rpc_vers_minor = rpc->rpc_vers_minor;
+ header.packed_drep[0] = rpc->packed_drep[0];
+ header.packed_drep[1] = rpc->packed_drep[1];
+ header.packed_drep[2] = rpc->packed_drep[2];
+ header.packed_drep[3] = rpc->packed_drep[3];
+ return header;
}
-UINT32 rpc_offset_align(UINT32* offset, UINT32 alignment)
+size_t rpc_offset_align(size_t* offset, size_t alignment)
{
- UINT32 pad;
+ size_t pad;
pad = *offset;
*offset = (*offset + alignment - 1) & ~(alignment - 1);
pad = *offset - pad;
return pad;
}
-UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad)
+size_t rpc_offset_pad(size_t* offset, size_t pad)
{
*offset += pad;
return pad;
@@ -239,64 +247,67 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad)
*
*/
-BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* length)
+BOOL rpc_get_stub_data_info(const rpcconn_hdr_t* header, size_t* poffset, size_t* length)
{
- UINT32 alloc_hint = 0;
- rpcconn_hdr_t* header;
+ size_t used = 0;
+ size_t offset = 0;
+ BOOL rc = FALSE;
UINT32 frag_length;
UINT32 auth_length;
- UINT32 auth_pad_length;
+ UINT32 auth_pad_length = 0;
UINT32 sec_trailer_offset;
- rpc_sec_trailer* sec_trailer;
- *offset = RPC_COMMON_FIELDS_LENGTH;
- header = ((rpcconn_hdr_t*)buffer);
+ const rpc_sec_trailer* sec_trailer = NULL;
+
+ assert(header);
+ assert(poffset);
+ assert(length);
+
+ offset = RPC_COMMON_FIELDS_LENGTH;
switch (header->common.ptype)
{
case PTYPE_RESPONSE:
- *offset += 8;
- rpc_offset_align(offset, 8);
- alloc_hint = header->response.alloc_hint;
+ offset += 8;
+ rpc_offset_align(&offset, 8);
+ sec_trailer = &header->response.auth_verifier;
break;
case PTYPE_REQUEST:
- *offset += 4;
- rpc_offset_align(offset, 8);
- alloc_hint = header->request.alloc_hint;
+ offset += 4;
+ rpc_offset_align(&offset, 8);
+ sec_trailer = &header->request.auth_verifier;
break;
case PTYPE_RTS:
- *offset += 4;
+ offset += 4;
break;
default:
WLog_ERR(TAG, "Unknown PTYPE: 0x%02" PRIX8 "", header->common.ptype);
- return FALSE;
+ goto fail;
}
- if (!length)
- return TRUE;
+ frag_length = header->common.frag_length;
+ auth_length = header->common.auth_length;
+
+ if (poffset)
+ *poffset = offset;
- if (header->common.ptype == PTYPE_REQUEST)
+ /* The fragment must be larger than the authentication trailer */
+ used = offset + auth_length + 8ull;
+ if (sec_trailer)
{
- UINT32 sec_trailer_offset;
- sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
- *length = sec_trailer_offset - *offset;
- return TRUE;
+ auth_pad_length = sec_trailer->auth_pad_length;
+ used += sec_trailer->auth_pad_length;
}
- frag_length = header->common.frag_length;
- auth_length = header->common.auth_length;
+ if (frag_length < used)
+ goto fail;
+
+ if (!length)
+ return TRUE;
+
sec_trailer_offset = frag_length - auth_length - 8;
- sec_trailer = (rpc_sec_trailer*)&buffer[sec_trailer_offset];
- auth_pad_length = sec_trailer->auth_pad_length;
-#if 0
- WLog_DBG(TAG,
- "sec_trailer: type: %"PRIu8" level: %"PRIu8" pad_length: %"PRIu8" reserved: %"PRIu8" context_id: %"PRIu32"",
- sec_trailer->auth_type, sec_trailer->auth_level,
- sec_trailer->auth_pad_length, sec_trailer->auth_reserved,
- sec_trailer->auth_context_id);
-#endif
/**
* According to [MS-RPCE], auth_pad_length is the number of padding
@@ -310,18 +321,21 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
auth_length, (frag_length - (sec_trailer_offset + 8)));
}
- *length = frag_length - auth_length - 24 - 8 - auth_pad_length;
- return TRUE;
+ *length = sec_trailer_offset - auth_pad_length - offset;
+
+ rc = TRUE;
+fail:
+ return rc;
}
SSIZE_T rpc_channel_read(RpcChannel* channel, wStream* s, size_t length)
{
int status;
- if (!channel)
+ if (!channel || (length > INT32_MAX))
return -1;
- status = BIO_read(channel->tls->bio, Stream_Pointer(s), length);
+ status = BIO_read(channel->tls->bio, Stream_Pointer(s), (INT32)length);
if (status > 0)
{
@@ -340,10 +354,10 @@ SSIZE_T rpc_channel_write(RpcChannel* channel, const BYTE* data, size_t length)
{
int status;
- if (!channel)
+ if (!channel || (length > INT32_MAX))
return -1;
- status = tls_write_all(channel->tls, data, length);
+ status = tls_write_all(channel->tls, data, (INT32)length);
return status;
}
@@ -629,7 +643,7 @@ static void rpc_virtual_connection_free(RpcVirtualConnection* connection)
free(connection);
}
-static BOOL rpc_channel_tls_connect(RpcChannel* channel, int timeout)
+static BOOL rpc_channel_tls_connect(RpcChannel* channel, UINT32 timeout)
{
int sockfd;
rdpTls* tls;
@@ -719,7 +733,7 @@ static BOOL rpc_channel_tls_connect(RpcChannel* channel, int timeout)
return TRUE;
}
-static int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout)
+static int rpc_in_channel_connect(RpcInChannel* inChannel, UINT32 timeout)
{
rdpContext* context;
@@ -814,7 +828,7 @@ int rpc_out_channel_replacement_connect(RpcOutChannel* outChannel, int timeout)
return 1;
}
-BOOL rpc_connect(rdpRpc* rpc, int timeout)
+BOOL rpc_connect(rdpRpc* rpc, UINT32 timeout)
{
RpcInChannel* inChannel;
RpcOutChannel* outChannel;
@@ -840,7 +854,15 @@ BOOL rpc_connect(rdpRpc* rpc, int timeout)
rdpRpc* rpc_new(rdpTransport* transport)
{
- rdpRpc* rpc = (rdpRpc*)calloc(1, sizeof(rdpRpc));
+ rdpContext* context;
+ rdpRpc* rpc;
+
+ assert(transport);
+
+ context = transport->context;
+ assert(context);
+
+ rpc = (rdpRpc*)calloc(1, sizeof(rdpRpc));
if (!rpc)
return NULL;
@@ -848,7 +870,7 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->State = RPC_CLIENT_STATE_INITIAL;
rpc->transport = transport;
rpc->settings = transport->settings;
- rpc->context = transport->context;
+ rpc->context = context;
rpc->SendSeqNum = 0;
rpc->ntlm = ntlm_new();
@@ -873,7 +895,7 @@ rdpRpc* rpc_new(rdpTransport* transport)
rpc->CurrentKeepAliveInterval = rpc->KeepAliveInterval;
rpc->CurrentKeepAliveTime = 0;
rpc->CallId = 2;
- rpc->client = rpc_client_new(rpc->context, rpc->max_recv_frag);
+ rpc->client = rpc_client_new(context, rpc->max_recv_frag);
if (!rpc->client)
goto out_free;
diff --git a/libfreerdp/core/gateway/rpc.h b/libfreerdp/core/gateway/rpc.h
index 5c315d9..295085b 100644
--- a/libfreerdp/core/gateway/rpc.h
+++ b/libfreerdp/core/gateway/rpc.h
@@ -46,14 +46,6 @@ typedef struct rdp_rpc rdpRpc;
#define RPC_COMMON_FIELDS_LENGTH 16
-typedef struct
-{
- DEFINE_RPC_COMMON_FIELDS();
-
- UINT16 Flags;
- UINT16 NumberOfCommands;
-} rpcconn_rts_hdr_t;
-
#define RTS_PDU_HEADER_LENGTH 20
#define RPC_PDU_FLAG_STUB 0x00000001
@@ -71,7 +63,6 @@ typedef struct _RPC_PDU
#include "../tcp.h"
#include "../transport.h"
-#include "rts.h"
#include "http.h"
#include "ntlm.h"
@@ -146,6 +137,14 @@ typedef struct
DEFINE_RPC_COMMON_FIELDS();
} rpcconn_common_hdr_t;
+typedef struct
+{
+ rpcconn_common_hdr_t header;
+
+ UINT16 Flags;
+ UINT16 NumberOfCommands;
+} rpcconn_rts_hdr_t;
+
typedef UINT16 p_context_id_t;
typedef UINT16 p_reject_reason_t;
@@ -314,7 +313,7 @@ typedef struct auth_verifier_co_s auth_verifier_co_t;
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT16 max_xmit_frag;
UINT16 max_recv_frag;
@@ -328,7 +327,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT16 max_xmit_frag;
UINT16 max_recv_frag;
@@ -345,7 +344,7 @@ typedef struct
/* bind header */
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT16 max_xmit_frag;
UINT16 max_recv_frag;
@@ -358,7 +357,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT16 max_xmit_frag;
UINT16 max_recv_frag;
@@ -375,7 +374,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT16 max_xmit_frag;
UINT16 max_recv_frag;
@@ -385,7 +384,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
p_reject_reason_t provider_reject_reason;
@@ -394,7 +393,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
auth_verifier_co_t auth_verifier;
@@ -460,7 +459,7 @@ typedef struct _RPC_FAULT_CODE RPC_FAULT_CODE;
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT32 alloc_hint;
p_context_id_t p_cont_id;
@@ -479,14 +478,14 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
auth_verifier_co_t auth_verifier;
} rpcconn_orphaned_hdr_t;
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT32 alloc_hint;
@@ -505,7 +504,7 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
UINT32 alloc_hint;
p_context_id_t p_cont_id;
@@ -522,10 +521,11 @@ typedef struct
typedef struct
{
- DEFINE_RPC_COMMON_FIELDS();
+ rpcconn_common_hdr_t header;
} rpcconn_shutdown_hdr_t;
-typedef union {
+typedef union
+{
rpcconn_common_hdr_t common;
rpcconn_alter_context_hdr_t alter_context;
rpcconn_alter_context_response_hdr_t alter_context_response;
@@ -767,14 +767,14 @@ struct rdp_rpc
RpcVirtualConnection* VirtualConnection;
};
-FREERDP_LOCAL void rpc_pdu_header_print(rpcconn_hdr_t* header);
-FREERDP_LOCAL void rpc_pdu_header_init(rdpRpc* rpc, rpcconn_hdr_t* header);
+FREERDP_LOCAL void rpc_pdu_header_print(const rpcconn_hdr_t* header);
+FREERDP_LOCAL rpcconn_common_hdr_t rpc_pdu_header_init(const rdpRpc* rpc);
-FREERDP_LOCAL UINT32 rpc_offset_align(UINT32* offset, UINT32 alignment);
-FREERDP_LOCAL UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad);
+FREERDP_LOCAL size_t rpc_offset_align(size_t* offset, size_t alignment);
+FREERDP_LOCAL size_t rpc_offset_pad(size_t* offset, size_t pad);
-FREERDP_LOCAL BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset,
- UINT32* length);
+FREERDP_LOCAL BOOL rpc_get_stub_data_info(const rpcconn_hdr_t* header, size_t* offset,
+ size_t* length);
FREERDP_LOCAL SSIZE_T rpc_channel_write(RpcChannel* channel, const BYTE* data, size_t length);
@@ -794,7 +794,7 @@ FREERDP_LOCAL BOOL rpc_virtual_connection_transition_to_state(rdpRpc* rpc,
RpcVirtualConnection* connection,
VIRTUAL_CONNECTION_STATE state);
-FREERDP_LOCAL BOOL rpc_connect(rdpRpc* rpc, int timeout);
+FREERDP_LOCAL BOOL rpc_connect(rdpRpc* rpc, UINT32 timeout);
FREERDP_LOCAL rdpRpc* rpc_new(rdpTransport* transport);
FREERDP_LOCAL void rpc_free(rdpRpc* rpc);
diff --git a/libfreerdp/core/gateway/rpc_bind.c b/libfreerdp/core/gateway/rpc_bind.c
index 4cfd022..5bc227b 100644
--- a/libfreerdp/core/gateway/rpc_bind.c
+++ b/libfreerdp/core/gateway/rpc_bind.c
@@ -22,11 +22,14 @@
#endif
#include <winpr/crt.h>
+#include <assert.h>
#include <freerdp/log.h>
#include "rpc_client.h"
+#include "rts.h"
+
#include "rpc_bind.h"
#define TAG FREERDP_TAG("core.gateway.rpc")
@@ -106,18 +109,32 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
{
BOOL continueNeeded = FALSE;
int status = -1;
- BYTE* buffer = NULL;
+ wStream* buffer = NULL;
UINT32 offset;
- UINT32 length;
RpcClientCall* clientCall;
p_cont_elem_t* p_cont_elem;
- rpcconn_bind_hdr_t* bind_pdu = NULL;
+ rpcconn_bind_hdr_t bind_pdu = { 0 };
BOOL promptPassword = FALSE;
- rdpSettings* settings = rpc->settings;
- freerdp* instance = (freerdp*)settings->instance;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcInChannel* inChannel = connection->DefaultInChannel;
+ rdpSettings* settings;
+ freerdp* instance;
+ RpcVirtualConnection* connection;
+ RpcInChannel* inChannel;
const SecBuffer* sbuffer = NULL;
+
+ assert(rpc);
+
+ settings = rpc->settings;
+ assert(settings);
+
+ instance = (freerdp*)settings->instance;
+ assert(instance);
+
+ connection = rpc->VirtualConnection;
+
+ assert(connection);
+
+ inChannel = connection->DefaultInChannel;
+
WLog_DBG(TAG, "Sending Bind PDU");
ntlm_free(rpc->ntlm);
rpc->ntlm = ntlm_new();
@@ -180,35 +197,30 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
if (!continueNeeded)
goto fail;
- bind_pdu = (rpcconn_bind_hdr_t*)calloc(1, sizeof(rpcconn_bind_hdr_t));
-
- if (!bind_pdu)
- goto fail;
-
sbuffer = ntlm_client_get_output_buffer(rpc->ntlm);
if (!sbuffer)
goto fail;
- rpc_pdu_header_init(rpc, (rpcconn_hdr_t*)bind_pdu);
- bind_pdu->auth_length = (UINT16)sbuffer->cbBuffer;
- bind_pdu->auth_verifier.auth_value = sbuffer->pvBuffer;
- bind_pdu->ptype = PTYPE_BIND;
- bind_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_SUPPORT_HEADER_SIGN | PFC_CONC_MPX;
- bind_pdu->call_id = 2;
- bind_pdu->max_xmit_frag = rpc->max_xmit_frag;
- bind_pdu->max_recv_frag = rpc->max_recv_frag;
- bind_pdu->assoc_group_id = 0;
- bind_pdu->p_context_elem.n_context_elem = 2;
- bind_pdu->p_context_elem.reserved = 0;
- bind_pdu->p_context_elem.reserved2 = 0;
- bind_pdu->p_context_elem.p_cont_elem =
- calloc(bind_pdu->p_context_elem.n_context_elem, sizeof(p_cont_elem_t));
-
- if (!bind_pdu->p_context_elem.p_cont_elem)
+ bind_pdu.header = rpc_pdu_header_init(rpc);
+ bind_pdu.header.auth_length = (UINT16)sbuffer->cbBuffer;
+ bind_pdu.auth_verifier.auth_value = sbuffer->pvBuffer;
+ bind_pdu.header.ptype = PTYPE_BIND;
+ bind_pdu.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_SUPPORT_HEADER_SIGN | PFC_CONC_MPX;
+ bind_pdu.header.call_id = 2;
+ bind_pdu.max_xmit_frag = rpc->max_xmit_frag;
+ bind_pdu.max_recv_frag = rpc->max_recv_frag;
+ bind_pdu.assoc_group_id = 0;
+ bind_pdu.p_context_elem.n_context_elem = 2;
+ bind_pdu.p_context_elem.reserved = 0;
+ bind_pdu.p_context_elem.reserved2 = 0;
+ bind_pdu.p_context_elem.p_cont_elem =
+ calloc(bind_pdu.p_context_elem.n_context_elem, sizeof(p_cont_elem_t));
+
+ if (!bind_pdu.p_context_elem.p_cont_elem)
goto fail;
- p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
+ p_cont_elem = &bind_pdu.p_context_elem.p_cont_elem[0];
p_cont_elem->p_cont_id = 0;
p_cont_elem->n_transfer_syn = 1;
p_cont_elem->reserved = 0;
@@ -221,7 +233,7 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
CopyMemory(&(p_cont_elem->transfer_syntaxes[0].if_uuid), &NDR_UUID, sizeof(p_uuid_t));
p_cont_elem->transfer_syntaxes[0].if_version = NDR_SYNTAX_IF_VERSION;
- p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[1];
+ p_cont_elem = &bind_pdu.p_context_elem.p_cont_elem[1];
p_cont_elem->p_cont_id = 1;
p_cont_elem->n_transfer_syn = 1;
p_cont_elem->reserved = 0;
@@ -235,31 +247,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
CopyMemory(&(p_cont_elem->transfer_syntaxes[0].if_uuid), &BTFN_UUID, sizeof(p_uuid_t));
p_cont_elem->transfer_syntaxes[0].if_version = BTFN_SYNTAX_IF_VERSION;
offset = 116;
- bind_pdu->auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4);
- bind_pdu->auth_verifier.auth_type = RPC_C_AUTHN_WINNT;
- bind_pdu->auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY;
- bind_pdu->auth_verifier.auth_reserved = 0x00;
- bind_pdu->auth_verifier.auth_context_id = 0x00000000;
- offset += (8 + bind_pdu->auth_length);
- bind_pdu->frag_length = offset;
- buffer = (BYTE*)malloc(bind_pdu->frag_length);
+ bind_pdu.auth_verifier.auth_type = RPC_C_AUTHN_WINNT;
+ bind_pdu.auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY;
+ bind_pdu.auth_verifier.auth_reserved = 0x00;
+ bind_pdu.auth_verifier.auth_context_id = 0x00000000;
+ offset += (8 + bind_pdu.header.auth_length);
+ bind_pdu.header.frag_length = offset;
+
+ buffer = Stream_New(NULL, bind_pdu.header.frag_length);
if (!buffer)
goto fail;
- CopyMemory(buffer, bind_pdu, 24);
- CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
- CopyMemory(&buffer[28], &bind_pdu->p_context_elem.p_cont_elem[0], 24);
- CopyMemory(&buffer[52], bind_pdu->p_context_elem.p_cont_elem[0].transfer_syntaxes, 20);
- CopyMemory(&buffer[72], &bind_pdu->p_context_elem.p_cont_elem[1], 24);
- CopyMemory(&buffer[96], bind_pdu->p_context_elem.p_cont_elem[1].transfer_syntaxes, 20);
- offset = 116;
- rpc_offset_pad(&offset, bind_pdu->auth_verifier.auth_pad_length);
- CopyMemory(&buffer[offset], &bind_pdu->auth_verifier.auth_type, 8);
- CopyMemory(&buffer[offset + 8], bind_pdu->auth_verifier.auth_value, bind_pdu->auth_length);
- offset += (8 + bind_pdu->auth_length);
- length = bind_pdu->frag_length;
- clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
+ if (!rts_write_pdu_bind(buffer, &bind_pdu))
+ goto fail;
+
+ clientCall = rpc_client_call_new(bind_pdu.header.call_id, 0);
if (!clientCall)
goto fail;
@@ -270,22 +273,19 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
goto fail;
}
- status = rpc_in_channel_send_pdu(inChannel, buffer, length);
+ Stream_SealLength(buffer);
+ status = rpc_in_channel_send_pdu(inChannel, Stream_Buffer(buffer), Stream_Length(buffer));
fail:
- if (bind_pdu)
+ if (bind_pdu.p_context_elem.p_cont_elem)
{
- if (bind_pdu->p_context_elem.p_cont_elem)
- {
- free(bind_pdu->p_context_elem.p_cont_elem[0].transfer_syntaxes);
- free(bind_pdu->p_context_elem.p_cont_elem[1].transfer_syntaxes);
- }
-
- free(bind_pdu->p_context_elem.p_cont_elem);
+ free(bind_pdu.p_context_elem.p_cont_elem[0].transfer_syntaxes);
+ free(bind_pdu.p_context_elem.p_cont_elem[1].transfer_syntaxes);
}
- free(bind_pdu);
- free(buffer);
+ free(bind_pdu.p_context_elem.p_cont_elem);
+
+ Stream_Free(buffer, TRUE);
return (status > 0) ? 1 : -1;
}
@@ -315,31 +315,47 @@ fail:
* example.
*/
-int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+BOOL rpc_recv_bind_ack_pdu(rdpRpc* rpc, wStream* s)
{
+ BOOL rc = FALSE;
BOOL continueNeeded = FALSE;
- BYTE* auth_data;
- rpcconn_hdr_t* header;
- header = (rpcconn_hdr_t*)buffer;
+ const BYTE* auth_data;
+ size_t pos, end;
+ rpcconn_hdr_t header = { 0 };
+
+ assert(rpc);
+ assert(rpc->ntlm);
+ assert(s);
+
+ pos = Stream_GetPosition(s);
+ if (!rts_read_pdu_header(s, &header))
+ goto fail;
+
WLog_DBG(TAG, "Receiving BindAck PDU");
- if (!rpc || !rpc->ntlm)
- return -1;
+ rpc->max_recv_frag = header.bind_ack.max_xmit_frag;
+ rpc->max_xmit_frag = header.bind_ack.max_recv_frag;
- rpc->max_recv_frag = header->bind_ack.max_xmit_frag;
- rpc->max_xmit_frag = header->bind_ack.max_recv_frag;
- auth_data = buffer + (header->common.frag_length - header->common.auth_length);
+ /* Get the correct offset in the input data and pass that on as input buffer.
+ * rts_read_pdu_header did already do consistency checks */
+ end = Stream_GetPosition(s);
+ Stream_SetPosition(s, pos + header.common.frag_length - header.common.auth_length);
+ auth_data = Stream_Pointer(s);
+ Stream_SetPosition(s, end);
- if (!ntlm_client_set_input_buffer(rpc->ntlm, TRUE, auth_data, header->common.auth_length))
- return -1;
+ if (!ntlm_client_set_input_buffer(rpc->ntlm, TRUE, auth_data, header.common.auth_length))
+ goto fail;
if (!ntlm_authenticate(rpc->ntlm, &continueNeeded))
- return -1;
+ goto fail;
if (continueNeeded)
- return -1;
+ goto fail;
- return (int)length;
+ rc = TRUE;
+fail:
+ rts_free_pdu_header(&header, FALSE);
+ return rc;
}
/**
@@ -352,67 +368,63 @@ int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc)
{
int status = -1;
- BYTE* buffer;
- UINT32 offset;
- UINT32 length;
+ wStream* buffer;
+ size_t offset;
const SecBuffer* sbuffer;
RpcClientCall* clientCall;
- rpcconn_rpc_auth_3_hdr_t* auth_3_pdu;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcInChannel* inChannel = connection->DefaultInChannel;
- WLog_DBG(TAG, "Sending RpcAuth3 PDU");
- auth_3_pdu = (rpcconn_rpc_auth_3_hdr_t*)calloc(1, sizeof(rpcconn_rpc_auth_3_hdr_t));
+ rpcconn_rpc_auth_3_hdr_t auth_3_pdu = { 0 };
+ RpcVirtualConnection* connection;
+ RpcInChannel* inChannel;
- if (!auth_3_pdu)
- return -1;
+ assert(rpc);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
+ inChannel = connection->DefaultInChannel;
+ assert(inChannel);
+
+ WLog_DBG(TAG, "Sending RpcAuth3 PDU");
sbuffer = ntlm_client_get_output_buffer(rpc->ntlm);
if (!sbuffer)
- {
- free(auth_3_pdu);
return -1;
- }
- rpc_pdu_header_init(rpc, (rpcconn_hdr_t*)auth_3_pdu);
- auth_3_pdu->auth_length = (UINT16)sbuffer->cbBuffer;
- auth_3_pdu->auth_verifier.auth_value = sbuffer->pvBuffer;
- auth_3_pdu->ptype = PTYPE_RPC_AUTH_3;
- auth_3_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_CONC_MPX;
- auth_3_pdu->call_id = 2;
- auth_3_pdu->max_xmit_frag = rpc->max_xmit_frag;
- auth_3_pdu->max_recv_frag = rpc->max_recv_frag;
+ auth_3_pdu.header = rpc_pdu_header_init(rpc);
+ auth_3_pdu.header.auth_length = (UINT16)sbuffer->cbBuffer;
+ auth_3_pdu.auth_verifier.auth_value = sbuffer->pvBuffer;
+ auth_3_pdu.header.ptype = PTYPE_RPC_AUTH_3;
+ auth_3_pdu.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG | PFC_CONC_MPX;
+ auth_3_pdu.header.call_id = 2;
+ auth_3_pdu.max_xmit_frag = rpc->max_xmit_frag;
+ auth_3_pdu.max_recv_frag = rpc->max_recv_frag;
offset = 20;
- auth_3_pdu->auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4);
- auth_3_pdu->auth_verifier.auth_type = RPC_C_AUTHN_WINNT;
- auth_3_pdu->auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY;
- auth_3_pdu->auth_verifier.auth_reserved = 0x00;
- auth_3_pdu->auth_verifier.auth_context_id = 0x00000000;
- offset += (8 + auth_3_pdu->auth_length);
- auth_3_pdu->frag_length = offset;
- buffer = (BYTE*)malloc(auth_3_pdu->frag_length);
+ auth_3_pdu.auth_verifier.auth_pad_length = rpc_offset_align(&offset, 4);
+ auth_3_pdu.auth_verifier.auth_type = RPC_C_AUTHN_WINNT;
+ auth_3_pdu.auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY;
+ auth_3_pdu.auth_verifier.auth_reserved = 0x00;
+ auth_3_pdu.auth_verifier.auth_context_id = 0x00000000;
+ offset += (8 + auth_3_pdu.header.auth_length);
+ auth_3_pdu.header.frag_length = offset;
+
+ buffer = Stream_New(NULL, auth_3_pdu.header.frag_length);
if (!buffer)
- {
- free(auth_3_pdu);
return -1;
- }
- CopyMemory(buffer, auth_3_pdu, 20);
- offset = 20;
- rpc_offset_pad(&offset, auth_3_pdu->auth_verifier.auth_pad_length);
- CopyMemory(&buffer[offset], &auth_3_pdu->auth_verifier.auth_type, 8);
- CopyMemory(&buffer[offset + 8], auth_3_pdu->auth_verifier.auth_value, auth_3_pdu->auth_length);
- offset += (8 + auth_3_pdu->auth_length);
- length = auth_3_pdu->frag_length;
- clientCall = rpc_client_call_new(auth_3_pdu->call_id, 0);
+ if (!rts_write_pdu_auth3(buffer, &auth_3_pdu))
+ goto fail;
+
+ clientCall = rpc_client_call_new(auth_3_pdu.header.call_id, 0);
if (ArrayList_Add(rpc->client->ClientCallList, clientCall) >= 0)
{
- status = rpc_in_channel_send_pdu(inChannel, buffer, length);
+ Stream_SealLength(buffer);
+ status = rpc_in_channel_send_pdu(inChannel, Stream_Buffer(buffer), Stream_Length(buffer));
}
- free(auth_3_pdu);
- free(buffer);
+fail:
+ Stream_Free(buffer, TRUE);
return (status > 0) ? 1 : -1;
}
diff --git a/libfreerdp/core/gateway/rpc_bind.h b/libfreerdp/core/gateway/rpc_bind.h
index 759555f..69758e5 100644
--- a/libfreerdp/core/gateway/rpc_bind.h
+++ b/libfreerdp/core/gateway/rpc_bind.h
@@ -35,7 +35,7 @@ FREERDP_LOCAL extern const p_uuid_t BTFN_UUID;
#define BTFN_SYNTAX_IF_VERSION 0x00000001
FREERDP_LOCAL int rpc_send_bind_pdu(rdpRpc* rpc);
-FREERDP_LOCAL int rpc_recv_bind_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
+FREERDP_LOCAL BOOL rpc_recv_bind_ack_pdu(rdpRpc* rpc, wStream* s);
FREERDP_LOCAL int rpc_send_rpc_auth_3_pdu(rdpRpc* rpc);
#endif /* FREERDP_LIB_CORE_GATEWAY_RPC_BIND_H */
diff --git a/libfreerdp/core/gateway/rpc_client.c b/libfreerdp/core/gateway/rpc_client.c
index 23518f8..e6a89a6 100644
--- a/libfreerdp/core/gateway/rpc_client.c
+++ b/libfreerdp/core/gateway/rpc_client.c
@@ -24,6 +24,7 @@
#include <freerdp/log.h>
#include <winpr/crt.h>
+#include <assert.h>
#include <winpr/print.h>
#include <winpr/synch.h>
#include <winpr/thread.h>
@@ -35,6 +36,8 @@
#include "rpc_bind.h"
#include "rpc_fault.h"
#include "rpc_client.h"
+#include "rts_signature.h"
+
#include "../rdp.h"
#include "../proxy.h"
@@ -99,7 +102,7 @@ static int rpc_client_receive_pipe_write(RpcClient* client, const BYTE* buffer,
int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length)
{
int index = 0;
- int status = 0;
+ size_t status = 0;
int nchunks = 0;
DataChunk chunks[2];
@@ -122,7 +125,10 @@ int rpc_client_receive_pipe_read(RpcClient* client, BYTE* buffer, size_t length)
ResetEvent(client->PipeEvent);
LeaveCriticalSection(&(client->PipeLock));
- return status;
+
+ if (status > INT_MAX)
+ return -1;
+ return (int)status;
}
static int rpc_client_transition_to_state(rdpRpc* rpc, RPC_CLIENT_STATE state)
@@ -173,8 +179,15 @@ static int rpc_client_transition_to_state(rdpRpc* rpc, RPC_CLIENT_STATE state)
static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
{
int status = -1;
- rpcconn_rts_hdr_t* rts;
- rdpTsg* tsg = rpc->transport->tsg;
+ rdpTsg* tsg;
+
+ assert(rpc);
+ assert(pdu);
+
+ Stream_SealLength(pdu->s);
+ Stream_SetPosition(pdu->s, 0);
+
+ tsg = rpc->transport->tsg;
if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
{
@@ -187,17 +200,13 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
break;
case VIRTUAL_CONNECTION_STATE_WAIT_A3W:
- rts = (rpcconn_rts_hdr_t*)Stream_Buffer(pdu->s);
-
- if (!rts_match_pdu_signature(&RTS_PDU_CONN_A3_SIGNATURE, rts))
+ if (!rts_match_pdu_signature(&RTS_PDU_CONN_A3_SIGNATURE, pdu->s, NULL))
{
WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/A3");
return -1;
}
- status = rts_recv_CONN_A3_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
-
- if (status < 0)
+ if (!rts_recv_CONN_A3_pdu(rpc, pdu->s))
{
WLog_ERR(TAG, "rts_recv_CONN_A3_pdu failure");
return -1;
@@ -209,17 +218,13 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
break;
case VIRTUAL_CONNECTION_STATE_WAIT_C2:
- rts = (rpcconn_rts_hdr_t*)Stream_Buffer(pdu->s);
-
- if (!rts_match_pdu_signature(&RTS_PDU_CONN_C2_SIGNATURE, rts))
+ if (!rts_match_pdu_signature(&RTS_PDU_CONN_C2_SIGNATURE, pdu->s, NULL))
{
WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/C2");
return -1;
}
- status = rts_recv_CONN_C2_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
-
- if (status < 0)
+ if (!rts_recv_CONN_C2_pdu(rpc, pdu->s))
{
WLog_ERR(TAG, "rts_recv_CONN_C2_pdu failure");
return -1;
@@ -252,7 +257,7 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
{
if (pdu->Type == PTYPE_BIND_ACK)
{
- if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0)
+ if (!rpc_recv_bind_ack_pdu(rpc, pdu->s))
{
WLog_ERR(TAG, "rpc_recv_bind_ack_pdu failure");
return -1;
@@ -301,89 +306,117 @@ static int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
{
- BYTE* buffer;
+ int rc = -1;
RPC_PDU* pdu;
- UINT32 StubOffset;
- UINT32 StubLength;
+ size_t StubOffset;
+ size_t StubLength;
RpcClientCall* call;
- rpcconn_hdr_t* header;
+ rpcconn_hdr_t header = { 0 };
+
+ assert(rpc);
+ assert(rpc->client);
+ assert(fragment);
+
pdu = rpc->client->pdu;
- buffer = (BYTE*)Stream_Buffer(fragment);
- header = (rpcconn_hdr_t*)Stream_Buffer(fragment);
+ assert(pdu);
+
+ Stream_SealLength(fragment);
+ Stream_SetPosition(fragment, 0);
+
+ if (!rts_read_pdu_header(fragment, &header))
+ goto fail;
- if (header->common.ptype == PTYPE_RESPONSE)
+ if (header.common.ptype == PTYPE_RESPONSE)
{
- rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
+ rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header.common.frag_length;
rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow -=
- header->common.frag_length;
+ header.common.frag_length;
if (rpc->VirtualConnection->DefaultOutChannel->ReceiverAvailableWindow <
(rpc->ReceiveWindow / 2))
{
- if (rts_send_flow_control_ack_pdu(rpc) < 0)
- return -1;
+ if (!rts_send_flow_control_ack_pdu(rpc))
+ goto fail;
}
- if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
+ if (!rpc_get_stub_data_info(&header, &StubOffset, &StubLength))
{
WLog_ERR(TAG, "expected stub");
- return -1;
+ goto fail;
}
if (StubLength == 4)
{
- if ((header->common.call_id == rpc->PipeCallId) &&
- (header->common.pfc_flags & PFC_LAST_FRAG))
+ if ((header.common.call_id == rpc->PipeCallId) &&
+ (header.common.pfc_flags & PFC_LAST_FRAG))
{
/* End of TsProxySetupReceivePipe */
TerminateEventArgs e;
- rpc->result = *((UINT32*)&buffer[StubOffset]);
- freerdp_abort_connect(rpc->context->instance);
- tsg_set_state(rpc->transport->tsg, TSG_STATE_TUNNEL_CLOSE_PENDING);
+ rdpContext* context = rpc->transport->context;
+ rdpTsg* tsg = rpc->transport->tsg;
+
+ assert(context);
+
+ if (Stream_Length(fragment) < StubOffset + 4)
+ goto fail;
+ Stream_SetPosition(fragment, StubOffset);
+ Stream_Read_UINT32(fragment, rpc->result);
+
+ freerdp_abort_connect(context->instance);
+ tsg_set_state(tsg, TSG_STATE_TUNNEL_CLOSE_PENDING);
EventArgsInit(&e, "freerdp");
e.code = 0;
- PubSub_OnTerminate(rpc->context->pubSub, rpc->context, &e);
- return 0;
+ PubSub_OnTerminate(context->pubSub, context, &e);
+ rc = 0;
+ goto success;
}
- if (header->common.call_id != rpc->PipeCallId)
+ if (header.common.call_id != rpc->PipeCallId)
{
/* Ignoring non-TsProxySetupReceivePipe Response */
- return 0;
+ rc = 0;
+ goto success;
}
}
if (rpc->StubFragCount == 0)
- rpc->StubCallId = header->common.call_id;
+ rpc->StubCallId = header.common.call_id;
- if (rpc->StubCallId != header->common.call_id)
+ if (rpc->StubCallId != header.common.call_id)
{
WLog_ERR(TAG,
"invalid call_id: actual: %" PRIu32 ", expected: %" PRIu32
", frag_count: %" PRIu32 "",
- rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
+ rpc->StubCallId, header.common.call_id, rpc->StubFragCount);
}
call = rpc_client_call_find_by_id(rpc->client, rpc->StubCallId);
if (!call)
- return -1;
+ goto fail;
if (call->OpNum != TsProxySetupReceivePipeOpnum)
{
- if (!Stream_EnsureCapacity(pdu->s, header->response.alloc_hint))
- return -1;
+ const rpcconn_response_hdr_t* response =
+ (const rpcconn_response_hdr_t*)&header.response;
+ if (!Stream_EnsureCapacity(pdu->s, response->alloc_hint))
+ goto fail;
+
+ if (Stream_Length(fragment) < StubOffset + StubLength)
+ goto fail;
- Stream_Write(pdu->s, &buffer[StubOffset], StubLength);
+ Stream_SetPosition(fragment, StubOffset);
+ Stream_Write(pdu->s, Stream_Pointer(fragment), StubLength);
rpc->StubFragCount++;
- if (header->response.alloc_hint == StubLength)
+ if (response->alloc_hint == StubLength)
{
pdu->Flags = RPC_PDU_FLAG_STUB;
pdu->Type = PTYPE_RESPONSE;
pdu->CallId = rpc->StubCallId;
- Stream_SealLength(pdu->s);
- rpc_client_recv_pdu(rpc, pdu);
+
+ if (rpc_client_recv_pdu(rpc, pdu) < 0)
+ goto fail;
rpc_pdu_reset(pdu);
rpc->StubFragCount = 0;
rpc->StubCallId = 0;
@@ -391,75 +424,84 @@ static int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
}
else
{
- rpc_client_receive_pipe_write(rpc->client, &buffer[StubOffset], (size_t)StubLength);
+ const rpcconn_response_hdr_t* response = &header.response;
+ if (Stream_Length(fragment) < StubOffset + StubLength)
+ goto fail;
+ Stream_SetPosition(fragment, StubOffset);
+ rpc_client_receive_pipe_write(rpc->client, Stream_Pointer(fragment),
+ (size_t)StubLength);
rpc->StubFragCount++;
- if (header->response.alloc_hint == StubLength)
+ if (response->alloc_hint == StubLength)
{
rpc->StubFragCount = 0;
rpc->StubCallId = 0;
}
}
- return 1;
+ goto success;
}
- else if (header->common.ptype == PTYPE_RTS)
+ else if (header.common.ptype == PTYPE_RTS)
{
if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
{
pdu->Flags = 0;
- pdu->Type = header->common.ptype;
- pdu->CallId = header->common.call_id;
+ pdu->Type = header.common.ptype;
+ pdu->CallId = header.common.call_id;
if (!Stream_EnsureCapacity(pdu->s, Stream_Length(fragment)))
- return -1;
+ goto fail;
- Stream_Write(pdu->s, buffer, Stream_Length(fragment));
- Stream_SealLength(pdu->s);
+ Stream_Write(pdu->s, Stream_Buffer(fragment), Stream_Length(fragment));
if (rpc_client_recv_pdu(rpc, pdu) < 0)
- return -1;
+ goto fail;
rpc_pdu_reset(pdu);
}
else
{
- if (rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length) < 0)
- return -1;
+ if (!rts_recv_out_of_sequence_pdu(rpc, fragment, &header))
+ goto fail;
}
- return 1;
+ goto success;
}
- else if (header->common.ptype == PTYPE_BIND_ACK)
+ else if (header.common.ptype == PTYPE_BIND_ACK)
{
pdu->Flags = 0;
- pdu->Type = header->common.ptype;
- pdu->CallId = header->common.call_id;
+ pdu->Type = header.common.ptype;
+ pdu->CallId = header.common.call_id;
if (!Stream_EnsureCapacity(pdu->s, Stream_Length(fragment)))
- return -1;
+ goto fail;
- Stream_Write(pdu->s, buffer, Stream_Length(fragment));
- Stream_SealLength(pdu->s);
+ Stream_Write(pdu->s, Stream_Buffer(fragment), Stream_Length(fragment));
if (rpc_client_recv_pdu(rpc, pdu) < 0)
- return -1;
+ goto fail;
rpc_pdu_reset(pdu);
- return 1;
+ goto success;
}
- else if (header->common.ptype == PTYPE_FAULT)
+ else if (header.common.ptype == PTYPE_FAULT)
{
- rpc_recv_fault_pdu(header->fault.status);
- return -1;
+ const rpcconn_fault_hdr_t* fault = (const rpcconn_fault_hdr_t*)&header.fault;
+ rpc_recv_fault_pdu(fault->status);
+ goto fail;
}
else
{
- WLog_ERR(TAG, "unexpected RPC PDU type 0x%02" PRIX8 "", header->common.ptype);
- return -1;
+ WLog_ERR(TAG, "unexpected RPC PDU type 0x%02" PRIX8 "", header.common.ptype);
+ goto fail;
}
- return 1;
+success:
+ rc = (rc < 0) ? 1 : 0; /* In case of default error return change to 1, otherwise we already set
+ the return code */
+fail:
+ rts_free_pdu_header(&header, FALSE);
+ return rc;
}
static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
@@ -509,7 +551,7 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
/* Send CONN/A1 PDU over OUT channel */
- if (rts_send_CONN_A1_pdu(rpc) < 0)
+ if (!rts_send_CONN_A1_pdu(rpc))
{
http_response_free(response);
WLog_ERR(TAG, "rpc_send_CONN_A1_pdu error!");
@@ -549,7 +591,8 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
if (statusCode == HTTP_STATUS_DENIED)
{
- freerdp_set_last_error_if_not(rpc->context, FREERDP_ERROR_AUTHENTICATION_FAILED);
+ rdpContext* context = rpc->context;
+ freerdp_set_last_error_if_not(context, FREERDP_ERROR_AUTHENTICATION_FAILED);
}
http_response_free(response);
@@ -563,12 +606,13 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
}
else
{
- wStream* fragment;
- rpcconn_common_hdr_t* header;
- fragment = rpc->client->ReceiveFragment;
+ wStream* fragment = rpc->client->ReceiveFragment;
while (1)
{
+ size_t pos;
+ rpcconn_common_hdr_t header = { 0 };
+
while (Stream_GetPosition(fragment) < RPC_COMMON_FIELDS_LENGTH)
{
status = rpc_channel_read(&outChannel->common, fragment,
@@ -581,22 +625,27 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
return 0;
}
- header = (rpcconn_common_hdr_t*)Stream_Buffer(fragment);
+ pos = Stream_GetPosition(fragment);
+ Stream_SetPosition(fragment, 0);
+
+ /* Ignore errors, the PDU might not be complete. */
+ rts_read_common_pdu_header(fragment, &header);
+ Stream_SetPosition(fragment, pos);
- if (header->frag_length > rpc->max_recv_frag)
+ if (header.frag_length > rpc->max_recv_frag)
{
WLog_ERR(TAG,
"rpc_client_recv: invalid fragment size: %" PRIu16 " (max: %" PRIu16 ")",
- header->frag_length, rpc->max_recv_frag);
+ header.frag_length, rpc->max_recv_frag);
winpr_HexDump(TAG, WLOG_ERROR, Stream_Buffer(fragment),
Stream_GetPosition(fragment));
return -1;
}
- while (Stream_GetPosition(fragment) < header->frag_length)
+ while (Stream_GetPosition(fragment) < header.frag_length)
{
status = rpc_channel_read(&outChannel->common, fragment,
- header->frag_length - Stream_GetPosition(fragment));
+ header.frag_length - Stream_GetPosition(fragment));
if (status < 0)
{
@@ -604,14 +653,12 @@ static int rpc_client_default_out_channel_recv(rdpRpc* rpc)
return -1;
}
- if (Stream_GetPosition(fragment) < header->frag_length)
+ if (Stream_GetPosition(fragment) < header.frag_length)
return 0;
}
{
/* complete fragment received */
- Stream_SealLength(fragment);
- Stream_SetPosition(fragment, 0);
status = rpc_client_recv_fragment(rpc, fragment);
if (status < 0)
@@ -663,10 +710,10 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc)
if (rpc_ncacn_http_send_out_channel_request(&nextOutChannel->common, TRUE))
{
rpc_ncacn_http_ntlm_uninit(&nextOutChannel->common);
- status = rts_send_OUT_R1_A3_pdu(rpc);
- if (status >= 0)
+ if (rts_send_OUT_R1_A3_pdu(rpc))
{
+ status = 1;
rpc_out_channel_transition_to_state(
nextOutChannel, CLIENT_OUT_CHANNEL_STATE_OPENED_A6W);
}
@@ -687,11 +734,14 @@ static int rpc_client_nondefault_out_channel_recv(rdpRpc* rpc)
break;
+ case CLIENT_OUT_CHANNEL_STATE_INITIAL:
+ case CLIENT_OUT_CHANNEL_STATE_CONNECTED:
+ case CLIENT_OUT_CHANNEL_STATE_NEGOTIATED:
default:
WLog_ERR(TAG,
"rpc_client_nondefault_out_channel_recv: Unexpected message %08" PRIx32,
nextOutChannel->State);
- return -1;
+ status = -1;
}
http_response_free(response);
@@ -769,7 +819,7 @@ int rpc_client_in_channel_recv(rdpRpc* rpc)
/* Send CONN/B1 PDU over IN channel */
- if (rts_send_CONN_B1_pdu(rpc) < 0)
+ if (!rts_send_CONN_B1_pdu(rpc))
{
WLog_ERR(TAG, "rpc_send_CONN_B1_pdu error!");
http_response_free(response);
@@ -810,8 +860,8 @@ int rpc_client_in_channel_recv(rdpRpc* rpc)
RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT32 CallId)
{
- int index;
- int count;
+ size_t index;
+ size_t count;
RpcClientCall* clientCall = NULL;
if (!client)
@@ -856,18 +906,23 @@ static void rpc_array_client_call_free(void* call)
rpc_client_call_free((RpcClientCall*)call);
}
-int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length)
+int rpc_in_channel_send_pdu(RpcInChannel* inChannel, const BYTE* buffer, size_t length)
{
- int status;
+ SSIZE_T status;
RpcClientCall* clientCall;
- rpcconn_common_hdr_t* header;
+ wStream s;
+ rpcconn_common_hdr_t header = { 0 };
+
status = rpc_channel_write(&inChannel->common, buffer, length);
if (status <= 0)
return -1;
- header = (rpcconn_common_hdr_t*)buffer;
- clientCall = rpc_client_call_find_by_id(inChannel->common.client, header->call_id);
+ Stream_StaticInit(&s, buffer, length);
+ if (!rts_read_common_pdu_header(&s, &header))
+ return -1;
+
+ clientCall = rpc_client_call_find_by_id(inChannel->common.client, header.call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
/*
@@ -877,7 +932,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length
* variables specified by this abstract data model.
*/
- if (header->ptype == PTYPE_REQUEST)
+ if (header.ptype == PTYPE_REQUEST)
{
inChannel->BytesSent += status;
inChannel->SenderAvailableWindow -= status;
@@ -888,7 +943,7 @@ int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length
BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum)
{
- UINT32 offset;
+ size_t offset;
BYTE* buffer = NULL;
UINT32 stub_data_pad;
SecBuffer Buffers[2] = { 0 };
@@ -941,15 +996,15 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum)
if (size < 0)
goto fail;
- rpc_pdu_header_init(rpc, (rpcconn_hdr_t*)request_pdu);
- request_pdu->ptype = PTYPE_REQUEST;
- request_pdu->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG;
- request_pdu->auth_length = (UINT16)size;
- request_pdu->call_id = rpc->CallId++;
+ request_pdu->header = rpc_pdu_header_init(rpc);
+ request_pdu->header.ptype = PTYPE_REQUEST;
+ request_pdu->header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG;
+ request_pdu->header.auth_length = (UINT16)size;
+ request_pdu->header.call_id = rpc->CallId++;
request_pdu->alloc_hint = length;
request_pdu->p_cont_id = 0x0000;
request_pdu->opnum = opnum;
- clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
+ clientCall = rpc_client_call_new(request_pdu->header.call_id, request_pdu->opnum);
if (!clientCall)
goto fail;
@@ -961,7 +1016,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum)
}
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
- rpc->PipeCallId = request_pdu->call_id;
+ rpc->PipeCallId = request_pdu->header.call_id;
request_pdu->stub_data = Stream_Buffer(s);
offset = 24;
@@ -972,9 +1027,9 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum)
request_pdu->auth_verifier.auth_level = RPC_C_AUTHN_LEVEL_PKT_INTEGRITY;
request_pdu->auth_verifier.auth_reserved = 0x00;
request_pdu->auth_verifier.auth_context_id = 0x00000000;
- offset += (8 + request_pdu->auth_length);
- request_pdu->frag_length = offset;
- buffer = (BYTE*)calloc(1, request_pdu->frag_length);
+ offset += (8 + request_pdu->header.auth_length);
+ request_pdu->header.frag_length = offset;
+ buffer = (BYTE*)calloc(1, request_pdu->header.frag_length);
if (!buffer)
goto fail;
@@ -1007,7 +1062,7 @@ BOOL rpc_client_write_call(rdpRpc* rpc, wStream* s, UINT16 opnum)
CopyMemory(&buffer[offset], Buffers[1].pvBuffer, Buffers[1].cbBuffer);
offset += Buffers[1].cbBuffer;
- if (rpc_in_channel_send_pdu(inChannel, buffer, request_pdu->frag_length) < 0)
+ if (rpc_in_channel_send_pdu(inChannel, buffer, request_pdu->header.frag_length) < 0)
goto fail;
rc = TRUE;
@@ -1031,7 +1086,7 @@ static BOOL rpc_client_resolve_gateway(rdpSettings* settings, char** host, UINT1
const char* peerHostname = settings->GatewayHostname;
const char* proxyUsername = settings->ProxyUsername;
const char* proxyPassword = settings->ProxyPassword;
- *port = settings->GatewayPort;
+ *port = (UINT16)settings->GatewayPort;
*isProxy = proxy_prepare(settings, &peerHostname, port, &proxyUsername, &proxyPassword);
result = freerdp_tcp_resolve_host(peerHostname, *port, 0);
diff --git a/libfreerdp/core/gateway/rpc_client.h b/libfreerdp/core/gateway/rpc_client.h
index af0b8ce..7b509de 100644
--- a/libfreerdp/core/gateway/rpc_client.h
+++ b/libfreerdp/core/gateway/rpc_client.h
@@ -31,7 +31,8 @@ FREERDP_LOCAL RpcClientCall* rpc_client_call_find_by_id(RpcClient* client, UINT3
FREERDP_LOCAL RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum);
FREERDP_LOCAL void rpc_client_call_free(RpcClientCall* client_call);
-FREERDP_LOCAL int rpc_in_channel_send_pdu(RpcInChannel* inChannel, BYTE* buffer, UINT32 length);
+FREERDP_LOCAL int rpc_in_channel_send_pdu(RpcInChannel* inChannel, const BYTE* buffer,
+ size_t length);
FREERDP_LOCAL int rpc_client_in_channel_recv(rdpRpc* rpc);
FREERDP_LOCAL int rpc_client_out_channel_recv(rdpRpc* rpc);
diff --git a/libfreerdp/core/gateway/rpc_fault.c b/libfreerdp/core/gateway/rpc_fault.c
index 7259f04..c4cb086 100644
--- a/libfreerdp/core/gateway/rpc_fault.c
+++ b/libfreerdp/core/gateway/rpc_fault.c
@@ -133,10 +133,7 @@ static const RPC_FAULT_CODE RPC_FAULT_CODES[] = {
CAT_GATEWAY)
DEFINE_RPC_FAULT_CODE(
RPC_S_INVALID_OBJECT,
- CAT_GATEWAY){
- 0,
- NULL,
- NULL }
+ CAT_GATEWAY)
};
static const RPC_FAULT_CODE RPC_TSG_FAULT_CODES[] = {
@@ -222,9 +219,7 @@ static const RPC_FAULT_CODE RPC_TSG_FAULT_CODES[] = {
DEFINE_RPC_FAULT_CODE(
HRESULT_CODE(
RPC_S_CALL_CANCELLED),
- CAT_GATEWAY){
- 0, NULL,
- NULL }
+ CAT_GATEWAY)
};
/**
@@ -377,22 +372,22 @@ const char* rpc_error_to_string(UINT32 code)
size_t index;
static char buffer[1024];
- for (index = 0; RPC_FAULT_CODES[index].name != NULL; index++)
+ for (index = 0; index < ARRAYSIZE(RPC_FAULT_CODES); index++)
{
- if (RPC_FAULT_CODES[index].code == code)
+ const RPC_FAULT_CODE* const current = &RPC_FAULT_CODES[index];
+ if (current->code == code)
{
- sprintf_s(buffer, ARRAYSIZE(buffer), "%s [0x%08" PRIX32 "]",
- RPC_FAULT_CODES[index].name, code);
+ sprintf_s(buffer, ARRAYSIZE(buffer), "%s", current->name);
goto out;
}
}
- for (index = 0; RPC_TSG_FAULT_CODES[index].name != NULL; index++)
+ for (index = 0; index < ARRAYSIZE(RPC_TSG_FAULT_CODES); index++)
{
- if (RPC_TSG_FAULT_CODES[index].code == code)
+ const RPC_FAULT_CODE* const current = &RPC_TSG_FAULT_CODES[index];
+ if (current->code == code)
{
- sprintf_s(buffer, ARRAYSIZE(buffer), "%s [0x%08" PRIX32 "]",
- RPC_TSG_FAULT_CODES[index].name, code);
+ sprintf_s(buffer, ARRAYSIZE(buffer), "%s", current->name);
goto out;
}
}
@@ -406,16 +401,18 @@ const char* rpc_error_to_category(UINT32 code)
{
size_t index;
- for (index = 0; RPC_FAULT_CODES[index].category != NULL; index++)
+ for (index = 0; index < ARRAYSIZE(RPC_FAULT_CODES); index++)
{
- if (RPC_FAULT_CODES[index].code == code)
- return RPC_FAULT_CODES[index].category;
+ const RPC_FAULT_CODE* const current = &RPC_FAULT_CODES[index];
+ if (current->code == code)
+ return current->category;
}
- for (index = 0; RPC_TSG_FAULT_CODES[index].category != NULL; index++)
+ for (index = 0; index < ARRAYSIZE(RPC_TSG_FAULT_CODES); index++)
{
- if (RPC_TSG_FAULT_CODES[index].code == code)
- return RPC_TSG_FAULT_CODES[index].category;
+ const RPC_FAULT_CODE* const current = &RPC_TSG_FAULT_CODES[index];
+ if (current->code == code)
+ return current->category;
}
return "UNKNOWN";
diff --git a/libfreerdp/core/gateway/rts.c b/libfreerdp/core/gateway/rts.c
index 6218f81..a053dfc 100644
--- a/libfreerdp/core/gateway/rts.c
+++ b/libfreerdp/core/gateway/rts.c
@@ -21,6 +21,7 @@
#include "config.h"
#endif
+#include <assert.h>
#include <winpr/crt.h>
#include <winpr/crypto.h>
#include <winpr/winhttp.h>
@@ -29,6 +30,7 @@
#include "ncacn_http.h"
#include "rpc_client.h"
+#include "rts_signature.h"
#include "rts.h"
@@ -67,543 +69,1639 @@
*
*/
-static void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
+static const char* rts_pdu_ptype_to_string(UINT32 ptype)
{
- ZeroMemory(header, sizeof(*header));
- header->rpc_vers = 5;
- header->rpc_vers_minor = 0;
- header->ptype = PTYPE_RTS;
- header->packed_drep[0] = 0x10;
- header->packed_drep[1] = 0x00;
- header->packed_drep[2] = 0x00;
- header->packed_drep[3] = 0x00;
- header->pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG;
- header->auth_length = 0;
- header->call_id = 0;
+ switch (ptype)
+ {
+ case PTYPE_REQUEST:
+ return "PTYPE_REQUEST";
+ case PTYPE_PING:
+ return "PTYPE_PING";
+ case PTYPE_RESPONSE:
+ return "PTYPE_RESPONSE";
+ case PTYPE_FAULT:
+ return "PTYPE_FAULT";
+ case PTYPE_WORKING:
+ return "PTYPE_WORKING";
+ case PTYPE_NOCALL:
+ return "PTYPE_NOCALL";
+ case PTYPE_REJECT:
+ return "PTYPE_REJECT";
+ case PTYPE_ACK:
+ return "PTYPE_ACK";
+ case PTYPE_CL_CANCEL:
+ return "PTYPE_CL_CANCEL";
+ case PTYPE_FACK:
+ return "PTYPE_FACK";
+ case PTYPE_CANCEL_ACK:
+ return "PTYPE_CANCEL_ACK";
+ case PTYPE_BIND:
+ return "PTYPE_BIND";
+ case PTYPE_BIND_ACK:
+ return "PTYPE_BIND_ACK";
+ case PTYPE_BIND_NAK:
+ return "PTYPE_BIND_NAK";
+ case PTYPE_ALTER_CONTEXT:
+ return "PTYPE_ALTER_CONTEXT";
+ case PTYPE_ALTER_CONTEXT_RESP:
+ return "PTYPE_ALTER_CONTEXT_RESP";
+ case PTYPE_RPC_AUTH_3:
+ return "PTYPE_RPC_AUTH_3";
+ case PTYPE_SHUTDOWN:
+ return "PTYPE_SHUTDOWN";
+ case PTYPE_CO_CANCEL:
+ return "PTYPE_CO_CANCEL";
+ case PTYPE_ORPHANED:
+ return "PTYPE_ORPHANED";
+ case PTYPE_RTS:
+ return "PTYPE_RTS";
+ default:
+ return "UNKNOWN";
+ }
+}
+
+static rpcconn_rts_hdr_t rts_pdu_header_init(void)
+{
+ rpcconn_rts_hdr_t header = { 0 };
+ header.header.rpc_vers = 5;
+ header.header.rpc_vers_minor = 0;
+ header.header.ptype = PTYPE_RTS;
+ header.header.packed_drep[0] = 0x10;
+ header.header.packed_drep[1] = 0x00;
+ header.header.packed_drep[2] = 0x00;
+ header.header.packed_drep[3] = 0x00;
+ header.header.pfc_flags = PFC_FIRST_FRAG | PFC_LAST_FRAG;
+ header.header.auth_length = 0;
+ header.header.call_id = 0;
+
+ return header;
+}
+
+static BOOL rts_align_stream(wStream* s, size_t alignment)
+{
+ size_t pos, pad;
+
+ assert(s);
+ assert(alignment > 0);
+
+ pos = Stream_GetPosition(s);
+ pad = rpc_offset_align(&pos, alignment);
+ return Stream_SafeSeek(s, pad);
+}
+
+static char* sdup(const void* src, size_t length)
+{
+ char* dst;
+ assert(src || (length == 0));
+ if (length == 0)
+ return NULL;
+
+ dst = calloc(length + 1, sizeof(char));
+ if (!dst)
+ return NULL;
+ memcpy(dst, src, length);
+ return dst;
+}
+
+static BOOL rts_write_common_pdu_header(wStream* s, const rpcconn_common_hdr_t* header)
+{
+ assert(s);
+ assert(header);
+ if (!Stream_EnsureRemainingCapacity(s, sizeof(rpcconn_common_hdr_t)))
+ return FALSE;
+
+ Stream_Write_UINT8(s, header->rpc_vers);
+ Stream_Write_UINT8(s, header->rpc_vers_minor);
+ Stream_Write_UINT8(s, header->ptype);
+ Stream_Write_UINT8(s, header->pfc_flags);
+ Stream_Write(s, header->packed_drep, ARRAYSIZE(header->packed_drep));
+ Stream_Write_UINT16(s, header->frag_length);
+ Stream_Write_UINT16(s, header->auth_length);
+ Stream_Write_UINT32(s, header->call_id);
+ return TRUE;
+}
+
+BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header)
+{
+ size_t left;
+ assert(s);
+ assert(header);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+
+ Stream_Read_UINT8(s, header->rpc_vers);
+ Stream_Read_UINT8(s, header->rpc_vers_minor);
+ Stream_Read_UINT8(s, header->ptype);
+ Stream_Read_UINT8(s, header->pfc_flags);
+ Stream_Read(s, header->packed_drep, ARRAYSIZE(header->packed_drep));
+ Stream_Read_UINT16(s, header->frag_length);
+ Stream_Read_UINT16(s, header->auth_length);
+ Stream_Read_UINT32(s, header->call_id);
+
+ if (header->frag_length < sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+
+ left = Stream_GetRemainingLength(s);
+ if (left < header->frag_length - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL rts_read_auth_verifier_no_checks(wStream* s, auth_verifier_co_t* auth,
+ const rpcconn_common_hdr_t* header, size_t* startPos)
+{
+ assert(s);
+ assert(auth);
+ assert(header);
+
+ assert(header->frag_length > header->auth_length);
+
+ if (startPos)
+ *startPos = Stream_GetPosition(s);
+
+ /* Read the auth verifier and check padding matches frag_length */
+ {
+ const size_t expected = header->frag_length - header->auth_length - 8;
+
+ Stream_SetPosition(s, expected);
+ if (Stream_GetRemainingLength(s) < sizeof(auth_verifier_co_t))
+ return FALSE;
+
+ Stream_Read_UINT8(s, auth->auth_type);
+ Stream_Read_UINT8(s, auth->auth_level);
+ Stream_Read_UINT8(s, auth->auth_pad_length);
+ Stream_Read_UINT8(s, auth->auth_reserved);
+ Stream_Read_UINT32(s, auth->auth_context_id);
+ }
+
+ if (header->auth_length != 0)
+ {
+ const void* ptr = Stream_Pointer(s);
+ if (!Stream_SafeSeek(s, header->auth_length))
+ return FALSE;
+ auth->auth_value = (BYTE*)sdup(ptr, header->auth_length);
+ if (auth->auth_value == NULL)
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static BOOL rts_read_auth_verifier(wStream* s, auth_verifier_co_t* auth,
+ const rpcconn_common_hdr_t* header)
+{
+ size_t pos;
+ assert(s);
+ assert(auth);
+ assert(header);
+
+ if (!rts_read_auth_verifier_no_checks(s, auth, header, &pos))
+ return FALSE;
+
+ {
+ const size_t expected = header->frag_length - header->auth_length - 8;
+ assert(pos + auth->auth_pad_length == expected);
+ }
+
+ return TRUE;
+}
+
+static BOOL rts_read_auth_verifier_with_stub(wStream* s, auth_verifier_co_t* auth,
+ rpcconn_common_hdr_t* header)
+{
+ size_t pos;
+ size_t alloc_hint = 0;
+ BYTE** ptr = NULL;
+
+ if (!rts_read_auth_verifier_no_checks(s, auth, header, &pos))
+ return FALSE;
+
+ switch (header->ptype)
+ {
+ case PTYPE_FAULT:
+ {
+ rpcconn_fault_hdr_t* hdr = (rpcconn_fault_hdr_t*)header;
+ alloc_hint = hdr->alloc_hint;
+ ptr = &hdr->stub_data;
+ }
+ break;
+ case PTYPE_RESPONSE:
+ {
+ rpcconn_response_hdr_t* hdr = (rpcconn_response_hdr_t*)header;
+ alloc_hint = hdr->alloc_hint;
+ ptr = &hdr->stub_data;
+ }
+ break;
+ case PTYPE_REQUEST:
+ {
+ rpcconn_request_hdr_t* hdr = (rpcconn_request_hdr_t*)header;
+ alloc_hint = hdr->alloc_hint;
+ ptr = &hdr->stub_data;
+ }
+ break;
+ default:
+ return FALSE;
+ }
+
+ if (alloc_hint > 0)
+ {
+ const size_t size =
+ header->frag_length - header->auth_length - 8 - auth->auth_pad_length - pos;
+ const void* src = Stream_Buffer(s) + pos;
+
+ *ptr = (BYTE*)sdup(src, size);
+ if (!*ptr)
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static void rts_free_auth_verifier(auth_verifier_co_t* auth)
+{
+ if (!auth)
+ return;
+ free(auth->auth_value);
+}
+
+static BOOL rts_write_auth_verifier(wStream* s, const auth_verifier_co_t* auth,
+ const rpcconn_common_hdr_t* header)
+{
+ size_t pos;
+ UINT8 auth_pad_length = 0;
+
+ assert(s);
+ assert(auth);
+ assert(header);
+
+ /* Align start to a multiple of 4 */
+ pos = Stream_GetPosition(s);
+ if ((pos % 4) != 0)
+ {
+ auth_pad_length = 4 - (pos % 4);
+ if (!Stream_EnsureRemainingCapacity(s, auth_pad_length))
+ return FALSE;
+ Stream_Zero(s, auth_pad_length);
+ }
+
+ assert(header->frag_length + 8ull > header->auth_length);
+ {
+ size_t pos = Stream_GetPosition(s);
+ size_t expected = header->frag_length - header->auth_length - 8;
+
+ assert(pos == expected);
+ }
+
+ if (!Stream_EnsureRemainingCapacity(s, sizeof(auth_verifier_co_t)))
+ return FALSE;
+
+ Stream_Write_UINT8(s, auth->auth_type);
+ Stream_Write_UINT8(s, auth->auth_level);
+ Stream_Write_UINT8(s, auth_pad_length);
+ Stream_Write_UINT8(s, 0); /* auth->auth_reserved */
+ Stream_Write_UINT32(s, auth->auth_context_id);
+
+ if (!Stream_EnsureRemainingCapacity(s, header->auth_length))
+ return FALSE;
+ Stream_Write(s, auth->auth_value, header->auth_length);
+ return TRUE;
+}
+
+static BOOL rts_read_version(wStream* s, p_rt_version_t* version)
+{
+ assert(s);
+ assert(version);
+
+ if (Stream_GetRemainingLength(s) < 2 * sizeof(UINT8))
+ return FALSE;
+ Stream_Read_UINT8(s, version->major);
+ Stream_Read_UINT8(s, version->minor);
+ return TRUE;
+}
+
+void rts_free_supported_versions(p_rt_versions_supported_t* versions)
+{
+ if (!versions)
+ return;
+ free(versions->p_protocols);
+ versions->p_protocols = NULL;
+}
+
+static BOOL rts_read_supported_versions(wStream* s, p_rt_versions_supported_t* versions)
+{
+ BYTE x;
+
+ assert(s);
+ assert(versions);
+
+ if (Stream_GetRemainingLength(s) < sizeof(UINT8))
+ return FALSE;
+
+ Stream_Read_UINT8(s, versions->n_protocols); /* count */
+
+ if (versions->n_protocols > 0)
+ {
+ versions->p_protocols = calloc(versions->n_protocols, sizeof(p_rt_version_t));
+ if (!versions->p_protocols)
+ return FALSE;
+ }
+ for (x = 0; x < versions->n_protocols; x++)
+ {
+ p_rt_version_t* version = &versions->p_protocols[x];
+ if (!rts_read_version(s, version)) /* size_is(n_protocols) */
+ {
+ rts_free_supported_versions(versions);
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+static BOOL rts_read_port_any(wStream* s, port_any_t* port)
+{
+ const void* ptr;
+
+ assert(s);
+ assert(port);
+
+ if (Stream_GetRemainingLength(s) < sizeof(UINT16))
+ return FALSE;
+
+ Stream_Read_UINT16(s, port->length);
+ if (port->length == 0)
+ return TRUE;
+
+ ptr = Stream_Pointer(s);
+ if (!Stream_SafeSeek(s, port->length))
+ return FALSE;
+ port->port_spec = sdup(ptr, port->length);
+ return port->port_spec != NULL;
+}
+
+static void rts_free_port_any(port_any_t* port)
+{
+ if (!port)
+ return;
+ free(port->port_spec);
+}
+
+static BOOL rts_read_uuid(wStream* s, p_uuid_t* uuid)
+{
+ assert(s);
+ assert(uuid);
+
+ if (Stream_GetRemainingLength(s) < sizeof(p_uuid_t))
+ return FALSE;
+
+ Stream_Read_UINT32(s, uuid->time_low);
+ Stream_Read_UINT16(s, uuid->time_mid);
+ Stream_Read_UINT16(s, uuid->time_hi_and_version);
+ Stream_Read_UINT8(s, uuid->clock_seq_hi_and_reserved);
+ Stream_Read_UINT8(s, uuid->clock_seq_low);
+ Stream_Read(s, uuid->node, ARRAYSIZE(uuid->node));
+ return TRUE;
+}
+
+static BOOL rts_write_uuid(wStream* s, const p_uuid_t* uuid)
+{
+ assert(s);
+ assert(uuid);
+
+ if (!Stream_EnsureRemainingCapacity(s, sizeof(p_uuid_t)))
+ return FALSE;
+
+ Stream_Write_UINT32(s, uuid->time_low);
+ Stream_Write_UINT16(s, uuid->time_mid);
+ Stream_Write_UINT16(s, uuid->time_hi_and_version);
+ Stream_Write_UINT8(s, uuid->clock_seq_hi_and_reserved);
+ Stream_Write_UINT8(s, uuid->clock_seq_low);
+ Stream_Write(s, uuid->node, ARRAYSIZE(uuid->node));
+ return TRUE;
+}
+
+static p_syntax_id_t* rts_syntax_id_new(size_t count)
+{
+ return calloc(count, sizeof(p_syntax_id_t));
+}
+
+static void rts_syntax_id_free(p_syntax_id_t* ptr)
+{
+ free(ptr);
+}
+
+static BOOL rts_read_syntax_id(wStream* s, p_syntax_id_t* syntax_id)
+{
+ assert(s);
+ assert(syntax_id);
+
+ if (!rts_read_uuid(s, &syntax_id->if_uuid))
+ return FALSE;
+
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+
+ Stream_Read_UINT32(s, syntax_id->if_version);
+ return TRUE;
+}
+
+static BOOL rts_write_syntax_id(wStream* s, const p_syntax_id_t* syntax_id)
+{
+ assert(s);
+ assert(syntax_id);
+
+ if (!rts_write_uuid(s, &syntax_id->if_uuid))
+ return FALSE;
+
+ if (!Stream_EnsureRemainingCapacity(s, 4))
+ return FALSE;
+
+ Stream_Write_UINT32(s, syntax_id->if_version);
+ return TRUE;
+}
+
+p_cont_elem_t* rts_context_elem_new(size_t count)
+{
+ p_cont_elem_t* ctx = calloc(count, sizeof(p_cont_elem_t));
+ return ctx;
+}
+
+void rts_context_elem_free(p_cont_elem_t* ptr)
+{
+ if (!ptr)
+ return;
+ rts_syntax_id_free(ptr->transfer_syntaxes);
+ free(ptr);
+}
+
+static BOOL rts_read_context_elem(wStream* s, p_cont_elem_t* element)
+{
+ BYTE x;
+ assert(s);
+ assert(element);
+
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+
+ Stream_Read_UINT16(s, element->p_cont_id);
+ Stream_Read_UINT8(s, element->n_transfer_syn); /* number of items */
+ Stream_Read_UINT8(s, element->reserved); /* alignment pad, m.b.z. */
+
+ if (!rts_read_syntax_id(s, &element->abstract_syntax)) /* transfer syntax list */
+ return FALSE;
+
+ if (element->n_transfer_syn > 0)
+ {
+ element->transfer_syntaxes = rts_syntax_id_new(element->n_transfer_syn);
+ if (!element->transfer_syntaxes)
+ return FALSE;
+ for (x = 0; x < element->n_transfer_syn; x++)
+ {
+ p_syntax_id_t* syn = &element->transfer_syntaxes[x];
+ if (!rts_read_syntax_id(s, syn)) /* size_is(n_transfer_syn) */
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+static BOOL rts_write_context_elem(wStream* s, const p_cont_elem_t* element)
+{
+ BYTE x;
+ assert(s);
+ assert(element);
+
+ if (!Stream_EnsureRemainingCapacity(s, 4))
+ return FALSE;
+ Stream_Write_UINT16(s, element->p_cont_id);
+ Stream_Write_UINT8(s, element->n_transfer_syn); /* number of items */
+ Stream_Write_UINT8(s, element->reserved); /* alignment pad, m.b.z. */
+ if (!rts_write_syntax_id(s, &element->abstract_syntax)) /* transfer syntax list */
+ return FALSE;
+
+ for (x = 0; x < element->n_transfer_syn; x++)
+ {
+ const p_syntax_id_t* syn = &element->transfer_syntaxes[x];
+ if (!rts_write_syntax_id(s, syn)) /* size_is(n_transfer_syn) */
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static BOOL rts_read_context_list(wStream* s, p_cont_list_t* list)
+{
+ BYTE x;
+
+ assert(s);
+ assert(list);
+
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+ Stream_Read_UINT8(s, list->n_context_elem); /* number of items */
+ Stream_Read_UINT8(s, list->reserved); /* alignment pad, m.b.z. */
+ Stream_Read_UINT16(s, list->reserved2); /* alignment pad, m.b.z. */
+
+ if (list->n_context_elem > 0)
+ {
+ list->p_cont_elem = rts_context_elem_new(list->n_context_elem);
+ if (!list->p_cont_elem)
+ return FALSE;
+ for (x = 0; x < list->n_context_elem; x++)
+ {
+ p_cont_elem_t* element = &list->p_cont_elem[x];
+ if (!rts_read_context_elem(s, element))
+ return FALSE;
+ }
+ }
+ return TRUE;
+}
+
+static void rts_free_context_list(p_cont_list_t* list)
+{
+ if (!list)
+ return;
+ rts_context_elem_free(list->p_cont_elem);
+}
+
+static BOOL rts_write_context_list(wStream* s, const p_cont_list_t* list)
+{
+ BYTE x;
+
+ assert(s);
+ assert(list);
+
+ if (!Stream_EnsureRemainingCapacity(s, 4))
+ return FALSE;
+ Stream_Write_UINT8(s, list->n_context_elem); /* number of items */
+ Stream_Write_UINT8(s, 0); /* alignment pad, m.b.z. */
+ Stream_Write_UINT16(s, 0); /* alignment pad, m.b.z. */
+
+ for (x = 0; x < list->n_context_elem; x++)
+ {
+ const p_cont_elem_t* element = &list->p_cont_elem[x];
+ if (!rts_write_context_elem(s, element))
+ return FALSE;
+ }
+ return TRUE;
+}
+
+static p_result_t* rts_result_new(size_t count)
+{
+ return calloc(count, sizeof(p_result_t));
+}
+
+static void rts_result_free(p_result_t* results)
+{
+ if (!results)
+ return;
+ free(results);
+}
+
+static BOOL rts_read_result(wStream* s, p_result_t* result)
+{
+ assert(s);
+ assert(result);
+
+ if (Stream_GetRemainingLength(s) < 2)
+ return FALSE;
+ Stream_Read_UINT16(s, result->result);
+ Stream_Read_UINT16(s, result->reason);
+
+ return rts_read_syntax_id(s, &result->transfer_syntax);
+}
+
+static void rts_free_result(p_result_t* result)
+{
+ if (!result)
+ return;
+}
+
+static BOOL rts_read_result_list(wStream* s, p_result_list_t* list)
+{
+ BYTE x;
+
+ assert(s);
+ assert(list);
+
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+ Stream_Read_UINT8(s, list->n_results); /* count */
+ Stream_Read_UINT8(s, list->reserved); /* alignment pad, m.b.z. */
+ Stream_Read_UINT16(s, list->reserved2); /* alignment pad, m.b.z. */
+
+ if (list->n_results > 0)
+ {
+ list->p_results = rts_result_new(list->n_results);
+ if (!list->p_results)
+ return FALSE;
+
+ for (x = 0; x < list->n_results; x++)
+ {
+ p_result_t* result = &list->p_results[x]; /* size_is(n_results) */
+ if (!rts_read_result(s, result))
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+static void rts_free_result_list(p_result_list_t* list)
+{
+ BYTE x;
+
+ if (!list)
+ return;
+ for (x = 0; x < list->n_results; x++)
+ {
+ p_result_t* result = &list->p_results[x];
+ rts_free_result(result);
+ }
+ rts_result_free(list->p_results);
+}
+
+static void rts_free_pdu_alter_context(rpcconn_alter_context_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+
+ rts_free_context_list(&ctx->p_context_elem);
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_alter_context(wStream* s, rpcconn_alter_context_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_alter_context_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+
+ Stream_Read_UINT16(s, ctx->max_xmit_frag);
+ Stream_Read_UINT16(s, ctx->max_recv_frag);
+ Stream_Read_UINT32(s, ctx->assoc_group_id);
+
+ if (!rts_read_context_list(s, &ctx->p_context_elem))
+ return FALSE;
+
+ if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header))
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL rts_read_pdu_alter_context_response(wStream* s,
+ rpcconn_alter_context_response_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_alter_context_response_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT16(s, ctx->max_xmit_frag);
+ Stream_Read_UINT16(s, ctx->max_recv_frag);
+ Stream_Read_UINT32(s, ctx->assoc_group_id);
+
+ if (!rts_read_port_any(s, &ctx->sec_addr))
+ return FALSE;
+
+ if (!rts_align_stream(s, 4))
+ return FALSE;
+
+ if (!rts_read_result_list(s, &ctx->p_result_list))
+ return FALSE;
+
+ if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header))
+ return FALSE;
+
+ return TRUE;
+}
+
+static void rts_free_pdu_alter_context_response(rpcconn_alter_context_response_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+
+ rts_free_port_any(&ctx->sec_addr);
+ rts_free_result_list(&ctx->p_result_list);
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_bind(wStream* s, rpcconn_bind_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_bind_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT16(s, ctx->max_xmit_frag);
+ Stream_Read_UINT16(s, ctx->max_recv_frag);
+ Stream_Read_UINT32(s, ctx->assoc_group_id);
+
+ if (!rts_read_context_list(s, &ctx->p_context_elem))
+ return FALSE;
+
+ if (!rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header))
+ return FALSE;
+
+ return TRUE;
+}
+
+static void rts_free_pdu_bind(rpcconn_bind_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_context_list(&ctx->p_context_elem);
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_bind_ack(wStream* s, rpcconn_bind_ack_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_bind_ack_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT16(s, ctx->max_xmit_frag);
+ Stream_Read_UINT16(s, ctx->max_recv_frag);
+ Stream_Read_UINT32(s, ctx->assoc_group_id);
+
+ if (!rts_read_port_any(s, &ctx->sec_addr))
+ return FALSE;
+
+ if (!rts_align_stream(s, 4))
+ return FALSE;
+
+ if (!rts_read_result_list(s, &ctx->p_result_list))
+ return FALSE;
+
+ return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_bind_ack(rpcconn_bind_ack_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_port_any(&ctx->sec_addr);
+ rts_free_result_list(&ctx->p_result_list);
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_bind_nak(wStream* s, rpcconn_bind_nak_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_bind_nak_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT16(s, ctx->provider_reject_reason);
+ return rts_read_supported_versions(s, &ctx->versions);
+}
+
+static void rts_free_pdu_bind_nak(rpcconn_bind_nak_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+
+ rts_free_supported_versions(&ctx->versions);
+}
+
+static BOOL rts_read_pdu_auth3(wStream* s, rpcconn_rpc_auth_3_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_rpc_auth_3_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT16(s, ctx->max_xmit_frag);
+ Stream_Read_UINT16(s, ctx->max_recv_frag);
+
+ return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_auth3(rpcconn_rpc_auth_3_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_fault(wStream* s, rpcconn_fault_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_fault_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT32(s, ctx->alloc_hint);
+ Stream_Read_UINT16(s, ctx->p_cont_id);
+ Stream_Read_UINT8(s, ctx->cancel_count);
+ Stream_Read_UINT8(s, ctx->reserved);
+ Stream_Read_UINT32(s, ctx->status);
+
+ return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_fault(rpcconn_fault_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_cancel_ack(wStream* s, rpcconn_cancel_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_cancel_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_cancel_ack(rpcconn_cancel_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_orphaned(wStream* s, rpcconn_orphaned_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_orphaned_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ return rts_read_auth_verifier(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_orphaned(rpcconn_orphaned_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_request(wStream* s, rpcconn_request_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_request_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT32(s, ctx->alloc_hint);
+ Stream_Read_UINT16(s, ctx->p_cont_id);
+ Stream_Read_UINT16(s, ctx->opnum);
+ if (!rts_read_uuid(s, &ctx->object))
+ return FALSE;
+
+ return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_request(rpcconn_request_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_response(wStream* s, rpcconn_response_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) <
+ sizeof(rpcconn_response_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+ Stream_Read_UINT32(s, ctx->alloc_hint);
+ Stream_Read_UINT16(s, ctx->p_cont_id);
+ Stream_Read_UINT8(s, ctx->cancel_count);
+ Stream_Read_UINT8(s, ctx->reserved);
+
+ if (!rts_align_stream(s, 8))
+ return FALSE;
+
+ return rts_read_auth_verifier_with_stub(s, &ctx->auth_verifier, &ctx->header);
+}
+
+static void rts_free_pdu_response(rpcconn_response_hdr_t* ctx)
+{
+ if (!ctx)
+ return;
+ free(ctx->stub_data);
+ rts_free_auth_verifier(&ctx->auth_verifier);
+}
+
+static BOOL rts_read_pdu_rts(wStream* s, rpcconn_rts_hdr_t* ctx)
+{
+ assert(s);
+ assert(ctx);
+
+ if (Stream_GetRemainingLength(s) < sizeof(rpcconn_rts_hdr_t) - sizeof(rpcconn_common_hdr_t))
+ return FALSE;
+
+ Stream_Read_UINT16(s, ctx->Flags);
+ Stream_Read_UINT16(s, ctx->NumberOfCommands);
+ return TRUE;
+}
+
+static void rts_free_pdu_rts(rpcconn_rts_hdr_t* ctx)
+{
+ WINPR_UNUSED(ctx);
+}
+
+void rts_free_pdu_header(rpcconn_hdr_t* header, BOOL allocated)
+{
+ if (!header)
+ return;
+
+ switch (header->common.ptype)
+ {
+ case PTYPE_ALTER_CONTEXT:
+ rts_free_pdu_alter_context(&header->alter_context);
+ break;
+ case PTYPE_ALTER_CONTEXT_RESP:
+ rts_free_pdu_alter_context_response(&header->alter_context_response);
+ break;
+ case PTYPE_BIND:
+ rts_free_pdu_bind(&header->bind);
+ break;
+ case PTYPE_BIND_ACK:
+ rts_free_pdu_bind_ack(&header->bind_ack);
+ break;
+ case PTYPE_BIND_NAK:
+ rts_free_pdu_bind_nak(&header->bind_nak);
+ break;
+ case PTYPE_RPC_AUTH_3:
+ rts_free_pdu_auth3(&header->rpc_auth_3);
+ break;
+ case PTYPE_CANCEL_ACK:
+ rts_free_pdu_cancel_ack(&header->cancel);
+ break;
+ case PTYPE_FAULT:
+ rts_free_pdu_fault(&header->fault);
+ break;
+ case PTYPE_ORPHANED:
+ rts_free_pdu_orphaned(&header->orphaned);
+ break;
+ case PTYPE_REQUEST:
+ rts_free_pdu_request(&header->request);
+ break;
+ case PTYPE_RESPONSE:
+ rts_free_pdu_response(&header->response);
+ break;
+ case PTYPE_RTS:
+ rts_free_pdu_rts(&header->rts);
+ break;
+ /* No extra fields */
+ case PTYPE_SHUTDOWN:
+ break;
+
+ /* not handled */
+ case PTYPE_PING:
+ case PTYPE_WORKING:
+ case PTYPE_NOCALL:
+ case PTYPE_REJECT:
+ case PTYPE_ACK:
+ case PTYPE_CL_CANCEL:
+ case PTYPE_FACK:
+ case PTYPE_CO_CANCEL:
+ default:
+ break;
+ }
+
+ if (allocated)
+ free(header);
+}
+
+BOOL rts_read_pdu_header(wStream* s, rpcconn_hdr_t* header)
+{
+ BOOL rc = FALSE;
+ assert(s);
+ assert(header);
+
+ if (!rts_read_common_pdu_header(s, &header->common))
+ return FALSE;
+
+ WLog_DBG(TAG, "Reading PDU type %s", rts_pdu_ptype_to_string(header->common.ptype));
+ fflush(stdout);
+ switch (header->common.ptype)
+ {
+ case PTYPE_ALTER_CONTEXT:
+ rc = rts_read_pdu_alter_context(s, &header->alter_context);
+ break;
+ case PTYPE_ALTER_CONTEXT_RESP:
+ rc = rts_read_pdu_alter_context_response(s, &header->alter_context_response);
+ break;
+ case PTYPE_BIND:
+ rc = rts_read_pdu_bind(s, &header->bind);
+ break;
+ case PTYPE_BIND_ACK:
+ rc = rts_read_pdu_bind_ack(s, &header->bind_ack);
+ break;
+ case PTYPE_BIND_NAK:
+ rc = rts_read_pdu_bind_nak(s, &header->bind_nak);
+ break;
+ case PTYPE_RPC_AUTH_3:
+ rc = rts_read_pdu_auth3(s, &header->rpc_auth_3);
+ break;
+ case PTYPE_CANCEL_ACK:
+ rc = rts_read_pdu_cancel_ack(s, &header->cancel);
+ break;
+ case PTYPE_FAULT:
+ rc = rts_read_pdu_fault(s, &header->fault);
+ break;
+ case PTYPE_ORPHANED:
+ rc = rts_read_pdu_orphaned(s, &header->orphaned);
+ break;
+ case PTYPE_REQUEST:
+ rc = rts_read_pdu_request(s, &header->request);
+ break;
+ case PTYPE_RESPONSE:
+ rc = rts_read_pdu_response(s, &header->response);
+ break;
+ case PTYPE_RTS:
+ rc = rts_read_pdu_rts(s, &header->rts);
+ break;
+ case PTYPE_SHUTDOWN:
+ rc = TRUE; /* No extra fields */
+ break;
+
+ /* not handled */
+ case PTYPE_PING:
+ case PTYPE_WORKING:
+ case PTYPE_NOCALL:
+ case PTYPE_REJECT:
+ case PTYPE_ACK:
+ case PTYPE_CL_CANCEL:
+ case PTYPE_FACK:
+ case PTYPE_CO_CANCEL:
+ default:
+ break;
+ }
+
+ return rc;
+}
+
+static BOOL rts_write_pdu_header(wStream* s, const rpcconn_rts_hdr_t* header)
+{
+ assert(s);
+ assert(header);
+ if (!Stream_EnsureRemainingCapacity(s, sizeof(rpcconn_rts_hdr_t)))
+ return FALSE;
+
+ if (!rts_write_common_pdu_header(s, &header->header))
+ return FALSE;
+
+ Stream_Write_UINT16(s, header->Flags);
+ Stream_Write_UINT16(s, header->NumberOfCommands);
+ return TRUE;
}
-static int rts_receive_window_size_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length,
+static int rts_receive_window_size_command_read(rdpRpc* rpc, wStream* buffer,
UINT32* ReceiveWindowSize)
{
+ UINT32 val;
+
+ assert(rpc);
+ assert(buffer);
+
+ if (Stream_GetRemainingLength(buffer) < 4)
+ return -1;
+ Stream_Read_UINT32(buffer, val);
if (ReceiveWindowSize)
- *ReceiveWindowSize = *((UINT32*)&buffer[0]); /* ReceiveWindowSize (4 bytes) */
+ *ReceiveWindowSize = val; /* ReceiveWindowSize (4 bytes) */
return 4;
}
-static int rts_receive_window_size_command_write(BYTE* buffer, UINT32 ReceiveWindowSize)
+static BOOL rts_receive_window_size_command_write(wStream* s, UINT32 ReceiveWindowSize)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_RECEIVE_WINDOW_SIZE; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = ReceiveWindowSize; /* ReceiveWindowSize (4 bytes) */
- }
+ assert(s);
+
+ if (!Stream_EnsureRemainingCapacity(s, 2 * sizeof(UINT32)))
+ return FALSE;
- return 8;
+ Stream_Write_UINT32(s, RTS_CMD_RECEIVE_WINDOW_SIZE); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(s, ReceiveWindowSize); /* ReceiveWindowSize (4 bytes) */
+
+ return TRUE;
}
-static int rts_flow_control_ack_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length,
- UINT32* BytesReceived, UINT32* AvailableWindow,
- BYTE* ChannelCookie)
+static int rts_flow_control_ack_command_read(rdpRpc* rpc, wStream* buffer, UINT32* BytesReceived,
+ UINT32* AvailableWindow, BYTE* ChannelCookie)
{
+ UINT32 val;
+ assert(rpc);
+ assert(buffer);
+
/* Ack (24 bytes) */
+ if (Stream_GetRemainingLength(buffer) < 24)
+ return -1;
+
+ Stream_Read_UINT32(buffer, val);
if (BytesReceived)
- *BytesReceived = *((UINT32*)&buffer[0]); /* BytesReceived (4 bytes) */
+ *BytesReceived = val; /* BytesReceived (4 bytes) */
+ Stream_Read_UINT32(buffer, val);
if (AvailableWindow)
- *AvailableWindow = *((UINT32*)&buffer[4]); /* AvailableWindow (4 bytes) */
+ *AvailableWindow = val; /* AvailableWindow (4 bytes) */
if (ChannelCookie)
- CopyMemory(ChannelCookie, &buffer[8], 16); /* ChannelCookie (16 bytes) */
+ Stream_Read(buffer, ChannelCookie, 16); /* ChannelCookie (16 bytes) */
+ else
+ Stream_Seek(buffer, 16);
return 24;
}
-static int rts_flow_control_ack_command_write(BYTE* buffer, UINT32 BytesReceived,
- UINT32 AvailableWindow, BYTE* ChannelCookie)
+static BOOL rts_flow_control_ack_command_write(wStream* s, UINT32 BytesReceived,
+ UINT32 AvailableWindow, BYTE* ChannelCookie)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_FLOW_CONTROL_ACK; /* CommandType (4 bytes) */
- /* Ack (24 bytes) */
- *((UINT32*)&buffer[4]) = BytesReceived; /* BytesReceived (4 bytes) */
- *((UINT32*)&buffer[8]) = AvailableWindow; /* AvailableWindow (4 bytes) */
- CopyMemory(&buffer[12], ChannelCookie, 16); /* ChannelCookie (16 bytes) */
- }
+ assert(s);
- return 28;
-}
+ if (!Stream_EnsureRemainingCapacity(s, 28))
+ return FALSE;
-static int rts_connection_timeout_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length,
- UINT32* ConnectionTimeout)
-{
- if (ConnectionTimeout)
- *ConnectionTimeout = *((UINT32*)&buffer[0]); /* ConnectionTimeout (4 bytes) */
+ Stream_Write_UINT32(s, RTS_CMD_FLOW_CONTROL_ACK); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(s, BytesReceived); /* BytesReceived (4 bytes) */
+ Stream_Write_UINT32(s, AvailableWindow); /* AvailableWindow (4 bytes) */
+ Stream_Write(s, ChannelCookie, 16); /* ChannelCookie (16 bytes) */
- return 4;
+ return TRUE;
}
-static int rts_connection_timeout_command_write(BYTE* buffer, UINT32 ConnectionTimeout)
+static BOOL rts_connection_timeout_command_read(rdpRpc* rpc, wStream* buffer,
+ UINT32* ConnectionTimeout)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_CONNECTION_TIMEOUT; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = ConnectionTimeout; /* ConnectionTimeout (4 bytes) */
- }
+ UINT32 val;
+ assert(rpc);
+ assert(buffer);
- return 8;
-}
+ if (Stream_GetRemainingLength(buffer) < 4)
+ return FALSE;
-static int rts_cookie_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- /* Cookie (16 bytes) */
- return 16;
+ Stream_Read_UINT32(buffer, val);
+ if (ConnectionTimeout)
+ *ConnectionTimeout = val; /* ConnectionTimeout (4 bytes) */
+
+ return TRUE;
}
-static int rts_cookie_command_write(BYTE* buffer, BYTE* Cookie)
+static BOOL rts_cookie_command_write(wStream* s, const BYTE* Cookie)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_COOKIE; /* CommandType (4 bytes) */
- CopyMemory(&buffer[4], Cookie, 16); /* Cookie (16 bytes) */
- }
+ assert(s);
- return 20;
-}
+ if (!Stream_EnsureRemainingCapacity(s, 20))
+ return FALSE;
-static int rts_channel_lifetime_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- /* ChannelLifetime (4 bytes) */
- return 4;
+ Stream_Write_UINT32(s, RTS_CMD_COOKIE); /* CommandType (4 bytes) */
+ Stream_Write(s, Cookie, 16); /* Cookie (16 bytes) */
+
+ return TRUE;
}
-static int rts_channel_lifetime_command_write(BYTE* buffer, UINT32 ChannelLifetime)
+static BOOL rts_channel_lifetime_command_write(wStream* s, UINT32 ChannelLifetime)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_CHANNEL_LIFETIME; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = ChannelLifetime; /* ChannelLifetime (4 bytes) */
- }
+ assert(s);
- return 8;
-}
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
+ Stream_Write_UINT32(s, RTS_CMD_CHANNEL_LIFETIME); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(s, ChannelLifetime); /* ChannelLifetime (4 bytes) */
-static int rts_client_keepalive_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- /* ClientKeepalive (4 bytes) */
- return 4;
+ return TRUE;
}
-static int rts_client_keepalive_command_write(BYTE* buffer, UINT32 ClientKeepalive)
+static BOOL rts_client_keepalive_command_write(wStream* s, UINT32 ClientKeepalive)
{
+ assert(s);
+
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
/**
* An unsigned integer that specifies the keep-alive interval, in milliseconds,
* that this connection is configured to use. This value MUST be 0 or in the inclusive
* range of 60,000 through 4,294,967,295. If it is 0, it MUST be interpreted as 300,000.
*/
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_CLIENT_KEEPALIVE; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = ClientKeepalive; /* ClientKeepalive (4 bytes) */
- }
+ Stream_Write_UINT32(s, RTS_CMD_CLIENT_KEEPALIVE); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(s, ClientKeepalive); /* ClientKeepalive (4 bytes) */
- return 8;
+ return TRUE;
}
-static int rts_version_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static BOOL rts_version_command_read(rdpRpc* rpc, wStream* buffer)
{
- /* Version (4 bytes) */
- return 4;
-}
+ assert(rpc);
+ assert(buffer);
-static int rts_version_command_write(BYTE* buffer)
-{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_VERSION; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = 1; /* Version (4 bytes) */
- }
+ if (!Stream_SafeSeek(buffer, 4))
+ return FALSE;
- return 8;
+ /* Version (4 bytes) */
+ return TRUE;
}
-static int rts_empty_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static BOOL rts_version_command_write(wStream* buffer)
{
- return 0;
-}
+ assert(buffer);
-static int rts_empty_command_write(BYTE* buffer)
-{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_EMPTY; /* CommandType (4 bytes) */
- }
+ if (Stream_GetRemainingCapacity(buffer) < 8)
+ return FALSE;
- return 4;
-}
+ Stream_Write_UINT32(buffer, RTS_CMD_VERSION); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(buffer, 1); /* Version (4 bytes) */
-static SSIZE_T rts_padding_command_read(const BYTE* buffer, size_t length)
-{
- UINT32 ConformanceCount;
- ConformanceCount = *((UINT32*)&buffer[0]); /* ConformanceCount (4 bytes) */
- /* Padding (variable) */
- return ConformanceCount + 4;
+ return TRUE;
}
-static int rts_padding_command_write(BYTE* buffer, UINT32 ConformanceCount)
+static BOOL rts_empty_command_write(wStream* s)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_PADDING; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = ConformanceCount; /* ConformanceCount (4 bytes) */
- ZeroMemory(&buffer[8], ConformanceCount); /* Padding (variable) */
- }
-
- return 8 + ConformanceCount;
-}
+ assert(s);
-static int rts_negative_ance_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- return 0;
-}
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
-static int rts_negative_ance_command_write(BYTE* buffer)
-{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_NEGATIVE_ANCE; /* CommandType (4 bytes) */
- }
+ Stream_Write_UINT32(s, RTS_CMD_EMPTY); /* CommandType (4 bytes) */
- return 4;
+ return TRUE;
}
-static int rts_ance_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static BOOL rts_padding_command_read(wStream* s, size_t* length)
{
- return 0;
+ UINT32 ConformanceCount;
+ assert(s);
+ assert(length);
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+ Stream_Read_UINT32(s, ConformanceCount); /* ConformanceCount (4 bytes) */
+ *length = ConformanceCount + 4;
+ return TRUE;
}
-static int rts_ance_command_write(BYTE* buffer)
+static BOOL rts_client_address_command_read(wStream* s, size_t* length)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_ANCE; /* CommandType (4 bytes) */
- }
+ UINT32 AddressType;
- return 4;
-}
+ assert(s);
+ assert(length);
-static SSIZE_T rts_client_address_command_read(const BYTE* buffer, size_t length)
-{
- UINT32 AddressType;
- AddressType = *((UINT32*)&buffer[0]); /* AddressType (4 bytes) */
+ if (Stream_GetRemainingLength(s) < 4)
+ return FALSE;
+ Stream_Read_UINT32(s, AddressType); /* AddressType (4 bytes) */
if (AddressType == 0)
{
/* ClientAddress (4 bytes) */
/* padding (12 bytes) */
- return 4 + 4 + 12;
+ *length = 4 + 4 + 12;
}
else
{
/* ClientAddress (16 bytes) */
/* padding (12 bytes) */
- return 4 + 16 + 12;
+ *length = 4 + 16 + 12;
}
+ return TRUE;
}
-static int rts_client_address_command_write(BYTE* buffer, UINT32 AddressType, BYTE* ClientAddress)
+static BOOL rts_association_group_id_command_write(wStream* s, const BYTE* AssociationGroupId)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_CLIENT_ADDRESS; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = AddressType; /* AddressType (4 bytes) */
- }
-
- if (AddressType == 0)
- {
- if (buffer)
- {
- CopyMemory(&buffer[8], ClientAddress, 4); /* ClientAddress (4 bytes) */
- ZeroMemory(&buffer[12], 12); /* padding (12 bytes) */
- }
+ assert(s);
- return 24;
- }
- else
- {
- if (buffer)
- {
- CopyMemory(&buffer[8], ClientAddress, 16); /* ClientAddress (16 bytes) */
- ZeroMemory(&buffer[24], 12); /* padding (12 bytes) */
- }
+ if (!Stream_EnsureRemainingCapacity(s, 20))
+ return FALSE;
- return 36;
- }
-}
+ Stream_Write_UINT32(s, RTS_CMD_ASSOCIATION_GROUP_ID); /* CommandType (4 bytes) */
+ Stream_Write(s, AssociationGroupId, 16); /* AssociationGroupId (16 bytes) */
-static int rts_association_group_id_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- /* AssociationGroupId (16 bytes) */
- return 16;
+ return TRUE;
}
-static int rts_association_group_id_command_write(BYTE* buffer, BYTE* AssociationGroupId)
+static int rts_destination_command_read(rdpRpc* rpc, wStream* buffer, UINT32* Destination)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_ASSOCIATION_GROUP_ID; /* CommandType (4 bytes) */
- CopyMemory(&buffer[4], AssociationGroupId, 16); /* AssociationGroupId (16 bytes) */
- }
-
- return 20;
-}
+ UINT32 val;
+ assert(rpc);
+ assert(buffer);
-static int rts_destination_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length,
- UINT32* Destination)
-{
+ if (Stream_GetRemainingLength(buffer) < 4)
+ return -1;
+ Stream_Read_UINT32(buffer, val);
if (Destination)
- *Destination = *((UINT32*)&buffer[0]); /* Destination (4 bytes) */
+ *Destination = val; /* Destination (4 bytes) */
return 4;
}
-static int rts_destination_command_write(BYTE* buffer, UINT32 Destination)
+static BOOL rts_destination_command_write(wStream* s, UINT32 Destination)
{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_DESTINATION; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = Destination; /* Destination (4 bytes) */
- }
-
- return 8;
-}
+ assert(s);
-static int rts_ping_traffic_sent_notify_command_read(rdpRpc* rpc, BYTE* buffer, UINT32 length)
-{
- /* PingTrafficSent (4 bytes) */
- return 4;
-}
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
-static int rts_ping_traffic_sent_notify_command_write(BYTE* buffer, UINT32 PingTrafficSent)
-{
- if (buffer)
- {
- *((UINT32*)&buffer[0]) = RTS_CMD_PING_TRAFFIC_SENT_NOTIFY; /* CommandType (4 bytes) */
- *((UINT32*)&buffer[4]) = PingTrafficSent; /* PingTrafficSent (4 bytes) */
- }
+ Stream_Write_UINT32(s, RTS_CMD_DESTINATION); /* CommandType (4 bytes) */
+ Stream_Write_UINT32(s, Destination); /* Destination (4 bytes) */
- return 8;
+ return TRUE;
}
void rts_generate_cookie(BYTE* cookie)
{
+ assert(cookie);
winpr_RAND(cookie, 16);
}
+static BOOL rts_send_buffer(RpcChannel* channel, wStream* s, size_t frag_length)
+{
+ BOOL status = FALSE;
+ SSIZE_T rc;
+
+ assert(channel);
+ assert(s);
+
+ Stream_SealLength(s);
+ if (Stream_Length(s) < sizeof(rpcconn_common_hdr_t))
+ goto fail;
+ if (Stream_Length(s) != frag_length)
+ goto fail;
+
+ rc = rpc_channel_write(channel, Stream_Buffer(s), Stream_Length(s));
+ if (rc < 0)
+ goto fail;
+ if ((size_t)rc != Stream_Length(s))
+ goto fail;
+ status = TRUE;
+fail:
+ return status;
+}
+
+
/* CONN/A Sequence */
-int rts_send_CONN_A1_pdu(rdpRpc* rpc)
+BOOL rts_send_CONN_A1_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- rpcconn_rts_hdr_t header;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
UINT32 ReceiveWindowSize;
BYTE* OUTChannelCookie;
BYTE* VirtualConnectionCookie;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcOutChannel* outChannel = connection->DefaultOutChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 76;
+ RpcVirtualConnection* connection;
+ RpcOutChannel* outChannel;
+
+ assert(rpc);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
+ outChannel = connection->DefaultOutChannel;
+ assert(outChannel);
+
+ header.header.frag_length = 76;
header.Flags = RTS_FLAG_NONE;
header.NumberOfCommands = 4;
WLog_DBG(TAG, "Sending CONN/A1 RTS PDU");
VirtualConnectionCookie = (BYTE*)&(connection->Cookie);
OUTChannelCookie = (BYTE*)&(outChannel->common.Cookie);
ReceiveWindowSize = outChannel->ReceiveWindow;
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
return -1;
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
- rts_cookie_command_write(&buffer[28],
- VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */
- rts_cookie_command_write(&buffer[48], OUTChannelCookie); /* OUTChannelCookie (20 bytes) */
- rts_receive_window_size_command_write(&buffer[68],
- ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */
- status = rpc_channel_write(&outChannel->common, buffer, header.frag_length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ status = rts_version_command_write(buffer); /* Version (8 bytes) */
+ if (!status)
+ goto fail;
+ status = rts_cookie_command_write(
+ buffer, VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */
+ if (!status)
+ goto fail;
+ status = rts_cookie_command_write(buffer, OUTChannelCookie); /* OUTChannelCookie (20 bytes) */
+ if (!status)
+ goto fail;
+ status = rts_receive_window_size_command_write(
+ buffer, ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */
+ if (!status)
+ goto fail;
+ status = rts_send_buffer(&outChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return status;
}
-int rts_recv_CONN_A3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+BOOL rts_recv_CONN_A3_pdu(rdpRpc* rpc, wStream* buffer)
{
+ BOOL rc;
UINT32 ConnectionTimeout;
- rts_connection_timeout_command_read(rpc, &buffer[24], length - 24, &ConnectionTimeout);
+ if (!Stream_SafeSeek(buffer, 24))
+ return FALSE;
+
+ rc = rts_connection_timeout_command_read(rpc, buffer, &ConnectionTimeout);
+ if (!rc)
+ return rc;
+
WLog_DBG(TAG, "Receiving CONN/A3 RTS PDU: ConnectionTimeout: %" PRIu32 "", ConnectionTimeout);
+
+ assert(rpc);
+ assert(rpc->VirtualConnection);
+ assert(rpc->VirtualConnection->DefaultInChannel);
+
rpc->VirtualConnection->DefaultInChannel->PingOriginator.ConnectionTimeout = ConnectionTimeout;
- return 1;
+ return TRUE;
}
/* CONN/B Sequence */
-int rts_send_CONN_B1_pdu(rdpRpc* rpc)
+BOOL rts_send_CONN_B1_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- UINT32 length;
- rpcconn_rts_hdr_t header;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
BYTE* INChannelCookie;
BYTE* AssociationGroupId;
BYTE* VirtualConnectionCookie;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcInChannel* inChannel = connection->DefaultInChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 104;
+ RpcVirtualConnection* connection;
+ RpcInChannel* inChannel;
+
+ assert(rpc);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
+ inChannel = connection->DefaultInChannel;
+ assert(inChannel);
+
+ header.header.frag_length = 104;
header.Flags = RTS_FLAG_NONE;
header.NumberOfCommands = 6;
WLog_DBG(TAG, "Sending CONN/B1 RTS PDU");
VirtualConnectionCookie = (BYTE*)&(connection->Cookie);
INChannelCookie = (BYTE*)&(inChannel->common.Cookie);
AssociationGroupId = (BYTE*)&(connection->AssociationGroupId);
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
- return -1;
-
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
- rts_cookie_command_write(&buffer[28],
- VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */
- rts_cookie_command_write(&buffer[48], INChannelCookie); /* INChannelCookie (20 bytes) */
- rts_channel_lifetime_command_write(&buffer[68],
- rpc->ChannelLifetime); /* ChannelLifetime (8 bytes) */
- rts_client_keepalive_command_write(&buffer[76],
- rpc->KeepAliveInterval); /* ClientKeepalive (8 bytes) */
- rts_association_group_id_command_write(&buffer[84],
- AssociationGroupId); /* AssociationGroupId (20 bytes) */
- length = header.frag_length;
- status = rpc_channel_write(&inChannel->common, buffer, length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ goto fail;
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ if (!rts_version_command_write(buffer)) /* Version (8 bytes) */
+ goto fail;
+ if (!rts_cookie_command_write(buffer,
+ VirtualConnectionCookie)) /* VirtualConnectionCookie (20 bytes) */
+ goto fail;
+ if (!rts_cookie_command_write(buffer, INChannelCookie)) /* INChannelCookie (20 bytes) */
+ goto fail;
+ if (!rts_channel_lifetime_command_write(buffer,
+ rpc->ChannelLifetime)) /* ChannelLifetime (8 bytes) */
+ goto fail;
+ if (!rts_client_keepalive_command_write(buffer,
+ rpc->KeepAliveInterval)) /* ClientKeepalive (8 bytes) */
+ goto fail;
+ if (!rts_association_group_id_command_write(
+ buffer, AssociationGroupId)) /* AssociationGroupId (20 bytes) */
+ goto fail;
+ status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return status;
}
/* CONN/C Sequence */
-int rts_recv_CONN_C2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+BOOL rts_recv_CONN_C2_pdu(rdpRpc* rpc, wStream* buffer)
{
- UINT32 offset;
+ BOOL rc;
UINT32 ReceiveWindowSize;
UINT32 ConnectionTimeout;
- offset = 24;
- offset += rts_version_command_read(rpc, &buffer[offset], length - offset) + 4;
- offset += rts_receive_window_size_command_read(rpc, &buffer[offset], length - offset,
- &ReceiveWindowSize) +
- 4;
- offset += rts_connection_timeout_command_read(rpc, &buffer[offset], length - offset,
- &ConnectionTimeout) +
- 4;
+ assert(rpc);
+ assert(buffer);
+
+ if (!Stream_SafeSeek(buffer, 24))
+ return FALSE;
+
+ rc = rts_version_command_read(rpc, buffer);
+ if (rc < 0)
+ return rc;
+ rc = rts_receive_window_size_command_read(rpc, buffer, &ReceiveWindowSize);
+ if (rc < 0)
+ return rc;
+ rc = rts_connection_timeout_command_read(rpc, buffer, &ConnectionTimeout);
+ if (rc < 0)
+ return rc;
WLog_DBG(TAG,
"Receiving CONN/C2 RTS PDU: ConnectionTimeout: %" PRIu32 " ReceiveWindowSize: %" PRIu32
"",
ConnectionTimeout, ReceiveWindowSize);
+
+ assert(rpc);
+ assert(rpc->VirtualConnection);
+ assert(rpc->VirtualConnection->DefaultInChannel);
+
rpc->VirtualConnection->DefaultInChannel->PingOriginator.ConnectionTimeout = ConnectionTimeout;
rpc->VirtualConnection->DefaultInChannel->PeerReceiveWindow = ReceiveWindowSize;
- return 1;
+ return TRUE;
}
/* Out-of-Sequence PDUs */
-static int rts_send_keep_alive_pdu(rdpRpc* rpc)
+BOOL rts_send_flow_control_ack_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- UINT32 length;
- rpcconn_rts_hdr_t header;
- RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 28;
- header.Flags = RTS_FLAG_OTHER_CMD;
- header.NumberOfCommands = 1;
- WLog_DBG(TAG, "Sending Keep-Alive RTS PDU");
- buffer = (BYTE*)malloc(header.frag_length);
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
+ UINT32 BytesReceived;
+ UINT32 AvailableWindow;
+ BYTE* ChannelCookie;
+ RpcVirtualConnection* connection;
+ RpcInChannel* inChannel;
+ RpcOutChannel* outChannel;
- if (!buffer)
- return -1;
+ assert(rpc);
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_client_keepalive_command_write(
- &buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
- length = header.frag_length;
- status = rpc_channel_write(&inChannel->common, buffer, length);
- free(buffer);
- return (status > 0) ? 1 : -1;
-}
+ connection = rpc->VirtualConnection;
+ assert(connection);
-int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
-{
- int status;
- BYTE* buffer;
- UINT32 length;
- rpcconn_rts_hdr_t header;
- UINT32 BytesReceived;
- UINT32 AvailableWindow;
- BYTE* ChannelCookie;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcInChannel* inChannel = connection->DefaultInChannel;
- RpcOutChannel* outChannel = connection->DefaultOutChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 56;
+ inChannel = connection->DefaultInChannel;
+ assert(inChannel);
+
+ outChannel = connection->DefaultOutChannel;
+ assert(outChannel);
+
+ header.header.frag_length = 56;
header.Flags = RTS_FLAG_OTHER_CMD;
header.NumberOfCommands = 2;
+
WLog_DBG(TAG, "Sending FlowControlAck RTS PDU");
+
BytesReceived = outChannel->BytesReceived;
AvailableWindow = outChannel->AvailableWindowAdvertised;
ChannelCookie = (BYTE*)&(outChannel->common.Cookie);
outChannel->ReceiverAvailableWindow = outChannel->AvailableWindowAdvertised;
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
- return -1;
+ goto fail;
+
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ if (!rts_destination_command_write(buffer, FDOutProxy)) /* Destination Command (8 bytes) */
+ goto fail;
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
/* FlowControlAck Command (28 bytes) */
- rts_flow_control_ack_command_write(&buffer[28], BytesReceived, AvailableWindow, ChannelCookie);
- length = header.frag_length;
- status = rpc_channel_write(&inChannel->common, buffer, length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ if (!rts_flow_control_ack_command_write(buffer, BytesReceived, AvailableWindow, ChannelCookie))
+ goto fail;
+
+ status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return status;
}
-static int rts_recv_flow_control_ack_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static int rts_recv_flow_control_ack_pdu(rdpRpc* rpc, wStream* buffer)
{
- UINT32 offset;
+ int rc;
UINT32 BytesReceived;
UINT32 AvailableWindow;
- BYTE ChannelCookie[16];
- offset = 24;
- offset +=
- rts_flow_control_ack_command_read(rpc, &buffer[offset], length - offset, &BytesReceived,
- &AvailableWindow, (BYTE*)&ChannelCookie) +
- 4;
+ BYTE ChannelCookie[16] = { 0 };
+
+ rc = rts_flow_control_ack_command_read(rpc, buffer, &BytesReceived, &AvailableWindow,
+ (BYTE*)&ChannelCookie);
+ if (rc < 0)
+ return rc;
WLog_ERR(TAG,
"Receiving FlowControlAck RTS PDU: BytesReceived: %" PRIu32
" AvailableWindow: %" PRIu32 "",
BytesReceived, AvailableWindow);
+
+ assert(rpc->VirtualConnection);
+ assert(rpc->VirtualConnection->DefaultInChannel);
+
rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow =
AvailableWindow - (rpc->VirtualConnection->DefaultInChannel->BytesSent - BytesReceived);
return 1;
}
-static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, wStream* buffer)
{
- UINT32 offset;
+ int rc;
UINT32 Destination;
UINT32 BytesReceived;
UINT32 AvailableWindow;
- BYTE ChannelCookie[16];
+ BYTE ChannelCookie[16] = { 0 };
/**
* When the sender receives a FlowControlAck RTS PDU, it MUST use the following formula to
* recalculate its Sender AvailableWindow variable:
@@ -620,16 +1718,23 @@ static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buf
* structure in the PDU received.
*
*/
- offset = 24;
- offset += rts_destination_command_read(rpc, &buffer[offset], length - offset, &Destination) + 4;
- offset +=
- rts_flow_control_ack_command_read(rpc, &buffer[offset], length - offset, &BytesReceived,
- &AvailableWindow, (BYTE*)&ChannelCookie) +
- 4;
+
+ rc = rts_destination_command_read(rpc, buffer, &Destination);
+ if (rc < 0)
+ return rc;
+
+ rc = rts_flow_control_ack_command_read(rpc, buffer, &BytesReceived, &AvailableWindow,
+ ChannelCookie);
+ if (rc < 0)
+ return rc;
+
WLog_DBG(TAG,
"Receiving FlowControlAckWithDestination RTS PDU: BytesReceived: %" PRIu32
" AvailableWindow: %" PRIu32 "",
BytesReceived, AvailableWindow);
+
+ assert(rpc->VirtualConnection);
+ assert(rpc->VirtualConnection->DefaultInChannel);
rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow =
AvailableWindow - (rpc->VirtualConnection->DefaultInChannel->BytesSent - BytesReceived);
return 1;
@@ -637,31 +1742,40 @@ static int rts_recv_flow_control_ack_with_destination_pdu(rdpRpc* rpc, BYTE* buf
static int rts_send_ping_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- UINT32 length;
- rpcconn_rts_hdr_t header;
- RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 20;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
+ RpcInChannel* inChannel;
+
+ assert(rpc);
+ assert(rpc->VirtualConnection);
+
+ inChannel = rpc->VirtualConnection->DefaultInChannel;
+ assert(inChannel);
+
+ header.header.frag_length = 20;
header.Flags = RTS_FLAG_PING;
header.NumberOfCommands = 0;
WLog_DBG(TAG, "Sending Ping RTS PDU");
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
- return -1;
-
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- length = header.frag_length;
- status = rpc_channel_write(&inChannel->common, buffer, length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ goto fail;
+
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return (status) ? 1 : -1;
}
-SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length)
+BOOL rts_command_length(UINT32 CommandType, wStream* s, size_t* length)
{
- int CommandLength = 0;
+ size_t padding = 0;
+ size_t CommandLength = 0;
+
+ assert(s);
switch (CommandType)
{
@@ -698,7 +1812,8 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length
break;
case RTS_CMD_PADDING: /* variable-size */
- CommandLength = rts_padding_command_read(buffer, length);
+ if (!rts_padding_command_read(s, &padding))
+ return FALSE;
break;
case RTS_CMD_NEGATIVE_ANCE:
@@ -710,7 +1825,8 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length
break;
case RTS_CMD_CLIENT_ADDRESS: /* variable-size */
- CommandLength = rts_client_address_command_read(buffer, length);
+ if (!rts_client_address_command_read(s, &CommandLength))
+ return FALSE;
break;
case RTS_CMD_ASSOCIATION_GROUP_ID:
@@ -727,118 +1843,172 @@ SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length
default:
WLog_ERR(TAG, "Error: Unknown RTS Command Type: 0x%" PRIx32 "", CommandType);
- return -1;
+ return FALSE;
}
- return CommandLength;
+ CommandLength += padding;
+ if (Stream_GetRemainingLength(s) < CommandLength)
+ return FALSE;
+
+ if (length)
+ *length = CommandLength;
+ return TRUE;
}
static int rts_send_OUT_R2_A7_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- rpcconn_rts_hdr_t header;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
BYTE* SuccessorChannelCookie;
- RpcInChannel* inChannel = rpc->VirtualConnection->DefaultInChannel;
- RpcOutChannel* nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 56;
+ RpcInChannel* inChannel;
+ RpcOutChannel* nextOutChannel;
+
+ assert(rpc);
+ assert(rpc->VirtualConnection);
+
+ inChannel = rpc->VirtualConnection->DefaultInChannel;
+ assert(inChannel);
+
+ nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel;
+ assert(nextOutChannel);
+
+ header.header.frag_length = 56;
header.Flags = RTS_FLAG_OUT_CHANNEL;
header.NumberOfCommands = 3;
WLog_DBG(TAG, "Sending OUT_R2/A7 RTS PDU");
SuccessorChannelCookie = (BYTE*)&(nextOutChannel->common.Cookie);
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
return -1;
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_destination_command_write(&buffer[20], FDServer); /* Destination (8 bytes)*/
- rts_cookie_command_write(&buffer[28],
- SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */
- rts_version_command_write(&buffer[48]); /* Version (8 bytes) */
- status = rpc_channel_write(&inChannel->common, buffer, header.frag_length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ if (!rts_destination_command_write(buffer, FDServer)) /* Destination (8 bytes)*/
+ goto fail;
+ if (!rts_cookie_command_write(buffer,
+ SuccessorChannelCookie)) /* SuccessorChannelCookie (20 bytes) */
+ goto fail;
+ if (!rts_version_command_write(buffer)) /* Version (8 bytes) */
+ goto fail;
+ status = rts_send_buffer(&inChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return (status) ? 1 : -1;
}
static int rts_send_OUT_R2_C1_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- rpcconn_rts_hdr_t header;
- RpcOutChannel* nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 24;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
+ RpcOutChannel* nextOutChannel;
+
+ assert(rpc);
+ assert(rpc->VirtualConnection);
+
+ nextOutChannel = rpc->VirtualConnection->NonDefaultOutChannel;
+ assert(nextOutChannel);
+
+ header.header.frag_length = 24;
header.Flags = RTS_FLAG_PING;
header.NumberOfCommands = 1;
WLog_DBG(TAG, "Sending OUT_R2/C1 RTS PDU");
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
return -1;
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_empty_command_write(&buffer[20]); /* Empty command (4 bytes) */
- status = rpc_channel_write(&nextOutChannel->common, buffer, header.frag_length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+
+ if (!rts_empty_command_write(buffer)) /* Empty command (4 bytes) */
+ goto fail;
+ status = rts_send_buffer(&nextOutChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return (status) ? 1 : -1;
}
-int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc)
+BOOL rts_send_OUT_R1_A3_pdu(rdpRpc* rpc)
{
- int status;
- BYTE* buffer;
- rpcconn_rts_hdr_t header;
+ BOOL status = FALSE;
+ wStream* buffer;
+ rpcconn_rts_hdr_t header = rts_pdu_header_init();
UINT32 ReceiveWindowSize;
BYTE* VirtualConnectionCookie;
BYTE* PredecessorChannelCookie;
BYTE* SuccessorChannelCookie;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
- RpcOutChannel* outChannel = connection->DefaultOutChannel;
- RpcOutChannel* nextOutChannel = connection->NonDefaultOutChannel;
- rts_pdu_header_init(&header);
- header.frag_length = 96;
+ RpcVirtualConnection* connection;
+ RpcOutChannel* outChannel;
+ RpcOutChannel* nextOutChannel;
+
+ assert(rpc);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
+ outChannel = connection->DefaultOutChannel;
+ assert(outChannel);
+
+ nextOutChannel = connection->NonDefaultOutChannel;
+ assert(nextOutChannel);
+
+ header.header.frag_length = 96;
header.Flags = RTS_FLAG_RECYCLE_CHANNEL;
header.NumberOfCommands = 5;
+
WLog_DBG(TAG, "Sending OUT_R1/A3 RTS PDU");
+
VirtualConnectionCookie = (BYTE*)&(connection->Cookie);
PredecessorChannelCookie = (BYTE*)&(outChannel->common.Cookie);
SuccessorChannelCookie = (BYTE*)&(nextOutChannel->common.Cookie);
ReceiveWindowSize = outChannel->ReceiveWindow;
- buffer = (BYTE*)malloc(header.frag_length);
+ buffer = Stream_New(NULL, header.header.frag_length);
if (!buffer)
return -1;
- CopyMemory(buffer, ((BYTE*)&header), 20); /* RTS Header (20 bytes) */
- rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
- rts_cookie_command_write(&buffer[28],
- VirtualConnectionCookie); /* VirtualConnectionCookie (20 bytes) */
- rts_cookie_command_write(&buffer[48],
- PredecessorChannelCookie); /* PredecessorChannelCookie (20 bytes) */
- rts_cookie_command_write(&buffer[68],
- SuccessorChannelCookie); /* SuccessorChannelCookie (20 bytes) */
- rts_receive_window_size_command_write(&buffer[88],
- ReceiveWindowSize); /* ReceiveWindowSize (8 bytes) */
- status = rpc_channel_write(&nextOutChannel->common, buffer, header.frag_length);
- free(buffer);
- return (status > 0) ? 1 : -1;
+ if (!rts_write_pdu_header(buffer, &header)) /* RTS Header (20 bytes) */
+ goto fail;
+ if (!rts_version_command_write(buffer)) /* Version (8 bytes) */
+ goto fail;
+ if (!rts_cookie_command_write(buffer,
+ VirtualConnectionCookie)) /* VirtualConnectionCookie (20 bytes) */
+ goto fail;
+ if (!rts_cookie_command_write(
+ buffer, PredecessorChannelCookie)) /* PredecessorChannelCookie (20 bytes) */
+ goto fail;
+ if (!rts_cookie_command_write(buffer,
+ SuccessorChannelCookie)) /* SuccessorChannelCookie (20 bytes) */
+ goto fail;
+ if (!rts_receive_window_size_command_write(buffer,
+ ReceiveWindowSize)) /* ReceiveWindowSize (8 bytes) */
+ goto fail;
+
+ status = rts_send_buffer(&nextOutChannel->common, buffer, header.header.frag_length);
+fail:
+ Stream_Free(buffer, TRUE);
+ return status;
}
-static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, wStream* buffer)
{
int status;
- UINT32 offset;
UINT32 Destination = 0;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
+ RpcVirtualConnection* connection;
+ assert(rpc);
+ assert(buffer);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
WLog_DBG(TAG, "Receiving OUT R1/A2 RTS PDU");
- offset = 24;
- if (length < offset)
- return -1;
+ status = rts_destination_command_read(rpc, buffer, &Destination);
+ if (status < 0)
+ return status;;
- rts_destination_command_read(rpc, &buffer[offset], length - offset, &Destination);
connection->NonDefaultOutChannel = rpc_out_channel_new(rpc);
if (!connection->NonDefaultOutChannel)
@@ -857,10 +2027,17 @@ static int rts_recv_OUT_R1_A2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
return 1;
}
-static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, wStream* buffer)
{
int status;
- RpcVirtualConnection* connection = rpc->VirtualConnection;
+ RpcVirtualConnection* connection;
+
+ assert(rpc);
+ assert(buffer);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
WLog_DBG(TAG, "Receiving OUT R2/A6 RTS PDU");
status = rts_send_OUT_R2_C1_pdu(rpc);
@@ -885,47 +2062,59 @@ static int rts_recv_OUT_R2_A6_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
return 1;
}
-static int rts_recv_OUT_R2_B3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+static int rts_recv_OUT_R2_B3_pdu(rdpRpc* rpc, wStream* buffer)
{
- RpcVirtualConnection* connection = rpc->VirtualConnection;
+ RpcVirtualConnection* connection;
+
+ assert(rpc);
+ assert(buffer);
+
+ connection = rpc->VirtualConnection;
+ assert(connection);
+
WLog_DBG(TAG, "Receiving OUT R2/B3 RTS PDU");
rpc_out_channel_transition_to_state(connection->DefaultOutChannel,
CLIENT_OUT_CHANNEL_STATE_RECYCLED);
return 1;
}
-int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
+BOOL rts_recv_out_of_sequence_pdu(rdpRpc* rpc, wStream* buffer, const rpcconn_hdr_t* header)
{
- int status = -1;
+ BOOL status = FALSE;
UINT32 SignatureId;
- rpcconn_rts_hdr_t* rts;
- RtsPduSignature signature;
+ size_t length, total;
+ RtsPduSignature signature = { 0 };
RpcVirtualConnection* connection;
- if (!rpc || !buffer)
- return -1;
+ assert(rpc);
+ assert(buffer);
+ assert(header);
+
+ total = Stream_Length(buffer);
+ length = header->common.frag_length;
+ if (total < length)
+ return FALSE;
connection = rpc->VirtualConnection;
if (!connection)
- return -1;
-
- rts = (rpcconn_rts_hdr_t*)buffer;
+ return FALSE;
- if (!rts_extract_pdu_signature(&signature, rts))
- return -1;
+ if (!rts_extract_pdu_signature(&signature, buffer, header))
+ return FALSE;
SignatureId = rts_identify_pdu_signature(&signature, NULL);
- if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, rts))
+ if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, buffer, header))
{
- status = rts_recv_flow_control_ack_pdu(rpc, buffer, length);
+ status = rts_recv_flow_control_ack_pdu(rpc, buffer);
}
- else if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, rts))
+ else if (rts_match_pdu_signature(&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, buffer,
+ header))
{
- status = rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
+ status = rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer);
}
- else if (rts_match_pdu_signature(&RTS_PDU_PING_SIGNATURE, rts))
+ else if (rts_match_pdu_signature(&RTS_PDU_PING_SIGNATURE, buffer, header))
{
status = rts_send_ping_pdu(rpc);
}
@@ -933,28 +2122,28 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
{
if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED)
{
- if (rts_match_pdu_signature(&RTS_PDU_OUT_R1_A2_SIGNATURE, rts))
+ if (rts_match_pdu_signature(&RTS_PDU_OUT_R1_A2_SIGNATURE, buffer, header))
{
- status = rts_recv_OUT_R1_A2_pdu(rpc, buffer, length);
+ status = rts_recv_OUT_R1_A2_pdu(rpc, buffer);
}
}
else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_A6W)
{
- if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_A6_SIGNATURE, rts))
+ if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_A6_SIGNATURE, buffer, header))
{
- status = rts_recv_OUT_R2_A6_pdu(rpc, buffer, length);
+ status = rts_recv_OUT_R2_A6_pdu(rpc, buffer);
}
}
else if (connection->DefaultOutChannel->State == CLIENT_OUT_CHANNEL_STATE_OPENED_B3W)
{
- if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_B3_SIGNATURE, rts))
+ if (rts_match_pdu_signature(&RTS_PDU_OUT_R2_B3_SIGNATURE, buffer, header))
{
- status = rts_recv_OUT_R2_B3_pdu(rpc, buffer, length);
+ status = rts_recv_OUT_R2_B3_pdu(rpc, buffer);
}
}
}
- if (status < 0)
+ if (!status)
{
WLog_ERR(TAG, "error parsing RTS PDU with signature id: 0x%08" PRIX32 "", SignatureId);
rts_print_pdu_signature(&signature);
@@ -962,3 +2151,42 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
return status;
}
+
+BOOL rts_write_pdu_auth3(wStream* s, const rpcconn_rpc_auth_3_hdr_t* auth)
+{
+ assert(s);
+ assert(auth);
+
+ if (!rts_write_common_pdu_header(s, &auth->header))
+ return FALSE;
+
+ if (!Stream_EnsureRemainingCapacity(s, 2 * sizeof(UINT16)))
+ return FALSE;
+
+ Stream_Write_UINT16(s, auth->max_xmit_frag);
+ Stream_Write_UINT16(s, auth->max_recv_frag);
+
+ return rts_write_auth_verifier(s, &auth->auth_verifier, &auth->header);
+}
+
+BOOL rts_write_pdu_bind(wStream* s, const rpcconn_bind_hdr_t* bind)
+{
+
+ assert(s);
+ assert(bind);
+
+ if (!rts_write_common_pdu_header(s, &bind->header))
+ return FALSE;
+
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
+
+ Stream_Write_UINT16(s, bind->max_xmit_frag);
+ Stream_Write_UINT16(s, bind->max_recv_frag);
+ Stream_Write_UINT32(s, bind->assoc_group_id);
+
+ if (!rts_write_context_list(s, &bind->p_context_elem))
+ return FALSE;
+
+ return rts_write_auth_verifier(s, &bind->auth_verifier, &bind->header);
+}
diff --git a/libfreerdp/core/gateway/rts.h b/libfreerdp/core/gateway/rts.h
index ccc4cfd..01a66a7 100644
--- a/libfreerdp/core/gateway/rts.h
+++ b/libfreerdp/core/gateway/rts.h
@@ -24,12 +24,14 @@
#include "config.h"
#endif
-#include "rpc.h"
+#include <winpr/stream.h>
#include <freerdp/api.h>
#include <freerdp/types.h>
#include <freerdp/log.h>
+#include "rpc.h"
+
#define RTS_FLAG_NONE 0x0000
#define RTS_FLAG_PING 0x0001
#define RTS_FLAG_OTHER_CMD 0x0002
@@ -79,21 +81,28 @@
FREERDP_LOCAL void rts_generate_cookie(BYTE* cookie);
-FREERDP_LOCAL SSIZE_T rts_command_length(UINT32 CommandType, const BYTE* buffer, size_t length);
+FREERDP_LOCAL BOOL rts_write_pdu_auth3(wStream* s, const rpcconn_rpc_auth_3_hdr_t* auth);
+FREERDP_LOCAL BOOL rts_write_pdu_bind(wStream* s, const rpcconn_bind_hdr_t* bind);
+
+FREERDP_LOCAL BOOL rts_read_pdu_header(wStream* s, rpcconn_hdr_t* header);
+FREERDP_LOCAL void rts_free_pdu_header(rpcconn_hdr_t* header, BOOL allocated);
+
+FREERDP_LOCAL BOOL rts_read_common_pdu_header(wStream* s, rpcconn_common_hdr_t* header);
-FREERDP_LOCAL int rts_send_CONN_A1_pdu(rdpRpc* rpc);
-FREERDP_LOCAL int rts_recv_CONN_A3_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
+FREERDP_LOCAL BOOL rts_command_length(UINT32 CommandType, wStream* s, size_t* length);
-FREERDP_LOCAL int rts_send_CONN_B1_pdu(rdpRpc* rpc);
+FREERDP_LOCAL BOOL rts_send_CONN_A1_pdu(rdpRpc* rpc);
+FREERDP_LOCAL BOOL rts_recv_CONN_A3_pdu(rdpRpc* rpc, wStream* buffer);
-FREERDP_LOCAL int rts_recv_CONN_C2_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
+FREERDP_LOCAL BOOL rts_send_CONN_B1_pdu(rdpRpc* rpc);
-FREERDP_LOCAL int rts_send_OUT_R1_A3_pdu(rdpRpc* rpc);
+FREERDP_LOCAL BOOL rts_recv_CONN_C2_pdu(rdpRpc* rpc, wStream* buffer);
-FREERDP_LOCAL int rts_send_flow_control_ack_pdu(rdpRpc* rpc);
+FREERDP_LOCAL BOOL rts_send_OUT_R1_A3_pdu(rdpRpc* rpc);
-FREERDP_LOCAL int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
+FREERDP_LOCAL BOOL rts_send_flow_control_ack_pdu(rdpRpc* rpc);
-#include "rts_signature.h"
+FREERDP_LOCAL BOOL rts_recv_out_of_sequence_pdu(rdpRpc* rpc, wStream* buffer,
+ const rpcconn_hdr_t* header);
#endif /* FREERDP_LIB_CORE_GATEWAY_RTS_H */
diff --git a/libfreerdp/core/gateway/rts_signature.c b/libfreerdp/core/gateway/rts_signature.c
index d3b376d..9d2fd94 100644
--- a/libfreerdp/core/gateway/rts_signature.c
+++ b/libfreerdp/core/gateway/rts_signature.c
@@ -17,6 +17,9 @@
* limitations under the License.
*/
+#include <assert.h>
+#include <winpr/stream.h>
+
#include <freerdp/log.h>
#include "rts_signature.h"
@@ -277,89 +280,74 @@ static const RTS_PDU_SIGNATURE_ENTRY RTS_PDU_SIGNATURE_TABLE[] = {
{ RTS_PDU_FLOW_CONTROL_ACK, TRUE, &RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE, "FlowControlAck" },
{ RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION, TRUE,
&RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE, "FlowControlAckWithDestination" },
-
- { 0, FALSE, NULL, NULL }
};
-BOOL rts_match_pdu_signature(const RtsPduSignature* signature, const rpcconn_rts_hdr_t* rts)
-{
- UINT16 i;
- int status;
- const BYTE* buffer;
- UINT32 length;
- UINT32 offset;
- UINT32 CommandType;
- UINT32 CommandLength;
-
- if (!signature || !rts)
- return FALSE;
-
- if (rts->Flags != signature->Flags)
- return FALSE;
+BOOL rts_match_pdu_signature(const RtsPduSignature* signature, wStream* src,
+ const rpcconn_hdr_t* header){
+ RtsPduSignature extracted = { 0 };
- if (rts->NumberOfCommands != signature->NumberOfCommands)
- return FALSE;
+ assert(signature);
+ assert(src);
- buffer = (const BYTE*)rts;
- offset = RTS_PDU_HEADER_LENGTH;
- length = rts->frag_length - offset;
- for (i = 0; i < rts->NumberOfCommands; i++)
- {
- CommandType = *((UINT32*)&buffer[offset]); /* CommandType (4 bytes) */
- offset += 4;
+ if (!rts_extract_pdu_signature(&extracted, src, header))
+ return FALSE;
- if (CommandType != signature->CommandTypes[i])
- return FALSE;
+ return memcmp(signature, &extracted, sizeof(extracted)) == 0;
+}
- status = rts_command_length(CommandType, &buffer[offset], length);
+BOOL rts_extract_pdu_signature(RtsPduSignature* signature, wStream* src,
+ const rpcconn_hdr_t* header)
+{
+ BOOL rc = FALSE;
+ UINT16 i;
+ wStream tmp;
+ rpcconn_hdr_t rheader = { 0 };
+ const rpcconn_rts_hdr_t* rts;
- if (status < 0)
- return FALSE;
+ assert(signature);
+ assert(src);
- CommandLength = (UINT32)status;
- offset += CommandLength;
- length = rts->frag_length - offset;
+ Stream_StaticInit(&tmp, Stream_Pointer(src), Stream_GetRemainingLength(src));
+ if (!header)
+ {
+ if (!rts_read_pdu_header(&tmp, &rheader))
+ goto fail;
+ header = &rheader;
}
- return TRUE;
-}
-
-BOOL rts_extract_pdu_signature(RtsPduSignature* signature, const rpcconn_rts_hdr_t* rts)
-{
- int i;
- int status;
- BYTE* buffer;
- UINT32 length;
- UINT32 offset;
- UINT32 CommandType;
- UINT32 CommandLength;
-
- if (!signature || !rts)
- return FALSE;
+ rts = &header->rts;
+ if (rts->header.frag_length < sizeof(rpcconn_rts_hdr_t))
+ goto fail;
signature->Flags = rts->Flags;
signature->NumberOfCommands = rts->NumberOfCommands;
- buffer = (BYTE*)rts;
- offset = RTS_PDU_HEADER_LENGTH;
- length = rts->frag_length - offset;
for (i = 0; i < rts->NumberOfCommands; i++)
{
- CommandType = *((UINT32*)&buffer[offset]); /* CommandType (4 bytes) */
- offset += 4;
- signature->CommandTypes[i] = CommandType;
- status = rts_command_length(CommandType, &buffer[offset], length);
+ UINT32 CommandType;
+ size_t CommandLength;
+
+ if (Stream_GetRemainingLength(&tmp) < 4)
+ goto fail;
- if (status < 0)
- return FALSE;
+ Stream_Read_UINT32(&tmp, CommandType); /* CommandType (4 bytes) */
- CommandLength = (UINT32)status;
- offset += CommandLength;
- length = rts->frag_length - offset;
+ /* We only need this for comparison against known command types */
+ if (i < ARRAYSIZE(signature->CommandTypes))
+ signature->CommandTypes[i] = CommandType;
+
+ if (!rts_command_length(CommandType, &tmp, &CommandLength))
+ goto fail;
+ if (!Stream_SafeSeek(&tmp, CommandLength))
+ goto fail;
}
- return TRUE;
+ rc = TRUE;
+fail:
+ rts_free_pdu_header(&rheader, FALSE);
+ Stream_Free(&tmp, FALSE);
+ return rc;
}
UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature,
@@ -367,11 +355,15 @@ UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature,
{
size_t i, j;
- for (i = 0; RTS_PDU_SIGNATURE_TABLE[i].SignatureId != 0; i++)
+ if (entry)
+ *entry = NULL;
+
+ for (i = 0; i < ARRAYSIZE(RTS_PDU_SIGNATURE_TABLE); i++)
{
- const RtsPduSignature* pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
+ const RTS_PDU_SIGNATURE_ENTRY* current = &RTS_PDU_SIGNATURE_TABLE[i];
+ const RtsPduSignature* pSignature = current->Signature;
- if (!RTS_PDU_SIGNATURE_TABLE[i].SignatureClient)
+ if (!current->SignatureClient)
continue;
if (signature->Flags != pSignature->Flags)
@@ -387,9 +379,9 @@ UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature,
}
if (entry)
- *entry = &RTS_PDU_SIGNATURE_TABLE[i];
+ *entry = current;
- return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
+ return current->SignatureId;
}
return 0;
diff --git a/libfreerdp/core/gateway/rts_signature.h b/libfreerdp/core/gateway/rts_signature.h
index 2c43cdc..31f0e81 100644
--- a/libfreerdp/core/gateway/rts_signature.h
+++ b/libfreerdp/core/gateway/rts_signature.h
@@ -178,10 +178,10 @@ FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_PING_SIGNATURE;
FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_SIGNATURE;
FREERDP_LOCAL extern const RtsPduSignature RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION_SIGNATURE;
-FREERDP_LOCAL BOOL rts_match_pdu_signature(const RtsPduSignature* signature,
- const rpcconn_rts_hdr_t* rts);
-FREERDP_LOCAL BOOL rts_extract_pdu_signature(RtsPduSignature* signature,
- const rpcconn_rts_hdr_t* rts);
+FREERDP_LOCAL BOOL rts_match_pdu_signature(const RtsPduSignature* signature, wStream* s,
+ const rpcconn_hdr_t* header);
+FREERDP_LOCAL BOOL rts_extract_pdu_signature(RtsPduSignature* signature, wStream* s,
+ const rpcconn_hdr_t* header);
FREERDP_LOCAL UINT32 rts_identify_pdu_signature(const RtsPduSignature* signature,
const RTS_PDU_SIGNATURE_ENTRY** entry);
FREERDP_LOCAL BOOL rts_print_pdu_signature(const RtsPduSignature* signature);
diff --git a/libfreerdp/core/gateway/tsg.c b/libfreerdp/core/gateway/tsg.c
index 7537f18..58cd4b0 100644
--- a/libfreerdp/core/gateway/tsg.c
+++ b/libfreerdp/core/gateway/tsg.c
@@ -24,6 +24,8 @@
#include "config.h"
#endif
+#include <assert.h>
+
#include <winpr/crt.h>
#include <winpr/ndr.h>
#include <winpr/error.h>
@@ -206,7 +208,6 @@ struct rdp_tsg
UINT32 TunnelId;
UINT32 ChannelId;
BOOL reauthSequence;
- rdpSettings* settings;
rdpTransport* transport;
UINT64 ReauthTunnelContext;
CONTEXT_HANDLE TunnelContext;
@@ -276,7 +277,7 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI
{
wStream* s;
rdpTsg* tsg;
- int length;
+ size_t length;
const byte* buffer1 = NULL;
const byte* buffer2 = NULL;
const byte* buffer3 = NULL;
@@ -312,7 +313,9 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI
totalDataBytes += lengths[2] + 4;
}
- length = 28 + totalDataBytes;
+ length = 28ull + totalDataBytes;
+ if (length > INT_MAX)
+ return -1;
s = Stream_New(NULL, length);
if (!s)
@@ -348,7 +351,7 @@ static int TsProxySendToServer(handle_t IDL_handle, const byte pRpcMessage[], UI
if (!rpc_client_write_call(tsg->rpc, s, TsProxySendToServerOpnum))
return -1;
- return length;
+ return (int)length;
}
/**
@@ -501,13 +504,22 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu,
UINT32 SwitchValue;
UINT32 MessageSwitchValue = 0;
UINT32 IsMessagePresent;
+ rdpContext* context;
UINT32 MsgBytes;
+ TSG_PACKET_STRING_MESSAGE packetStringMessage;
+
PTSG_PACKET_CAPABILITIES tsgCaps = NULL;
PTSG_PACKET_VERSIONCAPS versionCaps = NULL;
PTSG_PACKET_CAPS_RESPONSE packetCapsResponse = NULL;
PTSG_PACKET_QUARENC_RESPONSE packetQuarEncResponse = NULL;
WLog_DBG(TAG, "TsProxyCreateTunnelReadResponse");
+ assert(tsg);
+ assert(tsg->rpc);
+
+ context = tsg->rpc->context;
+ assert(context);
+
if (!pdu)
return FALSE;
@@ -652,8 +664,8 @@ static BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu,
if (Stream_GetRemainingLength(pdu->s) < 16)
goto fail;
- Stream_Seek_UINT32(pdu->s); /* IsDisplayMandatory (4 bytes) */
- Stream_Seek_UINT32(pdu->s); /* IsConsent Mandatory (4 bytes) */
+ Stream_Read_INT32(pdu->s, packetStringMessage.isDisplayMandatory);
+ Stream_Read_INT32(pdu->s, packetStringMessage.isConsentMandatory);
Stream_Read_UINT32(pdu->s, MsgBytes);
Stream_Read_UINT32(pdu->s, Pointer);
@@ -833,16 +845,19 @@ fail:
static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnelContext)
{
- UINT32 pad;
+ size_t pad;
wStream* s;
size_t count;
- UINT32 offset;
+ size_t offset;
rdpRpc* rpc;
if (!tsg || !tsg->rpc || !tunnelContext || !tsg->MachineName)
return FALSE;
count = _wcslen(tsg->MachineName) + 1;
+ if (count > UINT32_MAX)
+ return FALSE;
+
rpc = tsg->rpc;
WLog_DBG(TAG, "TsProxyAuthorizeTunnelWriteRequest");
s = Stream_New(NULL, 1024 + count * 2);
@@ -859,13 +874,13 @@ static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunn
Stream_Write_UINT32(s, 0x00020000); /* PacketQuarRequestPtr (4 bytes) */
Stream_Write_UINT32(s, 0x00000000); /* Flags (4 bytes) */
Stream_Write_UINT32(s, 0x00020004); /* MachineNamePtr (4 bytes) */
- Stream_Write_UINT32(s, count); /* NameLength (4 bytes) */
+ Stream_Write_UINT32(s, (UINT32)count); /* NameLength (4 bytes) */
Stream_Write_UINT32(s, 0x00020008); /* DataPtr (4 bytes) */
Stream_Write_UINT32(s, 0); /* DataLength (4 bytes) */
/* MachineName */
- Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */
+ Stream_Write_UINT32(s, (UINT32)count); /* MaxCount (4 bytes) */
Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */
- Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */
+ Stream_Write_UINT32(s, (UINT32)count); /* ActualCount (4 bytes) */
Stream_Write_UTF16_String(s, tsg->MachineName, count); /* Array */
/* 4-byte alignment */
offset = Stream_GetPosition(s);
@@ -876,7 +891,7 @@ static BOOL TsProxyAuthorizeTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunn
return rpc_client_write_call(rpc, s, TsProxyAuthorizeTunnelOpnum);
}
-static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
+static BOOL TsProxyAuthorizeTunnelReadResponse(RPC_PDU* pdu)
{
BOOL rc = FALSE;
UINT32 Pointer;
@@ -937,25 +952,24 @@ static BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
Stream_Seek_UINT32(pdu->s); /* Reserved (4 bytes) */
Stream_Read_UINT32(pdu->s, Pointer); /* ResponseDataPtr (4 bytes) */
Stream_Read_UINT32(pdu->s, packetResponse->responseDataLen); /* ResponseDataLength (4 bytes) */
- Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags
- .enableAllRedirections); /* EnableAllRedirections (4 bytes) */
- Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags
- .disableAllRedirections); /* DisableAllRedirections (4 bytes) */
- Stream_Read_UINT32(pdu->s,
- packetResponse->redirectionFlags
- .driveRedirectionDisabled); /* DriveRedirectionDisabled (4 bytes) */
- Stream_Read_UINT32(pdu->s,
- packetResponse->redirectionFlags
- .printerRedirectionDisabled); /* PrinterRedirectionDisabled (4 bytes) */
- Stream_Read_UINT32(pdu->s,
- packetResponse->redirectionFlags
- .portRedirectionDisabled); /* PortRedirectionDisabled (4 bytes) */
- Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags.reserved); /* Reserved (4 bytes) */
- Stream_Read_UINT32(
+ Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags
+ .enableAllRedirections); /* EnableAllRedirections (4 bytes) */
+ Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags
+ .disableAllRedirections); /* DisableAllRedirections (4 bytes) */
+ Stream_Read_INT32(pdu->s,
+ packetResponse->redirectionFlags
+ .driveRedirectionDisabled); /* DriveRedirectionDisabled (4 bytes) */
+ Stream_Read_INT32(pdu->s,
+ packetResponse->redirectionFlags
+ .printerRedirectionDisabled); /* PrinterRedirectionDisabled (4 bytes) */
+ Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags
+ .portRedirectionDisabled); /* PortRedirectionDisabled (4 bytes) */
+ Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags.reserved); /* Reserved (4 bytes) */
+ Stream_Read_INT32(
pdu->s, packetResponse->redirectionFlags
.clipboardRedirectionDisabled); /* ClipboardRedirectionDisabled (4 bytes) */
- Stream_Read_UINT32(pdu->s, packetResponse->redirectionFlags
- .pnpRedirectionDisabled); /* PnpRedirectionDisabled (4 bytes) */
+ Stream_Read_INT32(pdu->s, packetResponse->redirectionFlags
+ .pnpRedirectionDisabled); /* PnpRedirectionDisabled (4 bytes) */
Stream_Read_UINT32(pdu->s, SizeValue); /* (4 bytes) */
if (SizeValue != packetResponse->responseDataLen)
@@ -1055,12 +1069,19 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
UINT32 Pointer;
UINT32 SwitchValue;
TSG_PACKET packet;
+ rdpContext* context;
char* messageText = NULL;
TSG_PACKET_MSG_RESPONSE packetMsgResponse = { 0 };
TSG_PACKET_STRING_MESSAGE packetStringMessage = { 0 };
TSG_PACKET_REAUTH_MESSAGE packetReauthMessage = { 0 };
WLog_DBG(TAG, "TsProxyMakeTunnelCallReadResponse");
+ assert(tsg);
+ assert(tsg->rpc);
+
+ context = tsg->rpc->context;
+ assert(context);
+
/* This is an asynchronous response */
if (!pdu)
@@ -1108,10 +1129,10 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
WLog_INFO(TAG, "Consent Message: %s", messageText);
free(messageText);
- if (tsg->rpc && tsg->rpc->context && tsg->rpc->context->instance)
+ if (context->instance)
{
- rc = IFCALLRESULT(TRUE, tsg->rpc->context->instance->PresentGatewayMessage,
- tsg->rpc->context->instance, SwitchValue,
+ rc = IFCALLRESULT(TRUE, context->instance->PresentGatewayMessage,
+ context->instance, SwitchValue,
packetStringMessage.isDisplayMandatory != 0,
packetStringMessage.isConsentMandatory != 0,
packetStringMessage.msgBytes, packetStringMessage.msgBuffer);
@@ -1129,10 +1150,10 @@ static BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
WLog_INFO(TAG, "Service Message: %s", messageText);
free(messageText);
- if (tsg->rpc && tsg->rpc->context && tsg->rpc->context->instance)
+ if (context->instance)
{
- rc = IFCALLRESULT(TRUE, tsg->rpc->context->instance->PresentGatewayMessage,
- tsg->rpc->context->instance, SwitchValue,
+ rc = IFCALLRESULT(TRUE, context->instance->PresentGatewayMessage,
+ context->instance, SwitchValue,
packetStringMessage.isDisplayMandatory != 0,
packetStringMessage.isConsentMandatory != 0,
packetStringMessage.msgBytes, packetStringMessage.msgBuffer);
@@ -1184,6 +1205,8 @@ static BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnel
rpc = tsg->rpc;
count = _wcslen(tsg->Hostname) + 1;
+ if (count > UINT32_MAX)
+ return FALSE;
s = Stream_New(NULL, 60 + count * 2);
if (!s)
@@ -1203,15 +1226,15 @@ static BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* tunnel
Stream_Write_UINT16(s, tsg->Port); /* PortNumber (0xD3D = 3389) (2 bytes) */
Stream_Write_UINT32(s, 0x00000001); /* NumResourceNames (4 bytes) */
Stream_Write_UINT32(s, 0x00020004); /* ResourceNamePtr (4 bytes) */
- Stream_Write_UINT32(s, count); /* MaxCount (4 bytes) */
+ Stream_Write_UINT32(s, (UINT32)count); /* MaxCount (4 bytes) */
Stream_Write_UINT32(s, 0); /* Offset (4 bytes) */
- Stream_Write_UINT32(s, count); /* ActualCount (4 bytes) */
+ Stream_Write_UINT32(s, (UINT32)count); /* ActualCount (4 bytes) */
Stream_Write_UTF16_String(s, tsg->Hostname, count); /* Array */
return rpc_client_write_call(rpc, s, TsProxyCreateChannelOpnum);
}
-static BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu,
- CONTEXT_HANDLE* channelContext, UINT32* channelId)
+static BOOL TsProxyCreateChannelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* channelContext,
+ UINT32* channelId)
{
BOOL rc = FALSE;
WLog_DBG(TAG, "TsProxyCreateChannelReadResponse");
@@ -1259,7 +1282,7 @@ static BOOL TsProxyCloseChannelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context
return rpc_client_write_call(rpc, s, TsProxyCloseChannelOpnum);
}
-static BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context)
+static BOOL TsProxyCloseChannelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* context)
{
BOOL rc = FALSE;
WLog_DBG(TAG, "TsProxyCloseChannelReadResponse");
@@ -1306,7 +1329,7 @@ static BOOL TsProxyCloseTunnelWriteRequest(rdpTsg* tsg, CONTEXT_HANDLE* context)
return rpc_client_write_call(rpc, s, TsProxyCloseTunnelOpnum);
}
-static BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu, CONTEXT_HANDLE* context)
+static BOOL TsProxyCloseTunnelReadResponse(RPC_PDU* pdu, CONTEXT_HANDLE* context)
{
BOOL rc = FALSE;
WLog_DBG(TAG, "TsProxyCloseTunnelReadResponse");
@@ -1491,8 +1514,6 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
return FALSE;
rpc = tsg->rpc;
- Stream_SealLength(pdu->s);
- Stream_SetPosition(pdu->s, 0);
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
{
@@ -1531,7 +1552,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
CONTEXT_HANDLE* TunnelContext;
TunnelContext = (tsg->reauthSequence) ? &tsg->NewTunnelContext : &tsg->TunnelContext;
- if (!TsProxyAuthorizeTunnelReadResponse(tsg, pdu))
+ if (!TsProxyAuthorizeTunnelReadResponse(pdu))
{
WLog_ERR(TAG, "TsProxyAuthorizeTunnelReadResponse failure");
return FALSE;
@@ -1580,7 +1601,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
CONTEXT_HANDLE ChannelContext;
- if (!TsProxyCreateChannelReadResponse(tsg, pdu, &ChannelContext, &tsg->ChannelId))
+ if (!TsProxyCreateChannelReadResponse(pdu, &ChannelContext, &tsg->ChannelId))
{
WLog_ERR(TAG, "TsProxyCreateChannelReadResponse failure");
return FALSE;
@@ -1653,7 +1674,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
CONTEXT_HANDLE ChannelContext;
- if (!TsProxyCloseChannelReadResponse(tsg, pdu, &ChannelContext))
+ if (!TsProxyCloseChannelReadResponse(pdu, &ChannelContext))
{
WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure");
return FALSE;
@@ -1665,7 +1686,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
CONTEXT_HANDLE TunnelContext;
- if (!TsProxyCloseTunnelReadResponse(tsg, pdu, &TunnelContext))
+ if (!TsProxyCloseTunnelReadResponse(pdu, &TunnelContext))
{
WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure");
return FALSE;
@@ -1680,7 +1701,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
CONTEXT_HANDLE ChannelContext;
- if (!TsProxyCloseChannelReadResponse(tsg, pdu, &ChannelContext))
+ if (!TsProxyCloseChannelReadResponse(pdu, &ChannelContext))
{
WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure");
return FALSE;
@@ -1710,7 +1731,7 @@ BOOL tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
{
CONTEXT_HANDLE TunnelContext;
- if (!TsProxyCloseTunnelReadResponse(tsg, pdu, &TunnelContext))
+ if (!TsProxyCloseTunnelReadResponse(pdu, &TunnelContext))
{
WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure");
return FALSE;
@@ -1818,10 +1839,25 @@ static BOOL tsg_set_machine_name(rdpTsg* tsg, const char* machineName)
BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout)
{
DWORD nCount;
- HANDLE events[64];
- rdpRpc* rpc = tsg->rpc;
- rdpSettings* settings = rpc->settings;
- rdpTransport* transport = rpc->transport;
+ HANDLE events[MAXIMUM_WAIT_OBJECTS] = { 0 };
+ rdpRpc* rpc;
+ rdpContext* context;
+ rdpSettings* settings;
+ rdpTransport* transport;
+
+ assert(tsg);
+
+ rpc = tsg->rpc;
+ assert(rpc);
+
+ transport = rpc->transport;
+ assert(transport);
+
+ context = tsg->rpc->context;
+ assert(context);
+
+ settings = context->settings;
+
tsg->Port = port;
tsg->transport = transport;
@@ -1840,7 +1876,7 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout)
return FALSE;
}
- nCount = tsg_get_event_handles(tsg, events, 64);
+ nCount = tsg_get_event_handles(tsg, events, ARRAYSIZE(events));
if (nCount == 0)
return FALSE;
@@ -1911,7 +1947,7 @@ BOOL tsg_disconnect(rdpTsg* tsg)
* @return < 0 on error; 0 if not enough data is available (non blocking mode); > 0 bytes to read
*/
-static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
+static int tsg_read(rdpTsg* tsg, BYTE* data, size_t length)
{
rdpRpc* rpc;
int status = 0;
@@ -1929,7 +1965,7 @@ static int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
do
{
- status = rpc_client_receive_pipe_read(rpc->client, data, (size_t)length);
+ status = rpc_client_receive_pipe_read(rpc->client, data, length);
if (status < 0)
return -1;
@@ -1979,7 +2015,7 @@ static int tsg_write(rdpTsg* tsg, const BYTE* data, UINT32 length)
if (status < 0)
return -1;
- return length;
+ return (int)length;
}
rdpTsg* tsg_new(rdpTransport* transport)
@@ -1991,7 +2027,6 @@ rdpTsg* tsg_new(rdpTransport* transport)
return NULL;
tsg->transport = transport;
- tsg->settings = transport->settings;
tsg->rpc = rpc_new(tsg->transport);
if (!tsg->rpc)
@@ -2019,7 +2054,10 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num)
int status;
rdpTsg* tsg = (rdpTsg*)BIO_get_data(bio);
BIO_clear_flags(bio, BIO_FLAGS_WRITE);
- status = tsg_write(tsg, (BYTE*)buf, num);
+
+ if (num < 0)
+ return -1;
+ status = tsg_write(tsg, (const BYTE*)buf, (UINT32)num);
if (status < 0)
{
@@ -2051,7 +2089,7 @@ static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
}
BIO_clear_flags(bio, BIO_FLAGS_READ);
- status = tsg_read(tsg, (BYTE*)buf, size);
+ status = tsg_read(tsg, (BYTE*)buf, (size_t)size);
if (status < 0)
{
@@ -2073,17 +2111,22 @@ static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
static int transport_bio_tsg_puts(BIO* bio, const char* str)
{
+ WINPR_UNUSED(bio);
+ WINPR_UNUSED(str);
return 1;
}
static int transport_bio_tsg_gets(BIO* bio, char* str, int size)
{
+ WINPR_UNUSED(bio);
+ WINPR_UNUSED(str);
+ WINPR_UNUSED(size);
return 1;
}
static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
{
- int status = -1;
+ long status = -1;
rdpTsg* tsg = (rdpTsg*)BIO_get_data(bio);
RpcVirtualConnection* connection = tsg->rpc->VirtualConnection;
RpcInChannel* inChannel = connection->DefaultInChannel;
@@ -2112,27 +2155,27 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
case BIO_C_READ_BLOCKED:
{
- BIO* bio = outChannel->common.bio;
- status = BIO_read_blocked(bio);
+ BIO* cbio = outChannel->common.bio;
+ status = BIO_read_blocked(cbio);
}
break;
case BIO_C_WRITE_BLOCKED:
{
- BIO* bio = inChannel->common.bio;
- status = BIO_write_blocked(bio);
+ BIO* cbio = inChannel->common.bio;
+ status = BIO_write_blocked(cbio);
}
break;
case BIO_C_WAIT_READ:
{
int timeout = (int)arg1;
- BIO* bio = outChannel->common.bio;
+ BIO* cbio = outChannel->common.bio;
- if (BIO_read_blocked(bio))
- return BIO_wait_read(bio, timeout);
- else if (BIO_write_blocked(bio))
- return BIO_wait_write(bio, timeout);
+ if (BIO_read_blocked(cbio))
+ return BIO_wait_read(cbio, timeout);
+ else if (BIO_write_blocked(cbio))
+ return BIO_wait_write(cbio, timeout);
else
status = 1;
}
@@ -2141,12 +2184,12 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
case BIO_C_WAIT_WRITE:
{
int timeout = (int)arg1;
- BIO* bio = inChannel->common.bio;
+ BIO* cbio = inChannel->common.bio;
- if (BIO_write_blocked(bio))
- status = BIO_wait_write(bio, timeout);
- else if (BIO_read_blocked(bio))
- status = BIO_wait_read(bio, timeout);
+ if (BIO_write_blocked(cbio))
+ status = BIO_wait_write(cbio, timeout);
+ else if (BIO_read_blocked(cbio))
+ status = BIO_wait_read(cbio, timeout);
else
status = 1;
}
@@ -2161,6 +2204,7 @@ static long transport_bio_tsg_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
static int transport_bio_tsg_new(BIO* bio)
{
+ assert(bio);
BIO_set_init(bio, 1);
BIO_set_flags(bio, BIO_FLAGS_SHOULD_RETRY);
return 1;
@@ -2168,6 +2212,8 @@ static int transport_bio_tsg_new(BIO* bio)
static int transport_bio_tsg_free(BIO* bio)
{
+ assert(bio);
+ WINPR_UNUSED(bio);
return 1;
}