diff --git a/apisix/plugins/ai-content-moderation.lua b/apisix/plugins/ai-content-moderation.lua index f4958f321ae3..0fbcdb24f847 100644 --- a/apisix/plugins/ai-content-moderation.lua +++ b/apisix/plugins/ai-content-moderation.lua @@ -55,12 +55,23 @@ local schema = { provider = { type = "object", properties = { - aws_comprehend = aws_comprehend_schema + aws_comprehend = aws_comprehend_schema, + custom_comprehend = { + type = "object", + properties = { + endpoint = { + type = "string", + pattern = [[^https?://]] + }, + region = { + type = "string", + default = "us-east-1", + } + }, + required = { "endpoint", } + }, }, maxProperties = 1, - -- ensure only one provider can be configured while implementing support for - -- other providers - required = { "aws_comprehend" } }, moderation_categories = { type = "object", @@ -120,14 +131,18 @@ function _M.rewrite(conf, ctx) local provider = conf.provider[next(conf.provider)] local credentials = aws_instance:Credentials({ - accessKeyId = provider.access_key_id, - secretAccessKey = provider.secret_access_key, + accessKeyId = provider.access_key_id or "", + secretAccessKey = provider.secret_access_key or "", sessionToken = provider.session_token, }) - local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com" - local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint)) - local endpoint = scheme .. "://" .. host + local endpoint = provider.endpoint + if not endpoint then + endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com" + end + + local scheme, host, port = unpack(http:parse_uri(endpoint)) + endpoint = scheme .. "://" .. host aws_instance.config.endpoint = endpoint aws_instance.config.ssl_verify = provider.ssl_verify diff --git a/t/plugin/ai-content-moderation.custom.t b/t/plugin/ai-content-moderation.custom.t new file mode 100644 index 000000000000..b88050c7351c --- /dev/null +++ b/t/plugin/ai-content-moderation.custom.t @@ -0,0 +1,295 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + listen 2668; + + default_type 'application/json'; + + location / { + content_by_lua_block { + local json = require("cjson.safe") + local open = io.open + local f = open('t/assets/content-moderation-responses.json', "r") + local resp = f:read("*a") + f:close() + + if not resp then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to open response.json file") + end + + local responses = json.decode(resp) + if not responses then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to decode response.json contents") + end + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + if not body then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to get request body: ", err) + end + + body, err = json.decode(body) + if not body then + ngx.status(503) + ngx.say("[INTERNAL FAILURE]: failed to decoded request body: ", err) + end + local result = body.TextSegments[1].Text + local final_response = responses[result] or "invalid" + + if final_response == "invalid" then + ngx.status = 500 + end + ngx.say(json.encode(final_response)) + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: sanity +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "custom_comprehend": { + "endpoint": "http://localhost:2668" + } + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: toxic request should fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"toxic"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds toxicity threshold + + + +=== TEST 3: good request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 + + + +=== TEST 4: profanity filter +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "custom_comprehend": { + "endpoint": "http://localhost:2668" + } + }, + "moderation_categories": { + "PROFANITY": 0.5 + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 5: profane request should fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 6: very profane request should also fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 7: good_request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200 + + + +=== TEST 8: set profanity = 0.7 (allow profane request but disallow very_profane) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/echo", + "plugins": { + "ai-content-moderation": { + "provider": { + "custom_comprehend": { + "endpoint": "http://localhost:2668" + } + }, + "moderation_categories": { + "PROFANITY": 0.7 + }, + "llm_provider": "openai" + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "127.0.0.1:1980": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 9: profane request should pass profanity check but fail toxicity check +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds toxicity threshold + + + +=== TEST 10: profane_but_not_toxic request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane_but_not_toxic"}]} +--- error_code: 200 + + + +=== TEST 11: but very profane request will fail +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]} +--- error_code: 400 +--- response_body chomp +request body exceeds PROFANITY threshold + + + +=== TEST 12: good_request should pass +--- request +POST /echo +{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]} +--- error_code: 200