diff --git a/internal/encryptor/encryptor.go b/internal/encryptor/encryptor.go index 74a7ee70b9..d150cfdec2 100644 --- a/internal/encryptor/encryptor.go +++ b/internal/encryptor/encryptor.go @@ -17,7 +17,6 @@ package encryptor import ( "crypto/rand" "fmt" - "io" "github.com/lf-edge/ekuiper/v2/internal/conf" "github.com/lf-edge/ekuiper/v2/internal/encryptor/aes" @@ -26,20 +25,15 @@ import ( func GetEncryptor(name string) (message.Encryptor, error) { if name == "aes" { + if conf.Config == nil || conf.Config.AesKey == nil { + return nil, fmt.Errorf("AES key is not defined") + } key, iv := getKeyIv() return aes.NewStreamEncrypter(key, iv) } return nil, fmt.Errorf("unsupported encryptor: %s", name) } -func GetEncryptWriter(name string, output io.Writer) (io.Writer, error) { - if name == "aes" { - key, iv := getKeyIv() - return aes.NewStreamWriter(key, iv, output) - } - return nil, fmt.Errorf("unsupported encryptor: %s", name) -} - func getKeyIv() ([]byte, []byte) { key := conf.Config.AesKey iv := make([]byte, 16) diff --git a/internal/topo/node/encrypt_op.go b/internal/topo/node/encrypt_op.go new file mode 100644 index 0000000000..87306f0e33 --- /dev/null +++ b/internal/topo/node/encrypt_op.go @@ -0,0 +1,66 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "fmt" + + "github.com/lf-edge/ekuiper/contract/v2/api" + "github.com/lf-edge/ekuiper/v2/internal/encryptor" + "github.com/lf-edge/ekuiper/v2/internal/pkg/def" + "github.com/lf-edge/ekuiper/v2/pkg/infra" + "github.com/lf-edge/ekuiper/v2/pkg/message" +) + +type EncryptNode struct { + *defaultSinkNode + tool message.Encryptor +} + +func NewEncryptOp(name string, rOpt *def.RuleOption, encryptMethod string) (*EncryptNode, error) { + dc, err := encryptor.GetEncryptor(encryptMethod) + if err != nil { + return nil, fmt.Errorf("get encryptor %s fail with error: %v", encryptMethod, err) + } + return &EncryptNode{ + defaultSinkNode: newDefaultSinkNode(name, rOpt), + tool: dc, + }, nil +} + +func (o *EncryptNode) Exec(ctx api.StreamContext, errCh chan<- error) { + o.prepareExec(ctx, errCh, "op") + go func() { + err := infra.SafeRun(func() error { + runWithOrder(ctx, o.defaultSinkNode, o.concurrency, o.Worker) + return nil + }) + if err != nil { + infra.DrainError(ctx, err, errCh) + } + }() +} + +func (o *EncryptNode) Worker(_ api.Logger, item any) []any { + o.statManager.ProcessTimeStart() + defer o.statManager.ProcessTimeEnd() + switch d := item.(type) { + case []byte: + r := o.tool.Encrypt(d) + return []any{r} + default: + return []any{fmt.Errorf("unsupported data received: %v", d)} + } +} diff --git a/internal/topo/node/encrypt_op_test.go b/internal/topo/node/encrypt_op_test.go new file mode 100644 index 0000000000..3ead6bb371 --- /dev/null +++ b/internal/topo/node/encrypt_op_test.go @@ -0,0 +1,78 @@ +// Copyright 2024 EMQ Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package node + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/lf-edge/ekuiper/v2/internal/conf" + "github.com/lf-edge/ekuiper/v2/internal/pkg/def" + mockContext "github.com/lf-edge/ekuiper/v2/pkg/mock/context" +) + +func TestNewEncryptOp(t *testing.T) { + _, err := NewEncryptOp("test", &def.RuleOption{}, "non") + assert.Error(t, err) + assert.Equal(t, "get encryptor non fail with error: unsupported encryptor: non", err.Error()) + _, err = NewEncryptOp("test", &def.RuleOption{}, "aes") + assert.Error(t, err) + assert.Equal(t, errors.New("get encryptor aes fail with error: AES key is not defined"), err) +} + +func TestEncryptOp_Exec(t *testing.T) { + conf.InitConf() + op, err := NewEncryptOp("test", &def.RuleOption{BufferLength: 10, SendError: true}, "aes") + assert.NoError(t, err) + op.tool = &MockEncryptor{} + out := make(chan any, 100) + err = op.AddOutput(out, "test") + assert.NoError(t, err) + ctx := mockContext.NewMockContext("test1", "compress_test") + errCh := make(chan error) + op.Exec(ctx, errCh) + + cases := []any{ + []byte("{\"a\":1,\"b\":2}"), + errors.New("go through error"), + "invalid", + } + expects := [][]any{ + {[]byte("mock encrypt")}, + {errors.New("go through error")}, + {errors.New("unsupported data received: invalid")}, + } + + for i, c := range cases { + op.input <- c + for _, e := range expects[i] { + r := <-out + switch tr := r.(type) { + case error: + assert.EqualError(t, e.(error), tr.Error()) + default: + assert.Equal(t, e, r) + } + } + } +} + +type MockEncryptor struct{} + +func (m *MockEncryptor) Encrypt(_ []byte) []byte { + return []byte("mock encrypt") +} diff --git a/internal/topo/node/props.go b/internal/topo/node/props.go index d3f96093fa..9c6cfe6d71 100644 --- a/internal/topo/node/props.go +++ b/internal/topo/node/props.go @@ -37,6 +37,7 @@ type SinkConf struct { BatchSize int `json:"batchSize"` LingerInterval int `json:"lingerInterval"` Compression string `json:"compression"` + Encryption string `json:"encryption"` conf.SinkConf } diff --git a/internal/topo/planner/planner_sink.go b/internal/topo/planner/planner_sink.go index 38898bdba7..029d07b32e 100644 --- a/internal/topo/planner/planner_sink.go +++ b/internal/topo/planner/planner_sink.go @@ -113,6 +113,16 @@ func splitSink(tp *topo.Topo, inputs []node.Emitter, s api.Sink, sinkName string tp.AddOperator(newInputs, compressOp) newInputs = []node.Emitter{compressOp} } + + if sc.Encryption != "" { + encryptOp, err := node.NewEncryptOp(fmt.Sprintf("%s_%d_encrypt", sinkName, index), options, sc.Encryption) + if err != nil { + return nil, err + } + index++ + tp.AddOperator(newInputs, encryptOp) + newInputs = []node.Emitter{encryptOp} + } } return newInputs, nil } diff --git a/internal/topo/planner/planner_sink_test.go b/internal/topo/planner/planner_sink_test.go index 8c53c86a82..472c6a1cc4 100644 --- a/internal/topo/planner/planner_sink_test.go +++ b/internal/topo/planner/planner_sink_test.go @@ -120,19 +120,55 @@ func TestSinkPlan(t *testing.T) { }, }, }, + { + name: "encrypt and compress sink plan", + rule: &def.Rule{ + Actions: []map[string]any{ + { + "log": map[string]any{ + "compression": "gzip", + "encryption": "aes", + }, + }, + }, + Options: defaultOption, + }, + topo: &def.PrintableTopo{ + Sources: []string{"source_src1"}, + Edges: map[string][]any{ + "source_src1": { + "op_log_0_0_transform", + }, + "op_log_0_0_transform": { + "op_log_0_1_encode", + }, + "op_log_0_1_encode": { + "op_log_0_2_compress", + }, + "op_log_0_2_compress": { + "op_log_0_3_encrypt", + }, + "op_log_0_3_encrypt": { + "sink_log_0", + }, + }, + }, + }, } for _, c := range tc { - tp, err := topo.NewWithNameAndOptions("test", c.rule.Options) - assert.NoError(t, err) - si, err := io.Source("memory") - assert.NoError(t, err) - n, err := node.NewSourceNode(tp.GetContext(), "src1", si, map[string]any{"datasource": "demo"}, &def.RuleOption{SendError: false}) - assert.NoError(t, err) - tp.AddSrc(n) - inputs := []node.Emitter{n} - err = buildActions(tp, c.rule, inputs, 1) - assert.NoError(t, err) - assert.Equal(t, c.topo, tp.GetTopo()) + t.Run(c.name, func(t *testing.T) { + tp, err := topo.NewWithNameAndOptions("test", c.rule.Options) + assert.NoError(t, err) + si, err := io.Source("memory") + assert.NoError(t, err) + n, err := node.NewSourceNode(tp.GetContext(), "src1", si, map[string]any{"datasource": "demo"}, &def.RuleOption{SendError: false}) + assert.NoError(t, err) + tp.AddSrc(n) + inputs := []node.Emitter{n} + err = buildActions(tp, c.rule, inputs, 1) + assert.NoError(t, err) + assert.Equal(t, c.topo, tp.GetTopo()) + }) } }