[openssl] master update

tmraz at fedoraproject.org tmraz at fedoraproject.org
Fri Feb 5 13:05:20 UTC 2021


The branch master has been updated
       via  bbde8566191e5851f4418cbb8acb0d50b16170d8 (commit)
       via  26372a4d44f0b4ef5423228b8bf975a5a7c814cb (commit)
      from  e60147fe74c202ef3ce5d36115252b7c3c504cd7 (commit)


- Log -----------------------------------------------------------------
commit bbde8566191e5851f4418cbb8acb0d50b16170d8
Author: Tomas Mraz <tomas at openssl.org>
Date:   Fri Jan 29 17:02:32 2021 +0100

    RSA: properly generate algorithm identifier for RSA-PSS signatures
    
    Fixes #13969
    
    - properly handle the mandatory RSA-PSS key parameters
    - improve parameter checking when setting the parameters
    - compute the algorithm id at the time it is requested so it
      reflects the actual parameters set
    - when generating keys do not override previously set parameters
      with defaults
    - tests added to the test_req recipe that should cover the PSS signature
      handling
    
    Reviewed-by: Richard Levitte <levitte at openssl.org>
    Reviewed-by: Shane Lontis <shane.lontis at oracle.com>
    (Merged from https://github.com/openssl/openssl/pull/13988)

commit 26372a4d44f0b4ef5423228b8bf975a5a7c814cb
Author: Tomas Mraz <tomas at openssl.org>
Date:   Wed Jan 27 10:22:41 2021 +0100

    provider-signature.pod: Fix formatting.
    
    Reviewed-by: Richard Levitte <levitte at openssl.org>
    Reviewed-by: Shane Lontis <shane.lontis at oracle.com>
    (Merged from https://github.com/openssl/openssl/pull/13988)

-----------------------------------------------------------------------

Summary of changes:
 crypto/rsa/rsa_ameth.c                        |   4 +-
 crypto/rsa/rsa_backend.c                      |   8 +-
 crypto/rsa/rsa_pss.c                          |   4 +-
 doc/man7/provider-signature.pod               |   8 +-
 include/crypto/rsa.h                          |   1 +
 providers/common/der/der_rsa.h.in             |   5 +-
 providers/common/der/der_rsa_key.c            |  32 +--
 providers/common/der/der_rsa_sig.c            |   2 +-
 providers/implementations/keymgmt/rsa_kmgmt.c |  17 +-
 providers/implementations/signature/rsa.c     | 295 +++++++++++++++++---------
 test/recipes/25-test_req.t                    |  54 ++++-
 test/testrsapssmandatory.pem                  |  29 +++
 12 files changed, 322 insertions(+), 137 deletions(-)
 create mode 100644 test/testrsapssmandatory.pem

diff --git a/crypto/rsa/rsa_ameth.c b/crypto/rsa/rsa_ameth.c
index 852facf577..e2dec1c98d 100644
--- a/crypto/rsa/rsa_ameth.c
+++ b/crypto/rsa/rsa_ameth.c
@@ -943,6 +943,7 @@ static int rsa_int_import_from(const OSSL_PARAM params[], void *vpctx,
     EVP_PKEY *pkey = EVP_PKEY_CTX_get0_pkey(pctx);
     RSA *rsa = ossl_rsa_new_with_ctx(pctx->libctx);
     RSA_PSS_PARAMS_30 rsa_pss_params = { 0, };
+    int pss_defaults_set = 0;
     int ok = 0;
 
     if (rsa == NULL) {
@@ -953,7 +954,8 @@ static int rsa_int_import_from(const OSSL_PARAM params[], void *vpctx,
     RSA_clear_flags(rsa, RSA_FLAG_TYPE_MASK);
     RSA_set_flags(rsa, rsa_type);
 
-    if (!ossl_rsa_pss_params_30_fromdata(&rsa_pss_params, params, pctx->libctx))
+    if (!ossl_rsa_pss_params_30_fromdata(&rsa_pss_params, &pss_defaults_set,
+                                         params, pctx->libctx))
         goto err;
 
     switch (rsa_type) {
diff --git a/crypto/rsa/rsa_backend.c b/crypto/rsa/rsa_backend.c
index 2f430b34d4..84f070a7ce 100644
--- a/crypto/rsa/rsa_backend.c
+++ b/crypto/rsa/rsa_backend.c
@@ -217,6 +217,7 @@ int ossl_rsa_pss_params_30_todata(const RSA_PSS_PARAMS_30 *pss,
 }
 
 int ossl_rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
+                                    int *defaults_set,
                                     const OSSL_PARAM params[],
                                     OSSL_LIB_CTX *libctx)
 {
@@ -249,10 +250,13 @@ int ossl_rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
      * restrictions, so we start by setting default values, and let each
      * parameter override their specific restriction data.
      */
-    if (param_md != NULL || param_mgf != NULL || param_mgf1md != NULL
-        || param_saltlen != NULL)
+    if (!*defaults_set
+        && (param_md != NULL || param_mgf != NULL || param_mgf1md != NULL
+            || param_saltlen != NULL)) {
         if (!ossl_rsa_pss_params_30_set_defaults(pss_params))
             return 0;
+        *defaults_set = 1;
+    }
 
     if (param_mgf != NULL) {
         int default_maskgenalg_nid = ossl_rsa_pss_params_30_maskgenalg(NULL);
diff --git a/crypto/rsa/rsa_pss.c b/crypto/rsa/rsa_pss.c
index 1b73cbb0f6..3a92ed04dd 100644
--- a/crypto/rsa/rsa_pss.c
+++ b/crypto/rsa/rsa_pss.c
@@ -113,7 +113,9 @@ int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const unsigned char *mHash,
         goto err;
     }
     if (sLen != RSA_PSS_SALTLEN_AUTO && (maskedDBLen - i) != sLen) {
-        ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED);
+        ERR_raise_data(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED,
+                       "expected: %d retrieved: %d", sLen,
+                       maskedDBLen - i);
         goto err;
     }
     if (!EVP_DigestInit_ex(ctx, Hash, NULL)
diff --git a/doc/man7/provider-signature.pod b/doc/man7/provider-signature.pod
index bf10b6572c..222693854f 100644
--- a/doc/man7/provider-signature.pod
+++ b/doc/man7/provider-signature.pod
@@ -323,10 +323,10 @@ follows.
 
 =item "digest" (B<OSSL_SIGNATURE_PARAM_DIGEST>) <UTF8 string>
 
-Get or sets the name of the digest algorithm used for the input to the signature
-functions. It is required in order to calculate the "algorithm-id".
+Get or sets the name of the digest algorithm used for the input to the
+signature functions. It is required in order to calculate the "algorithm-id".
 
-= item "properties" (B<OSSL_SIGNATURE_PARAM_PROPERTIES>) <UTF8 string>
+=item "properties" (B<OSSL_SIGNATURE_PARAM_PROPERTIES>) <UTF8 string>
 
 Sets the name of the property query associated with the "digest" algorithm.
 NULL is used if this optional value is not set.
@@ -337,7 +337,7 @@ Gets or sets the output size of the digest algorithm used for the input to the
 signature functions.
 The length of the "digest-size" parameter should not exceed that of a B<size_t>.
 
-= item "algorithm-id" (B<OSSL_SIGNATURE_PARAM_ALGORITHM_ID>) <octet string>
+=item "algorithm-id" (B<OSSL_SIGNATURE_PARAM_ALGORITHM_ID>) <octet string>
 
 Gets the DER encoded AlgorithmIdentifier that corresponds to the combination of
 signature algorithm and digest algorithm for the signature operation.
diff --git a/include/crypto/rsa.h b/include/crypto/rsa.h
index cb53b5dde6..599978dc3b 100644
--- a/include/crypto/rsa.h
+++ b/include/crypto/rsa.h
@@ -65,6 +65,7 @@ int ossl_rsa_fromdata(RSA *rsa, const OSSL_PARAM params[]);
 int ossl_rsa_pss_params_30_todata(const RSA_PSS_PARAMS_30 *pss,
                                   OSSL_PARAM_BLD *bld, OSSL_PARAM params[]);
 int ossl_rsa_pss_params_30_fromdata(RSA_PSS_PARAMS_30 *pss_params,
+                                    int *defaults_set,
                                     const OSSL_PARAM params[],
                                     OSSL_LIB_CTX *libctx);
 
diff --git a/providers/common/der/der_rsa.h.in b/providers/common/der/der_rsa.h.in
index 412d5bbe7f..733b9d60d6 100644
--- a/providers/common/der/der_rsa.h.in
+++ b/providers/common/der/der_rsa.h.in
@@ -23,6 +23,9 @@ int ossl_DER_w_RSASSA_PSS_params(WPACKET *pkt, int tag,
                                  const RSA_PSS_PARAMS_30 *pss);
 /* Subject Public Key Info */
 int ossl_DER_w_algorithmIdentifier_RSA(WPACKET *pkt, int tag, RSA *rsa);
+int ossl_DER_w_algorithmIdentifier_RSA_PSS(WPACKET *pkt, int tag,
+                                           int rsa_type,
+                                           const RSA_PSS_PARAMS_30 *pss);
 /* Signature */
 int ossl_DER_w_algorithmIdentifier_MDWithRSAEncryption(WPACKET *pkt, int tag,
-                                                       RSA *rsa, int mdnid);
+                                                       int mdnid);
diff --git a/providers/common/der/der_rsa_key.c b/providers/common/der/der_rsa_key.c
index 1cc5874290..70b8edb63b 100644
--- a/providers/common/der/der_rsa_key.c
+++ b/providers/common/der/der_rsa_key.c
@@ -52,18 +52,16 @@
  * around that, we make them non-static, and declare them an extra time to
  * avoid compilers complaining about definitions without declarations.
  */
-#if 0                            /* Currently unused */
 #define DER_AID_V_sha1Identifier                                        \
     DER_P_SEQUENCE|DER_F_CONSTRUCTED,                                   \
         DER_OID_SZ_id_sha1 + DER_SZ_NULL,                               \
         DER_OID_V_id_sha1,                                              \
         DER_V_NULL
-extern const unsigned char der_aid_sha1Identifier[];
-const unsigned char der_aid_sha1Identifier[] = {
+extern const unsigned char ossl_der_aid_sha1Identifier[];
+const unsigned char ossl_der_aid_sha1Identifier[] = {
     DER_AID_V_sha1Identifier
 };
-#define DER_AID_SZ_sha1Identifier sizeof(der_aid_sha1Identifier)
-#endif
+#define DER_AID_SZ_sha1Identifier sizeof(ossl_der_aid_sha1Identifier)
 
 #define DER_AID_V_sha224Identifier                                      \
     DER_P_SEQUENCE|DER_F_CONSTRUCTED,                                   \
@@ -277,8 +275,8 @@ static int DER_w_MaskGenAlgorithm(WPACKET *pkt, int tag,
 
 #define OAEP_PSS_MD_CASE(name, var)                                     \
     case NID_##name:                                                    \
-        var = ossl_der_oid_id_##name;                                   \
-        var##_sz = sizeof(ossl_der_oid_id_##name);                      \
+        var = ossl_der_aid_##name##Identifier;                          \
+        var##_sz = sizeof(ossl_der_aid_##name##Identifier);             \
         break;
 
 int ossl_DER_w_RSASSA_PSS_params(WPACKET *pkt, int tag,
@@ -356,14 +354,15 @@ int ossl_DER_w_RSASSA_PSS_params(WPACKET *pkt, int tag,
     var##_oid_sz = sizeof(ossl_der_oid_##name);                         \
     break;
 
-int ossl_DER_w_algorithmIdentifier_RSA(WPACKET *pkt, int tag, RSA *rsa)
+int ossl_DER_w_algorithmIdentifier_RSA_PSS(WPACKET *pkt, int tag,
+                                           int rsa_type,
+                                           const RSA_PSS_PARAMS_30 *pss)
 {
     int rsa_nid = NID_undef;
     const unsigned char *rsa_oid = NULL;
     size_t rsa_oid_sz = 0;
-    RSA_PSS_PARAMS_30 *pss_params = ossl_rsa_get0_pss_params_30(rsa);
 
-    switch (RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK)) {
+    switch (rsa_type) {
     case RSA_FLAG_TYPE_RSA:
         RSA_CASE(rsaEncryption, rsa);
     case RSA_FLAG_TYPE_RSASSAPSS:
@@ -375,8 +374,17 @@ int ossl_DER_w_algorithmIdentifier_RSA(WPACKET *pkt, int tag, RSA *rsa)
 
     return ossl_DER_w_begin_sequence(pkt, tag)
         && (rsa_nid != NID_rsassaPss
-            || ossl_rsa_pss_params_30_is_unrestricted(pss_params)
-            || ossl_DER_w_RSASSA_PSS_params(pkt, -1, pss_params))
+            || ossl_rsa_pss_params_30_is_unrestricted(pss)
+            || ossl_DER_w_RSASSA_PSS_params(pkt, -1, pss))
         && ossl_DER_w_precompiled(pkt, -1, rsa_oid, rsa_oid_sz)
         && ossl_DER_w_end_sequence(pkt, tag);
 }
+
+int ossl_DER_w_algorithmIdentifier_RSA(WPACKET *pkt, int tag, RSA *rsa)
+{
+    int rsa_type = RSA_test_flags(rsa, RSA_FLAG_TYPE_MASK);
+    RSA_PSS_PARAMS_30 *pss_params = ossl_rsa_get0_pss_params_30(rsa);
+
+    return ossl_DER_w_algorithmIdentifier_RSA_PSS(pkt, tag, rsa_type,
+                                                  pss_params);
+}
diff --git a/providers/common/der/der_rsa_sig.c b/providers/common/der/der_rsa_sig.c
index 1ff9bf789b..94ed60b69f 100644
--- a/providers/common/der/der_rsa_sig.c
+++ b/providers/common/der/der_rsa_sig.c
@@ -29,7 +29,7 @@
         break;
 
 int ossl_DER_w_algorithmIdentifier_MDWithRSAEncryption(WPACKET *pkt, int tag,
-                                                       RSA *rsa, int mdnid)
+                                                       int mdnid)
 {
     const unsigned char *precompiled = NULL;
     size_t precompiled_sz = 0;
diff --git a/providers/implementations/keymgmt/rsa_kmgmt.c b/providers/implementations/keymgmt/rsa_kmgmt.c
index 9f783c56d8..64779ca6be 100644
--- a/providers/implementations/keymgmt/rsa_kmgmt.c
+++ b/providers/implementations/keymgmt/rsa_kmgmt.c
@@ -56,11 +56,12 @@ static OSSL_FUNC_keymgmt_query_operation_name_fn rsa_query_operation_name;
 DEFINE_STACK_OF(BIGNUM)
 DEFINE_SPECIAL_STACK_OF_CONST(BIGNUM_const, BIGNUM)
 
-static int pss_params_fromdata(RSA_PSS_PARAMS_30 *pss_params,
+static int pss_params_fromdata(RSA_PSS_PARAMS_30 *pss_params, int *defaults_set,
                                const OSSL_PARAM params[], int rsa_type,
                                OSSL_LIB_CTX *libctx)
 {
-    if (!ossl_rsa_pss_params_30_fromdata(pss_params, params, libctx))
+    if (!ossl_rsa_pss_params_30_fromdata(pss_params, defaults_set,
+                                         params, libctx))
         return 0;
 
     /* If not a PSS type RSA, sending us PSS parameters is wrong */
@@ -153,6 +154,7 @@ static int rsa_import(void *keydata, int selection, const OSSL_PARAM params[])
     RSA *rsa = keydata;
     int rsa_type;
     int ok = 1;
+    int pss_defaults_set = 0;
 
     if (!ossl_prov_is_running() || rsa == NULL)
         return 0;
@@ -165,8 +167,10 @@ static int rsa_import(void *keydata, int selection, const OSSL_PARAM params[])
     /* TODO(3.0) OAEP should bring on parameters as well */
 
     if ((selection & OSSL_KEYMGMT_SELECT_OTHER_PARAMETERS) != 0)
-        ok = ok && pss_params_fromdata(ossl_rsa_get0_pss_params_30(rsa), params,
-                                       rsa_type, ossl_rsa_get0_libctx(rsa));
+        ok = ok && pss_params_fromdata(ossl_rsa_get0_pss_params_30(rsa),
+                                       &pss_defaults_set,
+                                       params, rsa_type,
+                                       ossl_rsa_get0_libctx(rsa));
     if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0)
         ok = ok && ossl_rsa_fromdata(rsa, params);
 
@@ -391,6 +395,7 @@ struct rsa_gen_ctx {
 
     /* For PSS */
     RSA_PSS_PARAMS_30 pss_params;
+    int pss_defaults_set;
 
     /* For generation callback */
     OSSL_CALLBACK *cb;
@@ -470,8 +475,8 @@ static int rsa_gen_set_params(void *genctx, const OSSL_PARAM params[])
         return 0;
     /* Only attempt to get PSS parameters when generating an RSA-PSS key */
     if (gctx->rsa_type == RSA_FLAG_TYPE_RSASSAPSS
-        && !pss_params_fromdata(&gctx->pss_params, params, gctx->rsa_type,
-                                gctx->libctx))
+        && !pss_params_fromdata(&gctx->pss_params, &gctx->pss_defaults_set, params,
+                                gctx->rsa_type, gctx->libctx))
         return 0;
 #if defined(FIPS_MODULE) && !defined(OPENSSL_NO_ACVP_TESTS)
     /* Any ACVP test related parameters are copied into a params[] */
diff --git a/providers/implementations/signature/rsa.c b/providers/implementations/signature/rsa.c
index 98ebf6b243..e61d8ab04e 100644
--- a/providers/implementations/signature/rsa.c
+++ b/providers/implementations/signature/rsa.c
@@ -13,6 +13,7 @@
  */
 #include "internal/deprecated.h"
 
+#include "e_os.h" /* strcasecmp */
 #include <string.h>
 #include <openssl/crypto.h>
 #include <openssl/core_dispatch.h>
@@ -86,11 +87,7 @@ typedef struct {
      * by their Final function.
      */
     unsigned int flag_allow_md : 1;
-
-    /* The Algorithm Identifier of the combined signature algorithm */
-    unsigned char aid_buf[128];
-    unsigned char *aid;
-    size_t  aid_len;
+    unsigned int mgf1_md_set : 1;
 
     /* main digest */
     EVP_MD *md;
@@ -102,6 +99,7 @@ typedef struct {
     int pad_mode;
     /* message digest for MGF1 */
     EVP_MD *mgf1_md;
+    int mgf1_mdnid;
     char mgf1_mdname[OSSL_MAX_NAME_SIZE]; /* Purely informational */
     /* PSS salt length */
     int saltlen;
@@ -113,6 +111,9 @@ typedef struct {
 
 } PROV_RSA_CTX;
 
+/* True if PSS parameters are restricted */
+#define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
+
 static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
 {
     if (prsactx->md != NULL)
@@ -120,24 +121,37 @@ static size_t rsa_get_md_size(const PROV_RSA_CTX *prsactx)
     return 0;
 }
 
-static int rsa_check_padding(int mdnid, int padding)
+static int rsa_check_padding(const PROV_RSA_CTX *prsactx,
+                             const char *mdname, const char *mgf1_mdname,
+                             int mdnid)
 {
-    if (padding == RSA_NO_PADDING) {
-        ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
-        return 0;
-    }
-
-    if (padding == RSA_X931_PADDING) {
-        if (RSA_X931_hash_id(mdnid) == -1) {
-            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
+    switch(prsactx->pad_mode) {
+        case RSA_NO_PADDING:
+            ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_PADDING_MODE);
             return 0;
-        }
+        case RSA_X931_PADDING:
+            if (RSA_X931_hash_id(mdnid) == -1) {
+                ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_X931_DIGEST);
+                return 0;
+            }
+            break;
+        case RSA_PKCS1_PSS_PADDING:
+            if (rsa_pss_restricted(prsactx))
+                if ((mdname != NULL && !EVP_MD_is_a(prsactx->md, mdname))
+                    || (mgf1_mdname != NULL
+                        && !EVP_MD_is_a(prsactx->mgf1_md, mgf1_mdname))) {
+                    ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
+                    return 0;
+                }
+            break;
+        default:
+            break;
     }
 
     return 1;
 }
 
-static int rsa_check_parameters(PROV_RSA_CTX *prsactx)
+static int rsa_check_parameters(PROV_RSA_CTX *prsactx, int min_saltlen)
 {
     if (prsactx->pad_mode == RSA_PKCS1_PSS_PADDING) {
         int max_saltlen;
@@ -146,10 +160,11 @@ static int rsa_check_parameters(PROV_RSA_CTX *prsactx)
         max_saltlen = RSA_size(prsactx->rsa) - EVP_MD_size(prsactx->md);
         if ((RSA_bits(prsactx->rsa) & 0x7) == 1)
             max_saltlen--;
-        if (prsactx->min_saltlen < 0 || prsactx->min_saltlen > max_saltlen) {
+        if (min_saltlen < 0 || min_saltlen > max_saltlen) {
             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_SALT_LENGTH);
             return 0;
         }
+        prsactx->min_saltlen = min_saltlen;
     }
     return 1;
 }
@@ -176,8 +191,81 @@ static void *rsa_newctx(void *provctx, const char *propq)
     return prsactx;
 }
 
-/* True if PSS parameters are restricted */
-#define rsa_pss_restricted(prsactx) (prsactx->min_saltlen != -1)
+static int rsa_pss_compute_saltlen(PROV_RSA_CTX *ctx)
+{
+    int saltlen = ctx->saltlen;
+ 
+    if (saltlen == RSA_PSS_SALTLEN_DIGEST) {
+        saltlen = EVP_MD_size(ctx->md);
+    } else if (saltlen == RSA_PSS_SALTLEN_AUTO || saltlen == RSA_PSS_SALTLEN_MAX) {
+        saltlen = RSA_size(ctx->rsa) - EVP_MD_size(ctx->md) - 2;
+        if ((RSA_bits(ctx->rsa) & 0x7) == 1)
+            saltlen--;
+    }
+    if (saltlen < 0) {
+        ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
+        return -1;
+    } else if (saltlen < ctx->min_saltlen) {
+        ERR_raise_data(ERR_LIB_PROV, PROV_R_PSS_SALTLEN_TOO_SMALL,
+                       "minimum salt length: %d, actual salt length: %d",
+                       ctx->min_saltlen, saltlen);
+        return -1;
+    }
+    return saltlen;
+}
+
+static unsigned char *rsa_generate_signature_aid(PROV_RSA_CTX *ctx,
+                                                 unsigned char *aid_buf,
+                                                 size_t buf_len,
+                                                 size_t *aid_len)
+{
+    WPACKET pkt;
+    unsigned char *aid = NULL;
+    int saltlen;
+    RSA_PSS_PARAMS_30 pss_params;
+
+    if (!WPACKET_init_der(&pkt, aid_buf, buf_len)) {
+        ERR_raise(ERR_LIB_PROV, ERR_R_MALLOC_FAILURE);
+        return NULL;
+    }
+
+    switch(ctx->pad_mode) {
+        case RSA_PKCS1_PADDING:
+            if (!ossl_DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1,
+                                                                    ctx->mdnid)) {
+                ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
+                goto cleanup;
+            }
+            break;
+        case RSA_PKCS1_PSS_PADDING:
+            saltlen = rsa_pss_compute_saltlen(ctx);
+            if (saltlen < 0)
+                goto cleanup;
+            if (!ossl_rsa_pss_params_30_set_defaults(&pss_params)
+                || !ossl_rsa_pss_params_30_set_hashalg(&pss_params, ctx->mdnid)
+                || !ossl_rsa_pss_params_30_set_maskgenhashalg(&pss_params,
+                                                              ctx->mgf1_mdnid)
+                || !ossl_rsa_pss_params_30_set_saltlen(&pss_params, saltlen)
+                || !ossl_DER_w_algorithmIdentifier_RSA_PSS(&pkt, -1,
+                                                           RSA_FLAG_TYPE_RSASSAPSS,
+                                                           &pss_params)) {
+                ERR_raise(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR);
+                goto cleanup;
+            }
+            break;
+        default:
+            ERR_raise_data(ERR_LIB_PROV, ERR_R_UNSUPPORTED,
+                           "Algorithm ID generation");
+            goto cleanup;
+    }
+    if (WPACKET_finish(&pkt)) {
+        WPACKET_get_total_written(&pkt, aid_len);
+        aid = WPACKET_get_curr(&pkt);
+    }
+ cleanup:
+    WPACKET_cleanup(&pkt);
+    return aid;
+}
 
 static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
                         const char *mdprops)
@@ -186,7 +274,6 @@ static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
         mdprops = ctx->propq;
 
     if (mdname != NULL) {
-        WPACKET pkt;
         EVP_MD *md = EVP_MD_fetch(ctx->libctx, mdname, mdprops);
         int sha1_allowed = (ctx->operation != EVP_PKEY_OP_SIGN);
         int md_nid = digest_rsa_sign_get_md_nid(md, sha1_allowed);
@@ -194,7 +281,7 @@ static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
 
         if (md == NULL
             || md_nid == NID_undef
-            || !rsa_check_padding(md_nid, ctx->pad_mode)
+            || !rsa_check_padding(ctx, mdname, NULL, md_nid)
             || mdname_len >= sizeof(ctx->mdname)) {
             if (md == NULL)
                 ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
@@ -209,27 +296,20 @@ static int rsa_setup_md(PROV_RSA_CTX *ctx, const char *mdname,
             return 0;
         }
 
+        if (!ctx->mgf1_md_set) {
+            if (!EVP_MD_up_ref(md)) {
+                EVP_MD_free(md);
+                return 0;
+            }
+            EVP_MD_free(ctx->mgf1_md);
+            ctx->mgf1_md = md;
+            ctx->mgf1_mdnid = md_nid;
+            OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
+        }
+
         EVP_MD_CTX_free(ctx->mdctx);
         EVP_MD_free(ctx->md);
 
-        /*
-         * TODO(3.0) Should we care about DER writing errors?
-         * All it really means is that for some reason, there's no
-         * AlgorithmIdentifier to be had (consider RSA with MD5-SHA1),
-         * but the operation itself is still valid, just as long as it's
-         * not used to construct anything that needs an AlgorithmIdentifier.
-         */
-        ctx->aid_len = 0;
-        if (WPACKET_init_der(&pkt, ctx->aid_buf, sizeof(ctx->aid_buf))
-            && ossl_DER_w_algorithmIdentifier_MDWithRSAEncryption(&pkt, -1,
-                                                                  ctx->rsa,
-                                                                  md_nid)
-            && WPACKET_finish(&pkt)) {
-            WPACKET_get_total_written(&pkt, &ctx->aid_len);
-            ctx->aid = WPACKET_get_curr(&pkt);
-        }
-        WPACKET_cleanup(&pkt);
-
         ctx->mdctx = NULL;
         ctx->md = md;
         ctx->mdnid = md_nid;
@@ -244,33 +324,37 @@ static int rsa_setup_mgf1_md(PROV_RSA_CTX *ctx, const char *mdname,
 {
     size_t len;
     EVP_MD *md = NULL;
+    int mdnid;
 
     if (mdprops == NULL)
         mdprops = ctx->propq;
 
-    if (ctx->mgf1_mdname[0] != '\0')
-        EVP_MD_free(ctx->mgf1_md);
-
     if ((md = EVP_MD_fetch(ctx->libctx, mdname, mdprops)) == NULL) {
         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
                        "%s could not be fetched", mdname);
         return 0;
     }
     /* The default for mgf1 is SHA1 - so allow SHA1 */
-    if (digest_rsa_sign_get_md_nid(md, 1) == NID_undef) {
-        ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
-                       "digest=%s", mdname);
+    if ((mdnid = digest_rsa_sign_get_md_nid(md, 1)) == NID_undef
+        || !rsa_check_padding(ctx, NULL, mdname, mdnid)) {
+        if (mdnid == NID_undef)
+            ERR_raise_data(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED,
+                           "digest=%s", mdname);
         EVP_MD_free(md);
         return 0;
     }
-    ctx->mgf1_md = md;
     len = OPENSSL_strlcpy(ctx->mgf1_mdname, mdname, sizeof(ctx->mgf1_mdname));
     if (len >= sizeof(ctx->mgf1_mdname)) {
         ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
                        "%s exceeds name buffer length", mdname);
+        EVP_MD_free(md);
         return 0;
     }
 
+    EVP_MD_free(ctx->mgf1_md);
+    ctx->mgf1_md = md;
+    ctx->mgf1_mdnid = mdnid;
+    ctx->mgf1_md_set = 1;
     return 1;
 }
 
@@ -317,7 +401,6 @@ static int rsa_signverify_init(void *vprsactx, void *vrsa, int operation)
 
                 mdname = ossl_rsa_oaeppss_nid2name(md_nid);
                 mgf1mdname = ossl_rsa_oaeppss_nid2name(mgf1md_nid);
-                prsactx->min_saltlen = min_saltlen;
 
                 if (mdname == NULL) {
                     ERR_raise_data(ERR_LIB_PROV, PROV_R_INVALID_DIGEST,
@@ -346,9 +429,10 @@ static int rsa_signverify_init(void *vprsactx, void *vrsa, int operation)
                 }
                 prsactx->saltlen = min_saltlen;
 
-                return rsa_setup_md(prsactx, mdname, prsactx->propq)
-                    && rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq)
-                    && rsa_check_parameters(prsactx);
+                /* call rsa_setup_mgf1_md before rsa_setup_md to avoid duplication */
+                return rsa_setup_mgf1_md(prsactx, mgf1mdname, prsactx->propq)
+                    && rsa_setup_md(prsactx, mdname, prsactx->propq)
+                    && rsa_check_parameters(prsactx, min_saltlen);
             }
         }
 
@@ -727,8 +811,12 @@ static int rsa_digest_signverify_init(void *vprsactx, const char *mdname,
 
     if (prsactx != NULL)
         prsactx->flag_allow_md = 0;
-    if (!rsa_signverify_init(vprsactx, vrsa, operation)
-        || !rsa_setup_md(prsactx, mdname, NULL)) /* TODO RL */
+    if (!rsa_signverify_init(vprsactx, vrsa, operation))
+        return 0;
+    if (mdname != NULL
+        /* was rsa_setup_md already called in rsa_signverify_init()? */
+        && (mdname[0] == '\0' || strcasecmp(prsactx->mdname, mdname) != 0)
+        && !rsa_setup_md(prsactx, mdname, prsactx->propq))
         return 0;
 
     prsactx->mdctx = EVP_MD_CTX_new();
@@ -912,9 +1000,17 @@ static int rsa_get_ctx_params(void *vprsactx, OSSL_PARAM *params)
         return 0;
 
     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_ALGORITHM_ID);
-    if (p != NULL
-        && !OSSL_PARAM_set_octet_string(p, prsactx->aid, prsactx->aid_len))
-        return 0;
+    if (p != NULL) {
+        /* The Algorithm Identifier of the combined signature algorithm */
+        unsigned char aid_buf[128];
+        unsigned char *aid;
+        size_t  aid_len;
+
+        aid = rsa_generate_signature_aid(prsactx, aid_buf,
+                                         sizeof(aid_buf), &aid_len);
+        if (aid == NULL || !OSSL_PARAM_set_octet_string(p, aid, aid_len))
+            return 0;
+    }
 
     p = OSSL_PARAM_locate(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
     if (p != NULL)
@@ -1011,6 +1107,12 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
 {
     PROV_RSA_CTX *prsactx = (PROV_RSA_CTX *)vprsactx;
     const OSSL_PARAM *p;
+    int pad_mode = prsactx->pad_mode;
+    int saltlen = prsactx->saltlen;
+    char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = NULL;
+    char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = NULL;
+    char mgf1mdname[OSSL_MAX_NAME_SIZE] = "", *pmgf1mdname = NULL;
+    char mgf1mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmgf1mdprops = NULL;
 
     if (prsactx == NULL || params == NULL)
         return 0;
@@ -1020,37 +1122,24 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
     if (p != NULL && !prsactx->flag_allow_md)
         return 0;
     if (p != NULL) {
-        char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
-        char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
         const OSSL_PARAM *propsp =
             OSSL_PARAM_locate_const(params,
                                     OSSL_SIGNATURE_PARAM_PROPERTIES);
 
+        pmdname = mdname;
         if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
             return 0;
 
-        if (propsp == NULL)
-            pmdprops = NULL;
-        else if (!OSSL_PARAM_get_utf8_string(propsp,
-                                             &pmdprops, sizeof(mdprops)))
-            return 0;
-
-        if (rsa_pss_restricted(prsactx)) {
-            /* TODO(3.0) figure out what to do for prsactx->md == NULL */
-            if (prsactx->md == NULL || EVP_MD_is_a(prsactx->md, mdname))
-                return 1;
-            ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
-            return 0;
+        if (propsp != NULL) {
+            pmdprops = mdprops;
+            if (!OSSL_PARAM_get_utf8_string(propsp,
+                                            &pmdprops, sizeof(mdprops)))
+                return 0;
         }
-
-        /* non-PSS code follows */
-        if (!rsa_setup_md(prsactx, mdname, pmdprops))
-            return 0;
     }
 
     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PAD_MODE);
     if (p != NULL) {
-        int pad_mode = 0;
         const char *err_extra_text = NULL;
 
         switch (p->data_type) {
@@ -1092,10 +1181,6 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
                     "PSS padding only allowed for sign and verify operations";
                 goto bad_pad;
             }
-            if (prsactx->md == NULL
-                && !rsa_setup_md(prsactx, RSA_DEFAULT_DIGEST_NAME, NULL)) {
-                return 0;
-            }
             break;
         case RSA_PKCS1_PADDING:
             err_extra_text = "PKCS#1 padding not allowed with RSA-PSS";
@@ -1124,16 +1209,11 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
                                err_extra_text);
             return 0;
         }
-        if (!rsa_check_padding(prsactx->mdnid, pad_mode))
-            return 0;
-        prsactx->pad_mode = pad_mode;
     }
 
     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_PSS_SALTLEN);
     if (p != NULL) {
-        int saltlen;
-
-        if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
+        if (pad_mode != RSA_PKCS1_PSS_PADDING) {
             ERR_raise_data(ERR_LIB_PROV, PROV_R_NOT_SUPPORTED,
                            "PSS saltlen can only be specified if "
                            "PSS padding has been specified first");
@@ -1199,46 +1279,49 @@ static int rsa_set_ctx_params(void *vprsactx, const OSSL_PARAM params[])
                 }
             }
         }
-
-        prsactx->saltlen = saltlen;
     }
 
     p = OSSL_PARAM_locate_const(params, OSSL_SIGNATURE_PARAM_MGF1_DIGEST);
     if (p != NULL) {
-        char mdname[OSSL_MAX_NAME_SIZE] = "", *pmdname = mdname;
-        char mdprops[OSSL_MAX_PROPQUERY_SIZE] = "", *pmdprops = mdprops;
         const OSSL_PARAM *propsp =
             OSSL_PARAM_locate_const(params,
                                     OSSL_SIGNATURE_PARAM_MGF1_PROPERTIES);
 
-        if (!OSSL_PARAM_get_utf8_string(p, &pmdname, sizeof(mdname)))
+        pmgf1mdname = mgf1mdname;
+        if (!OSSL_PARAM_get_utf8_string(p, &pmgf1mdname, sizeof(mgf1mdname)))
             return 0;
 
-        if (propsp == NULL)
-            pmdprops = NULL;
-        else if (!OSSL_PARAM_get_utf8_string(propsp,
-                                             &pmdprops, sizeof(mdprops)))
-            return 0;
+        if (propsp != NULL) {
+            pmgf1mdprops = mgf1mdprops;
+            if (!OSSL_PARAM_get_utf8_string(propsp,
+                                            &pmgf1mdprops, sizeof(mgf1mdprops)))
+                return 0;
+        }
 
-        if (prsactx->pad_mode != RSA_PKCS1_PSS_PADDING) {
+        if (pad_mode != RSA_PKCS1_PSS_PADDING) {
             ERR_raise(ERR_LIB_PROV, PROV_R_INVALID_MGF1_MD);
             return  0;
         }
+    }
 
-        if (rsa_pss_restricted(prsactx)) {
-            /* TODO(3.0) figure out what to do for prsactx->mgf1_md == NULL */
-            if (prsactx->mgf1_md == NULL
-                || EVP_MD_is_a(prsactx->mgf1_md, mdname))
-                return 1;
-            ERR_raise(ERR_LIB_PROV, PROV_R_DIGEST_NOT_ALLOWED);
-            return 0;
-        }
+    prsactx->saltlen = saltlen;
+    prsactx->pad_mode = pad_mode;
+
+    if (prsactx->md == NULL && pmdname == NULL
+        && pad_mode == RSA_PKCS1_PSS_PADDING)
+        pmdname = RSA_DEFAULT_DIGEST_NAME;
 
-        /* non-PSS code follows */
-        if (!rsa_setup_mgf1_md(prsactx, mdname, pmdprops))
+    if (pmgf1mdname != NULL
+        && !rsa_setup_mgf1_md(prsactx, pmgf1mdname, pmgf1mdprops))
+        return 0;
+
+    if (pmdname != NULL) {
+        if (!rsa_setup_md(prsactx, pmdname, pmdprops))
+            return 0;
+    } else {
+        if (!rsa_check_padding(prsactx, NULL, NULL, prsactx->mdnid))
             return 0;
     }
-
     return 1;
 }
 
diff --git a/test/recipes/25-test_req.t b/test/recipes/25-test_req.t
index 3f0d9f59e7..ab6c6e681b 100644
--- a/test/recipes/25-test_req.t
+++ b/test/recipes/25-test_req.t
@@ -93,7 +93,7 @@ subtest "generating certificate requests with RSA" => sub {
 };
 
 subtest "generating certificate requests with RSA-PSS" => sub {
-    plan tests => 4;
+    plan tests => 12;
 
     SKIP: {
         skip "RSA is not supported by this OpenSSL build", 2
@@ -104,7 +104,6 @@ subtest "generating certificate requests with RSA-PSS" => sub {
                     "-new", "-out", "testreq-rsapss.pem", "-utf8",
                     "-key", srctop_file("test", "testrsapss.pem")])),
            "Generating request");
-
         ok(run(app(["openssl", "req",
                     "-config", srctop_file("test", "test.cnf"),
                     "-verify", "-in", "testreq-rsapss.pem", "-noout"])),
@@ -117,11 +116,60 @@ subtest "generating certificate requests with RSA-PSS" => sub {
                     "-sigopt", "rsa_pss_saltlen:-1",
                     "-key", srctop_file("test", "testrsapss.pem")])),
            "Generating request");
-
         ok(run(app(["openssl", "req",
                     "-config", srctop_file("test", "test.cnf"),
                     "-verify", "-in", "testreq-rsapss2.pem", "-noout"])),
            "Verifying signature on request");
+
+        ok(run(app(["openssl", "req",
+                    "-config", srctop_file("test", "test.cnf"),
+                    "-new", "-out", "testreq-rsapssmand.pem", "-utf8",
+                    "-sigopt", "rsa_padding_mode:pss",
+                    "-key", srctop_file("test", "testrsapssmandatory.pem")])),
+           "Generating request");
+        ok(run(app(["openssl", "req",
+                    "-config", srctop_file("test", "test.cnf"),
+                    "-verify", "-in", "testreq-rsapssmand.pem", "-noout"])),
+           "Verifying signature on request");
+
+        ok(run(app(["openssl", "req",
+                    "-config", srctop_file("test", "test.cnf"),
+                    "-new", "-out", "testreq-rsapssmand2.pem", "-utf8",
+                    "-sigopt", "rsa_pss_saltlen:100",
+                    "-key", srctop_file("test", "testrsapssmandatory.pem")])),
+           "Generating request");
+        ok(run(app(["openssl", "req",
+                    "-config", srctop_file("test", "test.cnf"),
+                    "-verify", "-in", "testreq-rsapssmand2.pem", "-noout"])),
+           "Verifying signature on request");
+
+        ok(!run(app(["openssl", "req",
+                     "-config", srctop_file("test", "test.cnf"),
+                     "-new", "-out", "testreq-rsapss3.pem", "-utf8",
+                     "-sigopt", "rsa_padding_mode:pkcs1",
+                     "-key", srctop_file("test", "testrsapss.pem")])),
+           "Generating request with expected failure");
+
+        ok(!run(app(["openssl", "req",
+                     "-config", srctop_file("test", "test.cnf"),
+                     "-new", "-out", "testreq-rsapss3.pem", "-utf8",
+                     "-sigopt", "rsa_pss_saltlen:-4",
+                     "-key", srctop_file("test", "testrsapss.pem")])),
+           "Generating request with expected failure");
+
+        ok(!run(app(["openssl", "req",
+                     "-config", srctop_file("test", "test.cnf"),
+                     "-new", "-out", "testreq-rsapssmand3.pem", "-utf8",
+                     "-sigopt", "rsa_pss_saltlen:10",
+                     "-key", srctop_file("test", "testrsapssmandatory.pem")])),
+           "Generating request with expected failure");
+
+        ok(!run(app(["openssl", "req",
+                     "-config", srctop_file("test", "test.cnf"),
+                     "-new", "-out", "testreq-rsapssmand3.pem", "-utf8",
+                     "-sha256",
+                     "-key", srctop_file("test", "testrsapssmandatory.pem")])),
+           "Generating request with expected failure");
     }
 };
 
