Skip to content

Commit

Permalink
feat(content-moderation): support custom_comprehend
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Feb 26, 2025
1 parent da06747 commit b18a02e
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 9 deletions.
33 changes: 24 additions & 9 deletions apisix/plugins/ai-content-moderation.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,23 @@ local schema = {
provider = {
type = "object",
properties = {
aws_comprehend = aws_comprehend_schema
aws_comprehend = aws_comprehend_schema,
custom_comprehend = {
type = "object",
properties = {
endpoint = {
type = "string",
pattern = [[^https?://]]
},
region = {
type = "string",
default = "us-east-1",
}
},
required = { "endpoint", }
},
},
maxProperties = 1,
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "aws_comprehend" }
},
moderation_categories = {
type = "object",
Expand Down Expand Up @@ -120,14 +131,18 @@ function _M.rewrite(conf, ctx)
local provider = conf.provider[next(conf.provider)]

local credentials = aws_instance:Credentials({
accessKeyId = provider.access_key_id,
secretAccessKey = provider.secret_access_key,
accessKeyId = provider.access_key_id or "",
secretAccessKey = provider.secret_access_key or "",
sessionToken = provider.session_token,
})

local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com"
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint))
local endpoint = scheme .. "://" .. host
local endpoint = provider.endpoint
if not endpoint then
endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com"
end

local scheme, host, port = unpack(http:parse_uri(endpoint))
endpoint = scheme .. "://" .. host
aws_instance.config.endpoint = endpoint
aws_instance.config.ssl_verify = provider.ssl_verify

Expand Down
295 changes: 295 additions & 0 deletions t/plugin/ai-content-moderation.custom.t
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

use t::APISIX 'no_plan';

log_level("info");
repeat_each(1);
no_long_string();
no_root_location();


add_block_preprocessor(sub {
my ($block) = @_;

if (!defined $block->request) {
$block->set_value("request", "GET /t");
}

my $http_config = $block->http_config // <<_EOC_;
server {
listen 2668;
default_type 'application/json';
location / {
content_by_lua_block {
local json = require("cjson.safe")
local open = io.open
local f = open('t/assets/content-moderation-responses.json', "r")
local resp = f:read("*a")
f:close()
if not resp then
ngx.status(503)
ngx.say("[INTERNAL FAILURE]: failed to open response.json file")
end
local responses = json.decode(resp)
if not responses then
ngx.status(503)
ngx.say("[INTERNAL FAILURE]: failed to decode response.json contents")
end
if ngx.req.get_method() ~= "POST" then
ngx.status = 400
ngx.say("Unsupported request method: ", ngx.req.get_method())
end
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
if not body then
ngx.status(503)
ngx.say("[INTERNAL FAILURE]: failed to get request body: ", err)
end
body, err = json.decode(body)
if not body then
ngx.status(503)
ngx.say("[INTERNAL FAILURE]: failed to decoded request body: ", err)
end
local result = body.TextSegments[1].Text
local final_response = responses[result] or "invalid"
if final_response == "invalid" then
ngx.status = 500
end
ngx.say(json.encode(final_response))
}
}
}
_EOC_

$block->set_value("http_config", $http_config);
});

run_tests();

__DATA__
=== TEST 1: sanity
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/echo",
"plugins": {
"ai-content-moderation": {
"provider": {
"custom_comprehend": {
"endpoint": "http://localhost:2668"
}
},
"llm_provider": "openai"
}
},
"upstream": {
"type": "roundrobin",
"nodes": {
"127.0.0.1:1980": 1
}
}
}]]
)
if code >= 300 then
ngx.status = code
end
ngx.say(body)
}
}
--- response_body
passed
=== TEST 2: toxic request should fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"toxic"}]}
--- error_code: 400
--- response_body chomp
request body exceeds toxicity threshold
=== TEST 3: good request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
--- error_code: 200
=== TEST 4: profanity filter
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/echo",
"plugins": {
"ai-content-moderation": {
"provider": {
"custom_comprehend": {
"endpoint": "http://localhost:2668"
}
},
"moderation_categories": {
"PROFANITY": 0.5
},
"llm_provider": "openai"
}
},
"upstream": {
"type": "roundrobin",
"nodes": {
"127.0.0.1:1980": 1
}
}
}]]
)
if code >= 300 then
ngx.status = code
end
ngx.say(body)
}
}
--- response_body
passed
=== TEST 5: profane request should fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]}
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
=== TEST 6: very profane request should also fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]}
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
=== TEST 7: good_request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
--- error_code: 200
=== TEST 8: set profanity = 0.7 (allow profane request but disallow very_profane)
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/echo",
"plugins": {
"ai-content-moderation": {
"provider": {
"custom_comprehend": {
"endpoint": "http://localhost:2668"
}
},
"moderation_categories": {
"PROFANITY": 0.7
},
"llm_provider": "openai"
}
},
"upstream": {
"type": "roundrobin",
"nodes": {
"127.0.0.1:1980": 1
}
}
}]]
)
if code >= 300 then
ngx.status = code
end
ngx.say(body)
}
}
--- response_body
passed
=== TEST 9: profane request should pass profanity check but fail toxicity check
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]}
--- error_code: 400
--- response_body chomp
request body exceeds toxicity threshold
=== TEST 10: profane_but_not_toxic request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane_but_not_toxic"}]}
--- error_code: 200
=== TEST 11: but very profane request will fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]}
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
=== TEST 12: good_request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
--- error_code: 200

0 comments on commit b18a02e

Please sign in to comment.