rcolebaugh / rpms / openssh

Forked from rpms/openssh 2 years ago
Clone
089d79
diff --color -ruNp a/gss-genr.c b/gss-genr.c
089d79
--- a/gss-genr.c	2024-05-16 15:49:43.999411060 +0200
089d79
+++ b/gss-genr.c	2024-06-26 12:17:55.586856954 +0200
089d79
@@ -346,6 +346,7 @@ ssh_gssapi_build_ctx(Gssctxt **ctx)
089d79
 	(*ctx)->creds = GSS_C_NO_CREDENTIAL;
089d79
 	(*ctx)->client = GSS_C_NO_NAME;
089d79
 	(*ctx)->client_creds = GSS_C_NO_CREDENTIAL;
089d79
+	(*ctx)->first = 1;
089d79
 }
089d79
 
089d79
 /* Delete our context, providing it has been built correctly */
089d79
@@ -371,6 +372,12 @@ ssh_gssapi_delete_ctx(Gssctxt **ctx)
089d79
 		gss_release_name(&ms, &(*ctx)->client);
089d79
 	if ((*ctx)->client_creds != GSS_C_NO_CREDENTIAL)
089d79
 		gss_release_cred(&ms, &(*ctx)->client_creds);
089d79
+	sshbuf_free((*ctx)->shared_secret);
089d79
+	sshbuf_free((*ctx)->server_pubkey);
089d79
+	sshbuf_free((*ctx)->server_host_key_blob);
089d79
+	sshbuf_free((*ctx)->server_blob);
089d79
+	explicit_bzero((*ctx)->hash, sizeof((*ctx)->hash));
089d79
+        BN_clear_free((*ctx)->dh_client_pub);
089d79
 
089d79
 	free(*ctx);
089d79
 	*ctx = NULL;
089d79
diff --color -ruNp a/kexgssc.c b/kexgssc.c
089d79
--- a/kexgssc.c	2024-05-16 15:49:43.820407648 +0200
089d79
+++ b/kexgssc.c	2024-07-02 16:26:25.628746744 +0200
089d79
@@ -47,566 +47,658 @@
089d79
 
089d79
 #include "ssh-gss.h"
089d79
 
