diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index 96e695b14..d2f9b5925 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -13,6 +13,10 @@ # include #endif +#if OSSL_OPENSSL_PREREQ(3, 0, 0) +# include +#endif + /* * Classes */ @@ -180,6 +184,150 @@ ossl_pkey_new_from_data(int argc, VALUE *argv, VALUE self) return ossl_pkey_new(pkey); } +struct ossl_params_build_args { + const OSSL_PARAM *settable; + VALUE hash, *memo; + OSSL_PARAM_BLD *param_bld; +}; + +static int +ossl_params_set_i(VALUE key, VALUE value, VALUE _args) +{ + struct ossl_params_build_args *args = (void *)_args; + const OSSL_PARAM *p; + int ret; + + if (SYMBOL_P(key)) + key = rb_sym2str(key); + p = OSSL_PARAM_locate_const(args->settable, StringValueCStr(key)); + if (p == NULL) + rb_raise(eOSSLError, "key not settable: %"PRIsVALUE, key); + + switch (p->data_type) { + case OSSL_PARAM_INTEGER: + case OSSL_PARAM_UNSIGNED_INTEGER: + ret = OSSL_PARAM_BLD_push_BN(args->param_bld, p->key, + GetBNPtr(value)); + break; + case OSSL_PARAM_UTF8_STRING: + ret = OSSL_PARAM_BLD_push_utf8_string(args->param_bld, p->key, + StringValueCStr(value), + RSTRING_LEN(value)); + break; + case OSSL_PARAM_OCTET_STRING: + ret = OSSL_PARAM_BLD_push_octet_string(args->param_bld, p->key, + StringValuePtr(value), + RSTRING_LEN(value)); + break; + case OSSL_PARAM_UTF8_PTR: + ret = OSSL_PARAM_BLD_push_utf8_ptr(args->param_bld, p->key, + StringValueCStr(value), + RSTRING_LEN(value)); + if (*args->memo == Qundef) + *args->memo = rb_ary_new(); + rb_ary_push(*args->memo, value); + break; + case OSSL_PARAM_OCTET_PTR: + ret = OSSL_PARAM_BLD_push_utf8_ptr(args->param_bld, p->key, + StringValuePtr(value), + RSTRING_LEN(value)); + if (*args->memo == Qundef) + *args->memo = rb_ary_new(); + rb_ary_push(*args->memo, value); + break; + default: + rb_raise(eOSSLError, "unsupported data type %d for key %s", + (int)p->data_type, p->key); + } + + if (ret <= 0) + ossl_raise(eOSSLError, "OSSL_PARAM_BLD_push*"); + + return ST_CONTINUE; +} + +static VALUE +ossl_params_build(VALUE _args) +{ + struct ossl_params_build_args *args = (void *)_args; + OSSL_PARAM *params; + + args->param_bld = OSSL_PARAM_BLD_new(); + if (args->param_bld == NULL) + ossl_raise(eOSSLError, "OSSL_PARAM_BLD_new"); + + rb_hash_foreach(args->hash, ossl_params_set_i, _args); + + params = OSSL_PARAM_BLD_to_param(args->param_bld); + if (!params) + ossl_raise(eOSSLError, "OSSL_PARAM_BLD_to_params"); + + return (VALUE)params; +} + +static OSSL_PARAM * +ossl_protect_params_build(const OSSL_PARAM *settable, VALUE hash, + VALUE *memo, int *state) +{ + struct ossl_params_build_args args; + OSSL_PARAM *params; + + args.settable = settable; + args.hash = hash; + args.memo = memo; + + params = (void *)rb_protect(ossl_params_build, (VALUE)&args, state); + OSSL_PARAM_BLD_free(args.param_bld); + return params; +} + +/* + * call-seq: + * OpenSSL::PKey.from_data(algo, selection, hash) -> pkey + */ +static VALUE +ossl_pkey_s_from_data(int argc, VALUE *argv, VALUE self) +{ + VALUE type, vselection, hash, memo; + const OSSL_PARAM *settable; + OSSL_PARAM *params; + EVP_PKEY_CTX *pctx; + EVP_PKEY *pkey = NULL; + int selection, state, ret; + + rb_scan_args(argc, argv, "3", &type, &vselection, &hash); + selection = NUM2INT(vselection); + StringValueCStr(type); + Check_Type(hash, T_HASH); + + pctx = EVP_PKEY_CTX_new_from_name(NULL, RSTRING_PTR(type), NULL); + if (!pctx) + ossl_raise(ePKeyError, "EVP_PKEY_CTX_new_from_name"); + settable = EVP_PKEY_fromdata_settable(pctx, selection); + if (!settable) { + EVP_PKEY_CTX_free(pctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_settable"); + } + + params = ossl_protect_params_build(settable, hash, &memo, &state); + if (state) { + EVP_PKEY_CTX_free(pctx); + rb_jump_tag(state); + } + + if (EVP_PKEY_fromdata_init(pctx) <= 0) { + EVP_PKEY_CTX_free(pctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_init"); + } + ret = EVP_PKEY_fromdata(pctx, &pkey, selection, params); + OSSL_PARAM_free(params); + EVP_PKEY_CTX_free(pctx); + + if (ret <= 0) + ossl_raise(ePKeyError, "EVP_PKEY_fromdata"); + return ossl_pkey_new(pkey); +} + static VALUE pkey_ctx_apply_options_i(RB_BLOCK_CALL_FUNC_ARGLIST(i, ctx_v)) { @@ -1687,6 +1835,7 @@ Init_ossl_pkey(void) */ cPKey = rb_define_class_under(mPKey, "PKey", rb_cObject); + rb_define_module_function(mPKey, "from_data", ossl_pkey_s_from_data, -1); rb_define_module_function(mPKey, "read", ossl_pkey_new_from_data, -1); rb_define_module_function(mPKey, "generate_parameters", ossl_pkey_s_generate_parameters, -1); rb_define_module_function(mPKey, "generate_key", ossl_pkey_s_generate_key, -1); @@ -1718,6 +1867,10 @@ Init_ossl_pkey(void) rb_define_method(cPKey, "encrypt", ossl_pkey_encrypt, -1); rb_define_method(cPKey, "decrypt", ossl_pkey_decrypt, -1); + rb_define_const(mPKey, "KEY_PARAMETERS", INT2NUM(EVP_PKEY_KEY_PARAMETERS)); + rb_define_const(mPKey, "PUBLIC_KEY", INT2NUM(EVP_PKEY_PUBLIC_KEY)); + rb_define_const(mPKey, "KEYPAIR", INT2NUM(EVP_PKEY_KEYPAIR)); + id_private_q = rb_intern("private?"); /* diff --git a/test/openssl/test_pkey_rsa.rb b/test/openssl/test_pkey_rsa.rb index 5030af165..abf9286f4 100644 --- a/test/openssl/test_pkey_rsa.rb +++ b/test/openssl/test_pkey_rsa.rb @@ -583,6 +583,12 @@ def test_to_data_public assert_equal nil, rsa.d end + def test_from_data + pkey = Fixtures.pkey("rsa2048") + + rsa1 = OpenSSL::PKey.from_data("RSA", OpenSSL::PKey::KEYPAIR, data) + end + private def assert_same_rsa(expected, key) check_component(expected, key, [:n, :e, :d, :p, :q, :dmp1, :dmq1, :iqmp])