From cb0479510fa16da3c09069ec3c3e1abaf06ba175 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BE=84=E6=BD=AD?= Date: Mon, 13 Nov 2023 11:03:46 +0800 Subject: [PATCH] add oauth2 plugin (#632) --- plugins/wasm-cpp/extensions/oauth/BUILD | 80 +++ plugins/wasm-cpp/extensions/oauth/README.md | 129 +++++ plugins/wasm-cpp/extensions/oauth/plugin.cc | 463 +++++++++++++++++ plugins/wasm-cpp/extensions/oauth/plugin.h | 105 ++++ .../wasm-cpp/extensions/oauth/plugin_test.cc | 478 ++++++++++++++++++ 5 files changed, 1255 insertions(+) create mode 100644 plugins/wasm-cpp/extensions/oauth/BUILD create mode 100644 plugins/wasm-cpp/extensions/oauth/README.md create mode 100644 plugins/wasm-cpp/extensions/oauth/plugin.cc create mode 100644 plugins/wasm-cpp/extensions/oauth/plugin.h create mode 100644 plugins/wasm-cpp/extensions/oauth/plugin_test.cc diff --git a/plugins/wasm-cpp/extensions/oauth/BUILD b/plugins/wasm-cpp/extensions/oauth/BUILD new file mode 100644 index 0000000000..f86eeab6e5 --- /dev/null +++ b/plugins/wasm-cpp/extensions/oauth/BUILD @@ -0,0 +1,80 @@ +# Copyright (c) 2022 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary") +load("//bazel:wasm.bzl", "declare_wasm_image_targets") + +proxy_wasm_cc_binary( + name = "oauth.wasm", + srcs = [ + "plugin.cc", + "plugin.h", + ], + deps = [ + "//common:random_util", + "@com_github_thalhammer_jwt_cpp//:lib", + "@com_github_mariusbancila_stduuid//:lib", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@boringssl//:ssl", + "//common:json_util", + "//common:http_util", + "//common:rule_util", + ], +) + +cc_library( + name = "oauth_lib", + srcs = [ + "plugin.cc", + ], + hdrs = [ + "plugin.h", + ], + copts = ["-DNULL_PLUGIN"], + deps = [ + "@com_github_thalhammer_jwt_cpp//:lib", + "@com_github_mariusbancila_stduuid//:lib", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@boringssl//:ssl", + "//common:json_util", + "@proxy_wasm_cpp_host//:lib", + "//common:http_util_nullvm", + "//common:rule_util_nullvm", + ], +) + +cc_test( + name = "oauth_test", + srcs = [ + "plugin_test.cc", + ], + copts = ["-DNULL_PLUGIN"], + deps = [ + ":oauth_lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@proxy_wasm_cpp_host//:lib", + ], +) + +declare_wasm_image_targets( + name = "oauth", + wasm_file = ":oauth.wasm", +) diff --git a/plugins/wasm-cpp/extensions/oauth/README.md b/plugins/wasm-cpp/extensions/oauth/README.md new file mode 100644 index 0000000000..3a14b794af --- /dev/null +++ b/plugins/wasm-cpp/extensions/oauth/README.md @@ -0,0 +1,129 @@ +# 功能说明 +`OAuth2`插件实现了基于JWT(JSON Web Tokens)进行OAuth2 Access Token签发的能力, 遵循[RFC9068](https://datatracker.ietf.org/doc/html/rfc9068)规范 + +# 插件配置说明 + +## 配置字段 + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------- | --------------- | ------------------------------------------- | ------ | ----------------------------------------------------------- | +| `consumers` | array of object | 必填 | - | 配置服务的调用者,用于对请求进行认证 | +| `_rules_` | array of object | 选填 | - | 配置特定路由或域名的访问权限列表,用于对请求进行鉴权 | +| `issuer` | string | 选填 | Higress-Gateway | 用于填充JWT中的issuer | +| `auth_path` | string | 选填 | /oauth2/token | 指定路径后缀用于签发Token,路由级配置时,要确保首先能匹配对应的路由 | +| `global_credentials` | bool | 选填 | ture | 是否开启全局凭证,即允许路由A下的auth_path签发的Token可以用于访问路由B | +| `auth_header_name` | string | 选填 | Authorization | 用于指定从哪个请求头获取JWT | +| `token_ttl` | number | 选填 | 7200 | token从签发后多久内有效,单位为秒 | +| `clock_skew_seconds` | number | 选填 | 60 | 校验JWT的exp和iat字段时允许的时钟偏移量,单位为秒 | +| `keep_token` | bool | 选填 | ture | 转发给后端时是否保留JWT | + +`consumers`中每一项的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ----------------------- | ----------------- | -------- | ------------------------------------------------- | ------------------------ | +| `name` | string | 必填 | - | 配置该consumer的名称 | +| `client_id` | string | 必填 | - | OAuth2 client id | +| `client_secret` | string | 必填 | - | OAuth2 client secret | + +`_rules_` 中每一项的配置字段说明如下: + +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +| ---------------- | --------------- | ------------------------------------------------- | ------ | -------------------------------------------------- | +| `_match_route_` | array of string | 选填,`_match_route_`,`_match_domain_`中选填一项 | - | 配置要匹配的路由名称 | +| `_match_domain_` | array of string | 选填,`_match_route_`,`_match_domain_`中选填一项 | - | 配置要匹配的域名 | +| `allow` | array of string | 必填 | - | 对于符合匹配条件的请求,配置允许访问的consumer名称 | + +**注意:** +- 对于开启该配置的路由,如果路径后缀和`auth_path`匹配,则该路由到原目标服务,而是用于生成Token +- 如果关闭`global_credentials`,请确保启用此插件的路由不是精确匹配路由,此时若存在另一条前缀匹配路由,则可能导致预期外行为 +- 若不配置`_rules_`字段,则默认对当前网关实例的所有路由开启认证; +- 对于通过认证鉴权的请求,请求的header会被添加一个`X-Mse-Consumer`字段,用以标识调用者的名称。 + +## 配置示例 + +### 对特定路由或域名开启 + +以下配置将对网关特定路由或域名开启 Jwt Auth 认证和鉴权,注意如果一个JWT能匹配多个`jwks`,则按照配置顺序命中第一个匹配的`consumer` + +```yaml +consumers: +- name: consumer1 + client_id: 12345678-xxxx-xxxx-xxxx-xxxxxxxxxxxx + client_secret: abcdefgh-xxxx-xxxx-xxxx-xxxxxxxxxxxx +- name: consumer2 + client_id: 87654321-xxxx-xxxx-xxxx-xxxxxxxxxxxx + client_secret: hgfedcba-xxxx-xxxx-xxxx-xxxxxxxxxxxx +# 使用 _rules_ 字段进行细粒度规则配置 +_rules_: +# 规则一:按路由名称匹配生效 +- _match_route_: + - route-a + - route-b + allow: + - consumer1 +# 规则二:按域名匹配生效 +- _match_domain_: + - "*.example.com" + - test.com + allow: + - consumer2 +``` + +此例 `_match_route_` 中指定的 `route-a` 和 `route-b` 即在创建网关路由时填写的路由名称,当匹配到这两个路由时,将允许`name`为`consumer1`的调用者访问,其他调用者不允许访问; + +此例 `_match_domain_` 中指定的 `*.example.com` 和 `test.com` 用于匹配请求的域名,当发现域名匹配时,将允许`name`为`consumer2`的调用者访问,其他调用者不允许访问。 + +#### 使用 Client Credential 授权模式 + +**获取 AccessToken** + +```bash + +# 通过 GET 方法获取 + +curl 'http://test.com/oauth2/token?grant_type=client_credentials&client_id=12345678-xxxx-xxxx-xxxx-xxxxxxxxxxxx&client_secret=abcdefgh-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + +# 通过 POST 方法获取 (需要先匹配到有真实目标服务的路由) + +curl 'http://test.com/oauth2/token' -H 'content-type: application/x-www-form-urlencoded' -d 'grant_type=client_credentials&client_id=12345678-xxxx-xxxx-xxxx-xxxxxxxxxxxx&client_secret=abcdefgh-xxxx-xxxx-xxxx-xxxxxxxxxxxx' + +# 获取响应中的 access_token 字段即可: +{ + "token_type": "bearer", + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uXC9hdCtqd3QifQ.eyJhdWQiOiJkZWZhdWx0IiwiY2xpZW50X2lkIjoiMTIzNDU2NzgteHh4eC14eHh4LXh4eHgteHh4eHh4eHh4eHh4IiwiZXhwIjoxNjg3OTUxNDYzLCJpYXQiOjE2ODc5NDQyNjMsImlzcyI6IkhpZ3Jlc3MtR2F0ZXdheSIsImp0aSI6IjEwOTU5ZDFiLThkNjEtNGRlYy1iZWE3LTk0ODEwMzc1YjYzYyIsInN1YiI6ImNvbnN1bWVyMSJ9.NkT_rG3DcV9543vBQgneVqoGfIhVeOuUBwLJJ4Wycb0", + "expires_in": 7200 +} + +``` + +**使用 AccessToken 请求** + +```bash + +curl 'http://test.com' -H 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uXC9hdCtqd3QifQ.eyJhdWQiOiJkZWZhdWx0IiwiY2xpZW50X2lkIjoiMTIzNDU2NzgteHh4eC14eHh4LXh4eHgteHh4eHh4eHh4eHh4IiwiZXhwIjoxNjg3OTUxNDYzLCJpYXQiOjE2ODc5NDQyNjMsImlzcyI6IkhpZ3Jlc3MtR2F0ZXdheSIsImp0aSI6IjEwOTU5ZDFiLThkNjEtNGRlYy1iZWE3LTk0ODEwMzc1YjYzYyIsInN1YiI6ImNvbnN1bWVyMSJ9.NkT_rG3DcV9543vBQgneVqoGfIhVeOuUBwLJJ4Wycb0' + +``` +因为 test.com 仅授权了 consumer2,但这个 Access Token 是基于 consumer1 的 `client_id`,`client_secret` 获取的,因此将返回 `403 Access Denied` + + +### 网关实例级别开启 + +以下配置未指定`_rules_`字段,因此将对网关实例级别开启 OAuth2 认证 + +```yaml +consumers: +- name: consumer1 + client_id: 12345678-xxxx-xxxx-xxxx-xxxxxxxxxxxx + client_secret: abcdefgh-xxxx-xxxx-xxxx-xxxxxxxxxxxx +- name: consumer2 + client_id: 87654321-xxxx-xxxx-xxxx-xxxxxxxxxxxx + client_secret: hgfedcba-xxxx-xxxx-xxxx-xxxxxxxxxxxx +``` + +# 常见错误码说明 + +| HTTP 状态码 | 出错信息 | 原因说明 | +| ----------- | ---------------------- | -------------------------------------------------------------------------------- | +| 401 | Invalid Jwt token | 请求头未提供JWT, 或者JWT格式错误,或过期等原因 | +| 403 | Access Denied | 无权限访问当前路由 | + diff --git a/plugins/wasm-cpp/extensions/oauth/plugin.cc b/plugins/wasm-cpp/extensions/oauth/plugin.cc new file mode 100644 index 0000000000..b518d8d374 --- /dev/null +++ b/plugins/wasm-cpp/extensions/oauth/plugin.cc @@ -0,0 +1,463 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/oauth/plugin.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "common/common_util.h" +#include "common/http_util.h" +#include "common/json_util.h" +#include "uuid.h" + +using ::nlohmann::json; +using ::Wasm::Common::JsonArrayIterate; +using ::Wasm::Common::JsonGetField; +using ::Wasm::Common::JsonObjectIterate; +using ::Wasm::Common::JsonValueAs; + +#ifdef NULL_PLUGIN + +namespace proxy_wasm { +namespace null_plugin { +namespace oauth { + +PROXY_WASM_NULL_PLUGIN_REGISTRY + +#endif +namespace { +constexpr absl::string_view TokenResponseTemplate = R"( +{ + "token_type": "bearer", + "access_token": "%s", + "expires_in": %u +})"; +const std::string& DefaultAudience = "default"; +const std::string& TypeHeader = "application/at+jwt"; +const std::string& BearerPrefix = "Bearer "; +const std::string& ClientCredentialsGrant = "client_credentials"; +constexpr uint32_t MaximumUriLength = 256; +constexpr std::string_view kRcDetailOAuthPrefix = "oauth_access_denied"; +std::string generateRcDetails(std::string_view error_msg) { + // Replace space with underscore since RCDetails may be written to access log. + // Some log processors assume each log segment is separated by whitespace. + return absl::StrCat(kRcDetailOAuthPrefix, "{", + absl::StrJoin(absl::StrSplit(error_msg, ' '), "_"), "}"); +} +} // namespace +static RegisterContextFactory register_OAuth(CONTEXT_FACTORY(PluginContext), + ROOT_FACTORY(PluginRootContext)); + +#define JSON_FIND_FIELD(dict, field) \ + auto dict##_##field##_json = dict.find(#field); \ + if (dict##_##field##_json == dict.end()) { \ + LOG_WARN("can't find '" #field "' in " #dict); \ + return false; \ + } + +#define JSON_VALUE_AS(type, src, dst, err_msg) \ + auto dst##_v = JsonValueAs(src); \ + if (dst##_v.second != Wasm::Common::JsonParserResultDetail::OK || \ + !dst##_v.first) { \ + LOG_WARN(#err_msg); \ + return false; \ + } \ + auto& dst = dst##_v.first.value(); + +#define JSON_FIELD_VALUE_AS(type, dict, field) \ + JSON_VALUE_AS(type, dict##_##field##_json.value(), dict##_##field, \ + "'" #field "' field in " #dict "convert to " #type " failed") + +bool PluginRootContext::generateToken(const OAuthConfigRule& rule, + const std::string& route_name, + const absl::string_view& raw_params, + std::string* token, + std::string* err_msg) { + auto params = Wasm::Common::Http::parseParameters(raw_params, 0, true); + auto it = params.find("grant_type"); + if (it == params.end()) { + *err_msg = "grant_type is missing"; + return false; + } + if (it->second != ClientCredentialsGrant) { + *err_msg = absl::StrFormat("grant_type:%s is not support", it->second); + return false; + } + it = params.find("client_id"); + if (it == params.end()) { + *err_msg = "client_id is missing"; + return false; + } + auto c_it = rule.consumers.find(it->second); + if (c_it == rule.consumers.end()) { + *err_msg = "invalid client_id or client_secret"; + return false; + } + const auto& consumer = c_it->second; + it = params.find("client_secret"); + if (it == params.end()) { + *err_msg = "client_secret is missing"; + return false; + } + if (it->second != consumer.client_secret) { + *err_msg = "invalid client_id or client_secret"; + return false; + } + auto jwt = jwt::create(); + if (rule.global_credentials) { + jwt.set_audience(DefaultAudience); + } else { + jwt.set_audience(route_name); + } + it = params.find("scope"); + if (it != params.end()) { + jwt.set_payload_claim("scope", jwt::claim(it->second)); + } + std::random_device rd; + auto seed_data = std::array{}; + std::generate(std::begin(seed_data), std::end(seed_data), std::ref(rd)); + std::seed_seq seq(std::begin(seed_data), std::end(seed_data)); + std::mt19937 generator(seq); + uuids::uuid_random_generator gen{generator}; + std::error_code ec; + *token = jwt.set_issuer(rule.issuer) + .set_type(TypeHeader) + .set_subject(consumer.name) + .set_issued_at(std::chrono::system_clock::now()) + .set_expires_at(std::chrono::system_clock::now() + + std::chrono::seconds{rule.token_ttl}) + .set_payload_claim("client_id", jwt::claim(consumer.client_id)) + .set_id(uuids::to_string(gen())) + .sign(jwt::algorithm::hs256{consumer.client_secret}, ec); + if (ec) { + *err_msg = absl::StrCat("jwt sign failed: %s", ec.message()); + return false; + } + return true; +} + +bool PluginRootContext::parsePluginConfig(const json& conf, + OAuthConfigRule& rule) { + std::unordered_set name_set; + if (!JsonArrayIterate(conf, "consumers", [&](const json& consumer) -> bool { + Consumer c; + JSON_FIND_FIELD(consumer, name); + JSON_FIELD_VALUE_AS(std::string, consumer, name); + if (name_set.count(consumer_name) != 0) { + LOG_WARN("consumer already exists: " + consumer_name); + return false; + } + c.name = consumer_name; + JSON_FIND_FIELD(consumer, client_id); + JSON_FIELD_VALUE_AS(std::string, consumer, client_id); + c.client_id = consumer_client_id; + if (rule.consumers.find(c.client_id) != rule.consumers.end()) { + LOG_WARN("consumer client_id already exists: " + c.client_id); + return false; + } + JSON_FIND_FIELD(consumer, client_secret); + JSON_FIELD_VALUE_AS(std::string, consumer, client_secret); + c.client_secret = consumer_client_secret; + rule.consumers.emplace(c.client_id, std::move(c)); + name_set.insert(consumer_name); + return true; + })) { + LOG_WARN("failed to parse configuration for consumers."); + return false; + } + // if (rule.consumers.empty()) { + // LOG_INFO("at least one consumer has to be configured for a rule."); + // return false; + // } + auto conf_issuer_json = conf.find("issuer"); + if (conf_issuer_json != conf.end()) { + JSON_FIELD_VALUE_AS(std::string, conf, issuer); + rule.issuer = conf_issuer; + } + auto conf_auth_header_json = conf.find("auth_header"); + if (conf_auth_header_json != conf.end()) { + JSON_FIELD_VALUE_AS(std::string, conf, auth_header); + rule.auth_header_name = conf_auth_header; + } + auto conf_auth_path_json = conf.find("auth_path"); + if (conf_auth_path_json != conf.end()) { + JSON_FIELD_VALUE_AS(std::string, conf, auth_path); + if (conf_auth_path.empty()) { + conf_auth_path = "/"; + } else if (conf_auth_path[0] != '/') { + conf_auth_path = absl::StrCat("/", conf_auth_path); + } + rule.auth_path = conf_auth_path; + } + auto conf_global_credentials_json = conf.find("global_credentials"); + if (conf_global_credentials_json != conf.end()) { + JSON_FIELD_VALUE_AS(bool, conf, global_credentials); + rule.global_credentials = conf_global_credentials; + } + auto conf_token_ttl_json = conf.find("token_ttl"); + if (conf_token_ttl_json != conf.end()) { + JSON_FIELD_VALUE_AS(uint64_t, conf, token_ttl); + rule.token_ttl = conf_token_ttl; + } + auto conf_keep_token_json = conf.find("keep_token"); + if (conf_keep_token_json != conf.end()) { + JSON_FIELD_VALUE_AS(bool, conf, keep_token); + rule.keep_token = conf_keep_token; + } + auto conf_clock_skew_seconds_json = conf.find("clock_skew_seconds"); + if (conf_clock_skew_seconds_json != conf.end()) { + JSON_FIELD_VALUE_AS(uint64_t, conf, clock_skew_seconds); + rule.clock_skew = conf_clock_skew_seconds; + } + return true; +} + +#define CLAIM_CHECK(token, claim, type) \ + if (!token.has_payload_claim(#claim)) { \ + LOG_DEBUG("claim is missing: " #claim); \ + goto failed; \ + } \ + if (token.get_payload_claim(#claim).get_type() != type) { \ + LOG_DEBUG("claim is invalid: " #claim); \ + goto failed; \ + } + +bool PluginRootContext::checkPlugin( + const OAuthConfigRule& rule, + const std::optional>& allow_set, + const std::string& route_name) { + auto auth_header = getRequestHeader(rule.auth_header_name)->toString(); + bool verified = false; + std::string token_str; + { + size_t pos; + if (auth_header.empty()) { + LOG_DEBUG("auth header is empty"); + goto failed; + } + pos = auth_header.find(BearerPrefix); + if (pos == std::string::npos) { + LOG_DEBUG("auth header is not a bearer token"); + goto failed; + } + auto start = pos + BearerPrefix.size(); + token_str = + std::string{auth_header.c_str() + start, auth_header.size() - start}; + auto token = jwt::decode(token_str); + CLAIM_CHECK(token, client_id, jwt::json::type::string); + CLAIM_CHECK(token, iss, jwt::json::type::string); + CLAIM_CHECK(token, sub, jwt::json::type::string); + CLAIM_CHECK(token, aud, jwt::json::type::string); + CLAIM_CHECK(token, exp, jwt::json::type::integer); + CLAIM_CHECK(token, iat, jwt::json::type::integer); + auto client_id = token.get_payload_claim("client_id").as_string(); + auto it = rule.consumers.find(client_id); + if (it == rule.consumers.end()) { + LOG_DEBUG(absl::StrFormat("client_id not found:%s", client_id)); + goto failed; + } + auto consumer = it->second; + auto verifier = + jwt::verify() + .allow_algorithm(jwt::algorithm::hs256{consumer.client_secret}) + .with_issuer(rule.issuer) + .with_subject(consumer.name) + .with_type(TypeHeader) + .leeway(rule.clock_skew); + std::error_code ec; + verifier.verify(token, ec); + if (ec) { + LOG_INFO(absl::StrFormat("token verify failed, token:%s, reason:%s", + token_str, ec.message())); + goto failed; + } + verified = true; + if (allow_set && + allow_set.value().find(consumer.name) == allow_set.value().end()) { + LOG_DEBUG(absl::StrFormat("consumer:%s is not in route's:%s allow_set", + consumer.name, route_name)); + goto failed; + } + if (!rule.global_credentials) { + auto audience_json = token.get_payload_claim("aud"); + if (audience_json.get_type() != jwt::json::type::string) { + LOG_DEBUG(absl::StrFormat("invalid audience, token:%s", token_str)); + goto failed; + } + auto audience = audience_json.as_string(); + if (audience != route_name) { + LOG_DEBUG(absl::StrFormat("audience:%s not match this route:%s", + audience, route_name)); + goto failed; + } + } + if (!rule.keep_token) { + removeRequestHeader(rule.auth_header_name); + } + addRequestHeader("X-Mse-Consumer", consumer.name); + return true; + } +failed: + if (!verified) { + auto authn_value = absl::StrCat( + "Bearer realm=\"", + Wasm::Common::Http::buildOriginalUri(MaximumUriLength), "\""); + sendLocalResponse(401, kRcDetailOAuthPrefix, "Invalid Jwt token", + {{"WWW-Authenticate", authn_value}}); + } else { + sendLocalResponse(403, kRcDetailOAuthPrefix, "Access Denied", {}); + } + return false; +} + +bool PluginRootContext::onConfigure(size_t size) { + // Parse configuration JSON string. + if (size > 0 && !configure(size)) { + LOG_WARN("configuration has errors initialization will not continue."); + setInvalidConfig(); + return false; + } + return true; +} + +bool PluginRootContext::configure(size_t configuration_size) { + auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration, + 0, configuration_size); + // Parse configuration JSON string. + auto result = ::Wasm::Common::JsonParse(configuration_data->view()); + if (!result) { + LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); + return false; + } + if (!parseAuthRuleConfig(result.value())) { + LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", + configuration_data->view())); + return false; + } + return true; +} + +FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) { + auto* rootCtx = rootContext(); + auto config = rootCtx->getMatchAuthConfig(); + if (!config.first) { + return FilterHeadersStatus::Continue; + } + config_ = config.first; + getValue({"route_name"}, &route_name_); + auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString(); + auto params_pos = path.find('?'); + size_t uri_end; + if (params_pos == std::string::npos) { + uri_end = path.size(); + } else { + uri_end = params_pos; + } + // Authorize request + if (absl::EndsWith({path.c_str(), uri_end}, + config_.value().get().auth_path)) { + std::string err_msg, token; + auto method = + getRequestHeader(Wasm::Common::Http::Header::Method)->toString(); + if (method == "GET") { + if (params_pos == std::string::npos) { + err_msg = "Authorize parameters are missing"; + goto done; + } + params_pos++; + rootCtx->generateToken( + config_.value(), route_name_, + {path.c_str() + params_pos, path.size() - params_pos}, &token, + &err_msg); + goto done; + } + if (method == "POST") { + auto content_type = + getRequestHeader(Wasm::Common::Http::Header::ContentType)->toString(); + if (!absl::StrContains(absl::AsciiStrToLower(content_type), + "application/x-www-form-urlencoded")) { + err_msg = "Invalid content-type"; + goto done; + } + check_body_params_ = true; + } + done: + if (!err_msg.empty()) { + sendLocalResponse(400, generateRcDetails(err_msg), err_msg, {}); + return FilterHeadersStatus::StopIteration; + } + if (!token.empty()) { + sendLocalResponse(200, "", + absl::StrFormat(TokenResponseTemplate, token, + config_.value().get().token_ttl), + {{"Content-Type", "application/json"}}); + } + return FilterHeadersStatus::Continue; + } + return rootCtx->checkAuthRule( + [rootCtx, this](const auto& config, const auto& allow_set) { + return rootCtx->checkPlugin(config, allow_set, route_name_); + }) + ? FilterHeadersStatus::Continue + : FilterHeadersStatus::StopIteration; +} + +FilterDataStatus PluginContext::onRequestBody(size_t body_size, + bool end_stream) { + if (!check_body_params_) { + return FilterDataStatus::Continue; + } + body_total_size_ += body_size; + if (!end_stream) { + return FilterDataStatus::StopIterationAndBuffer; + } + auto* rootCtx = rootContext(); + auto body = + getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_); + LOG_DEBUG(absl::StrFormat("authorize request body: %s", body->toString())); + std::string token, err_msg; + if (rootCtx->generateToken(config_.value(), route_name_, body->view(), &token, + &err_msg)) { + sendLocalResponse(200, "", + absl::StrFormat(TokenResponseTemplate, token, + config_.value().get().token_ttl), + {{"Content-Type", "application/json"}}); + return FilterDataStatus::Continue; + } + sendLocalResponse(400, generateRcDetails(err_msg), err_msg, {}); + return FilterDataStatus::StopIterationNoBuffer; +} + +#ifdef NULL_PLUGIN + +} // namespace oauth +} // namespace null_plugin +} // namespace proxy_wasm + +#endif diff --git a/plugins/wasm-cpp/extensions/oauth/plugin.h b/plugins/wasm-cpp/extensions/oauth/plugin.h new file mode 100644 index 0000000000..0a1b4333fa --- /dev/null +++ b/plugins/wasm-cpp/extensions/oauth/plugin.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include "common/route_rule_matcher.h" +#include "jwt-cpp/jwt.h" +#define ASSERT(_X) assert(_X) + +#ifndef NULL_PLUGIN + +#include "proxy_wasm_intrinsics.h" + +#else + +#include "include/proxy-wasm/null_plugin.h" + +namespace proxy_wasm { +namespace null_plugin { +namespace oauth { + +#endif + +struct Consumer { + std::string name; + std::string client_id; + std::string client_secret; +}; + +struct OAuthConfigRule { + std::unordered_map consumers; + std::string issuer = "Higress-Gateway"; + std::string auth_header_name = "Authorization"; + std::string auth_path = "/oauth2/token"; + bool global_credentials = true; + uint64_t token_ttl = 7200; + bool keep_token = true; + uint64_t clock_skew = 60; +}; + +// PluginRootContext is the root context for all streams processed by the +// thread. It has the same lifetime as the worker thread and acts as target for +// interactions that outlives individual stream, e.g. timer, async calls. +class PluginRootContext : public RootContext, + public RouteRuleMatcher { + public: + PluginRootContext(uint32_t id, std::string_view root_id) + : RootContext(id, root_id) {} + ~PluginRootContext() {} + bool onConfigure(size_t) override; + bool checkPlugin(const OAuthConfigRule&, + const std::optional>&, + const std::string&); + bool configure(size_t); + bool generateToken(const OAuthConfigRule& rule, const std::string& route_name, + const absl::string_view& raw_params, std::string* token, + std::string* err_msg); + + private: + bool parsePluginConfig(const json&, OAuthConfigRule&) override; +}; + +// Per-stream context. +class PluginContext : public Context { + public: + explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + FilterDataStatus onRequestBody(size_t, bool) override; + + private: + inline PluginRootContext* rootContext() { + return dynamic_cast(this->root()); + } + + std::string route_name_; + std::optional> config_; + bool check_body_params_ = false; + size_t body_total_size_ = 0; +}; + +#ifdef NULL_PLUGIN + +} // namespace oauth +} // namespace null_plugin +} // namespace proxy_wasm + +#endif diff --git a/plugins/wasm-cpp/extensions/oauth/plugin_test.cc b/plugins/wasm-cpp/extensions/oauth/plugin_test.cc new file mode 100644 index 0000000000..aa564adb33 --- /dev/null +++ b/plugins/wasm-cpp/extensions/oauth/plugin_test.cc @@ -0,0 +1,478 @@ +// Copyright (c) 2022 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "extensions/oauth/plugin.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/null.h" + +namespace proxy_wasm { +namespace null_plugin { +namespace oauth { + +NullPluginRegistry* context_registry_; +RegisterNullVmPluginFactory register_oauth_plugin("oauth", []() { + return std::make_unique(oauth::context_registry_); +}); + +class MockContext : public proxy_wasm::ContextBase { + public: + MockContext(WasmBase* wasm) : ContextBase(wasm) {} + + MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType)); + MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view)); + MOCK_METHOD(WasmDataPtr, getBufferBytes, (WasmBufferType, size_t, size_t)); + MOCK_METHOD(WasmResult, getHeaderMapPairs, (WasmHeaderMapType, Pairs*)); + MOCK_METHOD(WasmResult, getHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* jwt */, + std::string_view* /*result */)); + MOCK_METHOD(WasmResult, addHeaderMapValue, + (WasmHeaderMapType /* type */, std::string_view /* jwt */, + std::string_view /* value */)); + MOCK_METHOD(WasmResult, sendLocalResponse, + (uint32_t /* response_code */, std::string_view /* body */, + Pairs /* additional_headers */, uint32_t /* grpc_status */, + std::string_view /* details */)); + MOCK_METHOD(uint64_t, getCurrentTimeNanoseconds, ()); + MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*)); + MOCK_METHOD(WasmResult, httpCall, + (std::string_view, const Pairs&, std::string_view, const Pairs&, + int, uint32_t*)); +}; + +class OAuthTest : public ::testing::Test { + protected: + OAuthTest() { + // Initialize test VM + test_vm_ = createNullVm(); + wasm_base_ = std::make_unique( + std::move(test_vm_), "test-vm", "", "", + std::unordered_map{}, + AllowedCapabilitiesMap{}); + wasm_base_->load("oauth"); + wasm_base_->initialize(); + + // Initialize host side context + mock_context_ = std::make_unique(wasm_base_.get()); + current_context_ = mock_context_.get(); + + ON_CALL(*mock_context_, log(testing::_, testing::_)) + .WillByDefault([](uint32_t, std::string_view m) { + std::cerr << m << "\n"; + return WasmResult::Ok; + }); + + ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders, + testing::_, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view header, + std::string_view* result) { + if (header == ":authority") { + *result = authority_; + } + if (header == ":path") { + *result = path_; + } + if (header == ":method") { + *result = method_; + } + if (header == "Authorization") { + *result = jwt_header_; + } + if (header == "content-type") { + *result = content_type_; + } + if (header == "x-custom-header") { + *result = custom_header_; + } + return WasmResult::Ok; + }); + ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders, + testing::_, testing::_)) + .WillByDefault([&](WasmHeaderMapType, std::string_view jwt, + std::string_view value) { return WasmResult::Ok; }); + + ON_CALL(*mock_context_, getCurrentTimeNanoseconds()).WillByDefault([&]() { + return current_time_; + }); + + ON_CALL(*mock_context_, getProperty(testing::_, testing::_)) + .WillByDefault([&](std::string_view path, std::string* result) { + *result = route_name_; + return WasmResult::Ok; + }); + + ON_CALL(*mock_context_, getBufferBytes(WasmBufferType::HttpCallResponseBody, + testing::_, testing::_)) + .WillByDefault([&](WasmBufferType, size_t, size_t) { + return std::make_unique(http_call_body_.data(), + http_call_body_.size()); + }); + + ON_CALL(*mock_context_, + getHeaderMapPairs(WasmHeaderMapType::HttpCallResponseHeaders, + testing::_)) + .WillByDefault([&](WasmHeaderMapType, Pairs* result) { + *result = http_call_headers_; + return WasmResult::Ok; + }); + + ON_CALL(*mock_context_, httpCall(testing::_, testing::_, testing::_, + testing::_, testing::_, testing::_)) + .WillByDefault([&](std::string_view, const Pairs&, std::string_view, + const Pairs&, int, uint32_t* token_ptr) { + root_context_->onHttpCallResponse( + *token_ptr, http_call_headers_.size(), http_call_body_.size(), 0); + return WasmResult::Ok; + }); + + // Initialize Wasm sandbox context + root_context_ = std::make_unique(0, ""); + context_ = std::make_unique(1, root_context_.get()); + } + ~OAuthTest() override {} + + std::unique_ptr wasm_base_; + std::unique_ptr test_vm_; + std::unique_ptr mock_context_; + + std::unique_ptr root_context_; + std::unique_ptr context_; + + std::string path_; + std::string method_; + std::string authority_; + std::string route_name_; + std::string jwt_header_; + std::string custom_header_; + std::string content_type_; + uint64_t current_time_; + + Pairs http_call_headers_; + std::string http_call_body_; +}; + +TEST_F(OAuthTest, generateToken) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + } + ], + "auth_path": "test/token" +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + path_ = "/abc/test/token"; + method_ = "GET"; + EXPECT_CALL(*mock_context_, + sendLocalResponse( + 400, std::string_view("Authorize parameters are missing"), + testing::_, testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + path_ = "/abc/test/token?"; + method_ = "GET"; + EXPECT_CALL(*mock_context_, + sendLocalResponse(400, std::string_view("grant_type is missing"), + testing::_, testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + path_ = + "/abc/test/" + "token?grant_type=client_credentials"; + method_ = "GET"; + EXPECT_CALL(*mock_context_, + sendLocalResponse(400, std::string_view("client_id is missing"), + testing::_, testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + path_ = + "/abc/test/" + "token?grant_type=client_credentials&client_id=9515b564-0b1d-11ee-9c4c-" + "00163e1250b5&client_secret=abcd"; + method_ = "GET"; + EXPECT_CALL(*mock_context_, + sendLocalResponse( + 400, std::string_view("invalid client_id or client_secret"), + testing::_, testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + path_ = + "/abc/test/" + "token?grant_type=client_credentials&client_id=9515b564-0b1d-11ee-9c4c-" + "00163e1250b5&client_secret=9e55de56-0b1d-11ee-b8ec-00163e1250b5"; + method_ = "GET"; + EXPECT_CALL(*mock_context_, sendLocalResponse(200, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + + path_ = "/abc/test/token"; + method_ = "POST"; + content_type_ = "application/x-www-form-urlencoded; charset=utf8"; + std::string body = "grant_type=client_credentials&client_id=wrongid"; + BufferBase body_buffer; + body_buffer.set({body.data(), body.size()}); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::HttpRequestBody)) + .WillOnce([&body_buffer](WasmBufferType) { return &body_buffer; }); + EXPECT_CALL(*mock_context_, + sendLocalResponse( + 400, std::string_view("invalid client_id or client_secret"), + testing::_, testing::_, testing::_)); + EXPECT_EQ(context_->onRequestBody(body.size(), true), + FilterDataStatus::StopIterationNoBuffer); + + path_ = "/abc/test/token"; + method_ = "POST"; + content_type_ = "application/x-www-form-urlencoded; charset=utf8"; + body = + "grant_type=client_credentials&client_id=9515b564-0b1d-11ee-9c4c-" + "00163e1250b5&client_secret=9e55de56-0b1d-11ee-b8ec-00163e1250b5"; + body_buffer; + body_buffer.set({body.data(), body.size()}); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::HttpRequestBody)) + .WillOnce([&body_buffer](WasmBufferType) { return &body_buffer; }); + EXPECT_CALL(*mock_context_, sendLocalResponse(200, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestBody(body.size(), true), + FilterDataStatus::Continue); +} + +TEST_F(OAuthTest, invalidToken) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + } + ] +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = R"(Bearer alksdjf)"; + EXPECT_CALL(*mock_context_, sendLocalResponse(401, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + jwt_header_ = R"(alksdjf)"; + EXPECT_CALL(*mock_context_, sendLocalResponse(401, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJkZWZhdWx0IiwiZXhwIjoxNjY1NjczODI5LCJpYXQiOjE2NjU2NzM4MTksImlzcyI6IkhpZ3Jlc3MtR2F0ZXdheSIsImp0aSI6IjEwOTU5ZDFiLThkNjEtNGRlYy1iZWE3LTk0ODEwMzc1YjYzYyIsInNjb3BlIjoidGVzdCIsInN1YiI6ImNvbnN1bWVyMiJ9.al7eoRdoNQlNx8HCqNesj7woiLOJmJLSqnZ)"; + EXPECT_CALL(*mock_context_, sendLocalResponse(401, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); +} + +TEST_F(OAuthTest, expire) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + } + ] +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJ0ZXN0MiIsImNsaWVudF9pZCI6Ijk1MTViNTY0LTBiMWQtMTFlZS05YzRjLTAwMTYzZTEyNTBiNSIsImV4cCI6MTY2NTY3MzgyOSwiaWF0IjoxNjY1NjczODE5LCJpc3MiOiJIaWdyZXNzLUdhdGV3YXkiLCJqdGkiOiIxMDk1OWQxYi04ZDYxLTRkZWMtYmVhNy05NDgxMDM3NWI2M2MiLCJzY29wZSI6InRlc3QiLCJzdWIiOiJjb25zdW1lcjEifQ.LsZ6mlRxlaqWa0IAZgmGVuDgypRbctkTcOyoCxqLrHY)"; + route_name_ = "test2"; + EXPECT_CALL(*mock_context_, sendLocalResponse(401, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); +} + +TEST_F(OAuthTest, routeAuth) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + } + ], + "global_credentials": false, + "clock_skew_seconds": 3153600000 +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJ0ZXN0MiIsImNsaWVudF9pZCI6Ijk1MTViNTY0LTBiMWQtMTFlZS05YzRjLTAwMTYzZTEyNTBiNSIsImV4cCI6MTY2NTY3MzgyOSwiaWF0IjoxNjY1NjczODE5LCJpc3MiOiJIaWdyZXNzLUdhdGV3YXkiLCJqdGkiOiIxMDk1OWQxYi04ZDYxLTRkZWMtYmVhNy05NDgxMDM3NWI2M2MiLCJzY29wZSI6InRlc3QiLCJzdWIiOiJjb25zdW1lcjEifQ.LsZ6mlRxlaqWa0IAZgmGVuDgypRbctkTcOyoCxqLrHY)"; + EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + route_name_ = "test2"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + +TEST_F(OAuthTest, globalAuth) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + } + ], + "clock_skew_seconds": 3153600000 +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJ0ZXN0MiIsImNsaWVudF9pZCI6Ijk1MTViNTY0LTBiMWQtMTFlZS05YzRjLTAwMTYzZTEyNTBiNSIsImV4cCI6MTY2NTY3MzgyOSwiaWF0IjoxNjY1NjczODE5LCJpc3MiOiJIaWdyZXNzLUdhdGV3YXkiLCJqdGkiOiIxMDk1OWQxYi04ZDYxLTRkZWMtYmVhNy05NDgxMDM3NWI2M2MiLCJzY29wZSI6InRlc3QiLCJzdWIiOiJjb25zdW1lcjEifQ.LsZ6mlRxlaqWa0IAZgmGVuDgypRbctkTcOyoCxqLrHY)"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + +TEST_F(OAuthTest, AuthZ) { + std::string configuration = R"( +{ + "consumers": [ + { + "name": "consumer1", + "client_id": "9515b564-0b1d-11ee-9c4c-00163e1250b5", + "client_secret": "9e55de56-0b1d-11ee-b8ec-00163e1250b5" + }, + { + "name": "consumer2", + "client_id": "d001d242-0bf0-11ee-97cb-00163e1250b5", + "client_secret": "d60bdafc-0bf0-11ee-afba-00163e1250b5" + } + ], + "clock_skew_seconds": 3153600000, + "global_credentials": true, + "_rules_": [ + { + "_match_route_": [ + "test1" + ], + "allow": [ + "consumer2" + ] + }, + { + "_match_route_": [ + "test2" + ], + "allow": [ + "consumer1" + ] + } + ] +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJ0ZXN0MiIsImNsaWVudF9pZCI6Ijk1MTViNTY0LTBiMWQtMTFlZS05YzRjLTAwMTYzZTEyNTBiNSIsImV4cCI6MTY2NTY3MzgyOSwiaWF0IjoxNjY1NjczODE5LCJpc3MiOiJIaWdyZXNzLUdhdGV3YXkiLCJqdGkiOiIxMDk1OWQxYi04ZDYxLTRkZWMtYmVhNy05NDgxMDM3NWI2M2MiLCJzY29wZSI6InRlc3QiLCJzdWIiOiJjb25zdW1lcjEifQ.LsZ6mlRxlaqWa0IAZgmGVuDgypRbctkTcOyoCxqLrHY)"; + route_name_ = "test1"; + EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + route_name_ = "test2"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJkZWZhdWx0IiwiY2xpZW50X2lkIjoiZDAwMWQyNDItMGJmMC0xMWVlLTk3Y2ItMDAxNjNlMTI1MGI1IiwiZXhwIjoxNjY1NjczODI5LCJpYXQiOjE2NjU2NzM4MTksImlzcyI6IkhpZ3Jlc3MtR2F0ZXdheSIsImp0aSI6IjEwOTU5ZDFiLThkNjEtNGRlYy1iZWE3LTk0ODEwMzc1YjYzYyIsInNjb3BlIjoidGVzdCIsInN1YiI6ImNvbnN1bWVyMiJ9.whS5U7llGX2BNAX19mjyxiWXa7wVs0_ONVByKVR9ntM)"; + route_name_ = "test2"; + EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + route_name_ = "test1"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + +TEST_F(OAuthTest, EmptyConsumer) { + std::string configuration = R"( +{ + "consumers": [ + ], + "_rules_": [ + { + "_match_route_": [ + "test1" + ], + "allow": [ + ] + } + ] +})"; + BufferBase buffer; + buffer.set({configuration.data(), configuration.size()}); + + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + jwt_header_ = + R"(Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6ImFwcGxpY2F0aW9uL2F0K2p3dCJ9.eyJhdWQiOiJ0ZXN0MiIsImNsaWVudF9pZCI6Ijk1MTViNTY0LTBiMWQtMTFlZS05YzRjLTAwMTYzZTEyNTBiNSIsImV4cCI6MTY2NTY3MzgyOSwiaWF0IjoxNjY1NjczODE5LCJpc3MiOiJIaWdyZXNzLUdhdGV3YXkiLCJqdGkiOiIxMDk1OWQxYi04ZDYxLTRkZWMtYmVhNy05NDgxMDM3NWI2M2MiLCJzY29wZSI6InRlc3QiLCJzdWIiOiJjb25zdW1lcjEifQ.LsZ6mlRxlaqWa0IAZgmGVuDgypRbctkTcOyoCxqLrHY)"; + route_name_ = "test1"; + EXPECT_CALL(*mock_context_, sendLocalResponse(401, testing::_, testing::_, + testing::_, testing::_)); + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + route_name_ = "test2"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + +} // namespace oauth +} // namespace null_plugin +} // namespace proxy_wasm