089d79
-int
089d79
-kexgss_client(struct ssh *ssh)
089d79
+static int input_kexgss_hostkey(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgss_continue(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgss_complete(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgss_error(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_group(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_continue(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_complete(int, u_int32_t, struct ssh *);
089d79
+
089d79
+static int
089d79
+kexgss_final(struct ssh *ssh)
089d79
 {
089d79
 	struct kex *kex = ssh->kex;
089d79
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER,
089d79
-	    recv_tok = GSS_C_EMPTY_BUFFER,
089d79
-	    gssbuf, msg_tok = GSS_C_EMPTY_BUFFER, *token_ptr;
089d79
-	Gssctxt *ctxt;
089d79
-	OM_uint32 maj_status, min_status, ret_flags;
089d79
-	struct sshbuf *server_blob = NULL;
089d79
-	struct sshbuf *shared_secret = NULL;
089d79
-	struct sshbuf *server_host_key_blob = NULL;
089d79
+	Gssctxt *gss = kex->gss;
089d79
 	struct sshbuf *empty = NULL;
089d79
-	u_char *msg;
089d79
-	int type = 0;
089d79
-	int first = 1;
089d79
+	struct sshbuf *shared_secret = NULL;
089d79
 	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
 	size_t hashlen;
089d79
-	u_char c;
089d79
 	int r;
089d79
 
089d79
-	/* Initialise our GSSAPI world */
089d79
-	ssh_gssapi_build_ctx(&ctxt);
089d79
-	if (ssh_gssapi_id_kex(ctxt, kex->name, kex->kex_type)
089d79
-	    == GSS_C_NO_OID)
089d79
-		fatal("Couldn't identify host exchange");
089d79
-
089d79
-	if (ssh_gssapi_import_name(ctxt, kex->gss_host))
089d79
-		fatal("Couldn't import hostname");
089d79
-
089d79
-	if (kex->gss_client &&
089d79
-	    ssh_gssapi_client_identity(ctxt, kex->gss_client))
089d79
-		fatal("Couldn't acquire client credentials");
089d79
-
089d79
-	/* Step 1 */
089d79
-	switch (kex->kex_type) {
089d79
-	case KEX_GSS_GRP1_SHA1:
089d79
-	case KEX_GSS_GRP14_SHA1:
089d79
-	case KEX_GSS_GRP14_SHA256:
089d79
-	case KEX_GSS_GRP16_SHA512:
089d79
-		r = kex_dh_keypair(kex);
089d79
-		break;
089d79
-	case KEX_GSS_NISTP256_SHA256:
089d79
-		r = kex_ecdh_keypair(kex);
089d79
-		break;
089d79
-	case KEX_GSS_C25519_SHA256:
089d79
-		r = kex_c25519_keypair(kex);
089d79
-		break;
089d79
-	default:
089d79
-		fatal_f("Unexpected KEX type %d", kex->kex_type);
089d79
-	}
089d79
-	if (r != 0) {
089d79
-		ssh_gssapi_delete_ctx(&ctxt);
089d79
-		return r;
089d79
-	}
089d79
-
089d79
-	token_ptr = GSS_C_NO_BUFFER;
089d79
-
089d79
-	do {
089d79
-		debug("Calling gss_init_sec_context");
089d79
-
089d79
-		maj_status = ssh_gssapi_init_ctx(ctxt,
089d79
-		    kex->gss_deleg_creds, token_ptr, &send_tok,
089d79
-		    &ret_flags);
089d79
-
089d79
-		if (GSS_ERROR(maj_status)) {
089d79
-			/* XXX Useles code: Missing send? */
089d79
-			if (send_tok.length != 0) {
089d79
-				if ((r = sshpkt_start(ssh,
089d79
-				        SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh, send_tok.value,
089d79
-				        send_tok.length)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			}
089d79
-			fatal("gss_init_context failed");
089d79
-		}
089d79
-
089d79
-		/* If we've got an old receive buffer get rid of it */
089d79
-		if (token_ptr != GSS_C_NO_BUFFER)
089d79
-			gss_release_buffer(&min_status, &recv_tok);
089d79
-
089d79
-		if (maj_status == GSS_S_COMPLETE) {
089d79
-			/* If mutual state flag is not true, kex fails */
089d79
-			if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
-				fatal("Mutual authentication failed");
089d79
-
089d79
-			/* If integ avail flag is not true kex fails */
089d79
-			if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
-				fatal("Integrity check failed");
089d79
-		}
089d79
-
089d79
-		/*
089d79
-		 * If we have data to send, then the last message that we
089d79
-		 * received cannot have been a 'complete'.
089d79
-		 */
089d79
-		if (send_tok.length != 0) {
089d79
-			if (first) {
089d79
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh, send_tok.value,
089d79
-				        send_tok.length)) != 0 ||
089d79
-				    (r = sshpkt_put_stringb(ssh, kex->client_pub)) != 0)
089d79
-					fatal("failed to construct packet: %s", ssh_err(r));
089d79
-				first = 0;
089d79
-			} else {
089d79
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh, send_tok.value,
089d79
-				        send_tok.length)) != 0)
089d79
-					fatal("failed to construct packet: %s", ssh_err(r));
089d79
-			}
089d79
-			if ((r = sshpkt_send(ssh)) != 0)
089d79
-				fatal("failed to send packet: %s", ssh_err(r));
089d79
-			gss_release_buffer(&min_status, &send_tok);
089d79
-
089d79
-			/* If we've sent them data, they should reply */
089d79
-			do {
089d79
-				type = ssh_packet_read(ssh);
089d79
-				if (type == SSH2_MSG_KEXGSS_HOSTKEY) {
089d79
-					u_char *tmp = NULL;
089d79
-					size_t tmp_len = 0;
089d79
-
089d79
-					debug("Received KEXGSS_HOSTKEY");
089d79
-					if (server_host_key_blob)
089d79
-						fatal("Server host key received more than once");
089d79
-					if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
089d79
-						fatal("Failed to read server host key: %s", ssh_err(r));
089d79
-					if ((server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
089d79
-						fatal("sshbuf_from failed");
089d79
-				}
089d79
-			} while (type == SSH2_MSG_KEXGSS_HOSTKEY);
089d79
-
089d79
-			switch (type) {
089d79
-			case SSH2_MSG_KEXGSS_CONTINUE:
089d79
-				debug("Received GSSAPI_CONTINUE");
089d79
-				if (maj_status == GSS_S_COMPLETE)
089d79
-					fatal("GSSAPI Continue received from server when complete");
089d79
-				if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-				        &recv_tok)) != 0 ||
089d79
-				    (r = sshpkt_get_end(ssh)) != 0)
089d79
-					fatal("Failed to read token: %s", ssh_err(r));
089d79
-				break;
089d79
-			case SSH2_MSG_KEXGSS_COMPLETE:
089d79
-				debug("Received GSSAPI_COMPLETE");
089d79
-				if (msg_tok.value != NULL)
089d79
-				        fatal("Received GSSAPI_COMPLETE twice?");
089d79
-				if ((r = sshpkt_getb_froms(ssh, &server_blob)) != 0 ||
089d79
-				    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-				        &msg_tok)) != 0)
089d79
-					fatal("Failed to read message: %s", ssh_err(r));
089d79
-
089d79
-				/* Is there a token included? */
089d79
-				if ((r = sshpkt_get_u8(ssh, &c)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-				if (c) {
089d79
-					if ((r = ssh_gssapi_sshpkt_get_buffer_desc(
089d79
-					    ssh, &recv_tok)) != 0)
089d79
-						fatal("Failed to read token: %s", ssh_err(r));
089d79
-					/* If we're already complete - protocol error */
089d79
-					if (maj_status == GSS_S_COMPLETE)
089d79
-						sshpkt_disconnect(ssh, "Protocol error: received token when complete");
089d79
-				} else {
089d79
-					/* No token included */
089d79
-					if (maj_status != GSS_S_COMPLETE)
089d79
-						sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
089d79
-				}
089d79
-				if ((r = sshpkt_get_end(ssh)) != 0) {
089d79
-					fatal("Expecting end of packet.");
089d79
-				}
089d79
-				break;
089d79
-			case SSH2_MSG_KEXGSS_ERROR:
089d79
-				debug("Received Error");
089d79
-				if ((r = sshpkt_get_u32(ssh, &maj_status)) != 0 ||
089d79
-				    (r = sshpkt_get_u32(ssh, &min_status)) != 0 ||
089d79
-				    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
089d79
-				    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
089d79
-				    (r = sshpkt_get_end(ssh)) != 0)
089d79
-					fatal("sshpkt_get failed: %s", ssh_err(r));
089d79
-				fatal("GSSAPI Error: \n%.400s", msg);
089d79
-			default:
089d79
-				sshpkt_disconnect(ssh, "Protocol error: didn't expect packet type %d",
089d79
-				    type);
089d79
-			}
089d79
-			token_ptr = &recv_tok;
089d79
-		} else {
089d79
-			/* No data, and not complete */
089d79
-			if (maj_status != GSS_S_COMPLETE)
089d79
-				fatal("Not complete, and no token output");
089d79
-		}
089d79
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
089d79
-
089d79
 	/*
089d79
 	 * We _must_ have received a COMPLETE message in reply from the
089d79
 	 * server, which will have set server_blob and msg_tok
089d79
 	 */
089d79
 
089d79
-	if (type != SSH2_MSG_KEXGSS_COMPLETE)
089d79
-		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
089d79
-
089d79
 	/* compute shared secret */
089d79
 	switch (kex->kex_type) {
089d79
 	case KEX_GSS_GRP1_SHA1:
089d79
 	case KEX_GSS_GRP14_SHA1:
089d79
 	case KEX_GSS_GRP14_SHA256:
089d79
 	case KEX_GSS_GRP16_SHA512:
089d79
-		r = kex_dh_dec(kex, server_blob, &shared_secret);
089d79
+		r = kex_dh_dec(kex, gss->server_blob, &shared_secret);
089d79
 		break;
089d79
 	case KEX_GSS_C25519_SHA256:
089d79
-		if (sshbuf_ptr(server_blob)[sshbuf_len(server_blob)] & 0x80)
089d79
+		if (sshbuf_ptr(gss->server_blob)[sshbuf_len(gss->server_blob)] & 0x80)
089d79
 			fatal("The received key has MSB of last octet set!");
089d79
-		r = kex_c25519_dec(kex, server_blob, &shared_secret);
089d79
+		r = kex_c25519_dec(kex, gss->server_blob, &shared_secret);
089d79
 		break;
089d79
 	case KEX_GSS_NISTP256_SHA256:
089d79
-		if (sshbuf_len(server_blob) != 65)
089d79
-			fatal("The received NIST-P256 key did not match"
089d79
-			    "expected length (expected 65, got %zu)", sshbuf_len(server_blob));
089d79
+		if (sshbuf_len(gss->server_blob) != 65)
089d79
+			fatal("The received NIST-P256 key did not match "
089d79
+			      "expected length (expected 65, got %zu)",
089d79
+			      sshbuf_len(gss->server_blob));
089d79
 
089d79
-		if (sshbuf_ptr(server_blob)[0] != POINT_CONVERSION_UNCOMPRESSED)
089d79
+		if (sshbuf_ptr(gss->server_blob)[0] != POINT_CONVERSION_UNCOMPRESSED)
089d79
 			fatal("The received NIST-P256 key does not have first octet 0x04");
089d79
 
089d79
-		r = kex_ecdh_dec(kex, server_blob, &shared_secret);
089d79
+		r = kex_ecdh_dec(kex, gss->server_blob, &shared_secret);
089d79
 		break;
089d79
 	default:
089d79
 		r = SSH_ERR_INVALID_ARGUMENT;
089d79
 		break;
089d79
 	}
089d79
-	if (r != 0)
089d79
+	if (r != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		goto out;
089d79
+	}
089d79
 
089d79
 	if ((empty = sshbuf_new()) == NULL) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		r = SSH_ERR_ALLOC_FAIL;
089d79
 		goto out;
089d79
 	}
089d79
 
089d79
 	hashlen = sizeof(hash);
089d79
-	if ((r = kex_gen_hash(
089d79
-	    kex->hash_alg,
089d79
-	    kex->client_version,
089d79
-	    kex->server_version,
089d79
-	    kex->my,
089d79
-	    kex->peer,
089d79
-	    (server_host_key_blob ? server_host_key_blob : empty),
089d79
-	    kex->client_pub,
089d79
-	    server_blob,
089d79
-	    shared_secret,
089d79
-	    hash, &hashlen)) != 0)
089d79
+	r = kex_gen_hash(kex->hash_alg, kex->client_version,
089d79
+			 kex->server_version, kex->my, kex->peer,
089d79
+			 (gss->server_host_key_blob ? gss->server_host_key_blob : empty),
089d79
+			 kex->client_pub, gss->server_blob, shared_secret,
089d79
+			 hash, &hashlen);
089d79
+	sshbuf_free(empty);
089d79
+	if (r != 0)
089d79
 		fatal_f("Unexpected KEX type %d", kex->kex_type);
089d79
 
089d79
-	gssbuf.value = hash;
089d79
-	gssbuf.length = hashlen;
089d79
+	gss->buf.value = hash;
089d79
+	gss->buf.length = hashlen;
089d79
 
089d79
 	/* Verify that the hash matches the MIC we just got. */
089d79
-	if (GSS_ERROR(ssh_gssapi_checkmic(ctxt, &gssbuf, &msg_tok)))
089d79
+	if (GSS_ERROR(ssh_gssapi_checkmic(gss, &gss->buf, &gss->msg_tok)))
089d79
 		sshpkt_disconnect(ssh, "Hash's MIC didn't verify");
089d79
 
089d79
-	gss_release_buffer(&min_status, &msg_tok);
089d79
+	gss_release_buffer(&gss->minor, &gss->msg_tok);
089d79
 
089d79
 	if (kex->gss_deleg_creds)
089d79
-		ssh_gssapi_credentials_updated(ctxt);
089d79
+		ssh_gssapi_credentials_updated(gss);
089d79
 
089d79
 	if (gss_kex_context == NULL)
089d79
-		gss_kex_context = ctxt;
089d79
+		gss_kex_context = gss;
089d79
 	else
089d79
-		ssh_gssapi_delete_ctx(&ctxt);
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 
089d79
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
089d79
 		r = kex_send_newkeys(ssh);
089d79
 
089d79
+	if (kex->gss != NULL) {
089d79
+		sshbuf_free(gss->server_host_key_blob);
089d79
+		gss->server_host_key_blob = NULL;
089d79
+		sshbuf_free(gss->server_blob);
089d79
+		gss->server_blob = NULL;
089d79
+	}
089d79
 out:
089d79
-	explicit_bzero(hash, sizeof(hash));
089d79
 	explicit_bzero(kex->c25519_client_key, sizeof(kex->c25519_client_key));
089d79
-	sshbuf_free(empty);
089d79
-	sshbuf_free(server_host_key_blob);
089d79
-	sshbuf_free(server_blob);
089d79
+	explicit_bzero(hash, sizeof(hash));
089d79
 	sshbuf_free(shared_secret);
089d79
 	sshbuf_free(kex->client_pub);
089d79
 	kex->client_pub = NULL;
089d79
 	return r;
089d79
 }
089d79
 
089d79
+static int
089d79
+kexgss_init_ctx(struct ssh *ssh,
089d79
+		gss_buffer_desc *token_ptr)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
089d79
+	OM_uint32 ret_flags;
089d79
+	int r;
089d79
+
089d79
+	debug("Calling gss_init_sec_context");
089d79
+
089d79
+	gss->major = ssh_gssapi_init_ctx(gss, kex->gss_deleg_creds,
089d79
+					 token_ptr, &send_tok, &ret_flags);
089d79
+
089d79
+	if (GSS_ERROR(gss->major)) {
089d79
+		/* XXX Useless code: Missing send? */
089d79
+		if (send_tok.length != 0) {
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+				fatal("sshpkt failed: %s", ssh_err(r));
089d79
+		}
089d79
+		fatal("gss_init_context failed");
089d79
+	}
089d79
+
089d79
+	/* If we've got an old receive buffer get rid of it */
089d79
+	if (token_ptr != GSS_C_NO_BUFFER)
089d79
+		gss_release_buffer(&gss->minor, token_ptr);
089d79
+
089d79
+	if (gss->major == GSS_S_COMPLETE) {
089d79
+		/* If mutual state flag is not true, kex fails */
089d79
+		if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
+			fatal("Mutual authentication failed");
089d79
+
089d79
+		/* If integ avail flag is not true kex fails */
089d79
+		if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
+			fatal("Integrity check failed");
089d79
+	}
089d79
+
089d79
+	/*
089d79
+	 * If we have data to send, then the last message that we
089d79
+	 * received cannot have been a 'complete'.
089d79
+	 */
089d79
+	if (send_tok.length != 0) {
089d79
+		if (gss->first) {
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
+			    (r = sshpkt_put_stringb(ssh, kex->client_pub)) != 0)
089d79
+				fatal("failed to construct packet: %s", ssh_err(r));
089d79
+			gss->first = 0;
089d79
+		} else {
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+				fatal("failed to construct packet: %s", ssh_err(r));
089d79
+		}
089d79
+		if ((r = sshpkt_send(ssh)) != 0)
089d79
+			fatal("failed to send packet: %s", ssh_err(r));
089d79
+		gss_release_buffer(&gss->minor, &send_tok);
089d79
+
089d79
+		/* If we've sent them data, they should reply */
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, &input_kexgss_hostkey);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgss_continue);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, &input_kexgss_complete);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, &input_kexgss_error);
089d79
+		return 0;
089d79
+	}
089d79
+	/* No data, and not complete */
089d79
+	if (gss->major != GSS_S_COMPLETE)
089d79
+		fatal("Not complete, and no token output");
089d79
+
089d79
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return kexgss_init_ctx(ssh, token_ptr);
089d79
+
089d79
+	return kexgss_final(ssh);
089d79
+}
089d79
+
089d79
 int
089d79
-kexgssgex_client(struct ssh *ssh)
089d79
+kexgss_client(struct ssh *ssh)
089d79
 {
089d79
 	struct kex *kex = ssh->kex;
089d79
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER,
089d79
-	    recv_tok = GSS_C_EMPTY_BUFFER, gssbuf,
089d79
-            msg_tok = GSS_C_EMPTY_BUFFER, *token_ptr;
089d79
-	Gssctxt *ctxt;
089d79
-	OM_uint32 maj_status, min_status, ret_flags;
089d79
-	struct sshbuf *shared_secret = NULL;
089d79
-	BIGNUM *p = NULL;
089d79
-	BIGNUM *g = NULL;
089d79
-	struct sshbuf *buf = NULL;
089d79
-	struct sshbuf *server_host_key_blob = NULL;
089d79
-	struct sshbuf *server_blob = NULL;
089d79
-	BIGNUM *dh_server_pub = NULL;
089d79
-	u_char *msg;
089d79
-	int type = 0;
089d79
-	int first = 1;
089d79
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
-	size_t hashlen;
089d79
-	const BIGNUM *pub_key, *dh_p, *dh_g;
089d79
-	int nbits = 0, min = DH_GRP_MIN, max = DH_GRP_MAX;
089d79
-	struct sshbuf *empty = NULL;
089d79
-	u_char c;
089d79
 	int r;
089d79
 
089d79
 	/* Initialise our GSSAPI world */
089d79
-	ssh_gssapi_build_ctx(&ctxt);
089d79
-	if (ssh_gssapi_id_kex(ctxt, kex->name, kex->kex_type)
089d79
-	    == GSS_C_NO_OID)
089d79
+	ssh_gssapi_build_ctx(&kex->gss);
089d79
+	if (ssh_gssapi_id_kex(kex->gss, kex->name, kex->kex_type) == GSS_C_NO_OID)
089d79
 		fatal("Couldn't identify host exchange");
089d79
 
089d79
-	if (ssh_gssapi_import_name(ctxt, kex->gss_host))
089d79
+	if (ssh_gssapi_import_name(kex->gss, kex->gss_host))
089d79
 		fatal("Couldn't import hostname");
089d79
 
089d79
 	if (kex->gss_client &&
089d79
-	    ssh_gssapi_client_identity(ctxt, kex->gss_client))
089d79
+	    ssh_gssapi_client_identity(kex->gss, kex->gss_client))
089d79
 		fatal("Couldn't acquire client credentials");
