Skip to content

Commit

Permalink
refactor code and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Feb 28, 2025
1 parent a131c10 commit f26a3e0
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 112 deletions.
77 changes: 27 additions & 50 deletions apisix/plugins/ai-aws-content-moderation.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,27 @@ local require = require
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST


local aws_comprehend_schema = {
type = "object",
properties = {
access_key_id = { type = "string" },
secret_access_key = { type = "string" },
region = { type = "string" },
endpoint = {
type = "string",
pattern = [[^https?://]]
},
ssl_verify = {
type = "boolean",
default = true
}
},
required = { "access_key_id", "secret_access_key", "region", }
}

local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|"..
"HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$"
local schema = {
type = "object",
properties = {
provider = {
comprehend = {
type = "object",
properties = {
aws_comprehend = aws_comprehend_schema
access_key_id = { type = "string" },
secret_access_key = { type = "string" },
region = { type = "string" },
endpoint = {
type = "string",
pattern = [[^https?://]]
},
ssl_verify = {
type = "boolean",
default = true
}
},
maxProperties = 1,
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "aws_comprehend" }
required = { "access_key_id", "secret_access_key", "region", }
},
moderation_categories = {
type = "object",
Expand All @@ -78,13 +66,9 @@ local schema = {
minimum = 0,
maximum = 1,
default = 0.5
},
llm_provider = {
type = "string",
enum = { "openai" },
}
},
required = { "provider", "llm_provider" },
required = { "comprehend" },
}


Expand All @@ -107,51 +91,44 @@ function _M.rewrite(conf, ctx)
return HTTP_INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf"
end

local body, err = core.request.get_json_request_body_table()
local body, err = core.request.get_body()
if not body then
return HTTP_BAD_REQUEST, err
end

local msgs = body.messages
if type(msgs) ~= "table" or #msgs < 1 then
return HTTP_BAD_REQUEST, "messages not found in request body"
end

local provider = conf.provider[next(conf.provider)]
local comprehend = conf.comprehend

local credentials = aws_instance:Credentials({
accessKeyId = provider.access_key_id,
secretAccessKey = provider.secret_access_key,
sessionToken = provider.session_token,
accessKeyId = comprehend.access_key_id,
secretAccessKey = comprehend.secret_access_key,
sessionToken = comprehend.session_token,
})

local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com"
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint))
local default_endpoint = "https://comprehend." .. comprehend.region .. ".amazonaws.com"
local scheme, host, port = unpack(http:parse_uri(comprehend.endpoint or default_endpoint))
local endpoint = scheme .. "://" .. host
aws_instance.config.endpoint = endpoint
aws_instance.config.ssl_verify = provider.ssl_verify
aws_instance.config.ssl_verify = comprehend.ssl_verify

local comprehend = aws_instance:Comprehend({
credentials = credentials,
endpoint = endpoint,
region = provider.region,
region = comprehend.region,
port = port,
})

local ai_module = require("apisix.plugins.ai." .. conf.llm_provider)
local create_request_text_segments = ai_module.create_request_text_segments

local text_segments = create_request_text_segments(msgs)
local res, err = comprehend:detectToxicContent({
LanguageCode = "en",
TextSegments = text_segments,
TextSegments = {{
Text = body
}},
})

if not res then
core.log.error("failed to send request to ", endpoint, ": ", err)
return HTTP_INTERNAL_SERVER_ERROR, err
end

core.log.warn("dibag: ", core.json.encode(res))
local results = res.body and res.body.ResultList
if type(results) ~= "table" or core.table.isempty(results) then
return HTTP_INTERNAL_SERVER_ERROR, "failed to get moderation results from response"
Expand Down
30 changes: 12 additions & 18 deletions t/plugin/ai-aws-content-moderation-secrets.t
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,12 @@ Success! Data written to: kv/apisix/foo
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "$secret://vault/test1/foo/access_key_id",
"secret_access_key": "$secret://vault/test1/foo/secret_access_key",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
},
"llm_provider": "openai"
"comprehend": {
"access_key_id": "$secret://vault/test1/foo/access_key_id",
"secret_access_key": "$secret://vault/test1/foo/secret_access_key",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
}
},
"upstream": {
Expand Down Expand Up @@ -170,15 +167,12 @@ POST /echo
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "$env://ACCESS_KEY_ID",
"secret_access_key": "$env://SECRET_ACCESS_KEY",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
},
"llm_provider": "openai"
"comprehend": {
"access_key_id": "$env://ACCESS_KEY_ID",
"secret_access_key": "$env://SECRET_ACCESS_KEY",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
}
},
"upstream": {
Expand Down
63 changes: 27 additions & 36 deletions t/plugin/ai-aws-content-moderation.t
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,12 @@ __DATA__
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
},
"llm_provider": "openai"
"comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
}
},
"upstream": {
Expand All @@ -136,7 +133,7 @@ passed
=== TEST 2: toxic request should fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"toxic"}]}
toxic
--- error_code: 400
--- response_body chomp
request body exceeds toxicity threshold
Expand All @@ -146,7 +143,7 @@ request body exceeds toxicity threshold
=== TEST 3: good request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
good_request
--- error_code: 200
Expand All @@ -162,18 +159,15 @@ POST /echo
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
"comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
},
"moderation_categories": {
"PROFANITY": 0.5
},
"llm_provider": "openai"
}
}
},
"upstream": {
Expand All @@ -199,7 +193,7 @@ passed
=== TEST 5: profane request should fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]}
profane
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
Expand All @@ -209,7 +203,7 @@ request body exceeds PROFANITY threshold
=== TEST 6: very profane request should also fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]}
very_profane
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
Expand All @@ -219,7 +213,7 @@ request body exceeds PROFANITY threshold
=== TEST 7: good_request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
good_request
--- error_code: 200
Expand All @@ -235,18 +229,15 @@ POST /echo
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
"comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
},
"moderation_categories": {
"PROFANITY": 0.7
},
"llm_provider": "openai"
}
}
},
"upstream": {
Expand All @@ -272,7 +263,7 @@ passed
=== TEST 9: profane request should pass profanity check but fail toxicity check
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane"}]}
profane
--- error_code: 400
--- response_body chomp
request body exceeds toxicity threshold
Expand All @@ -282,15 +273,15 @@ request body exceeds toxicity threshold
=== TEST 10: profane_but_not_toxic request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"profane_but_not_toxic"}]}
profane_but_not_toxic
--- error_code: 200
=== TEST 11: but very profane request will fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"very_profane"}]}
very_profane
--- error_code: 400
--- response_body chomp
request body exceeds PROFANITY threshold
Expand All @@ -300,5 +291,5 @@ request body exceeds PROFANITY threshold
=== TEST 12: good_request should pass
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"good_request"}]}
good_request
--- error_code: 200
14 changes: 6 additions & 8 deletions t/plugin/ai-aws-content-moderation2.t
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ __DATA__
"uri": "/echo",
"plugins": {
"ai-aws-content-moderation": {
"provider": {
"aws_comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
}
"comprehend": {
"access_key_id": "access",
"secret_access_key": "ea+secret",
"region": "us-east-1",
"endpoint": "http://localhost:2668"
},
"llm_provider": "openai"
}
Expand Down Expand Up @@ -80,7 +78,7 @@ passed
=== TEST 2: request should fail
--- request
POST /echo
{"model":"gpt-4o-mini","messages":[{"role":"user","content":"toxic"}]}
toxic
--- error_code: 500
--- response_body chomp
Comprehend:detectToxicContent() failed to connect to 'http://localhost:2668': connection refused
Expand Down

0 comments on commit f26a3e0

Please sign in to comment.