diff --git a/test/testrsapssmandatory.pem b/test/testrsapssmandatory.pem
new file mode 100644
index 0000000000..d01ae82c88
--- /dev/null
+++ b/test/testrsapssmandatory.pem
@@ -0,0 +1,29 @@
+-----BEGIN PRIVATE KEY-----
+MIIE7gIBADA9BgkqhkiG9w0BAQowMKANMAsGCWCGSAFlAwQCA6EaMBgGCSqGSIb3
+DQEBCDALBglghkgBZQMEAgOiAwIBQASCBKgwggSkAgEAAoIBAQDdiLMYj8fgrXKB
+dEC704hcfmeJebCyaZbYHBE/1YthJOptbhisBbNk4onKMITO6hkYOoH12rNxqwY5
+d9J1Ray6SJETVHxYCKftJ1LlrUJGqpyRCAAff1LYjjGRyqcMzVItWffy2iCgKGud
+uUqs9Og3wsVxUeXfTSGnLo1UevVc1qTKZJuDRWD2EItuwnFt7GA89IgGx8/liLsg
+cdlnm81gGdDmNKxNGi3VeOaJqFWnP9CpL8iXybG7F32U9mgEdE+EYt8GhQfNLzjL
+j17xfLl5K0SMqL8q+phas6Md0OmTl3Xg8Tupdoo/okAoYGXrv/sHDiV1YBSkXD4i
+dbV42aUfAgMBAAECggEAEyEJrfZEYR85Avqh2FYksS/tCs7qNg2uC80opCVxWpsQ
+bxCRqtD3M5/oHABih2dpcVEkBbGzyv3klLPHBX9VseQwOsYR0pw0u+KoYtK6JVX4
+HQHe2Nlqsu5cU2V3VUCpducM5Ph21r2GxWDJlPO01ZPI7scOnWCQpln7tC7F3xU0
+jNQ0SnFZ6SO4FrrBxOMjnIFiNMexxZt0fU7khy/dGck9aN4DtmQENcQkGdXj5xRv
+lInh92mQ16yMCbEU8cslWaAwqRF/k/5QxoIwTXr8PqaWshH9TIAht0rvTilWpHPg
+zpW6Pog/wGzVat3NeU3vBDYIUayHc6n3gbfJZDNxmQKBgQD41lAkxNsA89mYY7S9
+5NkDJ1N1hKNwg+iEyCZJkjxUk+SymdO7U/iD27Hgn/XyXm4RC5aHYpXJSnuiOk7R
+Z1Az1jjqLzPxsP72sWLORzGq82smYrK+iV2rhozWNlfVyazDkBcRRz2bLSESzgvO
+JWD3K3pjvj8U9ZSUhz+zXo4sUwKBgQDj6TBTKGDb8Au8sUOC916GrIrUEq5SkMDT
+A4CiD4fmvbdNs90AhD/mmqBw/dP3TbCPNmP8tGMUT0BDev6BoRKYOt+1XGYXt2de
+P38teVU/ZUcAO2RGdMNSdWT5o9BCWQZ18qSoOR/QanckOnkhKCgU/wqSdIvBBRMQ
+5e4qdI0qhQKBgB2MJTwYfADi88WaoU2jLPmo48oik926bBPISHOX/73zScbDaVbn
+I61UmwyXMfczq1Iu1BMDa9HZHFEpJ07KO8XL/DoinMJoR/43Fgp0fbtU6DZIpfzm
+Bs9lTLfrAAcMyYz3QSX2FaSleTXobZJu8dKnwQKzBn6QorH4VWIRKkStAoGBAIYL
+M1nlaLpSf4S2OT/A376Ton9CkXaMHmy9JZ2rRsHmGPZBcB0Kq06k6PIrx8wuzEYe
+tkX9jjx2tBQ8NY3mPzp7ffF766vNOaWL8O+86e+EUHMJe1uY9vv7gaz1tNog5BTg
+5gjuuBBrXbFYFr/yj0hyDDTBCSU4J9OLeD1OGWzFAoGBAMGc9h8oLyA3rQEjIuVA
+CuzgvZxOFPbtODFPcL4EQgAKLiKS+oZK0jONfCHaQB1AhIq8/nT/4suw7tWqYoKp
+KGH/+8tKNodKZfZLjVp0k8gsehyMDz1002/RLMJyFRIJWa1BqEJs7v7XgWW3RcmC
+PWznhdpNx3BYDSao5Ibl7I5E
+-----END PRIVATE KEY-----


More information about the openssl-commits mailing list