089d79
 
089d79
-	debug("Doing group exchange");
089d79
-	nbits = dh_estimate(kex->dh_need * 8);
089d79
+	/* Step 1 */
089d79
+	switch (kex->kex_type) {
089d79
+	case KEX_GSS_GRP1_SHA1:
089d79
+	case KEX_GSS_GRP14_SHA1:
089d79
+	case KEX_GSS_GRP14_SHA256:
089d79
+	case KEX_GSS_GRP16_SHA512:
089d79
+		r = kex_dh_keypair(kex);
089d79
+		break;
089d79
+	case KEX_GSS_NISTP256_SHA256:
089d79
+		r = kex_ecdh_keypair(kex);
089d79
+		break;
089d79
+	case KEX_GSS_C25519_SHA256:
089d79
+		r = kex_c25519_keypair(kex);
089d79
+		break;
089d79
+	default:
089d79
+		fatal_f("Unexpected KEX type %d", kex->kex_type);
089d79
+	}
089d79
+	if (r != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		return r;
089d79
+	}
089d79
+	return kexgss_init_ctx(ssh, GSS_C_NO_BUFFER);
089d79
+}
089d79
 
089d79
-	kex->min = DH_GRP_MIN;
089d79
-	kex->max = DH_GRP_MAX;
089d79
-	kex->nbits = nbits;
089d79
-	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUPREQ)) != 0 ||
089d79
-	    (r = sshpkt_put_u32(ssh, min)) != 0 ||
089d79
-	    (r = sshpkt_put_u32(ssh, nbits)) != 0 ||
089d79
-	    (r = sshpkt_put_u32(ssh, max)) != 0 ||
089d79
-	    (r = sshpkt_send(ssh)) != 0)
089d79
-		fatal("Failed to construct a packet: %s", ssh_err(r));
089d79
+static int
089d79
+input_kexgss_hostkey(int type,
089d79
+		     u_int32_t seq,
089d79
+		     struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	u_char *tmp = NULL;
089d79
+	size_t tmp_len = 0;
089d79
+	int r;
089d79
+
089d79
+	debug("Received KEXGSS_HOSTKEY");
089d79
+	if (gss->server_host_key_blob)
089d79
+		fatal("Server host key received more than once");
089d79
+	if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
089d79
+		fatal("Failed to read server host key: %s", ssh_err(r));
089d79
+	if ((gss->server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
089d79
+		fatal("sshbuf_from failed");
089d79
+	return 0;
089d79
+}
089d79
 
089d79
-	if ((r = ssh_packet_read_expect(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0)
089d79
-		fatal("Error: %s", ssh_err(r));
089d79
+static int
089d79
+input_kexgss_continue(int type,
089d79
+		      u_int32_t seq,
089d79
+		      struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
089d79
+	int r;
089d79
 
089d79
-	if ((r = sshpkt_get_bignum2(ssh, &p)) != 0 ||
089d79
-	    (r = sshpkt_get_bignum2(ssh, &g)) != 0 ||
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
089d79
+
089d79
+	debug("Received GSSAPI_CONTINUE");
089d79
+	if (gss->major == GSS_S_COMPLETE)
089d79
+		fatal("GSSAPI Continue received from server when complete");
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
 	    (r = sshpkt_get_end(ssh)) != 0)
089d79
-		fatal("shpkt_get_bignum2 failed: %s", ssh_err(r));
089d79
+		fatal("Failed to read token: %s", ssh_err(r));
089d79
+	if  (!(gss->major & GSS_S_CONTINUE_NEEDED))
089d79
+		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
089d79
+	return kexgss_init_ctx(ssh, &recv_tok);
089d79
+}
089d79
 
089d79
-	if (BN_num_bits(p) < min || BN_num_bits(p) > max)
089d79
-		fatal("GSSGRP_GEX group out of range: %d !< %d !< %d",
089d79
-		    min, BN_num_bits(p), max);
089d79
+static int
089d79
+input_kexgss_complete(int type,
089d79
+		      u_int32_t seq,
089d79
+		      struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
089d79
+	u_char c;
089d79
+	int r;
089d79
 
089d79
-	if ((kex->dh = dh_new_group(g, p)) == NULL)
089d79
-		fatal("dn_new_group() failed");
089d79
-	p = g = NULL; /* belong to kex->dh now */
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
089d79
+
089d79
+	debug("Received GSSAPI_COMPLETE");
089d79
+	if (gss->msg_tok.value != NULL)
089d79
+	        fatal("Received GSSAPI_COMPLETE twice?");
089d79
+	if ((r = sshpkt_getb_froms(ssh, &gss->server_blob)) != 0 ||
089d79
+	    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &gss->msg_tok)) != 0)
089d79
+		fatal("Failed to read message: %s", ssh_err(r));
089d79
+
089d79
+	/* Is there a token included? */
089d79
+	if ((r = sshpkt_get_u8(ssh, &c)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+	if (c) {
089d79
+		if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0)
089d79
+			fatal("Failed to read token: %s", ssh_err(r));
089d79
+		/* If we're already complete - protocol error */
089d79
+		if (gss->major == GSS_S_COMPLETE)
089d79
+			sshpkt_disconnect(ssh, "Protocol error: received token when complete");
089d79
+	} else {
089d79
+		if (gss->major != GSS_S_COMPLETE)
089d79
+			sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
089d79
+	}
089d79
+	if ((r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("Expecting end of packet.");
089d79
 
089d79
-	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0)
089d79
-		goto out;
089d79
-	DH_get0_key(kex->dh, &pub_key, NULL);
089d79
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return kexgss_init_ctx(ssh, &recv_tok);
089d79
 
089d79
-	token_ptr = GSS_C_NO_BUFFER;
089d79
+	return kexgss_final(ssh);
089d79
+}
089d79
 
089d79
-	do {
089d79
-		/* Step 2 - call GSS_Init_sec_context() */
089d79
-		debug("Calling gss_init_sec_context");
089d79
-
089d79
-		maj_status = ssh_gssapi_init_ctx(ctxt,
089d79
-		    kex->gss_deleg_creds, token_ptr, &send_tok,
089d79
-		    &ret_flags);
089d79
-
089d79
-		if (GSS_ERROR(maj_status)) {
089d79
-			/* XXX Useles code: Missing send? */
089d79
-			if (send_tok.length != 0) {
089d79
-				if ((r = sshpkt_start(ssh,
089d79
-				        SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh, send_tok.value,
089d79
-				        send_tok.length)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			}
089d79
-			fatal("gss_init_context failed");
089d79
-		}
089d79
+static int
089d79
+input_kexgss_error(int type,
089d79
+		   u_int32_t seq,
089d79
+		   struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	u_char *msg;
089d79
+	int r;
089d79
 
089d79
-		/* If we've got an old receive buffer get rid of it */
089d79
-		if (token_ptr != GSS_C_NO_BUFFER)
089d79
-			gss_release_buffer(&min_status, &recv_tok);
089d79
-
089d79
-		if (maj_status == GSS_S_COMPLETE) {
089d79
-			/* If mutual state flag is not true, kex fails */
089d79
-			if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
-				fatal("Mutual authentication failed");
089d79
-
089d79
-			/* If integ avail flag is not true kex fails */
089d79
-			if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
-				fatal("Integrity check failed");
089d79
-		}
089d79
+	debug("Received Error");
089d79
+	if ((r = sshpkt_get_u32(ssh, &gss->major)) != 0 ||
089d79
+	    (r = sshpkt_get_u32(ssh, &gss->minor)) != 0 ||
089d79
+	    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
089d79
+	    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt_get failed: %s", ssh_err(r));
089d79
+	fatal("GSSAPI Error: \n%.400s", msg);
089d79
+	return 0;
089d79
+}
089d79
 
089d79
-		/*
089d79
-		 * If we have data to send, then the last message that we
089d79
-		 * received cannot have been a 'complete'.
089d79
-		 */
089d79
-		if (send_tok.length != 0) {
089d79
-			if (first) {
089d79
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh, send_tok.value,
089d79
-				        send_tok.length)) != 0 ||
089d79
-				    (r = sshpkt_put_bignum2(ssh, pub_key)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-				first = 0;
089d79
-			} else {
089d79
-				if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-				    (r = sshpkt_put_string(ssh,send_tok.value,
089d79
-				        send_tok.length)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			}
089d79
-			if ((r = sshpkt_send(ssh)) != 0)
089d79
-				fatal("sshpkt_send failed: %s", ssh_err(r));
089d79
-			gss_release_buffer(&min_status, &send_tok);
089d79
-
089d79
-			/* If we've sent them data, they should reply */
089d79
-			do {
089d79
-				type = ssh_packet_read(ssh);
089d79
-				if (type == SSH2_MSG_KEXGSS_HOSTKEY) {
089d79
-					u_char *tmp = NULL;
089d79
-					size_t tmp_len = 0;
089d79
-
089d79
-					debug("Received KEXGSS_HOSTKEY");
089d79
-					if (server_host_key_blob)
089d79
-						fatal("Server host key received more than once");
089d79
-					if ((r = sshpkt_get_string(ssh, &tmp, &tmp_len)) != 0)
089d79
-						fatal("sshpkt failed: %s", ssh_err(r));
089d79
-					if ((server_host_key_blob = sshbuf_from(tmp, tmp_len)) == NULL)
089d79
-						fatal("sshbuf_from failed");
089d79
-				}
089d79
-			} while (type == SSH2_MSG_KEXGSS_HOSTKEY);
089d79
-
089d79
-			switch (type) {
089d79
-			case SSH2_MSG_KEXGSS_CONTINUE:
089d79
-				debug("Received GSSAPI_CONTINUE");
089d79
-				if (maj_status == GSS_S_COMPLETE)
089d79
-					fatal("GSSAPI Continue received from server when complete");
089d79
-				if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-				        &recv_tok)) != 0 ||
089d79
-				    (r = sshpkt_get_end(ssh)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-				break;
089d79
-			case SSH2_MSG_KEXGSS_COMPLETE:
089d79
-				debug("Received GSSAPI_COMPLETE");
089d79
-				if (msg_tok.value != NULL)
089d79
-				        fatal("Received GSSAPI_COMPLETE twice?");
089d79
-				if ((r = sshpkt_getb_froms(ssh, &server_blob)) != 0 ||
089d79
-				    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-				        &msg_tok)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-
089d79
-				/* Is there a token included? */
089d79
-				if ((r = sshpkt_get_u8(ssh, &c)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-				if (c) {
089d79
-					if ((r = ssh_gssapi_sshpkt_get_buffer_desc(
089d79
-					        ssh, &recv_tok)) != 0 ||
089d79
-					    (r = sshpkt_get_end(ssh)) != 0)
089d79
-						fatal("sshpkt failed: %s", ssh_err(r));
089d79
-					/* If we're already complete - protocol error */
089d79
-					if (maj_status == GSS_S_COMPLETE)
089d79
-						sshpkt_disconnect(ssh, "Protocol error: received token when complete");
089d79
-				} else {
089d79
-					/* No token included */
089d79
-					if (maj_status != GSS_S_COMPLETE)
089d79
-						sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
089d79
-				}
089d79
-				break;
089d79
-			case SSH2_MSG_KEXGSS_ERROR:
089d79
-				debug("Received Error");
089d79
-				if ((r = sshpkt_get_u32(ssh, &maj_status)) != 0 ||
089d79
-				    (r = sshpkt_get_u32(ssh, &min_status)) != 0 ||
089d79
-				    (r = sshpkt_get_string(ssh, &msg, NULL)) != 0 ||
089d79
-				    (r = sshpkt_get_string(ssh, NULL, NULL)) != 0 || /* lang tag */
089d79
-				    (r = sshpkt_get_end(ssh)) != 0)
089d79
-					fatal("sshpkt failed: %s", ssh_err(r));
089d79
-				fatal("GSSAPI Error: \n%.400s", msg);
089d79
-			default:
089d79
-				sshpkt_disconnect(ssh, "Protocol error: didn't expect packet type %d",
089d79
-				    type);
089d79
-			}
089d79
-			token_ptr = &recv_tok;
089d79
-		} else {
089d79
-			/* No data, and not complete */
089d79
-			if (maj_status != GSS_S_COMPLETE)
089d79
-				fatal("Not complete, and no token output");
089d79
-		}
089d79
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
089d79
+/*******************************************************/
089d79
+/******************** KEXGSSGEX ************************/
089d79
+/*******************************************************/
089d79
+
089d79
+int
089d79
+kexgssgex_client(struct ssh *ssh)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	int r;
089d79
+
089d79
+	/* Initialise our GSSAPI world */
089d79
+	ssh_gssapi_build_ctx(&kex->gss);
089d79
+	if (ssh_gssapi_id_kex(kex->gss, kex->name, kex->kex_type) == GSS_C_NO_OID)
089d79
+		fatal("Couldn't identify host exchange");
089d79
+
089d79
+	if (ssh_gssapi_import_name(kex->gss, kex->gss_host))
089d79
+		fatal("Couldn't import hostname");
089d79
+
089d79
+	if (kex->gss_client &&
089d79
+	    ssh_gssapi_client_identity(kex->gss, kex->gss_client))
089d79
+		fatal("Couldn't acquire client credentials");
089d79
+
089d79
+	debug("Doing group exchange");
089d79
+	kex->min = DH_GRP_MIN;
089d79
+	kex->max = DH_GRP_MAX;
089d79
+	kex->nbits = dh_estimate(kex->dh_need * 8);
089d79
+
089d79
+	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUPREQ)) != 0 ||
089d79
+	    (r = sshpkt_put_u32(ssh, kex->min)) != 0 ||
089d79
+	    (r = sshpkt_put_u32(ssh, kex->nbits)) != 0 ||
089d79
+	    (r = sshpkt_put_u32(ssh, kex->max)) != 0 ||
089d79
+	    (r = sshpkt_send(ssh)) != 0)
089d79
+		fatal("Failed to construct a packet: %s", ssh_err(r));
089d79
+
089d79
+	debug("Wait SSH2_MSG_KEXGSS_GROUP");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUP, &input_kexgssgex_group);
089d79
+	return 0;
089d79
+}
089d79
+
089d79
+static int
089d79
+kexgssgex_final(struct ssh *ssh)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	struct sshbuf *buf = NULL;
089d79
+	struct sshbuf *empty = NULL;
089d79
+	struct sshbuf *shared_secret = NULL;
089d79
+	BIGNUM *dh_server_pub = NULL;
089d79
+	const BIGNUM *pub_key, *dh_p, *dh_g;
089d79
+	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
+	size_t hashlen;
089d79
+	int r = SSH_ERR_INTERNAL_ERROR;
089d79
 
089d79
 	/*
089d79
 	 * We _must_ have received a COMPLETE message in reply from the
089d79
-	 * server, which will have set dh_server_pub and msg_tok
089d79
+	 * server, which will have set server_blob and msg_tok
089d79
 	 */
089d79
 
089d79
-	if (type != SSH2_MSG_KEXGSS_COMPLETE)
089d79
-		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
089d79
-
089d79
 	/* 7. C verifies that the key Q_S is valid */
089d79
 	/* 8. C computes shared secret */
089d79
 	if ((buf = sshbuf_new()) == NULL ||
089d79
-	    (r = sshbuf_put_stringb(buf, server_blob)) != 0 ||
089d79
-	    (r = sshbuf_get_bignum2(buf, &dh_server_pub)) != 0)
089d79
+	    (r = sshbuf_put_stringb(buf, gss->server_blob)) != 0 ||
089d79
+	    (r = sshbuf_get_bignum2(buf, &dh_server_pub)) != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		goto out;
089d79
+	}
089d79
 	sshbuf_free(buf);
089d79
 	buf = NULL;
089d79
 
089d79
 	if ((shared_secret = sshbuf_new()) == NULL) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		r = SSH_ERR_ALLOC_FAIL;
089d79
 		goto out;
089d79
 	}
089d79
 
089d79
-	if ((r = kex_dh_compute_key(kex, dh_server_pub, shared_secret)) != 0)
089d79
+	if ((r = kex_dh_compute_key(kex, dh_server_pub, shared_secret)) != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		goto out;
089d79
+	}
089d79
+
089d79
 	if ((empty = sshbuf_new()) == NULL) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		r = SSH_ERR_ALLOC_FAIL;
089d79
 		goto out;
089d79
 	}
089d79
 
089d79
+	DH_get0_key(kex->dh, &pub_key, NULL);
089d79
 	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
089d79
 	hashlen = sizeof(hash);
089d79
-	if ((r = kexgex_hash(
089d79
-	    kex->hash_alg,
089d79
-	    kex->client_version,
089d79
-	    kex->server_version,
089d79
-	    kex->my,
089d79
-	    kex->peer,
089d79
-	    (server_host_key_blob ? server_host_key_blob : empty),
089d79
- 	    kex->min, kex->nbits, kex->max,
089d79
-	    dh_p, dh_g,
089d79
-	    pub_key,
089d79
-	    dh_server_pub,
089d79
-	    sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
089d79
-	    hash, &hashlen)) != 0)
089d79
+	r = kexgex_hash(kex->hash_alg, kex->client_version,
089d79
+			kex->server_version, kex->my, kex->peer,
089d79
+			(gss->server_host_key_blob ? gss->server_host_key_blob : empty),
089d79
+			kex->min, kex->nbits, kex->max, dh_p, dh_g, pub_key,
089d79
+			dh_server_pub, sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
089d79
+			hash, &hashlen);
089d79
+	sshbuf_free(empty);
089d79
+	if (r != 0)
089d79
 		fatal("Failed to calculate hash: %s", ssh_err(r));
089d79
 
089d79
-	gssbuf.value = hash;
089d79
-	gssbuf.length = hashlen;
089d79
+	gss->buf.value = hash;
089d79
+	gss->buf.length = hashlen;
089d79
 
089d79
 	/* Verify that the hash matches the MIC we just got. */
089d79
-	if (GSS_ERROR(ssh_gssapi_checkmic(ctxt, &gssbuf, &msg_tok)))
089d79
+	if (GSS_ERROR(ssh_gssapi_checkmic(gss, &gss->buf, &gss->msg_tok)))
089d79
 		sshpkt_disconnect(ssh, "Hash's MIC didn't verify");
089d79
 
089d79
-	gss_release_buffer(&min_status, &msg_tok);
089d79
+	gss_release_buffer(&gss->minor, &gss->msg_tok);
089d79
 
089d79
 	if (kex->gss_deleg_creds)
089d79
-		ssh_gssapi_credentials_updated(ctxt);
089d79
+		ssh_gssapi_credentials_updated(gss);
089d79
 
089d79
 	if (gss_kex_context == NULL)
089d79
-		gss_kex_context = ctxt;
089d79
+		gss_kex_context = gss;
089d79
 	else
089d79
-		ssh_gssapi_delete_ctx(&ctxt);
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 
089d79
 	/* Finally derive the keys and send them */
089d79
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
089d79
 		r = kex_send_newkeys(ssh);
089d79
+
089d79
+	if (kex->gss != NULL) {
089d79
+		sshbuf_free(gss->server_host_key_blob);
089d79
+		gss->server_host_key_blob = NULL;
089d79
+		sshbuf_free(gss->server_blob);
089d79
+		gss->server_blob = NULL;
089d79
+	}
089d79
 out:
089d79
-	sshbuf_free(buf);
089d79
-	sshbuf_free(server_blob);
089d79
-	sshbuf_free(empty);
089d79
 	explicit_bzero(hash, sizeof(hash));
089d79
 	DH_free(kex->dh);
089d79
 	kex->dh = NULL;
089d79
 	BN_clear_free(dh_server_pub);
089d79
 	sshbuf_free(shared_secret);
089d79
-	sshbuf_free(server_host_key_blob);
089d79
 	return r;
089d79
 }
089d79
 
089d79
+static int
089d79
+kexgssgex_init_ctx(struct ssh *ssh,
089d79
+		   gss_buffer_desc *token_ptr)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	const BIGNUM *pub_key;
089d79
+	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
089d79
+	OM_uint32 ret_flags;
089d79
+	int r;
089d79
+
089d79
+	/* Step 2 - call GSS_Init_sec_context() */
089d79
+	debug("Calling gss_init_sec_context");
089d79
+
089d79
+	gss->major = ssh_gssapi_init_ctx(gss, kex->gss_deleg_creds,
089d79
+					 token_ptr, &send_tok, &ret_flags);
089d79
+
089d79
+	if (GSS_ERROR(gss->major)) {
089d79
+		/* XXX Useless code: Missing send? */
089d79
+		if (send_tok.length != 0) {
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+				fatal("sshpkt failed: %s", ssh_err(r));
089d79
+		}
089d79
+		fatal("gss_init_context failed");
089d79
+	}
089d79
+
089d79
+	/* If we've got an old receive buffer get rid of it */
089d79
+	if (token_ptr != GSS_C_NO_BUFFER)
089d79
+		gss_release_buffer(&gss->minor, token_ptr);
089d79
+
089d79
+	if (gss->major == GSS_S_COMPLETE) {
089d79
+		/* If mutual state flag is not true, kex fails */
089d79
+		if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
+			fatal("Mutual authentication failed");
089d79
+
089d79
+		/* If integ avail flag is not true kex fails */
089d79
+		if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
+			fatal("Integrity check failed");
089d79
+	}
089d79
+
089d79
+	/*
089d79
+	 * If we have data to send, then the last message that we
089d79
+	 * received cannot have been a 'complete'.
089d79
+	 */
089d79
+	if (send_tok.length != 0) {
089d79
+		if (gss->first) {
089d79
+	                DH_get0_key(kex->dh, &pub_key, NULL);
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_INIT)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
+			    (r = sshpkt_put_bignum2(ssh, pub_key)) != 0)
089d79
+				fatal("failed to construct packet: %s", ssh_err(r));
089d79
+			gss->first = 0;
089d79
+		} else {
089d79
+			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+				fatal("failed to construct packet: %s", ssh_err(r));
089d79
+		}
089d79
+		if ((r = sshpkt_send(ssh)) != 0)
089d79
+			fatal("failed to send packet: %s", ssh_err(r));
089d79
+		gss_release_buffer(&gss->minor, &send_tok);
089d79
+
089d79
+		/* If we've sent them data, they should reply */
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, &input_kexgss_hostkey);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgssgex_continue);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, &input_kexgssgex_complete);
089d79
+		ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, &input_kexgss_error);
089d79
+		return 0;
089d79
+	}
089d79
+	/* No data, and not complete */
089d79
+	if (gss->major != GSS_S_COMPLETE)
089d79
+		fatal("Not complete, and no token output");
089d79
+
089d79
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return kexgssgex_init_ctx(ssh, token_ptr);
089d79
+
089d79
+	return kexgssgex_final(ssh);
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgssgex_group(int type,
089d79
+		      u_int32_t seq,
089d79
+		      struct ssh *ssh)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	BIGNUM *p = NULL;
089d79
+	BIGNUM *g = NULL;
089d79
+	int r;
089d79
+
089d79
+	debug("Received SSH2_MSG_KEXGSS_GROUP");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUP, NULL);
089d79
+
089d79
+	if ((r = sshpkt_get_bignum2(ssh, &p)) != 0 ||
089d79
+	    (r = sshpkt_get_bignum2(ssh, &g)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("shpkt_get_bignum2 failed: %s", ssh_err(r));
089d79
+
089d79
+	if (BN_num_bits(p) < kex->min || BN_num_bits(p) > kex->max)
089d79
+		fatal("GSSGRP_GEX group out of range: %d !< %d !< %d",
089d79
+		    kex->min, BN_num_bits(p), kex->max);
089d79
+
089d79
+	if ((kex->dh = dh_new_group(g, p)) == NULL)
089d79
+		fatal("dn_new_group() failed");
089d79
+	p = g = NULL; /* belong to kex->dh now */
089d79
+
089d79
+	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		DH_free(kex->dh);
089d79
+		kex->dh = NULL;
089d79
+		return r;
089d79
+	}
089d79
+
089d79
+	return kexgssgex_init_ctx(ssh, GSS_C_NO_BUFFER);
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgssgex_continue(int type,
089d79
+			 u_int32_t seq,
089d79
+			 struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
089d79
+	int r;
089d79
+
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
089d79
+
089d79
+	debug("Received GSSAPI_CONTINUE");
089d79
+	if (gss->major == GSS_S_COMPLETE)
089d79
+		fatal("GSSAPI Continue received from server when complete");
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("Failed to read token: %s", ssh_err(r));
089d79
+	if  (!(gss->major & GSS_S_CONTINUE_NEEDED))
089d79
+		fatal("Didn't receive a SSH2_MSG_KEXGSS_COMPLETE when I expected it");
089d79
+	return kexgssgex_init_ctx(ssh, &recv_tok);
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgssgex_complete(int type,
089d79
+		      u_int32_t seq,
089d79
+		      struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok = GSS_C_EMPTY_BUFFER;
089d79
+	u_char c;
089d79
+	int r;
089d79
+
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_HOSTKEY, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_COMPLETE, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_ERROR, NULL);
089d79
+
089d79
+	debug("Received GSSAPI_COMPLETE");
089d79
+	if (gss->msg_tok.value != NULL)
089d79
+	        fatal("Received GSSAPI_COMPLETE twice?");
089d79
+	if ((r = sshpkt_getb_froms(ssh, &gss->server_blob)) != 0 ||
089d79
+	    (r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &gss->msg_tok)) != 0)
089d79
+		fatal("Failed to read message: %s", ssh_err(r));
089d79
+
089d79
+	/* Is there a token included? */
089d79
+	if ((r = sshpkt_get_u8(ssh, &c)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+	if (c) {
089d79
+		if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0)
089d79
+			fatal("Failed to read token: %s", ssh_err(r));
089d79
+		/* If we're already complete - protocol error */
089d79
+		if (gss->major == GSS_S_COMPLETE)
089d79
+			sshpkt_disconnect(ssh, "Protocol error: received token when complete");
089d79
+	} else {
089d79
+		if (gss->major != GSS_S_COMPLETE)
089d79
+			sshpkt_disconnect(ssh, "Protocol error: did not receive final token");
089d79
+	}
089d79
+	if ((r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("Expecting end of packet.");
089d79
+
089d79
+	if  (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return kexgssgex_init_ctx(ssh, &recv_tok);
089d79
+
089d79
+	return kexgssgex_final(ssh);
089d79
+}
089d79
+
089d79
 #endif /* defined(GSSAPI) && defined(WITH_OPENSSL) */
089d79
diff --color -ruNp a/kexgsss.c b/kexgsss.c
089d79
--- a/kexgsss.c	2024-05-16 15:49:43.820407648 +0200
089d79
+++ b/kexgsss.c	2024-07-02 16:29:05.744790839 +0200
089d79
@@ -50,33 +50,18 @@
089d79
 
089d79
 extern ServerOptions options;
089d79
 
089d79
+static int input_kexgss_init(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgss_continue(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_groupreq(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_init(int, u_int32_t, struct ssh *);
089d79
+static int input_kexgssgex_continue(int, u_int32_t, struct ssh *);
089d79
+
089d79
 int
089d79
 kexgss_server(struct ssh *ssh)
089d79
 {
089d79
 	struct kex *kex = ssh->kex;
089d79
-	OM_uint32 maj_status, min_status;
089d79
-
089d79
-	/*
089d79
-	 * Some GSSAPI implementations use the input value of ret_flags (an
089d79
-	 * output variable) as a means of triggering mechanism specific
089d79
-	 * features. Initializing it to zero avoids inadvertently
089d79
-	 * activating this non-standard behaviour.
089d79
-	 */
089d79
-
089d79
-	OM_uint32 ret_flags = 0;
089d79
-	gss_buffer_desc gssbuf = {0, NULL}, recv_tok, msg_tok;
089d79
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
089d79
-	Gssctxt *ctxt = NULL;
089d79
-	struct sshbuf *shared_secret = NULL;
089d79
-	struct sshbuf *client_pubkey = NULL;
089d79
-	struct sshbuf *server_pubkey = NULL;
089d79
-	struct sshbuf *empty = sshbuf_new();
089d79
-	int type = 0;
089d79
 	gss_OID oid;
089d79
 	char *mechs;
089d79
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
-	size_t hashlen;
089d79
-	int r;
089d79
 
089d79
 	/* Initialise GSSAPI */
089d79
 
089d79
@@ -92,135 +77,91 @@ kexgss_server(struct ssh *ssh)
089d79
 	debug2_f("Identifying %s", kex->name);
089d79
 	oid = ssh_gssapi_id_kex(NULL, kex->name, kex->kex_type);
089d79
 	if (oid == GSS_C_NO_OID)
089d79
-	   fatal("Unknown gssapi mechanism");
089d79
+		fatal("Unknown gssapi mechanism");
089d79
 
089d79
 	debug2_f("Acquiring credentials");
089d79
 
089d79
-	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&ctxt, oid)))
089d79
+	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&kex->gss, oid)))
089d79
 		fatal("Unable to acquire credentials for the server");
089d79
 
089d79
-	do {
089d79
-		debug("Wait SSH2_MSG_KEXGSS_INIT");
089d79
-		type = ssh_packet_read(ssh);
089d79
-		switch(type) {
089d79
-		case SSH2_MSG_KEXGSS_INIT:
089d79
-			if (gssbuf.value != NULL)
089d79
-				fatal("Received KEXGSS_INIT after initialising");
089d79
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-			        &recv_tok)) != 0 ||
089d79
-			    (r = sshpkt_getb_froms(ssh, &client_pubkey)) != 0 ||
089d79
-			    (r = sshpkt_get_end(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
+	ssh_gssapi_build_ctx(&kex->gss);
089d79
+	if (kex->gss == NULL)
089d79
+		fatal("Unable to allocate memory for gss context");
089d79
+
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, &input_kexgss_init);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgss_continue);
089d79
+	debug("Wait SSH2_MSG_KEXGSS_INIT");
089d79
+	return 0;
089d79
+}
089d79
 
089d79
-			switch (kex->kex_type) {
089d79
-			case KEX_GSS_GRP1_SHA1:
089d79
-			case KEX_GSS_GRP14_SHA1:
089d79
-			case KEX_GSS_GRP14_SHA256:
089d79
-			case KEX_GSS_GRP16_SHA512:
089d79
-				r = kex_dh_enc(kex, client_pubkey, &server_pubkey,
089d79
-				    &shared_secret);
089d79
-				break;
089d79
-			case KEX_GSS_NISTP256_SHA256:
089d79
-				r = kex_ecdh_enc(kex, client_pubkey, &server_pubkey,
089d79
-				    &shared_secret);
089d79
-				break;
089d79
-			case KEX_GSS_C25519_SHA256:
089d79
-				r = kex_c25519_enc(kex, client_pubkey, &server_pubkey,
089d79
-				    &shared_secret);
089d79
-				break;
089d79
-			default:
089d79
-				fatal_f("Unexpected KEX type %d", kex->kex_type);
089d79
-			}
089d79
-			if (r != 0)
089d79
-				goto out;
089d79
-
089d79
-			/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
089d79
-
089d79
-			/* Calculate the hash early so we can free the
089d79
-			* client_pubkey, which has reference to the parent
089d79
-			* buffer state->incoming_packet
089d79
-			*/
089d79
-			hashlen = sizeof(hash);
089d79
-			if ((r = kex_gen_hash(
089d79
-			    kex->hash_alg,
089d79
-			    kex->client_version,
089d79
-			    kex->server_version,
089d79
-			    kex->peer,
089d79
-			    kex->my,
089d79
-			    empty,
089d79
-			    client_pubkey,
089d79
-			    server_pubkey,
089d79
-			    shared_secret,
089d79
-			    hash, &hashlen)) != 0)
089d79
-				goto out;
089d79
-
089d79
-			gssbuf.value = hash;
089d79
-			gssbuf.length = hashlen;
089d79
-
089d79
-			sshbuf_free(client_pubkey);
089d79
-			client_pubkey = NULL;
089d79
-
089d79
-			break;
089d79
-		case SSH2_MSG_KEXGSS_CONTINUE:
089d79
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-			        &recv_tok)) != 0 ||
089d79
-			    (r = sshpkt_get_end(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			break;
089d79
-		default:
089d79
-			sshpkt_disconnect(ssh,
089d79
-			    "Protocol error: didn't expect packet type %d",
089d79
-			    type);
089d79
-		}
089d79
+static inline void
089d79
+kexgss_accept_ctx(struct ssh *ssh,
089d79
+		  gss_buffer_desc *recv_tok,
089d79
+		  gss_buffer_desc *send_tok,
089d79
+		  OM_uint32 *ret_flags)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	int r;
089d79
 
089d79
-		maj_status = mm_ssh_gssapi_accept_ctx(ctxt, &recv_tok,
089d79
-		    &send_tok, &ret_flags);
089d79
+	gss->major = mm_ssh_gssapi_accept_ctx(gss, recv_tok, send_tok, ret_flags);
089d79
+	gss_release_buffer(&gss->minor, recv_tok);
089d79
 
089d79
-		gss_release_buffer(&min_status, &recv_tok);
089d79
+	if (gss->major != GSS_S_COMPLETE && send_tok->length == 0)
089d79
+		fatal("Zero length token output when incomplete");
089d79
 
089d79
-		if (maj_status != GSS_S_COMPLETE && send_tok.length == 0)
089d79
-			fatal("Zero length token output when incomplete");
089d79
+	if (gss->buf.value == NULL)
089d79
+		fatal("No client public key");
089d79
 
089d79
-		if (gssbuf.value == NULL)
089d79
-			fatal("No client public key");
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED) {
089d79
+		debug("Sending GSSAPI_CONTINUE");
089d79
+		if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
089d79
+		    (r = sshpkt_send(ssh)) != 0)
089d79
+			fatal("sshpkt failed: %s", ssh_err(r));
089d79
+		gss_release_buffer(&gss->minor, send_tok);
089d79
+	}
089d79
+}
089d79
 
089d79
-		if (maj_status & GSS_S_CONTINUE_NEEDED) {
089d79
-			debug("Sending GSSAPI_CONTINUE");
089d79
-			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
-			    (r = sshpkt_send(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			gss_release_buffer(&min_status, &send_tok);
089d79
-		}
089d79
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
089d79
+static inline int
089d79
+kexgss_final(struct ssh *ssh,
089d79
+	     gss_buffer_desc *send_tok,
089d79
+	     OM_uint32 *ret_flags)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	gss_buffer_desc msg_tok;
089d79
+	int r;
089d79
+
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
 
089d79
-	if (GSS_ERROR(maj_status)) {
089d79
-		if (send_tok.length > 0) {
089d79
+	if (GSS_ERROR(gss->major)) {
089d79
+		if (send_tok->length > 0) {
089d79
 			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
089d79
 			    (r = sshpkt_send(ssh)) != 0)
089d79
 				fatal("sshpkt failed: %s", ssh_err(r));
089d79
 		}
089d79
 		fatal("accept_ctx died");
089d79
 	}
089d79
 
089d79
-	if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
+	if (!(*ret_flags & GSS_C_MUTUAL_FLAG))
089d79
 		fatal("Mutual Authentication flag wasn't set");
089d79
 
089d79
-	if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
+	if (!(*ret_flags & GSS_C_INTEG_FLAG))
089d79
 		fatal("Integrity flag wasn't set");
089d79
 
089d79
-	if (GSS_ERROR(mm_ssh_gssapi_sign(ctxt, &gssbuf, &msg_tok)))
089d79
+	if (GSS_ERROR(mm_ssh_gssapi_sign(gss, &gss->buf, &msg_tok)))
089d79
 		fatal("Couldn't get MIC");
089d79
 
089d79
 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_COMPLETE)) != 0 ||
089d79
-	    (r = sshpkt_put_stringb(ssh, server_pubkey)) != 0 ||
089d79
+	    (r = sshpkt_put_stringb(ssh, gss->server_pubkey)) != 0 ||
089d79
 	    (r = sshpkt_put_string(ssh, msg_tok.value, msg_tok.length)) != 0)
089d79
 		fatal("sshpkt failed: %s", ssh_err(r));
089d79
 
089d79
-	if (send_tok.length != 0) {
089d79
+	if (send_tok->length != 0) {
089d79
 		if ((r = sshpkt_put_u8(ssh, 1)) != 0 || /* true */
089d79
-		    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0)
089d79
 			fatal("sshpkt failed: %s", ssh_err(r));
089d79
 	} else {
089d79
 		if ((r = sshpkt_put_u8(ssh, 0)) != 0) /* false */
089d79
@@ -229,59 +170,139 @@ kexgss_server(struct ssh *ssh)
089d79
 	if ((r = sshpkt_send(ssh)) != 0)
089d79
 		fatal("sshpkt_send failed: %s", ssh_err(r));
089d79
 
089d79
-	gss_release_buffer(&min_status, &send_tok);
089d79
-	gss_release_buffer(&min_status, &msg_tok);
089d79
+	gss_release_buffer(&gss->minor, send_tok);
089d79
+	gss_release_buffer(&gss->minor, &msg_tok);
089d79
 
089d79
 	if (gss_kex_context == NULL)
089d79
-		gss_kex_context = ctxt;
089d79
+		gss_kex_context = gss;
089d79
 	else
089d79
-		ssh_gssapi_delete_ctx(&ctxt);
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 
089d79
-	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
089d79
+	if ((r = kex_derive_keys(ssh, gss->hash, gss->hashlen, gss->shared_secret)) == 0)
089d79
 		r = kex_send_newkeys(ssh);
089d79
 
089d79
 	/* If this was a rekey, then save out any delegated credentials we
089d79
 	 * just exchanged.  */
089d79
 	if (options.gss_store_rekey)
089d79
 		ssh_gssapi_rekey_creds();
089d79
-out:
089d79
-	sshbuf_free(empty);
089d79
-	explicit_bzero(hash, sizeof(hash));
089d79
-	sshbuf_free(shared_secret);
089d79
-	sshbuf_free(client_pubkey);
089d79
-	sshbuf_free(server_pubkey);
089d79
+
089d79
+	if (kex->gss != NULL) {
089d79
+		explicit_bzero(gss->hash, sizeof(gss->hash));
089d79
+		sshbuf_free(gss->shared_secret);
089d79
+		gss->shared_secret = NULL;
089d79
+		sshbuf_free(gss->server_pubkey);
089d79
+		gss->server_pubkey = NULL;
089d79
+	}
089d79
 	return r;
089d79
 }
089d79
 
089d79
-int
089d79
-kexgssgex_server(struct ssh *ssh)
089d79
+static int
089d79
+input_kexgss_init(int type,
089d79
+		  u_int32_t seq,
089d79
+		  struct ssh *ssh)
089d79
 {
089d79
 	struct kex *kex = ssh->kex;
089d79
-	OM_uint32 maj_status, min_status;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	struct sshbuf *empty;
089d79
+	struct sshbuf *client_pubkey = NULL;
089d79
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
089d79
+	OM_uint32 ret_flags = 0;
089d79
+	int r;
089d79
+
089d79
+	debug("SSH2_MSG_KEXGSS_INIT received");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
089d79
 
089d79
-	/*
089d79
-	 * Some GSSAPI implementations use the input value of ret_flags (an
089d79
-	 * output variable) as a means of triggering mechanism specific
089d79
-	 * features. Initializing it to zero avoids inadvertently
089d79
-	 * activating this non-standard behaviour.
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
+	    (r = sshpkt_getb_froms(ssh, &client_pubkey)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	switch (kex->kex_type) {
089d79
+	case KEX_GSS_GRP1_SHA1:
089d79
+	case KEX_GSS_GRP14_SHA1:
089d79
+	case KEX_GSS_GRP14_SHA256:
089d79
+	case KEX_GSS_GRP16_SHA512:
089d79
+		r = kex_dh_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
089d79
+		break;
089d79
+	case KEX_GSS_NISTP256_SHA256:
089d79
+		r = kex_ecdh_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
089d79
+		break;
089d79
+	case KEX_GSS_C25519_SHA256:
089d79
+		r = kex_c25519_enc(kex, client_pubkey, &gss->server_pubkey, &gss->shared_secret);
089d79
+		break;
089d79
+	default:
089d79
+		fatal_f("Unexpected KEX type %d", kex->kex_type);
089d79
+	}
089d79
+	if (r != 0) {
089d79
+		sshbuf_free(client_pubkey);
089d79
+                ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		return r;
089d79
+	}
089d79
+
089d79
+	/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
089d79
+
089d79
+	if ((empty = sshbuf_new()) == NULL) {
089d79
+		sshbuf_free(client_pubkey);
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		return SSH_ERR_ALLOC_FAIL;
089d79
+	}
089d79
+
089d79
+	/* Calculate the hash early so we can free the
089d79
+	 * client_pubkey, which has reference to the parent
089d79
+	 * buffer state->incoming_packet
089d79
 	 */
089d79
+	gss->hashlen = sizeof(gss->hash);
089d79
+	r = kex_gen_hash(kex->hash_alg, kex->client_version, kex->server_version,
089d79
+			 kex->peer, kex->my, empty, client_pubkey, gss->server_pubkey,
089d79
+			 gss->shared_secret, gss->hash, &gss->hashlen);
089d79
+	sshbuf_free(empty);
089d79
+	sshbuf_free(client_pubkey);
089d79
+	if (r != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		return r;
089d79
+	}
089d79
+
089d79
+	gss->buf.value = gss->hash;
089d79
+	gss->buf.length = gss->hashlen;
089d79
+
089d79
+	kexgss_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return 0;
089d79
 
089d79
+	return kexgss_final(ssh, &send_tok, &ret_flags);
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgss_continue(int type,
089d79
+		      u_int32_t seq,
089d79
+		      struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
089d79
 	OM_uint32 ret_flags = 0;
089d79
-	gss_buffer_desc gssbuf, recv_tok, msg_tok;
089d79
-	gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER;
089d79
-	Gssctxt *ctxt = NULL;
089d79
-	struct sshbuf *shared_secret = NULL;
089d79
-	int type = 0;
089d79
+	int r;
089d79
+
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	kexgss_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return 0;
089d79
+
089d79
+	return kexgss_final(ssh, &send_tok, &ret_flags);
089d79
+}
089d79
+
089d79
+/*******************************************************/
089d79
+/******************** KEXGSSGEX ************************/
089d79
+/*******************************************************/
089d79
+
089d79
+int
089d79
+kexgssgex_server(struct ssh *ssh)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
 	gss_OID oid;
089d79
 	char *mechs;
089d79
-	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
-	size_t hashlen;
089d79
-	BIGNUM *dh_client_pub = NULL;
089d79
-	const BIGNUM *pub_key, *dh_p, *dh_g;
089d79
-	int min = -1, max = -1, nbits = -1;
089d79
-	int cmin = -1, cmax = -1; /* client proposal */
089d79
-	struct sshbuf *empty = sshbuf_new();
089d79
-	int r;
089d79
 
089d79
 	/* Initialise GSSAPI */
089d79
 
089d79
@@ -289,153 +310,125 @@ kexgssgex_server(struct ssh *ssh)
089d79
 	 * in the GSSAPI code are no longer available. This kludges them back
089d79
 	 * into life
089d79
 	 */
089d79
-	if (!ssh_gssapi_oid_table_ok())
089d79
-		if ((mechs = ssh_gssapi_server_mechanisms()))
089d79
-			free(mechs);
089d79
+	if (!ssh_gssapi_oid_table_ok()) {
089d79
+		mechs = ssh_gssapi_server_mechanisms();
089d79
+		free(mechs);
089d79
+	}
089d79
 
089d79
 	debug2_f("Identifying %s", kex->name);
089d79
 	oid = ssh_gssapi_id_kex(NULL, kex->name, kex->kex_type);
089d79
 	if (oid == GSS_C_NO_OID)
089d79
-	   fatal("Unknown gssapi mechanism");
089d79
+		fatal("Unknown gssapi mechanism");
089d79
 
089d79
 	debug2_f("Acquiring credentials");
089d79
 
089d79
-	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&ctxt, oid)))
089d79
+	if (GSS_ERROR(mm_ssh_gssapi_server_ctx(&kex->gss, oid)))
089d79
 		fatal("Unable to acquire credentials for the server");
089d79
 
089d79
-	/* 5. S generates an ephemeral key pair (do the allocations early) */
089d79
-	debug("Doing group exchange");
089d79
-	ssh_packet_read_expect(ssh, SSH2_MSG_KEXGSS_GROUPREQ);
089d79
-	/* store client proposal to provide valid signature */
089d79
-	if ((r = sshpkt_get_u32(ssh, &cmin)) != 0 ||
089d79
-	    (r = sshpkt_get_u32(ssh, &nbits)) != 0 ||
089d79
-	    (r = sshpkt_get_u32(ssh, &cmax)) != 0 ||
089d79
-	    (r = sshpkt_get_end(ssh)) != 0)
089d79
-		fatal("sshpkt failed: %s", ssh_err(r));
089d79
-	kex->nbits = nbits;
089d79
-	kex->min = cmin;
089d79
-	kex->max = cmax;
089d79
-	min = MAX(DH_GRP_MIN, cmin);
089d79
-	max = MIN(DH_GRP_MAX, cmax);
089d79
-	nbits = MAXIMUM(DH_GRP_MIN, nbits);
089d79
-	nbits = MINIMUM(DH_GRP_MAX, nbits);
089d79
-	if (max < min || nbits < min || max < nbits)
089d79
-		fatal("GSS_GEX, bad parameters: %d !< %d !< %d",
089d79
-		    min, nbits, max);
089d79
-	kex->dh = mm_choose_dh(min, nbits, max);
089d79
-	if (kex->dh == NULL) {
089d79
-		sshpkt_disconnect(ssh, "Protocol error: no matching group found");
089d79
-		fatal("Protocol error: no matching group found");
089d79
-	}
089d79
+	ssh_gssapi_build_ctx(&kex->gss);
089d79
+	if (kex->gss == NULL)
089d79
+		fatal("Unable to allocate memory for gss context");
089d79
 
089d79
-	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
089d79
-	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0 ||
089d79
-	    (r = sshpkt_put_bignum2(ssh, dh_p)) != 0 ||
089d79
-	    (r = sshpkt_put_bignum2(ssh, dh_g)) != 0 ||
089d79
-	    (r = sshpkt_send(ssh)) != 0)
089d79
-		fatal("sshpkt failed: %s", ssh_err(r));
089d79
-
089d79
-	if ((r = ssh_packet_write_wait(ssh)) != 0)
089d79
-		fatal("ssh_packet_write_wait: %s", ssh_err(r));
089d79
-
089d79
-	/* Compute our exchange value in parallel with the client */
089d79
-	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0)
089d79
-		goto out;
089d79
+	debug("Doing group exchange");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUPREQ, &input_kexgssgex_groupreq);
089d79
+	return 0;
089d79
+}
089d79
 
089d79
-	do {
089d79
-		debug("Wait SSH2_MSG_GSSAPI_INIT");
089d79
-		type = ssh_packet_read(ssh);
089d79
-		switch(type) {
089d79
-		case SSH2_MSG_KEXGSS_INIT:
089d79
-			if (dh_client_pub != NULL)
089d79
-				fatal("Received KEXGSS_INIT after initialising");
089d79
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-			        &recv_tok)) != 0 ||
089d79
-			    (r = sshpkt_get_bignum2(ssh, &dh_client_pub)) != 0 ||
089d79
-			    (r = sshpkt_get_end(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
+static inline void
089d79
+kexgssgex_accept_ctx(struct ssh *ssh,
089d79
+		     gss_buffer_desc *recv_tok,
089d79
+		     gss_buffer_desc *send_tok,
089d79
+		     OM_uint32 *ret_flags)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	int r;
089d79
 
089d79
-			/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
089d79
-			break;
089d79
-		case SSH2_MSG_KEXGSS_CONTINUE:
089d79
-			if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh,
089d79
-			        &recv_tok)) != 0 ||
089d79
-			    (r = sshpkt_get_end(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			break;
089d79
-		default:
089d79
-			sshpkt_disconnect(ssh,
089d79
-			    "Protocol error: didn't expect packet type %d",
089d79
-			    type);
089d79
-		}
089d79
+	gss->major = mm_ssh_gssapi_accept_ctx(gss, recv_tok, send_tok, ret_flags);
089d79
+	gss_release_buffer(&gss->minor, recv_tok);
089d79
 
089d79
-		maj_status = mm_ssh_gssapi_accept_ctx(ctxt, &recv_tok,
089d79
-		    &send_tok, &ret_flags);
089d79
+	if (gss->major != GSS_S_COMPLETE && send_tok->length == 0)
089d79
+		fatal("Zero length token output when incomplete");
089d79
 
089d79
-		gss_release_buffer(&min_status, &recv_tok);
089d79
+	if (gss->dh_client_pub == NULL)
089d79
+		fatal("No client public key");
089d79
 
089d79
-		if (maj_status != GSS_S_COMPLETE && send_tok.length == 0)
089d79
-			fatal("Zero length token output when incomplete");
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED) {
089d79
+		debug("Sending GSSAPI_CONTINUE");
089d79
+		if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
089d79
+		    (r = sshpkt_send(ssh)) != 0)
089d79
+			fatal("sshpkt failed: %s", ssh_err(r));
089d79
+		gss_release_buffer(&gss->minor, send_tok);
089d79
+	}
089d79
+}
089d79
 
089d79
-		if (dh_client_pub == NULL)
089d79
-			fatal("No client public key");
089d79
+static inline int
089d79
+kexgssgex_final(struct ssh *ssh,
089d79
+		gss_buffer_desc *send_tok,
089d79
+		OM_uint32 *ret_flags)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	Gssctxt *gss = kex->gss;
089d79
+	gss_buffer_desc msg_tok;
089d79
+	u_char hash[SSH_DIGEST_MAX_LENGTH];
089d79
+	size_t hashlen;
089d79
+	const BIGNUM *pub_key, *dh_p, *dh_g;
089d79
+	struct sshbuf *shared_secret = NULL;
089d79
+	struct sshbuf *empty = NULL;
089d79
+	int r;
089d79
 
089d79
-		if (maj_status & GSS_S_CONTINUE_NEEDED) {
089d79
-			debug("Sending GSSAPI_CONTINUE");
089d79
-			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
-			    (r = sshpkt_send(ssh)) != 0)
089d79
-				fatal("sshpkt failed: %s", ssh_err(r));
089d79
-			gss_release_buffer(&min_status, &send_tok);
089d79
-		}
089d79
-	} while (maj_status & GSS_S_CONTINUE_NEEDED);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, NULL);
089d79
 
089d79
-	if (GSS_ERROR(maj_status)) {
089d79
-		if (send_tok.length > 0) {
089d79
+	if (GSS_ERROR(gss->major)) {
089d79
+		if (send_tok->length > 0) {
089d79
 			if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_CONTINUE)) != 0 ||
089d79
-			    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0 ||
089d79
+			    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0 ||
089d79
 			    (r = sshpkt_send(ssh)) != 0)
089d79
 				fatal("sshpkt failed: %s", ssh_err(r));
089d79
 		}
089d79
 		fatal("accept_ctx died");
089d79
 	}
089d79
 
089d79
-	if (!(ret_flags & GSS_C_MUTUAL_FLAG))
089d79
+	if (!(*ret_flags & GSS_C_MUTUAL_FLAG))
089d79
 		fatal("Mutual Authentication flag wasn't set");
089d79
 
089d79
-	if (!(ret_flags & GSS_C_INTEG_FLAG))
089d79
+	if (!(*ret_flags & GSS_C_INTEG_FLAG))
089d79
 		fatal("Integrity flag wasn't set");
089d79
 
089d79
 	/* calculate shared secret */
089d79
-	if ((shared_secret = sshbuf_new()) == NULL) {
089d79
+	shared_secret = sshbuf_new();
089d79
+	if (shared_secret == NULL) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		r = SSH_ERR_ALLOC_FAIL;
089d79
 		goto out;
089d79
 	}
089d79
-	if ((r = kex_dh_compute_key(kex, dh_client_pub, shared_secret)) != 0)
089d79
+	if ((r = kex_dh_compute_key(kex, gss->dh_client_pub, shared_secret)) != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 		goto out;
089d79
+	}
089d79
+
089d79
+	if ((empty = sshbuf_new()) == NULL) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		r = SSH_ERR_ALLOC_FAIL;
089d79
+		goto out;
089d79
+	}
089d79
 
089d79
 	DH_get0_key(kex->dh, &pub_key, NULL);
089d79
 	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
089d79
 	hashlen = sizeof(hash);
089d79
-	if ((r = kexgex_hash(
089d79
-	    kex->hash_alg,
089d79
-	    kex->client_version,
089d79
-	    kex->server_version,
089d79
-	    kex->peer,
089d79
-	    kex->my,
089d79
-	    empty,
089d79
-	    cmin, nbits, cmax,
089d79
-	    dh_p, dh_g,
089d79
-	    dh_client_pub,
089d79
-	    pub_key,
089d79
-	    sshbuf_ptr(shared_secret), sshbuf_len(shared_secret),
089d79
-	    hash, &hashlen)) != 0)
089d79
+	r = kexgex_hash(kex->hash_alg, kex->client_version, kex->server_version,
089d79
+			kex->peer, kex->my, empty, kex->min, kex->nbits, kex->max, dh_p, dh_g,
089d79
+			gss->dh_client_pub, pub_key, sshbuf_ptr(shared_secret),
089d79
+			sshbuf_len(shared_secret), hash, &hashlen);
089d79
+	sshbuf_free(empty);
089d79
+	if (r != 0)
089d79
 		fatal("kexgex_hash failed: %s", ssh_err(r));
089d79
 
089d79
-	gssbuf.value = hash;
089d79
-	gssbuf.length = hashlen;
089d79
+	gss->buf.value = hash;
089d79
+	gss->buf.length = hashlen;
089d79
 
089d79
-	if (GSS_ERROR(mm_ssh_gssapi_sign(ctxt, &gssbuf, &msg_tok)))
089d79
+	if (GSS_ERROR(mm_ssh_gssapi_sign(gss, &gss->buf, &msg_tok)))
089d79
 		fatal("Couldn't get MIC");
089d79
 
089d79
 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_COMPLETE)) != 0 ||
089d79
@@ -443,24 +436,24 @@ kexgssgex_server(struct ssh *ssh)
089d79
 	    (r = sshpkt_put_string(ssh, msg_tok.value, msg_tok.length)) != 0)
089d79
 		fatal("sshpkt failed: %s", ssh_err(r));
089d79
 
089d79
-	if (send_tok.length != 0) {
089d79
+	if (send_tok->length != 0) {
089d79
 		if ((r = sshpkt_put_u8(ssh, 1)) != 0 || /* true */
089d79
-		    (r = sshpkt_put_string(ssh, send_tok.value, send_tok.length)) != 0)
089d79
+		    (r = sshpkt_put_string(ssh, send_tok->value, send_tok->length)) != 0)
089d79
 			fatal("sshpkt failed: %s", ssh_err(r));
089d79
 	} else {
089d79
 		if ((r = sshpkt_put_u8(ssh, 0)) != 0) /* false */
089d79
 			fatal("sshpkt failed: %s", ssh_err(r));
089d79
 	}
089d79
 	if ((r = sshpkt_send(ssh)) != 0)
089d79
-		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+		fatal("sshpkt_send failed: %s", ssh_err(r));
089d79
 
089d79
-	gss_release_buffer(&min_status, &send_tok);
089d79
-	gss_release_buffer(&min_status, &msg_tok);
089d79
+	gss_release_buffer(&gss->minor, send_tok);
089d79
+	gss_release_buffer(&gss->minor, &msg_tok);
089d79
 
089d79
 	if (gss_kex_context == NULL)
089d79
-		gss_kex_context = ctxt;
089d79
+		gss_kex_context = gss;
089d79
 	else
089d79
-		ssh_gssapi_delete_ctx(&ctxt);
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
 
089d79
 	/* Finally derive the keys and send them */
089d79
 	if ((r = kex_derive_keys(ssh, hash, hashlen, shared_secret)) == 0)
089d79
@@ -470,13 +463,128 @@ kexgssgex_server(struct ssh *ssh)
089d79
 	 * just exchanged.  */
089d79
 	if (options.gss_store_rekey)
089d79
 		ssh_gssapi_rekey_creds();
089d79
+
089d79
+	if (kex->gss != NULL)
089d79
+		BN_clear_free(gss->dh_client_pub);
089d79
+
089d79
 out:
089d79
-	sshbuf_free(empty);
089d79
 	explicit_bzero(hash, sizeof(hash));
089d79
 	DH_free(kex->dh);
089d79
 	kex->dh = NULL;
089d79
-	BN_clear_free(dh_client_pub);
089d79
 	sshbuf_free(shared_secret);
089d79
 	return r;
089d79
 }
089d79
+
089d79
+static int
089d79
+input_kexgssgex_groupreq(int type,
089d79
+			 u_int32_t seq,
089d79
+			 struct ssh *ssh)
089d79
+{
089d79
+	struct kex *kex = ssh->kex;
089d79
+	const BIGNUM *dh_p, *dh_g;
089d79
+	int min = -1, max = -1, nbits = -1;
089d79
+	int cmin = -1, cmax = -1; /* client proposal */
089d79
+	int r;
089d79
+
089d79
+	/* 5. S generates an ephemeral key pair (do the allocations early) */
089d79
+
089d79
+	debug("SSH2_MSG_KEXGSS_GROUPREQ received");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_GROUPREQ, NULL);
089d79
+
089d79
+	/* store client proposal to provide valid signature */
089d79
+	if ((r = sshpkt_get_u32(ssh, &cmin)) != 0 ||
089d79
+	    (r = sshpkt_get_u32(ssh, &nbits)) != 0 ||
089d79
+	    (r = sshpkt_get_u32(ssh, &cmax)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	kex->nbits = nbits;
089d79
+	kex->min = cmin;
089d79
+	kex->max = cmax;
089d79
+	min = MAX(DH_GRP_MIN, cmin);
089d79
+	max = MIN(DH_GRP_MAX, cmax);
089d79
+	nbits = MAXIMUM(DH_GRP_MIN, nbits);
089d79
+	nbits = MINIMUM(DH_GRP_MAX, nbits);
089d79
+
089d79
+	if (max < min || nbits < min || max < nbits)
089d79
+		fatal("GSS_GEX, bad parameters: %d !< %d !< %d", min, nbits, max);
089d79
+
089d79
+	kex->dh = mm_choose_dh(min, nbits, max);
089d79
+	if (kex->dh == NULL) {
089d79
+		sshpkt_disconnect(ssh, "Protocol error: no matching group found");
089d79
+		fatal("Protocol error: no matching group found");
089d79
+	}
089d79
+
089d79
+	DH_get0_pqg(kex->dh, &dh_p, NULL, &dh_g);
089d79
+	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXGSS_GROUP)) != 0 ||
089d79
+	    (r = sshpkt_put_bignum2(ssh, dh_p)) != 0 ||
089d79
+	    (r = sshpkt_put_bignum2(ssh, dh_g)) != 0 ||
089d79
+	    (r = sshpkt_send(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	if ((r = ssh_packet_write_wait(ssh)) != 0)
089d79
+		fatal("ssh_packet_write_wait: %s", ssh_err(r));
089d79
+
089d79
+	/* Compute our exchange value in parallel with the client */
089d79
+	if ((r = dh_gen_key(kex->dh, kex->we_need * 8)) != 0) {
089d79
+		ssh_gssapi_delete_ctx(&kex->gss);
089d79
+		DH_free(kex->dh);
089d79
+		kex->dh = NULL;
089d79
+		return r;
089d79
+	}
089d79
+
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, &input_kexgssgex_init);
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_CONTINUE, &input_kexgssgex_continue);
089d79
+	debug("Wait SSH2_MSG_KEXGSS_INIT");
089d79
+	return 0;
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgssgex_init(int type,
089d79
+		     u_int32_t seq,
089d79
+		     struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
089d79
+	OM_uint32 ret_flags = 0;
089d79
+	int r;
089d79
+
089d79
+	debug("SSH2_MSG_KEXGSS_INIT received");
089d79
+	ssh_dispatch_set(ssh, SSH2_MSG_KEXGSS_INIT, NULL);
089d79
+
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
+	    (r = sshpkt_get_bignum2(ssh, &gss->dh_client_pub)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	/* Send SSH_MSG_KEXGSS_HOSTKEY here, if we want */
089d79
+
089d79
+	kexgssgex_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return 0;
089d79
+
089d79
+	return kexgssgex_final(ssh, &send_tok, &ret_flags);
089d79
+}
089d79
+
089d79
+static int
089d79
+input_kexgssgex_continue(int type,
089d79
+			 u_int32_t seq,
089d79
+			 struct ssh *ssh)
089d79
+{
089d79
+	Gssctxt *gss = ssh->kex->gss;
089d79
+	gss_buffer_desc recv_tok, send_tok = GSS_C_EMPTY_BUFFER;
089d79
+	OM_uint32 ret_flags = 0;
089d79
+	int r;
089d79
+
089d79
+	if ((r = ssh_gssapi_sshpkt_get_buffer_desc(ssh, &recv_tok)) != 0 ||
089d79
+	    (r = sshpkt_get_end(ssh)) != 0)
089d79
+		fatal("sshpkt failed: %s", ssh_err(r));
089d79
+
089d79
+	kexgssgex_accept_ctx(ssh, &recv_tok, &send_tok, &ret_flags);
089d79
+	if (gss->major & GSS_S_CONTINUE_NEEDED)
089d79
+		return 0;
089d79
+
089d79
+	return kexgssgex_final(ssh, &send_tok, &ret_flags);
089d79
+}
089d79
+
089d79
 #endif /* defined(GSSAPI) && defined(WITH_OPENSSL) */
089d79
diff --color -ruNp a/kex.h b/kex.h
089d79
--- a/kex.h	2024-05-16 15:49:43.986410812 +0200
089d79
+++ b/kex.h	2024-06-18 12:19:48.580347469 +0200
089d79
@@ -29,6 +29,10 @@
089d79
 #include "mac.h"
089d79
 #include "crypto_api.h"
089d79
 
089d79
+#ifdef GSSAPI
089d79
+# include "ssh-gss.h" /* Gssctxt */
089d79
+#endif
089d79
+
089d79
 #ifdef WITH_OPENSSL
089d79
 # include <openssl/bn.h>
089d79
 # include <openssl/dh.h>
089d79
@@ -177,6 +181,7 @@ struct kex {
089d79
 	int	hash_alg;
089d79
 	int	ec_nid;
089d79
 #ifdef GSSAPI
089d79
+	Gssctxt *gss;
089d79
 	int	gss_deleg_creds;
089d79
 	int	gss_trust_dns;
089d79
 	char    *gss_host;
089d79
diff --color -ruNp a/ssh-gss.h b/ssh-gss.h
089d79
--- a/ssh-gss.h	2024-05-16 15:49:43.837407972 +0200
089d79
+++ b/ssh-gss.h	2024-06-27 14:12:48.659866937 +0200
089d79
@@ -88,6 +88,8 @@ extern char **k5users_allowed_cmds;
089d79
 	KEX_GSS_GRP14_SHA1_ID "," \
089d79
 	KEX_GSS_GEX_SHA1_ID
089d79
 
089d79
+#include "digest.h" /* SSH_DIGEST_MAX_LENGTH */
089d79
+
089d79
 typedef struct {
089d79
 	char *filename;
089d79
 	char *envvar;
089d79
@@ -127,6 +129,16 @@ typedef struct {
089d79
 	gss_cred_id_t	creds; /* server */
089d79
 	gss_name_t	client; /* server */
089d79
 	gss_cred_id_t	client_creds; /* both */
089d79
+	struct sshbuf *shared_secret; /* both */
089d79
+	struct sshbuf *server_pubkey; /* server */
089d79
+	struct sshbuf *server_blob; /* client */
089d79
+	struct sshbuf *server_host_key_blob; /* client */
089d79
+	gss_buffer_desc msg_tok; /* client */
089d79
+	gss_buffer_desc buf; /* both */
089d79
+	u_char hash[SSH_DIGEST_MAX_LENGTH]; /* both */
089d79
+	size_t hashlen; /* both */
089d79
+	int first; /* client */
089d79
+	BIGNUM *dh_client_pub; /* server (gex) */
089d79
 } Gssctxt;
089d79
 
089d79
 extern ssh_gssapi_mech *supported_mechs[];